From e82f02a9a5caa42612fde59d64e68e62c2cc7473 Mon Sep 17 00:00:00 2001 From: Xinjie Date: Mon, 21 Jul 2025 23:31:15 +0800 Subject: [PATCH] feat(pipe): Release 3D scene generation pipeline. (#25) Release 3D scene generation pipeline and tag as v0.1.2. --------- Co-authored-by: xinjie.wang --- .gitignore | 2 +- .gitmodules | 7 +- .pre-commit-config.yaml | 4 +- README.md | 27 +- apps/common.py | 6 +- embodied_gen/data/datasets.py | 66 +- embodied_gen/data/utils.py | 3 +- embodied_gen/scripts/gen_scene3d.py | 191 ++++++ embodied_gen/scripts/imageto3d.py | 32 +- embodied_gen/scripts/text2image.py | 10 +- embodied_gen/scripts/textto3d.py | 20 +- embodied_gen/trainer/gsplat_trainer.py | 678 +++++++++++++++++++ embodied_gen/trainer/pono2mesh_trainer.py | 538 +++++++++++++++ embodied_gen/utils/config.py | 190 ++++++ embodied_gen/utils/gaussian.py | 331 +++++++++ embodied_gen/utils/monkey_patches.py | 152 +++++ embodied_gen/utils/process_media.py | 22 +- embodied_gen/utils/tags.py | 2 +- embodied_gen/validators/quality_checkers.py | 156 +++++ install.sh | 75 +- install/_utils.sh | 21 + install/install_basic.sh | 35 + install/install_extra.sh | 52 ++ pyproject.toml | 3 +- requirements.txt | 3 + tests/test_examples/test_quality_checkers.py | 44 ++ tests/test_unit/test_agents.py | 15 + thirdparty/pano2room | 1 + 28 files changed, 2577 insertions(+), 109 deletions(-) create mode 100644 embodied_gen/scripts/gen_scene3d.py create mode 100644 embodied_gen/trainer/gsplat_trainer.py create mode 100644 embodied_gen/trainer/pono2mesh_trainer.py create mode 100644 embodied_gen/utils/config.py create mode 100644 embodied_gen/utils/gaussian.py create mode 100644 embodied_gen/utils/monkey_patches.py create mode 100644 install/_utils.sh create mode 100644 install/install_basic.sh create mode 100644 install/install_extra.sh create mode 160000 thirdparty/pano2room diff --git a/.gitignore b/.gitignore index 7d3fe28..ea1d498 100644 --- a/.gitignore +++ b/.gitignore @@ -59,4 +59,4 @@ output* scripts/tools/ weights apps/sessions/ -apps/assets/ \ No newline at end of file +apps/assets/ diff --git a/.gitmodules b/.gitmodules index ca1c2c7..c6b0a7b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -2,4 +2,9 @@ path = thirdparty/TRELLIS url = https://github.com/microsoft/TRELLIS.git branch = main - shallow = true \ No newline at end of file + shallow = true +[submodule "thirdparty/pano2room"] + path = thirdparty/pano2room + url = https://github.com/TrickyGo/Pano2Room.git + branch = main + shallow = true diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 28d6ba6..e76d477 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - - repo: git@gitlab.hobot.cc:ptd/3rd/pre-commit/pre-commit-hooks.git - rev: v2.3.0 # Use the ref you want to point at + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.2.0 # Use the ref you want to point at hooks: - id: trailing-whitespace - id: check-added-large-files diff --git a/README.md b/README.md index 0cfc0cd..4de144e 100644 --- a/README.md +++ b/README.md @@ -30,11 +30,11 @@ ```sh git clone https://github.com/HorizonRobotics/EmbodiedGen.git cd EmbodiedGen -git checkout v0.1.1 +git checkout v0.1.2 git submodule update --init --recursive --progress -conda create -n embodiedgen python=3.10.13 -y +conda create -n embodiedgen python=3.10.13 -y # recommended to use a new env. conda activate embodiedgen -bash install.sh +bash install.sh basic ``` ### ✅ Setup GPT Agent @@ -94,7 +94,7 @@ python apps/text_to_3d.py ### ⚡ API Text-to-image model based on SD3.5 Medium, English prompts only. -Usage requires agreement to the [model license(click accept)](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium), models downloaded automatically. (ps: models with more permissive licenses found in `embodied_gen/models/image_comm_model.py`) +Usage requires agreement to the [model license(click accept)](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium), models downloaded automatically. For large-scale 3D assets generation, set `--n_pipe_retry=2` to ensure high end-to-end 3D asset usability through automatic quality check and retries. For more diverse results, do not set `--seed_img`. @@ -110,6 +110,7 @@ bash embodied_gen/scripts/textto3d.sh \ --prompts "small bronze figurine of a lion" "A globe with wooden base and latitude and longitude lines" "橙色电动手钻,有磨损细节" \ --output_root outputs/textto3d_k ``` +ps: models with more permissive licenses found in `embodied_gen/models/image_comm_model.py` --- @@ -146,10 +147,22 @@ bash embodied_gen/scripts/texture_gen.sh \

🌍 3D Scene Generation

-🚧 *Coming Soon* - scene3d +### ⚡ API +> Run `bash install.sh extra` to install additional requirements if you need to use `scene3d-cli`. + +It takes ~30mins to generate a color mesh and 3DGS per scene. + +```sh +CUDA_VISIBLE_DEVICES=0 scene3d-cli \ +--prompts "Art studio with easel and canvas" \ +--output_dir outputs/bg_scenes/ \ +--seed 0 \ +--gs3d.max_steps 4000 \ +--disable_pano_check +``` + --- @@ -189,7 +202,7 @@ bash embodied_gen/scripts/texture_gen.sh \ ## For Developer ```sh -pip install .[dev] && pre-commit install +pip install -e .[dev] && pre-commit install python -m pytest # Pass all unit-test are required. ``` diff --git a/apps/common.py b/apps/common.py index f7c7f40..76e6335 100644 --- a/apps/common.py +++ b/apps/common.py @@ -94,9 +94,6 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "false" os.environ["SPCONV_ALGO"] = "native" MAX_SEED = 100000 -DELIGHT = DelightingModel() -IMAGESR_MODEL = ImageRealESRGAN(outscale=4) -# IMAGESR_MODEL = ImageStableSR() def patched_setup_functions(self): @@ -136,6 +133,9 @@ def patched_setup_functions(self): Gaussian.setup_functions = patched_setup_functions +DELIGHT = DelightingModel() +IMAGESR_MODEL = ImageRealESRGAN(outscale=4) +# IMAGESR_MODEL = ImageStableSR() if os.getenv("GRADIO_APP") == "imageto3d": RBG_REMOVER = RembgRemover() RBG14_REMOVER = BMGG14Remover() diff --git a/embodied_gen/data/datasets.py b/embodied_gen/data/datasets.py index 4a9563a..eaa529b 100644 --- a/embodied_gen/data/datasets.py +++ b/embodied_gen/data/datasets.py @@ -19,8 +19,9 @@ import json import logging import os import random -from typing import Any, Callable, Dict, List, Tuple +from typing import Any, Callable, Dict, List, Literal, Tuple +import numpy as np import torch import torch.utils.checkpoint from PIL import Image @@ -36,6 +37,7 @@ logger = logging.getLogger(__name__) __all__ = [ "Asset3dGenDataset", + "PanoGSplatDataset", ] @@ -222,6 +224,68 @@ class Asset3dGenDataset(Dataset): return data +class PanoGSplatDataset(Dataset): + """A PyTorch Dataset for loading panorama-based 3D Gaussian Splatting data. + + This dataset is designed to be compatible with train and eval pipelines + that use COLMAP-style camera conventions. + + Args: + data_dir (str): Root directory where the dataset file is located. + split (str): Dataset split to use, either "train" or "eval". + data_name (str, optional): Name of the dataset file (default: "gs_data.pt"). + max_sample_num (int, optional): Maximum number of samples to load. If None, + all available samples in the split will be used. + """ + + def __init__( + self, + data_dir: str, + split: str = Literal["train", "eval"], + data_name: str = "gs_data.pt", + max_sample_num: int = None, + ) -> None: + self.data_path = os.path.join(data_dir, data_name) + self.split = split + self.max_sample_num = max_sample_num + if not os.path.exists(self.data_path): + raise FileNotFoundError( + f"Dataset file {self.data_path} not found. Please provide the correct path." + ) + self.data = torch.load(self.data_path, weights_only=False) + self.frames = self.data[split] + if max_sample_num is not None: + self.frames = self.frames[:max_sample_num] + self.points = self.data.get("points", None) + self.points_rgb = self.data.get("points_rgb", None) + + def __len__(self) -> int: + return len(self.frames) + + def cvt_blender_to_colmap_coord(self, c2w: np.ndarray) -> np.ndarray: + # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward) + tranformed_c2w = np.copy(c2w) + tranformed_c2w[:3, 1:3] *= -1 + + return tranformed_c2w + + def __getitem__(self, index: int) -> dict[str, any]: + data = self.frames[index] + c2w = self.cvt_blender_to_colmap_coord(data["camtoworld"]) + item = dict( + camtoworld=c2w, + K=data["K"], + image_h=data["image_h"], + image_w=data["image_w"], + ) + if "image" in data: + item["image"] = data["image"] + if "image_id" in data: + item["image_id"] = data["image_id"] + + return item + + if __name__ == "__main__": index_file = "datasets/objaverse/v1.0/statistics_1.0_gobjaverse_filter/view6s_v4/meta_ac2e0ddea8909db26d102c8465b5bcb2.json" # noqa target_hw = (512, 512) diff --git a/embodied_gen/data/utils.py b/embodied_gen/data/utils.py index 83cb39d..0d39f71 100644 --- a/embodied_gen/data/utils.py +++ b/embodied_gen/data/utils.py @@ -158,8 +158,9 @@ class DiffrastRender(object): return normalized_maps + @staticmethod def normalize_map_by_mask( - self, map: torch.Tensor, mask: torch.Tensor + map: torch.Tensor, mask: torch.Tensor ) -> torch.Tensor: # Normalize all maps in total by mask, normalized map in [0, 1]. foreground = (mask == 1).squeeze(dim=-1) diff --git a/embodied_gen/scripts/gen_scene3d.py b/embodied_gen/scripts/gen_scene3d.py new file mode 100644 index 0000000..42d2527 --- /dev/null +++ b/embodied_gen/scripts/gen_scene3d.py @@ -0,0 +1,191 @@ +import logging +import os +import random +import time +import warnings +from dataclasses import dataclass, field +from shutil import copy, rmtree + +import torch +import tyro +from huggingface_hub import snapshot_download +from packaging import version + +# Suppress warnings +warnings.filterwarnings("ignore", category=FutureWarning) +logging.getLogger("transformers").setLevel(logging.ERROR) +logging.getLogger("diffusers").setLevel(logging.ERROR) + +# TorchVision monkey patch for >0.16 +if version.parse(torch.__version__) >= version.parse("0.16"): + import sys + import types + + import torchvision.transforms.functional as TF + + functional_tensor = types.ModuleType( + "torchvision.transforms.functional_tensor" + ) + functional_tensor.rgb_to_grayscale = TF.rgb_to_grayscale + sys.modules["torchvision.transforms.functional_tensor"] = functional_tensor + +from gsplat.distributed import cli +from txt2panoimg import Text2360PanoramaImagePipeline +from embodied_gen.trainer.gsplat_trainer import ( + DefaultStrategy, + GsplatTrainConfig, +) +from embodied_gen.trainer.gsplat_trainer import entrypoint as gsplat_entrypoint +from embodied_gen.trainer.pono2mesh_trainer import Pano2MeshSRPipeline +from embodied_gen.utils.config import Pano2MeshSRConfig +from embodied_gen.utils.gaussian import restore_scene_scale_and_position +from embodied_gen.utils.gpt_clients import GPT_CLIENT +from embodied_gen.utils.log import logger +from embodied_gen.utils.process_media import is_image_file, parse_text_prompts +from embodied_gen.validators.quality_checkers import ( + PanoHeightEstimator, + PanoImageOccChecker, +) + +__all__ = [ + "generate_pano_image", + "entrypoint", +] + + +@dataclass +class Scene3DGenConfig: + prompts: list[str] # Text desc of indoor room or style reference image. + output_dir: str + seed: int | None = None + real_height: float | None = None # The real height of the room in meters. + pano_image_only: bool = False + disable_pano_check: bool = False + keep_middle_result: bool = False + n_retry: int = 7 + gs3d: GsplatTrainConfig = field( + default_factory=lambda: GsplatTrainConfig( + strategy=DefaultStrategy(verbose=True), + max_steps=4000, + init_opa=0.9, + opacity_reg=2e-3, + sh_degree=0, + means_lr=1e-4, + scales_lr=1e-3, + ) + ) + + +def generate_pano_image( + prompt: str, + output_path: str, + pipeline, + seed: int, + n_retry: int, + checker=None, + num_inference_steps: int = 40, +) -> None: + for i in range(n_retry): + logger.info( + f"GEN Panorama: Retry {i+1}/{n_retry} for prompt: {prompt}, seed: {seed}" + ) + if is_image_file(prompt): + raise NotImplementedError("Image mode not implemented yet.") + else: + txt_prompt = f"{prompt}, spacious, empty, wide open, open floor, minimal furniture" + inputs = { + "prompt": txt_prompt, + "num_inference_steps": num_inference_steps, + "upscale": False, + "seed": seed, + } + pano_image = pipeline(inputs) + + pano_image.save(output_path) + if checker is None: + break + + flag, response = checker(pano_image) + logger.warning(f"{response}, image saved in {output_path}") + if flag is True or flag is None: + break + + seed = random.randint(0, 100000) + + return + + +def entrypoint(*args, **kwargs): + cfg = tyro.cli(Scene3DGenConfig) + + # Init global models. + model_path = snapshot_download("archerfmy0831/sd-t2i-360panoimage") + IMG2PANO_PIPE = Text2360PanoramaImagePipeline( + model_path, torch_dtype=torch.float16, device="cuda" + ) + PANOMESH_CFG = Pano2MeshSRConfig() + PANO2MESH_PIPE = Pano2MeshSRPipeline(PANOMESH_CFG) + PANO_CHECKER = PanoImageOccChecker(GPT_CLIENT, box_hw=[95, 1000]) + PANOHEIGHT_ESTOR = PanoHeightEstimator(GPT_CLIENT) + + prompts = parse_text_prompts(cfg.prompts) + for idx, prompt in enumerate(prompts): + start_time = time.time() + output_dir = os.path.join(cfg.output_dir, f"scene_{idx:04d}") + os.makedirs(output_dir, exist_ok=True) + pano_path = os.path.join(output_dir, "pano_image.png") + with open(f"{output_dir}/prompt.txt", "w") as f: + f.write(prompt) + + generate_pano_image( + prompt, + pano_path, + IMG2PANO_PIPE, + cfg.seed if cfg.seed is not None else random.randint(0, 100000), + cfg.n_retry, + checker=None if cfg.disable_pano_check else PANO_CHECKER, + ) + + if cfg.pano_image_only: + continue + + logger.info("GEN and REPAIR Mesh from Panorama...") + PANO2MESH_PIPE(pano_path, output_dir) + + logger.info("TRAIN 3DGS from Mesh Init and Cube Image...") + cfg.gs3d.data_dir = output_dir + cfg.gs3d.result_dir = f"{output_dir}/gaussian" + cfg.gs3d.adjust_steps(cfg.gs3d.steps_scaler) + torch.set_default_device("cpu") # recover default setting. + cli(gsplat_entrypoint, cfg.gs3d, verbose=True) + + # Clean up the middle results. + gs_path = ( + f"{cfg.gs3d.result_dir}/ply/point_cloud_{cfg.gs3d.max_steps-1}.ply" + ) + copy(gs_path, f"{output_dir}/gs_model.ply") + video_path = f"{cfg.gs3d.result_dir}/renders/video_step{cfg.gs3d.max_steps-1}.mp4" + copy(video_path, f"{output_dir}/video.mp4") + gs_cfg_path = f"{cfg.gs3d.result_dir}/cfg.yml" + copy(gs_cfg_path, f"{output_dir}/gsplat_cfg.yml") + if not cfg.keep_middle_result: + rmtree(cfg.gs3d.result_dir, ignore_errors=True) + os.remove(f"{output_dir}/{PANOMESH_CFG.gs_data_file}") + + real_height = ( + PANOHEIGHT_ESTOR(pano_path) + if cfg.real_height is None + else cfg.real_height + ) + gs_path = os.path.join(output_dir, "gs_model.ply") + mesh_path = os.path.join(output_dir, "mesh_model.ply") + restore_scene_scale_and_position(real_height, mesh_path, gs_path) + + elapsed_time = (time.time() - start_time) / 60 + logger.info( + f"FINISHED 3D scene generation in {output_dir} in {elapsed_time:.2f} mins." + ) + + +if __name__ == "__main__": + entrypoint() diff --git a/embodied_gen/scripts/imageto3d.py b/embodied_gen/scripts/imageto3d.py index 00a0958..f36cfa8 100644 --- a/embodied_gen/scripts/imageto3d.py +++ b/embodied_gen/scripts/imageto3d.py @@ -62,25 +62,6 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "false" os.environ["SPCONV_ALGO"] = "native" random.seed(0) -logger.info("Loading Models...") -DELIGHT = DelightingModel() -IMAGESR_MODEL = ImageRealESRGAN(outscale=4) - -RBG_REMOVER = RembgRemover() -RBG14_REMOVER = BMGG14Remover() -SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu") -PIPELINE = TrellisImageTo3DPipeline.from_pretrained( - "microsoft/TRELLIS-image-large" -) -# PIPELINE.cuda() -SEG_CHECKER = ImageSegChecker(GPT_CLIENT) -GEO_CHECKER = MeshGeoChecker(GPT_CLIENT) -AESTHETIC_CHECKER = ImageAestheticChecker() -CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER] -TMP_DIR = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d" -) - def parse_args(): parser = argparse.ArgumentParser(description="Image to 3D pipeline args.") @@ -128,6 +109,19 @@ def entrypoint(**kwargs): if hasattr(args, k) and v is not None: setattr(args, k, v) + logger.info("Loading Models...") + DELIGHT = DelightingModel() + IMAGESR_MODEL = ImageRealESRGAN(outscale=4) + RBG_REMOVER = RembgRemover() + PIPELINE = TrellisImageTo3DPipeline.from_pretrained( + "microsoft/TRELLIS-image-large" + ) + # PIPELINE.cuda() + SEG_CHECKER = ImageSegChecker(GPT_CLIENT) + GEO_CHECKER = MeshGeoChecker(GPT_CLIENT) + AESTHETIC_CHECKER = ImageAestheticChecker() + CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER] + assert ( args.image_path or args.image_root ), "Please provide either --image_path or --image_root." diff --git a/embodied_gen/scripts/text2image.py b/embodied_gen/scripts/text2image.py index ac1587c..9f25dbc 100644 --- a/embodied_gen/scripts/text2image.py +++ b/embodied_gen/scripts/text2image.py @@ -31,6 +31,7 @@ from embodied_gen.models.text_model import ( build_text2img_pipeline, text2img_gen, ) +from embodied_gen.utils.process_media import parse_text_prompts logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -101,14 +102,7 @@ def entrypoint( if hasattr(args, k) and v is not None: setattr(args, k, v) - prompts = args.prompts - if len(prompts) == 1 and prompts[0].endswith(".txt"): - with open(prompts[0], "r") as f: - prompts = f.readlines() - prompts = [ - prompt.strip() for prompt in prompts if prompt.strip() != "" - ] - + prompts = parse_text_prompts(args.prompts) os.makedirs(args.output_root, exist_ok=True) ip_img_paths = args.ref_image diff --git a/embodied_gen/scripts/textto3d.py b/embodied_gen/scripts/textto3d.py index a4262a3..4d28093 100644 --- a/embodied_gen/scripts/textto3d.py +++ b/embodied_gen/scripts/textto3d.py @@ -44,13 +44,6 @@ __all__ = [ "text_to_3d", ] -logger.info("Loading Models...") -SEMANTIC_CHECKER = SemanticConsistChecker(GPT_CLIENT) -SEG_CHECKER = ImageSegChecker(GPT_CLIENT) -TXTGEN_CHECKER = TextGenAlignChecker(GPT_CLIENT) -PIPE_IMG = build_hf_image_pipeline("sd35") -BG_REMOVER = RembgRemover() - def text_to_image( prompt: str, @@ -121,6 +114,14 @@ def text_to_3d(**kwargs) -> dict: if hasattr(args, k) and v is not None: setattr(args, k, v) + logger.info("Loading Models...") + global SEMANTIC_CHECKER, SEG_CHECKER, TXTGEN_CHECKER, PIPE_IMG, BG_REMOVER + SEMANTIC_CHECKER = SemanticConsistChecker(GPT_CLIENT) + SEG_CHECKER = ImageSegChecker(GPT_CLIENT) + TXTGEN_CHECKER = TextGenAlignChecker(GPT_CLIENT) + PIPE_IMG = build_hf_image_pipeline(args.text_model) + BG_REMOVER = RembgRemover() + if args.asset_names is None or len(args.asset_names) == 0: args.asset_names = [f"sample3d_{i}" for i in range(len(args.prompts))] img_save_dir = os.path.join(args.output_root, "images") @@ -260,6 +261,11 @@ def parse_args(): default=0, help="Random seed for 3D generation", ) + parser.add_argument( + "--text_model", + type=str, + default="sd35", + ) parser.add_argument("--keep_intermediate", action="store_true") args, unknown = parser.parse_known_args() diff --git a/embodied_gen/trainer/gsplat_trainer.py b/embodied_gen/trainer/gsplat_trainer.py new file mode 100644 index 0000000..53aaf28 --- /dev/null +++ b/embodied_gen/trainer/gsplat_trainer.py @@ -0,0 +1,678 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. +# Part of the code comes from https://github.com/nerfstudio-project/gsplat/blob/main/examples/simple_trainer.py +# Both under the Apache License, Version 2.0. + + +import json +import os +import time +from collections import defaultdict +from typing import Dict, Optional, Tuple + +import cv2 +import imageio +import numpy as np +import torch +import torch.nn.functional as F +import tqdm +import tyro +import yaml +from fused_ssim import fused_ssim +from gsplat.distributed import cli +from gsplat.rendering import rasterization +from gsplat.strategy import DefaultStrategy, MCMCStrategy +from torch import Tensor +from torch.utils.tensorboard import SummaryWriter +from torchmetrics.image import ( + PeakSignalNoiseRatio, + StructuralSimilarityIndexMeasure, +) +from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity +from typing_extensions import Literal, assert_never +from embodied_gen.data.datasets import PanoGSplatDataset +from embodied_gen.utils.config import GsplatTrainConfig +from embodied_gen.utils.gaussian import ( + create_splats_with_optimizers, + export_splats, + resize_pinhole_intrinsics, + set_random_seed, +) + + +class Runner: + """Engine for training and testing from gsplat example. + + Code from https://github.com/nerfstudio-project/gsplat/blob/main/examples/simple_trainer.py + """ + + def __init__( + self, + local_rank: int, + world_rank, + world_size: int, + cfg: GsplatTrainConfig, + ) -> None: + set_random_seed(42 + local_rank) + + self.cfg = cfg + self.world_rank = world_rank + self.local_rank = local_rank + self.world_size = world_size + self.device = f"cuda:{local_rank}" + + # Where to dump results. + os.makedirs(cfg.result_dir, exist_ok=True) + + # Setup output directories. + self.ckpt_dir = f"{cfg.result_dir}/ckpts" + os.makedirs(self.ckpt_dir, exist_ok=True) + self.stats_dir = f"{cfg.result_dir}/stats" + os.makedirs(self.stats_dir, exist_ok=True) + self.render_dir = f"{cfg.result_dir}/renders" + os.makedirs(self.render_dir, exist_ok=True) + self.ply_dir = f"{cfg.result_dir}/ply" + os.makedirs(self.ply_dir, exist_ok=True) + + # Tensorboard + self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb") + self.trainset = PanoGSplatDataset(cfg.data_dir, split="train") + self.valset = PanoGSplatDataset( + cfg.data_dir, split="train", max_sample_num=6 + ) + self.testset = PanoGSplatDataset(cfg.data_dir, split="eval") + self.scene_scale = cfg.scene_scale + + # Model + self.splats, self.optimizers = create_splats_with_optimizers( + self.trainset.points, + self.trainset.points_rgb, + init_num_pts=cfg.init_num_pts, + init_extent=cfg.init_extent, + init_opacity=cfg.init_opa, + init_scale=cfg.init_scale, + means_lr=cfg.means_lr, + scales_lr=cfg.scales_lr, + opacities_lr=cfg.opacities_lr, + quats_lr=cfg.quats_lr, + sh0_lr=cfg.sh0_lr, + shN_lr=cfg.shN_lr, + scene_scale=self.scene_scale, + sh_degree=cfg.sh_degree, + sparse_grad=cfg.sparse_grad, + visible_adam=cfg.visible_adam, + batch_size=cfg.batch_size, + feature_dim=None, + device=self.device, + world_rank=world_rank, + world_size=world_size, + ) + print("Model initialized. Number of GS:", len(self.splats["means"])) + + # Densification Strategy + self.cfg.strategy.check_sanity(self.splats, self.optimizers) + + if isinstance(self.cfg.strategy, DefaultStrategy): + self.strategy_state = self.cfg.strategy.initialize_state( + scene_scale=self.scene_scale + ) + elif isinstance(self.cfg.strategy, MCMCStrategy): + self.strategy_state = self.cfg.strategy.initialize_state() + else: + assert_never(self.cfg.strategy) + + # Losses & Metrics. + self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to( + self.device + ) + self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device) + + if cfg.lpips_net == "alex": + self.lpips = LearnedPerceptualImagePatchSimilarity( + net_type="alex", normalize=True + ).to(self.device) + elif cfg.lpips_net == "vgg": + # The 3DGS official repo uses lpips vgg, which is equivalent with the following: + self.lpips = LearnedPerceptualImagePatchSimilarity( + net_type="vgg", normalize=False + ).to(self.device) + else: + raise ValueError(f"Unknown LPIPS network: {cfg.lpips_net}") + + def rasterize_splats( + self, + camtoworlds: Tensor, + Ks: Tensor, + width: int, + height: int, + masks: Optional[Tensor] = None, + rasterize_mode: Optional[Literal["classic", "antialiased"]] = None, + camera_model: Optional[Literal["pinhole", "ortho", "fisheye"]] = None, + **kwargs, + ) -> Tuple[Tensor, Tensor, Dict]: + means = self.splats["means"] # [N, 3] + # quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4] + # rasterization does normalization internally + quats = self.splats["quats"] # [N, 4] + scales = torch.exp(self.splats["scales"]) # [N, 3] + opacities = torch.sigmoid(self.splats["opacities"]) # [N,] + image_ids = kwargs.pop("image_ids", None) + + colors = torch.cat( + [self.splats["sh0"], self.splats["shN"]], 1 + ) # [N, K, 3] + + if rasterize_mode is None: + rasterize_mode = ( + "antialiased" if self.cfg.antialiased else "classic" + ) + if camera_model is None: + camera_model = self.cfg.camera_model + + render_colors, render_alphas, info = rasterization( + means=means, + quats=quats, + scales=scales, + opacities=opacities, + colors=colors, + viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] + Ks=Ks, # [C, 3, 3] + width=width, + height=height, + packed=self.cfg.packed, + absgrad=( + self.cfg.strategy.absgrad + if isinstance(self.cfg.strategy, DefaultStrategy) + else False + ), + sparse_grad=self.cfg.sparse_grad, + rasterize_mode=rasterize_mode, + distributed=self.world_size > 1, + camera_model=self.cfg.camera_model, + with_ut=self.cfg.with_ut, + with_eval3d=self.cfg.with_eval3d, + **kwargs, + ) + if masks is not None: + render_colors[~masks] = 0 + return render_colors, render_alphas, info + + def train(self): + cfg = self.cfg + device = self.device + world_rank = self.world_rank + + # Dump cfg. + if world_rank == 0: + with open(f"{cfg.result_dir}/cfg.yml", "w") as f: + yaml.dump(vars(cfg), f) + + max_steps = cfg.max_steps + init_step = 0 + + schedulers = [ + # means has a learning rate schedule, that end at 0.01 of the initial value + torch.optim.lr_scheduler.ExponentialLR( + self.optimizers["means"], gamma=0.01 ** (1.0 / max_steps) + ), + ] + trainloader = torch.utils.data.DataLoader( + self.trainset, + batch_size=cfg.batch_size, + shuffle=True, + num_workers=4, + persistent_workers=True, + pin_memory=True, + ) + trainloader_iter = iter(trainloader) + + # Training loop. + global_tic = time.time() + pbar = tqdm.tqdm(range(init_step, max_steps)) + for step in pbar: + try: + data = next(trainloader_iter) + except StopIteration: + trainloader_iter = iter(trainloader) + data = next(trainloader_iter) + + camtoworlds = data["camtoworld"].to(device) # [1, 4, 4] + Ks = data["K"].to(device) # [1, 3, 3] + pixels = data["image"].to(device) / 255.0 # [1, H, W, 3] + image_ids = data["image_id"].to(device) + masks = ( + data["mask"].to(device) if "mask" in data else None + ) # [1, H, W] + if cfg.depth_loss: + points = data["points"].to(device) # [1, M, 2] + depths_gt = data["depths"].to(device) # [1, M] + + height, width = pixels.shape[1:3] + + # sh schedule + sh_degree_to_use = min( + step // cfg.sh_degree_interval, cfg.sh_degree + ) + + # forward + renders, alphas, info = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=sh_degree_to_use, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + image_ids=image_ids, + render_mode="RGB+ED" if cfg.depth_loss else "RGB", + masks=masks, + ) + if renders.shape[-1] == 4: + colors, depths = renders[..., 0:3], renders[..., 3:4] + else: + colors, depths = renders, None + + if cfg.random_bkgd: + bkgd = torch.rand(1, 3, device=device) + colors = colors + bkgd * (1.0 - alphas) + + self.cfg.strategy.step_pre_backward( + params=self.splats, + optimizers=self.optimizers, + state=self.strategy_state, + step=step, + info=info, + ) + + # loss + l1loss = F.l1_loss(colors, pixels) + ssimloss = 1.0 - fused_ssim( + colors.permute(0, 3, 1, 2), + pixels.permute(0, 3, 1, 2), + padding="valid", + ) + loss = ( + l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda + ) + if cfg.depth_loss: + # query depths from depth map + points = torch.stack( + [ + points[:, :, 0] / (width - 1) * 2 - 1, + points[:, :, 1] / (height - 1) * 2 - 1, + ], + dim=-1, + ) # normalize to [-1, 1] + grid = points.unsqueeze(2) # [1, M, 1, 2] + depths = F.grid_sample( + depths.permute(0, 3, 1, 2), grid, align_corners=True + ) # [1, 1, M, 1] + depths = depths.squeeze(3).squeeze(1) # [1, M] + # calculate loss in disparity space + disp = torch.where( + depths > 0.0, 1.0 / depths, torch.zeros_like(depths) + ) + disp_gt = 1.0 / depths_gt # [1, M] + depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale + loss += depthloss * cfg.depth_lambda + + # regularizations + if cfg.opacity_reg > 0.0: + loss += ( + cfg.opacity_reg + * torch.sigmoid(self.splats["opacities"]).mean() + ) + if cfg.scale_reg > 0.0: + loss += cfg.scale_reg * torch.exp(self.splats["scales"]).mean() + + loss.backward() + + desc = ( + f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| " + ) + if cfg.depth_loss: + desc += f"depth loss={depthloss.item():.6f}| " + pbar.set_description(desc) + + # write images (gt and render) + # if world_rank == 0 and step % 800 == 0: + # canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() + # canvas = canvas.reshape(-1, *canvas.shape[2:]) + # imageio.imwrite( + # f"{self.render_dir}/train_rank{self.world_rank}.png", + # (canvas * 255).astype(np.uint8), + # ) + + if ( + world_rank == 0 + and cfg.tb_every > 0 + and step % cfg.tb_every == 0 + ): + mem = torch.cuda.max_memory_allocated() / 1024**3 + self.writer.add_scalar("train/loss", loss.item(), step) + self.writer.add_scalar("train/l1loss", l1loss.item(), step) + self.writer.add_scalar("train/ssimloss", ssimloss.item(), step) + self.writer.add_scalar( + "train/num_GS", len(self.splats["means"]), step + ) + self.writer.add_scalar("train/mem", mem, step) + if cfg.depth_loss: + self.writer.add_scalar( + "train/depthloss", depthloss.item(), step + ) + if cfg.tb_save_image: + canvas = ( + torch.cat([pixels, colors], dim=2) + .detach() + .cpu() + .numpy() + ) + canvas = canvas.reshape(-1, *canvas.shape[2:]) + self.writer.add_image("train/render", canvas, step) + self.writer.flush() + + # save checkpoint before updating the model + if ( + step in [i - 1 for i in cfg.save_steps] + or step == max_steps - 1 + ): + mem = torch.cuda.max_memory_allocated() / 1024**3 + stats = { + "mem": mem, + "ellipse_time": time.time() - global_tic, + "num_GS": len(self.splats["means"]), + } + print("Step: ", step, stats) + with open( + f"{self.stats_dir}/train_step{step:04d}_rank{self.world_rank}.json", + "w", + ) as f: + json.dump(stats, f) + data = {"step": step, "splats": self.splats.state_dict()} + torch.save( + data, + f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt", + ) + if ( + step in [i - 1 for i in cfg.ply_steps] or step == max_steps - 1 + ) and cfg.save_ply: + sh0 = self.splats["sh0"] + shN = self.splats["shN"] + means = self.splats["means"] + scales = self.splats["scales"] + quats = self.splats["quats"] + opacities = self.splats["opacities"] + export_splats( + means=means, + scales=scales, + quats=quats, + opacities=opacities, + sh0=sh0, + shN=shN, + format="ply", + save_to=f"{self.ply_dir}/point_cloud_{step}.ply", + ) + + # Turn Gradients into Sparse Tensor before running optimizer + if cfg.sparse_grad: + assert ( + cfg.packed + ), "Sparse gradients only work with packed mode." + gaussian_ids = info["gaussian_ids"] + for k in self.splats.keys(): + grad = self.splats[k].grad + if grad is None or grad.is_sparse: + continue + self.splats[k].grad = torch.sparse_coo_tensor( + indices=gaussian_ids[None], # [1, nnz] + values=grad[gaussian_ids], # [nnz, ...] + size=self.splats[k].size(), # [N, ...] + is_coalesced=len(Ks) == 1, + ) + + if cfg.visible_adam: + gaussian_cnt = self.splats.means.shape[0] + if cfg.packed: + visibility_mask = torch.zeros_like( + self.splats["opacities"], dtype=bool + ) + visibility_mask.scatter_(0, info["gaussian_ids"], 1) + else: + visibility_mask = (info["radii"] > 0).all(-1).any(0) + + # optimize + for optimizer in self.optimizers.values(): + if cfg.visible_adam: + optimizer.step(visibility_mask) + else: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for scheduler in schedulers: + scheduler.step() + + # Run post-backward steps after backward and optimizer + if isinstance(self.cfg.strategy, DefaultStrategy): + self.cfg.strategy.step_post_backward( + params=self.splats, + optimizers=self.optimizers, + state=self.strategy_state, + step=step, + info=info, + packed=cfg.packed, + ) + elif isinstance(self.cfg.strategy, MCMCStrategy): + self.cfg.strategy.step_post_backward( + params=self.splats, + optimizers=self.optimizers, + state=self.strategy_state, + step=step, + info=info, + lr=schedulers[0].get_last_lr()[0], + ) + else: + assert_never(self.cfg.strategy) + + # eval the full set + if step in [i - 1 for i in cfg.eval_steps]: + self.eval(step) + self.render_video(step) + + @torch.no_grad() + def eval( + self, + step: int, + stage: str = "val", + canvas_h: int = 512, + canvas_w: int = 1024, + ): + """Entry for evaluation.""" + print("Running evaluation...") + cfg = self.cfg + device = self.device + world_rank = self.world_rank + + valloader = torch.utils.data.DataLoader( + self.valset, batch_size=1, shuffle=False, num_workers=1 + ) + ellipse_time = 0 + metrics = defaultdict(list) + for i, data in enumerate(valloader): + camtoworlds = data["camtoworld"].to(device) + Ks = data["K"].to(device) + pixels = data["image"].to(device) / 255.0 + height, width = pixels.shape[1:3] + masks = data["mask"].to(device) if "mask" in data else None + + pixels = pixels.permute(0, 3, 1, 2) # NHWC -> NCHW + pixels = F.interpolate(pixels, size=(canvas_h, canvas_w // 2)) + + torch.cuda.synchronize() + tic = time.time() + colors, _, _ = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=cfg.sh_degree, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + masks=masks, + ) # [1, H, W, 3] + torch.cuda.synchronize() + ellipse_time += max(time.time() - tic, 1e-10) + + colors = colors.permute(0, 3, 1, 2) # NHWC -> NCHW + colors = F.interpolate(colors, size=(canvas_h, canvas_w // 2)) + colors = torch.clamp(colors, 0.0, 1.0) + canvas_list = [pixels, colors] + + if world_rank == 0: + canvas = torch.cat(canvas_list, dim=2).squeeze(0) + canvas = canvas.permute(1, 2, 0) # CHW -> HWC + canvas = (canvas * 255).to(torch.uint8).cpu().numpy() + cv2.imwrite( + f"{self.render_dir}/{stage}_step{step}_{i:04d}.png", + canvas[..., ::-1], + ) + metrics["psnr"].append(self.psnr(colors, pixels)) + metrics["ssim"].append(self.ssim(colors, pixels)) + metrics["lpips"].append(self.lpips(colors, pixels)) + + if world_rank == 0: + ellipse_time /= len(valloader) + + stats = { + k: torch.stack(v).mean().item() for k, v in metrics.items() + } + stats.update( + { + "ellipse_time": ellipse_time, + "num_GS": len(self.splats["means"]), + } + ) + print( + f"PSNR: {stats['psnr']:.3f}, SSIM: {stats['ssim']:.4f}, LPIPS: {stats['lpips']:.3f} " + f"Time: {stats['ellipse_time']:.3f}s/image " + f"Number of GS: {stats['num_GS']}" + ) + # save stats as json + with open( + f"{self.stats_dir}/{stage}_step{step:04d}.json", "w" + ) as f: + json.dump(stats, f) + # save stats to tensorboard + for k, v in stats.items(): + self.writer.add_scalar(f"{stage}/{k}", v, step) + self.writer.flush() + + @torch.no_grad() + def render_video( + self, step: int, canvas_h: int = 512, canvas_w: int = 1024 + ): + testloader = torch.utils.data.DataLoader( + self.testset, batch_size=1, shuffle=False, num_workers=1 + ) + + images_cache = [] + depth_global_min, depth_global_max = float("inf"), -float("inf") + for data in testloader: + camtoworlds = data["camtoworld"].to(self.device) + Ks = resize_pinhole_intrinsics( + data["K"].squeeze(), + raw_hw=(data["image_h"].item(), data["image_w"].item()), + new_hw=(canvas_h, canvas_w // 2), + ).to(self.device) + renders, _, _ = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks[None, ...], + width=canvas_w // 2, + height=canvas_h, + sh_degree=self.cfg.sh_degree, + near_plane=self.cfg.near_plane, + far_plane=self.cfg.far_plane, + render_mode="RGB+ED", + ) # [1, H, W, 4] + colors = torch.clamp(renders[0, ..., 0:3], 0.0, 1.0) # [H, W, 3] + colors = (colors * 255).to(torch.uint8).cpu().numpy() + depths = renders[0, ..., 3:4] # [H, W, 1], tensor in device. + images_cache.append([colors, depths]) + depth_global_min = min(depth_global_min, depths.min().item()) + depth_global_max = max(depth_global_max, depths.max().item()) + + video_path = f"{self.render_dir}/video_step{step}.mp4" + writer = imageio.get_writer(video_path, fps=30) + for rgb, depth in images_cache: + depth_normalized = torch.clip( + (depth - depth_global_min) + / (depth_global_max - depth_global_min), + 0, + 1, + ) + depth_normalized = ( + (depth_normalized * 255).to(torch.uint8).cpu().numpy() + ) + depth_map = cv2.applyColorMap(depth_normalized, cv2.COLORMAP_JET) + image = np.concatenate([rgb, depth_map], axis=1) + writer.append_data(image) + + writer.close() + + +def entrypoint( + local_rank: int, world_rank, world_size: int, cfg: GsplatTrainConfig +): + runner = Runner(local_rank, world_rank, world_size, cfg) + + if cfg.ckpt is not None: + # run eval only + ckpts = [ + torch.load(file, map_location=runner.device, weights_only=True) + for file in cfg.ckpt + ] + for k in runner.splats.keys(): + runner.splats[k].data = torch.cat( + [ckpt["splats"][k] for ckpt in ckpts] + ) + step = ckpts[0]["step"] + runner.eval(step=step) + runner.render_video(step=step) + else: + runner.train() + runner.render_video(step=cfg.max_steps - 1) + + +if __name__ == "__main__": + configs = { + "default": ( + "Gaussian splatting training using densification heuristics from the original paper.", + GsplatTrainConfig( + strategy=DefaultStrategy(verbose=True), + ), + ), + "mcmc": ( + "Gaussian splatting training using densification from the paper '3D Gaussian Splatting as Markov Chain Monte Carlo'.", + GsplatTrainConfig( + init_scale=0.1, + opacity_reg=0.01, + scale_reg=0.01, + strategy=MCMCStrategy(verbose=True), + ), + ), + } + cfg = tyro.extras.overridable_config_cli(configs) + cfg.adjust_steps(cfg.steps_scaler) + + cli(entrypoint, cfg, verbose=True) diff --git a/embodied_gen/trainer/pono2mesh_trainer.py b/embodied_gen/trainer/pono2mesh_trainer.py new file mode 100644 index 0000000..b234b79 --- /dev/null +++ b/embodied_gen/trainer/pono2mesh_trainer.py @@ -0,0 +1,538 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +from embodied_gen.utils.monkey_patches import monkey_patch_pano2room + +monkey_patch_pano2room() + +import os + +import cv2 +import numpy as np +import torch +import trimesh +from equilib import cube2equi, equi2pers +from kornia.morphology import dilation +from PIL import Image +from embodied_gen.models.sr_model import ImageRealESRGAN +from embodied_gen.utils.config import Pano2MeshSRConfig +from embodied_gen.utils.gaussian import compute_pinhole_intrinsics +from embodied_gen.utils.log import logger +from thirdparty.pano2room.modules.geo_predictors import PanoJointPredictor +from thirdparty.pano2room.modules.geo_predictors.PanoFusionDistancePredictor import ( + PanoFusionDistancePredictor, +) +from thirdparty.pano2room.modules.inpainters import PanoPersFusionInpainter +from thirdparty.pano2room.modules.mesh_fusion.render import ( + features_to_world_space_mesh, + render_mesh, +) +from thirdparty.pano2room.modules.mesh_fusion.sup_info import SupInfoPool +from thirdparty.pano2room.utils.camera_utils import gen_pano_rays +from thirdparty.pano2room.utils.functions import ( + depth_to_distance, + get_cubemap_views_world_to_cam, + resize_image_with_aspect_ratio, + rot_z_world_to_cam, + tensor_to_pil, +) + + +class Pano2MeshSRPipeline: + """Converting panoramic RGB image into 3D mesh representations, followed by inpainting and mesh refinement. + + This class integrates several key components including: + - Depth estimation from RGB panorama + - Inpainting of missing regions under offsets + - RGB-D to mesh conversion + - Multi-view mesh repair + - 3D Gaussian Splatting (3DGS) dataset generation + + Args: + config (Pano2MeshSRConfig): Configuration object containing model and pipeline parameters. + + Example: + ```python + pipeline = Pano2MeshSRPipeline(config) + pipeline(pano_image='example.png', output_dir='./output') + ``` + """ + + def __init__(self, config: Pano2MeshSRConfig) -> None: + self.cfg = config + self.device = config.device + + # Init models. + self.inpainter = PanoPersFusionInpainter(save_path=None) + self.geo_predictor = PanoJointPredictor(save_path=None) + self.pano_fusion_distance_predictor = PanoFusionDistancePredictor() + self.super_model = ImageRealESRGAN(outscale=self.cfg.upscale_factor) + + # Init poses. + cubemap_w2cs = get_cubemap_views_world_to_cam() + self.cubemap_w2cs = [p.to(self.device) for p in cubemap_w2cs] + self.camera_poses = self.load_camera_poses(self.cfg.trajectory_dir) + + kernel = cv2.getStructuringElement( + cv2.MORPH_ELLIPSE, self.cfg.kernel_size + ) + self.kernel = torch.from_numpy(kernel).float().to(self.device) + + def init_mesh_params(self) -> None: + torch.set_default_device(self.device) + self.inpaint_mask = torch.ones( + (self.cfg.cubemap_h, self.cfg.cubemap_w), dtype=torch.bool + ) + self.vertices = torch.empty((3, 0), requires_grad=False) + self.colors = torch.empty((3, 0), requires_grad=False) + self.faces = torch.empty((3, 0), dtype=torch.long, requires_grad=False) + + @staticmethod + def read_camera_pose_file(filepath: str) -> np.ndarray: + with open(filepath, "r") as f: + values = [float(num) for line in f for num in line.split()] + + return np.array(values).reshape(4, 4) + + def load_camera_poses( + self, trajectory_dir: str + ) -> tuple[np.ndarray, list[torch.Tensor]]: + pose_filenames = sorted( + [ + fname + for fname in os.listdir(trajectory_dir) + if fname.startswith("camera_pose") + ] + ) + + pano_pose_world = None + relative_poses = [] + for idx, filename in enumerate(pose_filenames): + pose_path = os.path.join(trajectory_dir, filename) + pose_matrix = self.read_camera_pose_file(pose_path) + + if pano_pose_world is None: + pano_pose_world = pose_matrix.copy() + pano_pose_world[0, 3] += self.cfg.pano_center_offset[0] + pano_pose_world[2, 3] += self.cfg.pano_center_offset[1] + + # Use different reference for the first 6 cubemap views + reference_pose = pose_matrix if idx < 6 else pano_pose_world + relative_matrix = pose_matrix @ np.linalg.inv(reference_pose) + relative_matrix[0:2, :] *= -1 # flip_xy + relative_matrix = ( + relative_matrix @ rot_z_world_to_cam(180).cpu().numpy() + ) + relative_matrix[:3, 3] *= self.cfg.pose_scale + relative_matrix = torch.tensor( + relative_matrix, dtype=torch.float32 + ) + relative_poses.append(relative_matrix) + + return relative_poses + + def load_inpaint_poses( + self, poses: torch.Tensor + ) -> dict[int, torch.Tensor]: + inpaint_poses = dict() + sampled_views = poses[:: self.cfg.inpaint_frame_stride] + init_pose = torch.eye(4) + for idx, w2c_tensor in enumerate(sampled_views): + w2c = w2c_tensor.cpu().numpy().astype(np.float32) + c2w = np.linalg.inv(w2c) + pose_tensor = init_pose.clone() + pose_tensor[:3, 3] = torch.from_numpy(c2w[:3, 3]) + pose_tensor[:3, 3] *= -1 + inpaint_poses[idx] = pose_tensor.to(self.device) + + return inpaint_poses + + def project(self, world_to_cam: torch.Tensor): + ( + project_image, + project_depth, + inpaint_mask, + _, + z_buf, + mesh, + ) = render_mesh( + vertices=self.vertices, + faces=self.faces, + vertex_features=self.colors, + H=self.cfg.cubemap_h, + W=self.cfg.cubemap_w, + fov_in_degrees=self.cfg.fov, + RT=world_to_cam, + blur_radius=self.cfg.blur_radius, + faces_per_pixel=self.cfg.faces_per_pixel, + ) + project_image = project_image * ~inpaint_mask + + return project_image[:3, ...], inpaint_mask, project_depth + + def render_pano(self, pose: torch.Tensor): + cubemap_list = [] + for cubemap_pose in self.cubemap_w2cs: + project_pose = cubemap_pose @ pose + rgb, inpaint_mask, depth = self.project(project_pose) + distance_map = depth_to_distance(depth[None, ...]) + mask = inpaint_mask[None, ...] + cubemap_list.append(torch.cat([rgb, distance_map, mask], dim=0)) + + # Set default tensor type for CPU operation in cube2equi + with torch.device("cpu"): + pano_rgbd = cube2equi( + cubemap_list, "list", self.cfg.pano_h, self.cfg.pano_w + ) + + pano_rgb = pano_rgbd[:3, :, :] + pano_depth = pano_rgbd[3:4, :, :].squeeze(0) + pano_mask = pano_rgbd[4:, :, :].squeeze(0) + + return pano_rgb, pano_depth, pano_mask + + def rgbd_to_mesh( + self, + rgb: torch.Tensor, + depth: torch.Tensor, + inpaint_mask: torch.Tensor, + world_to_cam: torch.Tensor = None, + using_distance_map: bool = True, + ) -> None: + if world_to_cam is None: + world_to_cam = torch.eye(4, dtype=torch.float32).to(self.device) + + if inpaint_mask.sum() == 0: + return + + vertices, faces, colors = features_to_world_space_mesh( + colors=rgb.squeeze(0), + depth=depth, + fov_in_degrees=self.cfg.fov, + world_to_cam=world_to_cam, + mask=inpaint_mask, + faces=self.faces, + vertices=self.vertices, + using_distance_map=using_distance_map, + edge_threshold=0.05, + ) + + faces += self.vertices.shape[1] + self.vertices = torch.cat([self.vertices, vertices], dim=1) + self.colors = torch.cat([self.colors, colors], dim=1) + self.faces = torch.cat([self.faces, faces], dim=1) + + def get_edge_image_by_depth( + self, depth: torch.Tensor, dilate_iter: int = 1 + ) -> np.ndarray: + if isinstance(depth, torch.Tensor): + depth = depth.cpu().detach().numpy() + + gray = (depth / depth.max() * 255).astype(np.uint8) + edges = cv2.Canny(gray, 60, 150) + if dilate_iter > 0: + kernel = np.ones((3, 3), np.uint8) + edges = cv2.dilate(edges, kernel, iterations=dilate_iter) + + return edges + + def mesh_repair_by_greedy_view_selection( + self, pose_dict: dict[str, torch.Tensor], output_dir: str + ) -> list: + inpainted_panos_w_pose = [] + while len(pose_dict) > 0: + logger.info(f"Repairing mesh left rounds {len(pose_dict)}") + sampled_views = [] + for key, pose in pose_dict.items(): + pano_rgb, pano_distance, pano_mask = self.render_pano(pose) + completeness = torch.sum(1 - pano_mask) / (pano_mask.numel()) + sampled_views.append((key, completeness.item(), pose)) + + if len(sampled_views) == 0: + break + + # Find inpainting with least view completeness. + sampled_views = sorted(sampled_views, key=lambda x: x[1]) + key, _, pose = sampled_views[len(sampled_views) * 2 // 3] + pose_dict.pop(key) + + pano_rgb, pano_distance, pano_mask = self.render_pano(pose) + + colors = pano_rgb.permute(1, 2, 0).clone() + distances = pano_distance.unsqueeze(-1).clone() + pano_inpaint_mask = pano_mask.clone() + init_pose = pose.clone() + normals = None + if pano_inpaint_mask.min().item() < 0.5: + colors, distances, normals = self.inpaint_panorama( + idx=key, + colors=colors, + distances=distances, + pano_mask=pano_inpaint_mask, + ) + + init_pose[0, 3], init_pose[1, 3], init_pose[2, 3] = ( + -pose[0, 3], + pose[2, 3], + 0, + ) + rays = gen_pano_rays( + init_pose, self.cfg.pano_h, self.cfg.pano_w + ) + conflict_mask = self.sup_pool.geo_check( + rays, distances.unsqueeze(-1) + ) # 0 is conflict, 1 not conflict + pano_inpaint_mask *= conflict_mask + + self.rgbd_to_mesh( + colors.permute(2, 0, 1), + distances, + pano_inpaint_mask, + world_to_cam=pose, + ) + + self.sup_pool.register_sup_info( + pose=init_pose, + mask=pano_inpaint_mask.clone(), + rgb=colors, + distance=distances.unsqueeze(-1), + normal=normals, + ) + + colors = colors.permute(2, 0, 1).unsqueeze(0) + inpainted_panos_w_pose.append([colors, pose]) + + if self.cfg.visualize: + from embodied_gen.data.utils import DiffrastRender + + tensor_to_pil(pano_rgb.unsqueeze(0)).save( + f"{output_dir}/rendered_pano_{key}.jpg" + ) + tensor_to_pil(colors).save( + f"{output_dir}/inpainted_pano_{key}.jpg" + ) + norm_depth = DiffrastRender.normalize_map_by_mask( + distances, torch.ones_like(distances) + ) + heatmap = (norm_depth.cpu().numpy() * 255).astype(np.uint8) + heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) + Image.fromarray(heatmap).save( + f"{output_dir}/inpainted_depth_{key}.png" + ) + + return inpainted_panos_w_pose + + def inpaint_panorama( + self, + idx: int, + colors: torch.Tensor, + distances: torch.Tensor, + pano_mask: torch.Tensor, + ) -> tuple[torch.Tensor]: + mask = (pano_mask[None, ..., None] > 0.5).float() + mask = mask.permute(0, 3, 1, 2) + mask = dilation(mask, kernel=self.kernel) + mask = mask[0, 0, ..., None] # hwc + inpainted_img = self.inpainter.inpaint(idx, colors, mask) + inpainted_img = colors * (1 - mask) + inpainted_img * mask + inpainted_distances, inpainted_normals = self.geo_predictor( + idx, + inpainted_img, + distances[..., None], + mask=mask, + reg_loss_weight=0.0, + normal_loss_weight=5e-2, + normal_tv_loss_weight=5e-2, + ) + + return inpainted_img, inpainted_distances.squeeze(), inpainted_normals + + def preprocess_pano( + self, image: Image.Image | str + ) -> tuple[torch.Tensor, torch.Tensor]: + if isinstance(image, str): + image = Image.open(image) + + image = image.convert("RGB") + + if image.size[0] < image.size[1]: + image = image.transpose(Image.TRANSPOSE) + + image = resize_image_with_aspect_ratio(image, self.cfg.pano_w) + image_rgb = torch.tensor(np.array(image)).permute(2, 0, 1) / 255 + image_rgb = image_rgb.to(self.device) + image_depth = self.pano_fusion_distance_predictor.predict( + image_rgb.permute(1, 2, 0) + ) + image_depth = ( + image_depth / image_depth.max() * self.cfg.depth_scale_factor + ) + + return image_rgb, image_depth + + def pano_to_perpective( + self, pano_image: torch.Tensor, pitch: float, yaw: float, fov: float + ) -> torch.Tensor: + rots = dict( + roll=0, + pitch=pitch, + yaw=yaw, + ) + perspective = equi2pers( + equi=pano_image.squeeze(0), + rots=rots, + height=self.cfg.cubemap_h, + width=self.cfg.cubemap_w, + fov_x=fov, + mode="bilinear", + ).unsqueeze(0) + + return perspective + + def pano_to_cubemap(self, pano_rgb: torch.Tensor): + # Define six canonical cube directions in (pitch, yaw) + directions = [ + (0, 0), + (0, 1.5 * np.pi), + (0, 1.0 * np.pi), + (0, 0.5 * np.pi), + (-0.5 * np.pi, 0), + (0.5 * np.pi, 0), + ] + + cubemaps_rgb = [] + for pitch, yaw in directions: + rgb_view = self.pano_to_perpective( + pano_rgb, pitch, yaw, fov=self.cfg.fov + ) + cubemaps_rgb.append(rgb_view.cpu()) + + return cubemaps_rgb + + def save_mesh(self, output_path: str) -> None: + vertices_np = self.vertices.T.cpu().numpy() + colors_np = self.colors.T.cpu().numpy() + faces_np = self.faces.T.cpu().numpy() + mesh = trimesh.Trimesh( + vertices=vertices_np, faces=faces_np, vertex_colors=colors_np + ) + + mesh.export(output_path) + + def mesh_pose_to_gs_pose(self, mesh_pose: torch.Tensor) -> np.ndarray: + pose = mesh_pose.clone() + pose[0, :] *= -1 + pose[1, :] *= -1 + + Rw2c = pose[:3, :3].cpu().numpy() + Tw2c = pose[:3, 3:].cpu().numpy() + yz_reverse = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) + + Rc2w = (yz_reverse @ Rw2c).T + Tc2w = -(Rc2w @ yz_reverse @ Tw2c) + c2w = np.concatenate((Rc2w, Tc2w), axis=1) + c2w = np.concatenate((c2w, np.array([[0, 0, 0, 1]])), axis=0) + + return c2w + + def __call__(self, pano_image: Image.Image | str, output_dir: str): + self.init_mesh_params() + pano_rgb, pano_depth = self.preprocess_pano(pano_image) + self.sup_pool = SupInfoPool() + self.sup_pool.register_sup_info( + pose=torch.eye(4).to(self.device), + mask=torch.ones([self.cfg.pano_h, self.cfg.pano_w]), + rgb=pano_rgb.permute(1, 2, 0), + distance=pano_depth[..., None], + ) + self.sup_pool.gen_occ_grid(res=256) + + logger.info("Init mesh from pano RGBD image...") + depth_edge = self.get_edge_image_by_depth(pano_depth) + inpaint_edge_mask = ( + ~torch.from_numpy(depth_edge).to(self.device).bool() + ) + self.rgbd_to_mesh(pano_rgb, pano_depth, inpaint_edge_mask) + + repair_poses = self.load_inpaint_poses(self.camera_poses) + inpainted_panos_w_poses = self.mesh_repair_by_greedy_view_selection( + repair_poses, output_dir + ) + torch.cuda.empty_cache() + torch.set_default_device("cpu") + + if self.cfg.mesh_file is not None: + mesh_path = os.path.join(output_dir, self.cfg.mesh_file) + self.save_mesh(mesh_path) + + if self.cfg.gs_data_file is None: + return + + logger.info(f"Dump data for 3DGS training...") + points_rgb = (self.colors.clip(0, 1) * 255).to(torch.uint8) + data = { + "points": self.vertices.permute(1, 0).cpu().numpy(), # (N, 3) + "points_rgb": points_rgb.permute(1, 0).cpu().numpy(), # (N, 3) + "train": [], + "eval": [], + } + image_h = self.cfg.cubemap_h * self.cfg.upscale_factor + image_w = self.cfg.cubemap_w * self.cfg.upscale_factor + Ks = compute_pinhole_intrinsics(image_w, image_h, self.cfg.fov) + for idx, (pano_img, pano_pose) in enumerate(inpainted_panos_w_poses): + cubemaps = self.pano_to_cubemap(pano_img) + for i in range(len(cubemaps)): + cubemap = tensor_to_pil(cubemaps[i]) + cubemap = self.super_model(cubemap) + mesh_pose = self.cubemap_w2cs[i] @ pano_pose + c2w = self.mesh_pose_to_gs_pose(mesh_pose) + data["train"].append( + { + "camtoworld": c2w.astype(np.float32), + "K": Ks.astype(np.float32), + "image": np.array(cubemap), + "image_h": image_h, + "image_w": image_w, + "image_id": len(cubemaps) * idx + i, + } + ) + + # Camera poses for evaluation. + for idx in range(len(self.camera_poses)): + c2w = self.mesh_pose_to_gs_pose(self.camera_poses[idx]) + data["eval"].append( + { + "camtoworld": c2w.astype(np.float32), + "K": Ks.astype(np.float32), + "image_h": image_h, + "image_w": image_w, + "image_id": idx, + } + ) + + data_path = os.path.join(output_dir, self.cfg.gs_data_file) + torch.save(data, data_path) + + return + + +if __name__ == "__main__": + output_dir = "outputs/bg_v2/test3" + input_pano = "apps/assets/example_scene/result_pano.png" + config = Pano2MeshSRConfig() + pipeline = Pano2MeshSRPipeline(config) + pipeline(input_pano, output_dir) diff --git a/embodied_gen/utils/config.py b/embodied_gen/utils/config.py new file mode 100644 index 0000000..8c08590 --- /dev/null +++ b/embodied_gen/utils/config.py @@ -0,0 +1,190 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +from dataclasses import dataclass, field +from typing import List, Optional, Union + +from gsplat.strategy import DefaultStrategy, MCMCStrategy +from typing_extensions import Literal, assert_never + +__all__ = [ + "Pano2MeshSRConfig", + "GsplatTrainConfig", +] + + +@dataclass +class Pano2MeshSRConfig: + mesh_file: str = "mesh_model.ply" + gs_data_file: str = "gs_data.pt" + device: str = "cuda" + blur_radius: int = 0 + faces_per_pixel: int = 8 + fov: int = 90 + pano_w: int = 2048 + pano_h: int = 1024 + cubemap_w: int = 512 + cubemap_h: int = 512 + pose_scale: float = 0.6 + pano_center_offset: tuple = (-0.2, 0.3) + inpaint_frame_stride: int = 20 + trajectory_dir: str = "apps/assets/example_scene/camera_trajectory" + visualize: bool = False + depth_scale_factor: float = 3.4092 + kernel_size: tuple = (9, 9) + upscale_factor: int = 4 + + +@dataclass +class GsplatTrainConfig: + # Path to the .pt files. If provide, it will skip training and run evaluation only. + ckpt: Optional[List[str]] = None + # Render trajectory path + render_traj_path: str = "interp" + + # Path to the Mip-NeRF 360 dataset + data_dir: str = "outputs/bg" + # Downsample factor for the dataset + data_factor: int = 4 + # Directory to save results + result_dir: str = "outputs/bg" + # Every N images there is a test image + test_every: int = 8 + # Random crop size for training (experimental) + patch_size: Optional[int] = None + # A global scaler that applies to the scene size related parameters + global_scale: float = 1.0 + # Normalize the world space + normalize_world_space: bool = True + # Camera model + camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole" + + # Port for the viewer server + port: int = 8080 + + # Batch size for training. Learning rates are scaled automatically + batch_size: int = 1 + # A global factor to scale the number of training steps + steps_scaler: float = 1.0 + + # Number of training steps + max_steps: int = 30_000 + # Steps to evaluate the model + eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + # Steps to save the model + save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + # Whether to save ply file (storage size can be large) + save_ply: bool = True + # Steps to save the model as ply + ply_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + # Whether to disable video generation during training and evaluation + disable_video: bool = False + + # Initial number of GSs. Ignored if using sfm + init_num_pts: int = 100_000 + # Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm + init_extent: float = 3.0 + # Degree of spherical harmonics + sh_degree: int = 1 + # Turn on another SH degree every this steps + sh_degree_interval: int = 1000 + # Initial opacity of GS + init_opa: float = 0.1 + # Initial scale of GS + init_scale: float = 1.0 + # Weight for SSIM loss + ssim_lambda: float = 0.2 + + # Near plane clipping distance + near_plane: float = 0.01 + # Far plane clipping distance + far_plane: float = 1e10 + + # Strategy for GS densification + strategy: Union[DefaultStrategy, MCMCStrategy] = field( + default_factory=DefaultStrategy + ) + # Use packed mode for rasterization, this leads to less memory usage but slightly slower. + packed: bool = False + # Use sparse gradients for optimization. (experimental) + sparse_grad: bool = False + # Use visible adam from Taming 3DGS. (experimental) + visible_adam: bool = False + # Anti-aliasing in rasterization. Might slightly hurt quantitative metrics. + antialiased: bool = False + + # Use random background for training to discourage transparency + random_bkgd: bool = False + + # LR for 3D point positions + means_lr: float = 1.6e-4 + # LR for Gaussian scale factors + scales_lr: float = 5e-3 + # LR for alpha blending weights + opacities_lr: float = 5e-2 + # LR for orientation (quaternions) + quats_lr: float = 1e-3 + # LR for SH band 0 (brightness) + sh0_lr: float = 2.5e-3 + # LR for higher-order SH (detail) + shN_lr: float = 2.5e-3 / 20 + + # Opacity regularization + opacity_reg: float = 0.0 + # Scale regularization + scale_reg: float = 0.0 + + # Enable depth loss. (experimental) + depth_loss: bool = False + # Weight for depth loss + depth_lambda: float = 1e-2 + + # Dump information to tensorboard every this steps + tb_every: int = 200 + # Save training images to tensorboard + tb_save_image: bool = False + + lpips_net: Literal["vgg", "alex"] = "alex" + + # 3DGUT (uncented transform + eval 3D) + with_ut: bool = False + with_eval3d: bool = False + + scene_scale: float = 1.0 + + def adjust_steps(self, factor: float): + self.eval_steps = [int(i * factor) for i in self.eval_steps] + self.save_steps = [int(i * factor) for i in self.save_steps] + self.ply_steps = [int(i * factor) for i in self.ply_steps] + self.max_steps = int(self.max_steps * factor) + self.sh_degree_interval = int(self.sh_degree_interval * factor) + + strategy = self.strategy + if isinstance(strategy, DefaultStrategy): + strategy.refine_start_iter = int( + strategy.refine_start_iter * factor + ) + strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) + strategy.reset_every = int(strategy.reset_every * factor) + strategy.refine_every = int(strategy.refine_every * factor) + elif isinstance(strategy, MCMCStrategy): + strategy.refine_start_iter = int( + strategy.refine_start_iter * factor + ) + strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) + strategy.refine_every = int(strategy.refine_every * factor) + else: + assert_never(strategy) diff --git a/embodied_gen/utils/gaussian.py b/embodied_gen/utils/gaussian.py new file mode 100644 index 0000000..68ee73e --- /dev/null +++ b/embodied_gen/utils/gaussian.py @@ -0,0 +1,331 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. +# Part of the code comes from https://github.com/nerfstudio-project/gsplat +# Both under the Apache License, Version 2.0. + + +import math +import random +from io import BytesIO +from typing import Dict, Literal, Optional, Tuple + +import numpy as np +import torch +import trimesh +from gsplat.optimizers import SelectiveAdam +from scipy.spatial.transform import Rotation +from sklearn.neighbors import NearestNeighbors +from torch import Tensor +from embodied_gen.models.gs_model import GaussianOperator + +__all__ = [ + "set_random_seed", + "export_splats", + "create_splats_with_optimizers", + "compute_pinhole_intrinsics", + "resize_pinhole_intrinsics", + "restore_scene_scale_and_position", +] + + +def knn(x: Tensor, K: int = 4) -> Tensor: + x_np = x.cpu().numpy() + model = NearestNeighbors(n_neighbors=K, metric="euclidean").fit(x_np) + distances, _ = model.kneighbors(x_np) + return torch.from_numpy(distances).to(x) + + +def rgb_to_sh(rgb: Tensor) -> Tensor: + C0 = 0.28209479177387814 + return (rgb - 0.5) / C0 + + +def set_random_seed(seed: int): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + +def splat2ply_bytes( + means: torch.Tensor, + scales: torch.Tensor, + quats: torch.Tensor, + opacities: torch.Tensor, + sh0: torch.Tensor, + shN: torch.Tensor, +) -> bytes: + num_splats = means.shape[0] + buffer = BytesIO() + + # Write PLY header + buffer.write(b"ply\n") + buffer.write(b"format binary_little_endian 1.0\n") + buffer.write(f"element vertex {num_splats}\n".encode()) + buffer.write(b"property float x\n") + buffer.write(b"property float y\n") + buffer.write(b"property float z\n") + for i, data in enumerate([sh0, shN]): + prefix = "f_dc" if i == 0 else "f_rest" + for j in range(data.shape[1]): + buffer.write(f"property float {prefix}_{j}\n".encode()) + buffer.write(b"property float opacity\n") + for i in range(scales.shape[1]): + buffer.write(f"property float scale_{i}\n".encode()) + for i in range(quats.shape[1]): + buffer.write(f"property float rot_{i}\n".encode()) + buffer.write(b"end_header\n") + + # Concatenate all tensors in the correct order + splat_data = torch.cat( + [means, sh0, shN, opacities.unsqueeze(1), scales, quats], dim=1 + ) + # Ensure correct dtype + splat_data = splat_data.to(torch.float32) + + # Write binary data + float_dtype = np.dtype(np.float32).newbyteorder("<") + buffer.write( + splat_data.detach().cpu().numpy().astype(float_dtype).tobytes() + ) + + return buffer.getvalue() + + +def export_splats( + means: torch.Tensor, + scales: torch.Tensor, + quats: torch.Tensor, + opacities: torch.Tensor, + sh0: torch.Tensor, + shN: torch.Tensor, + format: Literal["ply"] = "ply", + save_to: Optional[str] = None, +) -> bytes: + """Export a Gaussian Splats model to bytes in PLY file format.""" + total_splats = means.shape[0] + assert means.shape == (total_splats, 3), "Means must be of shape (N, 3)" + assert scales.shape == (total_splats, 3), "Scales must be of shape (N, 3)" + assert quats.shape == ( + total_splats, + 4, + ), "Quaternions must be of shape (N, 4)" + assert opacities.shape == ( + total_splats, + ), "Opacities must be of shape (N,)" + assert sh0.shape == (total_splats, 1, 3), "sh0 must be of shape (N, 1, 3)" + assert ( + shN.ndim == 3 and shN.shape[0] == total_splats and shN.shape[2] == 3 + ), f"shN must be of shape (N, K, 3), got {shN.shape}" + + # Reshape spherical harmonics + sh0 = sh0.squeeze(1) # Shape (N, 3) + shN = shN.permute(0, 2, 1).reshape(means.shape[0], -1) # Shape (N, K * 3) + + # Check for NaN or Inf values + invalid_mask = ( + torch.isnan(means).any(dim=1) + | torch.isinf(means).any(dim=1) + | torch.isnan(scales).any(dim=1) + | torch.isinf(scales).any(dim=1) + | torch.isnan(quats).any(dim=1) + | torch.isinf(quats).any(dim=1) + | torch.isnan(opacities).any(dim=0) + | torch.isinf(opacities).any(dim=0) + | torch.isnan(sh0).any(dim=1) + | torch.isinf(sh0).any(dim=1) + | torch.isnan(shN).any(dim=1) + | torch.isinf(shN).any(dim=1) + ) + + # Filter out invalid entries + valid_mask = ~invalid_mask + means = means[valid_mask] + scales = scales[valid_mask] + quats = quats[valid_mask] + opacities = opacities[valid_mask] + sh0 = sh0[valid_mask] + shN = shN[valid_mask] + + if format == "ply": + data = splat2ply_bytes(means, scales, quats, opacities, sh0, shN) + else: + raise ValueError(f"Unsupported format: {format}") + + if save_to: + with open(save_to, "wb") as binary_file: + binary_file.write(data) + + return data + + +def create_splats_with_optimizers( + points: np.ndarray = None, + points_rgb: np.ndarray = None, + init_num_pts: int = 100_000, + init_extent: float = 3.0, + init_opacity: float = 0.1, + init_scale: float = 1.0, + means_lr: float = 1.6e-4, + scales_lr: float = 5e-3, + opacities_lr: float = 5e-2, + quats_lr: float = 1e-3, + sh0_lr: float = 2.5e-3, + shN_lr: float = 2.5e-3 / 20, + scene_scale: float = 1.0, + sh_degree: int = 3, + sparse_grad: bool = False, + visible_adam: bool = False, + batch_size: int = 1, + feature_dim: Optional[int] = None, + device: str = "cuda", + world_rank: int = 0, + world_size: int = 1, +) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]: + if points is not None and points_rgb is not None: + points = torch.from_numpy(points).float() + rgbs = torch.from_numpy(points_rgb / 255.0).float() + else: + points = ( + init_extent * scene_scale * (torch.rand((init_num_pts, 3)) * 2 - 1) + ) + rgbs = torch.rand((init_num_pts, 3)) + + # Initialize the GS size to be the average dist of the 3 nearest neighbors + dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,] + dist_avg = torch.sqrt(dist2_avg) + scales = ( + torch.log(dist_avg * init_scale).unsqueeze(-1).repeat(1, 3) + ) # [N, 3] + + # Distribute the GSs to different ranks (also works for single rank) + points = points[world_rank::world_size] + rgbs = rgbs[world_rank::world_size] + scales = scales[world_rank::world_size] + + N = points.shape[0] + quats = torch.rand((N, 4)) # [N, 4] + opacities = torch.logit(torch.full((N,), init_opacity)) # [N,] + + params = [ + # name, value, lr + ("means", torch.nn.Parameter(points), means_lr * scene_scale), + ("scales", torch.nn.Parameter(scales), scales_lr), + ("quats", torch.nn.Parameter(quats), quats_lr), + ("opacities", torch.nn.Parameter(opacities), opacities_lr), + ] + + if feature_dim is None: + # color is SH coefficients. + colors = torch.zeros((N, (sh_degree + 1) ** 2, 3)) # [N, K, 3] + colors[:, 0, :] = rgb_to_sh(rgbs) + params.append(("sh0", torch.nn.Parameter(colors[:, :1, :]), sh0_lr)) + params.append(("shN", torch.nn.Parameter(colors[:, 1:, :]), shN_lr)) + else: + # features will be used for appearance and view-dependent shading + features = torch.rand(N, feature_dim) # [N, feature_dim] + params.append(("features", torch.nn.Parameter(features), sh0_lr)) + colors = torch.logit(rgbs) # [N, 3] + params.append(("colors", torch.nn.Parameter(colors), sh0_lr)) + + splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device) + # Scale learning rate based on batch size, reference: + # https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/ + # Note that this would not make the training exactly equivalent, see + # https://arxiv.org/pdf/2402.18824v1 + BS = batch_size * world_size + optimizer_class = None + if sparse_grad: + optimizer_class = torch.optim.SparseAdam + elif visible_adam: + optimizer_class = SelectiveAdam + else: + optimizer_class = torch.optim.Adam + optimizers = { + name: optimizer_class( + [{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}], + eps=1e-15 / math.sqrt(BS), + # TODO: check betas logic when BS is larger than 10 betas[0] will be zero. + betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)), + ) + for name, _, lr in params + } + return splats, optimizers + + +def compute_pinhole_intrinsics( + image_w: int, image_h: int, fov_deg: float +) -> np.ndarray: + fov_rad = np.deg2rad(fov_deg) + fx = image_w / (2 * np.tan(fov_rad / 2)) + fy = fx # assuming square pixels + cx = image_w / 2 + cy = image_h / 2 + K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) + + return K + + +def resize_pinhole_intrinsics( + raw_K: np.ndarray | torch.Tensor, + raw_hw: tuple[int, int], + new_hw: tuple[int, int], +) -> np.ndarray: + raw_h, raw_w = raw_hw + new_h, new_w = new_hw + + scale_x = new_w / raw_w + scale_y = new_h / raw_h + + new_K = raw_K.copy() if isinstance(raw_K, np.ndarray) else raw_K.clone() + new_K[0, 0] *= scale_x # fx + new_K[0, 2] *= scale_x # cx + new_K[1, 1] *= scale_y # fy + new_K[1, 2] *= scale_y # cy + + return new_K + + +def restore_scene_scale_and_position( + real_height: float, mesh_path: str, gs_path: str +) -> None: + """Scales a mesh and corresponding GS model to match a given real-world height. + + Uses the 1st and 99th percentile of mesh Z-axis to estimate height, + applies scaling and vertical alignment, and updates both the mesh and GS model. + + Args: + real_height (float): Target real-world height among Z axis. + mesh_path (str): Path to the input mesh file. + gs_path (str): Path to the Gaussian Splatting model file. + """ + mesh = trimesh.load(mesh_path) + z_min = np.percentile(mesh.vertices[:, 1], 1) + z_max = np.percentile(mesh.vertices[:, 1], 99) + height = z_max - z_min + scale = real_height / height + + rot = Rotation.from_quat([0, 1, 0, 0]) + mesh.vertices = rot.apply(mesh.vertices) + mesh.vertices[:, 1] -= z_min + mesh.vertices *= scale + mesh.export(mesh_path) + + gs_model: GaussianOperator = GaussianOperator.load_from_ply(gs_path) + gs_model = gs_model.get_gaussians( + instance_pose=torch.tensor([0.0, -z_min, 0, 0, 1, 0, 0]) + ) + gs_model.rescale(scale) + gs_model.save_to_ply(gs_path) diff --git a/embodied_gen/utils/monkey_patches.py b/embodied_gen/utils/monkey_patches.py new file mode 100644 index 0000000..79e77d5 --- /dev/null +++ b/embodied_gen/utils/monkey_patches.py @@ -0,0 +1,152 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +import os +import sys +import zipfile + +import torch +from huggingface_hub import hf_hub_download +from omegaconf import OmegaConf +from PIL import Image +from torchvision import transforms + + +def monkey_patch_pano2room(): + current_file_path = os.path.abspath(__file__) + current_dir = os.path.dirname(current_file_path) + sys.path.append(os.path.join(current_dir, "../..")) + sys.path.append(os.path.join(current_dir, "../../thirdparty/pano2room")) + from thirdparty.pano2room.modules.geo_predictors.omnidata.omnidata_normal_predictor import ( + OmnidataNormalPredictor, + ) + from thirdparty.pano2room.modules.geo_predictors.omnidata.omnidata_predictor import ( + OmnidataPredictor, + ) + + def patched_omni_depth_init(self): + self.img_size = 384 + self.model = torch.hub.load( + 'alexsax/omnidata_models', 'depth_dpt_hybrid_384' + ) + self.model.eval() + self.trans_totensor = transforms.Compose( + [ + transforms.Resize(self.img_size, interpolation=Image.BILINEAR), + transforms.CenterCrop(self.img_size), + transforms.Normalize(mean=0.5, std=0.5), + ] + ) + + OmnidataPredictor.__init__ = patched_omni_depth_init + + def patched_omni_normal_init(self): + self.img_size = 384 + self.model = torch.hub.load( + 'alexsax/omnidata_models', 'surface_normal_dpt_hybrid_384' + ) + self.model.eval() + self.trans_totensor = transforms.Compose( + [ + transforms.Resize(self.img_size, interpolation=Image.BILINEAR), + transforms.CenterCrop(self.img_size), + transforms.Normalize(mean=0.5, std=0.5), + ] + ) + + OmnidataNormalPredictor.__init__ = patched_omni_normal_init + + def patched_panojoint_init(self, save_path=None): + self.depth_predictor = OmnidataPredictor() + self.normal_predictor = OmnidataNormalPredictor() + self.save_path = save_path + + from modules.geo_predictors import PanoJointPredictor + + PanoJointPredictor.__init__ = patched_panojoint_init + + # NOTE: We use gsplat instead. + # import depth_diff_gaussian_rasterization_min as ddgr + # from dataclasses import dataclass + # @dataclass + # class PatchedGaussianRasterizationSettings: + # image_height: int + # image_width: int + # tanfovx: float + # tanfovy: float + # bg: torch.Tensor + # scale_modifier: float + # viewmatrix: torch.Tensor + # projmatrix: torch.Tensor + # sh_degree: int + # campos: torch.Tensor + # prefiltered: bool + # debug: bool = False + # ddgr.GaussianRasterizationSettings = PatchedGaussianRasterizationSettings + + # disable get_has_ddp_rank print in `BaseInpaintingTrainingModule` + os.environ["NODE_RANK"] = "0" + + from thirdparty.pano2room.modules.inpainters.lama.saicinpainting.training.trainers import ( + load_checkpoint, + ) + from thirdparty.pano2room.modules.inpainters.lama_inpainter import ( + LamaInpainter, + ) + + def patched_lama_inpaint_init(self): + zip_path = hf_hub_download( + repo_id="smartywu/big-lama", + filename="big-lama.zip", + repo_type="model", + ) + extract_dir = os.path.splitext(zip_path)[0] + + if not os.path.exists(extract_dir): + os.makedirs(extract_dir, exist_ok=True) + with zipfile.ZipFile(zip_path, "r") as zip_ref: + zip_ref.extractall(extract_dir) + + config_path = os.path.join(extract_dir, 'big-lama', 'config.yaml') + checkpoint_path = os.path.join( + extract_dir, 'big-lama/models/best.ckpt' + ) + train_config = OmegaConf.load(config_path) + train_config.training_model.predict_only = True + train_config.visualizer.kind = 'noop' + + self.model = load_checkpoint( + train_config, checkpoint_path, strict=False, map_location='cpu' + ) + self.model.freeze() + + LamaInpainter.__init__ = patched_lama_inpaint_init + + from diffusers import StableDiffusionInpaintPipeline + from thirdparty.pano2room.modules.inpainters.SDFT_inpainter import ( + SDFTInpainter, + ) + + def patched_sd_inpaint_init(self, subset_name=None): + super(SDFTInpainter, self).__init__() + pipe = StableDiffusionInpaintPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-inpainting", + torch_dtype=torch.float16, + ).to("cuda") + pipe.enable_model_cpu_offload() + self.inpaint_pipe = pipe + + SDFTInpainter.__init__ = patched_sd_inpaint_init diff --git a/embodied_gen/utils/process_media.py b/embodied_gen/utils/process_media.py index edfdcff..3087d4f 100644 --- a/embodied_gen/utils/process_media.py +++ b/embodied_gen/utils/process_media.py @@ -17,6 +17,7 @@ import logging import math +import mimetypes import os import textwrap from glob import glob @@ -27,10 +28,10 @@ import imageio import matplotlib.pyplot as plt import networkx as nx import numpy as np -import PIL.Image as Image import spaces from matplotlib.patches import Patch from moviepy.editor import VideoFileClip, clips_array +from PIL import Image from embodied_gen.data.differentiable_render import entrypoint as render_api from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum @@ -45,6 +46,8 @@ __all__ = [ "filter_image_small_connected_components", "combine_images_to_grid", "SceneTreeVisualizer", + "is_image_file", + "parse_text_prompts", ] @@ -356,6 +359,23 @@ def load_scene_dict(file_path: str) -> dict: return scene_dict +def is_image_file(filename: str) -> bool: + mime_type, _ = mimetypes.guess_type(filename) + + return mime_type is not None and mime_type.startswith('image') + + +def parse_text_prompts(prompts: list[str]) -> list[str]: + if len(prompts) == 1 and prompts[0].endswith(".txt"): + with open(prompts[0], "r") as f: + prompts = [ + line.strip() + for line in f + if line.strip() and not line.strip().startswith("#") + ] + return prompts + + if __name__ == "__main__": merge_video_video( "outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh_glo_normal.mp4", # noqa diff --git a/embodied_gen/utils/tags.py b/embodied_gen/utils/tags.py index 56deb45..60de686 100644 --- a/embodied_gen/utils/tags.py +++ b/embodied_gen/utils/tags.py @@ -1 +1 @@ -VERSION = "v0.1.1" +VERSION = "v0.1.2" diff --git a/embodied_gen/validators/quality_checkers.py b/embodied_gen/validators/quality_checkers.py index 88636f6..186346f 100644 --- a/embodied_gen/validators/quality_checkers.py +++ b/embodied_gen/validators/quality_checkers.py @@ -33,6 +33,9 @@ __all__ = [ "ImageAestheticChecker", "SemanticConsistChecker", "TextGenAlignChecker", + "PanoImageGenChecker", + "PanoHeightEstimator", + "PanoImageOccChecker", ] @@ -328,6 +331,159 @@ class TextGenAlignChecker(BaseChecker): ) +class PanoImageGenChecker(BaseChecker): + """A checker class that validates the quality and realism of generated panoramic indoor images. + + Attributes: + gpt_client (GPTclient): A GPT client instance used to query for image validation. + prompt (str): The instruction prompt passed to the GPT model. If None, a default prompt is used. + verbose (bool): Whether to print internal processing information for debugging. + """ + + def __init__( + self, + gpt_client: GPTclient, + prompt: str = None, + verbose: bool = False, + ) -> None: + super().__init__(prompt, verbose) + self.gpt_client = gpt_client + if self.prompt is None: + self.prompt = """ + You are a panoramic image analyzer specializing in indoor room structure validation. + + Given a generated panoramic image, assess if it meets all the criteria: + - Floor Space: ≥30 percent of the floor is free of objects or obstructions. + - Visual Clarity: Floor, walls, and ceiling are clear, with no distortion, blur, noise. + - Structural Continuity: Surfaces form plausible, continuous geometry + without breaks, floating parts, or abrupt cuts. + - Spatial Completeness: Full 360° coverage without missing areas, + seams, gaps, or stitching artifacts. + Instructions: + - If all criteria are met, reply with "YES". + - Otherwise, reply with "NO: " (max 20 words). + + Respond exactly as: + "YES" + or + "NO: brief explanation." + """ + + def query(self, image_paths: str | Image.Image) -> str: + + return self.gpt_client.query( + text_prompt=self.prompt, + image_base64=image_paths, + ) + + +class PanoImageOccChecker(BaseChecker): + """Checks for physical obstacles in the bottom-center region of a panoramic image. + + This class crops a specified region from the input panoramic image and uses + a GPT client to determine whether any physical obstacles there. + + Args: + gpt_client (GPTclient): The GPT-based client used for visual reasoning. + box_hw (tuple[int, int]): The height and width of the crop box. + prompt (str, optional): Custom prompt for the GPT client. Defaults to a predefined one. + verbose (bool, optional): Whether to print verbose logs. Defaults to False. + """ + + def __init__( + self, + gpt_client: GPTclient, + box_hw: tuple[int, int], + prompt: str = None, + verbose: bool = False, + ) -> None: + super().__init__(prompt, verbose) + self.gpt_client = gpt_client + self.box_hw = box_hw + if self.prompt is None: + self.prompt = """ + This image is a cropped region from the bottom-center of a panoramic view. + Please determine whether there is any obstacle present — such as furniture, tables, or other physical objects. + Ignore floor textures, rugs, carpets, shadows, and lighting effects — they do not count as obstacles. + Only consider real, physical objects that could block walking or movement. + + Instructions: + - If there is no obstacle, reply: "YES". + - Otherwise, reply: "NO: " (max 20 words). + + Respond exactly as: + "YES" + or + "NO: brief explanation." + """ + + def query(self, image_paths: str | Image.Image) -> str: + if isinstance(image_paths, str): + image_paths = Image.open(image_paths) + + w, h = image_paths.size + image_paths = image_paths.crop( + ( + (w - self.box_hw[1]) // 2, + h - self.box_hw[0], + (w + self.box_hw[1]) // 2, + h, + ) + ) + + return self.gpt_client.query( + text_prompt=self.prompt, + image_base64=image_paths, + ) + + +class PanoHeightEstimator(object): + """Estimate the real ceiling height of an indoor space from a 360° panoramic image. + + Attributes: + gpt_client (GPTclient): The GPT client used to perform image-based reasoning and return height estimates. + default_value (float): The fallback height in meters if parsing the GPT output fails. + prompt (str): The textual instruction used to guide the GPT model for height estimation. + """ + + def __init__( + self, + gpt_client: GPTclient, + default_value: float = 3.5, + ) -> None: + self.gpt_client = gpt_client + self.default_value = default_value + self.prompt = """ + You are an expert in building height estimation and panoramic image analysis. + Your task is to analyze a 360° indoor panoramic image and estimate the **actual height** of the space in meters. + + Consider the following visual cues: + 1. Ceiling visibility and reference objects (doors, windows, furniture, appliances). + 2. Floor features or level differences. + 3. Room type (e.g., residential, office, commercial). + 4. Object-to-ceiling proportions (e.g., height of doors relative to ceiling). + 5. Architectural elements (e.g., chandeliers, shelves, kitchen cabinets). + + Input: A full 360° panoramic indoor photo. + Output: A single number in meters representing the estimated room height. Only return the number (e.g., `3.2`) + """ + + def __call__(self, image_paths: str | Image.Image) -> float: + result = self.gpt_client.query( + text_prompt=self.prompt, + image_base64=image_paths, + ) + try: + result = float(result.strip()) + except ValueError: + logger.error( + f"Parser error: failed convert {result} to float, use default value {self.default_value}." + ) + result = self.default_value + + return result + + class SemanticMatcher(BaseChecker): def __init__( self, diff --git a/install.sh b/install.sh index 2568aa9..f063894 100644 --- a/install.sh +++ b/install.sh @@ -1,65 +1,28 @@ #!/bin/bash set -e -RED='\033[0;31m' -GREEN='\033[0;32m' -NC='\033[0m' +STAGE=$1 # "basic" | "extra" | "all" +STAGE=${STAGE:-all} -echo -e "${GREEN}Starting installation process...${NC}" +SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) + +source "$SCRIPT_DIR/install/_utils.sh" git config --global http.postBuffer 524288000 +# Patch submodule .gitignore to ignore __pycache__, only if submodule exists +PANO2ROOM_PATH="$SCRIPT_DIR/thirdparty/pano2room" +if [ -d "$PANO2ROOM_PATH" ]; then + echo "__pycache__/" > "$PANO2ROOM_PATH/.gitignore" + log_info "Added .gitignore to ignore __pycache__ in $PANO2ROOM_PATH" +fi -echo -e "${GREEN}Installing flash-attn...${NC}" -pip install flash-attn==2.7.0.post2 --no-build-isolation || { - echo -e "${RED}Failed to install flash-attn${NC}" - exit 1 -} +log_info "===== Starting installation stage: $STAGE =====" -echo -e "${GREEN}Installing dependencies from requirements.txt...${NC}" -pip install -r requirements.txt --use-deprecated=legacy-resolver --default-timeout=60 || { - echo -e "${RED}Failed to install requirements${NC}" - exit 1 -} +if [[ "$STAGE" == "basic" || "$STAGE" == "all" ]]; then + bash "$SCRIPT_DIR/install/install_basic.sh" +fi +if [[ "$STAGE" == "extra" || "$STAGE" == "all" ]]; then + bash "$SCRIPT_DIR/install/install_extra.sh" +fi -echo -e "${GREEN}Installing kolors from GitHub...${NC}" -pip install kolors@git+https://github.com/Kwai-Kolors/Kolors.git#egg=038818d || { - echo -e "${RED}Failed to install kolors${NC}" - exit 1 -} - - -echo -e "${GREEN}Installing kaolin from GitHub...${NC}" -pip install kaolin@git+https://github.com/NVIDIAGameWorks/kaolin.git@v0.16.0 || { - echo -e "${RED}Failed to install kaolin${NC}" - exit 1 -} - - -echo -e "${GREEN}Installing diff-gaussian-rasterization...${NC}" -TMP_DIR="/tmp/mip-splatting" -rm -rf "$TMP_DIR" -git clone --recursive https://github.com/autonomousvision/mip-splatting.git "$TMP_DIR" && \ -pip install "$TMP_DIR/submodules/diff-gaussian-rasterization" && \ -rm -rf "$TMP_DIR" || { - echo -e "${RED}Failed to clone or install diff-gaussian-rasterization${NC}" - rm -rf "$TMP_DIR" - exit 1 -} - - -echo -e "${GREEN}Installing gsplat from GitHub...${NC}" -pip install git+https://github.com/nerfstudio-project/gsplat.git@v1.5.0 || { - echo -e "${RED}Failed to install gsplat${NC}" - exit 1 -} - - -echo -e "${GREEN}Installing EmbodiedGen...${NC}" -pip install triton==2.1.0 -pip install -e . || { - echo -e "${RED}Failed to install EmbodiedGen pyproject.toml${NC}" - exit 1 -} - -echo -e "${GREEN}Installation completed successfully!${NC}" - +log_info "===== Installation completed successfully. =====" diff --git a/install/_utils.sh b/install/_utils.sh new file mode 100644 index 0000000..3fa8879 --- /dev/null +++ b/install/_utils.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +RED='\033[0;31m' +GREEN='\033[0;32m' +NC='\033[0m' + +log_info() { + echo -e "${GREEN}[INFO] $1${NC}" +} + +log_error() { + echo -e "${RED}[ERROR] $1${NC}" >&2 +} + +try_install() { + log_info "$1" + eval "$2" || { + log_error "$3" + exit 1 + } +} diff --git a/install/install_basic.sh b/install/install_basic.sh new file mode 100644 index 0000000..26fce27 --- /dev/null +++ b/install/install_basic.sh @@ -0,0 +1,35 @@ +#!/bin/bash +set -e +SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +source "$SCRIPT_DIR/_utils.sh" + +try_install "Installing flash-attn..." \ + "pip install flash-attn==2.7.0.post2 --no-build-isolation" \ + "flash-attn installation failed." + +try_install "Installing requirements.txt..." \ + "pip install -r requirements.txt --use-deprecated=legacy-resolver --default-timeout=60" \ + "requirements installation failed." + +try_install "Installing kolors..." \ + "pip install kolors@git+https://github.com/HochCC/Kolors.git" \ + "kolors installation failed." + +try_install "Installing kaolin..." \ + "pip install kaolin@git+https://github.com/NVIDIAGameWorks/kaolin.git@v0.16.0" \ + "kaolin installation failed." + +log_info "Installing diff-gaussian-rasterization..." +TMP_DIR="/tmp/mip-splatting" +rm -rf "$TMP_DIR" +git clone --recursive https://github.com/autonomousvision/mip-splatting.git "$TMP_DIR" +pip install "$TMP_DIR/submodules/diff-gaussian-rasterization" +rm -rf "$TMP_DIR" + +try_install "Installing gsplat..." \ + "pip install git+https://github.com/nerfstudio-project/gsplat.git@v1.5.3" \ + "gsplat installation failed." + +try_install "Installing EmbodiedGen..." \ + "pip install triton==2.1.0 --no-deps && pip install -e ." \ + "EmbodiedGen installation failed." diff --git a/install/install_extra.sh b/install/install_extra.sh new file mode 100644 index 0000000..1af2dd3 --- /dev/null +++ b/install/install_extra.sh @@ -0,0 +1,52 @@ +#!/bin/bash +set -e +SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +source "$SCRIPT_DIR/_utils.sh" + +# try_install "Installing txt2panoimg..." \ +# "pip install txt2panoimg@git+https://github.com/HochCC/SD-T2I-360PanoImage --no-deps" \ +# "txt2panoimg installation failed." + +# try_install "Installing fused-ssim..." \ +# "pip install fused-ssim@git+https://github.com/rahul-goel/fused-ssim#egg=328dc98" \ +# "fused-ssim installation failed." + +# try_install "Installing tiny-cuda-nn..." \ +# "pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch" \ +# "tiny-cuda-nn installation failed." + +# try_install "Installing pytorch3d" \ +# "pip install git+https://github.com/facebookresearch/pytorch3d.git@v0.7.7" \ +# "pytorch3d installation failed." + + +PYTHON_PACKAGES_NODEPS=( + timm + txt2panoimg@git+https://github.com/HochCC/SD-T2I-360PanoImage + kornia + kornia_rs +) + +PYTHON_PACKAGES=( + fused-ssim@git+https://github.com/rahul-goel/fused-ssim#egg=328dc98 + git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch + git+https://github.com/facebookresearch/pytorch3d.git@v0.7.7 + h5py + albumentations==0.5.2 + webdataset + icecream + open3d + pyequilib + numpy==1.26.4 + triton==2.1.0 +) + +for pkg in "${PYTHON_PACKAGES_NODEPS[@]}"; do + try_install "Installing $pkg without dependencies..." \ + "pip install --no-deps $pkg" \ + "$pkg installation failed." +done + +try_install "Installing other Python dependencies..." \ + "pip install ${PYTHON_PACKAGES[*]}" \ + "Python dependencies installation failed." diff --git a/pyproject.toml b/pyproject.toml index adfbc51..e50931e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ packages = ["embodied_gen"] [project] name = "embodied_gen" -version = "v0.1.1" +version = "v0.1.2" readme = "README.md" license = "Apache-2.0" license-files = ["LICENSE", "NOTICE"] @@ -31,6 +31,7 @@ drender-cli = "embodied_gen.data.differentiable_render:entrypoint" backproject-cli = "embodied_gen.data.backproject_v2:entrypoint" img3d-cli = "embodied_gen.scripts.imageto3d:entrypoint" text3d-cli = "embodied_gen.scripts.textto3d:text_to_3d" +scene3d-cli = "embodied_gen.scripts.gen_scene3d:entrypoint" [tool.pydocstyle] match = '(?!test_).*(?!_pb2)\.py' diff --git a/requirements.txt b/requirements.txt index 0c62681..47854cf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,6 +32,9 @@ vtk==9.3.1 spaces colorlog json-repair +scikit-learn +omegaconf +tyro utils3d@git+https://github.com/EasternJournalist/utils3d.git#egg=9a4eb15 clip@git+https://github.com/openai/CLIP.git segment-anything@git+https://github.com/facebookresearch/segment-anything.git#egg=dca509f diff --git a/tests/test_examples/test_quality_checkers.py b/tests/test_examples/test_quality_checkers.py index 4031c01..6cd3fad 100644 --- a/tests/test_examples/test_quality_checkers.py +++ b/tests/test_examples/test_quality_checkers.py @@ -17,6 +17,7 @@ import logging import tempfile +from glob import glob import pytest from embodied_gen.utils.gpt_clients import GPT_CLIENT @@ -25,6 +26,9 @@ from embodied_gen.validators.quality_checkers import ( ImageAestheticChecker, ImageSegChecker, MeshGeoChecker, + PanoHeightEstimator, + PanoImageGenChecker, + PanoImageOccChecker, SemanticConsistChecker, TextGenAlignChecker, ) @@ -57,6 +61,21 @@ def textalign_checker(): return TextGenAlignChecker(GPT_CLIENT) +@pytest.fixture(scope="module") +def pano_checker(): + return PanoImageGenChecker(GPT_CLIENT) + + +@pytest.fixture(scope="module") +def pano_height_estimator(): + return PanoHeightEstimator(GPT_CLIENT) + + +@pytest.fixture(scope="module") +def panoocc_checker(): + return PanoImageOccChecker(GPT_CLIENT, box_hw=[90, 1000]) + + def test_geo_checker(geo_checker): flag, result = geo_checker( [ @@ -117,3 +136,28 @@ def test_textgen_checker(textalign_checker, mesh_path, text_desc): ) flag, result = textalign_checker(text_desc, image_list) logger.info(f"textalign_checker: {flag}, {result}") + + +def test_panoheight_estimator(pano_height_estimator): + image_paths = glob("outputs/bg_v3/test2/*/*.png") + for image_path in image_paths: + result = pano_height_estimator(image_path) + logger.info(f"{type(result)}, {result}") + + +def test_pano_checker(pano_checker): + # image_paths = [ + # "outputs/bg_gen2/scene_0000/pano_image.png", + # "outputs/bg_gen2/scene_0001/pano_image.png", + # ] + image_paths = glob("outputs/bg_gen/*/*.png") + for image_path in image_paths: + flag, result = pano_checker(image_path) + logger.info(f"{image_path} {flag}, {result}") + + +def test_panoocc_checker(panoocc_checker): + image_paths = glob("outputs/bg_gen/*/*.png") + for image_path in image_paths: + flag, result = panoocc_checker(image_path) + logger.info(f"{image_path} {flag}, {result}") diff --git a/tests/test_unit/test_agents.py b/tests/test_unit/test_agents.py index f49fab5..79fe9c4 100644 --- a/tests/test_unit/test_agents.py +++ b/tests/test_unit/test_agents.py @@ -23,6 +23,8 @@ from embodied_gen.utils.gpt_clients import GPT_CLIENT from embodied_gen.validators.quality_checkers import ( ImageSegChecker, MeshGeoChecker, + PanoHeightEstimator, + PanoImageGenChecker, SemanticConsistChecker, ) @@ -93,3 +95,16 @@ def test_semantic_checker(gptclient_query_case2): ) assert isinstance(flag, (bool, type(None))) assert isinstance(result, str) + + +def test_panoheight_estimator(): + checker = PanoHeightEstimator(GPT_CLIENT, default_value=3.5) + result = checker(image_paths="dummy_path/pano.png") + assert isinstance(result, float) + + +def test_panogen_checker(): + checker = PanoImageGenChecker(GPT_CLIENT) + flag, result = checker(image_paths="dummy_path/pano.png") + assert isinstance(flag, (bool, type(None))) + assert isinstance(result, str) diff --git a/thirdparty/pano2room b/thirdparty/pano2room new file mode 160000 index 0000000..bbf93ae --- /dev/null +++ b/thirdparty/pano2room @@ -0,0 +1 @@ +Subproject commit bbf93ae57086ed700edc6ee445852d4457a9d704