feat(pipe): Release 3D scene generation pipeline. (#25)
Release 3D scene generation pipeline and tag as v0.1.2. --------- Co-authored-by: xinjie.wang <xinjie.wang@gpu-4090-dev015.hogpu.cc>
This commit is contained in:
parent
51759f011a
commit
e82f02a9a5
5
.gitmodules
vendored
5
.gitmodules
vendored
@ -3,3 +3,8 @@
|
|||||||
url = https://github.com/microsoft/TRELLIS.git
|
url = https://github.com/microsoft/TRELLIS.git
|
||||||
branch = main
|
branch = main
|
||||||
shallow = true
|
shallow = true
|
||||||
|
[submodule "thirdparty/pano2room"]
|
||||||
|
path = thirdparty/pano2room
|
||||||
|
url = https://github.com/TrickyGo/Pano2Room.git
|
||||||
|
branch = main
|
||||||
|
shallow = true
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
repos:
|
repos:
|
||||||
- repo: git@gitlab.hobot.cc:ptd/3rd/pre-commit/pre-commit-hooks.git
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v2.3.0 # Use the ref you want to point at
|
rev: v4.2.0 # Use the ref you want to point at
|
||||||
hooks:
|
hooks:
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
- id: check-added-large-files
|
- id: check-added-large-files
|
||||||
|
|||||||
27
README.md
27
README.md
@ -30,11 +30,11 @@
|
|||||||
```sh
|
```sh
|
||||||
git clone https://github.com/HorizonRobotics/EmbodiedGen.git
|
git clone https://github.com/HorizonRobotics/EmbodiedGen.git
|
||||||
cd EmbodiedGen
|
cd EmbodiedGen
|
||||||
git checkout v0.1.1
|
git checkout v0.1.2
|
||||||
git submodule update --init --recursive --progress
|
git submodule update --init --recursive --progress
|
||||||
conda create -n embodiedgen python=3.10.13 -y
|
conda create -n embodiedgen python=3.10.13 -y # recommended to use a new env.
|
||||||
conda activate embodiedgen
|
conda activate embodiedgen
|
||||||
bash install.sh
|
bash install.sh basic
|
||||||
```
|
```
|
||||||
|
|
||||||
### ✅ Setup GPT Agent
|
### ✅ Setup GPT Agent
|
||||||
@ -94,7 +94,7 @@ python apps/text_to_3d.py
|
|||||||
|
|
||||||
### ⚡ API
|
### ⚡ API
|
||||||
Text-to-image model based on SD3.5 Medium, English prompts only.
|
Text-to-image model based on SD3.5 Medium, English prompts only.
|
||||||
Usage requires agreement to the [model license(click accept)](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium), models downloaded automatically. (ps: models with more permissive licenses found in `embodied_gen/models/image_comm_model.py`)
|
Usage requires agreement to the [model license(click accept)](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium), models downloaded automatically.
|
||||||
|
|
||||||
For large-scale 3D assets generation, set `--n_pipe_retry=2` to ensure high end-to-end 3D asset usability through automatic quality check and retries. For more diverse results, do not set `--seed_img`.
|
For large-scale 3D assets generation, set `--n_pipe_retry=2` to ensure high end-to-end 3D asset usability through automatic quality check and retries. For more diverse results, do not set `--seed_img`.
|
||||||
|
|
||||||
@ -110,6 +110,7 @@ bash embodied_gen/scripts/textto3d.sh \
|
|||||||
--prompts "small bronze figurine of a lion" "A globe with wooden base and latitude and longitude lines" "橙色电动手钻,有磨损细节" \
|
--prompts "small bronze figurine of a lion" "A globe with wooden base and latitude and longitude lines" "橙色电动手钻,有磨损细节" \
|
||||||
--output_root outputs/textto3d_k
|
--output_root outputs/textto3d_k
|
||||||
```
|
```
|
||||||
|
ps: models with more permissive licenses found in `embodied_gen/models/image_comm_model.py`
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@ -146,10 +147,22 @@ bash embodied_gen/scripts/texture_gen.sh \
|
|||||||
|
|
||||||
<h2 id="3d-scene-generation">🌍 3D Scene Generation</h2>
|
<h2 id="3d-scene-generation">🌍 3D Scene Generation</h2>
|
||||||
|
|
||||||
🚧 *Coming Soon*
|
|
||||||
|
|
||||||
<img src="apps/assets/scene3d.gif" alt="scene3d" style="width: 640px;">
|
<img src="apps/assets/scene3d.gif" alt="scene3d" style="width: 640px;">
|
||||||
|
|
||||||
|
### ⚡ API
|
||||||
|
> Run `bash install.sh extra` to install additional requirements if you need to use `scene3d-cli`.
|
||||||
|
|
||||||
|
It takes ~30mins to generate a color mesh and 3DGS per scene.
|
||||||
|
|
||||||
|
```sh
|
||||||
|
CUDA_VISIBLE_DEVICES=0 scene3d-cli \
|
||||||
|
--prompts "Art studio with easel and canvas" \
|
||||||
|
--output_dir outputs/bg_scenes/ \
|
||||||
|
--seed 0 \
|
||||||
|
--gs3d.max_steps 4000 \
|
||||||
|
--disable_pano_check
|
||||||
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|
||||||
@ -189,7 +202,7 @@ bash embodied_gen/scripts/texture_gen.sh \
|
|||||||
|
|
||||||
## For Developer
|
## For Developer
|
||||||
```sh
|
```sh
|
||||||
pip install .[dev] && pre-commit install
|
pip install -e .[dev] && pre-commit install
|
||||||
python -m pytest # Pass all unit-test are required.
|
python -m pytest # Pass all unit-test are required.
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@ -94,9 +94,6 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
|
|||||||
os.environ["SPCONV_ALGO"] = "native"
|
os.environ["SPCONV_ALGO"] = "native"
|
||||||
|
|
||||||
MAX_SEED = 100000
|
MAX_SEED = 100000
|
||||||
DELIGHT = DelightingModel()
|
|
||||||
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
|
||||||
# IMAGESR_MODEL = ImageStableSR()
|
|
||||||
|
|
||||||
|
|
||||||
def patched_setup_functions(self):
|
def patched_setup_functions(self):
|
||||||
@ -136,6 +133,9 @@ def patched_setup_functions(self):
|
|||||||
Gaussian.setup_functions = patched_setup_functions
|
Gaussian.setup_functions = patched_setup_functions
|
||||||
|
|
||||||
|
|
||||||
|
DELIGHT = DelightingModel()
|
||||||
|
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
||||||
|
# IMAGESR_MODEL = ImageStableSR()
|
||||||
if os.getenv("GRADIO_APP") == "imageto3d":
|
if os.getenv("GRADIO_APP") == "imageto3d":
|
||||||
RBG_REMOVER = RembgRemover()
|
RBG_REMOVER = RembgRemover()
|
||||||
RBG14_REMOVER = BMGG14Remover()
|
RBG14_REMOVER = BMGG14Remover()
|
||||||
|
|||||||
@ -19,8 +19,9 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from typing import Any, Callable, Dict, List, Tuple
|
from typing import Any, Callable, Dict, List, Literal, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -36,6 +37,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Asset3dGenDataset",
|
"Asset3dGenDataset",
|
||||||
|
"PanoGSplatDataset",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -222,6 +224,68 @@ class Asset3dGenDataset(Dataset):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class PanoGSplatDataset(Dataset):
|
||||||
|
"""A PyTorch Dataset for loading panorama-based 3D Gaussian Splatting data.
|
||||||
|
|
||||||
|
This dataset is designed to be compatible with train and eval pipelines
|
||||||
|
that use COLMAP-style camera conventions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dir (str): Root directory where the dataset file is located.
|
||||||
|
split (str): Dataset split to use, either "train" or "eval".
|
||||||
|
data_name (str, optional): Name of the dataset file (default: "gs_data.pt").
|
||||||
|
max_sample_num (int, optional): Maximum number of samples to load. If None,
|
||||||
|
all available samples in the split will be used.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_dir: str,
|
||||||
|
split: str = Literal["train", "eval"],
|
||||||
|
data_name: str = "gs_data.pt",
|
||||||
|
max_sample_num: int = None,
|
||||||
|
) -> None:
|
||||||
|
self.data_path = os.path.join(data_dir, data_name)
|
||||||
|
self.split = split
|
||||||
|
self.max_sample_num = max_sample_num
|
||||||
|
if not os.path.exists(self.data_path):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Dataset file {self.data_path} not found. Please provide the correct path."
|
||||||
|
)
|
||||||
|
self.data = torch.load(self.data_path, weights_only=False)
|
||||||
|
self.frames = self.data[split]
|
||||||
|
if max_sample_num is not None:
|
||||||
|
self.frames = self.frames[:max_sample_num]
|
||||||
|
self.points = self.data.get("points", None)
|
||||||
|
self.points_rgb = self.data.get("points_rgb", None)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.frames)
|
||||||
|
|
||||||
|
def cvt_blender_to_colmap_coord(self, c2w: np.ndarray) -> np.ndarray:
|
||||||
|
# change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
|
||||||
|
tranformed_c2w = np.copy(c2w)
|
||||||
|
tranformed_c2w[:3, 1:3] *= -1
|
||||||
|
|
||||||
|
return tranformed_c2w
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> dict[str, any]:
|
||||||
|
data = self.frames[index]
|
||||||
|
c2w = self.cvt_blender_to_colmap_coord(data["camtoworld"])
|
||||||
|
item = dict(
|
||||||
|
camtoworld=c2w,
|
||||||
|
K=data["K"],
|
||||||
|
image_h=data["image_h"],
|
||||||
|
image_w=data["image_w"],
|
||||||
|
)
|
||||||
|
if "image" in data:
|
||||||
|
item["image"] = data["image"]
|
||||||
|
if "image_id" in data:
|
||||||
|
item["image_id"] = data["image_id"]
|
||||||
|
|
||||||
|
return item
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
index_file = "datasets/objaverse/v1.0/statistics_1.0_gobjaverse_filter/view6s_v4/meta_ac2e0ddea8909db26d102c8465b5bcb2.json" # noqa
|
index_file = "datasets/objaverse/v1.0/statistics_1.0_gobjaverse_filter/view6s_v4/meta_ac2e0ddea8909db26d102c8465b5bcb2.json" # noqa
|
||||||
target_hw = (512, 512)
|
target_hw = (512, 512)
|
||||||
|
|||||||
@ -158,8 +158,9 @@ class DiffrastRender(object):
|
|||||||
|
|
||||||
return normalized_maps
|
return normalized_maps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def normalize_map_by_mask(
|
def normalize_map_by_mask(
|
||||||
self, map: torch.Tensor, mask: torch.Tensor
|
map: torch.Tensor, mask: torch.Tensor
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Normalize all maps in total by mask, normalized map in [0, 1].
|
# Normalize all maps in total by mask, normalized map in [0, 1].
|
||||||
foreground = (mask == 1).squeeze(dim=-1)
|
foreground = (mask == 1).squeeze(dim=-1)
|
||||||
|
|||||||
191
embodied_gen/scripts/gen_scene3d.py
Normal file
191
embodied_gen/scripts/gen_scene3d.py
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from shutil import copy, rmtree
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import tyro
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
# Suppress warnings
|
||||||
|
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||||
|
logging.getLogger("transformers").setLevel(logging.ERROR)
|
||||||
|
logging.getLogger("diffusers").setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
# TorchVision monkey patch for >0.16
|
||||||
|
if version.parse(torch.__version__) >= version.parse("0.16"):
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
|
||||||
|
import torchvision.transforms.functional as TF
|
||||||
|
|
||||||
|
functional_tensor = types.ModuleType(
|
||||||
|
"torchvision.transforms.functional_tensor"
|
||||||
|
)
|
||||||
|
functional_tensor.rgb_to_grayscale = TF.rgb_to_grayscale
|
||||||
|
sys.modules["torchvision.transforms.functional_tensor"] = functional_tensor
|
||||||
|
|
||||||
|
from gsplat.distributed import cli
|
||||||
|
from txt2panoimg import Text2360PanoramaImagePipeline
|
||||||
|
from embodied_gen.trainer.gsplat_trainer import (
|
||||||
|
DefaultStrategy,
|
||||||
|
GsplatTrainConfig,
|
||||||
|
)
|
||||||
|
from embodied_gen.trainer.gsplat_trainer import entrypoint as gsplat_entrypoint
|
||||||
|
from embodied_gen.trainer.pono2mesh_trainer import Pano2MeshSRPipeline
|
||||||
|
from embodied_gen.utils.config import Pano2MeshSRConfig
|
||||||
|
from embodied_gen.utils.gaussian import restore_scene_scale_and_position
|
||||||
|
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
||||||
|
from embodied_gen.utils.log import logger
|
||||||
|
from embodied_gen.utils.process_media import is_image_file, parse_text_prompts
|
||||||
|
from embodied_gen.validators.quality_checkers import (
|
||||||
|
PanoHeightEstimator,
|
||||||
|
PanoImageOccChecker,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"generate_pano_image",
|
||||||
|
"entrypoint",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Scene3DGenConfig:
|
||||||
|
prompts: list[str] # Text desc of indoor room or style reference image.
|
||||||
|
output_dir: str
|
||||||
|
seed: int | None = None
|
||||||
|
real_height: float | None = None # The real height of the room in meters.
|
||||||
|
pano_image_only: bool = False
|
||||||
|
disable_pano_check: bool = False
|
||||||
|
keep_middle_result: bool = False
|
||||||
|
n_retry: int = 7
|
||||||
|
gs3d: GsplatTrainConfig = field(
|
||||||
|
default_factory=lambda: GsplatTrainConfig(
|
||||||
|
strategy=DefaultStrategy(verbose=True),
|
||||||
|
max_steps=4000,
|
||||||
|
init_opa=0.9,
|
||||||
|
opacity_reg=2e-3,
|
||||||
|
sh_degree=0,
|
||||||
|
means_lr=1e-4,
|
||||||
|
scales_lr=1e-3,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_pano_image(
|
||||||
|
prompt: str,
|
||||||
|
output_path: str,
|
||||||
|
pipeline,
|
||||||
|
seed: int,
|
||||||
|
n_retry: int,
|
||||||
|
checker=None,
|
||||||
|
num_inference_steps: int = 40,
|
||||||
|
) -> None:
|
||||||
|
for i in range(n_retry):
|
||||||
|
logger.info(
|
||||||
|
f"GEN Panorama: Retry {i+1}/{n_retry} for prompt: {prompt}, seed: {seed}"
|
||||||
|
)
|
||||||
|
if is_image_file(prompt):
|
||||||
|
raise NotImplementedError("Image mode not implemented yet.")
|
||||||
|
else:
|
||||||
|
txt_prompt = f"{prompt}, spacious, empty, wide open, open floor, minimal furniture"
|
||||||
|
inputs = {
|
||||||
|
"prompt": txt_prompt,
|
||||||
|
"num_inference_steps": num_inference_steps,
|
||||||
|
"upscale": False,
|
||||||
|
"seed": seed,
|
||||||
|
}
|
||||||
|
pano_image = pipeline(inputs)
|
||||||
|
|
||||||
|
pano_image.save(output_path)
|
||||||
|
if checker is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
flag, response = checker(pano_image)
|
||||||
|
logger.warning(f"{response}, image saved in {output_path}")
|
||||||
|
if flag is True or flag is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
seed = random.randint(0, 100000)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def entrypoint(*args, **kwargs):
|
||||||
|
cfg = tyro.cli(Scene3DGenConfig)
|
||||||
|
|
||||||
|
# Init global models.
|
||||||
|
model_path = snapshot_download("archerfmy0831/sd-t2i-360panoimage")
|
||||||
|
IMG2PANO_PIPE = Text2360PanoramaImagePipeline(
|
||||||
|
model_path, torch_dtype=torch.float16, device="cuda"
|
||||||
|
)
|
||||||
|
PANOMESH_CFG = Pano2MeshSRConfig()
|
||||||
|
PANO2MESH_PIPE = Pano2MeshSRPipeline(PANOMESH_CFG)
|
||||||
|
PANO_CHECKER = PanoImageOccChecker(GPT_CLIENT, box_hw=[95, 1000])
|
||||||
|
PANOHEIGHT_ESTOR = PanoHeightEstimator(GPT_CLIENT)
|
||||||
|
|
||||||
|
prompts = parse_text_prompts(cfg.prompts)
|
||||||
|
for idx, prompt in enumerate(prompts):
|
||||||
|
start_time = time.time()
|
||||||
|
output_dir = os.path.join(cfg.output_dir, f"scene_{idx:04d}")
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
pano_path = os.path.join(output_dir, "pano_image.png")
|
||||||
|
with open(f"{output_dir}/prompt.txt", "w") as f:
|
||||||
|
f.write(prompt)
|
||||||
|
|
||||||
|
generate_pano_image(
|
||||||
|
prompt,
|
||||||
|
pano_path,
|
||||||
|
IMG2PANO_PIPE,
|
||||||
|
cfg.seed if cfg.seed is not None else random.randint(0, 100000),
|
||||||
|
cfg.n_retry,
|
||||||
|
checker=None if cfg.disable_pano_check else PANO_CHECKER,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.pano_image_only:
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info("GEN and REPAIR Mesh from Panorama...")
|
||||||
|
PANO2MESH_PIPE(pano_path, output_dir)
|
||||||
|
|
||||||
|
logger.info("TRAIN 3DGS from Mesh Init and Cube Image...")
|
||||||
|
cfg.gs3d.data_dir = output_dir
|
||||||
|
cfg.gs3d.result_dir = f"{output_dir}/gaussian"
|
||||||
|
cfg.gs3d.adjust_steps(cfg.gs3d.steps_scaler)
|
||||||
|
torch.set_default_device("cpu") # recover default setting.
|
||||||
|
cli(gsplat_entrypoint, cfg.gs3d, verbose=True)
|
||||||
|
|
||||||
|
# Clean up the middle results.
|
||||||
|
gs_path = (
|
||||||
|
f"{cfg.gs3d.result_dir}/ply/point_cloud_{cfg.gs3d.max_steps-1}.ply"
|
||||||
|
)
|
||||||
|
copy(gs_path, f"{output_dir}/gs_model.ply")
|
||||||
|
video_path = f"{cfg.gs3d.result_dir}/renders/video_step{cfg.gs3d.max_steps-1}.mp4"
|
||||||
|
copy(video_path, f"{output_dir}/video.mp4")
|
||||||
|
gs_cfg_path = f"{cfg.gs3d.result_dir}/cfg.yml"
|
||||||
|
copy(gs_cfg_path, f"{output_dir}/gsplat_cfg.yml")
|
||||||
|
if not cfg.keep_middle_result:
|
||||||
|
rmtree(cfg.gs3d.result_dir, ignore_errors=True)
|
||||||
|
os.remove(f"{output_dir}/{PANOMESH_CFG.gs_data_file}")
|
||||||
|
|
||||||
|
real_height = (
|
||||||
|
PANOHEIGHT_ESTOR(pano_path)
|
||||||
|
if cfg.real_height is None
|
||||||
|
else cfg.real_height
|
||||||
|
)
|
||||||
|
gs_path = os.path.join(output_dir, "gs_model.ply")
|
||||||
|
mesh_path = os.path.join(output_dir, "mesh_model.ply")
|
||||||
|
restore_scene_scale_and_position(real_height, mesh_path, gs_path)
|
||||||
|
|
||||||
|
elapsed_time = (time.time() - start_time) / 60
|
||||||
|
logger.info(
|
||||||
|
f"FINISHED 3D scene generation in {output_dir} in {elapsed_time:.2f} mins."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
entrypoint()
|
||||||
@ -62,25 +62,6 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
|
|||||||
os.environ["SPCONV_ALGO"] = "native"
|
os.environ["SPCONV_ALGO"] = "native"
|
||||||
random.seed(0)
|
random.seed(0)
|
||||||
|
|
||||||
logger.info("Loading Models...")
|
|
||||||
DELIGHT = DelightingModel()
|
|
||||||
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
|
||||||
|
|
||||||
RBG_REMOVER = RembgRemover()
|
|
||||||
RBG14_REMOVER = BMGG14Remover()
|
|
||||||
SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
|
|
||||||
PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
|
|
||||||
"microsoft/TRELLIS-image-large"
|
|
||||||
)
|
|
||||||
# PIPELINE.cuda()
|
|
||||||
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
|
|
||||||
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
|
|
||||||
AESTHETIC_CHECKER = ImageAestheticChecker()
|
|
||||||
CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
|
|
||||||
TMP_DIR = os.path.join(
|
|
||||||
os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description="Image to 3D pipeline args.")
|
parser = argparse.ArgumentParser(description="Image to 3D pipeline args.")
|
||||||
@ -128,6 +109,19 @@ def entrypoint(**kwargs):
|
|||||||
if hasattr(args, k) and v is not None:
|
if hasattr(args, k) and v is not None:
|
||||||
setattr(args, k, v)
|
setattr(args, k, v)
|
||||||
|
|
||||||
|
logger.info("Loading Models...")
|
||||||
|
DELIGHT = DelightingModel()
|
||||||
|
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
||||||
|
RBG_REMOVER = RembgRemover()
|
||||||
|
PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
|
||||||
|
"microsoft/TRELLIS-image-large"
|
||||||
|
)
|
||||||
|
# PIPELINE.cuda()
|
||||||
|
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
|
||||||
|
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
|
||||||
|
AESTHETIC_CHECKER = ImageAestheticChecker()
|
||||||
|
CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
args.image_path or args.image_root
|
args.image_path or args.image_root
|
||||||
), "Please provide either --image_path or --image_root."
|
), "Please provide either --image_path or --image_root."
|
||||||
|
|||||||
@ -31,6 +31,7 @@ from embodied_gen.models.text_model import (
|
|||||||
build_text2img_pipeline,
|
build_text2img_pipeline,
|
||||||
text2img_gen,
|
text2img_gen,
|
||||||
)
|
)
|
||||||
|
from embodied_gen.utils.process_media import parse_text_prompts
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -101,14 +102,7 @@ def entrypoint(
|
|||||||
if hasattr(args, k) and v is not None:
|
if hasattr(args, k) and v is not None:
|
||||||
setattr(args, k, v)
|
setattr(args, k, v)
|
||||||
|
|
||||||
prompts = args.prompts
|
prompts = parse_text_prompts(args.prompts)
|
||||||
if len(prompts) == 1 and prompts[0].endswith(".txt"):
|
|
||||||
with open(prompts[0], "r") as f:
|
|
||||||
prompts = f.readlines()
|
|
||||||
prompts = [
|
|
||||||
prompt.strip() for prompt in prompts if prompt.strip() != ""
|
|
||||||
]
|
|
||||||
|
|
||||||
os.makedirs(args.output_root, exist_ok=True)
|
os.makedirs(args.output_root, exist_ok=True)
|
||||||
|
|
||||||
ip_img_paths = args.ref_image
|
ip_img_paths = args.ref_image
|
||||||
|
|||||||
@ -44,13 +44,6 @@ __all__ = [
|
|||||||
"text_to_3d",
|
"text_to_3d",
|
||||||
]
|
]
|
||||||
|
|
||||||
logger.info("Loading Models...")
|
|
||||||
SEMANTIC_CHECKER = SemanticConsistChecker(GPT_CLIENT)
|
|
||||||
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
|
|
||||||
TXTGEN_CHECKER = TextGenAlignChecker(GPT_CLIENT)
|
|
||||||
PIPE_IMG = build_hf_image_pipeline("sd35")
|
|
||||||
BG_REMOVER = RembgRemover()
|
|
||||||
|
|
||||||
|
|
||||||
def text_to_image(
|
def text_to_image(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@ -121,6 +114,14 @@ def text_to_3d(**kwargs) -> dict:
|
|||||||
if hasattr(args, k) and v is not None:
|
if hasattr(args, k) and v is not None:
|
||||||
setattr(args, k, v)
|
setattr(args, k, v)
|
||||||
|
|
||||||
|
logger.info("Loading Models...")
|
||||||
|
global SEMANTIC_CHECKER, SEG_CHECKER, TXTGEN_CHECKER, PIPE_IMG, BG_REMOVER
|
||||||
|
SEMANTIC_CHECKER = SemanticConsistChecker(GPT_CLIENT)
|
||||||
|
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
|
||||||
|
TXTGEN_CHECKER = TextGenAlignChecker(GPT_CLIENT)
|
||||||
|
PIPE_IMG = build_hf_image_pipeline(args.text_model)
|
||||||
|
BG_REMOVER = RembgRemover()
|
||||||
|
|
||||||
if args.asset_names is None or len(args.asset_names) == 0:
|
if args.asset_names is None or len(args.asset_names) == 0:
|
||||||
args.asset_names = [f"sample3d_{i}" for i in range(len(args.prompts))]
|
args.asset_names = [f"sample3d_{i}" for i in range(len(args.prompts))]
|
||||||
img_save_dir = os.path.join(args.output_root, "images")
|
img_save_dir = os.path.join(args.output_root, "images")
|
||||||
@ -260,6 +261,11 @@ def parse_args():
|
|||||||
default=0,
|
default=0,
|
||||||
help="Random seed for 3D generation",
|
help="Random seed for 3D generation",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--text_model",
|
||||||
|
type=str,
|
||||||
|
default="sd35",
|
||||||
|
)
|
||||||
parser.add_argument("--keep_intermediate", action="store_true")
|
parser.add_argument("--keep_intermediate", action="store_true")
|
||||||
|
|
||||||
args, unknown = parser.parse_known_args()
|
args, unknown = parser.parse_known_args()
|
||||||
|
|||||||
678
embodied_gen/trainer/gsplat_trainer.py
Normal file
678
embodied_gen/trainer/gsplat_trainer.py
Normal file
@ -0,0 +1,678 @@
|
|||||||
|
# Project EmbodiedGen
|
||||||
|
#
|
||||||
|
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
# implied. See the License for the specific language governing
|
||||||
|
# permissions and limitations under the License.
|
||||||
|
# Part of the code comes from https://github.com/nerfstudio-project/gsplat/blob/main/examples/simple_trainer.py
|
||||||
|
# Both under the Apache License, Version 2.0.
|
||||||
|
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import imageio
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import tqdm
|
||||||
|
import tyro
|
||||||
|
import yaml
|
||||||
|
from fused_ssim import fused_ssim
|
||||||
|
from gsplat.distributed import cli
|
||||||
|
from gsplat.rendering import rasterization
|
||||||
|
from gsplat.strategy import DefaultStrategy, MCMCStrategy
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from torchmetrics.image import (
|
||||||
|
PeakSignalNoiseRatio,
|
||||||
|
StructuralSimilarityIndexMeasure,
|
||||||
|
)
|
||||||
|
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
||||||
|
from typing_extensions import Literal, assert_never
|
||||||
|
from embodied_gen.data.datasets import PanoGSplatDataset
|
||||||
|
from embodied_gen.utils.config import GsplatTrainConfig
|
||||||
|
from embodied_gen.utils.gaussian import (
|
||||||
|
create_splats_with_optimizers,
|
||||||
|
export_splats,
|
||||||
|
resize_pinhole_intrinsics,
|
||||||
|
set_random_seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Runner:
|
||||||
|
"""Engine for training and testing from gsplat example.
|
||||||
|
|
||||||
|
Code from https://github.com/nerfstudio-project/gsplat/blob/main/examples/simple_trainer.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
local_rank: int,
|
||||||
|
world_rank,
|
||||||
|
world_size: int,
|
||||||
|
cfg: GsplatTrainConfig,
|
||||||
|
) -> None:
|
||||||
|
set_random_seed(42 + local_rank)
|
||||||
|
|
||||||
|
self.cfg = cfg
|
||||||
|
self.world_rank = world_rank
|
||||||
|
self.local_rank = local_rank
|
||||||
|
self.world_size = world_size
|
||||||
|
self.device = f"cuda:{local_rank}"
|
||||||
|
|
||||||
|
# Where to dump results.
|
||||||
|
os.makedirs(cfg.result_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Setup output directories.
|
||||||
|
self.ckpt_dir = f"{cfg.result_dir}/ckpts"
|
||||||
|
os.makedirs(self.ckpt_dir, exist_ok=True)
|
||||||
|
self.stats_dir = f"{cfg.result_dir}/stats"
|
||||||
|
os.makedirs(self.stats_dir, exist_ok=True)
|
||||||
|
self.render_dir = f"{cfg.result_dir}/renders"
|
||||||
|
os.makedirs(self.render_dir, exist_ok=True)
|
||||||
|
self.ply_dir = f"{cfg.result_dir}/ply"
|
||||||
|
os.makedirs(self.ply_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Tensorboard
|
||||||
|
self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb")
|
||||||
|
self.trainset = PanoGSplatDataset(cfg.data_dir, split="train")
|
||||||
|
self.valset = PanoGSplatDataset(
|
||||||
|
cfg.data_dir, split="train", max_sample_num=6
|
||||||
|
)
|
||||||
|
self.testset = PanoGSplatDataset(cfg.data_dir, split="eval")
|
||||||
|
self.scene_scale = cfg.scene_scale
|
||||||
|
|
||||||
|
# Model
|
||||||
|
self.splats, self.optimizers = create_splats_with_optimizers(
|
||||||
|
self.trainset.points,
|
||||||
|
self.trainset.points_rgb,
|
||||||
|
init_num_pts=cfg.init_num_pts,
|
||||||
|
init_extent=cfg.init_extent,
|
||||||
|
init_opacity=cfg.init_opa,
|
||||||
|
init_scale=cfg.init_scale,
|
||||||
|
means_lr=cfg.means_lr,
|
||||||
|
scales_lr=cfg.scales_lr,
|
||||||
|
opacities_lr=cfg.opacities_lr,
|
||||||
|
quats_lr=cfg.quats_lr,
|
||||||
|
sh0_lr=cfg.sh0_lr,
|
||||||
|
shN_lr=cfg.shN_lr,
|
||||||
|
scene_scale=self.scene_scale,
|
||||||
|
sh_degree=cfg.sh_degree,
|
||||||
|
sparse_grad=cfg.sparse_grad,
|
||||||
|
visible_adam=cfg.visible_adam,
|
||||||
|
batch_size=cfg.batch_size,
|
||||||
|
feature_dim=None,
|
||||||
|
device=self.device,
|
||||||
|
world_rank=world_rank,
|
||||||
|
world_size=world_size,
|
||||||
|
)
|
||||||
|
print("Model initialized. Number of GS:", len(self.splats["means"]))
|
||||||
|
|
||||||
|
# Densification Strategy
|
||||||
|
self.cfg.strategy.check_sanity(self.splats, self.optimizers)
|
||||||
|
|
||||||
|
if isinstance(self.cfg.strategy, DefaultStrategy):
|
||||||
|
self.strategy_state = self.cfg.strategy.initialize_state(
|
||||||
|
scene_scale=self.scene_scale
|
||||||
|
)
|
||||||
|
elif isinstance(self.cfg.strategy, MCMCStrategy):
|
||||||
|
self.strategy_state = self.cfg.strategy.initialize_state()
|
||||||
|
else:
|
||||||
|
assert_never(self.cfg.strategy)
|
||||||
|
|
||||||
|
# Losses & Metrics.
|
||||||
|
self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(
|
||||||
|
self.device
|
||||||
|
)
|
||||||
|
self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device)
|
||||||
|
|
||||||
|
if cfg.lpips_net == "alex":
|
||||||
|
self.lpips = LearnedPerceptualImagePatchSimilarity(
|
||||||
|
net_type="alex", normalize=True
|
||||||
|
).to(self.device)
|
||||||
|
elif cfg.lpips_net == "vgg":
|
||||||
|
# The 3DGS official repo uses lpips vgg, which is equivalent with the following:
|
||||||
|
self.lpips = LearnedPerceptualImagePatchSimilarity(
|
||||||
|
net_type="vgg", normalize=False
|
||||||
|
).to(self.device)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown LPIPS network: {cfg.lpips_net}")
|
||||||
|
|
||||||
|
def rasterize_splats(
|
||||||
|
self,
|
||||||
|
camtoworlds: Tensor,
|
||||||
|
Ks: Tensor,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
masks: Optional[Tensor] = None,
|
||||||
|
rasterize_mode: Optional[Literal["classic", "antialiased"]] = None,
|
||||||
|
camera_model: Optional[Literal["pinhole", "ortho", "fisheye"]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[Tensor, Tensor, Dict]:
|
||||||
|
means = self.splats["means"] # [N, 3]
|
||||||
|
# quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4]
|
||||||
|
# rasterization does normalization internally
|
||||||
|
quats = self.splats["quats"] # [N, 4]
|
||||||
|
scales = torch.exp(self.splats["scales"]) # [N, 3]
|
||||||
|
opacities = torch.sigmoid(self.splats["opacities"]) # [N,]
|
||||||
|
image_ids = kwargs.pop("image_ids", None)
|
||||||
|
|
||||||
|
colors = torch.cat(
|
||||||
|
[self.splats["sh0"], self.splats["shN"]], 1
|
||||||
|
) # [N, K, 3]
|
||||||
|
|
||||||
|
if rasterize_mode is None:
|
||||||
|
rasterize_mode = (
|
||||||
|
"antialiased" if self.cfg.antialiased else "classic"
|
||||||
|
)
|
||||||
|
if camera_model is None:
|
||||||
|
camera_model = self.cfg.camera_model
|
||||||
|
|
||||||
|
render_colors, render_alphas, info = rasterization(
|
||||||
|
means=means,
|
||||||
|
quats=quats,
|
||||||
|
scales=scales,
|
||||||
|
opacities=opacities,
|
||||||
|
colors=colors,
|
||||||
|
viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4]
|
||||||
|
Ks=Ks, # [C, 3, 3]
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
packed=self.cfg.packed,
|
||||||
|
absgrad=(
|
||||||
|
self.cfg.strategy.absgrad
|
||||||
|
if isinstance(self.cfg.strategy, DefaultStrategy)
|
||||||
|
else False
|
||||||
|
),
|
||||||
|
sparse_grad=self.cfg.sparse_grad,
|
||||||
|
rasterize_mode=rasterize_mode,
|
||||||
|
distributed=self.world_size > 1,
|
||||||
|
camera_model=self.cfg.camera_model,
|
||||||
|
with_ut=self.cfg.with_ut,
|
||||||
|
with_eval3d=self.cfg.with_eval3d,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
if masks is not None:
|
||||||
|
render_colors[~masks] = 0
|
||||||
|
return render_colors, render_alphas, info
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
cfg = self.cfg
|
||||||
|
device = self.device
|
||||||
|
world_rank = self.world_rank
|
||||||
|
|
||||||
|
# Dump cfg.
|
||||||
|
if world_rank == 0:
|
||||||
|
with open(f"{cfg.result_dir}/cfg.yml", "w") as f:
|
||||||
|
yaml.dump(vars(cfg), f)
|
||||||
|
|
||||||
|
max_steps = cfg.max_steps
|
||||||
|
init_step = 0
|
||||||
|
|
||||||
|
schedulers = [
|
||||||
|
# means has a learning rate schedule, that end at 0.01 of the initial value
|
||||||
|
torch.optim.lr_scheduler.ExponentialLR(
|
||||||
|
self.optimizers["means"], gamma=0.01 ** (1.0 / max_steps)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
trainloader = torch.utils.data.DataLoader(
|
||||||
|
self.trainset,
|
||||||
|
batch_size=cfg.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=4,
|
||||||
|
persistent_workers=True,
|
||||||
|
pin_memory=True,
|
||||||
|
)
|
||||||
|
trainloader_iter = iter(trainloader)
|
||||||
|
|
||||||
|
# Training loop.
|
||||||
|
global_tic = time.time()
|
||||||
|
pbar = tqdm.tqdm(range(init_step, max_steps))
|
||||||
|
for step in pbar:
|
||||||
|
try:
|
||||||
|
data = next(trainloader_iter)
|
||||||
|
except StopIteration:
|
||||||
|
trainloader_iter = iter(trainloader)
|
||||||
|
data = next(trainloader_iter)
|
||||||
|
|
||||||
|
camtoworlds = data["camtoworld"].to(device) # [1, 4, 4]
|
||||||
|
Ks = data["K"].to(device) # [1, 3, 3]
|
||||||
|
pixels = data["image"].to(device) / 255.0 # [1, H, W, 3]
|
||||||
|
image_ids = data["image_id"].to(device)
|
||||||
|
masks = (
|
||||||
|
data["mask"].to(device) if "mask" in data else None
|
||||||
|
) # [1, H, W]
|
||||||
|
if cfg.depth_loss:
|
||||||
|
points = data["points"].to(device) # [1, M, 2]
|
||||||
|
depths_gt = data["depths"].to(device) # [1, M]
|
||||||
|
|
||||||
|
height, width = pixels.shape[1:3]
|
||||||
|
|
||||||
|
# sh schedule
|
||||||
|
sh_degree_to_use = min(
|
||||||
|
step // cfg.sh_degree_interval, cfg.sh_degree
|
||||||
|
)
|
||||||
|
|
||||||
|
# forward
|
||||||
|
renders, alphas, info = self.rasterize_splats(
|
||||||
|
camtoworlds=camtoworlds,
|
||||||
|
Ks=Ks,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
sh_degree=sh_degree_to_use,
|
||||||
|
near_plane=cfg.near_plane,
|
||||||
|
far_plane=cfg.far_plane,
|
||||||
|
image_ids=image_ids,
|
||||||
|
render_mode="RGB+ED" if cfg.depth_loss else "RGB",
|
||||||
|
masks=masks,
|
||||||
|
)
|
||||||
|
if renders.shape[-1] == 4:
|
||||||
|
colors, depths = renders[..., 0:3], renders[..., 3:4]
|
||||||
|
else:
|
||||||
|
colors, depths = renders, None
|
||||||
|
|
||||||
|
if cfg.random_bkgd:
|
||||||
|
bkgd = torch.rand(1, 3, device=device)
|
||||||
|
colors = colors + bkgd * (1.0 - alphas)
|
||||||
|
|
||||||
|
self.cfg.strategy.step_pre_backward(
|
||||||
|
params=self.splats,
|
||||||
|
optimizers=self.optimizers,
|
||||||
|
state=self.strategy_state,
|
||||||
|
step=step,
|
||||||
|
info=info,
|
||||||
|
)
|
||||||
|
|
||||||
|
# loss
|
||||||
|
l1loss = F.l1_loss(colors, pixels)
|
||||||
|
ssimloss = 1.0 - fused_ssim(
|
||||||
|
colors.permute(0, 3, 1, 2),
|
||||||
|
pixels.permute(0, 3, 1, 2),
|
||||||
|
padding="valid",
|
||||||
|
)
|
||||||
|
loss = (
|
||||||
|
l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda
|
||||||
|
)
|
||||||
|
if cfg.depth_loss:
|
||||||
|
# query depths from depth map
|
||||||
|
points = torch.stack(
|
||||||
|
[
|
||||||
|
points[:, :, 0] / (width - 1) * 2 - 1,
|
||||||
|
points[:, :, 1] / (height - 1) * 2 - 1,
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
) # normalize to [-1, 1]
|
||||||
|
grid = points.unsqueeze(2) # [1, M, 1, 2]
|
||||||
|
depths = F.grid_sample(
|
||||||
|
depths.permute(0, 3, 1, 2), grid, align_corners=True
|
||||||
|
) # [1, 1, M, 1]
|
||||||
|
depths = depths.squeeze(3).squeeze(1) # [1, M]
|
||||||
|
# calculate loss in disparity space
|
||||||
|
disp = torch.where(
|
||||||
|
depths > 0.0, 1.0 / depths, torch.zeros_like(depths)
|
||||||
|
)
|
||||||
|
disp_gt = 1.0 / depths_gt # [1, M]
|
||||||
|
depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale
|
||||||
|
loss += depthloss * cfg.depth_lambda
|
||||||
|
|
||||||
|
# regularizations
|
||||||
|
if cfg.opacity_reg > 0.0:
|
||||||
|
loss += (
|
||||||
|
cfg.opacity_reg
|
||||||
|
* torch.sigmoid(self.splats["opacities"]).mean()
|
||||||
|
)
|
||||||
|
if cfg.scale_reg > 0.0:
|
||||||
|
loss += cfg.scale_reg * torch.exp(self.splats["scales"]).mean()
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
desc = (
|
||||||
|
f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| "
|
||||||
|
)
|
||||||
|
if cfg.depth_loss:
|
||||||
|
desc += f"depth loss={depthloss.item():.6f}| "
|
||||||
|
pbar.set_description(desc)
|
||||||
|
|
||||||
|
# write images (gt and render)
|
||||||
|
# if world_rank == 0 and step % 800 == 0:
|
||||||
|
# canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy()
|
||||||
|
# canvas = canvas.reshape(-1, *canvas.shape[2:])
|
||||||
|
# imageio.imwrite(
|
||||||
|
# f"{self.render_dir}/train_rank{self.world_rank}.png",
|
||||||
|
# (canvas * 255).astype(np.uint8),
|
||||||
|
# )
|
||||||
|
|
||||||
|
if (
|
||||||
|
world_rank == 0
|
||||||
|
and cfg.tb_every > 0
|
||||||
|
and step % cfg.tb_every == 0
|
||||||
|
):
|
||||||
|
mem = torch.cuda.max_memory_allocated() / 1024**3
|
||||||
|
self.writer.add_scalar("train/loss", loss.item(), step)
|
||||||
|
self.writer.add_scalar("train/l1loss", l1loss.item(), step)
|
||||||
|
self.writer.add_scalar("train/ssimloss", ssimloss.item(), step)
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"train/num_GS", len(self.splats["means"]), step
|
||||||
|
)
|
||||||
|
self.writer.add_scalar("train/mem", mem, step)
|
||||||
|
if cfg.depth_loss:
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"train/depthloss", depthloss.item(), step
|
||||||
|
)
|
||||||
|
if cfg.tb_save_image:
|
||||||
|
canvas = (
|
||||||
|
torch.cat([pixels, colors], dim=2)
|
||||||
|
.detach()
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
)
|
||||||
|
canvas = canvas.reshape(-1, *canvas.shape[2:])
|
||||||
|
self.writer.add_image("train/render", canvas, step)
|
||||||
|
self.writer.flush()
|
||||||
|
|
||||||
|
# save checkpoint before updating the model
|
||||||
|
if (
|
||||||
|
step in [i - 1 for i in cfg.save_steps]
|
||||||
|
or step == max_steps - 1
|
||||||
|
):
|
||||||
|
mem = torch.cuda.max_memory_allocated() / 1024**3
|
||||||
|
stats = {
|
||||||
|
"mem": mem,
|
||||||
|
"ellipse_time": time.time() - global_tic,
|
||||||
|
"num_GS": len(self.splats["means"]),
|
||||||
|
}
|
||||||
|
print("Step: ", step, stats)
|
||||||
|
with open(
|
||||||
|
f"{self.stats_dir}/train_step{step:04d}_rank{self.world_rank}.json",
|
||||||
|
"w",
|
||||||
|
) as f:
|
||||||
|
json.dump(stats, f)
|
||||||
|
data = {"step": step, "splats": self.splats.state_dict()}
|
||||||
|
torch.save(
|
||||||
|
data,
|
||||||
|
f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt",
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
step in [i - 1 for i in cfg.ply_steps] or step == max_steps - 1
|
||||||
|
) and cfg.save_ply:
|
||||||
|
sh0 = self.splats["sh0"]
|
||||||
|
shN = self.splats["shN"]
|
||||||
|
means = self.splats["means"]
|
||||||
|
scales = self.splats["scales"]
|
||||||
|
quats = self.splats["quats"]
|
||||||
|
opacities = self.splats["opacities"]
|
||||||
|
export_splats(
|
||||||
|
means=means,
|
||||||
|
scales=scales,
|
||||||
|
quats=quats,
|
||||||
|
opacities=opacities,
|
||||||
|
sh0=sh0,
|
||||||
|
shN=shN,
|
||||||
|
format="ply",
|
||||||
|
save_to=f"{self.ply_dir}/point_cloud_{step}.ply",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Turn Gradients into Sparse Tensor before running optimizer
|
||||||
|
if cfg.sparse_grad:
|
||||||
|
assert (
|
||||||
|
cfg.packed
|
||||||
|
), "Sparse gradients only work with packed mode."
|
||||||
|
gaussian_ids = info["gaussian_ids"]
|
||||||
|
for k in self.splats.keys():
|
||||||
|
grad = self.splats[k].grad
|
||||||
|
if grad is None or grad.is_sparse:
|
||||||
|
continue
|
||||||
|
self.splats[k].grad = torch.sparse_coo_tensor(
|
||||||
|
indices=gaussian_ids[None], # [1, nnz]
|
||||||
|
values=grad[gaussian_ids], # [nnz, ...]
|
||||||
|
size=self.splats[k].size(), # [N, ...]
|
||||||
|
is_coalesced=len(Ks) == 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.visible_adam:
|
||||||
|
gaussian_cnt = self.splats.means.shape[0]
|
||||||
|
if cfg.packed:
|
||||||
|
visibility_mask = torch.zeros_like(
|
||||||
|
self.splats["opacities"], dtype=bool
|
||||||
|
)
|
||||||
|
visibility_mask.scatter_(0, info["gaussian_ids"], 1)
|
||||||
|
else:
|
||||||
|
visibility_mask = (info["radii"] > 0).all(-1).any(0)
|
||||||
|
|
||||||
|
# optimize
|
||||||
|
for optimizer in self.optimizers.values():
|
||||||
|
if cfg.visible_adam:
|
||||||
|
optimizer.step(visibility_mask)
|
||||||
|
else:
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
for scheduler in schedulers:
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
# Run post-backward steps after backward and optimizer
|
||||||
|
if isinstance(self.cfg.strategy, DefaultStrategy):
|
||||||
|
self.cfg.strategy.step_post_backward(
|
||||||
|
params=self.splats,
|
||||||
|
optimizers=self.optimizers,
|
||||||
|
state=self.strategy_state,
|
||||||
|
step=step,
|
||||||
|
info=info,
|
||||||
|
packed=cfg.packed,
|
||||||
|
)
|
||||||
|
elif isinstance(self.cfg.strategy, MCMCStrategy):
|
||||||
|
self.cfg.strategy.step_post_backward(
|
||||||
|
params=self.splats,
|
||||||
|
optimizers=self.optimizers,
|
||||||
|
state=self.strategy_state,
|
||||||
|
step=step,
|
||||||
|
info=info,
|
||||||
|
lr=schedulers[0].get_last_lr()[0],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert_never(self.cfg.strategy)
|
||||||
|
|
||||||
|
# eval the full set
|
||||||
|
if step in [i - 1 for i in cfg.eval_steps]:
|
||||||
|
self.eval(step)
|
||||||
|
self.render_video(step)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def eval(
|
||||||
|
self,
|
||||||
|
step: int,
|
||||||
|
stage: str = "val",
|
||||||
|
canvas_h: int = 512,
|
||||||
|
canvas_w: int = 1024,
|
||||||
|
):
|
||||||
|
"""Entry for evaluation."""
|
||||||
|
print("Running evaluation...")
|
||||||
|
cfg = self.cfg
|
||||||
|
device = self.device
|
||||||
|
world_rank = self.world_rank
|
||||||
|
|
||||||
|
valloader = torch.utils.data.DataLoader(
|
||||||
|
self.valset, batch_size=1, shuffle=False, num_workers=1
|
||||||
|
)
|
||||||
|
ellipse_time = 0
|
||||||
|
metrics = defaultdict(list)
|
||||||
|
for i, data in enumerate(valloader):
|
||||||
|
camtoworlds = data["camtoworld"].to(device)
|
||||||
|
Ks = data["K"].to(device)
|
||||||
|
pixels = data["image"].to(device) / 255.0
|
||||||
|
height, width = pixels.shape[1:3]
|
||||||
|
masks = data["mask"].to(device) if "mask" in data else None
|
||||||
|
|
||||||
|
pixels = pixels.permute(0, 3, 1, 2) # NHWC -> NCHW
|
||||||
|
pixels = F.interpolate(pixels, size=(canvas_h, canvas_w // 2))
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
tic = time.time()
|
||||||
|
colors, _, _ = self.rasterize_splats(
|
||||||
|
camtoworlds=camtoworlds,
|
||||||
|
Ks=Ks,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
sh_degree=cfg.sh_degree,
|
||||||
|
near_plane=cfg.near_plane,
|
||||||
|
far_plane=cfg.far_plane,
|
||||||
|
masks=masks,
|
||||||
|
) # [1, H, W, 3]
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
ellipse_time += max(time.time() - tic, 1e-10)
|
||||||
|
|
||||||
|
colors = colors.permute(0, 3, 1, 2) # NHWC -> NCHW
|
||||||
|
colors = F.interpolate(colors, size=(canvas_h, canvas_w // 2))
|
||||||
|
colors = torch.clamp(colors, 0.0, 1.0)
|
||||||
|
canvas_list = [pixels, colors]
|
||||||
|
|
||||||
|
if world_rank == 0:
|
||||||
|
canvas = torch.cat(canvas_list, dim=2).squeeze(0)
|
||||||
|
canvas = canvas.permute(1, 2, 0) # CHW -> HWC
|
||||||
|
canvas = (canvas * 255).to(torch.uint8).cpu().numpy()
|
||||||
|
cv2.imwrite(
|
||||||
|
f"{self.render_dir}/{stage}_step{step}_{i:04d}.png",
|
||||||
|
canvas[..., ::-1],
|
||||||
|
)
|
||||||
|
metrics["psnr"].append(self.psnr(colors, pixels))
|
||||||
|
metrics["ssim"].append(self.ssim(colors, pixels))
|
||||||
|
metrics["lpips"].append(self.lpips(colors, pixels))
|
||||||
|
|
||||||
|
if world_rank == 0:
|
||||||
|
ellipse_time /= len(valloader)
|
||||||
|
|
||||||
|
stats = {
|
||||||
|
k: torch.stack(v).mean().item() for k, v in metrics.items()
|
||||||
|
}
|
||||||
|
stats.update(
|
||||||
|
{
|
||||||
|
"ellipse_time": ellipse_time,
|
||||||
|
"num_GS": len(self.splats["means"]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"PSNR: {stats['psnr']:.3f}, SSIM: {stats['ssim']:.4f}, LPIPS: {stats['lpips']:.3f} "
|
||||||
|
f"Time: {stats['ellipse_time']:.3f}s/image "
|
||||||
|
f"Number of GS: {stats['num_GS']}"
|
||||||
|
)
|
||||||
|
# save stats as json
|
||||||
|
with open(
|
||||||
|
f"{self.stats_dir}/{stage}_step{step:04d}.json", "w"
|
||||||
|
) as f:
|
||||||
|
json.dump(stats, f)
|
||||||
|
# save stats to tensorboard
|
||||||
|
for k, v in stats.items():
|
||||||
|
self.writer.add_scalar(f"{stage}/{k}", v, step)
|
||||||
|
self.writer.flush()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def render_video(
|
||||||
|
self, step: int, canvas_h: int = 512, canvas_w: int = 1024
|
||||||
|
):
|
||||||
|
testloader = torch.utils.data.DataLoader(
|
||||||
|
self.testset, batch_size=1, shuffle=False, num_workers=1
|
||||||
|
)
|
||||||
|
|
||||||
|
images_cache = []
|
||||||
|
depth_global_min, depth_global_max = float("inf"), -float("inf")
|
||||||
|
for data in testloader:
|
||||||
|
camtoworlds = data["camtoworld"].to(self.device)
|
||||||
|
Ks = resize_pinhole_intrinsics(
|
||||||
|
data["K"].squeeze(),
|
||||||
|
raw_hw=(data["image_h"].item(), data["image_w"].item()),
|
||||||
|
new_hw=(canvas_h, canvas_w // 2),
|
||||||
|
).to(self.device)
|
||||||
|
renders, _, _ = self.rasterize_splats(
|
||||||
|
camtoworlds=camtoworlds,
|
||||||
|
Ks=Ks[None, ...],
|
||||||
|
width=canvas_w // 2,
|
||||||
|
height=canvas_h,
|
||||||
|
sh_degree=self.cfg.sh_degree,
|
||||||
|
near_plane=self.cfg.near_plane,
|
||||||
|
far_plane=self.cfg.far_plane,
|
||||||
|
render_mode="RGB+ED",
|
||||||
|
) # [1, H, W, 4]
|
||||||
|
colors = torch.clamp(renders[0, ..., 0:3], 0.0, 1.0) # [H, W, 3]
|
||||||
|
colors = (colors * 255).to(torch.uint8).cpu().numpy()
|
||||||
|
depths = renders[0, ..., 3:4] # [H, W, 1], tensor in device.
|
||||||
|
images_cache.append([colors, depths])
|
||||||
|
depth_global_min = min(depth_global_min, depths.min().item())
|
||||||
|
depth_global_max = max(depth_global_max, depths.max().item())
|
||||||
|
|
||||||
|
video_path = f"{self.render_dir}/video_step{step}.mp4"
|
||||||
|
writer = imageio.get_writer(video_path, fps=30)
|
||||||
|
for rgb, depth in images_cache:
|
||||||
|
depth_normalized = torch.clip(
|
||||||
|
(depth - depth_global_min)
|
||||||
|
/ (depth_global_max - depth_global_min),
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
depth_normalized = (
|
||||||
|
(depth_normalized * 255).to(torch.uint8).cpu().numpy()
|
||||||
|
)
|
||||||
|
depth_map = cv2.applyColorMap(depth_normalized, cv2.COLORMAP_JET)
|
||||||
|
image = np.concatenate([rgb, depth_map], axis=1)
|
||||||
|
writer.append_data(image)
|
||||||
|
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
|
||||||
|
def entrypoint(
|
||||||
|
local_rank: int, world_rank, world_size: int, cfg: GsplatTrainConfig
|
||||||
|
):
|
||||||
|
runner = Runner(local_rank, world_rank, world_size, cfg)
|
||||||
|
|
||||||
|
if cfg.ckpt is not None:
|
||||||
|
# run eval only
|
||||||
|
ckpts = [
|
||||||
|
torch.load(file, map_location=runner.device, weights_only=True)
|
||||||
|
for file in cfg.ckpt
|
||||||
|
]
|
||||||
|
for k in runner.splats.keys():
|
||||||
|
runner.splats[k].data = torch.cat(
|
||||||
|
[ckpt["splats"][k] for ckpt in ckpts]
|
||||||
|
)
|
||||||
|
step = ckpts[0]["step"]
|
||||||
|
runner.eval(step=step)
|
||||||
|
runner.render_video(step=step)
|
||||||
|
else:
|
||||||
|
runner.train()
|
||||||
|
runner.render_video(step=cfg.max_steps - 1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
configs = {
|
||||||
|
"default": (
|
||||||
|
"Gaussian splatting training using densification heuristics from the original paper.",
|
||||||
|
GsplatTrainConfig(
|
||||||
|
strategy=DefaultStrategy(verbose=True),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
"mcmc": (
|
||||||
|
"Gaussian splatting training using densification from the paper '3D Gaussian Splatting as Markov Chain Monte Carlo'.",
|
||||||
|
GsplatTrainConfig(
|
||||||
|
init_scale=0.1,
|
||||||
|
opacity_reg=0.01,
|
||||||
|
scale_reg=0.01,
|
||||||
|
strategy=MCMCStrategy(verbose=True),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
cfg = tyro.extras.overridable_config_cli(configs)
|
||||||
|
cfg.adjust_steps(cfg.steps_scaler)
|
||||||
|
|
||||||
|
cli(entrypoint, cfg, verbose=True)
|
||||||
538
embodied_gen/trainer/pono2mesh_trainer.py
Normal file
538
embodied_gen/trainer/pono2mesh_trainer.py
Normal file
@ -0,0 +1,538 @@
|
|||||||
|
# Project EmbodiedGen
|
||||||
|
#
|
||||||
|
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
# implied. See the License for the specific language governing
|
||||||
|
# permissions and limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from embodied_gen.utils.monkey_patches import monkey_patch_pano2room
|
||||||
|
|
||||||
|
monkey_patch_pano2room()
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import trimesh
|
||||||
|
from equilib import cube2equi, equi2pers
|
||||||
|
from kornia.morphology import dilation
|
||||||
|
from PIL import Image
|
||||||
|
from embodied_gen.models.sr_model import ImageRealESRGAN
|
||||||
|
from embodied_gen.utils.config import Pano2MeshSRConfig
|
||||||
|
from embodied_gen.utils.gaussian import compute_pinhole_intrinsics
|
||||||
|
from embodied_gen.utils.log import logger
|
||||||
|
from thirdparty.pano2room.modules.geo_predictors import PanoJointPredictor
|
||||||
|
from thirdparty.pano2room.modules.geo_predictors.PanoFusionDistancePredictor import (
|
||||||
|
PanoFusionDistancePredictor,
|
||||||
|
)
|
||||||
|
from thirdparty.pano2room.modules.inpainters import PanoPersFusionInpainter
|
||||||
|
from thirdparty.pano2room.modules.mesh_fusion.render import (
|
||||||
|
features_to_world_space_mesh,
|
||||||
|
render_mesh,
|
||||||
|
)
|
||||||
|
from thirdparty.pano2room.modules.mesh_fusion.sup_info import SupInfoPool
|
||||||
|
from thirdparty.pano2room.utils.camera_utils import gen_pano_rays
|
||||||
|
from thirdparty.pano2room.utils.functions import (
|
||||||
|
depth_to_distance,
|
||||||
|
get_cubemap_views_world_to_cam,
|
||||||
|
resize_image_with_aspect_ratio,
|
||||||
|
rot_z_world_to_cam,
|
||||||
|
tensor_to_pil,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Pano2MeshSRPipeline:
|
||||||
|
"""Converting panoramic RGB image into 3D mesh representations, followed by inpainting and mesh refinement.
|
||||||
|
|
||||||
|
This class integrates several key components including:
|
||||||
|
- Depth estimation from RGB panorama
|
||||||
|
- Inpainting of missing regions under offsets
|
||||||
|
- RGB-D to mesh conversion
|
||||||
|
- Multi-view mesh repair
|
||||||
|
- 3D Gaussian Splatting (3DGS) dataset generation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (Pano2MeshSRConfig): Configuration object containing model and pipeline parameters.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
pipeline = Pano2MeshSRPipeline(config)
|
||||||
|
pipeline(pano_image='example.png', output_dir='./output')
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Pano2MeshSRConfig) -> None:
|
||||||
|
self.cfg = config
|
||||||
|
self.device = config.device
|
||||||
|
|
||||||
|
# Init models.
|
||||||
|
self.inpainter = PanoPersFusionInpainter(save_path=None)
|
||||||
|
self.geo_predictor = PanoJointPredictor(save_path=None)
|
||||||
|
self.pano_fusion_distance_predictor = PanoFusionDistancePredictor()
|
||||||
|
self.super_model = ImageRealESRGAN(outscale=self.cfg.upscale_factor)
|
||||||
|
|
||||||
|
# Init poses.
|
||||||
|
cubemap_w2cs = get_cubemap_views_world_to_cam()
|
||||||
|
self.cubemap_w2cs = [p.to(self.device) for p in cubemap_w2cs]
|
||||||
|
self.camera_poses = self.load_camera_poses(self.cfg.trajectory_dir)
|
||||||
|
|
||||||
|
kernel = cv2.getStructuringElement(
|
||||||
|
cv2.MORPH_ELLIPSE, self.cfg.kernel_size
|
||||||
|
)
|
||||||
|
self.kernel = torch.from_numpy(kernel).float().to(self.device)
|
||||||
|
|
||||||
|
def init_mesh_params(self) -> None:
|
||||||
|
torch.set_default_device(self.device)
|
||||||
|
self.inpaint_mask = torch.ones(
|
||||||
|
(self.cfg.cubemap_h, self.cfg.cubemap_w), dtype=torch.bool
|
||||||
|
)
|
||||||
|
self.vertices = torch.empty((3, 0), requires_grad=False)
|
||||||
|
self.colors = torch.empty((3, 0), requires_grad=False)
|
||||||
|
self.faces = torch.empty((3, 0), dtype=torch.long, requires_grad=False)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def read_camera_pose_file(filepath: str) -> np.ndarray:
|
||||||
|
with open(filepath, "r") as f:
|
||||||
|
values = [float(num) for line in f for num in line.split()]
|
||||||
|
|
||||||
|
return np.array(values).reshape(4, 4)
|
||||||
|
|
||||||
|
def load_camera_poses(
|
||||||
|
self, trajectory_dir: str
|
||||||
|
) -> tuple[np.ndarray, list[torch.Tensor]]:
|
||||||
|
pose_filenames = sorted(
|
||||||
|
[
|
||||||
|
fname
|
||||||
|
for fname in os.listdir(trajectory_dir)
|
||||||
|
if fname.startswith("camera_pose")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
pano_pose_world = None
|
||||||
|
relative_poses = []
|
||||||
|
for idx, filename in enumerate(pose_filenames):
|
||||||
|
pose_path = os.path.join(trajectory_dir, filename)
|
||||||
|
pose_matrix = self.read_camera_pose_file(pose_path)
|
||||||
|
|
||||||
|
if pano_pose_world is None:
|
||||||
|
pano_pose_world = pose_matrix.copy()
|
||||||
|
pano_pose_world[0, 3] += self.cfg.pano_center_offset[0]
|
||||||
|
pano_pose_world[2, 3] += self.cfg.pano_center_offset[1]
|
||||||
|
|
||||||
|
# Use different reference for the first 6 cubemap views
|
||||||
|
reference_pose = pose_matrix if idx < 6 else pano_pose_world
|
||||||
|
relative_matrix = pose_matrix @ np.linalg.inv(reference_pose)
|
||||||
|
relative_matrix[0:2, :] *= -1 # flip_xy
|
||||||
|
relative_matrix = (
|
||||||
|
relative_matrix @ rot_z_world_to_cam(180).cpu().numpy()
|
||||||
|
)
|
||||||
|
relative_matrix[:3, 3] *= self.cfg.pose_scale
|
||||||
|
relative_matrix = torch.tensor(
|
||||||
|
relative_matrix, dtype=torch.float32
|
||||||
|
)
|
||||||
|
relative_poses.append(relative_matrix)
|
||||||
|
|
||||||
|
return relative_poses
|
||||||
|
|
||||||
|
def load_inpaint_poses(
|
||||||
|
self, poses: torch.Tensor
|
||||||
|
) -> dict[int, torch.Tensor]:
|
||||||
|
inpaint_poses = dict()
|
||||||
|
sampled_views = poses[:: self.cfg.inpaint_frame_stride]
|
||||||
|
init_pose = torch.eye(4)
|
||||||
|
for idx, w2c_tensor in enumerate(sampled_views):
|
||||||
|
w2c = w2c_tensor.cpu().numpy().astype(np.float32)
|
||||||
|
c2w = np.linalg.inv(w2c)
|
||||||
|
pose_tensor = init_pose.clone()
|
||||||
|
pose_tensor[:3, 3] = torch.from_numpy(c2w[:3, 3])
|
||||||
|
pose_tensor[:3, 3] *= -1
|
||||||
|
inpaint_poses[idx] = pose_tensor.to(self.device)
|
||||||
|
|
||||||
|
return inpaint_poses
|
||||||
|
|
||||||
|
def project(self, world_to_cam: torch.Tensor):
|
||||||
|
(
|
||||||
|
project_image,
|
||||||
|
project_depth,
|
||||||
|
inpaint_mask,
|
||||||
|
_,
|
||||||
|
z_buf,
|
||||||
|
mesh,
|
||||||
|
) = render_mesh(
|
||||||
|
vertices=self.vertices,
|
||||||
|
faces=self.faces,
|
||||||
|
vertex_features=self.colors,
|
||||||
|
H=self.cfg.cubemap_h,
|
||||||
|
W=self.cfg.cubemap_w,
|
||||||
|
fov_in_degrees=self.cfg.fov,
|
||||||
|
RT=world_to_cam,
|
||||||
|
blur_radius=self.cfg.blur_radius,
|
||||||
|
faces_per_pixel=self.cfg.faces_per_pixel,
|
||||||
|
)
|
||||||
|
project_image = project_image * ~inpaint_mask
|
||||||
|
|
||||||
|
return project_image[:3, ...], inpaint_mask, project_depth
|
||||||
|
|
||||||
|
def render_pano(self, pose: torch.Tensor):
|
||||||
|
cubemap_list = []
|
||||||
|
for cubemap_pose in self.cubemap_w2cs:
|
||||||
|
project_pose = cubemap_pose @ pose
|
||||||
|
rgb, inpaint_mask, depth = self.project(project_pose)
|
||||||
|
distance_map = depth_to_distance(depth[None, ...])
|
||||||
|
mask = inpaint_mask[None, ...]
|
||||||
|
cubemap_list.append(torch.cat([rgb, distance_map, mask], dim=0))
|
||||||
|
|
||||||
|
# Set default tensor type for CPU operation in cube2equi
|
||||||
|
with torch.device("cpu"):
|
||||||
|
pano_rgbd = cube2equi(
|
||||||
|
cubemap_list, "list", self.cfg.pano_h, self.cfg.pano_w
|
||||||
|
)
|
||||||
|
|
||||||
|
pano_rgb = pano_rgbd[:3, :, :]
|
||||||
|
pano_depth = pano_rgbd[3:4, :, :].squeeze(0)
|
||||||
|
pano_mask = pano_rgbd[4:, :, :].squeeze(0)
|
||||||
|
|
||||||
|
return pano_rgb, pano_depth, pano_mask
|
||||||
|
|
||||||
|
def rgbd_to_mesh(
|
||||||
|
self,
|
||||||
|
rgb: torch.Tensor,
|
||||||
|
depth: torch.Tensor,
|
||||||
|
inpaint_mask: torch.Tensor,
|
||||||
|
world_to_cam: torch.Tensor = None,
|
||||||
|
using_distance_map: bool = True,
|
||||||
|
) -> None:
|
||||||
|
if world_to_cam is None:
|
||||||
|
world_to_cam = torch.eye(4, dtype=torch.float32).to(self.device)
|
||||||
|
|
||||||
|
if inpaint_mask.sum() == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
vertices, faces, colors = features_to_world_space_mesh(
|
||||||
|
colors=rgb.squeeze(0),
|
||||||
|
depth=depth,
|
||||||
|
fov_in_degrees=self.cfg.fov,
|
||||||
|
world_to_cam=world_to_cam,
|
||||||
|
mask=inpaint_mask,
|
||||||
|
faces=self.faces,
|
||||||
|
vertices=self.vertices,
|
||||||
|
using_distance_map=using_distance_map,
|
||||||
|
edge_threshold=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
faces += self.vertices.shape[1]
|
||||||
|
self.vertices = torch.cat([self.vertices, vertices], dim=1)
|
||||||
|
self.colors = torch.cat([self.colors, colors], dim=1)
|
||||||
|
self.faces = torch.cat([self.faces, faces], dim=1)
|
||||||
|
|
||||||
|
def get_edge_image_by_depth(
|
||||||
|
self, depth: torch.Tensor, dilate_iter: int = 1
|
||||||
|
) -> np.ndarray:
|
||||||
|
if isinstance(depth, torch.Tensor):
|
||||||
|
depth = depth.cpu().detach().numpy()
|
||||||
|
|
||||||
|
gray = (depth / depth.max() * 255).astype(np.uint8)
|
||||||
|
edges = cv2.Canny(gray, 60, 150)
|
||||||
|
if dilate_iter > 0:
|
||||||
|
kernel = np.ones((3, 3), np.uint8)
|
||||||
|
edges = cv2.dilate(edges, kernel, iterations=dilate_iter)
|
||||||
|
|
||||||
|
return edges
|
||||||
|
|
||||||
|
def mesh_repair_by_greedy_view_selection(
|
||||||
|
self, pose_dict: dict[str, torch.Tensor], output_dir: str
|
||||||
|
) -> list:
|
||||||
|
inpainted_panos_w_pose = []
|
||||||
|
while len(pose_dict) > 0:
|
||||||
|
logger.info(f"Repairing mesh left rounds {len(pose_dict)}")
|
||||||
|
sampled_views = []
|
||||||
|
for key, pose in pose_dict.items():
|
||||||
|
pano_rgb, pano_distance, pano_mask = self.render_pano(pose)
|
||||||
|
completeness = torch.sum(1 - pano_mask) / (pano_mask.numel())
|
||||||
|
sampled_views.append((key, completeness.item(), pose))
|
||||||
|
|
||||||
|
if len(sampled_views) == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Find inpainting with least view completeness.
|
||||||
|
sampled_views = sorted(sampled_views, key=lambda x: x[1])
|
||||||
|
key, _, pose = sampled_views[len(sampled_views) * 2 // 3]
|
||||||
|
pose_dict.pop(key)
|
||||||
|
|
||||||
|
pano_rgb, pano_distance, pano_mask = self.render_pano(pose)
|
||||||
|
|
||||||
|
colors = pano_rgb.permute(1, 2, 0).clone()
|
||||||
|
distances = pano_distance.unsqueeze(-1).clone()
|
||||||
|
pano_inpaint_mask = pano_mask.clone()
|
||||||
|
init_pose = pose.clone()
|
||||||
|
normals = None
|
||||||
|
if pano_inpaint_mask.min().item() < 0.5:
|
||||||
|
colors, distances, normals = self.inpaint_panorama(
|
||||||
|
idx=key,
|
||||||
|
colors=colors,
|
||||||
|
distances=distances,
|
||||||
|
pano_mask=pano_inpaint_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
init_pose[0, 3], init_pose[1, 3], init_pose[2, 3] = (
|
||||||
|
-pose[0, 3],
|
||||||
|
pose[2, 3],
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
rays = gen_pano_rays(
|
||||||
|
init_pose, self.cfg.pano_h, self.cfg.pano_w
|
||||||
|
)
|
||||||
|
conflict_mask = self.sup_pool.geo_check(
|
||||||
|
rays, distances.unsqueeze(-1)
|
||||||
|
) # 0 is conflict, 1 not conflict
|
||||||
|
pano_inpaint_mask *= conflict_mask
|
||||||
|
|
||||||
|
self.rgbd_to_mesh(
|
||||||
|
colors.permute(2, 0, 1),
|
||||||
|
distances,
|
||||||
|
pano_inpaint_mask,
|
||||||
|
world_to_cam=pose,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.sup_pool.register_sup_info(
|
||||||
|
pose=init_pose,
|
||||||
|
mask=pano_inpaint_mask.clone(),
|
||||||
|
rgb=colors,
|
||||||
|
distance=distances.unsqueeze(-1),
|
||||||
|
normal=normals,
|
||||||
|
)
|
||||||
|
|
||||||
|
colors = colors.permute(2, 0, 1).unsqueeze(0)
|
||||||
|
inpainted_panos_w_pose.append([colors, pose])
|
||||||
|
|
||||||
|
if self.cfg.visualize:
|
||||||
|
from embodied_gen.data.utils import DiffrastRender
|
||||||
|
|
||||||
|
tensor_to_pil(pano_rgb.unsqueeze(0)).save(
|
||||||
|
f"{output_dir}/rendered_pano_{key}.jpg"
|
||||||
|
)
|
||||||
|
tensor_to_pil(colors).save(
|
||||||
|
f"{output_dir}/inpainted_pano_{key}.jpg"
|
||||||
|
)
|
||||||
|
norm_depth = DiffrastRender.normalize_map_by_mask(
|
||||||
|
distances, torch.ones_like(distances)
|
||||||
|
)
|
||||||
|
heatmap = (norm_depth.cpu().numpy() * 255).astype(np.uint8)
|
||||||
|
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
|
||||||
|
Image.fromarray(heatmap).save(
|
||||||
|
f"{output_dir}/inpainted_depth_{key}.png"
|
||||||
|
)
|
||||||
|
|
||||||
|
return inpainted_panos_w_pose
|
||||||
|
|
||||||
|
def inpaint_panorama(
|
||||||
|
self,
|
||||||
|
idx: int,
|
||||||
|
colors: torch.Tensor,
|
||||||
|
distances: torch.Tensor,
|
||||||
|
pano_mask: torch.Tensor,
|
||||||
|
) -> tuple[torch.Tensor]:
|
||||||
|
mask = (pano_mask[None, ..., None] > 0.5).float()
|
||||||
|
mask = mask.permute(0, 3, 1, 2)
|
||||||
|
mask = dilation(mask, kernel=self.kernel)
|
||||||
|
mask = mask[0, 0, ..., None] # hwc
|
||||||
|
inpainted_img = self.inpainter.inpaint(idx, colors, mask)
|
||||||
|
inpainted_img = colors * (1 - mask) + inpainted_img * mask
|
||||||
|
inpainted_distances, inpainted_normals = self.geo_predictor(
|
||||||
|
idx,
|
||||||
|
inpainted_img,
|
||||||
|
distances[..., None],
|
||||||
|
mask=mask,
|
||||||
|
reg_loss_weight=0.0,
|
||||||
|
normal_loss_weight=5e-2,
|
||||||
|
normal_tv_loss_weight=5e-2,
|
||||||
|
)
|
||||||
|
|
||||||
|
return inpainted_img, inpainted_distances.squeeze(), inpainted_normals
|
||||||
|
|
||||||
|
def preprocess_pano(
|
||||||
|
self, image: Image.Image | str
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if isinstance(image, str):
|
||||||
|
image = Image.open(image)
|
||||||
|
|
||||||
|
image = image.convert("RGB")
|
||||||
|
|
||||||
|
if image.size[0] < image.size[1]:
|
||||||
|
image = image.transpose(Image.TRANSPOSE)
|
||||||
|
|
||||||
|
image = resize_image_with_aspect_ratio(image, self.cfg.pano_w)
|
||||||
|
image_rgb = torch.tensor(np.array(image)).permute(2, 0, 1) / 255
|
||||||
|
image_rgb = image_rgb.to(self.device)
|
||||||
|
image_depth = self.pano_fusion_distance_predictor.predict(
|
||||||
|
image_rgb.permute(1, 2, 0)
|
||||||
|
)
|
||||||
|
image_depth = (
|
||||||
|
image_depth / image_depth.max() * self.cfg.depth_scale_factor
|
||||||
|
)
|
||||||
|
|
||||||
|
return image_rgb, image_depth
|
||||||
|
|
||||||
|
def pano_to_perpective(
|
||||||
|
self, pano_image: torch.Tensor, pitch: float, yaw: float, fov: float
|
||||||
|
) -> torch.Tensor:
|
||||||
|
rots = dict(
|
||||||
|
roll=0,
|
||||||
|
pitch=pitch,
|
||||||
|
yaw=yaw,
|
||||||
|
)
|
||||||
|
perspective = equi2pers(
|
||||||
|
equi=pano_image.squeeze(0),
|
||||||
|
rots=rots,
|
||||||
|
height=self.cfg.cubemap_h,
|
||||||
|
width=self.cfg.cubemap_w,
|
||||||
|
fov_x=fov,
|
||||||
|
mode="bilinear",
|
||||||
|
).unsqueeze(0)
|
||||||
|
|
||||||
|
return perspective
|
||||||
|
|
||||||
|
def pano_to_cubemap(self, pano_rgb: torch.Tensor):
|
||||||
|
# Define six canonical cube directions in (pitch, yaw)
|
||||||
|
directions = [
|
||||||
|
(0, 0),
|
||||||
|
(0, 1.5 * np.pi),
|
||||||
|
(0, 1.0 * np.pi),
|
||||||
|
(0, 0.5 * np.pi),
|
||||||
|
(-0.5 * np.pi, 0),
|
||||||
|
(0.5 * np.pi, 0),
|
||||||
|
]
|
||||||
|
|
||||||
|
cubemaps_rgb = []
|
||||||
|
for pitch, yaw in directions:
|
||||||
|
rgb_view = self.pano_to_perpective(
|
||||||
|
pano_rgb, pitch, yaw, fov=self.cfg.fov
|
||||||
|
)
|
||||||
|
cubemaps_rgb.append(rgb_view.cpu())
|
||||||
|
|
||||||
|
return cubemaps_rgb
|
||||||
|
|
||||||
|
def save_mesh(self, output_path: str) -> None:
|
||||||
|
vertices_np = self.vertices.T.cpu().numpy()
|
||||||
|
colors_np = self.colors.T.cpu().numpy()
|
||||||
|
faces_np = self.faces.T.cpu().numpy()
|
||||||
|
mesh = trimesh.Trimesh(
|
||||||
|
vertices=vertices_np, faces=faces_np, vertex_colors=colors_np
|
||||||
|
)
|
||||||
|
|
||||||
|
mesh.export(output_path)
|
||||||
|
|
||||||
|
def mesh_pose_to_gs_pose(self, mesh_pose: torch.Tensor) -> np.ndarray:
|
||||||
|
pose = mesh_pose.clone()
|
||||||
|
pose[0, :] *= -1
|
||||||
|
pose[1, :] *= -1
|
||||||
|
|
||||||
|
Rw2c = pose[:3, :3].cpu().numpy()
|
||||||
|
Tw2c = pose[:3, 3:].cpu().numpy()
|
||||||
|
yz_reverse = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]])
|
||||||
|
|
||||||
|
Rc2w = (yz_reverse @ Rw2c).T
|
||||||
|
Tc2w = -(Rc2w @ yz_reverse @ Tw2c)
|
||||||
|
c2w = np.concatenate((Rc2w, Tc2w), axis=1)
|
||||||
|
c2w = np.concatenate((c2w, np.array([[0, 0, 0, 1]])), axis=0)
|
||||||
|
|
||||||
|
return c2w
|
||||||
|
|
||||||
|
def __call__(self, pano_image: Image.Image | str, output_dir: str):
|
||||||
|
self.init_mesh_params()
|
||||||
|
pano_rgb, pano_depth = self.preprocess_pano(pano_image)
|
||||||
|
self.sup_pool = SupInfoPool()
|
||||||
|
self.sup_pool.register_sup_info(
|
||||||
|
pose=torch.eye(4).to(self.device),
|
||||||
|
mask=torch.ones([self.cfg.pano_h, self.cfg.pano_w]),
|
||||||
|
rgb=pano_rgb.permute(1, 2, 0),
|
||||||
|
distance=pano_depth[..., None],
|
||||||
|
)
|
||||||
|
self.sup_pool.gen_occ_grid(res=256)
|
||||||
|
|
||||||
|
logger.info("Init mesh from pano RGBD image...")
|
||||||
|
depth_edge = self.get_edge_image_by_depth(pano_depth)
|
||||||
|
inpaint_edge_mask = (
|
||||||
|
~torch.from_numpy(depth_edge).to(self.device).bool()
|
||||||
|
)
|
||||||
|
self.rgbd_to_mesh(pano_rgb, pano_depth, inpaint_edge_mask)
|
||||||
|
|
||||||
|
repair_poses = self.load_inpaint_poses(self.camera_poses)
|
||||||
|
inpainted_panos_w_poses = self.mesh_repair_by_greedy_view_selection(
|
||||||
|
repair_poses, output_dir
|
||||||
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.set_default_device("cpu")
|
||||||
|
|
||||||
|
if self.cfg.mesh_file is not None:
|
||||||
|
mesh_path = os.path.join(output_dir, self.cfg.mesh_file)
|
||||||
|
self.save_mesh(mesh_path)
|
||||||
|
|
||||||
|
if self.cfg.gs_data_file is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Dump data for 3DGS training...")
|
||||||
|
points_rgb = (self.colors.clip(0, 1) * 255).to(torch.uint8)
|
||||||
|
data = {
|
||||||
|
"points": self.vertices.permute(1, 0).cpu().numpy(), # (N, 3)
|
||||||
|
"points_rgb": points_rgb.permute(1, 0).cpu().numpy(), # (N, 3)
|
||||||
|
"train": [],
|
||||||
|
"eval": [],
|
||||||
|
}
|
||||||
|
image_h = self.cfg.cubemap_h * self.cfg.upscale_factor
|
||||||
|
image_w = self.cfg.cubemap_w * self.cfg.upscale_factor
|
||||||
|
Ks = compute_pinhole_intrinsics(image_w, image_h, self.cfg.fov)
|
||||||
|
for idx, (pano_img, pano_pose) in enumerate(inpainted_panos_w_poses):
|
||||||
|
cubemaps = self.pano_to_cubemap(pano_img)
|
||||||
|
for i in range(len(cubemaps)):
|
||||||
|
cubemap = tensor_to_pil(cubemaps[i])
|
||||||
|
cubemap = self.super_model(cubemap)
|
||||||
|
mesh_pose = self.cubemap_w2cs[i] @ pano_pose
|
||||||
|
c2w = self.mesh_pose_to_gs_pose(mesh_pose)
|
||||||
|
data["train"].append(
|
||||||
|
{
|
||||||
|
"camtoworld": c2w.astype(np.float32),
|
||||||
|
"K": Ks.astype(np.float32),
|
||||||
|
"image": np.array(cubemap),
|
||||||
|
"image_h": image_h,
|
||||||
|
"image_w": image_w,
|
||||||
|
"image_id": len(cubemaps) * idx + i,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Camera poses for evaluation.
|
||||||
|
for idx in range(len(self.camera_poses)):
|
||||||
|
c2w = self.mesh_pose_to_gs_pose(self.camera_poses[idx])
|
||||||
|
data["eval"].append(
|
||||||
|
{
|
||||||
|
"camtoworld": c2w.astype(np.float32),
|
||||||
|
"K": Ks.astype(np.float32),
|
||||||
|
"image_h": image_h,
|
||||||
|
"image_w": image_w,
|
||||||
|
"image_id": idx,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
data_path = os.path.join(output_dir, self.cfg.gs_data_file)
|
||||||
|
torch.save(data, data_path)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
output_dir = "outputs/bg_v2/test3"
|
||||||
|
input_pano = "apps/assets/example_scene/result_pano.png"
|
||||||
|
config = Pano2MeshSRConfig()
|
||||||
|
pipeline = Pano2MeshSRPipeline(config)
|
||||||
|
pipeline(input_pano, output_dir)
|
||||||
190
embodied_gen/utils/config.py
Normal file
190
embodied_gen/utils/config.py
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
# Project EmbodiedGen
|
||||||
|
#
|
||||||
|
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
# implied. See the License for the specific language governing
|
||||||
|
# permissions and limitations under the License.
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
from gsplat.strategy import DefaultStrategy, MCMCStrategy
|
||||||
|
from typing_extensions import Literal, assert_never
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Pano2MeshSRConfig",
|
||||||
|
"GsplatTrainConfig",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Pano2MeshSRConfig:
|
||||||
|
mesh_file: str = "mesh_model.ply"
|
||||||
|
gs_data_file: str = "gs_data.pt"
|
||||||
|
device: str = "cuda"
|
||||||
|
blur_radius: int = 0
|
||||||
|
faces_per_pixel: int = 8
|
||||||
|
fov: int = 90
|
||||||
|
pano_w: int = 2048
|
||||||
|
pano_h: int = 1024
|
||||||
|
cubemap_w: int = 512
|
||||||
|
cubemap_h: int = 512
|
||||||
|
pose_scale: float = 0.6
|
||||||
|
pano_center_offset: tuple = (-0.2, 0.3)
|
||||||
|
inpaint_frame_stride: int = 20
|
||||||
|
trajectory_dir: str = "apps/assets/example_scene/camera_trajectory"
|
||||||
|
visualize: bool = False
|
||||||
|
depth_scale_factor: float = 3.4092
|
||||||
|
kernel_size: tuple = (9, 9)
|
||||||
|
upscale_factor: int = 4
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GsplatTrainConfig:
|
||||||
|
# Path to the .pt files. If provide, it will skip training and run evaluation only.
|
||||||
|
ckpt: Optional[List[str]] = None
|
||||||
|
# Render trajectory path
|
||||||
|
render_traj_path: str = "interp"
|
||||||
|
|
||||||
|
# Path to the Mip-NeRF 360 dataset
|
||||||
|
data_dir: str = "outputs/bg"
|
||||||
|
# Downsample factor for the dataset
|
||||||
|
data_factor: int = 4
|
||||||
|
# Directory to save results
|
||||||
|
result_dir: str = "outputs/bg"
|
||||||
|
# Every N images there is a test image
|
||||||
|
test_every: int = 8
|
||||||
|
# Random crop size for training (experimental)
|
||||||
|
patch_size: Optional[int] = None
|
||||||
|
# A global scaler that applies to the scene size related parameters
|
||||||
|
global_scale: float = 1.0
|
||||||
|
# Normalize the world space
|
||||||
|
normalize_world_space: bool = True
|
||||||
|
# Camera model
|
||||||
|
camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole"
|
||||||
|
|
||||||
|
# Port for the viewer server
|
||||||
|
port: int = 8080
|
||||||
|
|
||||||
|
# Batch size for training. Learning rates are scaled automatically
|
||||||
|
batch_size: int = 1
|
||||||
|
# A global factor to scale the number of training steps
|
||||||
|
steps_scaler: float = 1.0
|
||||||
|
|
||||||
|
# Number of training steps
|
||||||
|
max_steps: int = 30_000
|
||||||
|
# Steps to evaluate the model
|
||||||
|
eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
|
||||||
|
# Steps to save the model
|
||||||
|
save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
|
||||||
|
# Whether to save ply file (storage size can be large)
|
||||||
|
save_ply: bool = True
|
||||||
|
# Steps to save the model as ply
|
||||||
|
ply_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
|
||||||
|
# Whether to disable video generation during training and evaluation
|
||||||
|
disable_video: bool = False
|
||||||
|
|
||||||
|
# Initial number of GSs. Ignored if using sfm
|
||||||
|
init_num_pts: int = 100_000
|
||||||
|
# Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm
|
||||||
|
init_extent: float = 3.0
|
||||||
|
# Degree of spherical harmonics
|
||||||
|
sh_degree: int = 1
|
||||||
|
# Turn on another SH degree every this steps
|
||||||
|
sh_degree_interval: int = 1000
|
||||||
|
# Initial opacity of GS
|
||||||
|
init_opa: float = 0.1
|
||||||
|
# Initial scale of GS
|
||||||
|
init_scale: float = 1.0
|
||||||
|
# Weight for SSIM loss
|
||||||
|
ssim_lambda: float = 0.2
|
||||||
|
|
||||||
|
# Near plane clipping distance
|
||||||
|
near_plane: float = 0.01
|
||||||
|
# Far plane clipping distance
|
||||||
|
far_plane: float = 1e10
|
||||||
|
|
||||||
|
# Strategy for GS densification
|
||||||
|
strategy: Union[DefaultStrategy, MCMCStrategy] = field(
|
||||||
|
default_factory=DefaultStrategy
|
||||||
|
)
|
||||||
|
# Use packed mode for rasterization, this leads to less memory usage but slightly slower.
|
||||||
|
packed: bool = False
|
||||||
|
# Use sparse gradients for optimization. (experimental)
|
||||||
|
sparse_grad: bool = False
|
||||||
|
# Use visible adam from Taming 3DGS. (experimental)
|
||||||
|
visible_adam: bool = False
|
||||||
|
# Anti-aliasing in rasterization. Might slightly hurt quantitative metrics.
|
||||||
|
antialiased: bool = False
|
||||||
|
|
||||||
|
# Use random background for training to discourage transparency
|
||||||
|
random_bkgd: bool = False
|
||||||
|
|
||||||
|
# LR for 3D point positions
|
||||||
|
means_lr: float = 1.6e-4
|
||||||
|
# LR for Gaussian scale factors
|
||||||
|
scales_lr: float = 5e-3
|
||||||
|
# LR for alpha blending weights
|
||||||
|
opacities_lr: float = 5e-2
|
||||||
|
# LR for orientation (quaternions)
|
||||||
|
quats_lr: float = 1e-3
|
||||||
|
# LR for SH band 0 (brightness)
|
||||||
|
sh0_lr: float = 2.5e-3
|
||||||
|
# LR for higher-order SH (detail)
|
||||||
|
shN_lr: float = 2.5e-3 / 20
|
||||||
|
|
||||||
|
# Opacity regularization
|
||||||
|
opacity_reg: float = 0.0
|
||||||
|
# Scale regularization
|
||||||
|
scale_reg: float = 0.0
|
||||||
|
|
||||||
|
# Enable depth loss. (experimental)
|
||||||
|
depth_loss: bool = False
|
||||||
|
# Weight for depth loss
|
||||||
|
depth_lambda: float = 1e-2
|
||||||
|
|
||||||
|
# Dump information to tensorboard every this steps
|
||||||
|
tb_every: int = 200
|
||||||
|
# Save training images to tensorboard
|
||||||
|
tb_save_image: bool = False
|
||||||
|
|
||||||
|
lpips_net: Literal["vgg", "alex"] = "alex"
|
||||||
|
|
||||||
|
# 3DGUT (uncented transform + eval 3D)
|
||||||
|
with_ut: bool = False
|
||||||
|
with_eval3d: bool = False
|
||||||
|
|
||||||
|
scene_scale: float = 1.0
|
||||||
|
|
||||||
|
def adjust_steps(self, factor: float):
|
||||||
|
self.eval_steps = [int(i * factor) for i in self.eval_steps]
|
||||||
|
self.save_steps = [int(i * factor) for i in self.save_steps]
|
||||||
|
self.ply_steps = [int(i * factor) for i in self.ply_steps]
|
||||||
|
self.max_steps = int(self.max_steps * factor)
|
||||||
|
self.sh_degree_interval = int(self.sh_degree_interval * factor)
|
||||||
|
|
||||||
|
strategy = self.strategy
|
||||||
|
if isinstance(strategy, DefaultStrategy):
|
||||||
|
strategy.refine_start_iter = int(
|
||||||
|
strategy.refine_start_iter * factor
|
||||||
|
)
|
||||||
|
strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor)
|
||||||
|
strategy.reset_every = int(strategy.reset_every * factor)
|
||||||
|
strategy.refine_every = int(strategy.refine_every * factor)
|
||||||
|
elif isinstance(strategy, MCMCStrategy):
|
||||||
|
strategy.refine_start_iter = int(
|
||||||
|
strategy.refine_start_iter * factor
|
||||||
|
)
|
||||||
|
strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor)
|
||||||
|
strategy.refine_every = int(strategy.refine_every * factor)
|
||||||
|
else:
|
||||||
|
assert_never(strategy)
|
||||||
331
embodied_gen/utils/gaussian.py
Normal file
331
embodied_gen/utils/gaussian.py
Normal file
@ -0,0 +1,331 @@
|
|||||||
|
# Project EmbodiedGen
|
||||||
|
#
|
||||||
|
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
# implied. See the License for the specific language governing
|
||||||
|
# permissions and limitations under the License.
|
||||||
|
# Part of the code comes from https://github.com/nerfstudio-project/gsplat
|
||||||
|
# Both under the Apache License, Version 2.0.
|
||||||
|
|
||||||
|
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Dict, Literal, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import trimesh
|
||||||
|
from gsplat.optimizers import SelectiveAdam
|
||||||
|
from scipy.spatial.transform import Rotation
|
||||||
|
from sklearn.neighbors import NearestNeighbors
|
||||||
|
from torch import Tensor
|
||||||
|
from embodied_gen.models.gs_model import GaussianOperator
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"set_random_seed",
|
||||||
|
"export_splats",
|
||||||
|
"create_splats_with_optimizers",
|
||||||
|
"compute_pinhole_intrinsics",
|
||||||
|
"resize_pinhole_intrinsics",
|
||||||
|
"restore_scene_scale_and_position",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def knn(x: Tensor, K: int = 4) -> Tensor:
|
||||||
|
x_np = x.cpu().numpy()
|
||||||
|
model = NearestNeighbors(n_neighbors=K, metric="euclidean").fit(x_np)
|
||||||
|
distances, _ = model.kneighbors(x_np)
|
||||||
|
return torch.from_numpy(distances).to(x)
|
||||||
|
|
||||||
|
|
||||||
|
def rgb_to_sh(rgb: Tensor) -> Tensor:
|
||||||
|
C0 = 0.28209479177387814
|
||||||
|
return (rgb - 0.5) / C0
|
||||||
|
|
||||||
|
|
||||||
|
def set_random_seed(seed: int):
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def splat2ply_bytes(
|
||||||
|
means: torch.Tensor,
|
||||||
|
scales: torch.Tensor,
|
||||||
|
quats: torch.Tensor,
|
||||||
|
opacities: torch.Tensor,
|
||||||
|
sh0: torch.Tensor,
|
||||||
|
shN: torch.Tensor,
|
||||||
|
) -> bytes:
|
||||||
|
num_splats = means.shape[0]
|
||||||
|
buffer = BytesIO()
|
||||||
|
|
||||||
|
# Write PLY header
|
||||||
|
buffer.write(b"ply\n")
|
||||||
|
buffer.write(b"format binary_little_endian 1.0\n")
|
||||||
|
buffer.write(f"element vertex {num_splats}\n".encode())
|
||||||
|
buffer.write(b"property float x\n")
|
||||||
|
buffer.write(b"property float y\n")
|
||||||
|
buffer.write(b"property float z\n")
|
||||||
|
for i, data in enumerate([sh0, shN]):
|
||||||
|
prefix = "f_dc" if i == 0 else "f_rest"
|
||||||
|
for j in range(data.shape[1]):
|
||||||
|
buffer.write(f"property float {prefix}_{j}\n".encode())
|
||||||
|
buffer.write(b"property float opacity\n")
|
||||||
|
for i in range(scales.shape[1]):
|
||||||
|
buffer.write(f"property float scale_{i}\n".encode())
|
||||||
|
for i in range(quats.shape[1]):
|
||||||
|
buffer.write(f"property float rot_{i}\n".encode())
|
||||||
|
buffer.write(b"end_header\n")
|
||||||
|
|
||||||
|
# Concatenate all tensors in the correct order
|
||||||
|
splat_data = torch.cat(
|
||||||
|
[means, sh0, shN, opacities.unsqueeze(1), scales, quats], dim=1
|
||||||
|
)
|
||||||
|
# Ensure correct dtype
|
||||||
|
splat_data = splat_data.to(torch.float32)
|
||||||
|
|
||||||
|
# Write binary data
|
||||||
|
float_dtype = np.dtype(np.float32).newbyteorder("<")
|
||||||
|
buffer.write(
|
||||||
|
splat_data.detach().cpu().numpy().astype(float_dtype).tobytes()
|
||||||
|
)
|
||||||
|
|
||||||
|
return buffer.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
def export_splats(
|
||||||
|
means: torch.Tensor,
|
||||||
|
scales: torch.Tensor,
|
||||||
|
quats: torch.Tensor,
|
||||||
|
opacities: torch.Tensor,
|
||||||
|
sh0: torch.Tensor,
|
||||||
|
shN: torch.Tensor,
|
||||||
|
format: Literal["ply"] = "ply",
|
||||||
|
save_to: Optional[str] = None,
|
||||||
|
) -> bytes:
|
||||||
|
"""Export a Gaussian Splats model to bytes in PLY file format."""
|
||||||
|
total_splats = means.shape[0]
|
||||||
|
assert means.shape == (total_splats, 3), "Means must be of shape (N, 3)"
|
||||||
|
assert scales.shape == (total_splats, 3), "Scales must be of shape (N, 3)"
|
||||||
|
assert quats.shape == (
|
||||||
|
total_splats,
|
||||||
|
4,
|
||||||
|
), "Quaternions must be of shape (N, 4)"
|
||||||
|
assert opacities.shape == (
|
||||||
|
total_splats,
|
||||||
|
), "Opacities must be of shape (N,)"
|
||||||
|
assert sh0.shape == (total_splats, 1, 3), "sh0 must be of shape (N, 1, 3)"
|
||||||
|
assert (
|
||||||
|
shN.ndim == 3 and shN.shape[0] == total_splats and shN.shape[2] == 3
|
||||||
|
), f"shN must be of shape (N, K, 3), got {shN.shape}"
|
||||||
|
|
||||||
|
# Reshape spherical harmonics
|
||||||
|
sh0 = sh0.squeeze(1) # Shape (N, 3)
|
||||||
|
shN = shN.permute(0, 2, 1).reshape(means.shape[0], -1) # Shape (N, K * 3)
|
||||||
|
|
||||||
|
# Check for NaN or Inf values
|
||||||
|
invalid_mask = (
|
||||||
|
torch.isnan(means).any(dim=1)
|
||||||
|
| torch.isinf(means).any(dim=1)
|
||||||
|
| torch.isnan(scales).any(dim=1)
|
||||||
|
| torch.isinf(scales).any(dim=1)
|
||||||
|
| torch.isnan(quats).any(dim=1)
|
||||||
|
| torch.isinf(quats).any(dim=1)
|
||||||
|
| torch.isnan(opacities).any(dim=0)
|
||||||
|
| torch.isinf(opacities).any(dim=0)
|
||||||
|
| torch.isnan(sh0).any(dim=1)
|
||||||
|
| torch.isinf(sh0).any(dim=1)
|
||||||
|
| torch.isnan(shN).any(dim=1)
|
||||||
|
| torch.isinf(shN).any(dim=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter out invalid entries
|
||||||
|
valid_mask = ~invalid_mask
|
||||||
|
means = means[valid_mask]
|
||||||
|
scales = scales[valid_mask]
|
||||||
|
quats = quats[valid_mask]
|
||||||
|
opacities = opacities[valid_mask]
|
||||||
|
sh0 = sh0[valid_mask]
|
||||||
|
shN = shN[valid_mask]
|
||||||
|
|
||||||
|
if format == "ply":
|
||||||
|
data = splat2ply_bytes(means, scales, quats, opacities, sh0, shN)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported format: {format}")
|
||||||
|
|
||||||
|
if save_to:
|
||||||
|
with open(save_to, "wb") as binary_file:
|
||||||
|
binary_file.write(data)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def create_splats_with_optimizers(
|
||||||
|
points: np.ndarray = None,
|
||||||
|
points_rgb: np.ndarray = None,
|
||||||
|
init_num_pts: int = 100_000,
|
||||||
|
init_extent: float = 3.0,
|
||||||
|
init_opacity: float = 0.1,
|
||||||
|
init_scale: float = 1.0,
|
||||||
|
means_lr: float = 1.6e-4,
|
||||||
|
scales_lr: float = 5e-3,
|
||||||
|
opacities_lr: float = 5e-2,
|
||||||
|
quats_lr: float = 1e-3,
|
||||||
|
sh0_lr: float = 2.5e-3,
|
||||||
|
shN_lr: float = 2.5e-3 / 20,
|
||||||
|
scene_scale: float = 1.0,
|
||||||
|
sh_degree: int = 3,
|
||||||
|
sparse_grad: bool = False,
|
||||||
|
visible_adam: bool = False,
|
||||||
|
batch_size: int = 1,
|
||||||
|
feature_dim: Optional[int] = None,
|
||||||
|
device: str = "cuda",
|
||||||
|
world_rank: int = 0,
|
||||||
|
world_size: int = 1,
|
||||||
|
) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]:
|
||||||
|
if points is not None and points_rgb is not None:
|
||||||
|
points = torch.from_numpy(points).float()
|
||||||
|
rgbs = torch.from_numpy(points_rgb / 255.0).float()
|
||||||
|
else:
|
||||||
|
points = (
|
||||||
|
init_extent * scene_scale * (torch.rand((init_num_pts, 3)) * 2 - 1)
|
||||||
|
)
|
||||||
|
rgbs = torch.rand((init_num_pts, 3))
|
||||||
|
|
||||||
|
# Initialize the GS size to be the average dist of the 3 nearest neighbors
|
||||||
|
dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,]
|
||||||
|
dist_avg = torch.sqrt(dist2_avg)
|
||||||
|
scales = (
|
||||||
|
torch.log(dist_avg * init_scale).unsqueeze(-1).repeat(1, 3)
|
||||||
|
) # [N, 3]
|
||||||
|
|
||||||
|
# Distribute the GSs to different ranks (also works for single rank)
|
||||||
|
points = points[world_rank::world_size]
|
||||||
|
rgbs = rgbs[world_rank::world_size]
|
||||||
|
scales = scales[world_rank::world_size]
|
||||||
|
|
||||||
|
N = points.shape[0]
|
||||||
|
quats = torch.rand((N, 4)) # [N, 4]
|
||||||
|
opacities = torch.logit(torch.full((N,), init_opacity)) # [N,]
|
||||||
|
|
||||||
|
params = [
|
||||||
|
# name, value, lr
|
||||||
|
("means", torch.nn.Parameter(points), means_lr * scene_scale),
|
||||||
|
("scales", torch.nn.Parameter(scales), scales_lr),
|
||||||
|
("quats", torch.nn.Parameter(quats), quats_lr),
|
||||||
|
("opacities", torch.nn.Parameter(opacities), opacities_lr),
|
||||||
|
]
|
||||||
|
|
||||||
|
if feature_dim is None:
|
||||||
|
# color is SH coefficients.
|
||||||
|
colors = torch.zeros((N, (sh_degree + 1) ** 2, 3)) # [N, K, 3]
|
||||||
|
colors[:, 0, :] = rgb_to_sh(rgbs)
|
||||||
|
params.append(("sh0", torch.nn.Parameter(colors[:, :1, :]), sh0_lr))
|
||||||
|
params.append(("shN", torch.nn.Parameter(colors[:, 1:, :]), shN_lr))
|
||||||
|
else:
|
||||||
|
# features will be used for appearance and view-dependent shading
|
||||||
|
features = torch.rand(N, feature_dim) # [N, feature_dim]
|
||||||
|
params.append(("features", torch.nn.Parameter(features), sh0_lr))
|
||||||
|
colors = torch.logit(rgbs) # [N, 3]
|
||||||
|
params.append(("colors", torch.nn.Parameter(colors), sh0_lr))
|
||||||
|
|
||||||
|
splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device)
|
||||||
|
# Scale learning rate based on batch size, reference:
|
||||||
|
# https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/
|
||||||
|
# Note that this would not make the training exactly equivalent, see
|
||||||
|
# https://arxiv.org/pdf/2402.18824v1
|
||||||
|
BS = batch_size * world_size
|
||||||
|
optimizer_class = None
|
||||||
|
if sparse_grad:
|
||||||
|
optimizer_class = torch.optim.SparseAdam
|
||||||
|
elif visible_adam:
|
||||||
|
optimizer_class = SelectiveAdam
|
||||||
|
else:
|
||||||
|
optimizer_class = torch.optim.Adam
|
||||||
|
optimizers = {
|
||||||
|
name: optimizer_class(
|
||||||
|
[{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}],
|
||||||
|
eps=1e-15 / math.sqrt(BS),
|
||||||
|
# TODO: check betas logic when BS is larger than 10 betas[0] will be zero.
|
||||||
|
betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)),
|
||||||
|
)
|
||||||
|
for name, _, lr in params
|
||||||
|
}
|
||||||
|
return splats, optimizers
|
||||||
|
|
||||||
|
|
||||||
|
def compute_pinhole_intrinsics(
|
||||||
|
image_w: int, image_h: int, fov_deg: float
|
||||||
|
) -> np.ndarray:
|
||||||
|
fov_rad = np.deg2rad(fov_deg)
|
||||||
|
fx = image_w / (2 * np.tan(fov_rad / 2))
|
||||||
|
fy = fx # assuming square pixels
|
||||||
|
cx = image_w / 2
|
||||||
|
cy = image_h / 2
|
||||||
|
K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
|
||||||
|
|
||||||
|
return K
|
||||||
|
|
||||||
|
|
||||||
|
def resize_pinhole_intrinsics(
|
||||||
|
raw_K: np.ndarray | torch.Tensor,
|
||||||
|
raw_hw: tuple[int, int],
|
||||||
|
new_hw: tuple[int, int],
|
||||||
|
) -> np.ndarray:
|
||||||
|
raw_h, raw_w = raw_hw
|
||||||
|
new_h, new_w = new_hw
|
||||||
|
|
||||||
|
scale_x = new_w / raw_w
|
||||||
|
scale_y = new_h / raw_h
|
||||||
|
|
||||||
|
new_K = raw_K.copy() if isinstance(raw_K, np.ndarray) else raw_K.clone()
|
||||||
|
new_K[0, 0] *= scale_x # fx
|
||||||
|
new_K[0, 2] *= scale_x # cx
|
||||||
|
new_K[1, 1] *= scale_y # fy
|
||||||
|
new_K[1, 2] *= scale_y # cy
|
||||||
|
|
||||||
|
return new_K
|
||||||
|
|
||||||
|
|
||||||
|
def restore_scene_scale_and_position(
|
||||||
|
real_height: float, mesh_path: str, gs_path: str
|
||||||
|
) -> None:
|
||||||
|
"""Scales a mesh and corresponding GS model to match a given real-world height.
|
||||||
|
|
||||||
|
Uses the 1st and 99th percentile of mesh Z-axis to estimate height,
|
||||||
|
applies scaling and vertical alignment, and updates both the mesh and GS model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
real_height (float): Target real-world height among Z axis.
|
||||||
|
mesh_path (str): Path to the input mesh file.
|
||||||
|
gs_path (str): Path to the Gaussian Splatting model file.
|
||||||
|
"""
|
||||||
|
mesh = trimesh.load(mesh_path)
|
||||||
|
z_min = np.percentile(mesh.vertices[:, 1], 1)
|
||||||
|
z_max = np.percentile(mesh.vertices[:, 1], 99)
|
||||||
|
height = z_max - z_min
|
||||||
|
scale = real_height / height
|
||||||
|
|
||||||
|
rot = Rotation.from_quat([0, 1, 0, 0])
|
||||||
|
mesh.vertices = rot.apply(mesh.vertices)
|
||||||
|
mesh.vertices[:, 1] -= z_min
|
||||||
|
mesh.vertices *= scale
|
||||||
|
mesh.export(mesh_path)
|
||||||
|
|
||||||
|
gs_model: GaussianOperator = GaussianOperator.load_from_ply(gs_path)
|
||||||
|
gs_model = gs_model.get_gaussians(
|
||||||
|
instance_pose=torch.tensor([0.0, -z_min, 0, 0, 1, 0, 0])
|
||||||
|
)
|
||||||
|
gs_model.rescale(scale)
|
||||||
|
gs_model.save_to_ply(gs_path)
|
||||||
152
embodied_gen/utils/monkey_patches.py
Normal file
152
embodied_gen/utils/monkey_patches.py
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
# Project EmbodiedGen
|
||||||
|
#
|
||||||
|
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
# implied. See the License for the specific language governing
|
||||||
|
# permissions and limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import zipfile
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
|
||||||
|
def monkey_patch_pano2room():
|
||||||
|
current_file_path = os.path.abspath(__file__)
|
||||||
|
current_dir = os.path.dirname(current_file_path)
|
||||||
|
sys.path.append(os.path.join(current_dir, "../.."))
|
||||||
|
sys.path.append(os.path.join(current_dir, "../../thirdparty/pano2room"))
|
||||||
|
from thirdparty.pano2room.modules.geo_predictors.omnidata.omnidata_normal_predictor import (
|
||||||
|
OmnidataNormalPredictor,
|
||||||
|
)
|
||||||
|
from thirdparty.pano2room.modules.geo_predictors.omnidata.omnidata_predictor import (
|
||||||
|
OmnidataPredictor,
|
||||||
|
)
|
||||||
|
|
||||||
|
def patched_omni_depth_init(self):
|
||||||
|
self.img_size = 384
|
||||||
|
self.model = torch.hub.load(
|
||||||
|
'alexsax/omnidata_models', 'depth_dpt_hybrid_384'
|
||||||
|
)
|
||||||
|
self.model.eval()
|
||||||
|
self.trans_totensor = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Resize(self.img_size, interpolation=Image.BILINEAR),
|
||||||
|
transforms.CenterCrop(self.img_size),
|
||||||
|
transforms.Normalize(mean=0.5, std=0.5),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
OmnidataPredictor.__init__ = patched_omni_depth_init
|
||||||
|
|
||||||
|
def patched_omni_normal_init(self):
|
||||||
|
self.img_size = 384
|
||||||
|
self.model = torch.hub.load(
|
||||||
|
'alexsax/omnidata_models', 'surface_normal_dpt_hybrid_384'
|
||||||
|
)
|
||||||
|
self.model.eval()
|
||||||
|
self.trans_totensor = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Resize(self.img_size, interpolation=Image.BILINEAR),
|
||||||
|
transforms.CenterCrop(self.img_size),
|
||||||
|
transforms.Normalize(mean=0.5, std=0.5),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
OmnidataNormalPredictor.__init__ = patched_omni_normal_init
|
||||||
|
|
||||||
|
def patched_panojoint_init(self, save_path=None):
|
||||||
|
self.depth_predictor = OmnidataPredictor()
|
||||||
|
self.normal_predictor = OmnidataNormalPredictor()
|
||||||
|
self.save_path = save_path
|
||||||
|
|
||||||
|
from modules.geo_predictors import PanoJointPredictor
|
||||||
|
|
||||||
|
PanoJointPredictor.__init__ = patched_panojoint_init
|
||||||
|
|
||||||
|
# NOTE: We use gsplat instead.
|
||||||
|
# import depth_diff_gaussian_rasterization_min as ddgr
|
||||||
|
# from dataclasses import dataclass
|
||||||
|
# @dataclass
|
||||||
|
# class PatchedGaussianRasterizationSettings:
|
||||||
|
# image_height: int
|
||||||
|
# image_width: int
|
||||||
|
# tanfovx: float
|
||||||
|
# tanfovy: float
|
||||||
|
# bg: torch.Tensor
|
||||||
|
# scale_modifier: float
|
||||||
|
# viewmatrix: torch.Tensor
|
||||||
|
# projmatrix: torch.Tensor
|
||||||
|
# sh_degree: int
|
||||||
|
# campos: torch.Tensor
|
||||||
|
# prefiltered: bool
|
||||||
|
# debug: bool = False
|
||||||
|
# ddgr.GaussianRasterizationSettings = PatchedGaussianRasterizationSettings
|
||||||
|
|
||||||
|
# disable get_has_ddp_rank print in `BaseInpaintingTrainingModule`
|
||||||
|
os.environ["NODE_RANK"] = "0"
|
||||||
|
|
||||||
|
from thirdparty.pano2room.modules.inpainters.lama.saicinpainting.training.trainers import (
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from thirdparty.pano2room.modules.inpainters.lama_inpainter import (
|
||||||
|
LamaInpainter,
|
||||||
|
)
|
||||||
|
|
||||||
|
def patched_lama_inpaint_init(self):
|
||||||
|
zip_path = hf_hub_download(
|
||||||
|
repo_id="smartywu/big-lama",
|
||||||
|
filename="big-lama.zip",
|
||||||
|
repo_type="model",
|
||||||
|
)
|
||||||
|
extract_dir = os.path.splitext(zip_path)[0]
|
||||||
|
|
||||||
|
if not os.path.exists(extract_dir):
|
||||||
|
os.makedirs(extract_dir, exist_ok=True)
|
||||||
|
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||||
|
zip_ref.extractall(extract_dir)
|
||||||
|
|
||||||
|
config_path = os.path.join(extract_dir, 'big-lama', 'config.yaml')
|
||||||
|
checkpoint_path = os.path.join(
|
||||||
|
extract_dir, 'big-lama/models/best.ckpt'
|
||||||
|
)
|
||||||
|
train_config = OmegaConf.load(config_path)
|
||||||
|
train_config.training_model.predict_only = True
|
||||||
|
train_config.visualizer.kind = 'noop'
|
||||||
|
|
||||||
|
self.model = load_checkpoint(
|
||||||
|
train_config, checkpoint_path, strict=False, map_location='cpu'
|
||||||
|
)
|
||||||
|
self.model.freeze()
|
||||||
|
|
||||||
|
LamaInpainter.__init__ = patched_lama_inpaint_init
|
||||||
|
|
||||||
|
from diffusers import StableDiffusionInpaintPipeline
|
||||||
|
from thirdparty.pano2room.modules.inpainters.SDFT_inpainter import (
|
||||||
|
SDFTInpainter,
|
||||||
|
)
|
||||||
|
|
||||||
|
def patched_sd_inpaint_init(self, subset_name=None):
|
||||||
|
super(SDFTInpainter, self).__init__()
|
||||||
|
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||||
|
"stabilityai/stable-diffusion-2-inpainting",
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
).to("cuda")
|
||||||
|
pipe.enable_model_cpu_offload()
|
||||||
|
self.inpaint_pipe = pipe
|
||||||
|
|
||||||
|
SDFTInpainter.__init__ = patched_sd_inpaint_init
|
||||||
@ -17,6 +17,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
import textwrap
|
import textwrap
|
||||||
from glob import glob
|
from glob import glob
|
||||||
@ -27,10 +28,10 @@ import imageio
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import PIL.Image as Image
|
|
||||||
import spaces
|
import spaces
|
||||||
from matplotlib.patches import Patch
|
from matplotlib.patches import Patch
|
||||||
from moviepy.editor import VideoFileClip, clips_array
|
from moviepy.editor import VideoFileClip, clips_array
|
||||||
|
from PIL import Image
|
||||||
from embodied_gen.data.differentiable_render import entrypoint as render_api
|
from embodied_gen.data.differentiable_render import entrypoint as render_api
|
||||||
from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum
|
from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum
|
||||||
|
|
||||||
@ -45,6 +46,8 @@ __all__ = [
|
|||||||
"filter_image_small_connected_components",
|
"filter_image_small_connected_components",
|
||||||
"combine_images_to_grid",
|
"combine_images_to_grid",
|
||||||
"SceneTreeVisualizer",
|
"SceneTreeVisualizer",
|
||||||
|
"is_image_file",
|
||||||
|
"parse_text_prompts",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -356,6 +359,23 @@ def load_scene_dict(file_path: str) -> dict:
|
|||||||
return scene_dict
|
return scene_dict
|
||||||
|
|
||||||
|
|
||||||
|
def is_image_file(filename: str) -> bool:
|
||||||
|
mime_type, _ = mimetypes.guess_type(filename)
|
||||||
|
|
||||||
|
return mime_type is not None and mime_type.startswith('image')
|
||||||
|
|
||||||
|
|
||||||
|
def parse_text_prompts(prompts: list[str]) -> list[str]:
|
||||||
|
if len(prompts) == 1 and prompts[0].endswith(".txt"):
|
||||||
|
with open(prompts[0], "r") as f:
|
||||||
|
prompts = [
|
||||||
|
line.strip()
|
||||||
|
for line in f
|
||||||
|
if line.strip() and not line.strip().startswith("#")
|
||||||
|
]
|
||||||
|
return prompts
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
merge_video_video(
|
merge_video_video(
|
||||||
"outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh_glo_normal.mp4", # noqa
|
"outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh_glo_normal.mp4", # noqa
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
VERSION = "v0.1.1"
|
VERSION = "v0.1.2"
|
||||||
|
|||||||
@ -33,6 +33,9 @@ __all__ = [
|
|||||||
"ImageAestheticChecker",
|
"ImageAestheticChecker",
|
||||||
"SemanticConsistChecker",
|
"SemanticConsistChecker",
|
||||||
"TextGenAlignChecker",
|
"TextGenAlignChecker",
|
||||||
|
"PanoImageGenChecker",
|
||||||
|
"PanoHeightEstimator",
|
||||||
|
"PanoImageOccChecker",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -328,6 +331,159 @@ class TextGenAlignChecker(BaseChecker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PanoImageGenChecker(BaseChecker):
|
||||||
|
"""A checker class that validates the quality and realism of generated panoramic indoor images.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
gpt_client (GPTclient): A GPT client instance used to query for image validation.
|
||||||
|
prompt (str): The instruction prompt passed to the GPT model. If None, a default prompt is used.
|
||||||
|
verbose (bool): Whether to print internal processing information for debugging.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
gpt_client: GPTclient,
|
||||||
|
prompt: str = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(prompt, verbose)
|
||||||
|
self.gpt_client = gpt_client
|
||||||
|
if self.prompt is None:
|
||||||
|
self.prompt = """
|
||||||
|
You are a panoramic image analyzer specializing in indoor room structure validation.
|
||||||
|
|
||||||
|
Given a generated panoramic image, assess if it meets all the criteria:
|
||||||
|
- Floor Space: ≥30 percent of the floor is free of objects or obstructions.
|
||||||
|
- Visual Clarity: Floor, walls, and ceiling are clear, with no distortion, blur, noise.
|
||||||
|
- Structural Continuity: Surfaces form plausible, continuous geometry
|
||||||
|
without breaks, floating parts, or abrupt cuts.
|
||||||
|
- Spatial Completeness: Full 360° coverage without missing areas,
|
||||||
|
seams, gaps, or stitching artifacts.
|
||||||
|
Instructions:
|
||||||
|
- If all criteria are met, reply with "YES".
|
||||||
|
- Otherwise, reply with "NO: <brief explanation>" (max 20 words).
|
||||||
|
|
||||||
|
Respond exactly as:
|
||||||
|
"YES"
|
||||||
|
or
|
||||||
|
"NO: brief explanation."
|
||||||
|
"""
|
||||||
|
|
||||||
|
def query(self, image_paths: str | Image.Image) -> str:
|
||||||
|
|
||||||
|
return self.gpt_client.query(
|
||||||
|
text_prompt=self.prompt,
|
||||||
|
image_base64=image_paths,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PanoImageOccChecker(BaseChecker):
|
||||||
|
"""Checks for physical obstacles in the bottom-center region of a panoramic image.
|
||||||
|
|
||||||
|
This class crops a specified region from the input panoramic image and uses
|
||||||
|
a GPT client to determine whether any physical obstacles there.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gpt_client (GPTclient): The GPT-based client used for visual reasoning.
|
||||||
|
box_hw (tuple[int, int]): The height and width of the crop box.
|
||||||
|
prompt (str, optional): Custom prompt for the GPT client. Defaults to a predefined one.
|
||||||
|
verbose (bool, optional): Whether to print verbose logs. Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
gpt_client: GPTclient,
|
||||||
|
box_hw: tuple[int, int],
|
||||||
|
prompt: str = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(prompt, verbose)
|
||||||
|
self.gpt_client = gpt_client
|
||||||
|
self.box_hw = box_hw
|
||||||
|
if self.prompt is None:
|
||||||
|
self.prompt = """
|
||||||
|
This image is a cropped region from the bottom-center of a panoramic view.
|
||||||
|
Please determine whether there is any obstacle present — such as furniture, tables, or other physical objects.
|
||||||
|
Ignore floor textures, rugs, carpets, shadows, and lighting effects — they do not count as obstacles.
|
||||||
|
Only consider real, physical objects that could block walking or movement.
|
||||||
|
|
||||||
|
Instructions:
|
||||||
|
- If there is no obstacle, reply: "YES".
|
||||||
|
- Otherwise, reply: "NO: <brief explanation>" (max 20 words).
|
||||||
|
|
||||||
|
Respond exactly as:
|
||||||
|
"YES"
|
||||||
|
or
|
||||||
|
"NO: brief explanation."
|
||||||
|
"""
|
||||||
|
|
||||||
|
def query(self, image_paths: str | Image.Image) -> str:
|
||||||
|
if isinstance(image_paths, str):
|
||||||
|
image_paths = Image.open(image_paths)
|
||||||
|
|
||||||
|
w, h = image_paths.size
|
||||||
|
image_paths = image_paths.crop(
|
||||||
|
(
|
||||||
|
(w - self.box_hw[1]) // 2,
|
||||||
|
h - self.box_hw[0],
|
||||||
|
(w + self.box_hw[1]) // 2,
|
||||||
|
h,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.gpt_client.query(
|
||||||
|
text_prompt=self.prompt,
|
||||||
|
image_base64=image_paths,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PanoHeightEstimator(object):
|
||||||
|
"""Estimate the real ceiling height of an indoor space from a 360° panoramic image.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
gpt_client (GPTclient): The GPT client used to perform image-based reasoning and return height estimates.
|
||||||
|
default_value (float): The fallback height in meters if parsing the GPT output fails.
|
||||||
|
prompt (str): The textual instruction used to guide the GPT model for height estimation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
gpt_client: GPTclient,
|
||||||
|
default_value: float = 3.5,
|
||||||
|
) -> None:
|
||||||
|
self.gpt_client = gpt_client
|
||||||
|
self.default_value = default_value
|
||||||
|
self.prompt = """
|
||||||
|
You are an expert in building height estimation and panoramic image analysis.
|
||||||
|
Your task is to analyze a 360° indoor panoramic image and estimate the **actual height** of the space in meters.
|
||||||
|
|
||||||
|
Consider the following visual cues:
|
||||||
|
1. Ceiling visibility and reference objects (doors, windows, furniture, appliances).
|
||||||
|
2. Floor features or level differences.
|
||||||
|
3. Room type (e.g., residential, office, commercial).
|
||||||
|
4. Object-to-ceiling proportions (e.g., height of doors relative to ceiling).
|
||||||
|
5. Architectural elements (e.g., chandeliers, shelves, kitchen cabinets).
|
||||||
|
|
||||||
|
Input: A full 360° panoramic indoor photo.
|
||||||
|
Output: A single number in meters representing the estimated room height. Only return the number (e.g., `3.2`)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, image_paths: str | Image.Image) -> float:
|
||||||
|
result = self.gpt_client.query(
|
||||||
|
text_prompt=self.prompt,
|
||||||
|
image_base64=image_paths,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
result = float(result.strip())
|
||||||
|
except ValueError:
|
||||||
|
logger.error(
|
||||||
|
f"Parser error: failed convert {result} to float, use default value {self.default_value}."
|
||||||
|
)
|
||||||
|
result = self.default_value
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class SemanticMatcher(BaseChecker):
|
class SemanticMatcher(BaseChecker):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
75
install.sh
75
install.sh
@ -1,65 +1,28 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
RED='\033[0;31m'
|
STAGE=$1 # "basic" | "extra" | "all"
|
||||||
GREEN='\033[0;32m'
|
STAGE=${STAGE:-all}
|
||||||
NC='\033[0m'
|
|
||||||
|
|
||||||
echo -e "${GREEN}Starting installation process...${NC}"
|
SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
||||||
|
|
||||||
|
source "$SCRIPT_DIR/install/_utils.sh"
|
||||||
git config --global http.postBuffer 524288000
|
git config --global http.postBuffer 524288000
|
||||||
|
# Patch submodule .gitignore to ignore __pycache__, only if submodule exists
|
||||||
|
PANO2ROOM_PATH="$SCRIPT_DIR/thirdparty/pano2room"
|
||||||
|
if [ -d "$PANO2ROOM_PATH" ]; then
|
||||||
|
echo "__pycache__/" > "$PANO2ROOM_PATH/.gitignore"
|
||||||
|
log_info "Added .gitignore to ignore __pycache__ in $PANO2ROOM_PATH"
|
||||||
|
fi
|
||||||
|
|
||||||
echo -e "${GREEN}Installing flash-attn...${NC}"
|
log_info "===== Starting installation stage: $STAGE ====="
|
||||||
pip install flash-attn==2.7.0.post2 --no-build-isolation || {
|
|
||||||
echo -e "${RED}Failed to install flash-attn${NC}"
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
echo -e "${GREEN}Installing dependencies from requirements.txt...${NC}"
|
if [[ "$STAGE" == "basic" || "$STAGE" == "all" ]]; then
|
||||||
pip install -r requirements.txt --use-deprecated=legacy-resolver --default-timeout=60 || {
|
bash "$SCRIPT_DIR/install/install_basic.sh"
|
||||||
echo -e "${RED}Failed to install requirements${NC}"
|
fi
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
|
if [[ "$STAGE" == "extra" || "$STAGE" == "all" ]]; then
|
||||||
|
bash "$SCRIPT_DIR/install/install_extra.sh"
|
||||||
|
fi
|
||||||
|
|
||||||
echo -e "${GREEN}Installing kolors from GitHub...${NC}"
|
log_info "===== Installation completed successfully. ====="
|
||||||
pip install kolors@git+https://github.com/Kwai-Kolors/Kolors.git#egg=038818d || {
|
|
||||||
echo -e "${RED}Failed to install kolors${NC}"
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
echo -e "${GREEN}Installing kaolin from GitHub...${NC}"
|
|
||||||
pip install kaolin@git+https://github.com/NVIDIAGameWorks/kaolin.git@v0.16.0 || {
|
|
||||||
echo -e "${RED}Failed to install kaolin${NC}"
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
echo -e "${GREEN}Installing diff-gaussian-rasterization...${NC}"
|
|
||||||
TMP_DIR="/tmp/mip-splatting"
|
|
||||||
rm -rf "$TMP_DIR"
|
|
||||||
git clone --recursive https://github.com/autonomousvision/mip-splatting.git "$TMP_DIR" && \
|
|
||||||
pip install "$TMP_DIR/submodules/diff-gaussian-rasterization" && \
|
|
||||||
rm -rf "$TMP_DIR" || {
|
|
||||||
echo -e "${RED}Failed to clone or install diff-gaussian-rasterization${NC}"
|
|
||||||
rm -rf "$TMP_DIR"
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
echo -e "${GREEN}Installing gsplat from GitHub...${NC}"
|
|
||||||
pip install git+https://github.com/nerfstudio-project/gsplat.git@v1.5.0 || {
|
|
||||||
echo -e "${RED}Failed to install gsplat${NC}"
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
echo -e "${GREEN}Installing EmbodiedGen...${NC}"
|
|
||||||
pip install triton==2.1.0
|
|
||||||
pip install -e . || {
|
|
||||||
echo -e "${RED}Failed to install EmbodiedGen pyproject.toml${NC}"
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
echo -e "${GREEN}Installation completed successfully!${NC}"
|
|
||||||
|
|
||||||
|
|||||||
21
install/_utils.sh
Normal file
21
install/_utils.sh
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
RED='\033[0;31m'
|
||||||
|
GREEN='\033[0;32m'
|
||||||
|
NC='\033[0m'
|
||||||
|
|
||||||
|
log_info() {
|
||||||
|
echo -e "${GREEN}[INFO] $1${NC}"
|
||||||
|
}
|
||||||
|
|
||||||
|
log_error() {
|
||||||
|
echo -e "${RED}[ERROR] $1${NC}" >&2
|
||||||
|
}
|
||||||
|
|
||||||
|
try_install() {
|
||||||
|
log_info "$1"
|
||||||
|
eval "$2" || {
|
||||||
|
log_error "$3"
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
}
|
||||||
35
install/install_basic.sh
Normal file
35
install/install_basic.sh
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
||||||
|
source "$SCRIPT_DIR/_utils.sh"
|
||||||
|
|
||||||
|
try_install "Installing flash-attn..." \
|
||||||
|
"pip install flash-attn==2.7.0.post2 --no-build-isolation" \
|
||||||
|
"flash-attn installation failed."
|
||||||
|
|
||||||
|
try_install "Installing requirements.txt..." \
|
||||||
|
"pip install -r requirements.txt --use-deprecated=legacy-resolver --default-timeout=60" \
|
||||||
|
"requirements installation failed."
|
||||||
|
|
||||||
|
try_install "Installing kolors..." \
|
||||||
|
"pip install kolors@git+https://github.com/HochCC/Kolors.git" \
|
||||||
|
"kolors installation failed."
|
||||||
|
|
||||||
|
try_install "Installing kaolin..." \
|
||||||
|
"pip install kaolin@git+https://github.com/NVIDIAGameWorks/kaolin.git@v0.16.0" \
|
||||||
|
"kaolin installation failed."
|
||||||
|
|
||||||
|
log_info "Installing diff-gaussian-rasterization..."
|
||||||
|
TMP_DIR="/tmp/mip-splatting"
|
||||||
|
rm -rf "$TMP_DIR"
|
||||||
|
git clone --recursive https://github.com/autonomousvision/mip-splatting.git "$TMP_DIR"
|
||||||
|
pip install "$TMP_DIR/submodules/diff-gaussian-rasterization"
|
||||||
|
rm -rf "$TMP_DIR"
|
||||||
|
|
||||||
|
try_install "Installing gsplat..." \
|
||||||
|
"pip install git+https://github.com/nerfstudio-project/gsplat.git@v1.5.3" \
|
||||||
|
"gsplat installation failed."
|
||||||
|
|
||||||
|
try_install "Installing EmbodiedGen..." \
|
||||||
|
"pip install triton==2.1.0 --no-deps && pip install -e ." \
|
||||||
|
"EmbodiedGen installation failed."
|
||||||
52
install/install_extra.sh
Normal file
52
install/install_extra.sh
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
||||||
|
source "$SCRIPT_DIR/_utils.sh"
|
||||||
|
|
||||||
|
# try_install "Installing txt2panoimg..." \
|
||||||
|
# "pip install txt2panoimg@git+https://github.com/HochCC/SD-T2I-360PanoImage --no-deps" \
|
||||||
|
# "txt2panoimg installation failed."
|
||||||
|
|
||||||
|
# try_install "Installing fused-ssim..." \
|
||||||
|
# "pip install fused-ssim@git+https://github.com/rahul-goel/fused-ssim#egg=328dc98" \
|
||||||
|
# "fused-ssim installation failed."
|
||||||
|
|
||||||
|
# try_install "Installing tiny-cuda-nn..." \
|
||||||
|
# "pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch" \
|
||||||
|
# "tiny-cuda-nn installation failed."
|
||||||
|
|
||||||
|
# try_install "Installing pytorch3d" \
|
||||||
|
# "pip install git+https://github.com/facebookresearch/pytorch3d.git@v0.7.7" \
|
||||||
|
# "pytorch3d installation failed."
|
||||||
|
|
||||||
|
|
||||||
|
PYTHON_PACKAGES_NODEPS=(
|
||||||
|
timm
|
||||||
|
txt2panoimg@git+https://github.com/HochCC/SD-T2I-360PanoImage
|
||||||
|
kornia
|
||||||
|
kornia_rs
|
||||||
|
)
|
||||||
|
|
||||||
|
PYTHON_PACKAGES=(
|
||||||
|
fused-ssim@git+https://github.com/rahul-goel/fused-ssim#egg=328dc98
|
||||||
|
git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
|
||||||
|
git+https://github.com/facebookresearch/pytorch3d.git@v0.7.7
|
||||||
|
h5py
|
||||||
|
albumentations==0.5.2
|
||||||
|
webdataset
|
||||||
|
icecream
|
||||||
|
open3d
|
||||||
|
pyequilib
|
||||||
|
numpy==1.26.4
|
||||||
|
triton==2.1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
for pkg in "${PYTHON_PACKAGES_NODEPS[@]}"; do
|
||||||
|
try_install "Installing $pkg without dependencies..." \
|
||||||
|
"pip install --no-deps $pkg" \
|
||||||
|
"$pkg installation failed."
|
||||||
|
done
|
||||||
|
|
||||||
|
try_install "Installing other Python dependencies..." \
|
||||||
|
"pip install ${PYTHON_PACKAGES[*]}" \
|
||||||
|
"Python dependencies installation failed."
|
||||||
@ -7,7 +7,7 @@ packages = ["embodied_gen"]
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "embodied_gen"
|
name = "embodied_gen"
|
||||||
version = "v0.1.1"
|
version = "v0.1.2"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
license-files = ["LICENSE", "NOTICE"]
|
license-files = ["LICENSE", "NOTICE"]
|
||||||
@ -31,6 +31,7 @@ drender-cli = "embodied_gen.data.differentiable_render:entrypoint"
|
|||||||
backproject-cli = "embodied_gen.data.backproject_v2:entrypoint"
|
backproject-cli = "embodied_gen.data.backproject_v2:entrypoint"
|
||||||
img3d-cli = "embodied_gen.scripts.imageto3d:entrypoint"
|
img3d-cli = "embodied_gen.scripts.imageto3d:entrypoint"
|
||||||
text3d-cli = "embodied_gen.scripts.textto3d:text_to_3d"
|
text3d-cli = "embodied_gen.scripts.textto3d:text_to_3d"
|
||||||
|
scene3d-cli = "embodied_gen.scripts.gen_scene3d:entrypoint"
|
||||||
|
|
||||||
[tool.pydocstyle]
|
[tool.pydocstyle]
|
||||||
match = '(?!test_).*(?!_pb2)\.py'
|
match = '(?!test_).*(?!_pb2)\.py'
|
||||||
|
|||||||
@ -32,6 +32,9 @@ vtk==9.3.1
|
|||||||
spaces
|
spaces
|
||||||
colorlog
|
colorlog
|
||||||
json-repair
|
json-repair
|
||||||
|
scikit-learn
|
||||||
|
omegaconf
|
||||||
|
tyro
|
||||||
utils3d@git+https://github.com/EasternJournalist/utils3d.git#egg=9a4eb15
|
utils3d@git+https://github.com/EasternJournalist/utils3d.git#egg=9a4eb15
|
||||||
clip@git+https://github.com/openai/CLIP.git
|
clip@git+https://github.com/openai/CLIP.git
|
||||||
segment-anything@git+https://github.com/facebookresearch/segment-anything.git#egg=dca509f
|
segment-anything@git+https://github.com/facebookresearch/segment-anything.git#egg=dca509f
|
||||||
|
|||||||
@ -17,6 +17,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from glob import glob
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
||||||
@ -25,6 +26,9 @@ from embodied_gen.validators.quality_checkers import (
|
|||||||
ImageAestheticChecker,
|
ImageAestheticChecker,
|
||||||
ImageSegChecker,
|
ImageSegChecker,
|
||||||
MeshGeoChecker,
|
MeshGeoChecker,
|
||||||
|
PanoHeightEstimator,
|
||||||
|
PanoImageGenChecker,
|
||||||
|
PanoImageOccChecker,
|
||||||
SemanticConsistChecker,
|
SemanticConsistChecker,
|
||||||
TextGenAlignChecker,
|
TextGenAlignChecker,
|
||||||
)
|
)
|
||||||
@ -57,6 +61,21 @@ def textalign_checker():
|
|||||||
return TextGenAlignChecker(GPT_CLIENT)
|
return TextGenAlignChecker(GPT_CLIENT)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def pano_checker():
|
||||||
|
return PanoImageGenChecker(GPT_CLIENT)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def pano_height_estimator():
|
||||||
|
return PanoHeightEstimator(GPT_CLIENT)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def panoocc_checker():
|
||||||
|
return PanoImageOccChecker(GPT_CLIENT, box_hw=[90, 1000])
|
||||||
|
|
||||||
|
|
||||||
def test_geo_checker(geo_checker):
|
def test_geo_checker(geo_checker):
|
||||||
flag, result = geo_checker(
|
flag, result = geo_checker(
|
||||||
[
|
[
|
||||||
@ -117,3 +136,28 @@ def test_textgen_checker(textalign_checker, mesh_path, text_desc):
|
|||||||
)
|
)
|
||||||
flag, result = textalign_checker(text_desc, image_list)
|
flag, result = textalign_checker(text_desc, image_list)
|
||||||
logger.info(f"textalign_checker: {flag}, {result}")
|
logger.info(f"textalign_checker: {flag}, {result}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_panoheight_estimator(pano_height_estimator):
|
||||||
|
image_paths = glob("outputs/bg_v3/test2/*/*.png")
|
||||||
|
for image_path in image_paths:
|
||||||
|
result = pano_height_estimator(image_path)
|
||||||
|
logger.info(f"{type(result)}, {result}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_pano_checker(pano_checker):
|
||||||
|
# image_paths = [
|
||||||
|
# "outputs/bg_gen2/scene_0000/pano_image.png",
|
||||||
|
# "outputs/bg_gen2/scene_0001/pano_image.png",
|
||||||
|
# ]
|
||||||
|
image_paths = glob("outputs/bg_gen/*/*.png")
|
||||||
|
for image_path in image_paths:
|
||||||
|
flag, result = pano_checker(image_path)
|
||||||
|
logger.info(f"{image_path} {flag}, {result}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_panoocc_checker(panoocc_checker):
|
||||||
|
image_paths = glob("outputs/bg_gen/*/*.png")
|
||||||
|
for image_path in image_paths:
|
||||||
|
flag, result = panoocc_checker(image_path)
|
||||||
|
logger.info(f"{image_path} {flag}, {result}")
|
||||||
|
|||||||
@ -23,6 +23,8 @@ from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
|||||||
from embodied_gen.validators.quality_checkers import (
|
from embodied_gen.validators.quality_checkers import (
|
||||||
ImageSegChecker,
|
ImageSegChecker,
|
||||||
MeshGeoChecker,
|
MeshGeoChecker,
|
||||||
|
PanoHeightEstimator,
|
||||||
|
PanoImageGenChecker,
|
||||||
SemanticConsistChecker,
|
SemanticConsistChecker,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -93,3 +95,16 @@ def test_semantic_checker(gptclient_query_case2):
|
|||||||
)
|
)
|
||||||
assert isinstance(flag, (bool, type(None)))
|
assert isinstance(flag, (bool, type(None)))
|
||||||
assert isinstance(result, str)
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_panoheight_estimator():
|
||||||
|
checker = PanoHeightEstimator(GPT_CLIENT, default_value=3.5)
|
||||||
|
result = checker(image_paths="dummy_path/pano.png")
|
||||||
|
assert isinstance(result, float)
|
||||||
|
|
||||||
|
|
||||||
|
def test_panogen_checker():
|
||||||
|
checker = PanoImageGenChecker(GPT_CLIENT)
|
||||||
|
flag, result = checker(image_paths="dummy_path/pano.png")
|
||||||
|
assert isinstance(flag, (bool, type(None)))
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|||||||
1
thirdparty/pano2room
vendored
Submodule
1
thirdparty/pano2room
vendored
Submodule
@ -0,0 +1 @@
|
|||||||
|
Subproject commit bbf93ae57086ed700edc6ee445852d4457a9d704
|
||||||
Loading…
x
Reference in New Issue
Block a user