diff --git a/.gitignore b/.gitignore index 7d3fe28..ea1d498 100644 --- a/.gitignore +++ b/.gitignore @@ -59,4 +59,4 @@ output* scripts/tools/ weights apps/sessions/ -apps/assets/ \ No newline at end of file +apps/assets/ diff --git a/.gitmodules b/.gitmodules index ca1c2c7..c6b0a7b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -2,4 +2,9 @@ path = thirdparty/TRELLIS url = https://github.com/microsoft/TRELLIS.git branch = main - shallow = true \ No newline at end of file + shallow = true +[submodule "thirdparty/pano2room"] + path = thirdparty/pano2room + url = https://github.com/TrickyGo/Pano2Room.git + branch = main + shallow = true diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 28d6ba6..e76d477 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - - repo: git@gitlab.hobot.cc:ptd/3rd/pre-commit/pre-commit-hooks.git - rev: v2.3.0 # Use the ref you want to point at + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.2.0 # Use the ref you want to point at hooks: - id: trailing-whitespace - id: check-added-large-files diff --git a/README.md b/README.md index 0cfc0cd..4de144e 100644 --- a/README.md +++ b/README.md @@ -30,11 +30,11 @@ ```sh git clone https://github.com/HorizonRobotics/EmbodiedGen.git cd EmbodiedGen -git checkout v0.1.1 +git checkout v0.1.2 git submodule update --init --recursive --progress -conda create -n embodiedgen python=3.10.13 -y +conda create -n embodiedgen python=3.10.13 -y # recommended to use a new env. conda activate embodiedgen -bash install.sh +bash install.sh basic ``` ### ✅ Setup GPT Agent @@ -94,7 +94,7 @@ python apps/text_to_3d.py ### ⚡ API Text-to-image model based on SD3.5 Medium, English prompts only. -Usage requires agreement to the [model license(click accept)](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium), models downloaded automatically. (ps: models with more permissive licenses found in `embodied_gen/models/image_comm_model.py`) +Usage requires agreement to the [model license(click accept)](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium), models downloaded automatically. For large-scale 3D assets generation, set `--n_pipe_retry=2` to ensure high end-to-end 3D asset usability through automatic quality check and retries. For more diverse results, do not set `--seed_img`. @@ -110,6 +110,7 @@ bash embodied_gen/scripts/textto3d.sh \ --prompts "small bronze figurine of a lion" "A globe with wooden base and latitude and longitude lines" "橙色电动手钻,有磨损细节" \ --output_root outputs/textto3d_k ``` +ps: models with more permissive licenses found in `embodied_gen/models/image_comm_model.py` --- @@ -146,10 +147,22 @@ bash embodied_gen/scripts/texture_gen.sh \
+### ⚡ API
+> Run `bash install.sh extra` to install additional requirements if you need to use `scene3d-cli`.
+
+It takes ~30mins to generate a color mesh and 3DGS per scene.
+
+```sh
+CUDA_VISIBLE_DEVICES=0 scene3d-cli \
+--prompts "Art studio with easel and canvas" \
+--output_dir outputs/bg_scenes/ \
+--seed 0 \
+--gs3d.max_steps 4000 \
+--disable_pano_check
+```
+
---
@@ -189,7 +202,7 @@ bash embodied_gen/scripts/texture_gen.sh \
## For Developer
```sh
-pip install .[dev] && pre-commit install
+pip install -e .[dev] && pre-commit install
python -m pytest # Pass all unit-test are required.
```
diff --git a/apps/common.py b/apps/common.py
index f7c7f40..76e6335 100644
--- a/apps/common.py
+++ b/apps/common.py
@@ -94,9 +94,6 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
os.environ["SPCONV_ALGO"] = "native"
MAX_SEED = 100000
-DELIGHT = DelightingModel()
-IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
-# IMAGESR_MODEL = ImageStableSR()
def patched_setup_functions(self):
@@ -136,6 +133,9 @@ def patched_setup_functions(self):
Gaussian.setup_functions = patched_setup_functions
+DELIGHT = DelightingModel()
+IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
+# IMAGESR_MODEL = ImageStableSR()
if os.getenv("GRADIO_APP") == "imageto3d":
RBG_REMOVER = RembgRemover()
RBG14_REMOVER = BMGG14Remover()
diff --git a/embodied_gen/data/datasets.py b/embodied_gen/data/datasets.py
index 4a9563a..eaa529b 100644
--- a/embodied_gen/data/datasets.py
+++ b/embodied_gen/data/datasets.py
@@ -19,8 +19,9 @@ import json
import logging
import os
import random
-from typing import Any, Callable, Dict, List, Tuple
+from typing import Any, Callable, Dict, List, Literal, Tuple
+import numpy as np
import torch
import torch.utils.checkpoint
from PIL import Image
@@ -36,6 +37,7 @@ logger = logging.getLogger(__name__)
__all__ = [
"Asset3dGenDataset",
+ "PanoGSplatDataset",
]
@@ -222,6 +224,68 @@ class Asset3dGenDataset(Dataset):
return data
+class PanoGSplatDataset(Dataset):
+ """A PyTorch Dataset for loading panorama-based 3D Gaussian Splatting data.
+
+ This dataset is designed to be compatible with train and eval pipelines
+ that use COLMAP-style camera conventions.
+
+ Args:
+ data_dir (str): Root directory where the dataset file is located.
+ split (str): Dataset split to use, either "train" or "eval".
+ data_name (str, optional): Name of the dataset file (default: "gs_data.pt").
+ max_sample_num (int, optional): Maximum number of samples to load. If None,
+ all available samples in the split will be used.
+ """
+
+ def __init__(
+ self,
+ data_dir: str,
+ split: str = Literal["train", "eval"],
+ data_name: str = "gs_data.pt",
+ max_sample_num: int = None,
+ ) -> None:
+ self.data_path = os.path.join(data_dir, data_name)
+ self.split = split
+ self.max_sample_num = max_sample_num
+ if not os.path.exists(self.data_path):
+ raise FileNotFoundError(
+ f"Dataset file {self.data_path} not found. Please provide the correct path."
+ )
+ self.data = torch.load(self.data_path, weights_only=False)
+ self.frames = self.data[split]
+ if max_sample_num is not None:
+ self.frames = self.frames[:max_sample_num]
+ self.points = self.data.get("points", None)
+ self.points_rgb = self.data.get("points_rgb", None)
+
+ def __len__(self) -> int:
+ return len(self.frames)
+
+ def cvt_blender_to_colmap_coord(self, c2w: np.ndarray) -> np.ndarray:
+ # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
+ tranformed_c2w = np.copy(c2w)
+ tranformed_c2w[:3, 1:3] *= -1
+
+ return tranformed_c2w
+
+ def __getitem__(self, index: int) -> dict[str, any]:
+ data = self.frames[index]
+ c2w = self.cvt_blender_to_colmap_coord(data["camtoworld"])
+ item = dict(
+ camtoworld=c2w,
+ K=data["K"],
+ image_h=data["image_h"],
+ image_w=data["image_w"],
+ )
+ if "image" in data:
+ item["image"] = data["image"]
+ if "image_id" in data:
+ item["image_id"] = data["image_id"]
+
+ return item
+
+
if __name__ == "__main__":
index_file = "datasets/objaverse/v1.0/statistics_1.0_gobjaverse_filter/view6s_v4/meta_ac2e0ddea8909db26d102c8465b5bcb2.json" # noqa
target_hw = (512, 512)
diff --git a/embodied_gen/data/utils.py b/embodied_gen/data/utils.py
index 83cb39d..0d39f71 100644
--- a/embodied_gen/data/utils.py
+++ b/embodied_gen/data/utils.py
@@ -158,8 +158,9 @@ class DiffrastRender(object):
return normalized_maps
+ @staticmethod
def normalize_map_by_mask(
- self, map: torch.Tensor, mask: torch.Tensor
+ map: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
# Normalize all maps in total by mask, normalized map in [0, 1].
foreground = (mask == 1).squeeze(dim=-1)
diff --git a/embodied_gen/scripts/gen_scene3d.py b/embodied_gen/scripts/gen_scene3d.py
new file mode 100644
index 0000000..42d2527
--- /dev/null
+++ b/embodied_gen/scripts/gen_scene3d.py
@@ -0,0 +1,191 @@
+import logging
+import os
+import random
+import time
+import warnings
+from dataclasses import dataclass, field
+from shutil import copy, rmtree
+
+import torch
+import tyro
+from huggingface_hub import snapshot_download
+from packaging import version
+
+# Suppress warnings
+warnings.filterwarnings("ignore", category=FutureWarning)
+logging.getLogger("transformers").setLevel(logging.ERROR)
+logging.getLogger("diffusers").setLevel(logging.ERROR)
+
+# TorchVision monkey patch for >0.16
+if version.parse(torch.__version__) >= version.parse("0.16"):
+ import sys
+ import types
+
+ import torchvision.transforms.functional as TF
+
+ functional_tensor = types.ModuleType(
+ "torchvision.transforms.functional_tensor"
+ )
+ functional_tensor.rgb_to_grayscale = TF.rgb_to_grayscale
+ sys.modules["torchvision.transforms.functional_tensor"] = functional_tensor
+
+from gsplat.distributed import cli
+from txt2panoimg import Text2360PanoramaImagePipeline
+from embodied_gen.trainer.gsplat_trainer import (
+ DefaultStrategy,
+ GsplatTrainConfig,
+)
+from embodied_gen.trainer.gsplat_trainer import entrypoint as gsplat_entrypoint
+from embodied_gen.trainer.pono2mesh_trainer import Pano2MeshSRPipeline
+from embodied_gen.utils.config import Pano2MeshSRConfig
+from embodied_gen.utils.gaussian import restore_scene_scale_and_position
+from embodied_gen.utils.gpt_clients import GPT_CLIENT
+from embodied_gen.utils.log import logger
+from embodied_gen.utils.process_media import is_image_file, parse_text_prompts
+from embodied_gen.validators.quality_checkers import (
+ PanoHeightEstimator,
+ PanoImageOccChecker,
+)
+
+__all__ = [
+ "generate_pano_image",
+ "entrypoint",
+]
+
+
+@dataclass
+class Scene3DGenConfig:
+ prompts: list[str] # Text desc of indoor room or style reference image.
+ output_dir: str
+ seed: int | None = None
+ real_height: float | None = None # The real height of the room in meters.
+ pano_image_only: bool = False
+ disable_pano_check: bool = False
+ keep_middle_result: bool = False
+ n_retry: int = 7
+ gs3d: GsplatTrainConfig = field(
+ default_factory=lambda: GsplatTrainConfig(
+ strategy=DefaultStrategy(verbose=True),
+ max_steps=4000,
+ init_opa=0.9,
+ opacity_reg=2e-3,
+ sh_degree=0,
+ means_lr=1e-4,
+ scales_lr=1e-3,
+ )
+ )
+
+
+def generate_pano_image(
+ prompt: str,
+ output_path: str,
+ pipeline,
+ seed: int,
+ n_retry: int,
+ checker=None,
+ num_inference_steps: int = 40,
+) -> None:
+ for i in range(n_retry):
+ logger.info(
+ f"GEN Panorama: Retry {i+1}/{n_retry} for prompt: {prompt}, seed: {seed}"
+ )
+ if is_image_file(prompt):
+ raise NotImplementedError("Image mode not implemented yet.")
+ else:
+ txt_prompt = f"{prompt}, spacious, empty, wide open, open floor, minimal furniture"
+ inputs = {
+ "prompt": txt_prompt,
+ "num_inference_steps": num_inference_steps,
+ "upscale": False,
+ "seed": seed,
+ }
+ pano_image = pipeline(inputs)
+
+ pano_image.save(output_path)
+ if checker is None:
+ break
+
+ flag, response = checker(pano_image)
+ logger.warning(f"{response}, image saved in {output_path}")
+ if flag is True or flag is None:
+ break
+
+ seed = random.randint(0, 100000)
+
+ return
+
+
+def entrypoint(*args, **kwargs):
+ cfg = tyro.cli(Scene3DGenConfig)
+
+ # Init global models.
+ model_path = snapshot_download("archerfmy0831/sd-t2i-360panoimage")
+ IMG2PANO_PIPE = Text2360PanoramaImagePipeline(
+ model_path, torch_dtype=torch.float16, device="cuda"
+ )
+ PANOMESH_CFG = Pano2MeshSRConfig()
+ PANO2MESH_PIPE = Pano2MeshSRPipeline(PANOMESH_CFG)
+ PANO_CHECKER = PanoImageOccChecker(GPT_CLIENT, box_hw=[95, 1000])
+ PANOHEIGHT_ESTOR = PanoHeightEstimator(GPT_CLIENT)
+
+ prompts = parse_text_prompts(cfg.prompts)
+ for idx, prompt in enumerate(prompts):
+ start_time = time.time()
+ output_dir = os.path.join(cfg.output_dir, f"scene_{idx:04d}")
+ os.makedirs(output_dir, exist_ok=True)
+ pano_path = os.path.join(output_dir, "pano_image.png")
+ with open(f"{output_dir}/prompt.txt", "w") as f:
+ f.write(prompt)
+
+ generate_pano_image(
+ prompt,
+ pano_path,
+ IMG2PANO_PIPE,
+ cfg.seed if cfg.seed is not None else random.randint(0, 100000),
+ cfg.n_retry,
+ checker=None if cfg.disable_pano_check else PANO_CHECKER,
+ )
+
+ if cfg.pano_image_only:
+ continue
+
+ logger.info("GEN and REPAIR Mesh from Panorama...")
+ PANO2MESH_PIPE(pano_path, output_dir)
+
+ logger.info("TRAIN 3DGS from Mesh Init and Cube Image...")
+ cfg.gs3d.data_dir = output_dir
+ cfg.gs3d.result_dir = f"{output_dir}/gaussian"
+ cfg.gs3d.adjust_steps(cfg.gs3d.steps_scaler)
+ torch.set_default_device("cpu") # recover default setting.
+ cli(gsplat_entrypoint, cfg.gs3d, verbose=True)
+
+ # Clean up the middle results.
+ gs_path = (
+ f"{cfg.gs3d.result_dir}/ply/point_cloud_{cfg.gs3d.max_steps-1}.ply"
+ )
+ copy(gs_path, f"{output_dir}/gs_model.ply")
+ video_path = f"{cfg.gs3d.result_dir}/renders/video_step{cfg.gs3d.max_steps-1}.mp4"
+ copy(video_path, f"{output_dir}/video.mp4")
+ gs_cfg_path = f"{cfg.gs3d.result_dir}/cfg.yml"
+ copy(gs_cfg_path, f"{output_dir}/gsplat_cfg.yml")
+ if not cfg.keep_middle_result:
+ rmtree(cfg.gs3d.result_dir, ignore_errors=True)
+ os.remove(f"{output_dir}/{PANOMESH_CFG.gs_data_file}")
+
+ real_height = (
+ PANOHEIGHT_ESTOR(pano_path)
+ if cfg.real_height is None
+ else cfg.real_height
+ )
+ gs_path = os.path.join(output_dir, "gs_model.ply")
+ mesh_path = os.path.join(output_dir, "mesh_model.ply")
+ restore_scene_scale_and_position(real_height, mesh_path, gs_path)
+
+ elapsed_time = (time.time() - start_time) / 60
+ logger.info(
+ f"FINISHED 3D scene generation in {output_dir} in {elapsed_time:.2f} mins."
+ )
+
+
+if __name__ == "__main__":
+ entrypoint()
diff --git a/embodied_gen/scripts/imageto3d.py b/embodied_gen/scripts/imageto3d.py
index 00a0958..f36cfa8 100644
--- a/embodied_gen/scripts/imageto3d.py
+++ b/embodied_gen/scripts/imageto3d.py
@@ -62,25 +62,6 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
os.environ["SPCONV_ALGO"] = "native"
random.seed(0)
-logger.info("Loading Models...")
-DELIGHT = DelightingModel()
-IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
-
-RBG_REMOVER = RembgRemover()
-RBG14_REMOVER = BMGG14Remover()
-SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
-PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
- "microsoft/TRELLIS-image-large"
-)
-# PIPELINE.cuda()
-SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
-GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
-AESTHETIC_CHECKER = ImageAestheticChecker()
-CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
-TMP_DIR = os.path.join(
- os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
-)
-
def parse_args():
parser = argparse.ArgumentParser(description="Image to 3D pipeline args.")
@@ -128,6 +109,19 @@ def entrypoint(**kwargs):
if hasattr(args, k) and v is not None:
setattr(args, k, v)
+ logger.info("Loading Models...")
+ DELIGHT = DelightingModel()
+ IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
+ RBG_REMOVER = RembgRemover()
+ PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
+ "microsoft/TRELLIS-image-large"
+ )
+ # PIPELINE.cuda()
+ SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
+ GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
+ AESTHETIC_CHECKER = ImageAestheticChecker()
+ CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
+
assert (
args.image_path or args.image_root
), "Please provide either --image_path or --image_root."
diff --git a/embodied_gen/scripts/text2image.py b/embodied_gen/scripts/text2image.py
index ac1587c..9f25dbc 100644
--- a/embodied_gen/scripts/text2image.py
+++ b/embodied_gen/scripts/text2image.py
@@ -31,6 +31,7 @@ from embodied_gen.models.text_model import (
build_text2img_pipeline,
text2img_gen,
)
+from embodied_gen.utils.process_media import parse_text_prompts
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@@ -101,14 +102,7 @@ def entrypoint(
if hasattr(args, k) and v is not None:
setattr(args, k, v)
- prompts = args.prompts
- if len(prompts) == 1 and prompts[0].endswith(".txt"):
- with open(prompts[0], "r") as f:
- prompts = f.readlines()
- prompts = [
- prompt.strip() for prompt in prompts if prompt.strip() != ""
- ]
-
+ prompts = parse_text_prompts(args.prompts)
os.makedirs(args.output_root, exist_ok=True)
ip_img_paths = args.ref_image
diff --git a/embodied_gen/scripts/textto3d.py b/embodied_gen/scripts/textto3d.py
index a4262a3..4d28093 100644
--- a/embodied_gen/scripts/textto3d.py
+++ b/embodied_gen/scripts/textto3d.py
@@ -44,13 +44,6 @@ __all__ = [
"text_to_3d",
]
-logger.info("Loading Models...")
-SEMANTIC_CHECKER = SemanticConsistChecker(GPT_CLIENT)
-SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
-TXTGEN_CHECKER = TextGenAlignChecker(GPT_CLIENT)
-PIPE_IMG = build_hf_image_pipeline("sd35")
-BG_REMOVER = RembgRemover()
-
def text_to_image(
prompt: str,
@@ -121,6 +114,14 @@ def text_to_3d(**kwargs) -> dict:
if hasattr(args, k) and v is not None:
setattr(args, k, v)
+ logger.info("Loading Models...")
+ global SEMANTIC_CHECKER, SEG_CHECKER, TXTGEN_CHECKER, PIPE_IMG, BG_REMOVER
+ SEMANTIC_CHECKER = SemanticConsistChecker(GPT_CLIENT)
+ SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
+ TXTGEN_CHECKER = TextGenAlignChecker(GPT_CLIENT)
+ PIPE_IMG = build_hf_image_pipeline(args.text_model)
+ BG_REMOVER = RembgRemover()
+
if args.asset_names is None or len(args.asset_names) == 0:
args.asset_names = [f"sample3d_{i}" for i in range(len(args.prompts))]
img_save_dir = os.path.join(args.output_root, "images")
@@ -260,6 +261,11 @@ def parse_args():
default=0,
help="Random seed for 3D generation",
)
+ parser.add_argument(
+ "--text_model",
+ type=str,
+ default="sd35",
+ )
parser.add_argument("--keep_intermediate", action="store_true")
args, unknown = parser.parse_known_args()
diff --git a/embodied_gen/trainer/gsplat_trainer.py b/embodied_gen/trainer/gsplat_trainer.py
new file mode 100644
index 0000000..53aaf28
--- /dev/null
+++ b/embodied_gen/trainer/gsplat_trainer.py
@@ -0,0 +1,678 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+# Part of the code comes from https://github.com/nerfstudio-project/gsplat/blob/main/examples/simple_trainer.py
+# Both under the Apache License, Version 2.0.
+
+
+import json
+import os
+import time
+from collections import defaultdict
+from typing import Dict, Optional, Tuple
+
+import cv2
+import imageio
+import numpy as np
+import torch
+import torch.nn.functional as F
+import tqdm
+import tyro
+import yaml
+from fused_ssim import fused_ssim
+from gsplat.distributed import cli
+from gsplat.rendering import rasterization
+from gsplat.strategy import DefaultStrategy, MCMCStrategy
+from torch import Tensor
+from torch.utils.tensorboard import SummaryWriter
+from torchmetrics.image import (
+ PeakSignalNoiseRatio,
+ StructuralSimilarityIndexMeasure,
+)
+from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
+from typing_extensions import Literal, assert_never
+from embodied_gen.data.datasets import PanoGSplatDataset
+from embodied_gen.utils.config import GsplatTrainConfig
+from embodied_gen.utils.gaussian import (
+ create_splats_with_optimizers,
+ export_splats,
+ resize_pinhole_intrinsics,
+ set_random_seed,
+)
+
+
+class Runner:
+ """Engine for training and testing from gsplat example.
+
+ Code from https://github.com/nerfstudio-project/gsplat/blob/main/examples/simple_trainer.py
+ """
+
+ def __init__(
+ self,
+ local_rank: int,
+ world_rank,
+ world_size: int,
+ cfg: GsplatTrainConfig,
+ ) -> None:
+ set_random_seed(42 + local_rank)
+
+ self.cfg = cfg
+ self.world_rank = world_rank
+ self.local_rank = local_rank
+ self.world_size = world_size
+ self.device = f"cuda:{local_rank}"
+
+ # Where to dump results.
+ os.makedirs(cfg.result_dir, exist_ok=True)
+
+ # Setup output directories.
+ self.ckpt_dir = f"{cfg.result_dir}/ckpts"
+ os.makedirs(self.ckpt_dir, exist_ok=True)
+ self.stats_dir = f"{cfg.result_dir}/stats"
+ os.makedirs(self.stats_dir, exist_ok=True)
+ self.render_dir = f"{cfg.result_dir}/renders"
+ os.makedirs(self.render_dir, exist_ok=True)
+ self.ply_dir = f"{cfg.result_dir}/ply"
+ os.makedirs(self.ply_dir, exist_ok=True)
+
+ # Tensorboard
+ self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb")
+ self.trainset = PanoGSplatDataset(cfg.data_dir, split="train")
+ self.valset = PanoGSplatDataset(
+ cfg.data_dir, split="train", max_sample_num=6
+ )
+ self.testset = PanoGSplatDataset(cfg.data_dir, split="eval")
+ self.scene_scale = cfg.scene_scale
+
+ # Model
+ self.splats, self.optimizers = create_splats_with_optimizers(
+ self.trainset.points,
+ self.trainset.points_rgb,
+ init_num_pts=cfg.init_num_pts,
+ init_extent=cfg.init_extent,
+ init_opacity=cfg.init_opa,
+ init_scale=cfg.init_scale,
+ means_lr=cfg.means_lr,
+ scales_lr=cfg.scales_lr,
+ opacities_lr=cfg.opacities_lr,
+ quats_lr=cfg.quats_lr,
+ sh0_lr=cfg.sh0_lr,
+ shN_lr=cfg.shN_lr,
+ scene_scale=self.scene_scale,
+ sh_degree=cfg.sh_degree,
+ sparse_grad=cfg.sparse_grad,
+ visible_adam=cfg.visible_adam,
+ batch_size=cfg.batch_size,
+ feature_dim=None,
+ device=self.device,
+ world_rank=world_rank,
+ world_size=world_size,
+ )
+ print("Model initialized. Number of GS:", len(self.splats["means"]))
+
+ # Densification Strategy
+ self.cfg.strategy.check_sanity(self.splats, self.optimizers)
+
+ if isinstance(self.cfg.strategy, DefaultStrategy):
+ self.strategy_state = self.cfg.strategy.initialize_state(
+ scene_scale=self.scene_scale
+ )
+ elif isinstance(self.cfg.strategy, MCMCStrategy):
+ self.strategy_state = self.cfg.strategy.initialize_state()
+ else:
+ assert_never(self.cfg.strategy)
+
+ # Losses & Metrics.
+ self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(
+ self.device
+ )
+ self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device)
+
+ if cfg.lpips_net == "alex":
+ self.lpips = LearnedPerceptualImagePatchSimilarity(
+ net_type="alex", normalize=True
+ ).to(self.device)
+ elif cfg.lpips_net == "vgg":
+ # The 3DGS official repo uses lpips vgg, which is equivalent with the following:
+ self.lpips = LearnedPerceptualImagePatchSimilarity(
+ net_type="vgg", normalize=False
+ ).to(self.device)
+ else:
+ raise ValueError(f"Unknown LPIPS network: {cfg.lpips_net}")
+
+ def rasterize_splats(
+ self,
+ camtoworlds: Tensor,
+ Ks: Tensor,
+ width: int,
+ height: int,
+ masks: Optional[Tensor] = None,
+ rasterize_mode: Optional[Literal["classic", "antialiased"]] = None,
+ camera_model: Optional[Literal["pinhole", "ortho", "fisheye"]] = None,
+ **kwargs,
+ ) -> Tuple[Tensor, Tensor, Dict]:
+ means = self.splats["means"] # [N, 3]
+ # quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4]
+ # rasterization does normalization internally
+ quats = self.splats["quats"] # [N, 4]
+ scales = torch.exp(self.splats["scales"]) # [N, 3]
+ opacities = torch.sigmoid(self.splats["opacities"]) # [N,]
+ image_ids = kwargs.pop("image_ids", None)
+
+ colors = torch.cat(
+ [self.splats["sh0"], self.splats["shN"]], 1
+ ) # [N, K, 3]
+
+ if rasterize_mode is None:
+ rasterize_mode = (
+ "antialiased" if self.cfg.antialiased else "classic"
+ )
+ if camera_model is None:
+ camera_model = self.cfg.camera_model
+
+ render_colors, render_alphas, info = rasterization(
+ means=means,
+ quats=quats,
+ scales=scales,
+ opacities=opacities,
+ colors=colors,
+ viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4]
+ Ks=Ks, # [C, 3, 3]
+ width=width,
+ height=height,
+ packed=self.cfg.packed,
+ absgrad=(
+ self.cfg.strategy.absgrad
+ if isinstance(self.cfg.strategy, DefaultStrategy)
+ else False
+ ),
+ sparse_grad=self.cfg.sparse_grad,
+ rasterize_mode=rasterize_mode,
+ distributed=self.world_size > 1,
+ camera_model=self.cfg.camera_model,
+ with_ut=self.cfg.with_ut,
+ with_eval3d=self.cfg.with_eval3d,
+ **kwargs,
+ )
+ if masks is not None:
+ render_colors[~masks] = 0
+ return render_colors, render_alphas, info
+
+ def train(self):
+ cfg = self.cfg
+ device = self.device
+ world_rank = self.world_rank
+
+ # Dump cfg.
+ if world_rank == 0:
+ with open(f"{cfg.result_dir}/cfg.yml", "w") as f:
+ yaml.dump(vars(cfg), f)
+
+ max_steps = cfg.max_steps
+ init_step = 0
+
+ schedulers = [
+ # means has a learning rate schedule, that end at 0.01 of the initial value
+ torch.optim.lr_scheduler.ExponentialLR(
+ self.optimizers["means"], gamma=0.01 ** (1.0 / max_steps)
+ ),
+ ]
+ trainloader = torch.utils.data.DataLoader(
+ self.trainset,
+ batch_size=cfg.batch_size,
+ shuffle=True,
+ num_workers=4,
+ persistent_workers=True,
+ pin_memory=True,
+ )
+ trainloader_iter = iter(trainloader)
+
+ # Training loop.
+ global_tic = time.time()
+ pbar = tqdm.tqdm(range(init_step, max_steps))
+ for step in pbar:
+ try:
+ data = next(trainloader_iter)
+ except StopIteration:
+ trainloader_iter = iter(trainloader)
+ data = next(trainloader_iter)
+
+ camtoworlds = data["camtoworld"].to(device) # [1, 4, 4]
+ Ks = data["K"].to(device) # [1, 3, 3]
+ pixels = data["image"].to(device) / 255.0 # [1, H, W, 3]
+ image_ids = data["image_id"].to(device)
+ masks = (
+ data["mask"].to(device) if "mask" in data else None
+ ) # [1, H, W]
+ if cfg.depth_loss:
+ points = data["points"].to(device) # [1, M, 2]
+ depths_gt = data["depths"].to(device) # [1, M]
+
+ height, width = pixels.shape[1:3]
+
+ # sh schedule
+ sh_degree_to_use = min(
+ step // cfg.sh_degree_interval, cfg.sh_degree
+ )
+
+ # forward
+ renders, alphas, info = self.rasterize_splats(
+ camtoworlds=camtoworlds,
+ Ks=Ks,
+ width=width,
+ height=height,
+ sh_degree=sh_degree_to_use,
+ near_plane=cfg.near_plane,
+ far_plane=cfg.far_plane,
+ image_ids=image_ids,
+ render_mode="RGB+ED" if cfg.depth_loss else "RGB",
+ masks=masks,
+ )
+ if renders.shape[-1] == 4:
+ colors, depths = renders[..., 0:3], renders[..., 3:4]
+ else:
+ colors, depths = renders, None
+
+ if cfg.random_bkgd:
+ bkgd = torch.rand(1, 3, device=device)
+ colors = colors + bkgd * (1.0 - alphas)
+
+ self.cfg.strategy.step_pre_backward(
+ params=self.splats,
+ optimizers=self.optimizers,
+ state=self.strategy_state,
+ step=step,
+ info=info,
+ )
+
+ # loss
+ l1loss = F.l1_loss(colors, pixels)
+ ssimloss = 1.0 - fused_ssim(
+ colors.permute(0, 3, 1, 2),
+ pixels.permute(0, 3, 1, 2),
+ padding="valid",
+ )
+ loss = (
+ l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda
+ )
+ if cfg.depth_loss:
+ # query depths from depth map
+ points = torch.stack(
+ [
+ points[:, :, 0] / (width - 1) * 2 - 1,
+ points[:, :, 1] / (height - 1) * 2 - 1,
+ ],
+ dim=-1,
+ ) # normalize to [-1, 1]
+ grid = points.unsqueeze(2) # [1, M, 1, 2]
+ depths = F.grid_sample(
+ depths.permute(0, 3, 1, 2), grid, align_corners=True
+ ) # [1, 1, M, 1]
+ depths = depths.squeeze(3).squeeze(1) # [1, M]
+ # calculate loss in disparity space
+ disp = torch.where(
+ depths > 0.0, 1.0 / depths, torch.zeros_like(depths)
+ )
+ disp_gt = 1.0 / depths_gt # [1, M]
+ depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale
+ loss += depthloss * cfg.depth_lambda
+
+ # regularizations
+ if cfg.opacity_reg > 0.0:
+ loss += (
+ cfg.opacity_reg
+ * torch.sigmoid(self.splats["opacities"]).mean()
+ )
+ if cfg.scale_reg > 0.0:
+ loss += cfg.scale_reg * torch.exp(self.splats["scales"]).mean()
+
+ loss.backward()
+
+ desc = (
+ f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| "
+ )
+ if cfg.depth_loss:
+ desc += f"depth loss={depthloss.item():.6f}| "
+ pbar.set_description(desc)
+
+ # write images (gt and render)
+ # if world_rank == 0 and step % 800 == 0:
+ # canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy()
+ # canvas = canvas.reshape(-1, *canvas.shape[2:])
+ # imageio.imwrite(
+ # f"{self.render_dir}/train_rank{self.world_rank}.png",
+ # (canvas * 255).astype(np.uint8),
+ # )
+
+ if (
+ world_rank == 0
+ and cfg.tb_every > 0
+ and step % cfg.tb_every == 0
+ ):
+ mem = torch.cuda.max_memory_allocated() / 1024**3
+ self.writer.add_scalar("train/loss", loss.item(), step)
+ self.writer.add_scalar("train/l1loss", l1loss.item(), step)
+ self.writer.add_scalar("train/ssimloss", ssimloss.item(), step)
+ self.writer.add_scalar(
+ "train/num_GS", len(self.splats["means"]), step
+ )
+ self.writer.add_scalar("train/mem", mem, step)
+ if cfg.depth_loss:
+ self.writer.add_scalar(
+ "train/depthloss", depthloss.item(), step
+ )
+ if cfg.tb_save_image:
+ canvas = (
+ torch.cat([pixels, colors], dim=2)
+ .detach()
+ .cpu()
+ .numpy()
+ )
+ canvas = canvas.reshape(-1, *canvas.shape[2:])
+ self.writer.add_image("train/render", canvas, step)
+ self.writer.flush()
+
+ # save checkpoint before updating the model
+ if (
+ step in [i - 1 for i in cfg.save_steps]
+ or step == max_steps - 1
+ ):
+ mem = torch.cuda.max_memory_allocated() / 1024**3
+ stats = {
+ "mem": mem,
+ "ellipse_time": time.time() - global_tic,
+ "num_GS": len(self.splats["means"]),
+ }
+ print("Step: ", step, stats)
+ with open(
+ f"{self.stats_dir}/train_step{step:04d}_rank{self.world_rank}.json",
+ "w",
+ ) as f:
+ json.dump(stats, f)
+ data = {"step": step, "splats": self.splats.state_dict()}
+ torch.save(
+ data,
+ f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt",
+ )
+ if (
+ step in [i - 1 for i in cfg.ply_steps] or step == max_steps - 1
+ ) and cfg.save_ply:
+ sh0 = self.splats["sh0"]
+ shN = self.splats["shN"]
+ means = self.splats["means"]
+ scales = self.splats["scales"]
+ quats = self.splats["quats"]
+ opacities = self.splats["opacities"]
+ export_splats(
+ means=means,
+ scales=scales,
+ quats=quats,
+ opacities=opacities,
+ sh0=sh0,
+ shN=shN,
+ format="ply",
+ save_to=f"{self.ply_dir}/point_cloud_{step}.ply",
+ )
+
+ # Turn Gradients into Sparse Tensor before running optimizer
+ if cfg.sparse_grad:
+ assert (
+ cfg.packed
+ ), "Sparse gradients only work with packed mode."
+ gaussian_ids = info["gaussian_ids"]
+ for k in self.splats.keys():
+ grad = self.splats[k].grad
+ if grad is None or grad.is_sparse:
+ continue
+ self.splats[k].grad = torch.sparse_coo_tensor(
+ indices=gaussian_ids[None], # [1, nnz]
+ values=grad[gaussian_ids], # [nnz, ...]
+ size=self.splats[k].size(), # [N, ...]
+ is_coalesced=len(Ks) == 1,
+ )
+
+ if cfg.visible_adam:
+ gaussian_cnt = self.splats.means.shape[0]
+ if cfg.packed:
+ visibility_mask = torch.zeros_like(
+ self.splats["opacities"], dtype=bool
+ )
+ visibility_mask.scatter_(0, info["gaussian_ids"], 1)
+ else:
+ visibility_mask = (info["radii"] > 0).all(-1).any(0)
+
+ # optimize
+ for optimizer in self.optimizers.values():
+ if cfg.visible_adam:
+ optimizer.step(visibility_mask)
+ else:
+ optimizer.step()
+ optimizer.zero_grad(set_to_none=True)
+ for scheduler in schedulers:
+ scheduler.step()
+
+ # Run post-backward steps after backward and optimizer
+ if isinstance(self.cfg.strategy, DefaultStrategy):
+ self.cfg.strategy.step_post_backward(
+ params=self.splats,
+ optimizers=self.optimizers,
+ state=self.strategy_state,
+ step=step,
+ info=info,
+ packed=cfg.packed,
+ )
+ elif isinstance(self.cfg.strategy, MCMCStrategy):
+ self.cfg.strategy.step_post_backward(
+ params=self.splats,
+ optimizers=self.optimizers,
+ state=self.strategy_state,
+ step=step,
+ info=info,
+ lr=schedulers[0].get_last_lr()[0],
+ )
+ else:
+ assert_never(self.cfg.strategy)
+
+ # eval the full set
+ if step in [i - 1 for i in cfg.eval_steps]:
+ self.eval(step)
+ self.render_video(step)
+
+ @torch.no_grad()
+ def eval(
+ self,
+ step: int,
+ stage: str = "val",
+ canvas_h: int = 512,
+ canvas_w: int = 1024,
+ ):
+ """Entry for evaluation."""
+ print("Running evaluation...")
+ cfg = self.cfg
+ device = self.device
+ world_rank = self.world_rank
+
+ valloader = torch.utils.data.DataLoader(
+ self.valset, batch_size=1, shuffle=False, num_workers=1
+ )
+ ellipse_time = 0
+ metrics = defaultdict(list)
+ for i, data in enumerate(valloader):
+ camtoworlds = data["camtoworld"].to(device)
+ Ks = data["K"].to(device)
+ pixels = data["image"].to(device) / 255.0
+ height, width = pixels.shape[1:3]
+ masks = data["mask"].to(device) if "mask" in data else None
+
+ pixels = pixels.permute(0, 3, 1, 2) # NHWC -> NCHW
+ pixels = F.interpolate(pixels, size=(canvas_h, canvas_w // 2))
+
+ torch.cuda.synchronize()
+ tic = time.time()
+ colors, _, _ = self.rasterize_splats(
+ camtoworlds=camtoworlds,
+ Ks=Ks,
+ width=width,
+ height=height,
+ sh_degree=cfg.sh_degree,
+ near_plane=cfg.near_plane,
+ far_plane=cfg.far_plane,
+ masks=masks,
+ ) # [1, H, W, 3]
+ torch.cuda.synchronize()
+ ellipse_time += max(time.time() - tic, 1e-10)
+
+ colors = colors.permute(0, 3, 1, 2) # NHWC -> NCHW
+ colors = F.interpolate(colors, size=(canvas_h, canvas_w // 2))
+ colors = torch.clamp(colors, 0.0, 1.0)
+ canvas_list = [pixels, colors]
+
+ if world_rank == 0:
+ canvas = torch.cat(canvas_list, dim=2).squeeze(0)
+ canvas = canvas.permute(1, 2, 0) # CHW -> HWC
+ canvas = (canvas * 255).to(torch.uint8).cpu().numpy()
+ cv2.imwrite(
+ f"{self.render_dir}/{stage}_step{step}_{i:04d}.png",
+ canvas[..., ::-1],
+ )
+ metrics["psnr"].append(self.psnr(colors, pixels))
+ metrics["ssim"].append(self.ssim(colors, pixels))
+ metrics["lpips"].append(self.lpips(colors, pixels))
+
+ if world_rank == 0:
+ ellipse_time /= len(valloader)
+
+ stats = {
+ k: torch.stack(v).mean().item() for k, v in metrics.items()
+ }
+ stats.update(
+ {
+ "ellipse_time": ellipse_time,
+ "num_GS": len(self.splats["means"]),
+ }
+ )
+ print(
+ f"PSNR: {stats['psnr']:.3f}, SSIM: {stats['ssim']:.4f}, LPIPS: {stats['lpips']:.3f} "
+ f"Time: {stats['ellipse_time']:.3f}s/image "
+ f"Number of GS: {stats['num_GS']}"
+ )
+ # save stats as json
+ with open(
+ f"{self.stats_dir}/{stage}_step{step:04d}.json", "w"
+ ) as f:
+ json.dump(stats, f)
+ # save stats to tensorboard
+ for k, v in stats.items():
+ self.writer.add_scalar(f"{stage}/{k}", v, step)
+ self.writer.flush()
+
+ @torch.no_grad()
+ def render_video(
+ self, step: int, canvas_h: int = 512, canvas_w: int = 1024
+ ):
+ testloader = torch.utils.data.DataLoader(
+ self.testset, batch_size=1, shuffle=False, num_workers=1
+ )
+
+ images_cache = []
+ depth_global_min, depth_global_max = float("inf"), -float("inf")
+ for data in testloader:
+ camtoworlds = data["camtoworld"].to(self.device)
+ Ks = resize_pinhole_intrinsics(
+ data["K"].squeeze(),
+ raw_hw=(data["image_h"].item(), data["image_w"].item()),
+ new_hw=(canvas_h, canvas_w // 2),
+ ).to(self.device)
+ renders, _, _ = self.rasterize_splats(
+ camtoworlds=camtoworlds,
+ Ks=Ks[None, ...],
+ width=canvas_w // 2,
+ height=canvas_h,
+ sh_degree=self.cfg.sh_degree,
+ near_plane=self.cfg.near_plane,
+ far_plane=self.cfg.far_plane,
+ render_mode="RGB+ED",
+ ) # [1, H, W, 4]
+ colors = torch.clamp(renders[0, ..., 0:3], 0.0, 1.0) # [H, W, 3]
+ colors = (colors * 255).to(torch.uint8).cpu().numpy()
+ depths = renders[0, ..., 3:4] # [H, W, 1], tensor in device.
+ images_cache.append([colors, depths])
+ depth_global_min = min(depth_global_min, depths.min().item())
+ depth_global_max = max(depth_global_max, depths.max().item())
+
+ video_path = f"{self.render_dir}/video_step{step}.mp4"
+ writer = imageio.get_writer(video_path, fps=30)
+ for rgb, depth in images_cache:
+ depth_normalized = torch.clip(
+ (depth - depth_global_min)
+ / (depth_global_max - depth_global_min),
+ 0,
+ 1,
+ )
+ depth_normalized = (
+ (depth_normalized * 255).to(torch.uint8).cpu().numpy()
+ )
+ depth_map = cv2.applyColorMap(depth_normalized, cv2.COLORMAP_JET)
+ image = np.concatenate([rgb, depth_map], axis=1)
+ writer.append_data(image)
+
+ writer.close()
+
+
+def entrypoint(
+ local_rank: int, world_rank, world_size: int, cfg: GsplatTrainConfig
+):
+ runner = Runner(local_rank, world_rank, world_size, cfg)
+
+ if cfg.ckpt is not None:
+ # run eval only
+ ckpts = [
+ torch.load(file, map_location=runner.device, weights_only=True)
+ for file in cfg.ckpt
+ ]
+ for k in runner.splats.keys():
+ runner.splats[k].data = torch.cat(
+ [ckpt["splats"][k] for ckpt in ckpts]
+ )
+ step = ckpts[0]["step"]
+ runner.eval(step=step)
+ runner.render_video(step=step)
+ else:
+ runner.train()
+ runner.render_video(step=cfg.max_steps - 1)
+
+
+if __name__ == "__main__":
+ configs = {
+ "default": (
+ "Gaussian splatting training using densification heuristics from the original paper.",
+ GsplatTrainConfig(
+ strategy=DefaultStrategy(verbose=True),
+ ),
+ ),
+ "mcmc": (
+ "Gaussian splatting training using densification from the paper '3D Gaussian Splatting as Markov Chain Monte Carlo'.",
+ GsplatTrainConfig(
+ init_scale=0.1,
+ opacity_reg=0.01,
+ scale_reg=0.01,
+ strategy=MCMCStrategy(verbose=True),
+ ),
+ ),
+ }
+ cfg = tyro.extras.overridable_config_cli(configs)
+ cfg.adjust_steps(cfg.steps_scaler)
+
+ cli(entrypoint, cfg, verbose=True)
diff --git a/embodied_gen/trainer/pono2mesh_trainer.py b/embodied_gen/trainer/pono2mesh_trainer.py
new file mode 100644
index 0000000..b234b79
--- /dev/null
+++ b/embodied_gen/trainer/pono2mesh_trainer.py
@@ -0,0 +1,538 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+
+from embodied_gen.utils.monkey_patches import monkey_patch_pano2room
+
+monkey_patch_pano2room()
+
+import os
+
+import cv2
+import numpy as np
+import torch
+import trimesh
+from equilib import cube2equi, equi2pers
+from kornia.morphology import dilation
+from PIL import Image
+from embodied_gen.models.sr_model import ImageRealESRGAN
+from embodied_gen.utils.config import Pano2MeshSRConfig
+from embodied_gen.utils.gaussian import compute_pinhole_intrinsics
+from embodied_gen.utils.log import logger
+from thirdparty.pano2room.modules.geo_predictors import PanoJointPredictor
+from thirdparty.pano2room.modules.geo_predictors.PanoFusionDistancePredictor import (
+ PanoFusionDistancePredictor,
+)
+from thirdparty.pano2room.modules.inpainters import PanoPersFusionInpainter
+from thirdparty.pano2room.modules.mesh_fusion.render import (
+ features_to_world_space_mesh,
+ render_mesh,
+)
+from thirdparty.pano2room.modules.mesh_fusion.sup_info import SupInfoPool
+from thirdparty.pano2room.utils.camera_utils import gen_pano_rays
+from thirdparty.pano2room.utils.functions import (
+ depth_to_distance,
+ get_cubemap_views_world_to_cam,
+ resize_image_with_aspect_ratio,
+ rot_z_world_to_cam,
+ tensor_to_pil,
+)
+
+
+class Pano2MeshSRPipeline:
+ """Converting panoramic RGB image into 3D mesh representations, followed by inpainting and mesh refinement.
+
+ This class integrates several key components including:
+ - Depth estimation from RGB panorama
+ - Inpainting of missing regions under offsets
+ - RGB-D to mesh conversion
+ - Multi-view mesh repair
+ - 3D Gaussian Splatting (3DGS) dataset generation
+
+ Args:
+ config (Pano2MeshSRConfig): Configuration object containing model and pipeline parameters.
+
+ Example:
+ ```python
+ pipeline = Pano2MeshSRPipeline(config)
+ pipeline(pano_image='example.png', output_dir='./output')
+ ```
+ """
+
+ def __init__(self, config: Pano2MeshSRConfig) -> None:
+ self.cfg = config
+ self.device = config.device
+
+ # Init models.
+ self.inpainter = PanoPersFusionInpainter(save_path=None)
+ self.geo_predictor = PanoJointPredictor(save_path=None)
+ self.pano_fusion_distance_predictor = PanoFusionDistancePredictor()
+ self.super_model = ImageRealESRGAN(outscale=self.cfg.upscale_factor)
+
+ # Init poses.
+ cubemap_w2cs = get_cubemap_views_world_to_cam()
+ self.cubemap_w2cs = [p.to(self.device) for p in cubemap_w2cs]
+ self.camera_poses = self.load_camera_poses(self.cfg.trajectory_dir)
+
+ kernel = cv2.getStructuringElement(
+ cv2.MORPH_ELLIPSE, self.cfg.kernel_size
+ )
+ self.kernel = torch.from_numpy(kernel).float().to(self.device)
+
+ def init_mesh_params(self) -> None:
+ torch.set_default_device(self.device)
+ self.inpaint_mask = torch.ones(
+ (self.cfg.cubemap_h, self.cfg.cubemap_w), dtype=torch.bool
+ )
+ self.vertices = torch.empty((3, 0), requires_grad=False)
+ self.colors = torch.empty((3, 0), requires_grad=False)
+ self.faces = torch.empty((3, 0), dtype=torch.long, requires_grad=False)
+
+ @staticmethod
+ def read_camera_pose_file(filepath: str) -> np.ndarray:
+ with open(filepath, "r") as f:
+ values = [float(num) for line in f for num in line.split()]
+
+ return np.array(values).reshape(4, 4)
+
+ def load_camera_poses(
+ self, trajectory_dir: str
+ ) -> tuple[np.ndarray, list[torch.Tensor]]:
+ pose_filenames = sorted(
+ [
+ fname
+ for fname in os.listdir(trajectory_dir)
+ if fname.startswith("camera_pose")
+ ]
+ )
+
+ pano_pose_world = None
+ relative_poses = []
+ for idx, filename in enumerate(pose_filenames):
+ pose_path = os.path.join(trajectory_dir, filename)
+ pose_matrix = self.read_camera_pose_file(pose_path)
+
+ if pano_pose_world is None:
+ pano_pose_world = pose_matrix.copy()
+ pano_pose_world[0, 3] += self.cfg.pano_center_offset[0]
+ pano_pose_world[2, 3] += self.cfg.pano_center_offset[1]
+
+ # Use different reference for the first 6 cubemap views
+ reference_pose = pose_matrix if idx < 6 else pano_pose_world
+ relative_matrix = pose_matrix @ np.linalg.inv(reference_pose)
+ relative_matrix[0:2, :] *= -1 # flip_xy
+ relative_matrix = (
+ relative_matrix @ rot_z_world_to_cam(180).cpu().numpy()
+ )
+ relative_matrix[:3, 3] *= self.cfg.pose_scale
+ relative_matrix = torch.tensor(
+ relative_matrix, dtype=torch.float32
+ )
+ relative_poses.append(relative_matrix)
+
+ return relative_poses
+
+ def load_inpaint_poses(
+ self, poses: torch.Tensor
+ ) -> dict[int, torch.Tensor]:
+ inpaint_poses = dict()
+ sampled_views = poses[:: self.cfg.inpaint_frame_stride]
+ init_pose = torch.eye(4)
+ for idx, w2c_tensor in enumerate(sampled_views):
+ w2c = w2c_tensor.cpu().numpy().astype(np.float32)
+ c2w = np.linalg.inv(w2c)
+ pose_tensor = init_pose.clone()
+ pose_tensor[:3, 3] = torch.from_numpy(c2w[:3, 3])
+ pose_tensor[:3, 3] *= -1
+ inpaint_poses[idx] = pose_tensor.to(self.device)
+
+ return inpaint_poses
+
+ def project(self, world_to_cam: torch.Tensor):
+ (
+ project_image,
+ project_depth,
+ inpaint_mask,
+ _,
+ z_buf,
+ mesh,
+ ) = render_mesh(
+ vertices=self.vertices,
+ faces=self.faces,
+ vertex_features=self.colors,
+ H=self.cfg.cubemap_h,
+ W=self.cfg.cubemap_w,
+ fov_in_degrees=self.cfg.fov,
+ RT=world_to_cam,
+ blur_radius=self.cfg.blur_radius,
+ faces_per_pixel=self.cfg.faces_per_pixel,
+ )
+ project_image = project_image * ~inpaint_mask
+
+ return project_image[:3, ...], inpaint_mask, project_depth
+
+ def render_pano(self, pose: torch.Tensor):
+ cubemap_list = []
+ for cubemap_pose in self.cubemap_w2cs:
+ project_pose = cubemap_pose @ pose
+ rgb, inpaint_mask, depth = self.project(project_pose)
+ distance_map = depth_to_distance(depth[None, ...])
+ mask = inpaint_mask[None, ...]
+ cubemap_list.append(torch.cat([rgb, distance_map, mask], dim=0))
+
+ # Set default tensor type for CPU operation in cube2equi
+ with torch.device("cpu"):
+ pano_rgbd = cube2equi(
+ cubemap_list, "list", self.cfg.pano_h, self.cfg.pano_w
+ )
+
+ pano_rgb = pano_rgbd[:3, :, :]
+ pano_depth = pano_rgbd[3:4, :, :].squeeze(0)
+ pano_mask = pano_rgbd[4:, :, :].squeeze(0)
+
+ return pano_rgb, pano_depth, pano_mask
+
+ def rgbd_to_mesh(
+ self,
+ rgb: torch.Tensor,
+ depth: torch.Tensor,
+ inpaint_mask: torch.Tensor,
+ world_to_cam: torch.Tensor = None,
+ using_distance_map: bool = True,
+ ) -> None:
+ if world_to_cam is None:
+ world_to_cam = torch.eye(4, dtype=torch.float32).to(self.device)
+
+ if inpaint_mask.sum() == 0:
+ return
+
+ vertices, faces, colors = features_to_world_space_mesh(
+ colors=rgb.squeeze(0),
+ depth=depth,
+ fov_in_degrees=self.cfg.fov,
+ world_to_cam=world_to_cam,
+ mask=inpaint_mask,
+ faces=self.faces,
+ vertices=self.vertices,
+ using_distance_map=using_distance_map,
+ edge_threshold=0.05,
+ )
+
+ faces += self.vertices.shape[1]
+ self.vertices = torch.cat([self.vertices, vertices], dim=1)
+ self.colors = torch.cat([self.colors, colors], dim=1)
+ self.faces = torch.cat([self.faces, faces], dim=1)
+
+ def get_edge_image_by_depth(
+ self, depth: torch.Tensor, dilate_iter: int = 1
+ ) -> np.ndarray:
+ if isinstance(depth, torch.Tensor):
+ depth = depth.cpu().detach().numpy()
+
+ gray = (depth / depth.max() * 255).astype(np.uint8)
+ edges = cv2.Canny(gray, 60, 150)
+ if dilate_iter > 0:
+ kernel = np.ones((3, 3), np.uint8)
+ edges = cv2.dilate(edges, kernel, iterations=dilate_iter)
+
+ return edges
+
+ def mesh_repair_by_greedy_view_selection(
+ self, pose_dict: dict[str, torch.Tensor], output_dir: str
+ ) -> list:
+ inpainted_panos_w_pose = []
+ while len(pose_dict) > 0:
+ logger.info(f"Repairing mesh left rounds {len(pose_dict)}")
+ sampled_views = []
+ for key, pose in pose_dict.items():
+ pano_rgb, pano_distance, pano_mask = self.render_pano(pose)
+ completeness = torch.sum(1 - pano_mask) / (pano_mask.numel())
+ sampled_views.append((key, completeness.item(), pose))
+
+ if len(sampled_views) == 0:
+ break
+
+ # Find inpainting with least view completeness.
+ sampled_views = sorted(sampled_views, key=lambda x: x[1])
+ key, _, pose = sampled_views[len(sampled_views) * 2 // 3]
+ pose_dict.pop(key)
+
+ pano_rgb, pano_distance, pano_mask = self.render_pano(pose)
+
+ colors = pano_rgb.permute(1, 2, 0).clone()
+ distances = pano_distance.unsqueeze(-1).clone()
+ pano_inpaint_mask = pano_mask.clone()
+ init_pose = pose.clone()
+ normals = None
+ if pano_inpaint_mask.min().item() < 0.5:
+ colors, distances, normals = self.inpaint_panorama(
+ idx=key,
+ colors=colors,
+ distances=distances,
+ pano_mask=pano_inpaint_mask,
+ )
+
+ init_pose[0, 3], init_pose[1, 3], init_pose[2, 3] = (
+ -pose[0, 3],
+ pose[2, 3],
+ 0,
+ )
+ rays = gen_pano_rays(
+ init_pose, self.cfg.pano_h, self.cfg.pano_w
+ )
+ conflict_mask = self.sup_pool.geo_check(
+ rays, distances.unsqueeze(-1)
+ ) # 0 is conflict, 1 not conflict
+ pano_inpaint_mask *= conflict_mask
+
+ self.rgbd_to_mesh(
+ colors.permute(2, 0, 1),
+ distances,
+ pano_inpaint_mask,
+ world_to_cam=pose,
+ )
+
+ self.sup_pool.register_sup_info(
+ pose=init_pose,
+ mask=pano_inpaint_mask.clone(),
+ rgb=colors,
+ distance=distances.unsqueeze(-1),
+ normal=normals,
+ )
+
+ colors = colors.permute(2, 0, 1).unsqueeze(0)
+ inpainted_panos_w_pose.append([colors, pose])
+
+ if self.cfg.visualize:
+ from embodied_gen.data.utils import DiffrastRender
+
+ tensor_to_pil(pano_rgb.unsqueeze(0)).save(
+ f"{output_dir}/rendered_pano_{key}.jpg"
+ )
+ tensor_to_pil(colors).save(
+ f"{output_dir}/inpainted_pano_{key}.jpg"
+ )
+ norm_depth = DiffrastRender.normalize_map_by_mask(
+ distances, torch.ones_like(distances)
+ )
+ heatmap = (norm_depth.cpu().numpy() * 255).astype(np.uint8)
+ heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
+ Image.fromarray(heatmap).save(
+ f"{output_dir}/inpainted_depth_{key}.png"
+ )
+
+ return inpainted_panos_w_pose
+
+ def inpaint_panorama(
+ self,
+ idx: int,
+ colors: torch.Tensor,
+ distances: torch.Tensor,
+ pano_mask: torch.Tensor,
+ ) -> tuple[torch.Tensor]:
+ mask = (pano_mask[None, ..., None] > 0.5).float()
+ mask = mask.permute(0, 3, 1, 2)
+ mask = dilation(mask, kernel=self.kernel)
+ mask = mask[0, 0, ..., None] # hwc
+ inpainted_img = self.inpainter.inpaint(idx, colors, mask)
+ inpainted_img = colors * (1 - mask) + inpainted_img * mask
+ inpainted_distances, inpainted_normals = self.geo_predictor(
+ idx,
+ inpainted_img,
+ distances[..., None],
+ mask=mask,
+ reg_loss_weight=0.0,
+ normal_loss_weight=5e-2,
+ normal_tv_loss_weight=5e-2,
+ )
+
+ return inpainted_img, inpainted_distances.squeeze(), inpainted_normals
+
+ def preprocess_pano(
+ self, image: Image.Image | str
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ if isinstance(image, str):
+ image = Image.open(image)
+
+ image = image.convert("RGB")
+
+ if image.size[0] < image.size[1]:
+ image = image.transpose(Image.TRANSPOSE)
+
+ image = resize_image_with_aspect_ratio(image, self.cfg.pano_w)
+ image_rgb = torch.tensor(np.array(image)).permute(2, 0, 1) / 255
+ image_rgb = image_rgb.to(self.device)
+ image_depth = self.pano_fusion_distance_predictor.predict(
+ image_rgb.permute(1, 2, 0)
+ )
+ image_depth = (
+ image_depth / image_depth.max() * self.cfg.depth_scale_factor
+ )
+
+ return image_rgb, image_depth
+
+ def pano_to_perpective(
+ self, pano_image: torch.Tensor, pitch: float, yaw: float, fov: float
+ ) -> torch.Tensor:
+ rots = dict(
+ roll=0,
+ pitch=pitch,
+ yaw=yaw,
+ )
+ perspective = equi2pers(
+ equi=pano_image.squeeze(0),
+ rots=rots,
+ height=self.cfg.cubemap_h,
+ width=self.cfg.cubemap_w,
+ fov_x=fov,
+ mode="bilinear",
+ ).unsqueeze(0)
+
+ return perspective
+
+ def pano_to_cubemap(self, pano_rgb: torch.Tensor):
+ # Define six canonical cube directions in (pitch, yaw)
+ directions = [
+ (0, 0),
+ (0, 1.5 * np.pi),
+ (0, 1.0 * np.pi),
+ (0, 0.5 * np.pi),
+ (-0.5 * np.pi, 0),
+ (0.5 * np.pi, 0),
+ ]
+
+ cubemaps_rgb = []
+ for pitch, yaw in directions:
+ rgb_view = self.pano_to_perpective(
+ pano_rgb, pitch, yaw, fov=self.cfg.fov
+ )
+ cubemaps_rgb.append(rgb_view.cpu())
+
+ return cubemaps_rgb
+
+ def save_mesh(self, output_path: str) -> None:
+ vertices_np = self.vertices.T.cpu().numpy()
+ colors_np = self.colors.T.cpu().numpy()
+ faces_np = self.faces.T.cpu().numpy()
+ mesh = trimesh.Trimesh(
+ vertices=vertices_np, faces=faces_np, vertex_colors=colors_np
+ )
+
+ mesh.export(output_path)
+
+ def mesh_pose_to_gs_pose(self, mesh_pose: torch.Tensor) -> np.ndarray:
+ pose = mesh_pose.clone()
+ pose[0, :] *= -1
+ pose[1, :] *= -1
+
+ Rw2c = pose[:3, :3].cpu().numpy()
+ Tw2c = pose[:3, 3:].cpu().numpy()
+ yz_reverse = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]])
+
+ Rc2w = (yz_reverse @ Rw2c).T
+ Tc2w = -(Rc2w @ yz_reverse @ Tw2c)
+ c2w = np.concatenate((Rc2w, Tc2w), axis=1)
+ c2w = np.concatenate((c2w, np.array([[0, 0, 0, 1]])), axis=0)
+
+ return c2w
+
+ def __call__(self, pano_image: Image.Image | str, output_dir: str):
+ self.init_mesh_params()
+ pano_rgb, pano_depth = self.preprocess_pano(pano_image)
+ self.sup_pool = SupInfoPool()
+ self.sup_pool.register_sup_info(
+ pose=torch.eye(4).to(self.device),
+ mask=torch.ones([self.cfg.pano_h, self.cfg.pano_w]),
+ rgb=pano_rgb.permute(1, 2, 0),
+ distance=pano_depth[..., None],
+ )
+ self.sup_pool.gen_occ_grid(res=256)
+
+ logger.info("Init mesh from pano RGBD image...")
+ depth_edge = self.get_edge_image_by_depth(pano_depth)
+ inpaint_edge_mask = (
+ ~torch.from_numpy(depth_edge).to(self.device).bool()
+ )
+ self.rgbd_to_mesh(pano_rgb, pano_depth, inpaint_edge_mask)
+
+ repair_poses = self.load_inpaint_poses(self.camera_poses)
+ inpainted_panos_w_poses = self.mesh_repair_by_greedy_view_selection(
+ repair_poses, output_dir
+ )
+ torch.cuda.empty_cache()
+ torch.set_default_device("cpu")
+
+ if self.cfg.mesh_file is not None:
+ mesh_path = os.path.join(output_dir, self.cfg.mesh_file)
+ self.save_mesh(mesh_path)
+
+ if self.cfg.gs_data_file is None:
+ return
+
+ logger.info(f"Dump data for 3DGS training...")
+ points_rgb = (self.colors.clip(0, 1) * 255).to(torch.uint8)
+ data = {
+ "points": self.vertices.permute(1, 0).cpu().numpy(), # (N, 3)
+ "points_rgb": points_rgb.permute(1, 0).cpu().numpy(), # (N, 3)
+ "train": [],
+ "eval": [],
+ }
+ image_h = self.cfg.cubemap_h * self.cfg.upscale_factor
+ image_w = self.cfg.cubemap_w * self.cfg.upscale_factor
+ Ks = compute_pinhole_intrinsics(image_w, image_h, self.cfg.fov)
+ for idx, (pano_img, pano_pose) in enumerate(inpainted_panos_w_poses):
+ cubemaps = self.pano_to_cubemap(pano_img)
+ for i in range(len(cubemaps)):
+ cubemap = tensor_to_pil(cubemaps[i])
+ cubemap = self.super_model(cubemap)
+ mesh_pose = self.cubemap_w2cs[i] @ pano_pose
+ c2w = self.mesh_pose_to_gs_pose(mesh_pose)
+ data["train"].append(
+ {
+ "camtoworld": c2w.astype(np.float32),
+ "K": Ks.astype(np.float32),
+ "image": np.array(cubemap),
+ "image_h": image_h,
+ "image_w": image_w,
+ "image_id": len(cubemaps) * idx + i,
+ }
+ )
+
+ # Camera poses for evaluation.
+ for idx in range(len(self.camera_poses)):
+ c2w = self.mesh_pose_to_gs_pose(self.camera_poses[idx])
+ data["eval"].append(
+ {
+ "camtoworld": c2w.astype(np.float32),
+ "K": Ks.astype(np.float32),
+ "image_h": image_h,
+ "image_w": image_w,
+ "image_id": idx,
+ }
+ )
+
+ data_path = os.path.join(output_dir, self.cfg.gs_data_file)
+ torch.save(data, data_path)
+
+ return
+
+
+if __name__ == "__main__":
+ output_dir = "outputs/bg_v2/test3"
+ input_pano = "apps/assets/example_scene/result_pano.png"
+ config = Pano2MeshSRConfig()
+ pipeline = Pano2MeshSRPipeline(config)
+ pipeline(input_pano, output_dir)
diff --git a/embodied_gen/utils/config.py b/embodied_gen/utils/config.py
new file mode 100644
index 0000000..8c08590
--- /dev/null
+++ b/embodied_gen/utils/config.py
@@ -0,0 +1,190 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+from dataclasses import dataclass, field
+from typing import List, Optional, Union
+
+from gsplat.strategy import DefaultStrategy, MCMCStrategy
+from typing_extensions import Literal, assert_never
+
+__all__ = [
+ "Pano2MeshSRConfig",
+ "GsplatTrainConfig",
+]
+
+
+@dataclass
+class Pano2MeshSRConfig:
+ mesh_file: str = "mesh_model.ply"
+ gs_data_file: str = "gs_data.pt"
+ device: str = "cuda"
+ blur_radius: int = 0
+ faces_per_pixel: int = 8
+ fov: int = 90
+ pano_w: int = 2048
+ pano_h: int = 1024
+ cubemap_w: int = 512
+ cubemap_h: int = 512
+ pose_scale: float = 0.6
+ pano_center_offset: tuple = (-0.2, 0.3)
+ inpaint_frame_stride: int = 20
+ trajectory_dir: str = "apps/assets/example_scene/camera_trajectory"
+ visualize: bool = False
+ depth_scale_factor: float = 3.4092
+ kernel_size: tuple = (9, 9)
+ upscale_factor: int = 4
+
+
+@dataclass
+class GsplatTrainConfig:
+ # Path to the .pt files. If provide, it will skip training and run evaluation only.
+ ckpt: Optional[List[str]] = None
+ # Render trajectory path
+ render_traj_path: str = "interp"
+
+ # Path to the Mip-NeRF 360 dataset
+ data_dir: str = "outputs/bg"
+ # Downsample factor for the dataset
+ data_factor: int = 4
+ # Directory to save results
+ result_dir: str = "outputs/bg"
+ # Every N images there is a test image
+ test_every: int = 8
+ # Random crop size for training (experimental)
+ patch_size: Optional[int] = None
+ # A global scaler that applies to the scene size related parameters
+ global_scale: float = 1.0
+ # Normalize the world space
+ normalize_world_space: bool = True
+ # Camera model
+ camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole"
+
+ # Port for the viewer server
+ port: int = 8080
+
+ # Batch size for training. Learning rates are scaled automatically
+ batch_size: int = 1
+ # A global factor to scale the number of training steps
+ steps_scaler: float = 1.0
+
+ # Number of training steps
+ max_steps: int = 30_000
+ # Steps to evaluate the model
+ eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
+ # Steps to save the model
+ save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
+ # Whether to save ply file (storage size can be large)
+ save_ply: bool = True
+ # Steps to save the model as ply
+ ply_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
+ # Whether to disable video generation during training and evaluation
+ disable_video: bool = False
+
+ # Initial number of GSs. Ignored if using sfm
+ init_num_pts: int = 100_000
+ # Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm
+ init_extent: float = 3.0
+ # Degree of spherical harmonics
+ sh_degree: int = 1
+ # Turn on another SH degree every this steps
+ sh_degree_interval: int = 1000
+ # Initial opacity of GS
+ init_opa: float = 0.1
+ # Initial scale of GS
+ init_scale: float = 1.0
+ # Weight for SSIM loss
+ ssim_lambda: float = 0.2
+
+ # Near plane clipping distance
+ near_plane: float = 0.01
+ # Far plane clipping distance
+ far_plane: float = 1e10
+
+ # Strategy for GS densification
+ strategy: Union[DefaultStrategy, MCMCStrategy] = field(
+ default_factory=DefaultStrategy
+ )
+ # Use packed mode for rasterization, this leads to less memory usage but slightly slower.
+ packed: bool = False
+ # Use sparse gradients for optimization. (experimental)
+ sparse_grad: bool = False
+ # Use visible adam from Taming 3DGS. (experimental)
+ visible_adam: bool = False
+ # Anti-aliasing in rasterization. Might slightly hurt quantitative metrics.
+ antialiased: bool = False
+
+ # Use random background for training to discourage transparency
+ random_bkgd: bool = False
+
+ # LR for 3D point positions
+ means_lr: float = 1.6e-4
+ # LR for Gaussian scale factors
+ scales_lr: float = 5e-3
+ # LR for alpha blending weights
+ opacities_lr: float = 5e-2
+ # LR for orientation (quaternions)
+ quats_lr: float = 1e-3
+ # LR for SH band 0 (brightness)
+ sh0_lr: float = 2.5e-3
+ # LR for higher-order SH (detail)
+ shN_lr: float = 2.5e-3 / 20
+
+ # Opacity regularization
+ opacity_reg: float = 0.0
+ # Scale regularization
+ scale_reg: float = 0.0
+
+ # Enable depth loss. (experimental)
+ depth_loss: bool = False
+ # Weight for depth loss
+ depth_lambda: float = 1e-2
+
+ # Dump information to tensorboard every this steps
+ tb_every: int = 200
+ # Save training images to tensorboard
+ tb_save_image: bool = False
+
+ lpips_net: Literal["vgg", "alex"] = "alex"
+
+ # 3DGUT (uncented transform + eval 3D)
+ with_ut: bool = False
+ with_eval3d: bool = False
+
+ scene_scale: float = 1.0
+
+ def adjust_steps(self, factor: float):
+ self.eval_steps = [int(i * factor) for i in self.eval_steps]
+ self.save_steps = [int(i * factor) for i in self.save_steps]
+ self.ply_steps = [int(i * factor) for i in self.ply_steps]
+ self.max_steps = int(self.max_steps * factor)
+ self.sh_degree_interval = int(self.sh_degree_interval * factor)
+
+ strategy = self.strategy
+ if isinstance(strategy, DefaultStrategy):
+ strategy.refine_start_iter = int(
+ strategy.refine_start_iter * factor
+ )
+ strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor)
+ strategy.reset_every = int(strategy.reset_every * factor)
+ strategy.refine_every = int(strategy.refine_every * factor)
+ elif isinstance(strategy, MCMCStrategy):
+ strategy.refine_start_iter = int(
+ strategy.refine_start_iter * factor
+ )
+ strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor)
+ strategy.refine_every = int(strategy.refine_every * factor)
+ else:
+ assert_never(strategy)
diff --git a/embodied_gen/utils/gaussian.py b/embodied_gen/utils/gaussian.py
new file mode 100644
index 0000000..68ee73e
--- /dev/null
+++ b/embodied_gen/utils/gaussian.py
@@ -0,0 +1,331 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+# Part of the code comes from https://github.com/nerfstudio-project/gsplat
+# Both under the Apache License, Version 2.0.
+
+
+import math
+import random
+from io import BytesIO
+from typing import Dict, Literal, Optional, Tuple
+
+import numpy as np
+import torch
+import trimesh
+from gsplat.optimizers import SelectiveAdam
+from scipy.spatial.transform import Rotation
+from sklearn.neighbors import NearestNeighbors
+from torch import Tensor
+from embodied_gen.models.gs_model import GaussianOperator
+
+__all__ = [
+ "set_random_seed",
+ "export_splats",
+ "create_splats_with_optimizers",
+ "compute_pinhole_intrinsics",
+ "resize_pinhole_intrinsics",
+ "restore_scene_scale_and_position",
+]
+
+
+def knn(x: Tensor, K: int = 4) -> Tensor:
+ x_np = x.cpu().numpy()
+ model = NearestNeighbors(n_neighbors=K, metric="euclidean").fit(x_np)
+ distances, _ = model.kneighbors(x_np)
+ return torch.from_numpy(distances).to(x)
+
+
+def rgb_to_sh(rgb: Tensor) -> Tensor:
+ C0 = 0.28209479177387814
+ return (rgb - 0.5) / C0
+
+
+def set_random_seed(seed: int):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+
+
+def splat2ply_bytes(
+ means: torch.Tensor,
+ scales: torch.Tensor,
+ quats: torch.Tensor,
+ opacities: torch.Tensor,
+ sh0: torch.Tensor,
+ shN: torch.Tensor,
+) -> bytes:
+ num_splats = means.shape[0]
+ buffer = BytesIO()
+
+ # Write PLY header
+ buffer.write(b"ply\n")
+ buffer.write(b"format binary_little_endian 1.0\n")
+ buffer.write(f"element vertex {num_splats}\n".encode())
+ buffer.write(b"property float x\n")
+ buffer.write(b"property float y\n")
+ buffer.write(b"property float z\n")
+ for i, data in enumerate([sh0, shN]):
+ prefix = "f_dc" if i == 0 else "f_rest"
+ for j in range(data.shape[1]):
+ buffer.write(f"property float {prefix}_{j}\n".encode())
+ buffer.write(b"property float opacity\n")
+ for i in range(scales.shape[1]):
+ buffer.write(f"property float scale_{i}\n".encode())
+ for i in range(quats.shape[1]):
+ buffer.write(f"property float rot_{i}\n".encode())
+ buffer.write(b"end_header\n")
+
+ # Concatenate all tensors in the correct order
+ splat_data = torch.cat(
+ [means, sh0, shN, opacities.unsqueeze(1), scales, quats], dim=1
+ )
+ # Ensure correct dtype
+ splat_data = splat_data.to(torch.float32)
+
+ # Write binary data
+ float_dtype = np.dtype(np.float32).newbyteorder("<")
+ buffer.write(
+ splat_data.detach().cpu().numpy().astype(float_dtype).tobytes()
+ )
+
+ return buffer.getvalue()
+
+
+def export_splats(
+ means: torch.Tensor,
+ scales: torch.Tensor,
+ quats: torch.Tensor,
+ opacities: torch.Tensor,
+ sh0: torch.Tensor,
+ shN: torch.Tensor,
+ format: Literal["ply"] = "ply",
+ save_to: Optional[str] = None,
+) -> bytes:
+ """Export a Gaussian Splats model to bytes in PLY file format."""
+ total_splats = means.shape[0]
+ assert means.shape == (total_splats, 3), "Means must be of shape (N, 3)"
+ assert scales.shape == (total_splats, 3), "Scales must be of shape (N, 3)"
+ assert quats.shape == (
+ total_splats,
+ 4,
+ ), "Quaternions must be of shape (N, 4)"
+ assert opacities.shape == (
+ total_splats,
+ ), "Opacities must be of shape (N,)"
+ assert sh0.shape == (total_splats, 1, 3), "sh0 must be of shape (N, 1, 3)"
+ assert (
+ shN.ndim == 3 and shN.shape[0] == total_splats and shN.shape[2] == 3
+ ), f"shN must be of shape (N, K, 3), got {shN.shape}"
+
+ # Reshape spherical harmonics
+ sh0 = sh0.squeeze(1) # Shape (N, 3)
+ shN = shN.permute(0, 2, 1).reshape(means.shape[0], -1) # Shape (N, K * 3)
+
+ # Check for NaN or Inf values
+ invalid_mask = (
+ torch.isnan(means).any(dim=1)
+ | torch.isinf(means).any(dim=1)
+ | torch.isnan(scales).any(dim=1)
+ | torch.isinf(scales).any(dim=1)
+ | torch.isnan(quats).any(dim=1)
+ | torch.isinf(quats).any(dim=1)
+ | torch.isnan(opacities).any(dim=0)
+ | torch.isinf(opacities).any(dim=0)
+ | torch.isnan(sh0).any(dim=1)
+ | torch.isinf(sh0).any(dim=1)
+ | torch.isnan(shN).any(dim=1)
+ | torch.isinf(shN).any(dim=1)
+ )
+
+ # Filter out invalid entries
+ valid_mask = ~invalid_mask
+ means = means[valid_mask]
+ scales = scales[valid_mask]
+ quats = quats[valid_mask]
+ opacities = opacities[valid_mask]
+ sh0 = sh0[valid_mask]
+ shN = shN[valid_mask]
+
+ if format == "ply":
+ data = splat2ply_bytes(means, scales, quats, opacities, sh0, shN)
+ else:
+ raise ValueError(f"Unsupported format: {format}")
+
+ if save_to:
+ with open(save_to, "wb") as binary_file:
+ binary_file.write(data)
+
+ return data
+
+
+def create_splats_with_optimizers(
+ points: np.ndarray = None,
+ points_rgb: np.ndarray = None,
+ init_num_pts: int = 100_000,
+ init_extent: float = 3.0,
+ init_opacity: float = 0.1,
+ init_scale: float = 1.0,
+ means_lr: float = 1.6e-4,
+ scales_lr: float = 5e-3,
+ opacities_lr: float = 5e-2,
+ quats_lr: float = 1e-3,
+ sh0_lr: float = 2.5e-3,
+ shN_lr: float = 2.5e-3 / 20,
+ scene_scale: float = 1.0,
+ sh_degree: int = 3,
+ sparse_grad: bool = False,
+ visible_adam: bool = False,
+ batch_size: int = 1,
+ feature_dim: Optional[int] = None,
+ device: str = "cuda",
+ world_rank: int = 0,
+ world_size: int = 1,
+) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]:
+ if points is not None and points_rgb is not None:
+ points = torch.from_numpy(points).float()
+ rgbs = torch.from_numpy(points_rgb / 255.0).float()
+ else:
+ points = (
+ init_extent * scene_scale * (torch.rand((init_num_pts, 3)) * 2 - 1)
+ )
+ rgbs = torch.rand((init_num_pts, 3))
+
+ # Initialize the GS size to be the average dist of the 3 nearest neighbors
+ dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,]
+ dist_avg = torch.sqrt(dist2_avg)
+ scales = (
+ torch.log(dist_avg * init_scale).unsqueeze(-1).repeat(1, 3)
+ ) # [N, 3]
+
+ # Distribute the GSs to different ranks (also works for single rank)
+ points = points[world_rank::world_size]
+ rgbs = rgbs[world_rank::world_size]
+ scales = scales[world_rank::world_size]
+
+ N = points.shape[0]
+ quats = torch.rand((N, 4)) # [N, 4]
+ opacities = torch.logit(torch.full((N,), init_opacity)) # [N,]
+
+ params = [
+ # name, value, lr
+ ("means", torch.nn.Parameter(points), means_lr * scene_scale),
+ ("scales", torch.nn.Parameter(scales), scales_lr),
+ ("quats", torch.nn.Parameter(quats), quats_lr),
+ ("opacities", torch.nn.Parameter(opacities), opacities_lr),
+ ]
+
+ if feature_dim is None:
+ # color is SH coefficients.
+ colors = torch.zeros((N, (sh_degree + 1) ** 2, 3)) # [N, K, 3]
+ colors[:, 0, :] = rgb_to_sh(rgbs)
+ params.append(("sh0", torch.nn.Parameter(colors[:, :1, :]), sh0_lr))
+ params.append(("shN", torch.nn.Parameter(colors[:, 1:, :]), shN_lr))
+ else:
+ # features will be used for appearance and view-dependent shading
+ features = torch.rand(N, feature_dim) # [N, feature_dim]
+ params.append(("features", torch.nn.Parameter(features), sh0_lr))
+ colors = torch.logit(rgbs) # [N, 3]
+ params.append(("colors", torch.nn.Parameter(colors), sh0_lr))
+
+ splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device)
+ # Scale learning rate based on batch size, reference:
+ # https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/
+ # Note that this would not make the training exactly equivalent, see
+ # https://arxiv.org/pdf/2402.18824v1
+ BS = batch_size * world_size
+ optimizer_class = None
+ if sparse_grad:
+ optimizer_class = torch.optim.SparseAdam
+ elif visible_adam:
+ optimizer_class = SelectiveAdam
+ else:
+ optimizer_class = torch.optim.Adam
+ optimizers = {
+ name: optimizer_class(
+ [{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}],
+ eps=1e-15 / math.sqrt(BS),
+ # TODO: check betas logic when BS is larger than 10 betas[0] will be zero.
+ betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)),
+ )
+ for name, _, lr in params
+ }
+ return splats, optimizers
+
+
+def compute_pinhole_intrinsics(
+ image_w: int, image_h: int, fov_deg: float
+) -> np.ndarray:
+ fov_rad = np.deg2rad(fov_deg)
+ fx = image_w / (2 * np.tan(fov_rad / 2))
+ fy = fx # assuming square pixels
+ cx = image_w / 2
+ cy = image_h / 2
+ K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
+
+ return K
+
+
+def resize_pinhole_intrinsics(
+ raw_K: np.ndarray | torch.Tensor,
+ raw_hw: tuple[int, int],
+ new_hw: tuple[int, int],
+) -> np.ndarray:
+ raw_h, raw_w = raw_hw
+ new_h, new_w = new_hw
+
+ scale_x = new_w / raw_w
+ scale_y = new_h / raw_h
+
+ new_K = raw_K.copy() if isinstance(raw_K, np.ndarray) else raw_K.clone()
+ new_K[0, 0] *= scale_x # fx
+ new_K[0, 2] *= scale_x # cx
+ new_K[1, 1] *= scale_y # fy
+ new_K[1, 2] *= scale_y # cy
+
+ return new_K
+
+
+def restore_scene_scale_and_position(
+ real_height: float, mesh_path: str, gs_path: str
+) -> None:
+ """Scales a mesh and corresponding GS model to match a given real-world height.
+
+ Uses the 1st and 99th percentile of mesh Z-axis to estimate height,
+ applies scaling and vertical alignment, and updates both the mesh and GS model.
+
+ Args:
+ real_height (float): Target real-world height among Z axis.
+ mesh_path (str): Path to the input mesh file.
+ gs_path (str): Path to the Gaussian Splatting model file.
+ """
+ mesh = trimesh.load(mesh_path)
+ z_min = np.percentile(mesh.vertices[:, 1], 1)
+ z_max = np.percentile(mesh.vertices[:, 1], 99)
+ height = z_max - z_min
+ scale = real_height / height
+
+ rot = Rotation.from_quat([0, 1, 0, 0])
+ mesh.vertices = rot.apply(mesh.vertices)
+ mesh.vertices[:, 1] -= z_min
+ mesh.vertices *= scale
+ mesh.export(mesh_path)
+
+ gs_model: GaussianOperator = GaussianOperator.load_from_ply(gs_path)
+ gs_model = gs_model.get_gaussians(
+ instance_pose=torch.tensor([0.0, -z_min, 0, 0, 1, 0, 0])
+ )
+ gs_model.rescale(scale)
+ gs_model.save_to_ply(gs_path)
diff --git a/embodied_gen/utils/monkey_patches.py b/embodied_gen/utils/monkey_patches.py
new file mode 100644
index 0000000..79e77d5
--- /dev/null
+++ b/embodied_gen/utils/monkey_patches.py
@@ -0,0 +1,152 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+import os
+import sys
+import zipfile
+
+import torch
+from huggingface_hub import hf_hub_download
+from omegaconf import OmegaConf
+from PIL import Image
+from torchvision import transforms
+
+
+def monkey_patch_pano2room():
+ current_file_path = os.path.abspath(__file__)
+ current_dir = os.path.dirname(current_file_path)
+ sys.path.append(os.path.join(current_dir, "../.."))
+ sys.path.append(os.path.join(current_dir, "../../thirdparty/pano2room"))
+ from thirdparty.pano2room.modules.geo_predictors.omnidata.omnidata_normal_predictor import (
+ OmnidataNormalPredictor,
+ )
+ from thirdparty.pano2room.modules.geo_predictors.omnidata.omnidata_predictor import (
+ OmnidataPredictor,
+ )
+
+ def patched_omni_depth_init(self):
+ self.img_size = 384
+ self.model = torch.hub.load(
+ 'alexsax/omnidata_models', 'depth_dpt_hybrid_384'
+ )
+ self.model.eval()
+ self.trans_totensor = transforms.Compose(
+ [
+ transforms.Resize(self.img_size, interpolation=Image.BILINEAR),
+ transforms.CenterCrop(self.img_size),
+ transforms.Normalize(mean=0.5, std=0.5),
+ ]
+ )
+
+ OmnidataPredictor.__init__ = patched_omni_depth_init
+
+ def patched_omni_normal_init(self):
+ self.img_size = 384
+ self.model = torch.hub.load(
+ 'alexsax/omnidata_models', 'surface_normal_dpt_hybrid_384'
+ )
+ self.model.eval()
+ self.trans_totensor = transforms.Compose(
+ [
+ transforms.Resize(self.img_size, interpolation=Image.BILINEAR),
+ transforms.CenterCrop(self.img_size),
+ transforms.Normalize(mean=0.5, std=0.5),
+ ]
+ )
+
+ OmnidataNormalPredictor.__init__ = patched_omni_normal_init
+
+ def patched_panojoint_init(self, save_path=None):
+ self.depth_predictor = OmnidataPredictor()
+ self.normal_predictor = OmnidataNormalPredictor()
+ self.save_path = save_path
+
+ from modules.geo_predictors import PanoJointPredictor
+
+ PanoJointPredictor.__init__ = patched_panojoint_init
+
+ # NOTE: We use gsplat instead.
+ # import depth_diff_gaussian_rasterization_min as ddgr
+ # from dataclasses import dataclass
+ # @dataclass
+ # class PatchedGaussianRasterizationSettings:
+ # image_height: int
+ # image_width: int
+ # tanfovx: float
+ # tanfovy: float
+ # bg: torch.Tensor
+ # scale_modifier: float
+ # viewmatrix: torch.Tensor
+ # projmatrix: torch.Tensor
+ # sh_degree: int
+ # campos: torch.Tensor
+ # prefiltered: bool
+ # debug: bool = False
+ # ddgr.GaussianRasterizationSettings = PatchedGaussianRasterizationSettings
+
+ # disable get_has_ddp_rank print in `BaseInpaintingTrainingModule`
+ os.environ["NODE_RANK"] = "0"
+
+ from thirdparty.pano2room.modules.inpainters.lama.saicinpainting.training.trainers import (
+ load_checkpoint,
+ )
+ from thirdparty.pano2room.modules.inpainters.lama_inpainter import (
+ LamaInpainter,
+ )
+
+ def patched_lama_inpaint_init(self):
+ zip_path = hf_hub_download(
+ repo_id="smartywu/big-lama",
+ filename="big-lama.zip",
+ repo_type="model",
+ )
+ extract_dir = os.path.splitext(zip_path)[0]
+
+ if not os.path.exists(extract_dir):
+ os.makedirs(extract_dir, exist_ok=True)
+ with zipfile.ZipFile(zip_path, "r") as zip_ref:
+ zip_ref.extractall(extract_dir)
+
+ config_path = os.path.join(extract_dir, 'big-lama', 'config.yaml')
+ checkpoint_path = os.path.join(
+ extract_dir, 'big-lama/models/best.ckpt'
+ )
+ train_config = OmegaConf.load(config_path)
+ train_config.training_model.predict_only = True
+ train_config.visualizer.kind = 'noop'
+
+ self.model = load_checkpoint(
+ train_config, checkpoint_path, strict=False, map_location='cpu'
+ )
+ self.model.freeze()
+
+ LamaInpainter.__init__ = patched_lama_inpaint_init
+
+ from diffusers import StableDiffusionInpaintPipeline
+ from thirdparty.pano2room.modules.inpainters.SDFT_inpainter import (
+ SDFTInpainter,
+ )
+
+ def patched_sd_inpaint_init(self, subset_name=None):
+ super(SDFTInpainter, self).__init__()
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-2-inpainting",
+ torch_dtype=torch.float16,
+ ).to("cuda")
+ pipe.enable_model_cpu_offload()
+ self.inpaint_pipe = pipe
+
+ SDFTInpainter.__init__ = patched_sd_inpaint_init
diff --git a/embodied_gen/utils/process_media.py b/embodied_gen/utils/process_media.py
index edfdcff..3087d4f 100644
--- a/embodied_gen/utils/process_media.py
+++ b/embodied_gen/utils/process_media.py
@@ -17,6 +17,7 @@
import logging
import math
+import mimetypes
import os
import textwrap
from glob import glob
@@ -27,10 +28,10 @@ import imageio
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
-import PIL.Image as Image
import spaces
from matplotlib.patches import Patch
from moviepy.editor import VideoFileClip, clips_array
+from PIL import Image
from embodied_gen.data.differentiable_render import entrypoint as render_api
from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum
@@ -45,6 +46,8 @@ __all__ = [
"filter_image_small_connected_components",
"combine_images_to_grid",
"SceneTreeVisualizer",
+ "is_image_file",
+ "parse_text_prompts",
]
@@ -356,6 +359,23 @@ def load_scene_dict(file_path: str) -> dict:
return scene_dict
+def is_image_file(filename: str) -> bool:
+ mime_type, _ = mimetypes.guess_type(filename)
+
+ return mime_type is not None and mime_type.startswith('image')
+
+
+def parse_text_prompts(prompts: list[str]) -> list[str]:
+ if len(prompts) == 1 and prompts[0].endswith(".txt"):
+ with open(prompts[0], "r") as f:
+ prompts = [
+ line.strip()
+ for line in f
+ if line.strip() and not line.strip().startswith("#")
+ ]
+ return prompts
+
+
if __name__ == "__main__":
merge_video_video(
"outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh_glo_normal.mp4", # noqa
diff --git a/embodied_gen/utils/tags.py b/embodied_gen/utils/tags.py
index 56deb45..60de686 100644
--- a/embodied_gen/utils/tags.py
+++ b/embodied_gen/utils/tags.py
@@ -1 +1 @@
-VERSION = "v0.1.1"
+VERSION = "v0.1.2"
diff --git a/embodied_gen/validators/quality_checkers.py b/embodied_gen/validators/quality_checkers.py
index 88636f6..186346f 100644
--- a/embodied_gen/validators/quality_checkers.py
+++ b/embodied_gen/validators/quality_checkers.py
@@ -33,6 +33,9 @@ __all__ = [
"ImageAestheticChecker",
"SemanticConsistChecker",
"TextGenAlignChecker",
+ "PanoImageGenChecker",
+ "PanoHeightEstimator",
+ "PanoImageOccChecker",
]
@@ -328,6 +331,159 @@ class TextGenAlignChecker(BaseChecker):
)
+class PanoImageGenChecker(BaseChecker):
+ """A checker class that validates the quality and realism of generated panoramic indoor images.
+
+ Attributes:
+ gpt_client (GPTclient): A GPT client instance used to query for image validation.
+ prompt (str): The instruction prompt passed to the GPT model. If None, a default prompt is used.
+ verbose (bool): Whether to print internal processing information for debugging.
+ """
+
+ def __init__(
+ self,
+ gpt_client: GPTclient,
+ prompt: str = None,
+ verbose: bool = False,
+ ) -> None:
+ super().__init__(prompt, verbose)
+ self.gpt_client = gpt_client
+ if self.prompt is None:
+ self.prompt = """
+ You are a panoramic image analyzer specializing in indoor room structure validation.
+
+ Given a generated panoramic image, assess if it meets all the criteria:
+ - Floor Space: ≥30 percent of the floor is free of objects or obstructions.
+ - Visual Clarity: Floor, walls, and ceiling are clear, with no distortion, blur, noise.
+ - Structural Continuity: Surfaces form plausible, continuous geometry
+ without breaks, floating parts, or abrupt cuts.
+ - Spatial Completeness: Full 360° coverage without missing areas,
+ seams, gaps, or stitching artifacts.
+ Instructions:
+ - If all criteria are met, reply with "YES".
+ - Otherwise, reply with "NO: