feat(pipe): Release 3D scene generation pipeline. (#25)

Release 3D scene generation pipeline and tag as v0.1.2.

---------

Co-authored-by: xinjie.wang <xinjie.wang@gpu-4090-dev015.hogpu.cc>
This commit is contained in:
Xinjie 2025-07-21 23:31:15 +08:00 committed by GitHub
parent 51759f011a
commit e82f02a9a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 2577 additions and 109 deletions

2
.gitignore vendored
View File

@ -59,4 +59,4 @@ output*
scripts/tools/
weights
apps/sessions/
apps/assets/
apps/assets/

7
.gitmodules vendored
View File

@ -2,4 +2,9 @@
path = thirdparty/TRELLIS
url = https://github.com/microsoft/TRELLIS.git
branch = main
shallow = true
shallow = true
[submodule "thirdparty/pano2room"]
path = thirdparty/pano2room
url = https://github.com/TrickyGo/Pano2Room.git
branch = main
shallow = true

View File

@ -1,6 +1,6 @@
repos:
- repo: git@gitlab.hobot.cc:ptd/3rd/pre-commit/pre-commit-hooks.git
rev: v2.3.0 # Use the ref you want to point at
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.2.0 # Use the ref you want to point at
hooks:
- id: trailing-whitespace
- id: check-added-large-files

View File

@ -30,11 +30,11 @@
```sh
git clone https://github.com/HorizonRobotics/EmbodiedGen.git
cd EmbodiedGen
git checkout v0.1.1
git checkout v0.1.2
git submodule update --init --recursive --progress
conda create -n embodiedgen python=3.10.13 -y
conda create -n embodiedgen python=3.10.13 -y # recommended to use a new env.
conda activate embodiedgen
bash install.sh
bash install.sh basic
```
### ✅ Setup GPT Agent
@ -94,7 +94,7 @@ python apps/text_to_3d.py
### ⚡ API
Text-to-image model based on SD3.5 Medium, English prompts only.
Usage requires agreement to the [model license(click accept)](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium), models downloaded automatically. (ps: models with more permissive licenses found in `embodied_gen/models/image_comm_model.py`)
Usage requires agreement to the [model license(click accept)](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium), models downloaded automatically.
For large-scale 3D assets generation, set `--n_pipe_retry=2` to ensure high end-to-end 3D asset usability through automatic quality check and retries. For more diverse results, do not set `--seed_img`.
@ -110,6 +110,7 @@ bash embodied_gen/scripts/textto3d.sh \
--prompts "small bronze figurine of a lion" "A globe with wooden base and latitude and longitude lines" "橙色电动手钻,有磨损细节" \
--output_root outputs/textto3d_k
```
ps: models with more permissive licenses found in `embodied_gen/models/image_comm_model.py`
---
@ -146,10 +147,22 @@ bash embodied_gen/scripts/texture_gen.sh \
<h2 id="3d-scene-generation">🌍 3D Scene Generation</h2>
🚧 *Coming Soon*
<img src="apps/assets/scene3d.gif" alt="scene3d" style="width: 640px;">
### ⚡ API
> Run `bash install.sh extra` to install additional requirements if you need to use `scene3d-cli`.
It takes ~30mins to generate a color mesh and 3DGS per scene.
```sh
CUDA_VISIBLE_DEVICES=0 scene3d-cli \
--prompts "Art studio with easel and canvas" \
--output_dir outputs/bg_scenes/ \
--seed 0 \
--gs3d.max_steps 4000 \
--disable_pano_check
```
---
@ -189,7 +202,7 @@ bash embodied_gen/scripts/texture_gen.sh \
## For Developer
```sh
pip install .[dev] && pre-commit install
pip install -e .[dev] && pre-commit install
python -m pytest # Pass all unit-test are required.
```

View File

@ -94,9 +94,6 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
os.environ["SPCONV_ALGO"] = "native"
MAX_SEED = 100000
DELIGHT = DelightingModel()
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
# IMAGESR_MODEL = ImageStableSR()
def patched_setup_functions(self):
@ -136,6 +133,9 @@ def patched_setup_functions(self):
Gaussian.setup_functions = patched_setup_functions
DELIGHT = DelightingModel()
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
# IMAGESR_MODEL = ImageStableSR()
if os.getenv("GRADIO_APP") == "imageto3d":
RBG_REMOVER = RembgRemover()
RBG14_REMOVER = BMGG14Remover()

View File

@ -19,8 +19,9 @@ import json
import logging
import os
import random
from typing import Any, Callable, Dict, List, Tuple
from typing import Any, Callable, Dict, List, Literal, Tuple
import numpy as np
import torch
import torch.utils.checkpoint
from PIL import Image
@ -36,6 +37,7 @@ logger = logging.getLogger(__name__)
__all__ = [
"Asset3dGenDataset",
"PanoGSplatDataset",
]
@ -222,6 +224,68 @@ class Asset3dGenDataset(Dataset):
return data
class PanoGSplatDataset(Dataset):
"""A PyTorch Dataset for loading panorama-based 3D Gaussian Splatting data.
This dataset is designed to be compatible with train and eval pipelines
that use COLMAP-style camera conventions.
Args:
data_dir (str): Root directory where the dataset file is located.
split (str): Dataset split to use, either "train" or "eval".
data_name (str, optional): Name of the dataset file (default: "gs_data.pt").
max_sample_num (int, optional): Maximum number of samples to load. If None,
all available samples in the split will be used.
"""
def __init__(
self,
data_dir: str,
split: str = Literal["train", "eval"],
data_name: str = "gs_data.pt",
max_sample_num: int = None,
) -> None:
self.data_path = os.path.join(data_dir, data_name)
self.split = split
self.max_sample_num = max_sample_num
if not os.path.exists(self.data_path):
raise FileNotFoundError(
f"Dataset file {self.data_path} not found. Please provide the correct path."
)
self.data = torch.load(self.data_path, weights_only=False)
self.frames = self.data[split]
if max_sample_num is not None:
self.frames = self.frames[:max_sample_num]
self.points = self.data.get("points", None)
self.points_rgb = self.data.get("points_rgb", None)
def __len__(self) -> int:
return len(self.frames)
def cvt_blender_to_colmap_coord(self, c2w: np.ndarray) -> np.ndarray:
# change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
tranformed_c2w = np.copy(c2w)
tranformed_c2w[:3, 1:3] *= -1
return tranformed_c2w
def __getitem__(self, index: int) -> dict[str, any]:
data = self.frames[index]
c2w = self.cvt_blender_to_colmap_coord(data["camtoworld"])
item = dict(
camtoworld=c2w,
K=data["K"],
image_h=data["image_h"],
image_w=data["image_w"],
)
if "image" in data:
item["image"] = data["image"]
if "image_id" in data:
item["image_id"] = data["image_id"]
return item
if __name__ == "__main__":
index_file = "datasets/objaverse/v1.0/statistics_1.0_gobjaverse_filter/view6s_v4/meta_ac2e0ddea8909db26d102c8465b5bcb2.json" # noqa
target_hw = (512, 512)

View File

@ -158,8 +158,9 @@ class DiffrastRender(object):
return normalized_maps
@staticmethod
def normalize_map_by_mask(
self, map: torch.Tensor, mask: torch.Tensor
map: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
# Normalize all maps in total by mask, normalized map in [0, 1].
foreground = (mask == 1).squeeze(dim=-1)

View File

@ -0,0 +1,191 @@
import logging
import os
import random
import time
import warnings
from dataclasses import dataclass, field
from shutil import copy, rmtree
import torch
import tyro
from huggingface_hub import snapshot_download
from packaging import version
# Suppress warnings
warnings.filterwarnings("ignore", category=FutureWarning)
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("diffusers").setLevel(logging.ERROR)
# TorchVision monkey patch for >0.16
if version.parse(torch.__version__) >= version.parse("0.16"):
import sys
import types
import torchvision.transforms.functional as TF
functional_tensor = types.ModuleType(
"torchvision.transforms.functional_tensor"
)
functional_tensor.rgb_to_grayscale = TF.rgb_to_grayscale
sys.modules["torchvision.transforms.functional_tensor"] = functional_tensor
from gsplat.distributed import cli
from txt2panoimg import Text2360PanoramaImagePipeline
from embodied_gen.trainer.gsplat_trainer import (
DefaultStrategy,
GsplatTrainConfig,
)
from embodied_gen.trainer.gsplat_trainer import entrypoint as gsplat_entrypoint
from embodied_gen.trainer.pono2mesh_trainer import Pano2MeshSRPipeline
from embodied_gen.utils.config import Pano2MeshSRConfig
from embodied_gen.utils.gaussian import restore_scene_scale_and_position
from embodied_gen.utils.gpt_clients import GPT_CLIENT
from embodied_gen.utils.log import logger
from embodied_gen.utils.process_media import is_image_file, parse_text_prompts
from embodied_gen.validators.quality_checkers import (
PanoHeightEstimator,
PanoImageOccChecker,
)
__all__ = [
"generate_pano_image",
"entrypoint",
]
@dataclass
class Scene3DGenConfig:
prompts: list[str] # Text desc of indoor room or style reference image.
output_dir: str
seed: int | None = None
real_height: float | None = None # The real height of the room in meters.
pano_image_only: bool = False
disable_pano_check: bool = False
keep_middle_result: bool = False
n_retry: int = 7
gs3d: GsplatTrainConfig = field(
default_factory=lambda: GsplatTrainConfig(
strategy=DefaultStrategy(verbose=True),
max_steps=4000,
init_opa=0.9,
opacity_reg=2e-3,
sh_degree=0,
means_lr=1e-4,
scales_lr=1e-3,
)
)
def generate_pano_image(
prompt: str,
output_path: str,
pipeline,
seed: int,
n_retry: int,
checker=None,
num_inference_steps: int = 40,
) -> None:
for i in range(n_retry):
logger.info(
f"GEN Panorama: Retry {i+1}/{n_retry} for prompt: {prompt}, seed: {seed}"
)
if is_image_file(prompt):
raise NotImplementedError("Image mode not implemented yet.")
else:
txt_prompt = f"{prompt}, spacious, empty, wide open, open floor, minimal furniture"
inputs = {
"prompt": txt_prompt,
"num_inference_steps": num_inference_steps,
"upscale": False,
"seed": seed,
}
pano_image = pipeline(inputs)
pano_image.save(output_path)
if checker is None:
break
flag, response = checker(pano_image)
logger.warning(f"{response}, image saved in {output_path}")
if flag is True or flag is None:
break
seed = random.randint(0, 100000)
return
def entrypoint(*args, **kwargs):
cfg = tyro.cli(Scene3DGenConfig)
# Init global models.
model_path = snapshot_download("archerfmy0831/sd-t2i-360panoimage")
IMG2PANO_PIPE = Text2360PanoramaImagePipeline(
model_path, torch_dtype=torch.float16, device="cuda"
)
PANOMESH_CFG = Pano2MeshSRConfig()
PANO2MESH_PIPE = Pano2MeshSRPipeline(PANOMESH_CFG)
PANO_CHECKER = PanoImageOccChecker(GPT_CLIENT, box_hw=[95, 1000])
PANOHEIGHT_ESTOR = PanoHeightEstimator(GPT_CLIENT)
prompts = parse_text_prompts(cfg.prompts)
for idx, prompt in enumerate(prompts):
start_time = time.time()
output_dir = os.path.join(cfg.output_dir, f"scene_{idx:04d}")
os.makedirs(output_dir, exist_ok=True)
pano_path = os.path.join(output_dir, "pano_image.png")
with open(f"{output_dir}/prompt.txt", "w") as f:
f.write(prompt)
generate_pano_image(
prompt,
pano_path,
IMG2PANO_PIPE,
cfg.seed if cfg.seed is not None else random.randint(0, 100000),
cfg.n_retry,
checker=None if cfg.disable_pano_check else PANO_CHECKER,
)
if cfg.pano_image_only:
continue
logger.info("GEN and REPAIR Mesh from Panorama...")
PANO2MESH_PIPE(pano_path, output_dir)
logger.info("TRAIN 3DGS from Mesh Init and Cube Image...")
cfg.gs3d.data_dir = output_dir
cfg.gs3d.result_dir = f"{output_dir}/gaussian"
cfg.gs3d.adjust_steps(cfg.gs3d.steps_scaler)
torch.set_default_device("cpu") # recover default setting.
cli(gsplat_entrypoint, cfg.gs3d, verbose=True)
# Clean up the middle results.
gs_path = (
f"{cfg.gs3d.result_dir}/ply/point_cloud_{cfg.gs3d.max_steps-1}.ply"
)
copy(gs_path, f"{output_dir}/gs_model.ply")
video_path = f"{cfg.gs3d.result_dir}/renders/video_step{cfg.gs3d.max_steps-1}.mp4"
copy(video_path, f"{output_dir}/video.mp4")
gs_cfg_path = f"{cfg.gs3d.result_dir}/cfg.yml"
copy(gs_cfg_path, f"{output_dir}/gsplat_cfg.yml")
if not cfg.keep_middle_result:
rmtree(cfg.gs3d.result_dir, ignore_errors=True)
os.remove(f"{output_dir}/{PANOMESH_CFG.gs_data_file}")
real_height = (
PANOHEIGHT_ESTOR(pano_path)
if cfg.real_height is None
else cfg.real_height
)
gs_path = os.path.join(output_dir, "gs_model.ply")
mesh_path = os.path.join(output_dir, "mesh_model.ply")
restore_scene_scale_and_position(real_height, mesh_path, gs_path)
elapsed_time = (time.time() - start_time) / 60
logger.info(
f"FINISHED 3D scene generation in {output_dir} in {elapsed_time:.2f} mins."
)
if __name__ == "__main__":
entrypoint()

View File

@ -62,25 +62,6 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
os.environ["SPCONV_ALGO"] = "native"
random.seed(0)
logger.info("Loading Models...")
DELIGHT = DelightingModel()
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
RBG_REMOVER = RembgRemover()
RBG14_REMOVER = BMGG14Remover()
SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
"microsoft/TRELLIS-image-large"
)
# PIPELINE.cuda()
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
AESTHETIC_CHECKER = ImageAestheticChecker()
CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
TMP_DIR = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
)
def parse_args():
parser = argparse.ArgumentParser(description="Image to 3D pipeline args.")
@ -128,6 +109,19 @@ def entrypoint(**kwargs):
if hasattr(args, k) and v is not None:
setattr(args, k, v)
logger.info("Loading Models...")
DELIGHT = DelightingModel()
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
RBG_REMOVER = RembgRemover()
PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
"microsoft/TRELLIS-image-large"
)
# PIPELINE.cuda()
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
AESTHETIC_CHECKER = ImageAestheticChecker()
CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
assert (
args.image_path or args.image_root
), "Please provide either --image_path or --image_root."

View File

@ -31,6 +31,7 @@ from embodied_gen.models.text_model import (
build_text2img_pipeline,
text2img_gen,
)
from embodied_gen.utils.process_media import parse_text_prompts
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@ -101,14 +102,7 @@ def entrypoint(
if hasattr(args, k) and v is not None:
setattr(args, k, v)
prompts = args.prompts
if len(prompts) == 1 and prompts[0].endswith(".txt"):
with open(prompts[0], "r") as f:
prompts = f.readlines()
prompts = [
prompt.strip() for prompt in prompts if prompt.strip() != ""
]
prompts = parse_text_prompts(args.prompts)
os.makedirs(args.output_root, exist_ok=True)
ip_img_paths = args.ref_image

View File

@ -44,13 +44,6 @@ __all__ = [
"text_to_3d",
]
logger.info("Loading Models...")
SEMANTIC_CHECKER = SemanticConsistChecker(GPT_CLIENT)
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
TXTGEN_CHECKER = TextGenAlignChecker(GPT_CLIENT)
PIPE_IMG = build_hf_image_pipeline("sd35")
BG_REMOVER = RembgRemover()
def text_to_image(
prompt: str,
@ -121,6 +114,14 @@ def text_to_3d(**kwargs) -> dict:
if hasattr(args, k) and v is not None:
setattr(args, k, v)
logger.info("Loading Models...")
global SEMANTIC_CHECKER, SEG_CHECKER, TXTGEN_CHECKER, PIPE_IMG, BG_REMOVER
SEMANTIC_CHECKER = SemanticConsistChecker(GPT_CLIENT)
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
TXTGEN_CHECKER = TextGenAlignChecker(GPT_CLIENT)
PIPE_IMG = build_hf_image_pipeline(args.text_model)
BG_REMOVER = RembgRemover()
if args.asset_names is None or len(args.asset_names) == 0:
args.asset_names = [f"sample3d_{i}" for i in range(len(args.prompts))]
img_save_dir = os.path.join(args.output_root, "images")
@ -260,6 +261,11 @@ def parse_args():
default=0,
help="Random seed for 3D generation",
)
parser.add_argument(
"--text_model",
type=str,
default="sd35",
)
parser.add_argument("--keep_intermediate", action="store_true")
args, unknown = parser.parse_known_args()

View File

@ -0,0 +1,678 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
# Part of the code comes from https://github.com/nerfstudio-project/gsplat/blob/main/examples/simple_trainer.py
# Both under the Apache License, Version 2.0.
import json
import os
import time
from collections import defaultdict
from typing import Dict, Optional, Tuple
import cv2
import imageio
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
import tyro
import yaml
from fused_ssim import fused_ssim
from gsplat.distributed import cli
from gsplat.rendering import rasterization
from gsplat.strategy import DefaultStrategy, MCMCStrategy
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from torchmetrics.image import (
PeakSignalNoiseRatio,
StructuralSimilarityIndexMeasure,
)
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from typing_extensions import Literal, assert_never
from embodied_gen.data.datasets import PanoGSplatDataset
from embodied_gen.utils.config import GsplatTrainConfig
from embodied_gen.utils.gaussian import (
create_splats_with_optimizers,
export_splats,
resize_pinhole_intrinsics,
set_random_seed,
)
class Runner:
"""Engine for training and testing from gsplat example.
Code from https://github.com/nerfstudio-project/gsplat/blob/main/examples/simple_trainer.py
"""
def __init__(
self,
local_rank: int,
world_rank,
world_size: int,
cfg: GsplatTrainConfig,
) -> None:
set_random_seed(42 + local_rank)
self.cfg = cfg
self.world_rank = world_rank
self.local_rank = local_rank
self.world_size = world_size
self.device = f"cuda:{local_rank}"
# Where to dump results.
os.makedirs(cfg.result_dir, exist_ok=True)
# Setup output directories.
self.ckpt_dir = f"{cfg.result_dir}/ckpts"
os.makedirs(self.ckpt_dir, exist_ok=True)
self.stats_dir = f"{cfg.result_dir}/stats"
os.makedirs(self.stats_dir, exist_ok=True)
self.render_dir = f"{cfg.result_dir}/renders"
os.makedirs(self.render_dir, exist_ok=True)
self.ply_dir = f"{cfg.result_dir}/ply"
os.makedirs(self.ply_dir, exist_ok=True)
# Tensorboard
self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb")
self.trainset = PanoGSplatDataset(cfg.data_dir, split="train")
self.valset = PanoGSplatDataset(
cfg.data_dir, split="train", max_sample_num=6
)
self.testset = PanoGSplatDataset(cfg.data_dir, split="eval")
self.scene_scale = cfg.scene_scale
# Model
self.splats, self.optimizers = create_splats_with_optimizers(
self.trainset.points,
self.trainset.points_rgb,
init_num_pts=cfg.init_num_pts,
init_extent=cfg.init_extent,
init_opacity=cfg.init_opa,
init_scale=cfg.init_scale,
means_lr=cfg.means_lr,
scales_lr=cfg.scales_lr,
opacities_lr=cfg.opacities_lr,
quats_lr=cfg.quats_lr,
sh0_lr=cfg.sh0_lr,
shN_lr=cfg.shN_lr,
scene_scale=self.scene_scale,
sh_degree=cfg.sh_degree,
sparse_grad=cfg.sparse_grad,
visible_adam=cfg.visible_adam,
batch_size=cfg.batch_size,
feature_dim=None,
device=self.device,
world_rank=world_rank,
world_size=world_size,
)
print("Model initialized. Number of GS:", len(self.splats["means"]))
# Densification Strategy
self.cfg.strategy.check_sanity(self.splats, self.optimizers)
if isinstance(self.cfg.strategy, DefaultStrategy):
self.strategy_state = self.cfg.strategy.initialize_state(
scene_scale=self.scene_scale
)
elif isinstance(self.cfg.strategy, MCMCStrategy):
self.strategy_state = self.cfg.strategy.initialize_state()
else:
assert_never(self.cfg.strategy)
# Losses & Metrics.
self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(
self.device
)
self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device)
if cfg.lpips_net == "alex":
self.lpips = LearnedPerceptualImagePatchSimilarity(
net_type="alex", normalize=True
).to(self.device)
elif cfg.lpips_net == "vgg":
# The 3DGS official repo uses lpips vgg, which is equivalent with the following:
self.lpips = LearnedPerceptualImagePatchSimilarity(
net_type="vgg", normalize=False
).to(self.device)
else:
raise ValueError(f"Unknown LPIPS network: {cfg.lpips_net}")
def rasterize_splats(
self,
camtoworlds: Tensor,
Ks: Tensor,
width: int,
height: int,
masks: Optional[Tensor] = None,
rasterize_mode: Optional[Literal["classic", "antialiased"]] = None,
camera_model: Optional[Literal["pinhole", "ortho", "fisheye"]] = None,
**kwargs,
) -> Tuple[Tensor, Tensor, Dict]:
means = self.splats["means"] # [N, 3]
# quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4]
# rasterization does normalization internally
quats = self.splats["quats"] # [N, 4]
scales = torch.exp(self.splats["scales"]) # [N, 3]
opacities = torch.sigmoid(self.splats["opacities"]) # [N,]
image_ids = kwargs.pop("image_ids", None)
colors = torch.cat(
[self.splats["sh0"], self.splats["shN"]], 1
) # [N, K, 3]
if rasterize_mode is None:
rasterize_mode = (
"antialiased" if self.cfg.antialiased else "classic"
)
if camera_model is None:
camera_model = self.cfg.camera_model
render_colors, render_alphas, info = rasterization(
means=means,
quats=quats,
scales=scales,
opacities=opacities,
colors=colors,
viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4]
Ks=Ks, # [C, 3, 3]
width=width,
height=height,
packed=self.cfg.packed,
absgrad=(
self.cfg.strategy.absgrad
if isinstance(self.cfg.strategy, DefaultStrategy)
else False
),
sparse_grad=self.cfg.sparse_grad,
rasterize_mode=rasterize_mode,
distributed=self.world_size > 1,
camera_model=self.cfg.camera_model,
with_ut=self.cfg.with_ut,
with_eval3d=self.cfg.with_eval3d,
**kwargs,
)
if masks is not None:
render_colors[~masks] = 0
return render_colors, render_alphas, info
def train(self):
cfg = self.cfg
device = self.device
world_rank = self.world_rank
# Dump cfg.
if world_rank == 0:
with open(f"{cfg.result_dir}/cfg.yml", "w") as f:
yaml.dump(vars(cfg), f)
max_steps = cfg.max_steps
init_step = 0
schedulers = [
# means has a learning rate schedule, that end at 0.01 of the initial value
torch.optim.lr_scheduler.ExponentialLR(
self.optimizers["means"], gamma=0.01 ** (1.0 / max_steps)
),
]
trainloader = torch.utils.data.DataLoader(
self.trainset,
batch_size=cfg.batch_size,
shuffle=True,
num_workers=4,
persistent_workers=True,
pin_memory=True,
)
trainloader_iter = iter(trainloader)
# Training loop.
global_tic = time.time()
pbar = tqdm.tqdm(range(init_step, max_steps))
for step in pbar:
try:
data = next(trainloader_iter)
except StopIteration:
trainloader_iter = iter(trainloader)
data = next(trainloader_iter)
camtoworlds = data["camtoworld"].to(device) # [1, 4, 4]
Ks = data["K"].to(device) # [1, 3, 3]
pixels = data["image"].to(device) / 255.0 # [1, H, W, 3]
image_ids = data["image_id"].to(device)
masks = (
data["mask"].to(device) if "mask" in data else None
) # [1, H, W]
if cfg.depth_loss:
points = data["points"].to(device) # [1, M, 2]
depths_gt = data["depths"].to(device) # [1, M]
height, width = pixels.shape[1:3]
# sh schedule
sh_degree_to_use = min(
step // cfg.sh_degree_interval, cfg.sh_degree
)
# forward
renders, alphas, info = self.rasterize_splats(
camtoworlds=camtoworlds,
Ks=Ks,
width=width,
height=height,
sh_degree=sh_degree_to_use,
near_plane=cfg.near_plane,
far_plane=cfg.far_plane,
image_ids=image_ids,
render_mode="RGB+ED" if cfg.depth_loss else "RGB",
masks=masks,
)
if renders.shape[-1] == 4:
colors, depths = renders[..., 0:3], renders[..., 3:4]
else:
colors, depths = renders, None
if cfg.random_bkgd:
bkgd = torch.rand(1, 3, device=device)
colors = colors + bkgd * (1.0 - alphas)
self.cfg.strategy.step_pre_backward(
params=self.splats,
optimizers=self.optimizers,
state=self.strategy_state,
step=step,
info=info,
)
# loss
l1loss = F.l1_loss(colors, pixels)
ssimloss = 1.0 - fused_ssim(
colors.permute(0, 3, 1, 2),
pixels.permute(0, 3, 1, 2),
padding="valid",
)
loss = (
l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda
)
if cfg.depth_loss:
# query depths from depth map
points = torch.stack(
[
points[:, :, 0] / (width - 1) * 2 - 1,
points[:, :, 1] / (height - 1) * 2 - 1,
],
dim=-1,
) # normalize to [-1, 1]
grid = points.unsqueeze(2) # [1, M, 1, 2]
depths = F.grid_sample(
depths.permute(0, 3, 1, 2), grid, align_corners=True
) # [1, 1, M, 1]
depths = depths.squeeze(3).squeeze(1) # [1, M]
# calculate loss in disparity space
disp = torch.where(
depths > 0.0, 1.0 / depths, torch.zeros_like(depths)
)
disp_gt = 1.0 / depths_gt # [1, M]
depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale
loss += depthloss * cfg.depth_lambda
# regularizations
if cfg.opacity_reg > 0.0:
loss += (
cfg.opacity_reg
* torch.sigmoid(self.splats["opacities"]).mean()
)
if cfg.scale_reg > 0.0:
loss += cfg.scale_reg * torch.exp(self.splats["scales"]).mean()
loss.backward()
desc = (
f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| "
)
if cfg.depth_loss:
desc += f"depth loss={depthloss.item():.6f}| "
pbar.set_description(desc)
# write images (gt and render)
# if world_rank == 0 and step % 800 == 0:
# canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy()
# canvas = canvas.reshape(-1, *canvas.shape[2:])
# imageio.imwrite(
# f"{self.render_dir}/train_rank{self.world_rank}.png",
# (canvas * 255).astype(np.uint8),
# )
if (
world_rank == 0
and cfg.tb_every > 0
and step % cfg.tb_every == 0
):
mem = torch.cuda.max_memory_allocated() / 1024**3
self.writer.add_scalar("train/loss", loss.item(), step)
self.writer.add_scalar("train/l1loss", l1loss.item(), step)
self.writer.add_scalar("train/ssimloss", ssimloss.item(), step)
self.writer.add_scalar(
"train/num_GS", len(self.splats["means"]), step
)
self.writer.add_scalar("train/mem", mem, step)
if cfg.depth_loss:
self.writer.add_scalar(
"train/depthloss", depthloss.item(), step
)
if cfg.tb_save_image:
canvas = (
torch.cat([pixels, colors], dim=2)
.detach()
.cpu()
.numpy()
)
canvas = canvas.reshape(-1, *canvas.shape[2:])
self.writer.add_image("train/render", canvas, step)
self.writer.flush()
# save checkpoint before updating the model
if (
step in [i - 1 for i in cfg.save_steps]
or step == max_steps - 1
):
mem = torch.cuda.max_memory_allocated() / 1024**3
stats = {
"mem": mem,
"ellipse_time": time.time() - global_tic,
"num_GS": len(self.splats["means"]),
}
print("Step: ", step, stats)
with open(
f"{self.stats_dir}/train_step{step:04d}_rank{self.world_rank}.json",
"w",
) as f:
json.dump(stats, f)
data = {"step": step, "splats": self.splats.state_dict()}
torch.save(
data,
f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt",
)
if (
step in [i - 1 for i in cfg.ply_steps] or step == max_steps - 1
) and cfg.save_ply:
sh0 = self.splats["sh0"]
shN = self.splats["shN"]
means = self.splats["means"]
scales = self.splats["scales"]
quats = self.splats["quats"]
opacities = self.splats["opacities"]
export_splats(
means=means,
scales=scales,
quats=quats,
opacities=opacities,
sh0=sh0,
shN=shN,
format="ply",
save_to=f"{self.ply_dir}/point_cloud_{step}.ply",
)
# Turn Gradients into Sparse Tensor before running optimizer
if cfg.sparse_grad:
assert (
cfg.packed
), "Sparse gradients only work with packed mode."
gaussian_ids = info["gaussian_ids"]
for k in self.splats.keys():
grad = self.splats[k].grad
if grad is None or grad.is_sparse:
continue
self.splats[k].grad = torch.sparse_coo_tensor(
indices=gaussian_ids[None], # [1, nnz]
values=grad[gaussian_ids], # [nnz, ...]
size=self.splats[k].size(), # [N, ...]
is_coalesced=len(Ks) == 1,
)
if cfg.visible_adam:
gaussian_cnt = self.splats.means.shape[0]
if cfg.packed:
visibility_mask = torch.zeros_like(
self.splats["opacities"], dtype=bool
)
visibility_mask.scatter_(0, info["gaussian_ids"], 1)
else:
visibility_mask = (info["radii"] > 0).all(-1).any(0)
# optimize
for optimizer in self.optimizers.values():
if cfg.visible_adam:
optimizer.step(visibility_mask)
else:
optimizer.step()
optimizer.zero_grad(set_to_none=True)
for scheduler in schedulers:
scheduler.step()
# Run post-backward steps after backward and optimizer
if isinstance(self.cfg.strategy, DefaultStrategy):
self.cfg.strategy.step_post_backward(
params=self.splats,
optimizers=self.optimizers,
state=self.strategy_state,
step=step,
info=info,
packed=cfg.packed,
)
elif isinstance(self.cfg.strategy, MCMCStrategy):
self.cfg.strategy.step_post_backward(
params=self.splats,
optimizers=self.optimizers,
state=self.strategy_state,
step=step,
info=info,
lr=schedulers[0].get_last_lr()[0],
)
else:
assert_never(self.cfg.strategy)
# eval the full set
if step in [i - 1 for i in cfg.eval_steps]:
self.eval(step)
self.render_video(step)
@torch.no_grad()
def eval(
self,
step: int,
stage: str = "val",
canvas_h: int = 512,
canvas_w: int = 1024,
):
"""Entry for evaluation."""
print("Running evaluation...")
cfg = self.cfg
device = self.device
world_rank = self.world_rank
valloader = torch.utils.data.DataLoader(
self.valset, batch_size=1, shuffle=False, num_workers=1
)
ellipse_time = 0
metrics = defaultdict(list)
for i, data in enumerate(valloader):
camtoworlds = data["camtoworld"].to(device)
Ks = data["K"].to(device)
pixels = data["image"].to(device) / 255.0
height, width = pixels.shape[1:3]
masks = data["mask"].to(device) if "mask" in data else None
pixels = pixels.permute(0, 3, 1, 2) # NHWC -> NCHW
pixels = F.interpolate(pixels, size=(canvas_h, canvas_w // 2))
torch.cuda.synchronize()
tic = time.time()
colors, _, _ = self.rasterize_splats(
camtoworlds=camtoworlds,
Ks=Ks,
width=width,
height=height,
sh_degree=cfg.sh_degree,
near_plane=cfg.near_plane,
far_plane=cfg.far_plane,
masks=masks,
) # [1, H, W, 3]
torch.cuda.synchronize()
ellipse_time += max(time.time() - tic, 1e-10)
colors = colors.permute(0, 3, 1, 2) # NHWC -> NCHW
colors = F.interpolate(colors, size=(canvas_h, canvas_w // 2))
colors = torch.clamp(colors, 0.0, 1.0)
canvas_list = [pixels, colors]
if world_rank == 0:
canvas = torch.cat(canvas_list, dim=2).squeeze(0)
canvas = canvas.permute(1, 2, 0) # CHW -> HWC
canvas = (canvas * 255).to(torch.uint8).cpu().numpy()
cv2.imwrite(
f"{self.render_dir}/{stage}_step{step}_{i:04d}.png",
canvas[..., ::-1],
)
metrics["psnr"].append(self.psnr(colors, pixels))
metrics["ssim"].append(self.ssim(colors, pixels))
metrics["lpips"].append(self.lpips(colors, pixels))
if world_rank == 0:
ellipse_time /= len(valloader)
stats = {
k: torch.stack(v).mean().item() for k, v in metrics.items()
}
stats.update(
{
"ellipse_time": ellipse_time,
"num_GS": len(self.splats["means"]),
}
)
print(
f"PSNR: {stats['psnr']:.3f}, SSIM: {stats['ssim']:.4f}, LPIPS: {stats['lpips']:.3f} "
f"Time: {stats['ellipse_time']:.3f}s/image "
f"Number of GS: {stats['num_GS']}"
)
# save stats as json
with open(
f"{self.stats_dir}/{stage}_step{step:04d}.json", "w"
) as f:
json.dump(stats, f)
# save stats to tensorboard
for k, v in stats.items():
self.writer.add_scalar(f"{stage}/{k}", v, step)
self.writer.flush()
@torch.no_grad()
def render_video(
self, step: int, canvas_h: int = 512, canvas_w: int = 1024
):
testloader = torch.utils.data.DataLoader(
self.testset, batch_size=1, shuffle=False, num_workers=1
)
images_cache = []
depth_global_min, depth_global_max = float("inf"), -float("inf")
for data in testloader:
camtoworlds = data["camtoworld"].to(self.device)
Ks = resize_pinhole_intrinsics(
data["K"].squeeze(),
raw_hw=(data["image_h"].item(), data["image_w"].item()),
new_hw=(canvas_h, canvas_w // 2),
).to(self.device)
renders, _, _ = self.rasterize_splats(
camtoworlds=camtoworlds,
Ks=Ks[None, ...],
width=canvas_w // 2,
height=canvas_h,
sh_degree=self.cfg.sh_degree,
near_plane=self.cfg.near_plane,
far_plane=self.cfg.far_plane,
render_mode="RGB+ED",
) # [1, H, W, 4]
colors = torch.clamp(renders[0, ..., 0:3], 0.0, 1.0) # [H, W, 3]
colors = (colors * 255).to(torch.uint8).cpu().numpy()
depths = renders[0, ..., 3:4] # [H, W, 1], tensor in device.
images_cache.append([colors, depths])
depth_global_min = min(depth_global_min, depths.min().item())
depth_global_max = max(depth_global_max, depths.max().item())
video_path = f"{self.render_dir}/video_step{step}.mp4"
writer = imageio.get_writer(video_path, fps=30)
for rgb, depth in images_cache:
depth_normalized = torch.clip(
(depth - depth_global_min)
/ (depth_global_max - depth_global_min),
0,
1,
)
depth_normalized = (
(depth_normalized * 255).to(torch.uint8).cpu().numpy()
)
depth_map = cv2.applyColorMap(depth_normalized, cv2.COLORMAP_JET)
image = np.concatenate([rgb, depth_map], axis=1)
writer.append_data(image)
writer.close()
def entrypoint(
local_rank: int, world_rank, world_size: int, cfg: GsplatTrainConfig
):
runner = Runner(local_rank, world_rank, world_size, cfg)
if cfg.ckpt is not None:
# run eval only
ckpts = [
torch.load(file, map_location=runner.device, weights_only=True)
for file in cfg.ckpt
]
for k in runner.splats.keys():
runner.splats[k].data = torch.cat(
[ckpt["splats"][k] for ckpt in ckpts]
)
step = ckpts[0]["step"]
runner.eval(step=step)
runner.render_video(step=step)
else:
runner.train()
runner.render_video(step=cfg.max_steps - 1)
if __name__ == "__main__":
configs = {
"default": (
"Gaussian splatting training using densification heuristics from the original paper.",
GsplatTrainConfig(
strategy=DefaultStrategy(verbose=True),
),
),
"mcmc": (
"Gaussian splatting training using densification from the paper '3D Gaussian Splatting as Markov Chain Monte Carlo'.",
GsplatTrainConfig(
init_scale=0.1,
opacity_reg=0.01,
scale_reg=0.01,
strategy=MCMCStrategy(verbose=True),
),
),
}
cfg = tyro.extras.overridable_config_cli(configs)
cfg.adjust_steps(cfg.steps_scaler)
cli(entrypoint, cfg, verbose=True)

View File

@ -0,0 +1,538 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
from embodied_gen.utils.monkey_patches import monkey_patch_pano2room
monkey_patch_pano2room()
import os
import cv2
import numpy as np
import torch
import trimesh
from equilib import cube2equi, equi2pers
from kornia.morphology import dilation
from PIL import Image
from embodied_gen.models.sr_model import ImageRealESRGAN
from embodied_gen.utils.config import Pano2MeshSRConfig
from embodied_gen.utils.gaussian import compute_pinhole_intrinsics
from embodied_gen.utils.log import logger
from thirdparty.pano2room.modules.geo_predictors import PanoJointPredictor
from thirdparty.pano2room.modules.geo_predictors.PanoFusionDistancePredictor import (
PanoFusionDistancePredictor,
)
from thirdparty.pano2room.modules.inpainters import PanoPersFusionInpainter
from thirdparty.pano2room.modules.mesh_fusion.render import (
features_to_world_space_mesh,
render_mesh,
)
from thirdparty.pano2room.modules.mesh_fusion.sup_info import SupInfoPool
from thirdparty.pano2room.utils.camera_utils import gen_pano_rays
from thirdparty.pano2room.utils.functions import (
depth_to_distance,
get_cubemap_views_world_to_cam,
resize_image_with_aspect_ratio,
rot_z_world_to_cam,
tensor_to_pil,
)
class Pano2MeshSRPipeline:
"""Converting panoramic RGB image into 3D mesh representations, followed by inpainting and mesh refinement.
This class integrates several key components including:
- Depth estimation from RGB panorama
- Inpainting of missing regions under offsets
- RGB-D to mesh conversion
- Multi-view mesh repair
- 3D Gaussian Splatting (3DGS) dataset generation
Args:
config (Pano2MeshSRConfig): Configuration object containing model and pipeline parameters.
Example:
```python
pipeline = Pano2MeshSRPipeline(config)
pipeline(pano_image='example.png', output_dir='./output')
```
"""
def __init__(self, config: Pano2MeshSRConfig) -> None:
self.cfg = config
self.device = config.device
# Init models.
self.inpainter = PanoPersFusionInpainter(save_path=None)
self.geo_predictor = PanoJointPredictor(save_path=None)
self.pano_fusion_distance_predictor = PanoFusionDistancePredictor()
self.super_model = ImageRealESRGAN(outscale=self.cfg.upscale_factor)
# Init poses.
cubemap_w2cs = get_cubemap_views_world_to_cam()
self.cubemap_w2cs = [p.to(self.device) for p in cubemap_w2cs]
self.camera_poses = self.load_camera_poses(self.cfg.trajectory_dir)
kernel = cv2.getStructuringElement(
cv2.MORPH_ELLIPSE, self.cfg.kernel_size
)
self.kernel = torch.from_numpy(kernel).float().to(self.device)
def init_mesh_params(self) -> None:
torch.set_default_device(self.device)
self.inpaint_mask = torch.ones(
(self.cfg.cubemap_h, self.cfg.cubemap_w), dtype=torch.bool
)
self.vertices = torch.empty((3, 0), requires_grad=False)
self.colors = torch.empty((3, 0), requires_grad=False)
self.faces = torch.empty((3, 0), dtype=torch.long, requires_grad=False)
@staticmethod
def read_camera_pose_file(filepath: str) -> np.ndarray:
with open(filepath, "r") as f:
values = [float(num) for line in f for num in line.split()]
return np.array(values).reshape(4, 4)
def load_camera_poses(
self, trajectory_dir: str
) -> tuple[np.ndarray, list[torch.Tensor]]:
pose_filenames = sorted(
[
fname
for fname in os.listdir(trajectory_dir)
if fname.startswith("camera_pose")
]
)
pano_pose_world = None
relative_poses = []
for idx, filename in enumerate(pose_filenames):
pose_path = os.path.join(trajectory_dir, filename)
pose_matrix = self.read_camera_pose_file(pose_path)
if pano_pose_world is None:
pano_pose_world = pose_matrix.copy()
pano_pose_world[0, 3] += self.cfg.pano_center_offset[0]
pano_pose_world[2, 3] += self.cfg.pano_center_offset[1]
# Use different reference for the first 6 cubemap views
reference_pose = pose_matrix if idx < 6 else pano_pose_world
relative_matrix = pose_matrix @ np.linalg.inv(reference_pose)
relative_matrix[0:2, :] *= -1 # flip_xy
relative_matrix = (
relative_matrix @ rot_z_world_to_cam(180).cpu().numpy()
)
relative_matrix[:3, 3] *= self.cfg.pose_scale
relative_matrix = torch.tensor(
relative_matrix, dtype=torch.float32
)
relative_poses.append(relative_matrix)
return relative_poses
def load_inpaint_poses(
self, poses: torch.Tensor
) -> dict[int, torch.Tensor]:
inpaint_poses = dict()
sampled_views = poses[:: self.cfg.inpaint_frame_stride]
init_pose = torch.eye(4)
for idx, w2c_tensor in enumerate(sampled_views):
w2c = w2c_tensor.cpu().numpy().astype(np.float32)
c2w = np.linalg.inv(w2c)
pose_tensor = init_pose.clone()
pose_tensor[:3, 3] = torch.from_numpy(c2w[:3, 3])
pose_tensor[:3, 3] *= -1
inpaint_poses[idx] = pose_tensor.to(self.device)
return inpaint_poses
def project(self, world_to_cam: torch.Tensor):
(
project_image,
project_depth,
inpaint_mask,
_,
z_buf,
mesh,
) = render_mesh(
vertices=self.vertices,
faces=self.faces,
vertex_features=self.colors,
H=self.cfg.cubemap_h,
W=self.cfg.cubemap_w,
fov_in_degrees=self.cfg.fov,
RT=world_to_cam,
blur_radius=self.cfg.blur_radius,
faces_per_pixel=self.cfg.faces_per_pixel,
)
project_image = project_image * ~inpaint_mask
return project_image[:3, ...], inpaint_mask, project_depth
def render_pano(self, pose: torch.Tensor):
cubemap_list = []
for cubemap_pose in self.cubemap_w2cs:
project_pose = cubemap_pose @ pose
rgb, inpaint_mask, depth = self.project(project_pose)
distance_map = depth_to_distance(depth[None, ...])
mask = inpaint_mask[None, ...]
cubemap_list.append(torch.cat([rgb, distance_map, mask], dim=0))
# Set default tensor type for CPU operation in cube2equi
with torch.device("cpu"):
pano_rgbd = cube2equi(
cubemap_list, "list", self.cfg.pano_h, self.cfg.pano_w
)
pano_rgb = pano_rgbd[:3, :, :]
pano_depth = pano_rgbd[3:4, :, :].squeeze(0)
pano_mask = pano_rgbd[4:, :, :].squeeze(0)
return pano_rgb, pano_depth, pano_mask
def rgbd_to_mesh(
self,
rgb: torch.Tensor,
depth: torch.Tensor,
inpaint_mask: torch.Tensor,
world_to_cam: torch.Tensor = None,
using_distance_map: bool = True,
) -> None:
if world_to_cam is None:
world_to_cam = torch.eye(4, dtype=torch.float32).to(self.device)
if inpaint_mask.sum() == 0:
return
vertices, faces, colors = features_to_world_space_mesh(
colors=rgb.squeeze(0),
depth=depth,
fov_in_degrees=self.cfg.fov,
world_to_cam=world_to_cam,
mask=inpaint_mask,
faces=self.faces,
vertices=self.vertices,
using_distance_map=using_distance_map,
edge_threshold=0.05,
)
faces += self.vertices.shape[1]
self.vertices = torch.cat([self.vertices, vertices], dim=1)
self.colors = torch.cat([self.colors, colors], dim=1)
self.faces = torch.cat([self.faces, faces], dim=1)
def get_edge_image_by_depth(
self, depth: torch.Tensor, dilate_iter: int = 1
) -> np.ndarray:
if isinstance(depth, torch.Tensor):
depth = depth.cpu().detach().numpy()
gray = (depth / depth.max() * 255).astype(np.uint8)
edges = cv2.Canny(gray, 60, 150)
if dilate_iter > 0:
kernel = np.ones((3, 3), np.uint8)
edges = cv2.dilate(edges, kernel, iterations=dilate_iter)
return edges
def mesh_repair_by_greedy_view_selection(
self, pose_dict: dict[str, torch.Tensor], output_dir: str
) -> list:
inpainted_panos_w_pose = []
while len(pose_dict) > 0:
logger.info(f"Repairing mesh left rounds {len(pose_dict)}")
sampled_views = []
for key, pose in pose_dict.items():
pano_rgb, pano_distance, pano_mask = self.render_pano(pose)
completeness = torch.sum(1 - pano_mask) / (pano_mask.numel())
sampled_views.append((key, completeness.item(), pose))
if len(sampled_views) == 0:
break
# Find inpainting with least view completeness.
sampled_views = sorted(sampled_views, key=lambda x: x[1])
key, _, pose = sampled_views[len(sampled_views) * 2 // 3]
pose_dict.pop(key)
pano_rgb, pano_distance, pano_mask = self.render_pano(pose)
colors = pano_rgb.permute(1, 2, 0).clone()
distances = pano_distance.unsqueeze(-1).clone()
pano_inpaint_mask = pano_mask.clone()
init_pose = pose.clone()
normals = None
if pano_inpaint_mask.min().item() < 0.5:
colors, distances, normals = self.inpaint_panorama(
idx=key,
colors=colors,
distances=distances,
pano_mask=pano_inpaint_mask,
)
init_pose[0, 3], init_pose[1, 3], init_pose[2, 3] = (
-pose[0, 3],
pose[2, 3],
0,
)
rays = gen_pano_rays(
init_pose, self.cfg.pano_h, self.cfg.pano_w
)
conflict_mask = self.sup_pool.geo_check(
rays, distances.unsqueeze(-1)
) # 0 is conflict, 1 not conflict
pano_inpaint_mask *= conflict_mask
self.rgbd_to_mesh(
colors.permute(2, 0, 1),
distances,
pano_inpaint_mask,
world_to_cam=pose,
)
self.sup_pool.register_sup_info(
pose=init_pose,
mask=pano_inpaint_mask.clone(),
rgb=colors,
distance=distances.unsqueeze(-1),
normal=normals,
)
colors = colors.permute(2, 0, 1).unsqueeze(0)
inpainted_panos_w_pose.append([colors, pose])
if self.cfg.visualize:
from embodied_gen.data.utils import DiffrastRender
tensor_to_pil(pano_rgb.unsqueeze(0)).save(
f"{output_dir}/rendered_pano_{key}.jpg"
)
tensor_to_pil(colors).save(
f"{output_dir}/inpainted_pano_{key}.jpg"
)
norm_depth = DiffrastRender.normalize_map_by_mask(
distances, torch.ones_like(distances)
)
heatmap = (norm_depth.cpu().numpy() * 255).astype(np.uint8)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
Image.fromarray(heatmap).save(
f"{output_dir}/inpainted_depth_{key}.png"
)
return inpainted_panos_w_pose
def inpaint_panorama(
self,
idx: int,
colors: torch.Tensor,
distances: torch.Tensor,
pano_mask: torch.Tensor,
) -> tuple[torch.Tensor]:
mask = (pano_mask[None, ..., None] > 0.5).float()
mask = mask.permute(0, 3, 1, 2)
mask = dilation(mask, kernel=self.kernel)
mask = mask[0, 0, ..., None] # hwc
inpainted_img = self.inpainter.inpaint(idx, colors, mask)
inpainted_img = colors * (1 - mask) + inpainted_img * mask
inpainted_distances, inpainted_normals = self.geo_predictor(
idx,
inpainted_img,
distances[..., None],
mask=mask,
reg_loss_weight=0.0,
normal_loss_weight=5e-2,
normal_tv_loss_weight=5e-2,
)
return inpainted_img, inpainted_distances.squeeze(), inpainted_normals
def preprocess_pano(
self, image: Image.Image | str
) -> tuple[torch.Tensor, torch.Tensor]:
if isinstance(image, str):
image = Image.open(image)
image = image.convert("RGB")
if image.size[0] < image.size[1]:
image = image.transpose(Image.TRANSPOSE)
image = resize_image_with_aspect_ratio(image, self.cfg.pano_w)
image_rgb = torch.tensor(np.array(image)).permute(2, 0, 1) / 255
image_rgb = image_rgb.to(self.device)
image_depth = self.pano_fusion_distance_predictor.predict(
image_rgb.permute(1, 2, 0)
)
image_depth = (
image_depth / image_depth.max() * self.cfg.depth_scale_factor
)
return image_rgb, image_depth
def pano_to_perpective(
self, pano_image: torch.Tensor, pitch: float, yaw: float, fov: float
) -> torch.Tensor:
rots = dict(
roll=0,
pitch=pitch,
yaw=yaw,
)
perspective = equi2pers(
equi=pano_image.squeeze(0),
rots=rots,
height=self.cfg.cubemap_h,
width=self.cfg.cubemap_w,
fov_x=fov,
mode="bilinear",
).unsqueeze(0)
return perspective
def pano_to_cubemap(self, pano_rgb: torch.Tensor):
# Define six canonical cube directions in (pitch, yaw)
directions = [
(0, 0),
(0, 1.5 * np.pi),
(0, 1.0 * np.pi),
(0, 0.5 * np.pi),
(-0.5 * np.pi, 0),
(0.5 * np.pi, 0),
]
cubemaps_rgb = []
for pitch, yaw in directions:
rgb_view = self.pano_to_perpective(
pano_rgb, pitch, yaw, fov=self.cfg.fov
)
cubemaps_rgb.append(rgb_view.cpu())
return cubemaps_rgb
def save_mesh(self, output_path: str) -> None:
vertices_np = self.vertices.T.cpu().numpy()
colors_np = self.colors.T.cpu().numpy()
faces_np = self.faces.T.cpu().numpy()
mesh = trimesh.Trimesh(
vertices=vertices_np, faces=faces_np, vertex_colors=colors_np
)
mesh.export(output_path)
def mesh_pose_to_gs_pose(self, mesh_pose: torch.Tensor) -> np.ndarray:
pose = mesh_pose.clone()
pose[0, :] *= -1
pose[1, :] *= -1
Rw2c = pose[:3, :3].cpu().numpy()
Tw2c = pose[:3, 3:].cpu().numpy()
yz_reverse = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]])
Rc2w = (yz_reverse @ Rw2c).T
Tc2w = -(Rc2w @ yz_reverse @ Tw2c)
c2w = np.concatenate((Rc2w, Tc2w), axis=1)
c2w = np.concatenate((c2w, np.array([[0, 0, 0, 1]])), axis=0)
return c2w
def __call__(self, pano_image: Image.Image | str, output_dir: str):
self.init_mesh_params()
pano_rgb, pano_depth = self.preprocess_pano(pano_image)
self.sup_pool = SupInfoPool()
self.sup_pool.register_sup_info(
pose=torch.eye(4).to(self.device),
mask=torch.ones([self.cfg.pano_h, self.cfg.pano_w]),
rgb=pano_rgb.permute(1, 2, 0),
distance=pano_depth[..., None],
)
self.sup_pool.gen_occ_grid(res=256)
logger.info("Init mesh from pano RGBD image...")
depth_edge = self.get_edge_image_by_depth(pano_depth)
inpaint_edge_mask = (
~torch.from_numpy(depth_edge).to(self.device).bool()
)
self.rgbd_to_mesh(pano_rgb, pano_depth, inpaint_edge_mask)
repair_poses = self.load_inpaint_poses(self.camera_poses)
inpainted_panos_w_poses = self.mesh_repair_by_greedy_view_selection(
repair_poses, output_dir
)
torch.cuda.empty_cache()
torch.set_default_device("cpu")
if self.cfg.mesh_file is not None:
mesh_path = os.path.join(output_dir, self.cfg.mesh_file)
self.save_mesh(mesh_path)
if self.cfg.gs_data_file is None:
return
logger.info(f"Dump data for 3DGS training...")
points_rgb = (self.colors.clip(0, 1) * 255).to(torch.uint8)
data = {
"points": self.vertices.permute(1, 0).cpu().numpy(), # (N, 3)
"points_rgb": points_rgb.permute(1, 0).cpu().numpy(), # (N, 3)
"train": [],
"eval": [],
}
image_h = self.cfg.cubemap_h * self.cfg.upscale_factor
image_w = self.cfg.cubemap_w * self.cfg.upscale_factor
Ks = compute_pinhole_intrinsics(image_w, image_h, self.cfg.fov)
for idx, (pano_img, pano_pose) in enumerate(inpainted_panos_w_poses):
cubemaps = self.pano_to_cubemap(pano_img)
for i in range(len(cubemaps)):
cubemap = tensor_to_pil(cubemaps[i])
cubemap = self.super_model(cubemap)
mesh_pose = self.cubemap_w2cs[i] @ pano_pose
c2w = self.mesh_pose_to_gs_pose(mesh_pose)
data["train"].append(
{
"camtoworld": c2w.astype(np.float32),
"K": Ks.astype(np.float32),
"image": np.array(cubemap),
"image_h": image_h,
"image_w": image_w,
"image_id": len(cubemaps) * idx + i,
}
)
# Camera poses for evaluation.
for idx in range(len(self.camera_poses)):
c2w = self.mesh_pose_to_gs_pose(self.camera_poses[idx])
data["eval"].append(
{
"camtoworld": c2w.astype(np.float32),
"K": Ks.astype(np.float32),
"image_h": image_h,
"image_w": image_w,
"image_id": idx,
}
)
data_path = os.path.join(output_dir, self.cfg.gs_data_file)
torch.save(data, data_path)
return
if __name__ == "__main__":
output_dir = "outputs/bg_v2/test3"
input_pano = "apps/assets/example_scene/result_pano.png"
config = Pano2MeshSRConfig()
pipeline = Pano2MeshSRPipeline(config)
pipeline(input_pano, output_dir)

View File

@ -0,0 +1,190 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
from dataclasses import dataclass, field
from typing import List, Optional, Union
from gsplat.strategy import DefaultStrategy, MCMCStrategy
from typing_extensions import Literal, assert_never
__all__ = [
"Pano2MeshSRConfig",
"GsplatTrainConfig",
]
@dataclass
class Pano2MeshSRConfig:
mesh_file: str = "mesh_model.ply"
gs_data_file: str = "gs_data.pt"
device: str = "cuda"
blur_radius: int = 0
faces_per_pixel: int = 8
fov: int = 90
pano_w: int = 2048
pano_h: int = 1024
cubemap_w: int = 512
cubemap_h: int = 512
pose_scale: float = 0.6
pano_center_offset: tuple = (-0.2, 0.3)
inpaint_frame_stride: int = 20
trajectory_dir: str = "apps/assets/example_scene/camera_trajectory"
visualize: bool = False
depth_scale_factor: float = 3.4092
kernel_size: tuple = (9, 9)
upscale_factor: int = 4
@dataclass
class GsplatTrainConfig:
# Path to the .pt files. If provide, it will skip training and run evaluation only.
ckpt: Optional[List[str]] = None
# Render trajectory path
render_traj_path: str = "interp"
# Path to the Mip-NeRF 360 dataset
data_dir: str = "outputs/bg"
# Downsample factor for the dataset
data_factor: int = 4
# Directory to save results
result_dir: str = "outputs/bg"
# Every N images there is a test image
test_every: int = 8
# Random crop size for training (experimental)
patch_size: Optional[int] = None
# A global scaler that applies to the scene size related parameters
global_scale: float = 1.0
# Normalize the world space
normalize_world_space: bool = True
# Camera model
camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole"
# Port for the viewer server
port: int = 8080
# Batch size for training. Learning rates are scaled automatically
batch_size: int = 1
# A global factor to scale the number of training steps
steps_scaler: float = 1.0
# Number of training steps
max_steps: int = 30_000
# Steps to evaluate the model
eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
# Steps to save the model
save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
# Whether to save ply file (storage size can be large)
save_ply: bool = True
# Steps to save the model as ply
ply_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
# Whether to disable video generation during training and evaluation
disable_video: bool = False
# Initial number of GSs. Ignored if using sfm
init_num_pts: int = 100_000
# Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm
init_extent: float = 3.0
# Degree of spherical harmonics
sh_degree: int = 1
# Turn on another SH degree every this steps
sh_degree_interval: int = 1000
# Initial opacity of GS
init_opa: float = 0.1
# Initial scale of GS
init_scale: float = 1.0
# Weight for SSIM loss
ssim_lambda: float = 0.2
# Near plane clipping distance
near_plane: float = 0.01
# Far plane clipping distance
far_plane: float = 1e10
# Strategy for GS densification
strategy: Union[DefaultStrategy, MCMCStrategy] = field(
default_factory=DefaultStrategy
)
# Use packed mode for rasterization, this leads to less memory usage but slightly slower.
packed: bool = False
# Use sparse gradients for optimization. (experimental)
sparse_grad: bool = False
# Use visible adam from Taming 3DGS. (experimental)
visible_adam: bool = False
# Anti-aliasing in rasterization. Might slightly hurt quantitative metrics.
antialiased: bool = False
# Use random background for training to discourage transparency
random_bkgd: bool = False
# LR for 3D point positions
means_lr: float = 1.6e-4
# LR for Gaussian scale factors
scales_lr: float = 5e-3
# LR for alpha blending weights
opacities_lr: float = 5e-2
# LR for orientation (quaternions)
quats_lr: float = 1e-3
# LR for SH band 0 (brightness)
sh0_lr: float = 2.5e-3
# LR for higher-order SH (detail)
shN_lr: float = 2.5e-3 / 20
# Opacity regularization
opacity_reg: float = 0.0
# Scale regularization
scale_reg: float = 0.0
# Enable depth loss. (experimental)
depth_loss: bool = False
# Weight for depth loss
depth_lambda: float = 1e-2
# Dump information to tensorboard every this steps
tb_every: int = 200
# Save training images to tensorboard
tb_save_image: bool = False
lpips_net: Literal["vgg", "alex"] = "alex"
# 3DGUT (uncented transform + eval 3D)
with_ut: bool = False
with_eval3d: bool = False
scene_scale: float = 1.0
def adjust_steps(self, factor: float):
self.eval_steps = [int(i * factor) for i in self.eval_steps]
self.save_steps = [int(i * factor) for i in self.save_steps]
self.ply_steps = [int(i * factor) for i in self.ply_steps]
self.max_steps = int(self.max_steps * factor)
self.sh_degree_interval = int(self.sh_degree_interval * factor)
strategy = self.strategy
if isinstance(strategy, DefaultStrategy):
strategy.refine_start_iter = int(
strategy.refine_start_iter * factor
)
strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor)
strategy.reset_every = int(strategy.reset_every * factor)
strategy.refine_every = int(strategy.refine_every * factor)
elif isinstance(strategy, MCMCStrategy):
strategy.refine_start_iter = int(
strategy.refine_start_iter * factor
)
strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor)
strategy.refine_every = int(strategy.refine_every * factor)
else:
assert_never(strategy)

View File

@ -0,0 +1,331 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
# Part of the code comes from https://github.com/nerfstudio-project/gsplat
# Both under the Apache License, Version 2.0.
import math
import random
from io import BytesIO
from typing import Dict, Literal, Optional, Tuple
import numpy as np
import torch
import trimesh
from gsplat.optimizers import SelectiveAdam
from scipy.spatial.transform import Rotation
from sklearn.neighbors import NearestNeighbors
from torch import Tensor
from embodied_gen.models.gs_model import GaussianOperator
__all__ = [
"set_random_seed",
"export_splats",
"create_splats_with_optimizers",
"compute_pinhole_intrinsics",
"resize_pinhole_intrinsics",
"restore_scene_scale_and_position",
]
def knn(x: Tensor, K: int = 4) -> Tensor:
x_np = x.cpu().numpy()
model = NearestNeighbors(n_neighbors=K, metric="euclidean").fit(x_np)
distances, _ = model.kneighbors(x_np)
return torch.from_numpy(distances).to(x)
def rgb_to_sh(rgb: Tensor) -> Tensor:
C0 = 0.28209479177387814
return (rgb - 0.5) / C0
def set_random_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def splat2ply_bytes(
means: torch.Tensor,
scales: torch.Tensor,
quats: torch.Tensor,
opacities: torch.Tensor,
sh0: torch.Tensor,
shN: torch.Tensor,
) -> bytes:
num_splats = means.shape[0]
buffer = BytesIO()
# Write PLY header
buffer.write(b"ply\n")
buffer.write(b"format binary_little_endian 1.0\n")
buffer.write(f"element vertex {num_splats}\n".encode())
buffer.write(b"property float x\n")
buffer.write(b"property float y\n")
buffer.write(b"property float z\n")
for i, data in enumerate([sh0, shN]):
prefix = "f_dc" if i == 0 else "f_rest"
for j in range(data.shape[1]):
buffer.write(f"property float {prefix}_{j}\n".encode())
buffer.write(b"property float opacity\n")
for i in range(scales.shape[1]):
buffer.write(f"property float scale_{i}\n".encode())
for i in range(quats.shape[1]):
buffer.write(f"property float rot_{i}\n".encode())
buffer.write(b"end_header\n")
# Concatenate all tensors in the correct order
splat_data = torch.cat(
[means, sh0, shN, opacities.unsqueeze(1), scales, quats], dim=1
)
# Ensure correct dtype
splat_data = splat_data.to(torch.float32)
# Write binary data
float_dtype = np.dtype(np.float32).newbyteorder("<")
buffer.write(
splat_data.detach().cpu().numpy().astype(float_dtype).tobytes()
)
return buffer.getvalue()
def export_splats(
means: torch.Tensor,
scales: torch.Tensor,
quats: torch.Tensor,
opacities: torch.Tensor,
sh0: torch.Tensor,
shN: torch.Tensor,
format: Literal["ply"] = "ply",
save_to: Optional[str] = None,
) -> bytes:
"""Export a Gaussian Splats model to bytes in PLY file format."""
total_splats = means.shape[0]
assert means.shape == (total_splats, 3), "Means must be of shape (N, 3)"
assert scales.shape == (total_splats, 3), "Scales must be of shape (N, 3)"
assert quats.shape == (
total_splats,
4,
), "Quaternions must be of shape (N, 4)"
assert opacities.shape == (
total_splats,
), "Opacities must be of shape (N,)"
assert sh0.shape == (total_splats, 1, 3), "sh0 must be of shape (N, 1, 3)"
assert (
shN.ndim == 3 and shN.shape[0] == total_splats and shN.shape[2] == 3
), f"shN must be of shape (N, K, 3), got {shN.shape}"
# Reshape spherical harmonics
sh0 = sh0.squeeze(1) # Shape (N, 3)
shN = shN.permute(0, 2, 1).reshape(means.shape[0], -1) # Shape (N, K * 3)
# Check for NaN or Inf values
invalid_mask = (
torch.isnan(means).any(dim=1)
| torch.isinf(means).any(dim=1)
| torch.isnan(scales).any(dim=1)
| torch.isinf(scales).any(dim=1)
| torch.isnan(quats).any(dim=1)
| torch.isinf(quats).any(dim=1)
| torch.isnan(opacities).any(dim=0)
| torch.isinf(opacities).any(dim=0)
| torch.isnan(sh0).any(dim=1)
| torch.isinf(sh0).any(dim=1)
| torch.isnan(shN).any(dim=1)
| torch.isinf(shN).any(dim=1)
)
# Filter out invalid entries
valid_mask = ~invalid_mask
means = means[valid_mask]
scales = scales[valid_mask]
quats = quats[valid_mask]
opacities = opacities[valid_mask]
sh0 = sh0[valid_mask]
shN = shN[valid_mask]
if format == "ply":
data = splat2ply_bytes(means, scales, quats, opacities, sh0, shN)
else:
raise ValueError(f"Unsupported format: {format}")
if save_to:
with open(save_to, "wb") as binary_file:
binary_file.write(data)
return data
def create_splats_with_optimizers(
points: np.ndarray = None,
points_rgb: np.ndarray = None,
init_num_pts: int = 100_000,
init_extent: float = 3.0,
init_opacity: float = 0.1,
init_scale: float = 1.0,
means_lr: float = 1.6e-4,
scales_lr: float = 5e-3,
opacities_lr: float = 5e-2,
quats_lr: float = 1e-3,
sh0_lr: float = 2.5e-3,
shN_lr: float = 2.5e-3 / 20,
scene_scale: float = 1.0,
sh_degree: int = 3,
sparse_grad: bool = False,
visible_adam: bool = False,
batch_size: int = 1,
feature_dim: Optional[int] = None,
device: str = "cuda",
world_rank: int = 0,
world_size: int = 1,
) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]:
if points is not None and points_rgb is not None:
points = torch.from_numpy(points).float()
rgbs = torch.from_numpy(points_rgb / 255.0).float()
else:
points = (
init_extent * scene_scale * (torch.rand((init_num_pts, 3)) * 2 - 1)
)
rgbs = torch.rand((init_num_pts, 3))
# Initialize the GS size to be the average dist of the 3 nearest neighbors
dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,]
dist_avg = torch.sqrt(dist2_avg)
scales = (
torch.log(dist_avg * init_scale).unsqueeze(-1).repeat(1, 3)
) # [N, 3]
# Distribute the GSs to different ranks (also works for single rank)
points = points[world_rank::world_size]
rgbs = rgbs[world_rank::world_size]
scales = scales[world_rank::world_size]
N = points.shape[0]
quats = torch.rand((N, 4)) # [N, 4]
opacities = torch.logit(torch.full((N,), init_opacity)) # [N,]
params = [
# name, value, lr
("means", torch.nn.Parameter(points), means_lr * scene_scale),
("scales", torch.nn.Parameter(scales), scales_lr),
("quats", torch.nn.Parameter(quats), quats_lr),
("opacities", torch.nn.Parameter(opacities), opacities_lr),
]
if feature_dim is None:
# color is SH coefficients.
colors = torch.zeros((N, (sh_degree + 1) ** 2, 3)) # [N, K, 3]
colors[:, 0, :] = rgb_to_sh(rgbs)
params.append(("sh0", torch.nn.Parameter(colors[:, :1, :]), sh0_lr))
params.append(("shN", torch.nn.Parameter(colors[:, 1:, :]), shN_lr))
else:
# features will be used for appearance and view-dependent shading
features = torch.rand(N, feature_dim) # [N, feature_dim]
params.append(("features", torch.nn.Parameter(features), sh0_lr))
colors = torch.logit(rgbs) # [N, 3]
params.append(("colors", torch.nn.Parameter(colors), sh0_lr))
splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device)
# Scale learning rate based on batch size, reference:
# https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/
# Note that this would not make the training exactly equivalent, see
# https://arxiv.org/pdf/2402.18824v1
BS = batch_size * world_size
optimizer_class = None
if sparse_grad:
optimizer_class = torch.optim.SparseAdam
elif visible_adam:
optimizer_class = SelectiveAdam
else:
optimizer_class = torch.optim.Adam
optimizers = {
name: optimizer_class(
[{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}],
eps=1e-15 / math.sqrt(BS),
# TODO: check betas logic when BS is larger than 10 betas[0] will be zero.
betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)),
)
for name, _, lr in params
}
return splats, optimizers
def compute_pinhole_intrinsics(
image_w: int, image_h: int, fov_deg: float
) -> np.ndarray:
fov_rad = np.deg2rad(fov_deg)
fx = image_w / (2 * np.tan(fov_rad / 2))
fy = fx # assuming square pixels
cx = image_w / 2
cy = image_h / 2
K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
return K
def resize_pinhole_intrinsics(
raw_K: np.ndarray | torch.Tensor,
raw_hw: tuple[int, int],
new_hw: tuple[int, int],
) -> np.ndarray:
raw_h, raw_w = raw_hw
new_h, new_w = new_hw
scale_x = new_w / raw_w
scale_y = new_h / raw_h
new_K = raw_K.copy() if isinstance(raw_K, np.ndarray) else raw_K.clone()
new_K[0, 0] *= scale_x # fx
new_K[0, 2] *= scale_x # cx
new_K[1, 1] *= scale_y # fy
new_K[1, 2] *= scale_y # cy
return new_K
def restore_scene_scale_and_position(
real_height: float, mesh_path: str, gs_path: str
) -> None:
"""Scales a mesh and corresponding GS model to match a given real-world height.
Uses the 1st and 99th percentile of mesh Z-axis to estimate height,
applies scaling and vertical alignment, and updates both the mesh and GS model.
Args:
real_height (float): Target real-world height among Z axis.
mesh_path (str): Path to the input mesh file.
gs_path (str): Path to the Gaussian Splatting model file.
"""
mesh = trimesh.load(mesh_path)
z_min = np.percentile(mesh.vertices[:, 1], 1)
z_max = np.percentile(mesh.vertices[:, 1], 99)
height = z_max - z_min
scale = real_height / height
rot = Rotation.from_quat([0, 1, 0, 0])
mesh.vertices = rot.apply(mesh.vertices)
mesh.vertices[:, 1] -= z_min
mesh.vertices *= scale
mesh.export(mesh_path)
gs_model: GaussianOperator = GaussianOperator.load_from_ply(gs_path)
gs_model = gs_model.get_gaussians(
instance_pose=torch.tensor([0.0, -z_min, 0, 0, 1, 0, 0])
)
gs_model.rescale(scale)
gs_model.save_to_ply(gs_path)

View File

@ -0,0 +1,152 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import os
import sys
import zipfile
import torch
from huggingface_hub import hf_hub_download
from omegaconf import OmegaConf
from PIL import Image
from torchvision import transforms
def monkey_patch_pano2room():
current_file_path = os.path.abspath(__file__)
current_dir = os.path.dirname(current_file_path)
sys.path.append(os.path.join(current_dir, "../.."))
sys.path.append(os.path.join(current_dir, "../../thirdparty/pano2room"))
from thirdparty.pano2room.modules.geo_predictors.omnidata.omnidata_normal_predictor import (
OmnidataNormalPredictor,
)
from thirdparty.pano2room.modules.geo_predictors.omnidata.omnidata_predictor import (
OmnidataPredictor,
)
def patched_omni_depth_init(self):
self.img_size = 384
self.model = torch.hub.load(
'alexsax/omnidata_models', 'depth_dpt_hybrid_384'
)
self.model.eval()
self.trans_totensor = transforms.Compose(
[
transforms.Resize(self.img_size, interpolation=Image.BILINEAR),
transforms.CenterCrop(self.img_size),
transforms.Normalize(mean=0.5, std=0.5),
]
)
OmnidataPredictor.__init__ = patched_omni_depth_init
def patched_omni_normal_init(self):
self.img_size = 384
self.model = torch.hub.load(
'alexsax/omnidata_models', 'surface_normal_dpt_hybrid_384'
)
self.model.eval()
self.trans_totensor = transforms.Compose(
[
transforms.Resize(self.img_size, interpolation=Image.BILINEAR),
transforms.CenterCrop(self.img_size),
transforms.Normalize(mean=0.5, std=0.5),
]
)
OmnidataNormalPredictor.__init__ = patched_omni_normal_init
def patched_panojoint_init(self, save_path=None):
self.depth_predictor = OmnidataPredictor()
self.normal_predictor = OmnidataNormalPredictor()
self.save_path = save_path
from modules.geo_predictors import PanoJointPredictor
PanoJointPredictor.__init__ = patched_panojoint_init
# NOTE: We use gsplat instead.
# import depth_diff_gaussian_rasterization_min as ddgr
# from dataclasses import dataclass
# @dataclass
# class PatchedGaussianRasterizationSettings:
# image_height: int
# image_width: int
# tanfovx: float
# tanfovy: float
# bg: torch.Tensor
# scale_modifier: float
# viewmatrix: torch.Tensor
# projmatrix: torch.Tensor
# sh_degree: int
# campos: torch.Tensor
# prefiltered: bool
# debug: bool = False
# ddgr.GaussianRasterizationSettings = PatchedGaussianRasterizationSettings
# disable get_has_ddp_rank print in `BaseInpaintingTrainingModule`
os.environ["NODE_RANK"] = "0"
from thirdparty.pano2room.modules.inpainters.lama.saicinpainting.training.trainers import (
load_checkpoint,
)
from thirdparty.pano2room.modules.inpainters.lama_inpainter import (
LamaInpainter,
)
def patched_lama_inpaint_init(self):
zip_path = hf_hub_download(
repo_id="smartywu/big-lama",
filename="big-lama.zip",
repo_type="model",
)
extract_dir = os.path.splitext(zip_path)[0]
if not os.path.exists(extract_dir):
os.makedirs(extract_dir, exist_ok=True)
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(extract_dir)
config_path = os.path.join(extract_dir, 'big-lama', 'config.yaml')
checkpoint_path = os.path.join(
extract_dir, 'big-lama/models/best.ckpt'
)
train_config = OmegaConf.load(config_path)
train_config.training_model.predict_only = True
train_config.visualizer.kind = 'noop'
self.model = load_checkpoint(
train_config, checkpoint_path, strict=False, map_location='cpu'
)
self.model.freeze()
LamaInpainter.__init__ = patched_lama_inpaint_init
from diffusers import StableDiffusionInpaintPipeline
from thirdparty.pano2room.modules.inpainters.SDFT_inpainter import (
SDFTInpainter,
)
def patched_sd_inpaint_init(self, subset_name=None):
super(SDFTInpainter, self).__init__()
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float16,
).to("cuda")
pipe.enable_model_cpu_offload()
self.inpaint_pipe = pipe
SDFTInpainter.__init__ = patched_sd_inpaint_init

View File

@ -17,6 +17,7 @@
import logging
import math
import mimetypes
import os
import textwrap
from glob import glob
@ -27,10 +28,10 @@ import imageio
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import PIL.Image as Image
import spaces
from matplotlib.patches import Patch
from moviepy.editor import VideoFileClip, clips_array
from PIL import Image
from embodied_gen.data.differentiable_render import entrypoint as render_api
from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum
@ -45,6 +46,8 @@ __all__ = [
"filter_image_small_connected_components",
"combine_images_to_grid",
"SceneTreeVisualizer",
"is_image_file",
"parse_text_prompts",
]
@ -356,6 +359,23 @@ def load_scene_dict(file_path: str) -> dict:
return scene_dict
def is_image_file(filename: str) -> bool:
mime_type, _ = mimetypes.guess_type(filename)
return mime_type is not None and mime_type.startswith('image')
def parse_text_prompts(prompts: list[str]) -> list[str]:
if len(prompts) == 1 and prompts[0].endswith(".txt"):
with open(prompts[0], "r") as f:
prompts = [
line.strip()
for line in f
if line.strip() and not line.strip().startswith("#")
]
return prompts
if __name__ == "__main__":
merge_video_video(
"outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh_glo_normal.mp4", # noqa

View File

@ -1 +1 @@
VERSION = "v0.1.1"
VERSION = "v0.1.2"

View File

@ -33,6 +33,9 @@ __all__ = [
"ImageAestheticChecker",
"SemanticConsistChecker",
"TextGenAlignChecker",
"PanoImageGenChecker",
"PanoHeightEstimator",
"PanoImageOccChecker",
]
@ -328,6 +331,159 @@ class TextGenAlignChecker(BaseChecker):
)
class PanoImageGenChecker(BaseChecker):
"""A checker class that validates the quality and realism of generated panoramic indoor images.
Attributes:
gpt_client (GPTclient): A GPT client instance used to query for image validation.
prompt (str): The instruction prompt passed to the GPT model. If None, a default prompt is used.
verbose (bool): Whether to print internal processing information for debugging.
"""
def __init__(
self,
gpt_client: GPTclient,
prompt: str = None,
verbose: bool = False,
) -> None:
super().__init__(prompt, verbose)
self.gpt_client = gpt_client
if self.prompt is None:
self.prompt = """
You are a panoramic image analyzer specializing in indoor room structure validation.
Given a generated panoramic image, assess if it meets all the criteria:
- Floor Space: 30 percent of the floor is free of objects or obstructions.
- Visual Clarity: Floor, walls, and ceiling are clear, with no distortion, blur, noise.
- Structural Continuity: Surfaces form plausible, continuous geometry
without breaks, floating parts, or abrupt cuts.
- Spatial Completeness: Full 360° coverage without missing areas,
seams, gaps, or stitching artifacts.
Instructions:
- If all criteria are met, reply with "YES".
- Otherwise, reply with "NO: <brief explanation>" (max 20 words).
Respond exactly as:
"YES"
or
"NO: brief explanation."
"""
def query(self, image_paths: str | Image.Image) -> str:
return self.gpt_client.query(
text_prompt=self.prompt,
image_base64=image_paths,
)
class PanoImageOccChecker(BaseChecker):
"""Checks for physical obstacles in the bottom-center region of a panoramic image.
This class crops a specified region from the input panoramic image and uses
a GPT client to determine whether any physical obstacles there.
Args:
gpt_client (GPTclient): The GPT-based client used for visual reasoning.
box_hw (tuple[int, int]): The height and width of the crop box.
prompt (str, optional): Custom prompt for the GPT client. Defaults to a predefined one.
verbose (bool, optional): Whether to print verbose logs. Defaults to False.
"""
def __init__(
self,
gpt_client: GPTclient,
box_hw: tuple[int, int],
prompt: str = None,
verbose: bool = False,
) -> None:
super().__init__(prompt, verbose)
self.gpt_client = gpt_client
self.box_hw = box_hw
if self.prompt is None:
self.prompt = """
This image is a cropped region from the bottom-center of a panoramic view.
Please determine whether there is any obstacle present such as furniture, tables, or other physical objects.
Ignore floor textures, rugs, carpets, shadows, and lighting effects they do not count as obstacles.
Only consider real, physical objects that could block walking or movement.
Instructions:
- If there is no obstacle, reply: "YES".
- Otherwise, reply: "NO: <brief explanation>" (max 20 words).
Respond exactly as:
"YES"
or
"NO: brief explanation."
"""
def query(self, image_paths: str | Image.Image) -> str:
if isinstance(image_paths, str):
image_paths = Image.open(image_paths)
w, h = image_paths.size
image_paths = image_paths.crop(
(
(w - self.box_hw[1]) // 2,
h - self.box_hw[0],
(w + self.box_hw[1]) // 2,
h,
)
)
return self.gpt_client.query(
text_prompt=self.prompt,
image_base64=image_paths,
)
class PanoHeightEstimator(object):
"""Estimate the real ceiling height of an indoor space from a 360° panoramic image.
Attributes:
gpt_client (GPTclient): The GPT client used to perform image-based reasoning and return height estimates.
default_value (float): The fallback height in meters if parsing the GPT output fails.
prompt (str): The textual instruction used to guide the GPT model for height estimation.
"""
def __init__(
self,
gpt_client: GPTclient,
default_value: float = 3.5,
) -> None:
self.gpt_client = gpt_client
self.default_value = default_value
self.prompt = """
You are an expert in building height estimation and panoramic image analysis.
Your task is to analyze a 360° indoor panoramic image and estimate the **actual height** of the space in meters.
Consider the following visual cues:
1. Ceiling visibility and reference objects (doors, windows, furniture, appliances).
2. Floor features or level differences.
3. Room type (e.g., residential, office, commercial).
4. Object-to-ceiling proportions (e.g., height of doors relative to ceiling).
5. Architectural elements (e.g., chandeliers, shelves, kitchen cabinets).
Input: A full 360° panoramic indoor photo.
Output: A single number in meters representing the estimated room height. Only return the number (e.g., `3.2`)
"""
def __call__(self, image_paths: str | Image.Image) -> float:
result = self.gpt_client.query(
text_prompt=self.prompt,
image_base64=image_paths,
)
try:
result = float(result.strip())
except ValueError:
logger.error(
f"Parser error: failed convert {result} to float, use default value {self.default_value}."
)
result = self.default_value
return result
class SemanticMatcher(BaseChecker):
def __init__(
self,

View File

@ -1,65 +1,28 @@
#!/bin/bash
set -e
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m'
STAGE=$1 # "basic" | "extra" | "all"
STAGE=${STAGE:-all}
echo -e "${GREEN}Starting installation process...${NC}"
SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
source "$SCRIPT_DIR/install/_utils.sh"
git config --global http.postBuffer 524288000
# Patch submodule .gitignore to ignore __pycache__, only if submodule exists
PANO2ROOM_PATH="$SCRIPT_DIR/thirdparty/pano2room"
if [ -d "$PANO2ROOM_PATH" ]; then
echo "__pycache__/" > "$PANO2ROOM_PATH/.gitignore"
log_info "Added .gitignore to ignore __pycache__ in $PANO2ROOM_PATH"
fi
echo -e "${GREEN}Installing flash-attn...${NC}"
pip install flash-attn==2.7.0.post2 --no-build-isolation || {
echo -e "${RED}Failed to install flash-attn${NC}"
exit 1
}
log_info "===== Starting installation stage: $STAGE ====="
echo -e "${GREEN}Installing dependencies from requirements.txt...${NC}"
pip install -r requirements.txt --use-deprecated=legacy-resolver --default-timeout=60 || {
echo -e "${RED}Failed to install requirements${NC}"
exit 1
}
if [[ "$STAGE" == "basic" || "$STAGE" == "all" ]]; then
bash "$SCRIPT_DIR/install/install_basic.sh"
fi
if [[ "$STAGE" == "extra" || "$STAGE" == "all" ]]; then
bash "$SCRIPT_DIR/install/install_extra.sh"
fi
echo -e "${GREEN}Installing kolors from GitHub...${NC}"
pip install kolors@git+https://github.com/Kwai-Kolors/Kolors.git#egg=038818d || {
echo -e "${RED}Failed to install kolors${NC}"
exit 1
}
echo -e "${GREEN}Installing kaolin from GitHub...${NC}"
pip install kaolin@git+https://github.com/NVIDIAGameWorks/kaolin.git@v0.16.0 || {
echo -e "${RED}Failed to install kaolin${NC}"
exit 1
}
echo -e "${GREEN}Installing diff-gaussian-rasterization...${NC}"
TMP_DIR="/tmp/mip-splatting"
rm -rf "$TMP_DIR"
git clone --recursive https://github.com/autonomousvision/mip-splatting.git "$TMP_DIR" && \
pip install "$TMP_DIR/submodules/diff-gaussian-rasterization" && \
rm -rf "$TMP_DIR" || {
echo -e "${RED}Failed to clone or install diff-gaussian-rasterization${NC}"
rm -rf "$TMP_DIR"
exit 1
}
echo -e "${GREEN}Installing gsplat from GitHub...${NC}"
pip install git+https://github.com/nerfstudio-project/gsplat.git@v1.5.0 || {
echo -e "${RED}Failed to install gsplat${NC}"
exit 1
}
echo -e "${GREEN}Installing EmbodiedGen...${NC}"
pip install triton==2.1.0
pip install -e . || {
echo -e "${RED}Failed to install EmbodiedGen pyproject.toml${NC}"
exit 1
}
echo -e "${GREEN}Installation completed successfully!${NC}"
log_info "===== Installation completed successfully. ====="

21
install/_utils.sh Normal file
View File

@ -0,0 +1,21 @@
#!/bin/bash
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m'
log_info() {
echo -e "${GREEN}[INFO] $1${NC}"
}
log_error() {
echo -e "${RED}[ERROR] $1${NC}" >&2
}
try_install() {
log_info "$1"
eval "$2" || {
log_error "$3"
exit 1
}
}

35
install/install_basic.sh Normal file
View File

@ -0,0 +1,35 @@
#!/bin/bash
set -e
SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
source "$SCRIPT_DIR/_utils.sh"
try_install "Installing flash-attn..." \
"pip install flash-attn==2.7.0.post2 --no-build-isolation" \
"flash-attn installation failed."
try_install "Installing requirements.txt..." \
"pip install -r requirements.txt --use-deprecated=legacy-resolver --default-timeout=60" \
"requirements installation failed."
try_install "Installing kolors..." \
"pip install kolors@git+https://github.com/HochCC/Kolors.git" \
"kolors installation failed."
try_install "Installing kaolin..." \
"pip install kaolin@git+https://github.com/NVIDIAGameWorks/kaolin.git@v0.16.0" \
"kaolin installation failed."
log_info "Installing diff-gaussian-rasterization..."
TMP_DIR="/tmp/mip-splatting"
rm -rf "$TMP_DIR"
git clone --recursive https://github.com/autonomousvision/mip-splatting.git "$TMP_DIR"
pip install "$TMP_DIR/submodules/diff-gaussian-rasterization"
rm -rf "$TMP_DIR"
try_install "Installing gsplat..." \
"pip install git+https://github.com/nerfstudio-project/gsplat.git@v1.5.3" \
"gsplat installation failed."
try_install "Installing EmbodiedGen..." \
"pip install triton==2.1.0 --no-deps && pip install -e ." \
"EmbodiedGen installation failed."

52
install/install_extra.sh Normal file
View File

@ -0,0 +1,52 @@
#!/bin/bash
set -e
SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
source "$SCRIPT_DIR/_utils.sh"
# try_install "Installing txt2panoimg..." \
# "pip install txt2panoimg@git+https://github.com/HochCC/SD-T2I-360PanoImage --no-deps" \
# "txt2panoimg installation failed."
# try_install "Installing fused-ssim..." \
# "pip install fused-ssim@git+https://github.com/rahul-goel/fused-ssim#egg=328dc98" \
# "fused-ssim installation failed."
# try_install "Installing tiny-cuda-nn..." \
# "pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch" \
# "tiny-cuda-nn installation failed."
# try_install "Installing pytorch3d" \
# "pip install git+https://github.com/facebookresearch/pytorch3d.git@v0.7.7" \
# "pytorch3d installation failed."
PYTHON_PACKAGES_NODEPS=(
timm
txt2panoimg@git+https://github.com/HochCC/SD-T2I-360PanoImage
kornia
kornia_rs
)
PYTHON_PACKAGES=(
fused-ssim@git+https://github.com/rahul-goel/fused-ssim#egg=328dc98
git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
git+https://github.com/facebookresearch/pytorch3d.git@v0.7.7
h5py
albumentations==0.5.2
webdataset
icecream
open3d
pyequilib
numpy==1.26.4
triton==2.1.0
)
for pkg in "${PYTHON_PACKAGES_NODEPS[@]}"; do
try_install "Installing $pkg without dependencies..." \
"pip install --no-deps $pkg" \
"$pkg installation failed."
done
try_install "Installing other Python dependencies..." \
"pip install ${PYTHON_PACKAGES[*]}" \
"Python dependencies installation failed."

View File

@ -7,7 +7,7 @@ packages = ["embodied_gen"]
[project]
name = "embodied_gen"
version = "v0.1.1"
version = "v0.1.2"
readme = "README.md"
license = "Apache-2.0"
license-files = ["LICENSE", "NOTICE"]
@ -31,6 +31,7 @@ drender-cli = "embodied_gen.data.differentiable_render:entrypoint"
backproject-cli = "embodied_gen.data.backproject_v2:entrypoint"
img3d-cli = "embodied_gen.scripts.imageto3d:entrypoint"
text3d-cli = "embodied_gen.scripts.textto3d:text_to_3d"
scene3d-cli = "embodied_gen.scripts.gen_scene3d:entrypoint"
[tool.pydocstyle]
match = '(?!test_).*(?!_pb2)\.py'

View File

@ -32,6 +32,9 @@ vtk==9.3.1
spaces
colorlog
json-repair
scikit-learn
omegaconf
tyro
utils3d@git+https://github.com/EasternJournalist/utils3d.git#egg=9a4eb15
clip@git+https://github.com/openai/CLIP.git
segment-anything@git+https://github.com/facebookresearch/segment-anything.git#egg=dca509f

View File

@ -17,6 +17,7 @@
import logging
import tempfile
from glob import glob
import pytest
from embodied_gen.utils.gpt_clients import GPT_CLIENT
@ -25,6 +26,9 @@ from embodied_gen.validators.quality_checkers import (
ImageAestheticChecker,
ImageSegChecker,
MeshGeoChecker,
PanoHeightEstimator,
PanoImageGenChecker,
PanoImageOccChecker,
SemanticConsistChecker,
TextGenAlignChecker,
)
@ -57,6 +61,21 @@ def textalign_checker():
return TextGenAlignChecker(GPT_CLIENT)
@pytest.fixture(scope="module")
def pano_checker():
return PanoImageGenChecker(GPT_CLIENT)
@pytest.fixture(scope="module")
def pano_height_estimator():
return PanoHeightEstimator(GPT_CLIENT)
@pytest.fixture(scope="module")
def panoocc_checker():
return PanoImageOccChecker(GPT_CLIENT, box_hw=[90, 1000])
def test_geo_checker(geo_checker):
flag, result = geo_checker(
[
@ -117,3 +136,28 @@ def test_textgen_checker(textalign_checker, mesh_path, text_desc):
)
flag, result = textalign_checker(text_desc, image_list)
logger.info(f"textalign_checker: {flag}, {result}")
def test_panoheight_estimator(pano_height_estimator):
image_paths = glob("outputs/bg_v3/test2/*/*.png")
for image_path in image_paths:
result = pano_height_estimator(image_path)
logger.info(f"{type(result)}, {result}")
def test_pano_checker(pano_checker):
# image_paths = [
# "outputs/bg_gen2/scene_0000/pano_image.png",
# "outputs/bg_gen2/scene_0001/pano_image.png",
# ]
image_paths = glob("outputs/bg_gen/*/*.png")
for image_path in image_paths:
flag, result = pano_checker(image_path)
logger.info(f"{image_path} {flag}, {result}")
def test_panoocc_checker(panoocc_checker):
image_paths = glob("outputs/bg_gen/*/*.png")
for image_path in image_paths:
flag, result = panoocc_checker(image_path)
logger.info(f"{image_path} {flag}, {result}")

View File

@ -23,6 +23,8 @@ from embodied_gen.utils.gpt_clients import GPT_CLIENT
from embodied_gen.validators.quality_checkers import (
ImageSegChecker,
MeshGeoChecker,
PanoHeightEstimator,
PanoImageGenChecker,
SemanticConsistChecker,
)
@ -93,3 +95,16 @@ def test_semantic_checker(gptclient_query_case2):
)
assert isinstance(flag, (bool, type(None)))
assert isinstance(result, str)
def test_panoheight_estimator():
checker = PanoHeightEstimator(GPT_CLIENT, default_value=3.5)
result = checker(image_paths="dummy_path/pano.png")
assert isinstance(result, float)
def test_panogen_checker():
checker = PanoImageGenChecker(GPT_CLIENT)
flag, result = checker(image_paths="dummy_path/pano.png")
assert isinstance(flag, (bool, type(None)))
assert isinstance(result, str)

1
thirdparty/pano2room vendored Submodule

@ -0,0 +1 @@
Subproject commit bbf93ae57086ed700edc6ee445852d4457a9d704