diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..a47aa67
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,62 @@
+build/
+dummy/
+!scripts/build
+builddir/
+conan-deps/
+distribute/
+lib/
+bin/
+dist/
+deps/*
+docs/build
+python/dist/
+docs/index.rst
+python/MANIFEST.in
+*.egg-info
+*.pyc
+*.pyi
+*.json
+*.bak
+*.zip
+wheels*/
+
+# Compiled Object files
+*.slo
+*.lo
+*.o
+
+# Compiled Dynamic libraries
+*.so
+*.dylib
+
+# Compiled Static libraries
+*.lai
+*.la
+*.a
+
+.cproject
+.project
+.settings/
+*.db
+*.bak
+.arcconfig
+.vscode/
+
+# files
+*.pack
+*.pcd
+*.html
+*.ply
+*.mp4
+# node
+node_modules
+
+# local files
+build.sh
+__pycache__/
+output*
+*.log
+scripts/tools/
+weights/
+apps/assets/example_texture/
+apps/sessions/
\ No newline at end of file
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 0000000..2ccd752
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,4 @@
+[submodule "thirdparty/TRELLIS"]
+ path = thirdparty/TRELLIS
+ url = https://github.com/microsoft/TRELLIS.git
+ branch = main
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..28d6ba6
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,78 @@
+repos:
+ - repo: git@gitlab.hobot.cc:ptd/3rd/pre-commit/pre-commit-hooks.git
+ rev: v2.3.0 # Use the ref you want to point at
+ hooks:
+ - id: trailing-whitespace
+ - id: check-added-large-files
+ name: Check for added large files
+ description: Prevent giant files from being committed
+ entry: check-added-large-files
+ language: python
+ args: ["--maxkb=1024"]
+ - repo: local
+ hooks:
+ - id: cpplint-cpp-source
+ name: cpplint
+ description: Check cpp code style.
+ entry: python3 scripts/lint_src/lint.py
+ language: system
+ exclude: (?x)(^tools/|^thirdparty/|^patch_files/)
+ files: \.(c|cc|cxx|cpp|cu|h|hpp)$
+ args: [--project=asset_recons, --path]
+ - repo: local
+ hooks:
+ - id: pycodestyle-python
+ name: pep8-exclude-docs
+ description: Check python code style.
+ entry: pycodestyle
+ language: system
+ exclude: (?x)(^docs/|^thirdparty/|^scripts/build/)
+ files: \.(py)$
+ types: [file, python]
+ args: [--config=setup.cfg]
+
+
+ # pre-commit install --hook-type commit-msg to enable it
+ # - repo: local
+ # hooks:
+ # - id: commit-check
+ # name: check for commit msg format
+ # language: pygrep
+ # entry: '\A(?!(feat|fix|docs|style|refactor|perf|test|chore)\(.*\): (\[[a-zA-Z][a-zA-Z0-9_]+-[1-9][0-9]*\]|\[cr_id_skip\]) [A-Z]+.*)'
+ # args: [--multiline]
+ # stages: [commit-msg]
+
+ - repo: local
+ hooks:
+ - id: pydocstyle-python
+ name: pydocstyle-change-exclude-docs
+ description: Check python doc style.
+ entry: pydocstyle
+ language: system
+ exclude: (?x)(^docs/|^thirdparty/)
+ files: \.(py)$
+ types: [file, python]
+ args: [--config=pyproject.toml]
+ - repo: local
+ hooks:
+ - id: black
+ name: black-exclude-docs
+ description: black format
+ entry: black
+ language: system
+ exclude: (?x)(^docs/|^thirdparty/)
+ files: \.(py)$
+ types: [file, python]
+ args: [--config=pyproject.toml]
+
+ - repo: local
+ hooks:
+ - id: isort
+ name: isort
+ description: isort format
+ entry: isort
+ language: system
+ exclude: (?x)(^thirdparty/)
+ files: \.(py)$
+ types: [file, python]
+ args: [--settings-file=pyproject.toml]
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
index 261eeb9..2132bc3 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,3 +1,5 @@
+Copyright (c) 2024 Horizon Robotics and EmbodiedGen Contributors. All rights reserved.
+
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
@@ -186,7 +188,7 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.
- Copyright [yyyy] [name of copyright owner]
+ Copyright 2024 Horizon Robotics and EmbodiedGen Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 0000000..35d708e
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1 @@
+graft embodied_gen
\ No newline at end of file
diff --git a/README.md b/README.md
index c4ae5c8..c86525d 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,178 @@
-# EmbodiedGen
-Towards a Generative 3D World Engine for Embodied Intelligence
+# EmbodiedGen: Towards a Generative 3D World Engine for Embodied Intelligence
+
+[](https://horizonrobotics.github.io/robot_lab/embodied_gen/index.html)
+[](#)
+[](https://www.youtube.com/watch?v=SnHhzHeb_aI)
+[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Image-to-3D)
+[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Text-to-3D)
+[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Texture-Gen)
+
+
+
+
+**EmbodiedGen** generates interactive 3D worlds with real-world scale and physical realism at low cost.
+
+---
+
+## โจ Table of Contents of EmbodiedGen
+- [๐ผ๏ธ Image-to-3D](#image-to-3d)
+- [๐ Text-to-3D](#text-to-3d)
+- [๐จ Texture Generation](#texture-generation)
+- [๐ 3D Scene Generation](#3d-scene-generation)
+- [โ๏ธ Articulated Object Generation](#articulated-object-generation)
+- [๐๏ธ Layout Generation](#layout-generation)
+
+## ๐ Quick Start
+
+```sh
+git clone https://github.com/HorizonRobotics/EmbodiedGen
+cd EmbodiedGen
+conda create -n embodiedgen python=3.10.13 -y
+conda activate embodiedgen
+pip install -r requirements.txt --use-deprecated=legacy-resolver
+pip install -e .
+```
+
+---
+
+## ๐ข Setup GPT Agent
+
+Update the API key in file: `embodied_gen/utils/gpt_config.yaml`.
+
+You can choose between two backends for the GPT agent:
+
+- **`gpt-4o`** (Recommended) โ Use this if you have access to **Azure OpenAI**.
+- **`qwen2.5-vl`** โ An open alternative with free usage via [OpenRouter](https://openrouter.ai/settings/keys) (50 free requests per day)
+
+
+---
+
+
๐ผ๏ธ Image-to-3D
+
+[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Image-to-3D) Generate physically plausible 3D asset from input image.
+
+### Local Service
+Run the image-to-3D generation service locally. The first run will download required models.
+
+```sh
+# Run in foreground
+python apps/image_to_3d.py
+# Or run in the background
+CUDA_VISIBLE_DEVICES=0 nohup python apps/image_to_3d.py > /dev/null 2>&1 &
+```
+
+### Local API
+Generate a 3D model from an image using the command-line API.
+
+```sh
+python3 embodied_gen/scripts/imageto3d.py \
+ --image_path apps/assets/example_image/sample_04.jpg apps/assets/example_image/sample_19.jpg \
+ --output_root outputs/imageto3d/
+
+# See result(.urdf/mesh.obj/mesh.glb/gs.ply) in ${output_root}/sample_xx/result
+```
+
+---
+
+
+๐ Text-to-3D
+
+[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Text-to-3D) Create 3D assets from text descriptions for a wide range of geometry and styles.
+
+### Local Service
+Run the text-to-3D generation service locally.
+
+```sh
+python apps/text_to_3d.py
+```
+
+### Local API
+
+```sh
+bash embodied_gen/scripts/textto3d.sh \
+ --prompts "small bronze figurine of a lion" "ๅธฆๆจ่ดจๅบๅบง๏ผๅ
ทๆ็ป็บฌ็บฟ็ๅฐ็ไปช" "ๆฉ่ฒ็ตๅจๆ้ป๏ผๆ็ฃจๆ็ป่" \
+ --output_root outputs/textto3d/
+```
+
+---
+
+
+๐จ Texture Generation
+
+[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Texture-Gen) Generate visually rich textures for 3D mesh.
+
+### Local Service
+Run the texture generation service locally.
+
+```sh
+python apps/texture_edit.py
+```
+
+### Local API
+Generate textures for a 3D mesh using a text prompt.
+
+```sh
+bash embodied_gen/scripts/texture_gen.sh \
+ --mesh_path "apps/assets/example_texture/meshes/robot_text.obj" \
+ --prompt "ไธพ็็ๅญ็็บข่ฒๅๅฎ้ฃๆ ผๆบๅจไบบ๏ผ็ๅญไธๅ็โHelloโ" \
+ --output_root "outputs/texture_gen/" \
+ --uuid "robot_text"
+```
+
+---
+
+๐ 3D Scene Generation
+
+๐ง *Coming Soon*
+
+---
+
+
+โ๏ธ Articulated Object Generation
+
+๐ง *Coming Soon*
+
+---
+
+
+๐๏ธ Layout Generation
+
+๐ง *Coming Soon*
+
+---
+
+## ๐ Citation
+
+If you use EmbodiedGen in your research or projects, please cite:
+
+```bibtex
+Coming Soon
+```
+
+---
+
+## ๐ Acknowledgement
+
+EmbodiedGen builds upon the following amazing projects and models:
+
+- ๐ [Trellis](https://github.com/microsoft/TRELLIS)
+- ๐ [Hunyuan-Delight](https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0)
+- ๐ [Segment Anything Model](https://github.com/facebookresearch/segment-anything)
+- ๐ [Rembg: a tool to remove images background](https://github.com/danielgatis/rembg)
+- ๐ [RMBG-1.4: BRIA Background Removal](https://huggingface.co/briaai/RMBG-1.4)
+- ๐ [stable-diffusion-x4-upscaler](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler)
+- ๐ [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN)
+- ๐ [Kolors](https://github.com/Kwai-Kolors/Kolors)
+- ๐ [ChatGLM3](https://github.com/THUDM/ChatGLM3)
+- ๐ [Aesthetic Score Model](http://captions.christoph-schuhmann.de/aesthetic_viz_laion_sac+logos+ava1-l14-linearMSE-en-2.37B.html)
+- ๐ [Pano2Room](https://github.com/TrickyGo/Pano2Room)
+- ๐ [Diffusion360](https://github.com/ArcherFMY/SD-T2I-360PanoImage)
+- ๐ [kaolin](https://github.com/NVIDIAGameWorks/kaolin)
+- ๐ [diffusers](https://github.com/huggingface/diffusers)
+- ๐ GPT: QWEN2.5VL, GPT4o
+
+---
+
+## โ๏ธ License
+
+This project is licensed under the [Apache License 2.0](LICENSE). See the `LICENSE` file for details.
\ No newline at end of file
diff --git a/apps/common.py b/apps/common.py
new file mode 100644
index 0000000..11b7d6a
--- /dev/null
+++ b/apps/common.py
@@ -0,0 +1,899 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+import gc
+import logging
+import os
+import shutil
+import subprocess
+import sys
+from glob import glob
+
+import cv2
+import gradio as gr
+import numpy as np
+import spaces
+import torch
+import torch.nn.functional as F
+import trimesh
+from easydict import EasyDict as edict
+from gradio.themes import Soft
+from gradio.themes.utils.colors import gray, neutral, slate, stone, teal, zinc
+from PIL import Image
+from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
+from embodied_gen.data.differentiable_render import entrypoint as render_api
+from embodied_gen.data.utils import trellis_preprocess
+from embodied_gen.models.delight_model import DelightingModel
+from embodied_gen.models.gs_model import GaussianOperator
+from embodied_gen.models.segment_model import (
+ BMGG14Remover,
+ RembgRemover,
+ SAMPredictor,
+)
+from embodied_gen.models.sr_model import ImageRealESRGAN, ImageStableSR
+from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
+from embodied_gen.scripts.render_mv import build_texture_gen_pipe, infer_pipe
+from embodied_gen.scripts.text2image import (
+ build_text2img_ip_pipeline,
+ build_text2img_pipeline,
+ text2img_gen,
+)
+from embodied_gen.utils.gpt_clients import GPT_CLIENT
+from embodied_gen.utils.process_media import (
+ filter_image_small_connected_components,
+ merge_images_video,
+ render_video,
+)
+from embodied_gen.utils.tags import VERSION
+from embodied_gen.validators.quality_checkers import (
+ BaseChecker,
+ ImageAestheticChecker,
+ ImageSegChecker,
+ MeshGeoChecker,
+)
+from embodied_gen.validators.urdf_convertor import URDFGenerator, zip_files
+
+current_file_path = os.path.abspath(__file__)
+current_dir = os.path.dirname(current_file_path)
+sys.path.append(os.path.join(current_dir, ".."))
+from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
+from thirdparty.TRELLIS.trellis.representations import (
+ Gaussian,
+ MeshExtractResult,
+)
+from thirdparty.TRELLIS.trellis.representations.gaussian.general_utils import (
+ build_scaling_rotation,
+ inverse_sigmoid,
+ strip_symmetric,
+)
+from thirdparty.TRELLIS.trellis.utils import postprocessing_utils
+
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
+)
+logger = logging.getLogger(__name__)
+
+
+os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
+ "~/.cache/torch_extensions"
+)
+os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
+os.environ["SPCONV_ALGO"] = "native"
+
+MAX_SEED = 100000
+DELIGHT = DelightingModel()
+IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
+# IMAGESR_MODEL = ImageStableSR()
+
+
+def patched_setup_functions(self):
+ def inverse_softplus(x):
+ return x + torch.log(-torch.expm1(-x))
+
+ def build_covariance_from_scaling_rotation(
+ scaling, scaling_modifier, rotation
+ ):
+ L = build_scaling_rotation(scaling_modifier * scaling, rotation)
+ actual_covariance = L @ L.transpose(1, 2)
+ symm = strip_symmetric(actual_covariance)
+ return symm
+
+ if self.scaling_activation_type == "exp":
+ self.scaling_activation = torch.exp
+ self.inverse_scaling_activation = torch.log
+ elif self.scaling_activation_type == "softplus":
+ self.scaling_activation = F.softplus
+ self.inverse_scaling_activation = inverse_softplus
+
+ self.covariance_activation = build_covariance_from_scaling_rotation
+ self.opacity_activation = torch.sigmoid
+ self.inverse_opacity_activation = inverse_sigmoid
+ self.rotation_activation = F.normalize
+
+ self.scale_bias = self.inverse_scaling_activation(
+ torch.tensor(self.scaling_bias)
+ ).to(self.device)
+ self.rots_bias = torch.zeros((4)).to(self.device)
+ self.rots_bias[0] = 1
+ self.opacity_bias = self.inverse_opacity_activation(
+ torch.tensor(self.opacity_bias)
+ ).to(self.device)
+
+
+Gaussian.setup_functions = patched_setup_functions
+
+
+def download_kolors_weights() -> None:
+ logger.info(f"Download kolors weights from huggingface...")
+ subprocess.run(
+ [
+ "huggingface-cli",
+ "download",
+ "--resume-download",
+ "Kwai-Kolors/Kolors",
+ "--local-dir",
+ "weights/Kolors",
+ ],
+ check=True,
+ )
+ subprocess.run(
+ [
+ "huggingface-cli",
+ "download",
+ "--resume-download",
+ "Kwai-Kolors/Kolors-IP-Adapter-Plus",
+ "--local-dir",
+ "weights/Kolors-IP-Adapter-Plus",
+ ],
+ check=True,
+ )
+
+
+if os.getenv("GRADIO_APP") == "imageto3d":
+ RBG_REMOVER = RembgRemover()
+ RBG14_REMOVER = BMGG14Remover()
+ SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
+ PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
+ "microsoft/TRELLIS-image-large"
+ )
+ # PIPELINE.cuda()
+ SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
+ GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
+ AESTHETIC_CHECKER = ImageAestheticChecker()
+ CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
+ TMP_DIR = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
+ )
+elif os.getenv("GRADIO_APP") == "textto3d":
+ RBG_REMOVER = RembgRemover()
+ RBG14_REMOVER = BMGG14Remover()
+ PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
+ "microsoft/TRELLIS-image-large"
+ )
+ # PIPELINE.cuda()
+ text_model_dir = "weights/Kolors"
+ if not os.path.exists(text_model_dir):
+ download_kolors_weights()
+
+ PIPELINE_IMG_IP = build_text2img_ip_pipeline(text_model_dir, ref_scale=0.3)
+ PIPELINE_IMG = build_text2img_pipeline(text_model_dir)
+ SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
+ GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
+ AESTHETIC_CHECKER = ImageAestheticChecker()
+ CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
+ TMP_DIR = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), "sessions/textto3d"
+ )
+elif os.getenv("GRADIO_APP") == "texture_edit":
+ if not os.path.exists("weights/Kolors"):
+ download_kolors_weights()
+
+ PIPELINE_IP = build_texture_gen_pipe(
+ base_ckpt_dir="./weights",
+ ip_adapt_scale=0.7,
+ device="cuda",
+ )
+ PIPELINE = build_texture_gen_pipe(
+ base_ckpt_dir="./weights",
+ ip_adapt_scale=0,
+ device="cuda",
+ )
+ TMP_DIR = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), "sessions/texture_edit"
+ )
+
+os.makedirs(TMP_DIR, exist_ok=True)
+
+
+lighting_css = """
+
+"""
+
+image_css = """
+
+"""
+
+custom_theme = Soft(
+ primary_hue=stone,
+ secondary_hue=gray,
+ radius_size="md",
+ text_size="sm",
+ spacing_size="sm",
+)
+
+
+def start_session(req: gr.Request) -> None:
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
+ os.makedirs(user_dir, exist_ok=True)
+
+
+def end_session(req: gr.Request) -> None:
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
+ if os.path.exists(user_dir):
+ shutil.rmtree(user_dir)
+
+
+@spaces.GPU
+def preprocess_image_fn(
+ image: str | np.ndarray | Image.Image, rmbg_tag: str = "rembg"
+) -> tuple[Image.Image, Image.Image]:
+ if isinstance(image, str):
+ image = Image.open(image)
+ elif isinstance(image, np.ndarray):
+ image = Image.fromarray(image)
+
+ image_cache = image.copy().resize((512, 512))
+
+ bg_remover = RBG_REMOVER if rmbg_tag == "rembg" else RBG14_REMOVER
+ image = bg_remover(image)
+ image = trellis_preprocess(image)
+
+ return image, image_cache
+
+
+def preprocess_sam_image_fn(
+ image: Image.Image,
+) -> tuple[Image.Image, Image.Image]:
+ if isinstance(image, np.ndarray):
+ image = Image.fromarray(image)
+
+ sam_image = SAM_PREDICTOR.preprocess_image(image)
+ image_cache = Image.fromarray(sam_image).resize((512, 512))
+ SAM_PREDICTOR.predictor.set_image(sam_image)
+
+ return sam_image, image_cache
+
+
+def active_btn_by_content(content: gr.Image) -> gr.Button:
+ interactive = True if content is not None else False
+
+ return gr.Button(interactive=interactive)
+
+
+def active_btn_by_text_content(content: gr.Textbox) -> gr.Button:
+ if content is not None and len(content) > 0:
+ interactive = True
+ else:
+ interactive = False
+
+ return gr.Button(interactive=interactive)
+
+
+def get_selected_image(
+ choice: str, sample1: str, sample2: str, sample3: str
+) -> str:
+ if choice == "sample1":
+ return sample1
+ elif choice == "sample2":
+ return sample2
+ elif choice == "sample3":
+ return sample3
+ else:
+ raise ValueError(f"Invalid choice: {choice}")
+
+
+def get_cached_image(image_path: str) -> Image.Image:
+ if isinstance(image_path, Image.Image):
+ return image_path
+ return Image.open(image_path).resize((512, 512))
+
+
+@spaces.GPU
+def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
+ return {
+ "gaussian": {
+ **gs.init_params,
+ "_xyz": gs._xyz.cpu().numpy(),
+ "_features_dc": gs._features_dc.cpu().numpy(),
+ "_scaling": gs._scaling.cpu().numpy(),
+ "_rotation": gs._rotation.cpu().numpy(),
+ "_opacity": gs._opacity.cpu().numpy(),
+ },
+ "mesh": {
+ "vertices": mesh.vertices.cpu().numpy(),
+ "faces": mesh.faces.cpu().numpy(),
+ },
+ }
+
+
+def unpack_state(state: dict, device: str = "cpu") -> tuple[Gaussian, dict]:
+ gs = Gaussian(
+ aabb=state["gaussian"]["aabb"],
+ sh_degree=state["gaussian"]["sh_degree"],
+ mininum_kernel_size=state["gaussian"]["mininum_kernel_size"],
+ scaling_bias=state["gaussian"]["scaling_bias"],
+ opacity_bias=state["gaussian"]["opacity_bias"],
+ scaling_activation=state["gaussian"]["scaling_activation"],
+ device=device,
+ )
+ gs._xyz = torch.tensor(state["gaussian"]["_xyz"], device=device)
+ gs._features_dc = torch.tensor(
+ state["gaussian"]["_features_dc"], device=device
+ )
+ gs._scaling = torch.tensor(state["gaussian"]["_scaling"], device=device)
+ gs._rotation = torch.tensor(state["gaussian"]["_rotation"], device=device)
+ gs._opacity = torch.tensor(state["gaussian"]["_opacity"], device=device)
+
+ mesh = edict(
+ vertices=torch.tensor(state["mesh"]["vertices"], device=device),
+ faces=torch.tensor(state["mesh"]["faces"], device=device),
+ )
+
+ return gs, mesh
+
+
+def get_seed(randomize_seed: bool, seed: int, max_seed: int = MAX_SEED) -> int:
+ return np.random.randint(0, max_seed) if randomize_seed else seed
+
+
+def select_point(
+ image: np.ndarray,
+ sel_pix: list,
+ point_type: str,
+ evt: gr.SelectData,
+):
+ if point_type == "foreground_point":
+ sel_pix.append((evt.index, 1)) # append the foreground_point
+ elif point_type == "background_point":
+ sel_pix.append((evt.index, 0)) # append the background_point
+ else:
+ sel_pix.append((evt.index, 1)) # default foreground_point
+
+ masks = SAM_PREDICTOR.generate_masks(image, sel_pix)
+ seg_image = SAM_PREDICTOR.get_segmented_image(image, masks)
+
+ for point, label in sel_pix:
+ color = (255, 0, 0) if label == 0 else (0, 255, 0)
+ marker_type = 1 if label == 0 else 5
+ cv2.drawMarker(
+ image,
+ point,
+ color,
+ markerType=marker_type,
+ markerSize=15,
+ thickness=10,
+ )
+
+ torch.cuda.empty_cache()
+
+ return (image, masks), seg_image
+
+
+@spaces.GPU
+def image_to_3d(
+ image: Image.Image,
+ seed: int,
+ ss_guidance_strength: float,
+ ss_sampling_steps: int,
+ slat_guidance_strength: float,
+ slat_sampling_steps: int,
+ raw_image_cache: Image.Image,
+ sam_image: Image.Image = None,
+ is_sam_image: bool = False,
+ req: gr.Request = None,
+) -> tuple[dict, str]:
+ if is_sam_image:
+ seg_image = filter_image_small_connected_components(sam_image)
+ seg_image = Image.fromarray(seg_image, mode="RGBA")
+ seg_image = trellis_preprocess(seg_image)
+ else:
+ seg_image = image
+
+ if isinstance(seg_image, np.ndarray):
+ seg_image = Image.fromarray(seg_image)
+
+ output_root = os.path.join(TMP_DIR, str(req.session_hash))
+ os.makedirs(output_root, exist_ok=True)
+ seg_image.save(f"{output_root}/seg_image.png")
+ raw_image_cache.save(f"{output_root}/raw_image.png")
+ PIPELINE.cuda()
+ outputs = PIPELINE.run(
+ seg_image,
+ seed=seed,
+ formats=["gaussian", "mesh"],
+ preprocess_image=False,
+ sparse_structure_sampler_params={
+ "steps": ss_sampling_steps,
+ "cfg_strength": ss_guidance_strength,
+ },
+ slat_sampler_params={
+ "steps": slat_sampling_steps,
+ "cfg_strength": slat_guidance_strength,
+ },
+ )
+ # Set to cpu for memory saving.
+ PIPELINE.cpu()
+
+ gs_model = outputs["gaussian"][0]
+ mesh_model = outputs["mesh"][0]
+ color_images = render_video(gs_model)["color"]
+ normal_images = render_video(mesh_model)["normal"]
+
+ video_path = os.path.join(output_root, "gs_mesh.mp4")
+ merge_images_video(color_images, normal_images, video_path)
+ state = pack_state(gs_model, mesh_model)
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return state, video_path
+
+
+@spaces.GPU
+def extract_3d_representations(
+ state: dict, enable_delight: bool, texture_size: int, req: gr.Request
+):
+ output_root = TMP_DIR
+ output_root = os.path.join(output_root, str(req.session_hash))
+ gs_model, mesh_model = unpack_state(state, device="cuda")
+
+ mesh = postprocessing_utils.to_glb(
+ gs_model,
+ mesh_model,
+ simplify=0.9,
+ texture_size=1024,
+ verbose=True,
+ )
+ filename = "sample"
+ gs_path = os.path.join(output_root, f"{filename}_gs.ply")
+ gs_model.save_ply(gs_path)
+
+ # Rotate mesh and GS by 90 degrees around Z-axis.
+ rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
+ # Addtional rotation for GS to align mesh.
+ gs_rot = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) @ np.array(
+ rot_matrix
+ )
+ pose = GaussianOperator.trans_to_quatpose(gs_rot)
+ aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
+ GaussianOperator.resave_ply(
+ in_ply=gs_path,
+ out_ply=aligned_gs_path,
+ instance_pose=pose,
+ )
+
+ mesh.vertices = mesh.vertices @ np.array(rot_matrix)
+ mesh_obj_path = os.path.join(output_root, f"{filename}.obj")
+ mesh.export(mesh_obj_path)
+ mesh_glb_path = os.path.join(output_root, f"{filename}.glb")
+ mesh.export(mesh_glb_path)
+
+ torch.cuda.empty_cache()
+
+ return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
+
+
+def extract_3d_representations_v2(
+ state: dict,
+ enable_delight: bool,
+ texture_size: int,
+ req: gr.Request,
+):
+ output_root = TMP_DIR
+ user_dir = os.path.join(output_root, str(req.session_hash))
+ gs_model, mesh_model = unpack_state(state, device="cpu")
+
+ filename = "sample"
+ gs_path = os.path.join(user_dir, f"{filename}_gs.ply")
+ gs_model.save_ply(gs_path)
+
+ # Rotate mesh and GS by 90 degrees around Z-axis.
+ rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
+ gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
+ mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
+
+ # Addtional rotation for GS to align mesh.
+ gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix)
+ pose = GaussianOperator.trans_to_quatpose(gs_rot)
+ aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
+ GaussianOperator.resave_ply(
+ in_ply=gs_path,
+ out_ply=aligned_gs_path,
+ instance_pose=pose,
+ device="cpu",
+ )
+ color_path = os.path.join(user_dir, "color.png")
+ render_gs_api(aligned_gs_path, color_path)
+
+ mesh = trimesh.Trimesh(
+ vertices=mesh_model.vertices.cpu().numpy(),
+ faces=mesh_model.faces.cpu().numpy(),
+ )
+ mesh.vertices = mesh.vertices @ np.array(mesh_add_rot)
+ mesh.vertices = mesh.vertices @ np.array(rot_matrix)
+
+ mesh_obj_path = os.path.join(user_dir, f"{filename}.obj")
+ mesh.export(mesh_obj_path)
+
+ mesh = backproject_api(
+ delight_model=DELIGHT,
+ imagesr_model=IMAGESR_MODEL,
+ color_path=color_path,
+ mesh_path=mesh_obj_path,
+ output_path=mesh_obj_path,
+ skip_fix_mesh=False,
+ delight=enable_delight,
+ texture_wh=[texture_size, texture_size],
+ )
+
+ mesh_glb_path = os.path.join(user_dir, f"{filename}.glb")
+ mesh.export(mesh_glb_path)
+
+ return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
+
+
+def extract_urdf(
+ gs_path: str,
+ mesh_obj_path: str,
+ asset_cat_text: str,
+ height_range_text: str,
+ mass_range_text: str,
+ asset_version_text: str,
+ req: gr.Request = None,
+):
+ output_root = TMP_DIR
+ if req is not None:
+ output_root = os.path.join(output_root, str(req.session_hash))
+
+ # Convert to URDF and recover attrs by GPT.
+ filename = "sample"
+ urdf_convertor = URDFGenerator(GPT_CLIENT, render_view_num=4)
+ asset_attrs = {
+ "version": VERSION,
+ "gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply",
+ }
+ if asset_version_text:
+ asset_attrs["version"] = asset_version_text
+ if asset_cat_text:
+ asset_attrs["category"] = asset_cat_text.lower()
+ if height_range_text:
+ try:
+ min_height, max_height = map(float, height_range_text.split("-"))
+ asset_attrs["min_height"] = min_height
+ asset_attrs["max_height"] = max_height
+ except ValueError:
+ return "Invalid height input format. Use the format: min-max."
+ if mass_range_text:
+ try:
+ min_mass, max_mass = map(float, mass_range_text.split("-"))
+ asset_attrs["min_mass"] = min_mass
+ asset_attrs["max_mass"] = max_mass
+ except ValueError:
+ return "Invalid mass input format. Use the format: min-max."
+
+ urdf_path = urdf_convertor(
+ mesh_path=mesh_obj_path,
+ output_root=f"{output_root}/URDF_{filename}",
+ **asset_attrs,
+ )
+
+ # Rescale GS and save to URDF/mesh folder.
+ real_height = urdf_convertor.get_attr_from_urdf(
+ urdf_path, attr_name="real_height"
+ )
+ out_gs = f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}/{filename}_gs.ply" # noqa
+ GaussianOperator.resave_ply(
+ in_ply=gs_path,
+ out_ply=out_gs,
+ real_height=real_height,
+ device="cpu",
+ )
+
+ # Quality check and update .urdf file.
+ mesh_out = f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}/{filename}.obj" # noqa
+ trimesh.load(mesh_out).export(mesh_out.replace(".obj", ".glb"))
+ # image_paths = render_asset3d(
+ # mesh_path=mesh_out,
+ # output_root=f"{output_root}/URDF_{filename}",
+ # output_subdir="qa_renders",
+ # num_images=8,
+ # elevation=(30, -30),
+ # distance=5.5,
+ # )
+
+ image_dir = f"{output_root}/URDF_{filename}/{urdf_convertor.output_render_dir}/image_color" # noqa
+ image_paths = glob(f"{image_dir}/*.png")
+ images_list = []
+ for checker in CHECKERS:
+ images = image_paths
+ if isinstance(checker, ImageSegChecker):
+ images = [
+ f"{TMP_DIR}/{req.session_hash}/raw_image.png",
+ f"{TMP_DIR}/{req.session_hash}/seg_image.png",
+ ]
+ images_list.append(images)
+
+ results = BaseChecker.validate(CHECKERS, images_list)
+ urdf_convertor.add_quality_tag(urdf_path, results)
+
+ # Zip urdf files
+ urdf_zip = zip_files(
+ input_paths=[
+ f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}",
+ f"{output_root}/URDF_{filename}/{filename}.urdf",
+ ],
+ output_zip=f"{output_root}/urdf_{filename}.zip",
+ )
+
+ estimated_type = urdf_convertor.estimated_attrs["category"]
+ estimated_height = urdf_convertor.estimated_attrs["height"]
+ estimated_mass = urdf_convertor.estimated_attrs["mass"]
+ estimated_mu = urdf_convertor.estimated_attrs["mu"]
+
+ return (
+ urdf_zip,
+ estimated_type,
+ estimated_height,
+ estimated_mass,
+ estimated_mu,
+ )
+
+
+@spaces.GPU
+def text2image_fn(
+ prompt: str,
+ guidance_scale: float,
+ infer_step: int = 50,
+ ip_image: Image.Image | str = None,
+ ip_adapt_scale: float = 0.3,
+ image_wh: int | tuple[int, int] = [1024, 1024],
+ rmbg_tag: str = "rembg",
+ seed: int = None,
+ n_sample: int = 3,
+ req: gr.Request = None,
+):
+ if isinstance(image_wh, int):
+ image_wh = (image_wh, image_wh)
+ output_root = TMP_DIR
+ if req is not None:
+ output_root = os.path.join(output_root, str(req.session_hash))
+ os.makedirs(output_root, exist_ok=True)
+
+ pipeline = PIPELINE_IMG if ip_image is None else PIPELINE_IMG_IP
+ if ip_image is not None:
+ pipeline.set_ip_adapter_scale([ip_adapt_scale])
+
+ images = text2img_gen(
+ prompt=prompt,
+ n_sample=n_sample,
+ guidance_scale=guidance_scale,
+ pipeline=pipeline,
+ ip_image=ip_image,
+ image_wh=image_wh,
+ infer_step=infer_step,
+ seed=seed,
+ )
+
+ for idx in range(len(images)):
+ image = images[idx]
+ images[idx], _ = preprocess_image_fn(image, rmbg_tag)
+
+ save_paths = []
+ for idx, image in enumerate(images):
+ save_path = f"{output_root}/sample_{idx}.png"
+ image.save(save_path)
+ save_paths.append(save_path)
+
+ logger.info(f"Images saved to {output_root}")
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return save_paths + save_paths
+
+
+@spaces.GPU
+def generate_condition(mesh_path: str, req: gr.Request, uuid: str = "sample"):
+ output_root = os.path.join(TMP_DIR, str(req.session_hash))
+
+ _ = render_api(
+ mesh_path=mesh_path,
+ output_root=f"{output_root}/condition",
+ uuid=str(uuid),
+ )
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return None, None, None
+
+
+@spaces.GPU
+def generate_texture_mvimages(
+ prompt: str,
+ controlnet_cond_scale: float = 0.55,
+ guidance_scale: float = 9,
+ strength: float = 0.9,
+ num_inference_steps: int = 50,
+ seed: int = 0,
+ ip_adapt_scale: float = 0,
+ ip_img_path: str = None,
+ uid: str = "sample",
+ sub_idxs: tuple[tuple[int]] = ((0, 1, 2), (3, 4, 5)),
+ req: gr.Request = None,
+) -> list[str]:
+ output_root = os.path.join(TMP_DIR, str(req.session_hash))
+ use_ip_adapter = True if ip_img_path and ip_adapt_scale > 0 else False
+ PIPELINE_IP.set_ip_adapter_scale([ip_adapt_scale])
+ img_save_paths = infer_pipe(
+ index_file=f"{output_root}/condition/index.json",
+ controlnet_cond_scale=controlnet_cond_scale,
+ guidance_scale=guidance_scale,
+ strength=strength,
+ num_inference_steps=num_inference_steps,
+ ip_adapt_scale=ip_adapt_scale,
+ ip_img_path=ip_img_path,
+ uid=uid,
+ prompt=prompt,
+ save_dir=f"{output_root}/multi_view",
+ sub_idxs=sub_idxs,
+ pipeline=PIPELINE_IP if use_ip_adapter else PIPELINE,
+ seed=seed,
+ )
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return img_save_paths + img_save_paths
+
+
+def backproject_texture(
+ mesh_path: str,
+ input_image: str,
+ texture_size: int,
+ uuid: str = "sample",
+ req: gr.Request = None,
+) -> str:
+ output_root = os.path.join(TMP_DIR, str(req.session_hash))
+ output_dir = os.path.join(output_root, "texture_mesh")
+ os.makedirs(output_dir, exist_ok=True)
+ command = [
+ "backproject-cli",
+ "--mesh_path",
+ mesh_path,
+ "--input_image",
+ input_image,
+ "--output_root",
+ output_dir,
+ "--uuid",
+ f"{uuid}",
+ "--texture_size",
+ str(texture_size),
+ "--skip_fix_mesh",
+ ]
+
+ _ = subprocess.run(
+ command, capture_output=True, text=True, encoding="utf-8"
+ )
+ output_obj_mesh = os.path.join(output_dir, f"{uuid}.obj")
+ output_glb_mesh = os.path.join(output_dir, f"{uuid}.glb")
+ _ = trimesh.load(output_obj_mesh).export(output_glb_mesh)
+
+ zip_file = zip_files(
+ input_paths=[
+ output_glb_mesh,
+ output_obj_mesh,
+ os.path.join(output_dir, "material.mtl"),
+ os.path.join(output_dir, "material_0.png"),
+ ],
+ output_zip=os.path.join(output_dir, f"{uuid}.zip"),
+ )
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return output_glb_mesh, output_obj_mesh, zip_file
+
+
+@spaces.GPU
+def backproject_texture_v2(
+ mesh_path: str,
+ input_image: str,
+ texture_size: int,
+ enable_delight: bool = True,
+ fix_mesh: bool = False,
+ uuid: str = "sample",
+ req: gr.Request = None,
+) -> str:
+ output_root = os.path.join(TMP_DIR, str(req.session_hash))
+ output_dir = os.path.join(output_root, "texture_mesh")
+ os.makedirs(output_dir, exist_ok=True)
+
+ textured_mesh = backproject_api(
+ delight_model=DELIGHT,
+ imagesr_model=IMAGESR_MODEL,
+ color_path=input_image,
+ mesh_path=mesh_path,
+ output_path=f"{output_dir}/{uuid}.obj",
+ skip_fix_mesh=not fix_mesh,
+ delight=enable_delight,
+ texture_wh=[texture_size, texture_size],
+ )
+
+ output_obj_mesh = os.path.join(output_dir, f"{uuid}.obj")
+ output_glb_mesh = os.path.join(output_dir, f"{uuid}.glb")
+ _ = textured_mesh.export(output_glb_mesh)
+
+ zip_file = zip_files(
+ input_paths=[
+ output_glb_mesh,
+ output_obj_mesh,
+ os.path.join(output_dir, "material.mtl"),
+ os.path.join(output_dir, "material_0.png"),
+ ],
+ output_zip=os.path.join(output_dir, f"{uuid}.zip"),
+ )
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return output_glb_mesh, output_obj_mesh, zip_file
+
+
+@spaces.GPU
+def render_result_video(
+ mesh_path: str, video_size: int, req: gr.Request, uuid: str = ""
+) -> str:
+ output_root = os.path.join(TMP_DIR, str(req.session_hash))
+ output_dir = os.path.join(output_root, "texture_mesh")
+
+ _ = render_api(
+ mesh_path=mesh_path,
+ output_root=output_dir,
+ num_images=90,
+ elevation=[20],
+ with_mtl=True,
+ pbr_light_factor=1,
+ uuid=str(uuid),
+ gen_color_mp4=True,
+ gen_glonormal_mp4=True,
+ distance=5.5,
+ resolution_hw=(video_size, video_size),
+ )
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return f"{output_dir}/color.mp4"
diff --git a/apps/image_to_3d.py b/apps/image_to_3d.py
new file mode 100644
index 0000000..88883f4
--- /dev/null
+++ b/apps/image_to_3d.py
@@ -0,0 +1,501 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+
+import os
+
+os.environ["GRADIO_APP"] = "imageto3d"
+from glob import glob
+
+import gradio as gr
+from common import (
+ MAX_SEED,
+ VERSION,
+ active_btn_by_content,
+ custom_theme,
+ end_session,
+ extract_3d_representations_v2,
+ extract_urdf,
+ get_seed,
+ image_css,
+ image_to_3d,
+ lighting_css,
+ preprocess_image_fn,
+ preprocess_sam_image_fn,
+ select_point,
+ start_session,
+)
+
+with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
+ gr.Markdown(
+ """
+ ## ***EmbodiedGen***: Image-to-3D Asset
+ **๐ Version**: {VERSION}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ ๐ผ๏ธ Generate physically plausible 3D asset from single input image.
+
+ """.format(
+ VERSION=VERSION
+ ),
+ elem_classes=["header"],
+ )
+
+ gr.HTML(image_css)
+ gr.HTML(lighting_css)
+ with gr.Row():
+ with gr.Column(scale=2):
+ with gr.Tabs() as input_tabs:
+ with gr.Tab(
+ label="Image(auto seg)", id=0
+ ) as single_image_input_tab:
+ raw_image_cache = gr.Image(
+ format="png",
+ image_mode="RGB",
+ type="pil",
+ visible=False,
+ )
+ image_prompt = gr.Image(
+ label="Input Image",
+ format="png",
+ image_mode="RGBA",
+ type="pil",
+ height=400,
+ elem_classes=["image_fit"],
+ )
+ gr.Markdown(
+ """
+ If you are not satisfied with the auto segmentation
+ result, please switch to the `Image(SAM seg)` tab."""
+ )
+ with gr.Tab(
+ label="Image(SAM seg)", id=1
+ ) as samimage_input_tab:
+ with gr.Row():
+ with gr.Column(scale=1):
+ image_prompt_sam = gr.Image(
+ label="Input Image",
+ type="numpy",
+ height=400,
+ elem_classes=["image_fit"],
+ )
+ image_seg_sam = gr.Image(
+ label="SAM Seg Image",
+ image_mode="RGBA",
+ type="pil",
+ height=400,
+ visible=False,
+ )
+ with gr.Column(scale=1):
+ image_mask_sam = gr.AnnotatedImage(
+ elem_classes=["image_fit"]
+ )
+
+ fg_bg_radio = gr.Radio(
+ ["foreground_point", "background_point"],
+ label="Select foreground(green) or background(red) points, by default foreground", # noqa
+ value="foreground_point",
+ )
+ gr.Markdown(
+ """ Click the `Input Image` to select SAM points,
+ after get the satisified segmentation, click `Generate`
+ button to generate the 3D asset. \n
+ Note: If the segmented foreground is too small relative
+ to the entire image area, the generation will fail.
+ """
+ )
+
+ with gr.Accordion(label="Generation Settings", open=False):
+ with gr.Row():
+ seed = gr.Slider(
+ 0, MAX_SEED, label="Seed", value=0, step=1
+ )
+ texture_size = gr.Slider(
+ 1024,
+ 4096,
+ label="UV texture size",
+ value=2048,
+ step=256,
+ )
+ rmbg_tag = gr.Radio(
+ choices=["rembg", "rmbg14"],
+ value="rembg",
+ label="Background Removal Model",
+ )
+ with gr.Row():
+ randomize_seed = gr.Checkbox(
+ label="Randomize Seed", value=False
+ )
+ project_delight = gr.Checkbox(
+ label="Backproject delighting",
+ value=False,
+ )
+ gr.Markdown("Geo Structure Generation")
+ with gr.Row():
+ ss_guidance_strength = gr.Slider(
+ 0.0,
+ 10.0,
+ label="Guidance Strength",
+ value=7.5,
+ step=0.1,
+ )
+ ss_sampling_steps = gr.Slider(
+ 1, 50, label="Sampling Steps", value=12, step=1
+ )
+ gr.Markdown("Visual Appearance Generation")
+ with gr.Row():
+ slat_guidance_strength = gr.Slider(
+ 0.0,
+ 10.0,
+ label="Guidance Strength",
+ value=3.0,
+ step=0.1,
+ )
+ slat_sampling_steps = gr.Slider(
+ 1, 50, label="Sampling Steps", value=12, step=1
+ )
+
+ generate_btn = gr.Button(
+ "๐ 1. Generate(~0.5 mins)",
+ variant="primary",
+ interactive=False,
+ )
+ model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False)
+ with gr.Row():
+ extract_rep3d_btn = gr.Button(
+ "๐ 2. Extract 3D Representation(~2 mins)",
+ variant="primary",
+ interactive=False,
+ )
+ with gr.Accordion(
+ label="Enter Asset Attributes(optional)", open=False
+ ):
+ asset_cat_text = gr.Textbox(
+ label="Enter Asset Category (e.g., chair)"
+ )
+ height_range_text = gr.Textbox(
+ label="Enter **Height Range** in meter (e.g., 0.5-0.6)"
+ )
+ mass_range_text = gr.Textbox(
+ label="Enter **Mass Range** in kg (e.g., 1.1-1.2)"
+ )
+ asset_version_text = gr.Textbox(
+ label=f"Enter version (e.g., {VERSION})"
+ )
+ with gr.Row():
+ extract_urdf_btn = gr.Button(
+ "๐งฉ 3. Extract URDF with physics(~1 mins)",
+ variant="primary",
+ interactive=False,
+ )
+ with gr.Row():
+ gr.Markdown(
+ "#### Estimated Asset 3D Attributes(No input required)"
+ )
+ with gr.Row():
+ est_type_text = gr.Textbox(
+ label="Asset category", interactive=False
+ )
+ est_height_text = gr.Textbox(
+ label="Real height(.m)", interactive=False
+ )
+ est_mass_text = gr.Textbox(
+ label="Mass(.kg)", interactive=False
+ )
+ est_mu_text = gr.Textbox(
+ label="Friction coefficient", interactive=False
+ )
+ with gr.Row():
+ download_urdf = gr.DownloadButton(
+ label="โฌ๏ธ 4. Download URDF",
+ variant="primary",
+ interactive=False,
+ )
+
+ gr.Markdown(
+ """ NOTE: If `Asset Attributes` are provided, the provided
+ properties will be used; otherwise, the GPT-preset properties
+ will be applied. \n
+ The `Download URDF` file is restored to the real scale and
+ has quality inspection, open with an editor to view details.
+ """
+ )
+
+ with gr.Row() as single_image_example:
+ examples = gr.Examples(
+ label="Image Gallery",
+ examples=[
+ [image_path]
+ for image_path in sorted(
+ glob("apps/assets/example_image/*")
+ )
+ ],
+ inputs=[image_prompt, rmbg_tag],
+ fn=preprocess_image_fn,
+ outputs=[image_prompt, raw_image_cache],
+ run_on_click=True,
+ examples_per_page=10,
+ )
+
+ with gr.Row(visible=False) as single_sam_image_example:
+ examples = gr.Examples(
+ label="Image Gallery",
+ examples=[
+ [image_path]
+ for image_path in sorted(
+ glob("apps/assets/example_image/*")
+ )
+ ],
+ inputs=[image_prompt_sam],
+ fn=preprocess_sam_image_fn,
+ outputs=[image_prompt_sam, raw_image_cache],
+ run_on_click=True,
+ examples_per_page=10,
+ )
+ with gr.Column(scale=1):
+ video_output = gr.Video(
+ label="Generated 3D Asset",
+ autoplay=True,
+ loop=True,
+ height=300,
+ )
+ model_output_gs = gr.Model3D(
+ label="Gaussian Representation", height=300, interactive=False
+ )
+ aligned_gs = gr.Textbox(visible=False)
+ gr.Markdown(
+ """ The rendering of `Gaussian Representation` takes additional 10s. """ # noqa
+ )
+ with gr.Row():
+ model_output_mesh = gr.Model3D(
+ label="Mesh Representation",
+ height=300,
+ interactive=False,
+ clear_color=[0.8, 0.8, 0.8, 1],
+ elem_id="lighter_mesh",
+ )
+
+ is_samimage = gr.State(False)
+ output_buf = gr.State()
+ selected_points = gr.State(value=[])
+
+ demo.load(start_session)
+ demo.unload(end_session)
+
+ single_image_input_tab.select(
+ lambda: tuple(
+ [False, gr.Row.update(visible=True), gr.Row.update(visible=False)]
+ ),
+ outputs=[is_samimage, single_image_example, single_sam_image_example],
+ )
+ samimage_input_tab.select(
+ lambda: tuple(
+ [True, gr.Row.update(visible=True), gr.Row.update(visible=False)]
+ ),
+ outputs=[is_samimage, single_sam_image_example, single_image_example],
+ )
+
+ image_prompt.upload(
+ preprocess_image_fn,
+ inputs=[image_prompt, rmbg_tag],
+ outputs=[image_prompt, raw_image_cache],
+ )
+ image_prompt.change(
+ lambda: tuple(
+ [
+ gr.Button(interactive=False),
+ gr.Button(interactive=False),
+ gr.Button(interactive=False),
+ None,
+ "",
+ None,
+ None,
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ ]
+ ),
+ outputs=[
+ extract_rep3d_btn,
+ extract_urdf_btn,
+ download_urdf,
+ model_output_gs,
+ aligned_gs,
+ model_output_mesh,
+ video_output,
+ asset_cat_text,
+ height_range_text,
+ mass_range_text,
+ asset_version_text,
+ est_type_text,
+ est_height_text,
+ est_mass_text,
+ est_mu_text,
+ ],
+ )
+ image_prompt.change(
+ active_btn_by_content,
+ inputs=image_prompt,
+ outputs=generate_btn,
+ )
+
+ image_prompt_sam.upload(
+ preprocess_sam_image_fn,
+ inputs=[image_prompt_sam],
+ outputs=[image_prompt_sam, raw_image_cache],
+ )
+ image_prompt_sam.change(
+ lambda: tuple(
+ [
+ gr.Button(interactive=False),
+ gr.Button(interactive=False),
+ gr.Button(interactive=False),
+ None,
+ None,
+ None,
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ None,
+ [],
+ ]
+ ),
+ outputs=[
+ extract_rep3d_btn,
+ extract_urdf_btn,
+ download_urdf,
+ model_output_gs,
+ model_output_mesh,
+ video_output,
+ asset_cat_text,
+ height_range_text,
+ mass_range_text,
+ asset_version_text,
+ est_type_text,
+ est_height_text,
+ est_mass_text,
+ est_mu_text,
+ image_mask_sam,
+ selected_points,
+ ],
+ )
+
+ image_prompt_sam.select(
+ select_point,
+ [
+ image_prompt_sam,
+ selected_points,
+ fg_bg_radio,
+ ],
+ [image_mask_sam, image_seg_sam],
+ )
+ image_seg_sam.change(
+ active_btn_by_content,
+ inputs=image_seg_sam,
+ outputs=generate_btn,
+ )
+
+ generate_btn.click(
+ get_seed,
+ inputs=[randomize_seed, seed],
+ outputs=[seed],
+ ).success(
+ image_to_3d,
+ inputs=[
+ image_prompt,
+ seed,
+ ss_guidance_strength,
+ ss_sampling_steps,
+ slat_guidance_strength,
+ slat_sampling_steps,
+ raw_image_cache,
+ image_seg_sam,
+ is_samimage,
+ ],
+ outputs=[output_buf, video_output],
+ ).success(
+ lambda: gr.Button(interactive=True),
+ outputs=[extract_rep3d_btn],
+ )
+
+ extract_rep3d_btn.click(
+ extract_3d_representations_v2,
+ inputs=[
+ output_buf,
+ project_delight,
+ texture_size,
+ ],
+ outputs=[
+ model_output_mesh,
+ model_output_gs,
+ model_output_obj,
+ aligned_gs,
+ ],
+ ).success(
+ lambda: gr.Button(interactive=True),
+ outputs=[extract_urdf_btn],
+ )
+
+ extract_urdf_btn.click(
+ extract_urdf,
+ inputs=[
+ aligned_gs,
+ model_output_obj,
+ asset_cat_text,
+ height_range_text,
+ mass_range_text,
+ asset_version_text,
+ ],
+ outputs=[
+ download_urdf,
+ est_type_text,
+ est_height_text,
+ est_mass_text,
+ est_mu_text,
+ ],
+ queue=True,
+ show_progress="full",
+ ).success(
+ lambda: gr.Button(interactive=True),
+ outputs=[download_urdf],
+ )
+
+
+if __name__ == "__main__":
+ demo.launch(server_name="10.34.8.82", server_port=8081)
diff --git a/apps/text_to_3d.py b/apps/text_to_3d.py
new file mode 100644
index 0000000..a36a99f
--- /dev/null
+++ b/apps/text_to_3d.py
@@ -0,0 +1,481 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+
+import os
+
+os.environ["GRADIO_APP"] = "textto3d"
+
+
+import gradio as gr
+from common import (
+ MAX_SEED,
+ VERSION,
+ active_btn_by_text_content,
+ custom_theme,
+ end_session,
+ extract_3d_representations_v2,
+ extract_urdf,
+ get_cached_image,
+ get_seed,
+ get_selected_image,
+ image_css,
+ image_to_3d,
+ lighting_css,
+ start_session,
+ text2image_fn,
+)
+
+with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
+ gr.Markdown(
+ """
+ ## ***EmbodiedGen***: Text-to-3D Asset
+ **๐ Version**: {VERSION}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ ๐ Create 3D assets from text descriptions for a wide range of geometry and styles.
+
+ """.format(
+ VERSION=VERSION
+ ),
+ elem_classes=["header"],
+ )
+ gr.HTML(image_css)
+ gr.HTML(lighting_css)
+ with gr.Row():
+ with gr.Column(scale=1):
+ raw_image_cache = gr.Image(
+ format="png",
+ image_mode="RGB",
+ type="pil",
+ visible=False,
+ )
+ text_prompt = gr.Textbox(
+ label="Text Prompt (Chinese or English)",
+ placeholder="Input text prompt here",
+ )
+ ip_image = gr.Image(
+ label="Reference Image(optional)",
+ format="png",
+ image_mode="RGB",
+ type="filepath",
+ height=250,
+ elem_classes=["image_fit"],
+ )
+ gr.Markdown(
+ "Note: The `reference image` is optional, if use, "
+ "please provide image in nearly square resolution."
+ )
+
+ with gr.Accordion(label="Image Generation Settings", open=False):
+ with gr.Row():
+ seed = gr.Slider(
+ 0, MAX_SEED, label="Seed", value=0, step=1
+ )
+ randomize_seed = gr.Checkbox(
+ label="Randomize Seed", value=False
+ )
+ rmbg_tag = gr.Radio(
+ choices=["rembg", "rmbg14"],
+ value="rembg",
+ label="Background Removal Model",
+ )
+ ip_adapt_scale = gr.Slider(
+ 0, 1, label="IP-adapter Scale", value=0.3, step=0.05
+ )
+ img_guidance_scale = gr.Slider(
+ 1, 30, label="Text Guidance Scale", value=12, step=0.2
+ )
+ img_inference_steps = gr.Slider(
+ 10, 100, label="Sampling Steps", value=50, step=5
+ )
+ img_resolution = gr.Slider(
+ 512,
+ 1536,
+ label="Image Resolution",
+ value=1024,
+ step=128,
+ )
+
+ generate_img_btn = gr.Button(
+ "๐จ 1. Generate Images(~1min)",
+ variant="primary",
+ interactive=False,
+ )
+ dropdown = gr.Radio(
+ choices=["sample1", "sample2", "sample3"],
+ value="sample1",
+ label="Choose your favorite sample style.",
+ )
+ select_img = gr.Image(
+ visible=False,
+ format="png",
+ image_mode="RGBA",
+ type="pil",
+ height=300,
+ )
+
+ # text to 3d
+ with gr.Accordion(label="Generation Settings", open=False):
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
+ texture_size = gr.Slider(
+ 1024, 4096, label="UV texture size", value=2048, step=256
+ )
+ with gr.Row():
+ randomize_seed = gr.Checkbox(
+ label="Randomize Seed", value=False
+ )
+ project_delight = gr.Checkbox(
+ label="backproject delight", value=False
+ )
+ gr.Markdown("Geo Structure Generation")
+ with gr.Row():
+ ss_guidance_strength = gr.Slider(
+ 0.0,
+ 10.0,
+ label="Guidance Strength",
+ value=7.5,
+ step=0.1,
+ )
+ ss_sampling_steps = gr.Slider(
+ 1, 50, label="Sampling Steps", value=12, step=1
+ )
+ gr.Markdown("Visual Appearance Generation")
+ with gr.Row():
+ slat_guidance_strength = gr.Slider(
+ 0.0,
+ 10.0,
+ label="Guidance Strength",
+ value=3.0,
+ step=0.1,
+ )
+ slat_sampling_steps = gr.Slider(
+ 1, 50, label="Sampling Steps", value=12, step=1
+ )
+
+ generate_btn = gr.Button(
+ "๐ 2. Generate 3D(~0.5 mins)",
+ variant="primary",
+ interactive=False,
+ )
+ model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False)
+ with gr.Row():
+ extract_rep3d_btn = gr.Button(
+ "๐ 3. Extract 3D Representation(~1 mins)",
+ variant="primary",
+ interactive=False,
+ )
+ with gr.Accordion(
+ label="Enter Asset Attributes(optional)", open=False
+ ):
+ asset_cat_text = gr.Textbox(
+ label="Enter Asset Category (e.g., chair)"
+ )
+ height_range_text = gr.Textbox(
+ label="Enter Height Range in meter (e.g., 0.5-0.6)"
+ )
+ mass_range_text = gr.Textbox(
+ label="Enter Mass Range in kg (e.g., 1.1-1.2)"
+ )
+ asset_version_text = gr.Textbox(
+ label=f"Enter version (e.g., {VERSION})"
+ )
+ with gr.Row():
+ extract_urdf_btn = gr.Button(
+ "๐งฉ 4. Extract URDF with physics(~1 mins)",
+ variant="primary",
+ interactive=False,
+ )
+ with gr.Row():
+ download_urdf = gr.DownloadButton(
+ label="โฌ๏ธ 5. Download URDF",
+ variant="primary",
+ interactive=False,
+ )
+
+ with gr.Column(scale=3):
+ with gr.Row():
+ image_sample1 = gr.Image(
+ label="sample1",
+ format="png",
+ image_mode="RGBA",
+ type="filepath",
+ height=300,
+ interactive=False,
+ elem_classes=["image_fit"],
+ )
+ image_sample2 = gr.Image(
+ label="sample2",
+ format="png",
+ image_mode="RGBA",
+ type="filepath",
+ height=300,
+ interactive=False,
+ elem_classes=["image_fit"],
+ )
+ image_sample3 = gr.Image(
+ label="sample3",
+ format="png",
+ image_mode="RGBA",
+ type="filepath",
+ height=300,
+ interactive=False,
+ elem_classes=["image_fit"],
+ )
+ usample1 = gr.Image(
+ format="png",
+ image_mode="RGBA",
+ type="filepath",
+ visible=False,
+ )
+ usample2 = gr.Image(
+ format="png",
+ image_mode="RGBA",
+ type="filepath",
+ visible=False,
+ )
+ usample3 = gr.Image(
+ format="png",
+ image_mode="RGBA",
+ type="filepath",
+ visible=False,
+ )
+ gr.Markdown(
+ "The generated image may be of poor quality due to auto "
+ "segmentation. Try adjusting the text prompt or seed."
+ )
+ with gr.Row():
+ video_output = gr.Video(
+ label="Generated 3D Asset",
+ autoplay=True,
+ loop=True,
+ height=300,
+ interactive=False,
+ )
+ model_output_gs = gr.Model3D(
+ label="Gaussian Representation",
+ height=300,
+ interactive=False,
+ )
+ aligned_gs = gr.Textbox(visible=False)
+
+ model_output_mesh = gr.Model3D(
+ label="Mesh Representation",
+ clear_color=[0.8, 0.8, 0.8, 1],
+ height=300,
+ interactive=False,
+ elem_id="lighter_mesh",
+ )
+
+ gr.Markdown("Estimated Asset 3D Attributes(No input required)")
+ with gr.Row():
+ est_type_text = gr.Textbox(
+ label="Asset category", interactive=False
+ )
+ est_height_text = gr.Textbox(
+ label="Real height(.m)", interactive=False
+ )
+ est_mass_text = gr.Textbox(
+ label="Mass(.kg)", interactive=False
+ )
+ est_mu_text = gr.Textbox(
+ label="Friction coefficient", interactive=False
+ )
+
+ prompt_examples = [
+ "satin gold tea cup with saucer",
+ "small bronze figurine of a lion",
+ "brown leather bag",
+ "Miniature cup with floral design",
+ "ๅธฆๆจ่ดจๅบๅบง, ๅ
ทๆ็ป็บฌ็บฟ็ๅฐ็ไปช",
+ "ๆฉ่ฒ็ตๅจๆ้ป, ๆ็ฃจๆ็ป่",
+ "ๆๅทฅๅถไฝ็็ฎ้ฉ็ฌ่ฎฐๆฌ",
+ ]
+ examples = gr.Examples(
+ label="Gallery",
+ examples=prompt_examples,
+ inputs=[text_prompt],
+ examples_per_page=10,
+ )
+
+ output_buf = gr.State()
+
+ demo.load(start_session)
+ demo.unload(end_session)
+
+ text_prompt.change(
+ active_btn_by_text_content,
+ inputs=[text_prompt],
+ outputs=[generate_img_btn],
+ )
+
+ generate_img_btn.click(
+ lambda: tuple(
+ [
+ gr.Button(interactive=False),
+ gr.Button(interactive=False),
+ gr.Button(interactive=False),
+ gr.Button(interactive=False),
+ None,
+ "",
+ None,
+ None,
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ None,
+ None,
+ None,
+ ]
+ ),
+ outputs=[
+ extract_rep3d_btn,
+ extract_urdf_btn,
+ download_urdf,
+ generate_btn,
+ model_output_gs,
+ aligned_gs,
+ model_output_mesh,
+ video_output,
+ asset_cat_text,
+ height_range_text,
+ mass_range_text,
+ asset_version_text,
+ est_type_text,
+ est_height_text,
+ est_mass_text,
+ est_mu_text,
+ image_sample1,
+ image_sample2,
+ image_sample3,
+ ],
+ ).success(
+ text2image_fn,
+ inputs=[
+ text_prompt,
+ img_guidance_scale,
+ img_inference_steps,
+ ip_image,
+ ip_adapt_scale,
+ img_resolution,
+ rmbg_tag,
+ seed,
+ ],
+ outputs=[
+ image_sample1,
+ image_sample2,
+ image_sample3,
+ usample1,
+ usample2,
+ usample3,
+ ],
+ ).success(
+ lambda: gr.Button(interactive=True),
+ outputs=[generate_btn],
+ )
+
+ generate_btn.click(
+ get_seed,
+ inputs=[randomize_seed, seed],
+ outputs=[seed],
+ ).success(
+ get_selected_image,
+ inputs=[dropdown, usample1, usample2, usample3],
+ outputs=select_img,
+ ).success(
+ get_cached_image,
+ inputs=[select_img],
+ outputs=[raw_image_cache],
+ ).success(
+ image_to_3d,
+ inputs=[
+ select_img,
+ seed,
+ ss_guidance_strength,
+ ss_sampling_steps,
+ slat_guidance_strength,
+ slat_sampling_steps,
+ raw_image_cache,
+ ],
+ outputs=[output_buf, video_output],
+ ).success(
+ lambda: gr.Button(interactive=True),
+ outputs=[extract_rep3d_btn],
+ )
+
+ extract_rep3d_btn.click(
+ extract_3d_representations_v2,
+ inputs=[
+ output_buf,
+ project_delight,
+ texture_size,
+ ],
+ outputs=[
+ model_output_mesh,
+ model_output_gs,
+ model_output_obj,
+ aligned_gs,
+ ],
+ ).success(
+ lambda: gr.Button(interactive=True),
+ outputs=[extract_urdf_btn],
+ )
+
+ extract_urdf_btn.click(
+ extract_urdf,
+ inputs=[
+ aligned_gs,
+ model_output_obj,
+ asset_cat_text,
+ height_range_text,
+ mass_range_text,
+ asset_version_text,
+ ],
+ outputs=[
+ download_urdf,
+ est_type_text,
+ est_height_text,
+ est_mass_text,
+ est_mu_text,
+ ],
+ queue=True,
+ show_progress="full",
+ ).success(
+ lambda: gr.Button(interactive=True),
+ outputs=[download_urdf],
+ )
+
+
+if __name__ == "__main__":
+ demo.launch(server_name="10.34.8.82", server_port=8082)
diff --git a/apps/texture_edit.py b/apps/texture_edit.py
new file mode 100644
index 0000000..8d14bca
--- /dev/null
+++ b/apps/texture_edit.py
@@ -0,0 +1,382 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+
+import os
+
+os.environ["GRADIO_APP"] = "texture_edit"
+import gradio as gr
+from common import (
+ MAX_SEED,
+ VERSION,
+ backproject_texture_v2,
+ custom_theme,
+ end_session,
+ generate_condition,
+ generate_texture_mvimages,
+ get_seed,
+ get_selected_image,
+ image_css,
+ lighting_css,
+ render_result_video,
+ start_session,
+)
+
+
+def active_btn_by_content(mesh_content: gr.Model3D, text_content: gr.Textbox):
+ if (
+ mesh_content is not None
+ and text_content is not None
+ and len(text_content) > 0
+ ):
+ interactive = True
+ else:
+ interactive = False
+
+ return gr.Button(interactive=interactive)
+
+
+with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
+ gr.Markdown(
+ """
+ ## ***EmbodiedGen***: Texture Generation
+ **๐ Version**: {VERSION}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ ๐จ Generate visually rich textures for 3D mesh.
+
+ """.format(
+ VERSION=VERSION
+ ),
+ elem_classes=["header"],
+ )
+ gr.HTML(image_css)
+ gr.HTML(lighting_css)
+ with gr.Row():
+ with gr.Column(scale=1):
+ mesh_input = gr.Model3D(
+ label="Upload Mesh File(.obj or .glb)", height=300
+ )
+ local_mesh = gr.Textbox(visible=False)
+ text_prompt = gr.Textbox(
+ label="Text Prompt (Chinese or English)",
+ placeholder="Input text prompt here",
+ )
+ ip_image = gr.Image(
+ label="Reference Image(optional)",
+ format="png",
+ image_mode="RGB",
+ type="filepath",
+ height=250,
+ elem_classes=["image_fit"],
+ )
+ gr.Markdown(
+ "Note: The `reference image` is optional. If provided, please "
+ "increase the `Condition Scale` in Generation Settings."
+ )
+
+ with gr.Accordion(label="Generation Settings", open=False):
+ with gr.Row():
+ seed = gr.Slider(
+ 0, MAX_SEED, label="Seed", value=0, step=1
+ )
+ randomize_seed = gr.Checkbox(
+ label="Randomize Seed", value=False
+ )
+ ip_adapt_scale = gr.Slider(
+ 0, 1, label="IP-adapter Scale", value=0.7, step=0.05
+ )
+ cond_scale = gr.Slider(
+ 0.0,
+ 1.0,
+ label="Geo Condition Scale",
+ value=0.60,
+ step=0.01,
+ )
+ guidance_scale = gr.Slider(
+ 1, 30, label="Text Guidance Scale", value=9, step=0.2
+ )
+ guidance_strength = gr.Slider(
+ 0.0,
+ 1.0,
+ label="Strength",
+ value=0.9,
+ step=0.05,
+ )
+ num_inference_steps = gr.Slider(
+ 10, 100, label="Sampling Steps", value=50, step=5
+ )
+ texture_size = gr.Slider(
+ 1024, 4096, label="UV texture size", value=2048, step=256
+ )
+ video_size = gr.Slider(
+ 512, 2048, label="Video Resolution", value=512, step=256
+ )
+
+ generate_mv_btn = gr.Button(
+ "๐จ 1. Generate MV Images(~1min)",
+ variant="primary",
+ interactive=False,
+ )
+
+ with gr.Column(scale=3):
+ with gr.Row():
+ image_sample1 = gr.Image(
+ label="sample1",
+ format="png",
+ image_mode="RGBA",
+ type="filepath",
+ height=300,
+ interactive=False,
+ elem_classes=["image_fit"],
+ )
+ image_sample2 = gr.Image(
+ label="sample2",
+ format="png",
+ image_mode="RGBA",
+ type="filepath",
+ height=300,
+ interactive=False,
+ elem_classes=["image_fit"],
+ )
+ image_sample3 = gr.Image(
+ label="sample3",
+ format="png",
+ image_mode="RGBA",
+ type="filepath",
+ height=300,
+ interactive=False,
+ elem_classes=["image_fit"],
+ )
+
+ usample1 = gr.Image(
+ format="png",
+ image_mode="RGBA",
+ type="filepath",
+ visible=False,
+ )
+ usample2 = gr.Image(
+ format="png",
+ image_mode="RGBA",
+ type="filepath",
+ visible=False,
+ )
+ usample3 = gr.Image(
+ format="png",
+ image_mode="RGBA",
+ type="filepath",
+ visible=False,
+ )
+
+ gr.Markdown(
+ "Note: Select samples with consistent textures from various "
+ "perspectives and no obvious reflections."
+ )
+ with gr.Row():
+ with gr.Column(scale=1):
+ with gr.Row():
+ dropdown = gr.Radio(
+ choices=["sample1", "sample2", "sample3"],
+ value="sample1",
+ label="Choose your favorite sample style.",
+ )
+ select_img = gr.Image(
+ visible=False,
+ format="png",
+ image_mode="RGBA",
+ type="filepath",
+ height=300,
+ )
+ with gr.Row():
+ project_delight = gr.Checkbox(
+ label="delight", value=True
+ )
+ fix_mesh = gr.Checkbox(
+ label="simplify mesh", value=False
+ )
+
+ with gr.Column(scale=1):
+ texture_bake_btn = gr.Button(
+ "๐ ๏ธ 2. Texture Baking(~2min)",
+ variant="primary",
+ interactive=False,
+ )
+ download_btn = gr.DownloadButton(
+ label="โฌ๏ธ 3. Download Mesh",
+ variant="primary",
+ interactive=False,
+ )
+
+ with gr.Row():
+ mesh_output = gr.Model3D(
+ label="Mesh Edit Result",
+ clear_color=[0.8, 0.8, 0.8, 1],
+ height=380,
+ interactive=False,
+ elem_id="lighter_mesh",
+ )
+ mesh_outpath = gr.Textbox(visible=False)
+ video_output = gr.Video(
+ label="Mesh Edit Video",
+ autoplay=True,
+ loop=True,
+ height=380,
+ )
+
+ with gr.Row():
+ prompt_examples = []
+ with open("apps/assets/example_texture/text_prompts.txt", "r") as f:
+ for line in f:
+ parts = line.strip().split("\\")
+ prompt_examples.append([parts[0].strip(), parts[1].strip()])
+
+ examples = gr.Examples(
+ label="Mesh Gallery",
+ examples=prompt_examples,
+ inputs=[mesh_input, text_prompt],
+ examples_per_page=10,
+ )
+
+ demo.load(start_session)
+ demo.unload(end_session)
+
+ mesh_input.change(
+ lambda: tuple(
+ [
+ None,
+ None,
+ None,
+ gr.Button(interactive=False),
+ gr.Button(interactive=False),
+ None,
+ None,
+ None,
+ ]
+ ),
+ outputs=[
+ mesh_outpath,
+ mesh_output,
+ video_output,
+ texture_bake_btn,
+ download_btn,
+ image_sample1,
+ image_sample2,
+ image_sample3,
+ ],
+ ).success(
+ active_btn_by_content,
+ inputs=[mesh_input, text_prompt],
+ outputs=[generate_mv_btn],
+ )
+
+ text_prompt.change(
+ active_btn_by_content,
+ inputs=[mesh_input, text_prompt],
+ outputs=[generate_mv_btn],
+ )
+
+ generate_mv_btn.click(
+ get_seed,
+ inputs=[randomize_seed, seed],
+ outputs=[seed],
+ ).success(
+ lambda: tuple(
+ [
+ None,
+ None,
+ None,
+ gr.Button(interactive=False),
+ gr.Button(interactive=False),
+ ]
+ ),
+ outputs=[
+ mesh_outpath,
+ mesh_output,
+ video_output,
+ texture_bake_btn,
+ download_btn,
+ ],
+ ).success(
+ generate_condition,
+ inputs=[mesh_input],
+ outputs=[image_sample1, image_sample2, image_sample3],
+ ).success(
+ generate_texture_mvimages,
+ inputs=[
+ text_prompt,
+ cond_scale,
+ guidance_scale,
+ guidance_strength,
+ num_inference_steps,
+ seed,
+ ip_adapt_scale,
+ ip_image,
+ ],
+ outputs=[
+ image_sample1,
+ image_sample2,
+ image_sample3,
+ usample1,
+ usample2,
+ usample3,
+ ],
+ ).success(
+ lambda: gr.Button(interactive=True),
+ outputs=[texture_bake_btn],
+ )
+
+ texture_bake_btn.click(
+ lambda: tuple([None, None, None, gr.Button(interactive=False)]),
+ outputs=[mesh_outpath, mesh_output, video_output, download_btn],
+ ).success(
+ get_selected_image,
+ inputs=[dropdown, usample1, usample2, usample3],
+ outputs=select_img,
+ ).success(
+ backproject_texture_v2,
+ inputs=[
+ mesh_input,
+ select_img,
+ texture_size,
+ project_delight,
+ fix_mesh,
+ ],
+ outputs=[mesh_output, mesh_outpath, download_btn],
+ ).success(
+ lambda: gr.DownloadButton(interactive=True),
+ outputs=[download_btn],
+ ).success(
+ render_result_video,
+ inputs=[mesh_outpath, video_size],
+ outputs=[video_output],
+ )
+
+
+if __name__ == "__main__":
+ demo.launch(server_name="10.34.8.82", server_port=8083)
diff --git a/embodied_gen/data/backproject.py b/embodied_gen/data/backproject.py
new file mode 100644
index 0000000..b02ae27
--- /dev/null
+++ b/embodied_gen/data/backproject.py
@@ -0,0 +1,518 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+
+import argparse
+import logging
+import math
+import os
+from typing import List, Literal, Union
+
+import cv2
+import numpy as np
+import nvdiffrast.torch as dr
+import torch
+import trimesh
+import utils3d
+import xatlas
+from tqdm import tqdm
+from embodied_gen.data.mesh_operator import MeshFixer
+from embodied_gen.data.utils import (
+ CameraSetting,
+ get_images_from_grid,
+ init_kal_camera,
+ normalize_vertices_array,
+ post_process_texture,
+ save_mesh_with_mtl,
+)
+from embodied_gen.models.delight_model import DelightingModel
+
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
+)
+logger = logging.getLogger(__name__)
+
+
+class TextureBaker(object):
+ """Baking textures onto a mesh from multiple observations.
+
+ This class take 3D mesh data, camera settings and texture baking parameters
+ to generate texture map by projecting images to the mesh from diff views.
+ It supports both a fast texture baking approach and a more optimized method
+ with total variation regularization.
+
+ Attributes:
+ vertices (torch.Tensor): The vertices of the mesh.
+ faces (torch.Tensor): The faces of the mesh, defined by vertex indices.
+ uvs (torch.Tensor): The UV coordinates of the mesh.
+ camera_params (CameraSetting): Camera setting (intrinsics, extrinsics).
+ device (str): The device to run computations on ("cpu" or "cuda").
+ w2cs (torch.Tensor): World-to-camera transformation matrices.
+ projections (torch.Tensor): Camera projection matrices.
+
+ Example:
+ >>> vertices, faces, uvs = TextureBaker.parametrize_mesh(vertices, faces) # noqa
+ >>> texture_backer = TextureBaker(vertices, faces, uvs, camera_params)
+ >>> images = get_images_from_grid(args.color_path, image_size)
+ >>> texture = texture_backer.bake_texture(
+ ... images, texture_size=args.texture_size, mode=args.baker_mode
+ ... )
+ >>> texture = post_process_texture(texture)
+ """
+
+ def __init__(
+ self,
+ vertices: np.ndarray,
+ faces: np.ndarray,
+ uvs: np.ndarray,
+ camera_params: CameraSetting,
+ device: str = "cuda",
+ ) -> None:
+ self.vertices = (
+ torch.tensor(vertices, device=device)
+ if isinstance(vertices, np.ndarray)
+ else vertices.to(device)
+ )
+ self.faces = (
+ torch.tensor(faces.astype(np.int32), device=device)
+ if isinstance(faces, np.ndarray)
+ else faces.to(device)
+ )
+ self.uvs = (
+ torch.tensor(uvs, device=device)
+ if isinstance(uvs, np.ndarray)
+ else uvs.to(device)
+ )
+ self.camera_params = camera_params
+ self.device = device
+
+ camera = init_kal_camera(camera_params)
+ matrix_mv = camera.view_matrix() # (n_cam 4 4) world2cam
+ matrix_mv = kaolin_to_opencv_view(matrix_mv)
+ matrix_p = (
+ camera.intrinsics.projection_matrix()
+ ) # (n_cam 4 4) cam2pixel
+ self.w2cs = matrix_mv.to(self.device)
+ self.projections = matrix_p.to(self.device)
+
+ @staticmethod
+ def parametrize_mesh(
+ vertices: np.array, faces: np.array
+ ) -> Union[np.array, np.array, np.array]:
+ vmapping, indices, uvs = xatlas.parametrize(vertices, faces)
+
+ vertices = vertices[vmapping]
+ faces = indices
+
+ return vertices, faces, uvs
+
+ def _bake_fast(self, observations, w2cs, projections, texture_size, masks):
+ texture = torch.zeros(
+ (texture_size * texture_size, 3), dtype=torch.float32
+ ).cuda()
+ texture_weights = torch.zeros(
+ (texture_size * texture_size), dtype=torch.float32
+ ).cuda()
+ rastctx = utils3d.torch.RastContext(backend="cuda")
+ for observation, w2c, projection in tqdm(
+ zip(observations, w2cs, projections),
+ total=len(observations),
+ desc="Texture baking (fast)",
+ ):
+ with torch.no_grad():
+ rast = utils3d.torch.rasterize_triangle_faces(
+ rastctx,
+ self.vertices[None],
+ self.faces,
+ observation.shape[1],
+ observation.shape[0],
+ uv=self.uvs[None],
+ view=w2c,
+ projection=projection,
+ )
+ uv_map = rast["uv"][0].detach().flip(0)
+ mask = rast["mask"][0].detach().bool() & masks[0]
+
+ # nearest neighbor interpolation
+ uv_map = (uv_map * texture_size).floor().long()
+ obs = observation[mask]
+ uv_map = uv_map[mask]
+ idx = (
+ uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size
+ )
+ texture = texture.scatter_add(
+ 0, idx.view(-1, 1).expand(-1, 3), obs
+ )
+ texture_weights = texture_weights.scatter_add(
+ 0,
+ idx,
+ torch.ones(
+ (obs.shape[0]), dtype=torch.float32, device=texture.device
+ ),
+ )
+
+ mask = texture_weights > 0
+ texture[mask] /= texture_weights[mask][:, None]
+ texture = np.clip(
+ texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255,
+ 0,
+ 255,
+ ).astype(np.uint8)
+
+ # inpaint
+ mask = (
+ (texture_weights == 0)
+ .cpu()
+ .numpy()
+ .astype(np.uint8)
+ .reshape(texture_size, texture_size)
+ )
+ texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
+
+ return texture
+
+ def _bake_opt(
+ self,
+ observations,
+ w2cs,
+ projections,
+ texture_size,
+ lambda_tv,
+ masks,
+ total_steps,
+ ):
+ rastctx = utils3d.torch.RastContext(backend="cuda")
+ observations = [observations.flip(0) for observations in observations]
+ masks = [m.flip(0) for m in masks]
+ _uv = []
+ _uv_dr = []
+ for observation, w2c, projection in tqdm(
+ zip(observations, w2cs, projections),
+ total=len(w2cs),
+ ):
+ with torch.no_grad():
+ rast = utils3d.torch.rasterize_triangle_faces(
+ rastctx,
+ self.vertices[None],
+ self.faces,
+ observation.shape[1],
+ observation.shape[0],
+ uv=self.uvs[None],
+ view=w2c,
+ projection=projection,
+ )
+ _uv.append(rast["uv"].detach())
+ _uv_dr.append(rast["uv_dr"].detach())
+
+ texture = torch.nn.Parameter(
+ torch.zeros(
+ (1, texture_size, texture_size, 3), dtype=torch.float32
+ ).cuda()
+ )
+ optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2)
+
+ def cosine_anealing(step, total_steps, start_lr, end_lr):
+ return end_lr + 0.5 * (start_lr - end_lr) * (
+ 1 + np.cos(np.pi * step / total_steps)
+ )
+
+ def tv_loss(texture):
+ return torch.nn.functional.l1_loss(
+ texture[:, :-1, :, :], texture[:, 1:, :, :]
+ ) + torch.nn.functional.l1_loss(
+ texture[:, :, :-1, :], texture[:, :, 1:, :]
+ )
+
+ with tqdm(total=total_steps, desc="Texture baking") as pbar:
+ for step in range(total_steps):
+ optimizer.zero_grad()
+ selected = np.random.randint(0, len(w2cs))
+ uv, uv_dr, observation, mask = (
+ _uv[selected],
+ _uv_dr[selected],
+ observations[selected],
+ masks[selected],
+ )
+ render = dr.texture(texture, uv, uv_dr)[0]
+ loss = torch.nn.functional.l1_loss(
+ render[mask], observation[mask]
+ )
+ if lambda_tv > 0:
+ loss += lambda_tv * tv_loss(texture)
+ loss.backward()
+ optimizer.step()
+
+ optimizer.param_groups[0]["lr"] = cosine_anealing(
+ step, total_steps, 1e-2, 1e-5
+ )
+ pbar.set_postfix({"loss": loss.item()})
+ pbar.update()
+ texture = np.clip(
+ texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255
+ ).astype(np.uint8)
+ mask = 1 - utils3d.torch.rasterize_triangle_faces(
+ rastctx,
+ (self.uvs * 2 - 1)[None],
+ self.faces,
+ texture_size,
+ texture_size,
+ )["mask"][0].detach().cpu().numpy().astype(np.uint8)
+ texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
+
+ return texture
+
+ def bake_texture(
+ self,
+ images: List[np.array],
+ texture_size: int = 1024,
+ mode: Literal["fast", "opt"] = "opt",
+ lambda_tv: float = 1e-2,
+ opt_step: int = 2000,
+ ):
+ masks = [np.any(img > 0, axis=-1) for img in images]
+ masks = [torch.tensor(m > 0).bool().to(self.device) for m in masks]
+ images = [
+ torch.tensor(obs / 255.0).float().to(self.device) for obs in images
+ ]
+
+ if mode == "fast":
+ return self._bake_fast(
+ images, self.w2cs, self.projections, texture_size, masks
+ )
+ elif mode == "opt":
+ return self._bake_opt(
+ images,
+ self.w2cs,
+ self.projections,
+ texture_size,
+ lambda_tv,
+ masks,
+ opt_step,
+ )
+ else:
+ raise ValueError(f"Unknown mode: {mode}")
+
+
+def kaolin_to_opencv_view(raw_matrix):
+ R_orig = raw_matrix[:, :3, :3]
+ t_orig = raw_matrix[:, :3, 3]
+
+ R_target = torch.zeros_like(R_orig)
+ R_target[:, :, 0] = R_orig[:, :, 2]
+ R_target[:, :, 1] = R_orig[:, :, 0]
+ R_target[:, :, 2] = R_orig[:, :, 1]
+
+ t_target = t_orig
+
+ target_matrix = (
+ torch.eye(4, device=raw_matrix.device)
+ .unsqueeze(0)
+ .repeat(raw_matrix.size(0), 1, 1)
+ )
+ target_matrix[:, :3, :3] = R_target
+ target_matrix[:, :3, 3] = t_target
+
+ return target_matrix
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Render settings")
+
+ parser.add_argument(
+ "--mesh_path",
+ type=str,
+ nargs="+",
+ required=True,
+ help="Paths to the mesh files for rendering.",
+ )
+ parser.add_argument(
+ "--color_path",
+ type=str,
+ nargs="+",
+ required=True,
+ help="Paths to the mesh files for rendering.",
+ )
+ parser.add_argument(
+ "--output_root",
+ type=str,
+ default="./outputs",
+ help="Root directory for output",
+ )
+ parser.add_argument(
+ "--uuid",
+ type=str,
+ nargs="+",
+ default=None,
+ help="uuid for rendering saving.",
+ )
+ parser.add_argument(
+ "--num_images", type=int, default=6, help="Number of images to render."
+ )
+ parser.add_argument(
+ "--elevation",
+ type=float,
+ nargs="+",
+ default=[20.0, -10.0],
+ help="Elevation angles for the camera (default: [20.0, -10.0])",
+ )
+ parser.add_argument(
+ "--distance",
+ type=float,
+ default=5,
+ help="Camera distance (default: 5)",
+ )
+ parser.add_argument(
+ "--resolution_hw",
+ type=int,
+ nargs=2,
+ default=(512, 512),
+ help="Resolution of the output images (default: (512, 512))",
+ )
+ parser.add_argument(
+ "--fov",
+ type=float,
+ default=30,
+ help="Field of view in degrees (default: 30)",
+ )
+ parser.add_argument(
+ "--device",
+ type=str,
+ choices=["cpu", "cuda"],
+ default="cuda",
+ help="Device to run on (default: `cuda`)",
+ )
+ parser.add_argument(
+ "--texture_size",
+ type=int,
+ default=1024,
+ help="Texture size for texture baking (default: 1024)",
+ )
+ parser.add_argument(
+ "--baker_mode",
+ type=str,
+ default="opt",
+ help="Texture baking mode, `fast` or `opt` (default: opt)",
+ )
+ parser.add_argument(
+ "--opt_step",
+ type=int,
+ default=2500,
+ help="Optimization steps for texture baking (default: 2500)",
+ )
+ parser.add_argument(
+ "--mesh_sipmlify_ratio",
+ type=float,
+ default=0.9,
+ help="Mesh simplification ratio (default: 0.9)",
+ )
+ parser.add_argument(
+ "--no_coor_trans",
+ action="store_true",
+ help="Do not transform the asset coordinate system.",
+ )
+ parser.add_argument(
+ "--delight", action="store_true", help="Use delighting model."
+ )
+ parser.add_argument(
+ "--skip_fix_mesh", action="store_true", help="Fix mesh geometry."
+ )
+
+ args = parser.parse_args()
+
+ if args.uuid is None:
+ args.uuid = []
+ for path in args.mesh_path:
+ uuid = os.path.basename(path).split(".")[0]
+ args.uuid.append(uuid)
+
+ return args
+
+
+def entrypoint() -> None:
+ args = parse_args()
+ camera_params = CameraSetting(
+ num_images=args.num_images,
+ elevation=args.elevation,
+ distance=args.distance,
+ resolution_hw=args.resolution_hw,
+ fov=math.radians(args.fov),
+ device=args.device,
+ )
+
+ for mesh_path, uuid, img_path in zip(
+ args.mesh_path, args.uuid, args.color_path
+ ):
+ mesh = trimesh.load(mesh_path)
+ if isinstance(mesh, trimesh.Scene):
+ mesh = mesh.dump(concatenate=True)
+ vertices, scale, center = normalize_vertices_array(mesh.vertices)
+
+ if not args.no_coor_trans:
+ x_rot = np.array([[1, 0, 0], [0, 0, 1], [0, -1, 0]])
+ z_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])
+ vertices = vertices @ x_rot
+ vertices = vertices @ z_rot
+
+ faces = mesh.faces.astype(np.int32)
+ vertices = vertices.astype(np.float32)
+
+ if not args.skip_fix_mesh:
+ mesh_fixer = MeshFixer(vertices, faces, args.device)
+ vertices, faces = mesh_fixer(
+ filter_ratio=args.mesh_sipmlify_ratio,
+ max_hole_size=0.04,
+ resolution=1024,
+ num_views=1000,
+ norm_mesh_ratio=0.5,
+ )
+
+ vertices, faces, uvs = TextureBaker.parametrize_mesh(vertices, faces)
+ texture_backer = TextureBaker(
+ vertices,
+ faces,
+ uvs,
+ camera_params,
+ )
+ images = get_images_from_grid(
+ img_path, img_size=camera_params.resolution_hw[0]
+ )
+ if args.delight:
+ delight_model = DelightingModel()
+ images = [delight_model(img) for img in images]
+
+ images = [np.array(img) for img in images]
+ texture = texture_backer.bake_texture(
+ images=[img[..., :3] for img in images],
+ texture_size=args.texture_size,
+ mode=args.baker_mode,
+ opt_step=args.opt_step,
+ )
+ texture = post_process_texture(texture)
+
+ if not args.no_coor_trans:
+ vertices = vertices @ np.linalg.inv(z_rot)
+ vertices = vertices @ np.linalg.inv(x_rot)
+ vertices = vertices / scale
+ vertices = vertices + center
+
+ output_path = os.path.join(args.output_root, f"{uuid}.obj")
+ mesh = save_mesh_with_mtl(vertices, faces, uvs, texture, output_path)
+
+ return
+
+
+if __name__ == "__main__":
+ entrypoint()
diff --git a/embodied_gen/data/backproject_v2.py b/embodied_gen/data/backproject_v2.py
new file mode 100644
index 0000000..efe1b34
--- /dev/null
+++ b/embodied_gen/data/backproject_v2.py
@@ -0,0 +1,702 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+
+import argparse
+import logging
+import math
+import os
+
+import cv2
+import numpy as np
+import nvdiffrast.torch as dr
+import spaces
+import torch
+import torch.nn.functional as F
+import trimesh
+import xatlas
+from PIL import Image
+from embodied_gen.data.mesh_operator import MeshFixer
+from embodied_gen.data.utils import (
+ CameraSetting,
+ DiffrastRender,
+ get_images_from_grid,
+ init_kal_camera,
+ normalize_vertices_array,
+ post_process_texture,
+ save_mesh_with_mtl,
+)
+from embodied_gen.models.delight_model import DelightingModel
+from embodied_gen.models.sr_model import ImageRealESRGAN
+
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
+)
+logger = logging.getLogger(__name__)
+
+
+__all__ = [
+ "TextureBacker",
+]
+
+
+def _transform_vertices(
+ mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False
+) -> torch.Tensor:
+ """Transform 3D vertices using a projection matrix."""
+ t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype)
+ if pos.size(-1) == 3:
+ pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1)
+
+ result = pos @ t_mtx.T
+
+ return result if keepdim else result.unsqueeze(0)
+
+
+def _bilinear_interpolation_scattering(
+ image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor
+) -> torch.Tensor:
+ """Bilinear interpolation scattering for grid-based value accumulation."""
+ device = values.device
+ dtype = values.dtype
+ C = values.shape[-1]
+
+ indices = coords * torch.tensor(
+ [image_h - 1, image_w - 1], dtype=dtype, device=device
+ )
+ i, j = indices.unbind(-1)
+
+ i0, j0 = (
+ indices.floor()
+ .long()
+ .clamp(0, image_h - 2)
+ .clamp(0, image_w - 2)
+ .unbind(-1)
+ )
+ i1, j1 = i0 + 1, j0 + 1
+
+ w_i = i - i0.float()
+ w_j = j - j0.float()
+ weights = torch.stack(
+ [(1 - w_i) * (1 - w_j), (1 - w_i) * w_j, w_i * (1 - w_j), w_i * w_j],
+ dim=1,
+ )
+
+ indices_comb = torch.stack(
+ [
+ torch.stack([i0, j0], dim=1),
+ torch.stack([i0, j1], dim=1),
+ torch.stack([i1, j0], dim=1),
+ torch.stack([i1, j1], dim=1),
+ ],
+ dim=1,
+ )
+
+ grid = torch.zeros(image_h, image_w, C, device=device, dtype=dtype)
+ cnt = torch.zeros(image_h, image_w, 1, device=device, dtype=dtype)
+
+ for k in range(4):
+ idx = indices_comb[:, k]
+ w = weights[:, k].unsqueeze(-1)
+
+ stride = torch.tensor([image_w, 1], device=device, dtype=torch.long)
+ flat_idx = (idx * stride).sum(-1)
+
+ grid.view(-1, C).scatter_add_(
+ 0, flat_idx.unsqueeze(-1).expand(-1, C), values * w
+ )
+ cnt.view(-1, 1).scatter_add_(0, flat_idx.unsqueeze(-1), w)
+
+ mask = cnt.squeeze(-1) > 0
+ grid[mask] = grid[mask] / cnt[mask].repeat(1, C)
+
+ return grid
+
+
+def _texture_inpaint_smooth(
+ texture: np.ndarray,
+ mask: np.ndarray,
+ vertices: np.ndarray,
+ faces: np.ndarray,
+ uv_map: np.ndarray,
+) -> tuple[np.ndarray, np.ndarray]:
+ """Perform texture inpainting using vertex-based color propagation."""
+ image_h, image_w, C = texture.shape
+ N = vertices.shape[0]
+
+ # Initialize vertex data structures
+ vtx_mask = np.zeros(N, dtype=np.float32)
+ vtx_colors = np.zeros((N, C), dtype=np.float32)
+ unprocessed = []
+ adjacency = [[] for _ in range(N)]
+
+ # Build adjacency graph and initial color assignment
+ for face_idx in range(faces.shape[0]):
+ for k in range(3):
+ uv_idx_k = faces[face_idx, k]
+ v_idx = faces[face_idx, k]
+
+ # Convert UV to pixel coordinates with boundary clamping
+ u = np.clip(
+ int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
+ )
+ v = np.clip(
+ int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
+ 0,
+ image_h - 1,
+ )
+
+ if mask[v, u]:
+ vtx_mask[v_idx] = 1.0
+ vtx_colors[v_idx] = texture[v, u]
+ elif v_idx not in unprocessed:
+ unprocessed.append(v_idx)
+
+ # Build undirected adjacency graph
+ neighbor = faces[face_idx, (k + 1) % 3]
+ if neighbor not in adjacency[v_idx]:
+ adjacency[v_idx].append(neighbor)
+ if v_idx not in adjacency[neighbor]:
+ adjacency[neighbor].append(v_idx)
+
+ # Color propagation with dynamic stopping
+ remaining_iters, prev_count = 2, 0
+ while remaining_iters > 0:
+ current_unprocessed = []
+
+ for v_idx in unprocessed:
+ valid_neighbors = [n for n in adjacency[v_idx] if vtx_mask[n] > 0]
+ if not valid_neighbors:
+ current_unprocessed.append(v_idx)
+ continue
+
+ # Calculate inverse square distance weights
+ neighbors_pos = vertices[valid_neighbors]
+ dist_sq = np.sum((vertices[v_idx] - neighbors_pos) ** 2, axis=1)
+ weights = 1 / np.maximum(dist_sq, 1e-8)
+
+ vtx_colors[v_idx] = np.average(
+ vtx_colors[valid_neighbors], weights=weights, axis=0
+ )
+ vtx_mask[v_idx] = 1.0
+
+ # Update iteration control
+ if len(current_unprocessed) == prev_count:
+ remaining_iters -= 1
+ else:
+ remaining_iters = min(remaining_iters + 1, 2)
+ prev_count = len(current_unprocessed)
+ unprocessed = current_unprocessed
+
+ # Generate output texture
+ inpainted_texture, updated_mask = texture.copy(), mask.copy()
+ for face_idx in range(faces.shape[0]):
+ for k in range(3):
+ v_idx = faces[face_idx, k]
+ if not vtx_mask[v_idx]:
+ continue
+
+ # UV coordinate conversion
+ uv_idx_k = faces[face_idx, k]
+ u = np.clip(
+ int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
+ )
+ v = np.clip(
+ int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
+ 0,
+ image_h - 1,
+ )
+
+ inpainted_texture[v, u] = vtx_colors[v_idx]
+ updated_mask[v, u] = 255
+
+ return inpainted_texture, updated_mask
+
+
+class TextureBacker:
+ """Texture baking pipeline for multi-view projection and fusion.
+
+ This class performs UV-based texture generation for a 3D mesh using
+ multi-view color images, depth, and normal information. The pipeline
+ includes mesh normalization and UV unwrapping, visibility-aware
+ back-projection, confidence-weighted texture fusion, and inpainting
+ of missing texture regions.
+
+ Args:
+ camera_params (CameraSetting): Camera intrinsics and extrinsics used
+ for rendering each view.
+ view_weights (list[float]): A list of weights for each view, used
+ to blend confidence maps during texture fusion.
+ render_wh (tuple[int, int], optional): Resolution (width, height) for
+ intermediate rendering passes. Defaults to (2048, 2048).
+ texture_wh (tuple[int, int], optional): Output texture resolution
+ (width, height). Defaults to (2048, 2048).
+ bake_angle_thresh (int, optional): Maximum angle (in degrees) between
+ view direction and surface normal for projection to be considered valid.
+ Defaults to 75.
+ mask_thresh (float, optional): Threshold applied to visibility masks
+ during rendering. Defaults to 0.5.
+ smooth_texture (bool, optional): If True, apply post-processing (e.g.,
+ blurring) to the final texture. Defaults to True.
+ """
+
+ def __init__(
+ self,
+ camera_params: CameraSetting,
+ view_weights: list[float],
+ render_wh: tuple[int, int] = (2048, 2048),
+ texture_wh: tuple[int, int] = (2048, 2048),
+ bake_angle_thresh: int = 75,
+ mask_thresh: float = 0.5,
+ smooth_texture: bool = True,
+ ) -> None:
+ self.camera_params = camera_params
+ self.renderer = None
+ self.view_weights = view_weights
+ self.device = camera_params.device
+ self.render_wh = render_wh
+ self.texture_wh = texture_wh
+ self.mask_thresh = mask_thresh
+ self.smooth_texture = smooth_texture
+
+ self.bake_angle_thresh = bake_angle_thresh
+ self.bake_unreliable_kernel_size = int(
+ (2 / 512) * max(self.render_wh[0], self.render_wh[1])
+ )
+
+ def _lazy_init_render(self, camera_params, mask_thresh):
+ if self.renderer is None:
+ camera = init_kal_camera(camera_params)
+ mv = camera.view_matrix() # (n 4 4) world2cam
+ p = camera.intrinsics.projection_matrix()
+ # NOTE: add a negative sign at P[0, 2] as the y axis is flipped in `nvdiffrast` output. # noqa
+ p[:, 1, 1] = -p[:, 1, 1]
+ self.renderer = DiffrastRender(
+ p_matrix=p,
+ mv_matrix=mv,
+ resolution_hw=camera_params.resolution_hw,
+ context=dr.RasterizeCudaContext(),
+ mask_thresh=mask_thresh,
+ grad_db=False,
+ device=self.device,
+ antialias_mask=True,
+ )
+
+ def load_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh:
+ mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
+ self.scale, self.center = scale, center
+
+ vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces)
+ uvs[:, 1] = 1 - uvs[:, 1]
+ mesh.vertices = mesh.vertices[vmapping]
+ mesh.faces = indices
+ mesh.visual.uv = uvs
+
+ return mesh
+
+ def get_mesh_np_attrs(
+ self,
+ mesh: trimesh.Trimesh,
+ scale: float = None,
+ center: np.ndarray = None,
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
+ vertices = mesh.vertices.copy()
+ faces = mesh.faces.copy()
+ uv_map = mesh.visual.uv.copy()
+ uv_map[:, 1] = 1.0 - uv_map[:, 1]
+
+ if scale is not None:
+ vertices = vertices / scale
+ if center is not None:
+ vertices = vertices + center
+
+ return vertices, faces, uv_map
+
+ def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor:
+ depth_image_np = depth_image.cpu().numpy()
+ depth_image_np = (depth_image_np * 255).astype(np.uint8)
+ depth_edges = cv2.Canny(depth_image_np, 30, 80)
+ sketch_image = (
+ torch.from_numpy(depth_edges).to(depth_image.device).float() / 255
+ )
+ sketch_image = sketch_image.unsqueeze(-1)
+
+ return sketch_image
+
+ def compute_enhanced_viewnormal(
+ self, mv_mtx: torch.Tensor, vertices: torch.Tensor, faces: torch.Tensor
+ ) -> torch.Tensor:
+ rast, _ = self.renderer.compute_dr_raster(vertices, faces)
+ rendered_view_normals = []
+ for idx in range(len(mv_mtx)):
+ pos_cam = _transform_vertices(mv_mtx[idx], vertices, keepdim=True)
+ pos_cam = pos_cam[:, :3] / pos_cam[:, 3:]
+ v0, v1, v2 = (pos_cam[faces[:, i]] for i in range(3))
+ face_norm = F.normalize(
+ torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1
+ )
+ vertex_norm = (
+ torch.from_numpy(
+ trimesh.geometry.mean_vertex_normals(
+ len(pos_cam), faces.cpu(), face_norm.cpu()
+ )
+ )
+ .to(vertices.device)
+ .contiguous()
+ )
+ im_base_normals, _ = dr.interpolate(
+ vertex_norm[None, ...].float(),
+ rast[idx : idx + 1],
+ faces.to(torch.int32),
+ )
+ rendered_view_normals.append(im_base_normals)
+
+ rendered_view_normals = torch.cat(rendered_view_normals, dim=0)
+
+ return rendered_view_normals
+
+ def back_project(
+ self, image, vis_mask, depth, normal, uv
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ image = np.array(image)
+ image = torch.as_tensor(image, device=self.device, dtype=torch.float32)
+ if image.ndim == 2:
+ image = image.unsqueeze(-1)
+ image = image / 255
+
+ depth_inv = (1.0 - depth) * vis_mask
+ sketch_image = self._render_depth_edges(depth_inv)
+
+ cos = F.cosine_similarity(
+ torch.tensor([[0, 0, 1]], device=self.device),
+ normal.view(-1, 3),
+ ).view_as(normal[..., :1])
+ cos[cos < np.cos(np.radians(self.bake_angle_thresh))] = 0
+
+ k = self.bake_unreliable_kernel_size * 2 + 1
+ kernel = torch.ones((1, 1, k, k), device=self.device)
+
+ vis_mask = vis_mask.permute(2, 0, 1).unsqueeze(0).float()
+ vis_mask = F.conv2d(
+ 1.0 - vis_mask,
+ kernel,
+ padding=k // 2,
+ )
+ vis_mask = 1.0 - (vis_mask > 0).float()
+ vis_mask = vis_mask.squeeze(0).permute(1, 2, 0)
+
+ sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0)
+ sketch_image = F.conv2d(sketch_image, kernel, padding=k // 2)
+ sketch_image = (sketch_image > 0).float()
+ sketch_image = sketch_image.squeeze(0).permute(1, 2, 0)
+ vis_mask = vis_mask * (sketch_image < 0.5)
+
+ cos[vis_mask == 0] = 0
+ valid_pixels = (vis_mask != 0).view(-1)
+
+ return (
+ self._scatter_texture(uv, image, valid_pixels),
+ self._scatter_texture(uv, cos, valid_pixels),
+ )
+
+ def _scatter_texture(self, uv, data, mask):
+ def __filter_data(data, mask):
+ return data.view(-1, data.shape[-1])[mask]
+
+ return _bilinear_interpolation_scattering(
+ self.texture_wh[1],
+ self.texture_wh[0],
+ __filter_data(uv, mask)[..., [1, 0]],
+ __filter_data(data, mask),
+ )
+
+ @torch.no_grad()
+ def fast_bake_texture(
+ self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor]
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ channel = textures[0].shape[-1]
+ texture_merge = torch.zeros(self.texture_wh + [channel]).to(
+ self.device
+ )
+ trust_map_merge = torch.zeros(self.texture_wh + [1]).to(self.device)
+ for texture, cos_map in zip(textures, confidence_maps):
+ view_sum = (cos_map > 0).sum()
+ painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum()
+ if painted_sum / view_sum > 0.99:
+ continue
+ texture_merge += texture * cos_map
+ trust_map_merge += cos_map
+ texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1e-8)
+
+ return texture_merge, trust_map_merge > 1e-8
+
+ def uv_inpaint(
+ self, mesh: trimesh.Trimesh, texture: np.ndarray, mask: np.ndarray
+ ) -> np.ndarray:
+ vertices, faces, uv_map = self.get_mesh_np_attrs(mesh)
+
+ texture, mask = _texture_inpaint_smooth(
+ texture, mask, vertices, faces, uv_map
+ )
+ texture = texture.clip(0, 1)
+ texture = cv2.inpaint(
+ (texture * 255).astype(np.uint8),
+ 255 - mask,
+ 3,
+ cv2.INPAINT_NS,
+ )
+
+ return texture
+
+ @spaces.GPU
+ def compute_texture(
+ self,
+ colors: list[Image.Image],
+ mesh: trimesh.Trimesh,
+ ) -> trimesh.Trimesh:
+ self._lazy_init_render(self.camera_params, self.mask_thresh)
+
+ vertices = torch.from_numpy(mesh.vertices).to(self.device).float()
+ faces = torch.from_numpy(mesh.faces).to(self.device).to(torch.int)
+ uv_map = torch.from_numpy(mesh.visual.uv).to(self.device).float()
+
+ rendered_depth, masks = self.renderer.render_depth(vertices, faces)
+ norm_deps = self.renderer.normalize_map_by_mask(rendered_depth, masks)
+ render_uvs, _ = self.renderer.render_uv(vertices, faces, uv_map)
+ view_normals = self.compute_enhanced_viewnormal(
+ self.renderer.mv_mtx, vertices, faces
+ )
+
+ textures, weighted_cos_maps = [], []
+ for color, mask, dep, normal, uv, weight in zip(
+ colors,
+ masks,
+ norm_deps,
+ view_normals,
+ render_uvs,
+ self.view_weights,
+ ):
+ texture, cos_map = self.back_project(color, mask, dep, normal, uv)
+ textures.append(texture)
+ weighted_cos_maps.append(weight * (cos_map**4))
+
+ texture, mask = self.fast_bake_texture(textures, weighted_cos_maps)
+
+ texture_np = texture.cpu().numpy()
+ mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8)
+
+ return texture_np, mask_np
+
+ def __call__(
+ self,
+ colors: list[Image.Image],
+ mesh: trimesh.Trimesh,
+ output_path: str,
+ ) -> trimesh.Trimesh:
+ """Runs the texture baking and exports the textured mesh.
+
+ Args:
+ colors (list[Image.Image]): List of input view images.
+ mesh (trimesh.Trimesh): Input mesh to be textured.
+ output_path (str): Path to save the output textured mesh (.obj or .glb).
+
+ Returns:
+ trimesh.Trimesh: The textured mesh with UV and texture image.
+ """
+ mesh = self.load_mesh(mesh)
+ texture_np, mask_np = self.compute_texture(colors, mesh)
+
+ texture_np = self.uv_inpaint(mesh, texture_np, mask_np)
+ if self.smooth_texture:
+ texture_np = post_process_texture(texture_np)
+
+ vertices, faces, uv_map = self.get_mesh_np_attrs(
+ mesh, self.scale, self.center
+ )
+ textured_mesh = save_mesh_with_mtl(
+ vertices, faces, uv_map, texture_np, output_path
+ )
+
+ return textured_mesh
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Backproject texture")
+ parser.add_argument(
+ "--color_path",
+ type=str,
+ help="Multiview color image in 6x512x512 file path",
+ )
+ parser.add_argument(
+ "--mesh_path",
+ type=str,
+ help="Mesh path, .obj, .glb or .ply",
+ )
+ parser.add_argument(
+ "--output_path",
+ type=str,
+ help="Output mesh path with suffix",
+ )
+ parser.add_argument(
+ "--num_images", type=int, default=6, help="Number of images to render."
+ )
+ parser.add_argument(
+ "--elevation",
+ nargs=2,
+ type=float,
+ default=[20.0, -10.0],
+ help="Elevation angles for the camera (default: [20.0, -10.0])",
+ )
+ parser.add_argument(
+ "--distance",
+ type=float,
+ default=5,
+ help="Camera distance (default: 5)",
+ )
+ parser.add_argument(
+ "--resolution_hw",
+ type=int,
+ nargs=2,
+ default=(2048, 2048),
+ help="Resolution of the output images (default: (2048, 2048))",
+ )
+ parser.add_argument(
+ "--fov",
+ type=float,
+ default=30,
+ help="Field of view in degrees (default: 30)",
+ )
+ parser.add_argument(
+ "--device",
+ type=str,
+ choices=["cpu", "cuda"],
+ default="cuda",
+ help="Device to run on (default: `cuda`)",
+ )
+ parser.add_argument(
+ "--skip_fix_mesh", action="store_true", help="Fix mesh geometry."
+ )
+ parser.add_argument(
+ "--texture_wh",
+ nargs=2,
+ type=int,
+ default=[2048, 2048],
+ help="Texture resolution width and height",
+ )
+ parser.add_argument(
+ "--mesh_sipmlify_ratio",
+ type=float,
+ default=0.9,
+ help="Mesh simplification ratio (default: 0.9)",
+ )
+ parser.add_argument(
+ "--delight", action="store_true", help="Use delighting model."
+ )
+ parser.add_argument(
+ "--no_smooth_texture",
+ action="store_true",
+ help="Do not smooth the texture.",
+ )
+ parser.add_argument(
+ "--save_glb_path", type=str, default=None, help="Save glb path."
+ )
+ parser.add_argument(
+ "--no_save_delight_img",
+ action="store_true",
+ help="Disable saving delight image",
+ )
+
+ args, unknown = parser.parse_known_args()
+
+ return args
+
+
+def entrypoint(
+ delight_model: DelightingModel = None,
+ imagesr_model: ImageRealESRGAN = None,
+ **kwargs,
+) -> trimesh.Trimesh:
+ args = parse_args()
+ for k, v in kwargs.items():
+ if hasattr(args, k) and v is not None:
+ setattr(args, k, v)
+
+ # Setup camera parameters.
+ camera_params = CameraSetting(
+ num_images=args.num_images,
+ elevation=args.elevation,
+ distance=args.distance,
+ resolution_hw=args.resolution_hw,
+ fov=math.radians(args.fov),
+ device=args.device,
+ )
+ view_weights = [1, 0.1, 0.02, 0.1, 1, 0.02]
+
+ color_grid = Image.open(args.color_path)
+ if args.delight:
+ if delight_model is None:
+ delight_model = DelightingModel()
+ save_dir = os.path.dirname(args.output_path)
+ os.makedirs(save_dir, exist_ok=True)
+ color_grid = delight_model(color_grid)
+ if not args.no_save_delight_img:
+ color_grid.save(f"{save_dir}/color_grid_delight.png")
+
+ multiviews = get_images_from_grid(color_grid, img_size=512)
+
+ # Use RealESRGAN_x4plus for x4 (512->2048) image super resolution.
+ if imagesr_model is None:
+ imagesr_model = ImageRealESRGAN(outscale=4)
+ multiviews = [imagesr_model(img) for img in multiviews]
+ multiviews = [img.convert("RGB") for img in multiviews]
+ mesh = trimesh.load(args.mesh_path)
+ if isinstance(mesh, trimesh.Scene):
+ mesh = mesh.dump(concatenate=True)
+
+ if not args.skip_fix_mesh:
+ mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
+ mesh_fixer = MeshFixer(mesh.vertices, mesh.faces, args.device)
+ mesh.vertices, mesh.faces = mesh_fixer(
+ filter_ratio=args.mesh_sipmlify_ratio,
+ max_hole_size=0.04,
+ resolution=1024,
+ num_views=1000,
+ norm_mesh_ratio=0.5,
+ )
+ # Restore scale.
+ mesh.vertices = mesh.vertices / scale
+ mesh.vertices = mesh.vertices + center
+
+ # Baking texture to mesh.
+ texture_backer = TextureBacker(
+ camera_params=camera_params,
+ view_weights=view_weights,
+ render_wh=camera_params.resolution_hw,
+ texture_wh=args.texture_wh,
+ smooth_texture=not args.no_smooth_texture,
+ )
+
+ textured_mesh = texture_backer(multiviews, mesh, args.output_path)
+
+ if args.save_glb_path is not None:
+ os.makedirs(os.path.dirname(args.save_glb_path), exist_ok=True)
+ textured_mesh.export(args.save_glb_path)
+
+ return textured_mesh
+
+
+if __name__ == "__main__":
+ entrypoint()
diff --git a/embodied_gen/data/datasets.py b/embodied_gen/data/datasets.py
new file mode 100644
index 0000000..4a9563a
--- /dev/null
+++ b/embodied_gen/data/datasets.py
@@ -0,0 +1,256 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+
+import json
+import logging
+import os
+import random
+from typing import Any, Callable, Dict, List, Tuple
+
+import torch
+import torch.utils.checkpoint
+from PIL import Image
+from torch import nn
+from torch.utils.data import Dataset
+from torchvision import transforms
+
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
+)
+logger = logging.getLogger(__name__)
+
+
+__all__ = [
+ "Asset3dGenDataset",
+]
+
+
+class Asset3dGenDataset(Dataset):
+ def __init__(
+ self,
+ index_file: str,
+ target_hw: Tuple[int, int],
+ transform: Callable = None,
+ control_transform: Callable = None,
+ max_train_samples: int = None,
+ sub_idxs: List[List[int]] = None,
+ seed: int = 79,
+ ) -> None:
+ if not os.path.exists(index_file):
+ raise FileNotFoundError(f"{index_file} index_file not found.")
+
+ self.index_file = index_file
+ self.target_hw = target_hw
+ self.transform = transform
+ self.control_transform = control_transform
+ self.max_train_samples = max_train_samples
+ self.meta_info = self.prepare_data_index(index_file)
+ self.data_list = sorted(self.meta_info.keys())
+ self.sub_idxs = sub_idxs # sub_idxs [[0,1,2], [3,4,5], [...], ...]
+ self.image_num = 6 # hardcode temp.
+ random.seed(seed)
+ logger.info(f"Trainset: {len(self)} asset3d instances.")
+
+ def __len__(self) -> int:
+ return len(self.meta_info)
+
+ def prepare_data_index(self, index_file: str) -> Dict[str, Any]:
+ with open(index_file, "r") as fin:
+ meta_info = json.load(fin)
+
+ meta_info_filtered = dict()
+ for idx, uid in enumerate(meta_info):
+ if "status" not in meta_info[uid]:
+ continue
+ if meta_info[uid]["status"] != "success":
+ continue
+ if self.max_train_samples and idx >= self.max_train_samples:
+ break
+
+ meta_info_filtered[uid] = meta_info[uid]
+
+ logger.info(
+ f"Load {len(meta_info)} assets, keep {len(meta_info_filtered)} valids." # noqa
+ )
+
+ return meta_info_filtered
+
+ def fetch_sample_images(
+ self,
+ uid: str,
+ attrs: List[str],
+ sub_index: int = None,
+ transform: Callable = None,
+ ) -> torch.Tensor:
+ sample = self.meta_info[uid]
+ images = []
+ for attr in attrs:
+ item = sample[attr]
+ if sub_index is not None:
+ item = item[sub_index]
+ mode = "L" if attr == "image_mask" else "RGB"
+ image = Image.open(item).convert(mode)
+ if transform is not None:
+ image = transform(image)
+ if len(image.shape) == 2:
+ image = image[..., None]
+ images.append(image)
+
+ images = torch.cat(images, dim=0)
+
+ return images
+
+ def fetch_sample_grid_images(
+ self,
+ uid: str,
+ attrs: List[str],
+ sub_idxs: List[List[int]],
+ transform: Callable = None,
+ ) -> torch.Tensor:
+ assert transform is not None
+
+ grid_image = []
+ for row_idxs in sub_idxs:
+ row_image = []
+ for row_idx in row_idxs:
+ image = self.fetch_sample_images(
+ uid, attrs, row_idx, transform
+ )
+ row_image.append(image)
+ row_image = torch.cat(row_image, dim=2) # (c h w)
+ grid_image.append(row_image)
+
+ grid_image = torch.cat(grid_image, dim=1)
+
+ return grid_image
+
+ def compute_text_embeddings(
+ self, embed_path: str, original_size: Tuple[int, int]
+ ) -> Dict[str, nn.Module]:
+ data_dict = torch.load(embed_path)
+ prompt_embeds = data_dict["prompt_embeds"][0]
+ add_text_embeds = data_dict["pooled_prompt_embeds"][0]
+
+ # Need changed if random crop, set as crop_top_left [y1, x1], center crop as [0, 0]. # noqa
+ crops_coords_top_left = (0, 0)
+ add_time_ids = list(
+ original_size + crops_coords_top_left + self.target_hw
+ )
+ add_time_ids = torch.tensor([add_time_ids])
+ # add_time_ids = add_time_ids.repeat((len(add_text_embeds), 1))
+
+ unet_added_cond_kwargs = {
+ "text_embeds": add_text_embeds,
+ "time_ids": add_time_ids,
+ }
+
+ return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
+
+ def visualize_item(
+ self,
+ control: torch.Tensor,
+ color: torch.Tensor,
+ save_dir: str = None,
+ ) -> List[Image.Image]:
+ to_pil = transforms.ToPILImage()
+
+ color = (color + 1) / 2
+ color_pil = to_pil(color)
+ normal_pil = to_pil(control[0:3])
+ position_pil = to_pil(control[3:6])
+ mask_pil = to_pil(control[6:])
+
+ if save_dir is not None:
+ os.makedirs(save_dir, exist_ok=True)
+ color_pil.save(f"{save_dir}/rgb.jpg")
+ normal_pil.save(f"{save_dir}/normal.jpg")
+ position_pil.save(f"{save_dir}/position.jpg")
+ mask_pil.save(f"{save_dir}/mask.jpg")
+ logger.info(f"Visualization in {save_dir}")
+
+ return normal_pil, position_pil, mask_pil, color_pil
+
+ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
+ uid = self.data_list[index]
+
+ sub_idxs = self.sub_idxs
+ if sub_idxs is None:
+ sub_idxs = [[random.randint(0, self.image_num - 1)]]
+
+ input_image = self.fetch_sample_grid_images(
+ uid,
+ attrs=["image_view_normal", "image_position", "image_mask"],
+ sub_idxs=sub_idxs,
+ transform=self.control_transform,
+ )
+ assert input_image.shape[1:] == self.target_hw
+
+ output_image = self.fetch_sample_grid_images(
+ uid,
+ attrs=["image_color"],
+ sub_idxs=sub_idxs,
+ transform=self.transform,
+ )
+
+ sample = self.meta_info[uid]
+ text_feats = self.compute_text_embeddings(
+ sample["text_feat"], tuple(sample["image_hw"])
+ )
+
+ data = dict(
+ pixel_values=output_image,
+ conditioning_pixel_values=input_image,
+ prompt_embeds=text_feats["prompt_embeds"],
+ text_embeds=text_feats["text_embeds"],
+ time_ids=text_feats["time_ids"],
+ )
+
+ return data
+
+
+if __name__ == "__main__":
+ index_file = "datasets/objaverse/v1.0/statistics_1.0_gobjaverse_filter/view6s_v4/meta_ac2e0ddea8909db26d102c8465b5bcb2.json" # noqa
+ target_hw = (512, 512)
+ transform_list = [
+ transforms.Resize(
+ target_hw, interpolation=transforms.InterpolationMode.BILINEAR
+ ),
+ transforms.CenterCrop(target_hw),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ image_transform = transforms.Compose(transform_list)
+ control_transform = transforms.Compose(transform_list[:-1])
+
+ sub_idxs = [[0, 1, 2], [3, 4, 5]] # None
+ if sub_idxs is not None:
+ target_hw = (
+ target_hw[0] * len(sub_idxs),
+ target_hw[1] * len(sub_idxs[0]),
+ )
+
+ dataset = Asset3dGenDataset(
+ index_file,
+ target_hw,
+ image_transform,
+ control_transform,
+ sub_idxs=sub_idxs,
+ )
+ data = dataset[0]
+ dataset.visualize_item(
+ data["conditioning_pixel_values"], data["pixel_values"], save_dir="./"
+ )
diff --git a/embodied_gen/data/differentiable_render.py b/embodied_gen/data/differentiable_render.py
new file mode 100644
index 0000000..18b0a86
--- /dev/null
+++ b/embodied_gen/data/differentiable_render.py
@@ -0,0 +1,526 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+
+import argparse
+import json
+import logging
+import math
+import os
+from collections import defaultdict
+from typing import List, Union
+
+import cv2
+import nvdiffrast.torch as dr
+import torch
+from tqdm import tqdm
+from embodied_gen.data.utils import (
+ CameraSetting,
+ DiffrastRender,
+ RenderItems,
+ as_list,
+ calc_vertex_normals,
+ import_kaolin_mesh,
+ init_kal_camera,
+ normalize_vertices_array,
+ render_pbr,
+ save_images,
+)
+from embodied_gen.utils.process_media import (
+ create_gif_from_images,
+ create_mp4_from_images,
+)
+
+os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
+os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
+ "~/.cache/torch_extensions"
+)
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
+)
+logger = logging.getLogger(__name__)
+
+
+__all__ = ["ImageRender"]
+
+
+class ImageRender(object):
+ """A differentiable mesh renderer supporting multi-view rendering.
+
+ This class wraps a differentiable rasterization using `nvdiffrast` to
+ render mesh geometry to various maps (normal, depth, alpha, albedo, etc.).
+
+ Args:
+ render_items (list[RenderItems]): A list of rendering targets to
+ generate (e.g., IMAGE, DEPTH, NORMAL, etc.).
+ camera_params (CameraSetting): The camera parameters for rendering,
+ including intrinsic and extrinsic matrices.
+ recompute_vtx_normal (bool, optional): If True, recomputes
+ vertex normals from the mesh geometry. Defaults to True.
+ with_mtl (bool, optional): Whether to load `.mtl` material files
+ for meshes. Defaults to False.
+ gen_color_gif (bool, optional): Generate a GIF of rendered
+ color images. Defaults to False.
+ gen_color_mp4 (bool, optional): Generate an MP4 video of rendered
+ color images. Defaults to False.
+ gen_viewnormal_mp4 (bool, optional): Generate an MP4 video of
+ view-space normals. Defaults to False.
+ gen_glonormal_mp4 (bool, optional): Generate an MP4 video of
+ global-space normals. Defaults to False.
+ no_index_file (bool, optional): If True, skip saving the `index.json`
+ summary file. Defaults to False.
+ light_factor (float, optional): A scalar multiplier for
+ PBR light intensity. Defaults to 1.0.
+ """
+
+ def __init__(
+ self,
+ render_items: list[RenderItems],
+ camera_params: CameraSetting,
+ recompute_vtx_normal: bool = True,
+ with_mtl: bool = False,
+ gen_color_gif: bool = False,
+ gen_color_mp4: bool = False,
+ gen_viewnormal_mp4: bool = False,
+ gen_glonormal_mp4: bool = False,
+ no_index_file: bool = False,
+ light_factor: float = 1.0,
+ ) -> None:
+ camera = init_kal_camera(camera_params)
+ self.camera = camera
+
+ # Setup MVP matrix and renderer.
+ mv = camera.view_matrix() # (n 4 4) world2cam
+ p = camera.intrinsics.projection_matrix()
+ # NOTE: add a negative sign at P[0, 2] as the y axis is flipped in `nvdiffrast` output. # noqa
+ p[:, 1, 1] = -p[:, 1, 1]
+ # mvp = torch.bmm(p, mv) # camera.view_projection_matrix()
+ self.mv = mv
+ self.p = p
+
+ renderer = DiffrastRender(
+ p_matrix=p,
+ mv_matrix=mv,
+ resolution_hw=camera_params.resolution_hw,
+ context=dr.RasterizeCudaContext(),
+ mask_thresh=0.5,
+ grad_db=False,
+ device=camera_params.device,
+ antialias_mask=True,
+ )
+ self.renderer = renderer
+ self.recompute_vtx_normal = recompute_vtx_normal
+ self.render_items = render_items
+ self.device = camera_params.device
+ self.with_mtl = with_mtl
+ self.gen_color_gif = gen_color_gif
+ self.gen_color_mp4 = gen_color_mp4
+ self.gen_viewnormal_mp4 = gen_viewnormal_mp4
+ self.gen_glonormal_mp4 = gen_glonormal_mp4
+ self.light_factor = light_factor
+ self.no_index_file = no_index_file
+
+ def render_mesh(
+ self,
+ mesh_path: Union[str, List[str]],
+ output_root: str,
+ uuid: Union[str, List[str]] = None,
+ prompts: List[str] = None,
+ ) -> None:
+ mesh_path = as_list(mesh_path)
+ if uuid is None:
+ uuid = [os.path.basename(p).split(".")[0] for p in mesh_path]
+ uuid = as_list(uuid)
+ assert len(mesh_path) == len(uuid)
+ os.makedirs(output_root, exist_ok=True)
+
+ meta_info = dict()
+ for idx, (path, uid) in tqdm(
+ enumerate(zip(mesh_path, uuid)), total=len(mesh_path)
+ ):
+ output_dir = os.path.join(output_root, uid)
+ os.makedirs(output_dir, exist_ok=True)
+ prompt = prompts[idx] if prompts else None
+ data_dict = self(path, output_dir, prompt)
+ meta_info[uid] = data_dict
+
+ if self.no_index_file:
+ return
+
+ index_file = os.path.join(output_root, "index.json")
+ with open(index_file, "w") as fout:
+ json.dump(meta_info, fout)
+
+ logger.info(f"Rendering meta info logged in {index_file}")
+
+ def __call__(
+ self, mesh_path: str, output_dir: str, prompt: str = None
+ ) -> dict[str, str]:
+ """Render a single mesh and return paths to the rendered outputs.
+
+ Processes the input mesh, renders multiple modalities (e.g., normals,
+ depth, albedo), and optionally saves video or image sequences.
+
+ Args:
+ mesh_path (str): Path to the mesh file (.obj/.glb).
+ output_dir (str): Directory to save rendered outputs.
+ prompt (str, optional): Optional caption prompt for MP4 metadata.
+
+ Returns:
+ dict[str, str]: A mapping render types to the saved image paths.
+ """
+ try:
+ mesh = import_kaolin_mesh(mesh_path, self.with_mtl)
+ except Exception as e:
+ logger.error(f"[ERROR MESH LOAD]: {e}, skip {mesh_path}")
+ return
+
+ mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
+ if self.recompute_vtx_normal:
+ mesh.vertex_normals = calc_vertex_normals(
+ mesh.vertices, mesh.faces
+ )
+
+ mesh = mesh.to(self.device)
+ vertices, faces, vertex_normals = (
+ mesh.vertices,
+ mesh.faces,
+ mesh.vertex_normals,
+ )
+
+ # Perform rendering.
+ data_dict = defaultdict(list)
+ if RenderItems.ALPHA.value in self.render_items:
+ masks, _ = self.renderer.render_rast_alpha(vertices, faces)
+ render_paths = save_images(
+ masks, f"{output_dir}/{RenderItems.ALPHA}"
+ )
+ data_dict[RenderItems.ALPHA.value] = render_paths
+
+ if RenderItems.GLOBAL_NORMAL.value in self.render_items:
+ rendered_normals, masks = self.renderer.render_global_normal(
+ vertices, faces, vertex_normals
+ )
+ if self.gen_glonormal_mp4:
+ if isinstance(rendered_normals, torch.Tensor):
+ rendered_normals = rendered_normals.detach().cpu().numpy()
+ create_mp4_from_images(
+ rendered_normals,
+ output_path=f"{output_dir}/normal.mp4",
+ fps=15,
+ prompt=prompt,
+ )
+ else:
+ render_paths = save_images(
+ rendered_normals,
+ f"{output_dir}/{RenderItems.GLOBAL_NORMAL}",
+ cvt_color=cv2.COLOR_BGR2RGB,
+ )
+ data_dict[RenderItems.GLOBAL_NORMAL.value] = render_paths
+
+ if RenderItems.VIEW_NORMAL.value in self.render_items:
+ assert (
+ RenderItems.GLOBAL_NORMAL in self.render_items
+ ), f"Must render global normal firstly, got render_items: {self.render_items}." # noqa
+ rendered_view_normals = self.renderer.transform_normal(
+ rendered_normals, self.mv, masks, to_view=True
+ )
+
+ if self.gen_viewnormal_mp4:
+ create_mp4_from_images(
+ rendered_view_normals,
+ output_path=f"{output_dir}/view_normal.mp4",
+ fps=15,
+ prompt=prompt,
+ )
+ else:
+ render_paths = save_images(
+ rendered_view_normals,
+ f"{output_dir}/{RenderItems.VIEW_NORMAL}",
+ cvt_color=cv2.COLOR_BGR2RGB,
+ )
+ data_dict[RenderItems.VIEW_NORMAL.value] = render_paths
+
+ if RenderItems.POSITION_MAP.value in self.render_items:
+ rendered_position, masks = self.renderer.render_position(
+ vertices, faces
+ )
+ norm_position = self.renderer.normalize_map_by_mask(
+ rendered_position, masks
+ )
+ render_paths = save_images(
+ norm_position,
+ f"{output_dir}/{RenderItems.POSITION_MAP}",
+ cvt_color=cv2.COLOR_BGR2RGB,
+ )
+ data_dict[RenderItems.POSITION_MAP.value] = render_paths
+
+ if RenderItems.DEPTH.value in self.render_items:
+ rendered_depth, masks = self.renderer.render_depth(vertices, faces)
+ norm_depth = self.renderer.normalize_map_by_mask(
+ rendered_depth, masks
+ )
+ render_paths = save_images(
+ norm_depth,
+ f"{output_dir}/{RenderItems.DEPTH}",
+ )
+ data_dict[RenderItems.DEPTH.value] = render_paths
+
+ render_paths = save_images(
+ rendered_depth,
+ f"{output_dir}/{RenderItems.DEPTH}_exr",
+ to_uint8=False,
+ format=".exr",
+ )
+ data_dict[f"{RenderItems.DEPTH.value}_exr"] = render_paths
+
+ if RenderItems.IMAGE.value in self.render_items:
+ images = []
+ albedos = []
+ diffuses = []
+ masks, _ = self.renderer.render_rast_alpha(vertices, faces)
+ try:
+ for idx, cam in enumerate(self.camera):
+ image, albedo, diffuse, _ = render_pbr(
+ mesh, cam, light_factor=self.light_factor
+ )
+ image = torch.cat([image[0], masks[idx]], axis=-1)
+ images.append(image.detach().cpu().numpy())
+
+ if RenderItems.ALBEDO.value in self.render_items:
+ albedo = torch.cat([albedo[0], masks[idx]], axis=-1)
+ albedos.append(albedo.detach().cpu().numpy())
+
+ if RenderItems.DIFFUSE.value in self.render_items:
+ diffuse = torch.cat([diffuse[0], masks[idx]], axis=-1)
+ diffuses.append(diffuse.detach().cpu().numpy())
+
+ except Exception as e:
+ logger.error(f"[ERROR pbr render]: {e}, skip {mesh_path}")
+ return
+
+ if self.gen_color_gif:
+ create_gif_from_images(
+ images,
+ output_path=f"{output_dir}/color.gif",
+ fps=15,
+ )
+
+ if self.gen_color_mp4:
+ create_mp4_from_images(
+ images,
+ output_path=f"{output_dir}/color.mp4",
+ fps=15,
+ prompt=prompt,
+ )
+
+ if self.gen_color_mp4 or self.gen_color_gif:
+ return data_dict
+
+ render_paths = save_images(
+ images,
+ f"{output_dir}/{RenderItems.IMAGE}",
+ cvt_color=cv2.COLOR_BGRA2RGBA,
+ )
+ data_dict[RenderItems.IMAGE.value] = render_paths
+
+ render_paths = save_images(
+ albedos,
+ f"{output_dir}/{RenderItems.ALBEDO}",
+ cvt_color=cv2.COLOR_BGRA2RGBA,
+ )
+ data_dict[RenderItems.ALBEDO.value] = render_paths
+
+ render_paths = save_images(
+ diffuses,
+ f"{output_dir}/{RenderItems.DIFFUSE}",
+ cvt_color=cv2.COLOR_BGRA2RGBA,
+ )
+ data_dict[RenderItems.DIFFUSE.value] = render_paths
+
+ data_dict["status"] = "success"
+
+ logger.info(f"Finish rendering in {output_dir}")
+
+ return data_dict
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Render settings")
+
+ parser.add_argument(
+ "--mesh_path",
+ type=str,
+ nargs="+",
+ help="Paths to the mesh files for rendering.",
+ )
+ parser.add_argument(
+ "--output_root",
+ type=str,
+ help="Root directory for output",
+ )
+ parser.add_argument(
+ "--uuid",
+ type=str,
+ nargs="+",
+ default=None,
+ help="uuid for rendering saving.",
+ )
+ parser.add_argument(
+ "--num_images", type=int, default=6, help="Number of images to render."
+ )
+ parser.add_argument(
+ "--elevation",
+ type=float,
+ nargs="+",
+ default=[20.0, -10.0],
+ help="Elevation angles for the camera (default: [20.0, -10.0])",
+ )
+ parser.add_argument(
+ "--distance",
+ type=float,
+ default=5,
+ help="Camera distance (default: 5)",
+ )
+ parser.add_argument(
+ "--resolution_hw",
+ type=int,
+ nargs=2,
+ default=(512, 512),
+ help="Resolution of the output images (default: (512, 512))",
+ )
+ parser.add_argument(
+ "--fov",
+ type=float,
+ default=30,
+ help="Field of view in degrees (default: 30)",
+ )
+ parser.add_argument(
+ "--pbr_light_factor",
+ type=float,
+ default=1.0,
+ help="Light factor for mesh PBR rendering (default: 2.)",
+ )
+ parser.add_argument(
+ "--with_mtl",
+ action="store_true",
+ help="Whether to render with mesh material.",
+ )
+ parser.add_argument(
+ "--gen_color_gif",
+ action="store_true",
+ help="Whether to generate color .gif rendering file.",
+ )
+ parser.add_argument(
+ "--gen_color_mp4",
+ action="store_true",
+ help="Whether to generate color .mp4 rendering file.",
+ )
+ parser.add_argument(
+ "--gen_viewnormal_mp4",
+ action="store_true",
+ help="Whether to generate view normal .mp4 rendering file.",
+ )
+ parser.add_argument(
+ "--gen_glonormal_mp4",
+ action="store_true",
+ help="Whether to generate global normal .mp4 rendering file.",
+ )
+ parser.add_argument(
+ "--prompts",
+ type=str,
+ nargs="+",
+ default=None,
+ help="Text prompts for the rendering.",
+ )
+
+ args = parser.parse_args()
+
+ if args.uuid is None and args.mesh_path is not None:
+ args.uuid = []
+ for path in args.mesh_path:
+ uuid = os.path.basename(path).split(".")[0]
+ args.uuid.append(uuid)
+
+ return args
+
+
+def entrypoint(**kwargs) -> None:
+ args = parse_args()
+ for k, v in kwargs.items():
+ if hasattr(args, k) and v is not None:
+ setattr(args, k, v)
+
+ camera_settings = CameraSetting(
+ num_images=args.num_images,
+ elevation=args.elevation,
+ distance=args.distance,
+ resolution_hw=args.resolution_hw,
+ fov=math.radians(args.fov),
+ device="cuda",
+ )
+
+ render_items = [
+ RenderItems.ALPHA.value,
+ RenderItems.GLOBAL_NORMAL.value,
+ RenderItems.VIEW_NORMAL.value,
+ RenderItems.POSITION_MAP.value,
+ RenderItems.IMAGE.value,
+ RenderItems.DEPTH.value,
+ # RenderItems.ALBEDO.value,
+ # RenderItems.DIFFUSE.value,
+ ]
+
+ gen_video = (
+ args.gen_color_gif
+ or args.gen_color_mp4
+ or args.gen_viewnormal_mp4
+ or args.gen_glonormal_mp4
+ )
+ if gen_video:
+ render_items = []
+ if args.gen_color_gif or args.gen_color_mp4:
+ render_items.append(RenderItems.IMAGE.value)
+ if args.gen_glonormal_mp4:
+ render_items.append(RenderItems.GLOBAL_NORMAL.value)
+ if args.gen_viewnormal_mp4:
+ render_items.append(RenderItems.VIEW_NORMAL.value)
+ if RenderItems.GLOBAL_NORMAL.value not in render_items:
+ render_items.append(RenderItems.GLOBAL_NORMAL.value)
+
+ image_render = ImageRender(
+ render_items=render_items,
+ camera_params=camera_settings,
+ with_mtl=args.with_mtl,
+ gen_color_gif=args.gen_color_gif,
+ gen_color_mp4=args.gen_color_mp4,
+ gen_viewnormal_mp4=args.gen_viewnormal_mp4,
+ gen_glonormal_mp4=args.gen_glonormal_mp4,
+ light_factor=args.pbr_light_factor,
+ no_index_file=gen_video,
+ )
+ image_render.render_mesh(
+ mesh_path=args.mesh_path,
+ output_root=args.output_root,
+ uuid=args.uuid,
+ prompts=args.prompts,
+ )
+
+ return
+
+
+if __name__ == "__main__":
+ entrypoint()
diff --git a/embodied_gen/data/mesh_operator.py b/embodied_gen/data/mesh_operator.py
new file mode 100644
index 0000000..888b203
--- /dev/null
+++ b/embodied_gen/data/mesh_operator.py
@@ -0,0 +1,452 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+
+import logging
+from typing import Tuple, Union
+
+import igraph
+import numpy as np
+import pyvista as pv
+import spaces
+import torch
+import utils3d
+from pymeshfix import _meshfix
+from tqdm import tqdm
+
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
+)
+logger = logging.getLogger(__name__)
+
+
+__all__ = ["MeshFixer"]
+
+
+def _radical_inverse(base, n):
+ val = 0
+ inv_base = 1.0 / base
+ inv_base_n = inv_base
+ while n > 0:
+ digit = n % base
+ val += digit * inv_base_n
+ n //= base
+ inv_base_n *= inv_base
+ return val
+
+
+def _halton_sequence(dim, n):
+ PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
+ return [_radical_inverse(PRIMES[dim], n) for dim in range(dim)]
+
+
+def _hammersley_sequence(dim, n, num_samples):
+ return [n / num_samples] + _halton_sequence(dim - 1, n)
+
+
+def _sphere_hammersley_seq(n, num_samples, offset=(0, 0), remap=False):
+ """Generate a point on a unit sphere using the Hammersley sequence.
+
+ Args:
+ n (int): The index of the sample.
+ num_samples (int): The total number of samples.
+ offset (tuple, optional): Offset for the u and v coordinates.
+ remap (bool, optional): Whether to remap the u coordinate.
+
+ Returns:
+ list: A list containing the spherical coordinates [phi, theta].
+ """
+ u, v = _hammersley_sequence(2, n, num_samples)
+ u += offset[0] / num_samples
+ v += offset[1]
+
+ if remap:
+ u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3
+
+ theta = np.arccos(1 - 2 * u) - np.pi / 2
+ phi = v * 2 * np.pi
+ return [phi, theta]
+
+
+class MeshFixer(object):
+ """MeshFixer simplifies and repairs 3D triangle meshes by TSDF.
+
+ Attributes:
+ vertices (torch.Tensor): A tensor of shape (V, 3) representing vertex positions.
+ faces (torch.Tensor): A tensor of shape (F, 3) representing face indices.
+ device (str): Device to run computations on, typically "cuda" or "cpu".
+
+ Main logic reference: https://github.com/microsoft/TRELLIS/blob/main/trellis/utils/postprocessing_utils.py#L22
+ """
+
+ def __init__(
+ self,
+ vertices: Union[torch.Tensor, np.ndarray],
+ faces: Union[torch.Tensor, np.ndarray],
+ device: str = "cuda",
+ ) -> None:
+ self.device = device
+ if isinstance(vertices, np.ndarray):
+ vertices = torch.tensor(vertices)
+ self.vertices = vertices
+
+ if isinstance(faces, np.ndarray):
+ faces = torch.tensor(faces)
+ self.faces = faces
+
+ @staticmethod
+ def log_mesh_changes(method):
+ def wrapper(self, *args, **kwargs):
+ logger.info(
+ f"Before {method.__name__}: {self.vertices.shape[0]} vertices, {self.faces.shape[0]} faces" # noqa
+ )
+ result = method(self, *args, **kwargs)
+ logger.info(
+ f"After {method.__name__}: {self.vertices.shape[0]} vertices, {self.faces.shape[0]} faces" # noqa
+ )
+ return result
+
+ return wrapper
+
+ @log_mesh_changes
+ def fill_holes(
+ self,
+ max_hole_size: float,
+ max_hole_nbe: int,
+ resolution: int,
+ num_views: int,
+ norm_mesh_ratio: float = 1.0,
+ ) -> None:
+ self.vertices = self.vertices * norm_mesh_ratio
+ vertices, self.faces = self._fill_holes(
+ self.vertices,
+ self.faces,
+ max_hole_size,
+ max_hole_nbe,
+ resolution,
+ num_views,
+ )
+ self.vertices = vertices / norm_mesh_ratio
+
+ @staticmethod
+ @torch.no_grad()
+ def _fill_holes(
+ vertices: torch.Tensor,
+ faces: torch.Tensor,
+ max_hole_size: float,
+ max_hole_nbe: int,
+ resolution: int,
+ num_views: int,
+ ) -> Union[torch.Tensor, torch.Tensor]:
+ yaws, pitchs = [], []
+ for i in range(num_views):
+ y, p = _sphere_hammersley_seq(i, num_views)
+ yaws.append(y)
+ pitchs.append(p)
+
+ yaws, pitchs = torch.tensor(yaws).to(vertices), torch.tensor(
+ pitchs
+ ).to(vertices)
+ radius, fov = 2.0, torch.deg2rad(torch.tensor(40)).to(vertices)
+ projection = utils3d.torch.perspective_from_fov_xy(fov, fov, 1, 3)
+
+ views = []
+ for yaw, pitch in zip(yaws, pitchs):
+ orig = (
+ torch.tensor(
+ [
+ torch.sin(yaw) * torch.cos(pitch),
+ torch.cos(yaw) * torch.cos(pitch),
+ torch.sin(pitch),
+ ]
+ ).to(vertices)
+ * radius
+ )
+ view = utils3d.torch.view_look_at(
+ orig,
+ torch.tensor([0, 0, 0]).to(vertices),
+ torch.tensor([0, 0, 1]).to(vertices),
+ )
+ views.append(view)
+ views = torch.stack(views, dim=0)
+
+ # Rasterize the mesh
+ visibility = torch.zeros(
+ faces.shape[0], dtype=torch.int32, device=faces.device
+ )
+ rastctx = utils3d.torch.RastContext(backend="cuda")
+
+ for i in tqdm(
+ range(views.shape[0]), total=views.shape[0], desc="Rasterizing"
+ ):
+ view = views[i]
+ buffers = utils3d.torch.rasterize_triangle_faces(
+ rastctx,
+ vertices[None],
+ faces,
+ resolution,
+ resolution,
+ view=view,
+ projection=projection,
+ )
+ face_id = buffers["face_id"][0][buffers["mask"][0] > 0.95] - 1
+ face_id = torch.unique(face_id).long()
+ visibility[face_id] += 1
+
+ # Normalize visibility by the number of views
+ visibility = visibility.float() / num_views
+
+ # Mincut: Identify outer and inner faces
+ edges, face2edge, edge_degrees = utils3d.torch.compute_edges(faces)
+ boundary_edge_indices = torch.nonzero(edge_degrees == 1).reshape(-1)
+ connected_components = utils3d.torch.compute_connected_components(
+ faces, edges, face2edge
+ )
+
+ outer_face_indices = torch.zeros(
+ faces.shape[0], dtype=torch.bool, device=faces.device
+ )
+ for i in range(len(connected_components)):
+ outer_face_indices[connected_components[i]] = visibility[
+ connected_components[i]
+ ] > min(
+ max(
+ visibility[connected_components[i]].quantile(0.75).item(),
+ 0.25,
+ ),
+ 0.5,
+ )
+
+ outer_face_indices = outer_face_indices.nonzero().reshape(-1)
+ inner_face_indices = torch.nonzero(visibility == 0).reshape(-1)
+
+ if inner_face_indices.shape[0] == 0:
+ return vertices, faces
+
+ # Construct dual graph (faces as nodes, edges as edges)
+ dual_edges, dual_edge2edge = utils3d.torch.compute_dual_graph(
+ face2edge
+ )
+ dual_edge2edge = edges[dual_edge2edge]
+ dual_edges_weights = torch.norm(
+ vertices[dual_edge2edge[:, 0]] - vertices[dual_edge2edge[:, 1]],
+ dim=1,
+ )
+
+ # Mincut: Construct main graph and solve the mincut problem
+ g = igraph.Graph()
+ g.add_vertices(faces.shape[0])
+ g.add_edges(dual_edges.cpu().numpy())
+ g.es["weight"] = dual_edges_weights.cpu().numpy()
+
+ g.add_vertex("s") # source
+ g.add_vertex("t") # target
+
+ g.add_edges(
+ [(f, "s") for f in inner_face_indices],
+ attributes={
+ "weight": torch.ones(
+ inner_face_indices.shape[0], dtype=torch.float32
+ )
+ .cpu()
+ .numpy()
+ },
+ )
+ g.add_edges(
+ [(f, "t") for f in outer_face_indices],
+ attributes={
+ "weight": torch.ones(
+ outer_face_indices.shape[0], dtype=torch.float32
+ )
+ .cpu()
+ .numpy()
+ },
+ )
+
+ cut = g.mincut("s", "t", (np.array(g.es["weight"]) * 1000).tolist())
+ remove_face_indices = torch.tensor(
+ [v for v in cut.partition[0] if v < faces.shape[0]],
+ dtype=torch.long,
+ device=faces.device,
+ )
+
+ # Check if the cut is valid with each connected component
+ to_remove_cc = utils3d.torch.compute_connected_components(
+ faces[remove_face_indices]
+ )
+ valid_remove_cc = []
+ cutting_edges = []
+ for cc in to_remove_cc:
+ # Check visibility median for connected component
+ visibility_median = visibility[remove_face_indices[cc]].median()
+ if visibility_median > 0.25:
+ continue
+
+ # Check if the cutting loop is small enough
+ cc_edge_indices, cc_edges_degree = torch.unique(
+ face2edge[remove_face_indices[cc]], return_counts=True
+ )
+ cc_boundary_edge_indices = cc_edge_indices[cc_edges_degree == 1]
+ cc_new_boundary_edge_indices = cc_boundary_edge_indices[
+ ~torch.isin(cc_boundary_edge_indices, boundary_edge_indices)
+ ]
+ if len(cc_new_boundary_edge_indices) > 0:
+ cc_new_boundary_edge_cc = (
+ utils3d.torch.compute_edge_connected_components(
+ edges[cc_new_boundary_edge_indices]
+ )
+ )
+ cc_new_boundary_edges_cc_center = [
+ vertices[edges[cc_new_boundary_edge_indices[edge_cc]]]
+ .mean(dim=1)
+ .mean(dim=0)
+ for edge_cc in cc_new_boundary_edge_cc
+ ]
+ cc_new_boundary_edges_cc_area = []
+ for i, edge_cc in enumerate(cc_new_boundary_edge_cc):
+ _e1 = (
+ vertices[
+ edges[cc_new_boundary_edge_indices[edge_cc]][:, 0]
+ ]
+ - cc_new_boundary_edges_cc_center[i]
+ )
+ _e2 = (
+ vertices[
+ edges[cc_new_boundary_edge_indices[edge_cc]][:, 1]
+ ]
+ - cc_new_boundary_edges_cc_center[i]
+ )
+ cc_new_boundary_edges_cc_area.append(
+ torch.norm(torch.cross(_e1, _e2, dim=-1), dim=1).sum()
+ * 0.5
+ )
+ cutting_edges.append(cc_new_boundary_edge_indices)
+ if any(
+ [
+ _l > max_hole_size
+ for _l in cc_new_boundary_edges_cc_area
+ ]
+ ):
+ continue
+
+ valid_remove_cc.append(cc)
+
+ if len(valid_remove_cc) > 0:
+ remove_face_indices = remove_face_indices[
+ torch.cat(valid_remove_cc)
+ ]
+ mask = torch.ones(
+ faces.shape[0], dtype=torch.bool, device=faces.device
+ )
+ mask[remove_face_indices] = 0
+ faces = faces[mask]
+ faces, vertices = utils3d.torch.remove_unreferenced_vertices(
+ faces, vertices
+ )
+
+ tqdm.write(f"Removed {(~mask).sum()} faces by mincut")
+ else:
+ tqdm.write(f"Removed 0 faces by mincut")
+
+ # Fill small boundaries (holes)
+ mesh = _meshfix.PyTMesh()
+ mesh.load_array(vertices.cpu().numpy(), faces.cpu().numpy())
+ mesh.fill_small_boundaries(nbe=max_hole_nbe, refine=True)
+
+ _vertices, _faces = mesh.return_arrays()
+ vertices = torch.tensor(_vertices).to(vertices)
+ faces = torch.tensor(_faces).to(faces)
+
+ return vertices, faces
+
+ @property
+ def vertices_np(self) -> np.ndarray:
+ return self.vertices.cpu().numpy()
+
+ @property
+ def faces_np(self) -> np.ndarray:
+ return self.faces.cpu().numpy()
+
+ @log_mesh_changes
+ def simplify(self, ratio: float) -> None:
+ """Simplify the mesh using quadric edge collapse decimation.
+
+ Args:
+ ratio (float): Ratio of faces to filter out.
+ """
+ if ratio <= 0 or ratio >= 1:
+ raise ValueError("Simplify ratio must be between 0 and 1.")
+
+ # Convert to PyVista format for simplification
+ mesh = pv.PolyData(
+ self.vertices_np,
+ np.hstack([np.full((self.faces.shape[0], 1), 3), self.faces_np]),
+ )
+ mesh = mesh.decimate(ratio, progress_bar=True)
+
+ # Update vertices and faces
+ self.vertices = torch.tensor(
+ mesh.points, device=self.device, dtype=torch.float32
+ )
+ self.faces = torch.tensor(
+ mesh.faces.reshape(-1, 4)[:, 1:],
+ device=self.device,
+ dtype=torch.int32,
+ )
+
+ @spaces.GPU
+ def __call__(
+ self,
+ filter_ratio: float,
+ max_hole_size: float,
+ resolution: int,
+ num_views: int,
+ norm_mesh_ratio: float = 1.0,
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """Post-process the mesh by simplifying and filling holes.
+
+ This method performs a two-step process:
+ 1. Simplifies mesh by reducing faces using quadric edge decimation.
+ 2. Fills holes by removing invisible faces, repairing small boundaries.
+
+ Args:
+ filter_ratio (float): Ratio of faces to simplify out.
+ Must be in the range (0, 1).
+ max_hole_size (float): Maximum area of a hole to fill. Connected
+ components of holes larger than this size will not be repaired.
+ resolution (int): Resolution of the rasterization buffer.
+ num_views (int): Number of viewpoints to sample for rasterization.
+ norm_mesh_ratio (float, optional): A scaling factor applied to the
+ vertices of the mesh during processing.
+
+ Returns:
+ Tuple[np.ndarray, np.ndarray]:
+ - vertices: Simplified and repaired vertex array of (V, 3).
+ - faces: Simplified and repaired face array of (F, 3).
+ """
+ self.vertices = self.vertices.to(self.device)
+ self.faces = self.faces.to(self.device)
+
+ self.simplify(ratio=filter_ratio)
+ self.fill_holes(
+ max_hole_size=max_hole_size,
+ max_hole_nbe=int(250 * np.sqrt(1 - filter_ratio)),
+ resolution=resolution,
+ num_views=num_views,
+ norm_mesh_ratio=norm_mesh_ratio,
+ )
+
+ return self.vertices_np, self.faces_np
diff --git a/embodied_gen/data/utils.py b/embodied_gen/data/utils.py
new file mode 100644
index 0000000..31a65f8
--- /dev/null
+++ b/embodied_gen/data/utils.py
@@ -0,0 +1,1009 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+
+import math
+import os
+import random
+import zipfile
+from shutil import rmtree
+from typing import List, Tuple, Union
+
+import cv2
+import kaolin as kal
+import numpy as np
+import nvdiffrast.torch as dr
+import torch
+import torch.nn.functional as F
+from PIL import Image
+
+try:
+ from kolors.models.modeling_chatglm import ChatGLMModel
+ from kolors.models.tokenization_chatglm import ChatGLMTokenizer
+except ImportError:
+ ChatGLMTokenizer = None
+ ChatGLMModel = None
+import logging
+from dataclasses import dataclass, field
+from enum import Enum
+
+import trimesh
+from kaolin.render.camera import Camera
+from torch import nn
+
+logger = logging.getLogger(__name__)
+
+
+__all__ = [
+ "DiffrastRender",
+ "save_images",
+ "render_pbr",
+ "prelabel_text_feature",
+ "calc_vertex_normals",
+ "normalize_vertices_array",
+ "load_mesh_to_unit_cube",
+ "as_list",
+ "CameraSetting",
+ "RenderItems",
+ "import_kaolin_mesh",
+ "save_mesh_with_mtl",
+ "get_images_from_grid",
+ "post_process_texture",
+ "quat_mult",
+ "quat_to_rotmat",
+ "gamma_shs",
+ "resize_pil",
+ "trellis_preprocess",
+ "delete_dir",
+]
+
+
+class DiffrastRender(object):
+ """A class to handle differentiable rendering using nvdiffrast.
+
+ This class provides methods to render position, depth, and normal maps
+ with optional anti-aliasing and gradient disabling for rasterization.
+
+ Attributes:
+ p_mtx (torch.Tensor): Projection matrix.
+ mv_mtx (torch.Tensor): Model-view matrix.
+ mvp_mtx (torch.Tensor): Model-view-projection matrix, calculated as
+ p_mtx @ mv_mtx if not provided.
+ resolution_hw (Tuple[int, int]): Height and width of the rendering resolution. # noqa
+ _ctx (Union[dr.RasterizeCudaContext, dr.RasterizeGLContext]): Rasterization context. # noqa
+ mask_thresh (float): Threshold for mask creation.
+ grad_db (bool): Whether to disable gradients during rasterization.
+ antialias_mask (bool): Whether to apply anti-aliasing to the mask.
+ device (str): Device used for rendering ('cuda' or 'cpu').
+ """
+
+ def __init__(
+ self,
+ p_matrix: torch.Tensor,
+ mv_matrix: torch.Tensor,
+ resolution_hw: Tuple[int, int],
+ context: Union[dr.RasterizeCudaContext, dr.RasterizeGLContext] = None,
+ mvp_matrix: torch.Tensor = None,
+ mask_thresh: float = 0.5,
+ grad_db: bool = False,
+ antialias_mask: bool = True,
+ align_coordinate: bool = True,
+ device: str = "cuda",
+ ) -> None:
+ self.p_mtx = p_matrix
+ self.mv_mtx = mv_matrix
+ if mvp_matrix is None:
+ self.mvp_mtx = torch.bmm(p_matrix, mv_matrix)
+
+ self.resolution_hw = resolution_hw
+ if context is None:
+ context = dr.RasterizeCudaContext(device=device)
+ self._ctx = context
+ self.mask_thresh = mask_thresh
+ self.grad_db = grad_db
+ self.antialias_mask = antialias_mask
+ self.align_coordinate = align_coordinate
+ self.device = device
+
+ def compute_dr_raster(
+ self,
+ vertices: torch.Tensor,
+ faces: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ vertices_clip = self.transform_vertices(vertices, matrix=self.mvp_mtx)
+ rast, _ = dr.rasterize(
+ self._ctx,
+ vertices_clip,
+ faces.int(),
+ resolution=self.resolution_hw,
+ grad_db=self.grad_db,
+ )
+
+ return rast, vertices_clip
+
+ def transform_vertices(
+ self,
+ vertices: torch.Tensor,
+ matrix: torch.Tensor,
+ ) -> torch.Tensor:
+ verts_ones = torch.ones((len(vertices), 1)).to(vertices)
+ verts_homo = torch.cat([vertices, verts_ones], dim=-1)
+ trans_vertices = torch.matmul(verts_homo, matrix.permute(0, 2, 1))
+
+ return trans_vertices
+
+ def normalize_map_by_mask_separately(
+ self, map: torch.Tensor, mask: torch.Tensor
+ ) -> torch.Tensor:
+ # Normalize each map separately by mask, normalized map in [0, 1].
+ normalized_maps = []
+ for map_item, mask_item in zip(map, mask):
+ normalized_map = self.normalize_map_by_mask(map_item, mask_item)
+ normalized_maps.append(normalized_map)
+
+ normalized_maps = torch.stack(normalized_maps, dim=0)
+
+ return normalized_maps
+
+ def normalize_map_by_mask(
+ self, map: torch.Tensor, mask: torch.Tensor
+ ) -> torch.Tensor:
+ # Normalize all maps in total by mask, normalized map in [0, 1].
+ foreground = (mask == 1).squeeze(dim=-1)
+ foreground_elements = map[foreground]
+ if len(foreground_elements) == 0:
+ return map
+
+ min_val, _ = foreground_elements.min(dim=0)
+ max_val, _ = foreground_elements.max(dim=0)
+ val_range = (max_val - min_val).clip(min=1e-6)
+
+ normalized_map = (map - min_val) / val_range
+ normalized_map = torch.lerp(
+ torch.zeros_like(normalized_map), normalized_map, mask
+ )
+ normalized_map[normalized_map < 0] = 0
+
+ return normalized_map
+
+ def _compute_mask(
+ self,
+ rast: torch.Tensor,
+ vertices_clip: torch.Tensor,
+ faces: torch.Tensor,
+ ) -> torch.Tensor:
+ mask = (rast[..., 3:] > 0).float()
+ mask = mask.clip(min=0, max=1)
+
+ if self.antialias_mask is True:
+ mask = dr.antialias(mask, rast, vertices_clip, faces)
+ else:
+ foreground = mask > self.mask_thresh
+ mask[foreground] = 1
+ mask[~foreground] = 0
+
+ return mask
+
+ def render_rast_alpha(
+ self,
+ vertices: torch.Tensor,
+ faces: torch.Tensor,
+ ):
+ faces = faces.to(torch.int32)
+ rast, vertices_clip = self.compute_dr_raster(vertices, faces)
+ mask = self._compute_mask(rast, vertices_clip, faces)
+
+ return mask, rast
+
+ def render_position(
+ self,
+ vertices: torch.Tensor,
+ faces: torch.Tensor,
+ ) -> Union[torch.Tensor, torch.Tensor]:
+ # Vertices in model coordinate system, real position coordinate number.
+ faces = faces.to(torch.int32)
+ mask, rast = self.render_rast_alpha(vertices, faces)
+
+ vertices_model = vertices[None, ...].contiguous().float()
+ position_map, _ = dr.interpolate(vertices_model, rast, faces)
+ # Align with blender.
+ if self.align_coordinate:
+ position_map = position_map[..., [0, 2, 1]]
+ position_map[..., 1] = -position_map[..., 1]
+
+ position_map = torch.lerp(
+ torch.zeros_like(position_map), position_map, mask
+ )
+
+ return position_map, mask
+
+ def render_uv(
+ self,
+ vertices: torch.Tensor,
+ faces: torch.Tensor,
+ vtx_uv: torch.Tensor,
+ ) -> Union[torch.Tensor, torch.Tensor]:
+ faces = faces.to(torch.int32)
+ mask, rast = self.render_rast_alpha(vertices, faces)
+ uv_map, _ = dr.interpolate(vtx_uv, rast, faces)
+ uv_map = torch.lerp(torch.zeros_like(uv_map), uv_map, mask)
+
+ return uv_map, mask
+
+ def render_depth(
+ self,
+ vertices: torch.Tensor,
+ faces: torch.Tensor,
+ ) -> Union[torch.Tensor, torch.Tensor]:
+ # Vertices in model coordinate system, real depth coordinate number.
+ faces = faces.to(torch.int32)
+ mask, rast = self.render_rast_alpha(vertices, faces)
+
+ vertices_camera = self.transform_vertices(vertices, matrix=self.mv_mtx)
+ vertices_camera = vertices_camera[..., 2:3].contiguous().float()
+ depth_map, _ = dr.interpolate(vertices_camera, rast, faces)
+ # Change camera depth minus to positive.
+ if self.align_coordinate:
+ depth_map = -depth_map
+ depth_map = torch.lerp(torch.zeros_like(depth_map), depth_map, mask)
+
+ return depth_map, mask
+
+ def render_global_normal(
+ self,
+ vertices: torch.Tensor,
+ faces: torch.Tensor,
+ vertice_normals: torch.Tensor,
+ ) -> Union[torch.Tensor, torch.Tensor]:
+ # NOTE: vertice_normals in [-1, 1], return normal in [0, 1].
+ # vertices / vertice_normals in model coordinate system.
+ faces = faces.to(torch.int32)
+ mask, rast = self.render_rast_alpha(vertices, faces)
+ im_base_normals, _ = dr.interpolate(
+ vertice_normals[None, ...].float(), rast, faces
+ )
+
+ if im_base_normals is not None:
+ faces = faces.to(torch.int64)
+ vertices_cam = self.transform_vertices(
+ vertices, matrix=self.mv_mtx
+ )
+ face_vertices_ndc = kal.ops.mesh.index_vertices_by_faces(
+ vertices_cam[..., :3], faces
+ )
+ face_normal_sign = kal.ops.mesh.face_normals(face_vertices_ndc)[
+ ..., 2
+ ]
+ for idx in range(len(im_base_normals)):
+ face_idx = (rast[idx, ..., -1].long() - 1).contiguous()
+ im_normal_sign = torch.sign(face_normal_sign[idx, face_idx])
+ im_normal_sign[face_idx == -1] = 0
+ im_base_normals[idx] *= im_normal_sign.unsqueeze(-1)
+
+ normal = (im_base_normals + 1) / 2
+ normal = normal.clip(min=0, max=1)
+ normal = torch.lerp(torch.zeros_like(normal), normal, mask)
+
+ return normal, mask
+
+ def transform_normal(
+ self,
+ normals: torch.Tensor,
+ trans_matrix: torch.Tensor,
+ masks: torch.Tensor,
+ to_view: bool,
+ ) -> torch.Tensor:
+ # NOTE: input normals in [0, 1], output normals in [0, 1].
+ normals = normals.clone()
+ assert len(normals) == len(trans_matrix)
+
+ if not to_view:
+ # Flip the sign on the x-axis to match inv bae system for global transformation. # noqa
+ normals[..., 0] = 1 - normals[..., 0]
+
+ normals = 2 * normals - 1
+ b, h, w, c = normals.shape
+
+ transformed_normals = []
+ for normal, matrix in zip(normals, trans_matrix):
+ # Transform normals using the transformation matrix (4x4).
+ reshaped_normals = normal.view(-1, c) # (h w 3) -> (hw 3)
+ padded_vectors = torch.nn.functional.pad(
+ reshaped_normals, pad=(0, 1), mode="constant", value=0.0
+ )
+ transformed_normal = torch.matmul(
+ padded_vectors, matrix.transpose(0, 1)
+ )[..., :3]
+
+ # Normalize and clip the normals to [0, 1] range.
+ transformed_normal = F.normalize(transformed_normal, p=2, dim=-1)
+ transformed_normal = (transformed_normal + 1) / 2
+
+ if to_view:
+ # Flip the sign on the x-axis to match bae system for view transformation. # noqa
+ transformed_normal[..., 0] = 1 - transformed_normal[..., 0]
+
+ transformed_normals.append(transformed_normal.view(h, w, c))
+
+ transformed_normals = torch.stack(transformed_normals, dim=0)
+
+ if masks is not None:
+ transformed_normals = torch.lerp(
+ torch.zeros_like(transformed_normals),
+ transformed_normals,
+ masks,
+ )
+
+ return transformed_normals
+
+
+def _az_el_to_points(
+ azimuths: np.ndarray, elevations: np.ndarray
+) -> np.ndarray:
+ x = np.cos(azimuths) * np.cos(elevations)
+ y = np.sin(azimuths) * np.cos(elevations)
+ z = np.sin(elevations)
+
+ return np.stack([x, y, z], axis=-1)
+
+
+def _compute_az_el_by_views(
+ num_view: int, el: float
+) -> Tuple[np.ndarray, np.ndarray]:
+ azimuths = np.arange(num_view) / num_view * np.pi * 2
+ elevations = np.deg2rad(np.array([el] * num_view))
+
+ return azimuths, elevations
+
+
+def _compute_cam_pts_by_az_el(
+ azs: np.ndarray,
+ els: np.ndarray,
+ distance: float,
+ extra_pts: np.ndarray = None,
+) -> np.ndarray:
+ distances = np.array([distance for _ in range(len(azs))])
+ cam_pts = _az_el_to_points(azs, els) * distances[:, None]
+
+ if extra_pts is not None:
+ cam_pts = np.concatenate([cam_pts, extra_pts], axis=0)
+
+ # Align coordinate system.
+ cam_pts = cam_pts[:, [0, 2, 1]] # xyz -> xzy
+ cam_pts[..., 2] = -cam_pts[..., 2]
+
+ return cam_pts
+
+
+def compute_cam_pts_by_views(
+ num_view: int, el: float, distance: float, extra_pts: np.ndarray = None
+) -> torch.Tensor:
+ """Computes object-center camera points for a given number of views.
+
+ Args:
+ num_view (int): The number of views (camera positions) to compute.
+ el (float): The elevation angle in degrees.
+ distance (float): The distance from the origin to the camera.
+ extra_pts (np.ndarray): Extra camera points postion.
+
+ Returns:
+ torch.Tensor: A tensor containing the camera points for each view, with shape `(num_view, 3)`. # noqa
+ """
+ azimuths, elevations = _compute_az_el_by_views(num_view, el)
+ cam_pts = _compute_cam_pts_by_az_el(
+ azimuths, elevations, distance, extra_pts
+ )
+
+ return cam_pts
+
+
+def save_images(
+ images: Union[list[np.ndarray], list[torch.Tensor]],
+ output_dir: str,
+ cvt_color: str = None,
+ format: str = ".png",
+ to_uint8: bool = True,
+ verbose: bool = False,
+) -> List[str]:
+ # NOTE: images in [0, 1]
+ os.makedirs(output_dir, exist_ok=True)
+ save_paths = []
+ for idx, image in enumerate(images):
+ if isinstance(image, torch.Tensor):
+ image = image.detach().cpu().numpy()
+ if to_uint8:
+ image = image.clip(min=0, max=1)
+ image = (255.0 * image).astype(np.uint8)
+ if cvt_color is not None:
+ image = cv2.cvtColor(image, cvt_color)
+ save_path = os.path.join(output_dir, f"{idx:04d}{format}")
+ save_paths.append(save_path)
+
+ cv2.imwrite(save_path, image)
+
+ if verbose:
+ logger.info(f"Images saved in {output_dir}")
+
+ return save_paths
+
+
+def _current_lighting(
+ azimuths: List[float],
+ elevations: List[float],
+ light_factor: float = 1.0,
+ device: str = "cuda",
+):
+ # azimuths, elevations in degress.
+ directions = []
+ for az, el in zip(azimuths, elevations):
+ az, el = math.radians(az), math.radians(el)
+ direction = kal.render.lighting.sg_direction_from_azimuth_elevation(
+ az, el
+ )
+ directions.append(direction)
+ directions = torch.cat(directions, dim=0)
+
+ amplitude = torch.ones_like(directions) * light_factor
+ light_condition = kal.render.lighting.SgLightingParameters(
+ amplitude=amplitude,
+ direction=directions,
+ sharpness=3,
+ ).to(device)
+
+ # light_condition = kal.render.lighting.SgLightingParameters.from_sun(
+ # directions, strength=1, angle=90, color=None
+ # ).to(device)
+
+ return light_condition
+
+
+def render_pbr(
+ mesh,
+ camera,
+ device="cuda",
+ cxt=None,
+ custom_materials=None,
+ light_factor=1.0,
+):
+ if cxt is None:
+ cxt = dr.RasterizeCudaContext()
+
+ light_condition = _current_lighting(
+ azimuths=[0, 90, 180, 270],
+ elevations=[90, 60, 30, 20],
+ light_factor=light_factor,
+ device=device,
+ )
+ render_res = kal.render.easy_render.render_mesh(
+ camera,
+ mesh,
+ lighting=light_condition,
+ nvdiffrast_context=cxt,
+ custom_materials=custom_materials,
+ )
+
+ image = render_res[kal.render.easy_render.RenderPass.render]
+ image = image.clip(0, 1)
+
+ albedo = render_res[kal.render.easy_render.RenderPass.albedo]
+ albedo = albedo.clip(0, 1)
+
+ diffuse = render_res[kal.render.easy_render.RenderPass.diffuse]
+ diffuse = diffuse.clip(0, 1)
+
+ normal = render_res[kal.render.easy_render.RenderPass.normals]
+ normal = normal.clip(-1, 1)
+
+ return image, albedo, diffuse, normal
+
+
+def _move_to_target_device(data, device: str):
+ if isinstance(data, dict):
+ for key, value in data.items():
+ data[key] = _move_to_target_device(value, device)
+ elif isinstance(data, torch.Tensor):
+ return data.to(device)
+
+ return data
+
+
+def _encode_prompt(
+ prompt_batch,
+ text_encoders,
+ tokenizers,
+ proportion_empty_prompts=0,
+ is_train=True,
+):
+ prompt_embeds_list = []
+
+ captions = []
+ for caption in prompt_batch:
+ if random.random() < proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ captions.append(random.choice(caption) if is_train else caption[0])
+
+ with torch.no_grad():
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ text_inputs = tokenizer(
+ captions,
+ padding="max_length",
+ max_length=256,
+ truncation=True,
+ return_tensors="pt",
+ ).to(text_encoder.device)
+
+ output = text_encoder(
+ input_ids=text_inputs.input_ids,
+ attention_mask=text_inputs.attention_mask,
+ position_ids=text_inputs.position_ids,
+ output_hidden_states=True,
+ )
+
+ # We are only interested in the pooled output of the text encoder.
+ prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
+ pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone()
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def load_llm_models(pretrained_model_name_or_path: str, device: str):
+ tokenizer = ChatGLMTokenizer.from_pretrained(
+ pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ )
+ text_encoder = ChatGLMModel.from_pretrained(
+ pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ ).to(device)
+
+ text_encoders = [
+ text_encoder,
+ ]
+ tokenizers = [
+ tokenizer,
+ ]
+
+ logger.info(f"Load model from {pretrained_model_name_or_path} done.")
+
+ return tokenizers, text_encoders
+
+
+def prelabel_text_feature(
+ prompt_batch: List[str],
+ output_dir: str,
+ tokenizers: nn.Module,
+ text_encoders: nn.Module,
+) -> List[str]:
+ os.makedirs(output_dir, exist_ok=True)
+
+ # prompt_batch ["text..."]
+ prompt_embeds, pooled_prompt_embeds = _encode_prompt(
+ prompt_batch, text_encoders, tokenizers
+ )
+
+ prompt_embeds = _move_to_target_device(prompt_embeds, device="cpu")
+ pooled_prompt_embeds = _move_to_target_device(
+ pooled_prompt_embeds, device="cpu"
+ )
+
+ data_dict = dict(
+ prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds
+ )
+
+ save_path = os.path.join(output_dir, "text_feat.pth")
+ torch.save(data_dict, save_path)
+
+ return save_path
+
+
+def _calc_face_normals(
+ vertices: torch.Tensor, # V,3 first vertex may be unreferenced
+ faces: torch.Tensor, # F,3 long, first face may be all zero
+ normalize: bool = False,
+) -> torch.Tensor: # F,3
+ full_vertices = vertices[faces] # F,C=3,3
+ v0, v1, v2 = full_vertices.unbind(dim=1) # F,3
+ face_normals = torch.cross(v1 - v0, v2 - v0, dim=1) # F,3
+ if normalize:
+ face_normals = F.normalize(
+ face_normals, eps=1e-6, dim=1
+ ) # TODO inplace?
+ return face_normals # F,3
+
+
+def calc_vertex_normals(
+ vertices: torch.Tensor, # V,3 first vertex may be unreferenced
+ faces: torch.Tensor, # F,3 long, first face may be all zero
+ face_normals: torch.Tensor = None, # F,3, not normalized
+) -> torch.Tensor: # F,3
+ _F = faces.shape[0]
+
+ if face_normals is None:
+ face_normals = _calc_face_normals(vertices, faces)
+
+ vertex_normals = torch.zeros(
+ (vertices.shape[0], 3, 3), dtype=vertices.dtype, device=vertices.device
+ ) # V,C=3,3
+ vertex_normals.scatter_add_(
+ dim=0,
+ index=faces[:, :, None].expand(_F, 3, 3),
+ src=face_normals[:, None, :].expand(_F, 3, 3),
+ )
+ vertex_normals = vertex_normals.sum(dim=1) # V,3
+ return F.normalize(vertex_normals, eps=1e-6, dim=1)
+
+
+def normalize_vertices_array(
+ vertices: Union[torch.Tensor, np.ndarray],
+ mesh_scale: float = 1.0,
+ exec_norm: bool = True,
+):
+ if isinstance(vertices, torch.Tensor):
+ bbmin, bbmax = vertices.min(0)[0], vertices.max(0)[0]
+ else:
+ bbmin, bbmax = vertices.min(0), vertices.max(0) # (3,)
+ center = (bbmin + bbmax) * 0.5
+ bbsize = bbmax - bbmin
+ scale = 2 * mesh_scale / bbsize.max()
+ if exec_norm:
+ vertices = (vertices - center) * scale
+
+ return vertices, scale, center
+
+
+def load_mesh_to_unit_cube(
+ mesh_file: str,
+ mesh_scale: float = 1.0,
+) -> tuple[trimesh.Trimesh, float, list[float]]:
+ if not os.path.exists(mesh_file):
+ raise FileNotFoundError(f"mesh_file path {mesh_file} not exists.")
+
+ mesh = trimesh.load(mesh_file)
+ if isinstance(mesh, trimesh.Scene):
+ mesh = trimesh.utils.concatenate(mesh)
+
+ vertices, scale, center = normalize_vertices_array(
+ mesh.vertices, mesh_scale
+ )
+ mesh.vertices = vertices
+
+ return mesh, scale, center
+
+
+def as_list(obj):
+ if isinstance(obj, (list, tuple)):
+ return obj
+ elif isinstance(obj, set):
+ return list(obj)
+ else:
+ return [obj]
+
+
+@dataclass
+class CameraSetting:
+ """Camera settings for images rendering."""
+
+ num_images: int
+ elevation: list[float]
+ distance: float
+ resolution_hw: tuple[int, int]
+ fov: float
+ at: tuple[float, float, float] = field(
+ default_factory=lambda: (0.0, 0.0, 0.0)
+ )
+ up: tuple[float, float, float] = field(
+ default_factory=lambda: (0.0, 1.0, 0.0)
+ )
+ device: str = "cuda"
+ near: float = 1e-2
+ far: float = 1e2
+
+ def __post_init__(
+ self,
+ ):
+ h = self.resolution_hw[0]
+ f = (h / 2) / math.tan(self.fov / 2)
+ cx = self.resolution_hw[1] / 2
+ cy = self.resolution_hw[0] / 2
+ Ks = [
+ [f, 0, cx],
+ [0, f, cy],
+ [0, 0, 1],
+ ]
+
+ self.Ks = Ks
+
+
+@dataclass
+class RenderItems(str, Enum):
+ IMAGE = "image_color"
+ ALPHA = "image_mask"
+ VIEW_NORMAL = "image_view_normal"
+ GLOBAL_NORMAL = "image_global_normal"
+ POSITION_MAP = "image_position"
+ DEPTH = "image_depth"
+ ALBEDO = "image_albedo"
+ DIFFUSE = "image_diffuse"
+
+
+def _compute_az_el_by_camera_params(
+ camera_params: CameraSetting, flip_az: bool = False
+):
+ num_view = camera_params.num_images // len(camera_params.elevation)
+ view_interval = 2 * np.pi / num_view / 2
+ azimuths = []
+ elevations = []
+ for idx, el in enumerate(camera_params.elevation):
+ azs = np.arange(num_view) / num_view * np.pi * 2 + idx * view_interval
+ if flip_az:
+ azs *= -1
+ els = np.deg2rad(np.array([el] * num_view))
+ azimuths.append(azs)
+ elevations.append(els)
+
+ azimuths = np.concatenate(azimuths, axis=0)
+ elevations = np.concatenate(elevations, axis=0)
+
+ return azimuths, elevations
+
+
+def init_kal_camera(camera_params: CameraSetting) -> Camera:
+ azimuths, elevations = _compute_az_el_by_camera_params(camera_params)
+ cam_pts = _compute_cam_pts_by_az_el(
+ azimuths, elevations, camera_params.distance
+ )
+
+ up = torch.cat(
+ [
+ torch.tensor(camera_params.up).repeat(camera_params.num_images, 1),
+ ],
+ dim=0,
+ )
+
+ camera = Camera.from_args(
+ eye=torch.tensor(cam_pts),
+ at=torch.tensor(camera_params.at),
+ up=up,
+ fov=camera_params.fov,
+ height=camera_params.resolution_hw[0],
+ width=camera_params.resolution_hw[1],
+ near=camera_params.near,
+ far=camera_params.far,
+ device=camera_params.device,
+ )
+
+ return camera
+
+
+def import_kaolin_mesh(mesh_path: str, with_mtl: bool = False):
+ if mesh_path.endswith(".glb"):
+ mesh = kal.io.gltf.import_mesh(mesh_path)
+ elif mesh_path.endswith(".obj"):
+ with_material = True if with_mtl else False
+ mesh = kal.io.obj.import_mesh(mesh_path, with_materials=with_material)
+ if with_mtl and mesh.materials and len(mesh.materials) > 0:
+ material = kal.render.materials.PBRMaterial()
+ assert (
+ "map_Kd" in mesh.materials[0]
+ ), "'map_Kd' not found in materials."
+ material.diffuse_texture = mesh.materials[0]["map_Kd"] / 255.0
+ mesh.materials = [material]
+ elif mesh_path.endswith(".ply"):
+ mesh = trimesh.load(mesh_path)
+ mesh_path = mesh_path.replace(".ply", ".obj")
+ mesh.export(mesh_path)
+ mesh = kal.io.obj.import_mesh(mesh_path)
+ elif mesh_path.endswith(".off"):
+ mesh = kal.io.off.import_mesh(mesh_path)
+ else:
+ raise RuntimeError(
+ f"{mesh_path} mesh type not supported, "
+ "supported mesh type `.glb`, `.obj`, `.ply`, `.off`."
+ )
+
+ return mesh
+
+
+def save_mesh_with_mtl(
+ vertices: np.ndarray,
+ faces: np.ndarray,
+ uvs: np.ndarray,
+ texture: Union[Image.Image, np.ndarray],
+ output_path: str,
+ material_base=(250, 250, 250, 255),
+) -> trimesh.Trimesh:
+ if isinstance(texture, np.ndarray):
+ texture = Image.fromarray(texture)
+
+ mesh = trimesh.Trimesh(
+ vertices,
+ faces,
+ visual=trimesh.visual.TextureVisuals(uv=uvs, image=texture),
+ )
+ mesh.visual.material = trimesh.visual.material.SimpleMaterial(
+ image=texture,
+ diffuse=material_base,
+ ambient=material_base,
+ specular=material_base,
+ )
+
+ dir_name = os.path.dirname(output_path)
+ os.makedirs(dir_name, exist_ok=True)
+
+ _ = mesh.export(output_path)
+ # texture.save(os.path.join(dir_name, f"{file_name}_texture.png"))
+
+ logger.info(f"Saved mesh with texture to {output_path}")
+
+ return mesh
+
+
+def get_images_from_grid(
+ image: Union[str, Image.Image], img_size: int
+) -> list[Image.Image]:
+ if isinstance(image, str):
+ image = Image.open(image)
+
+ view_images = np.array(image)
+ view_images = np.concatenate(
+ [view_images[:img_size, ...], view_images[img_size:, ...]], axis=1
+ )
+ images = np.split(view_images, view_images.shape[1] // img_size, axis=1)
+ images = [Image.fromarray(img) for img in images]
+
+ return images
+
+
+def post_process_texture(texture: np.ndarray, iter: int = 1) -> np.ndarray:
+ for _ in range(iter):
+ texture = cv2.fastNlMeansDenoisingColored(texture, None, 2, 2, 7, 15)
+ texture = cv2.bilateralFilter(
+ texture, d=5, sigmaColor=20, sigmaSpace=20
+ )
+
+ return texture
+
+
+def quat_mult(q1, q2):
+ # NOTE:
+ # Q1 is the quaternion that rotates the vector from the original position to the final position # noqa
+ # Q2 is the quaternion that been rotated
+ w1, x1, y1, z1 = q1.T
+ w2, x2, y2, z2 = q2.T
+ w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
+ x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
+ y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
+ z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
+ return torch.stack([w, x, y, z]).T
+
+
+def quat_to_rotmat(quats: torch.Tensor, mode="wxyz") -> torch.Tensor:
+ """Convert quaternion to rotation matrix."""
+ quats = F.normalize(quats, p=2, dim=-1)
+
+ if mode == "xyzw":
+ x, y, z, w = torch.unbind(quats, dim=-1)
+ elif mode == "wxyz":
+ w, x, y, z = torch.unbind(quats, dim=-1)
+ else:
+ raise ValueError(f"Invalid mode: {mode}.")
+
+ R = torch.stack(
+ [
+ 1 - 2 * (y**2 + z**2),
+ 2 * (x * y - w * z),
+ 2 * (x * z + w * y),
+ 2 * (x * y + w * z),
+ 1 - 2 * (x**2 + z**2),
+ 2 * (y * z - w * x),
+ 2 * (x * z - w * y),
+ 2 * (y * z + w * x),
+ 1 - 2 * (x**2 + y**2),
+ ],
+ dim=-1,
+ )
+
+ return R.reshape(quats.shape[:-1] + (3, 3))
+
+
+def gamma_shs(shs: torch.Tensor, gamma: float) -> torch.Tensor:
+ C0 = 0.28209479177387814 # Constant for normalization in spherical harmonics # noqa
+ # Clip to the range [0.0, 1.0], apply gamma correction, and then un-clip back # noqa
+ new_shs = torch.clip(shs * C0 + 0.5, 0.0, 1.0)
+ new_shs = (torch.pow(new_shs, gamma) - 0.5) / C0
+ return new_shs
+
+
+def resize_pil(image: Image.Image, max_size: int = 1024) -> Image.Image:
+ max_size = max(image.size)
+ scale = min(1, 1024 / max_size)
+ if scale < 1:
+ new_size = (int(image.width * scale), int(image.height * scale))
+ image = image.resize(new_size, Image.Resampling.LANCZOS)
+
+ return image
+
+
+def trellis_preprocess(image: Image.Image) -> Image.Image:
+ """Process the input image as trellis done."""
+ image_np = np.array(image)
+ alpha = image_np[:, :, 3]
+ bbox = np.argwhere(alpha > 0.8 * 255)
+ bbox = (
+ np.min(bbox[:, 1]),
+ np.min(bbox[:, 0]),
+ np.max(bbox[:, 1]),
+ np.max(bbox[:, 0]),
+ )
+ center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
+ size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
+ size = int(size * 1.2)
+ bbox = (
+ center[0] - size // 2,
+ center[1] - size // 2,
+ center[0] + size // 2,
+ center[1] + size // 2,
+ )
+ image = image.crop(bbox)
+ image = image.resize((518, 518), Image.Resampling.LANCZOS)
+ image = np.array(image).astype(np.float32) / 255
+ image = image[:, :, :3] * image[:, :, 3:4]
+ image = Image.fromarray((image * 255).astype(np.uint8))
+
+ return image
+
+
+def zip_files(input_paths: list[str], output_zip: str) -> str:
+ with zipfile.ZipFile(output_zip, "w", zipfile.ZIP_DEFLATED) as zipf:
+ for input_path in input_paths:
+ if not os.path.exists(input_path):
+ raise FileNotFoundError(f"File not found: {input_path}")
+
+ if os.path.isdir(input_path):
+ for root, _, files in os.walk(input_path):
+ for file in files:
+ file_path = os.path.join(root, file)
+ arcname = os.path.relpath(
+ file_path, start=os.path.commonpath(input_paths)
+ )
+ zipf.write(file_path, arcname=arcname)
+ else:
+ arcname = os.path.relpath(
+ input_path, start=os.path.commonpath(input_paths)
+ )
+ zipf.write(input_path, arcname=arcname)
+
+ return output_zip
+
+
+def delete_dir(folder_path: str, keep_subs: list[str] = None) -> None:
+ for item in os.listdir(folder_path):
+ if keep_subs is not None and item in keep_subs:
+ continue
+ item_path = os.path.join(folder_path, item)
+ if os.path.isdir(item_path):
+ rmtree(item_path)
+ else:
+ os.remove(item_path)
diff --git a/embodied_gen/models/delight_model.py b/embodied_gen/models/delight_model.py
new file mode 100644
index 0000000..645b4c5
--- /dev/null
+++ b/embodied_gen/models/delight_model.py
@@ -0,0 +1,200 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+
+import os
+from typing import Union
+
+import cv2
+import numpy as np
+import spaces
+import torch
+from diffusers import (
+ EulerAncestralDiscreteScheduler,
+ StableDiffusionInstructPix2PixPipeline,
+)
+from huggingface_hub import snapshot_download
+from PIL import Image
+from embodied_gen.models.segment_model import RembgRemover
+
+__all__ = [
+ "DelightingModel",
+]
+
+
+class DelightingModel(object):
+ """A model to remove the lighting in image space.
+
+ This model is encapsulated based on the Hunyuan3D-Delight model
+ from https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0 # noqa
+
+ Attributes:
+ image_guide_scale (float): Weight of image guidance in diffusion process.
+ text_guide_scale (float): Weight of text (prompt) guidance in diffusion process.
+ num_infer_step (int): Number of inference steps for diffusion model.
+ mask_erosion_size (int): Size of erosion kernel for alpha mask cleanup.
+ device (str): Device used for inference, e.g., 'cuda' or 'cpu'.
+ seed (int): Random seed for diffusion model reproducibility.
+ model_path (str): Filesystem path to pretrained model weights.
+ pipeline: Lazy-loaded diffusion pipeline instance.
+ """
+
+ def __init__(
+ self,
+ model_path: str = None,
+ num_infer_step: int = 50,
+ mask_erosion_size: int = 3,
+ image_guide_scale: float = 1.5,
+ text_guide_scale: float = 1.0,
+ device: str = "cuda",
+ seed: int = 0,
+ ) -> None:
+ self.image_guide_scale = image_guide_scale
+ self.text_guide_scale = text_guide_scale
+ self.num_infer_step = num_infer_step
+ self.mask_erosion_size = mask_erosion_size
+ self.kernel = np.ones(
+ (self.mask_erosion_size, self.mask_erosion_size), np.uint8
+ )
+ self.seed = seed
+ self.device = device
+ self.pipeline = None # lazy load model adapt to @spaces.GPU
+
+ if model_path is None:
+ suffix = "hunyuan3d-delight-v2-0"
+ model_path = snapshot_download(
+ repo_id="tencent/Hunyuan3D-2", allow_patterns=f"{suffix}/*"
+ )
+ model_path = os.path.join(model_path, suffix)
+
+ self.model_path = model_path
+
+ def _lazy_init_pipeline(self):
+ if self.pipeline is None:
+ pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
+ self.model_path,
+ torch_dtype=torch.float16,
+ safety_checker=None,
+ )
+ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
+ pipeline.scheduler.config
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ pipeline.to(self.device, torch.float16)
+ self.pipeline = pipeline
+
+ def recenter_image(
+ self, image: Image.Image, border_ratio: float = 0.2
+ ) -> Image.Image:
+ if image.mode == "RGB":
+ return image
+ elif image.mode == "L":
+ image = image.convert("RGB")
+ return image
+
+ alpha_channel = np.array(image)[:, :, 3]
+ non_zero_indices = np.argwhere(alpha_channel > 0)
+ if non_zero_indices.size == 0:
+ raise ValueError("Image is fully transparent")
+
+ min_row, min_col = non_zero_indices.min(axis=0)
+ max_row, max_col = non_zero_indices.max(axis=0)
+
+ cropped_image = image.crop(
+ (min_col, min_row, max_col + 1, max_row + 1)
+ )
+
+ width, height = cropped_image.size
+ border_width = int(width * border_ratio)
+ border_height = int(height * border_ratio)
+
+ new_width = width + 2 * border_width
+ new_height = height + 2 * border_height
+
+ square_size = max(new_width, new_height)
+
+ new_image = Image.new(
+ "RGBA", (square_size, square_size), (255, 255, 255, 0)
+ )
+
+ paste_x = (square_size - new_width) // 2 + border_width
+ paste_y = (square_size - new_height) // 2 + border_height
+
+ new_image.paste(cropped_image, (paste_x, paste_y))
+
+ return new_image
+
+ @spaces.GPU
+ @torch.no_grad()
+ def __call__(
+ self,
+ image: Union[str, np.ndarray, Image.Image],
+ preprocess: bool = False,
+ target_wh: tuple[int, int] = None,
+ ) -> Image.Image:
+ self._lazy_init_pipeline()
+
+ if isinstance(image, str):
+ image = Image.open(image)
+ elif isinstance(image, np.ndarray):
+ image = Image.fromarray(image)
+
+ if preprocess:
+ bg_remover = RembgRemover()
+ image = bg_remover(image)
+ image = self.recenter_image(image)
+
+ if target_wh is not None:
+ image = image.resize(target_wh)
+ else:
+ target_wh = image.size
+
+ image_array = np.array(image)
+ assert image_array.shape[-1] == 4, "Image must have alpha channel"
+
+ raw_alpha_channel = image_array[:, :, 3]
+ alpha_channel = cv2.erode(raw_alpha_channel, self.kernel, iterations=1)
+ image_array[alpha_channel == 0, :3] = 255 # must be white background
+ image_array[:, :, 3] = alpha_channel
+
+ image = self.pipeline(
+ prompt="",
+ image=Image.fromarray(image_array).convert("RGB"),
+ generator=torch.manual_seed(self.seed),
+ num_inference_steps=self.num_infer_step,
+ image_guidance_scale=self.image_guide_scale,
+ guidance_scale=self.text_guide_scale,
+ ).images[0]
+
+ alpha_channel = Image.fromarray(alpha_channel)
+ rgba_image = image.convert("RGBA").resize(target_wh)
+ rgba_image.putalpha(alpha_channel)
+
+ return rgba_image
+
+
+if __name__ == "__main__":
+ delighting_model = DelightingModel()
+ image_path = "apps/assets/example_image/sample_12.jpg"
+ image = delighting_model(
+ image_path, preprocess=True, target_wh=(512, 512)
+ ) # noqa
+ image.save("delight.png")
+
+ # image_path = "embodied_gen/scripts/test_robot.png"
+ # image = delighting_model(image_path)
+ # image.save("delighting_image_a2.png")
diff --git a/embodied_gen/models/gs_model.py b/embodied_gen/models/gs_model.py
new file mode 100644
index 0000000..b866f46
--- /dev/null
+++ b/embodied_gen/models/gs_model.py
@@ -0,0 +1,526 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+
+import logging
+import os
+import struct
+from dataclasses import dataclass
+from typing import Optional
+
+import cv2
+import numpy as np
+import torch
+from gsplat.cuda._wrapper import spherical_harmonics
+from gsplat.rendering import rasterization
+from plyfile import PlyData
+from scipy.spatial.transform import Rotation
+from embodied_gen.data.utils import gamma_shs, quat_mult, quat_to_rotmat
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+__all__ = [
+ "RenderResult",
+ "GaussianOperator",
+]
+
+
+@dataclass
+class RenderResult:
+ rgb: np.ndarray
+ depth: np.ndarray
+ opacity: np.ndarray
+ mask_threshold: float = 10
+ mask: Optional[np.ndarray] = None
+ rgba: Optional[np.ndarray] = None
+
+ def __post_init__(self):
+ if isinstance(self.rgb, torch.Tensor):
+ rgb = self.rgb.detach().cpu().numpy()
+ rgb = (rgb * 255).astype(np.uint8)
+ self.rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
+ if isinstance(self.depth, torch.Tensor):
+ self.depth = self.depth.detach().cpu().numpy()
+ if isinstance(self.opacity, torch.Tensor):
+ opacity = self.opacity.detach().cpu().numpy()
+ opacity = (opacity * 255).astype(np.uint8)
+ self.opacity = cv2.cvtColor(opacity, cv2.COLOR_GRAY2RGB)
+ mask = np.where(self.opacity > self.mask_threshold, 255, 0)
+ self.mask = mask[..., 0:1].astype(np.uint8)
+ self.rgba = np.concatenate([self.rgb, self.mask], axis=-1)
+
+
+@dataclass
+class GaussianBase:
+ _opacities: torch.Tensor
+ _means: torch.Tensor
+ _scales: torch.Tensor
+ _quats: torch.Tensor
+ _rgbs: Optional[torch.Tensor] = None
+ _features_dc: Optional[torch.Tensor] = None
+ _features_rest: Optional[torch.Tensor] = None
+ sh_degree: Optional[int] = 0
+ device: str = "cuda"
+
+ def __post_init__(self):
+ self.active_sh_degree: int = self.sh_degree
+ self.to(self.device)
+
+ def to(self, device: str) -> None:
+ for k, v in self.__dict__.items():
+ if not isinstance(v, torch.Tensor):
+ continue
+ self.__dict__[k] = v.to(device)
+
+ def get_numpy_data(self):
+ data = {}
+ for k, v in self.__dict__.items():
+ if not isinstance(v, torch.Tensor):
+ continue
+ data[k] = v.detach().cpu().numpy()
+
+ return data
+
+ def quat_norm(self, x: torch.Tensor) -> torch.Tensor:
+ return x / x.norm(dim=-1, keepdim=True)
+
+ @classmethod
+ def load_from_ply(
+ cls,
+ path: str,
+ gamma: float = 1.0,
+ device: str = "cuda",
+ ) -> "GaussianBase":
+ plydata = PlyData.read(path)
+ xyz = torch.stack(
+ (
+ torch.tensor(plydata.elements[0]["x"], dtype=torch.float32),
+ torch.tensor(plydata.elements[0]["y"], dtype=torch.float32),
+ torch.tensor(plydata.elements[0]["z"], dtype=torch.float32),
+ ),
+ dim=1,
+ )
+
+ opacities = torch.tensor(
+ plydata.elements[0]["opacity"], dtype=torch.float32
+ ).unsqueeze(-1)
+ features_dc = torch.zeros((xyz.shape[0], 3), dtype=torch.float32)
+ features_dc[:, 0] = torch.tensor(
+ plydata.elements[0]["f_dc_0"], dtype=torch.float32
+ )
+ features_dc[:, 1] = torch.tensor(
+ plydata.elements[0]["f_dc_1"], dtype=torch.float32
+ )
+ features_dc[:, 2] = torch.tensor(
+ plydata.elements[0]["f_dc_2"], dtype=torch.float32
+ )
+
+ scale_names = [
+ p.name
+ for p in plydata.elements[0].properties
+ if p.name.startswith("scale_")
+ ]
+ scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1]))
+ scales = torch.zeros(
+ (xyz.shape[0], len(scale_names)), dtype=torch.float32
+ )
+ for idx, attr_name in enumerate(scale_names):
+ scales[:, idx] = torch.tensor(
+ plydata.elements[0][attr_name], dtype=torch.float32
+ )
+
+ rot_names = [
+ p.name
+ for p in plydata.elements[0].properties
+ if p.name.startswith("rot_")
+ ]
+ rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1]))
+ rots = torch.zeros((xyz.shape[0], len(rot_names)), dtype=torch.float32)
+ for idx, attr_name in enumerate(rot_names):
+ rots[:, idx] = torch.tensor(
+ plydata.elements[0][attr_name], dtype=torch.float32
+ )
+
+ rots = rots / torch.norm(rots, dim=-1, keepdim=True)
+
+ # extra features
+ extra_f_names = [
+ p.name
+ for p in plydata.elements[0].properties
+ if p.name.startswith("f_rest_")
+ ]
+ extra_f_names = sorted(
+ extra_f_names, key=lambda x: int(x.split("_")[-1])
+ )
+
+ max_sh_degree = int(np.sqrt((len(extra_f_names) + 3) / 3) - 1)
+ if max_sh_degree != 0:
+ features_extra = torch.zeros(
+ (xyz.shape[0], len(extra_f_names)), dtype=torch.float32
+ )
+ for idx, attr_name in enumerate(extra_f_names):
+ features_extra[:, idx] = torch.tensor(
+ plydata.elements[0][attr_name], dtype=torch.float32
+ )
+
+ features_extra = features_extra.view(
+ (features_extra.shape[0], 3, (max_sh_degree + 1) ** 2 - 1)
+ )
+ features_extra = features_extra.permute(0, 2, 1)
+
+ if abs(gamma - 1.0) > 1e-3:
+ features_dc = gamma_shs(features_dc, gamma)
+ features_extra[..., :] = 0.0
+ opacities *= 0.8
+
+ shs = torch.cat(
+ [
+ features_dc.reshape(-1, 3),
+ features_extra.reshape(len(features_dc), -1),
+ ],
+ dim=-1,
+ )
+ else:
+ # sh_dim is 0, only dc features
+ shs = features_dc
+ features_extra = None
+
+ return cls(
+ sh_degree=max_sh_degree,
+ _means=xyz,
+ _opacities=opacities,
+ _rgbs=shs,
+ _scales=scales,
+ _quats=rots,
+ _features_dc=features_dc,
+ _features_rest=features_extra,
+ device=device,
+ )
+
+ def save_to_ply(
+ self, path: str, colors: torch.Tensor = None, enable_mask: bool = False
+ ):
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ numpy_data = self.get_numpy_data()
+ means = numpy_data["_means"]
+ scales = numpy_data["_scales"]
+ quats = numpy_data["_quats"]
+ opacities = numpy_data["_opacities"]
+ sh0 = numpy_data["_features_dc"]
+ shN = numpy_data.get("_features_rest", np.zeros((means.shape[0], 0)))
+ shN = shN.reshape(means.shape[0], -1)
+
+ # Create a mask to identify rows with NaN or Inf in any of the numpy_data arrays # noqa
+ if enable_mask:
+ invalid_mask = (
+ np.isnan(means).any(axis=1)
+ | np.isinf(means).any(axis=1)
+ | np.isnan(scales).any(axis=1)
+ | np.isinf(scales).any(axis=1)
+ | np.isnan(quats).any(axis=1)
+ | np.isinf(quats).any(axis=1)
+ | np.isnan(opacities).any(axis=0)
+ | np.isinf(opacities).any(axis=0)
+ | np.isnan(sh0).any(axis=1)
+ | np.isinf(sh0).any(axis=1)
+ | np.isnan(shN).any(axis=1)
+ | np.isinf(shN).any(axis=1)
+ )
+
+ # Filter out rows with NaNs or Infs from all data arrays
+ means = means[~invalid_mask]
+ scales = scales[~invalid_mask]
+ quats = quats[~invalid_mask]
+ opacities = opacities[~invalid_mask]
+ sh0 = sh0[~invalid_mask]
+ shN = shN[~invalid_mask]
+
+ num_points = means.shape[0]
+
+ with open(path, "wb") as f:
+ # Write PLY header
+ f.write(b"ply\n")
+ f.write(b"format binary_little_endian 1.0\n")
+ f.write(f"element vertex {num_points}\n".encode())
+ f.write(b"property float x\n")
+ f.write(b"property float y\n")
+ f.write(b"property float z\n")
+ f.write(b"property float nx\n")
+ f.write(b"property float ny\n")
+ f.write(b"property float nz\n")
+
+ if colors is not None:
+ for j in range(colors.shape[1]):
+ f.write(f"property float f_dc_{j}\n".encode())
+ else:
+ for i, data in enumerate([sh0, shN]):
+ prefix = "f_dc" if i == 0 else "f_rest"
+ for j in range(data.shape[1]):
+ f.write(f"property float {prefix}_{j}\n".encode())
+
+ f.write(b"property float opacity\n")
+
+ for i in range(scales.shape[1]):
+ f.write(f"property float scale_{i}\n".encode())
+ for i in range(quats.shape[1]):
+ f.write(f"property float rot_{i}\n".encode())
+
+ f.write(b"end_header\n")
+
+ # Write vertex data
+ for i in range(num_points):
+ f.write(struct.pack(" (x y z qw qx qy qz)
+ instance_pose = instance_pose[[0, 1, 2, 6, 3, 4, 5]]
+ cur_instances_quats = self.quat_norm(instance_pose[3:])
+ rot_cur = quat_to_rotmat(cur_instances_quats, mode="wxyz")
+
+ # update the means
+ num_gs = means.shape[0]
+ trans_per_pts = torch.stack([instance_pose[:3]] * num_gs, dim=0)
+ quat_per_pts = torch.stack([instance_pose[3:]] * num_gs, dim=0)
+ rot_per_pts = torch.stack([rot_cur] * num_gs, dim=0) # (num_gs, 3, 3)
+
+ # update the means
+ cur_means = (
+ torch.bmm(rot_per_pts, means.unsqueeze(-1)).squeeze(-1)
+ + trans_per_pts
+ )
+
+ # update the quats
+ _quats = self.quat_norm(quats)
+ cur_quats = quat_mult(quat_per_pts, _quats)
+
+ return cur_means, cur_quats
+
+ def get_gaussians(
+ self,
+ c2w: torch.Tensor = None,
+ instance_pose: torch.Tensor = None,
+ apply_activate: bool = False,
+ ) -> "GaussianBase":
+ """Get Gaussian data under the given instance_pose."""
+ if c2w is None:
+ c2w = torch.eye(4).to(self.device)
+
+ if instance_pose is not None:
+ # compute the transformed gs means and quats
+ world_means, world_quats = self._compute_transform(
+ self._means, self._quats, instance_pose.float().to(self.device)
+ )
+ else:
+ world_means, world_quats = self._means, self._quats
+
+ # get colors of gaussians
+ if self._features_rest is not None:
+ colors = torch.cat(
+ (self._features_dc[:, None, :], self._features_rest), dim=1
+ )
+ else:
+ colors = self._features_dc[:, None, :]
+
+ if self.sh_degree > 0:
+ viewdirs = world_means.detach() - c2w[..., :3, 3] # (N, 3)
+ viewdirs = viewdirs / viewdirs.norm(dim=-1, keepdim=True)
+ rgbs = spherical_harmonics(self.sh_degree, viewdirs, colors)
+ rgbs = torch.clamp(rgbs + 0.5, 0.0, 1.0)
+ else:
+ rgbs = torch.sigmoid(colors[:, 0, :])
+
+ gs_dict = dict(
+ _means=world_means,
+ _opacities=(
+ torch.sigmoid(self._opacities)
+ if apply_activate
+ else self._opacities
+ ),
+ _rgbs=rgbs,
+ _scales=(
+ torch.exp(self._scales) if apply_activate else self._scales
+ ),
+ _quats=self.quat_norm(world_quats),
+ _features_dc=self._features_dc,
+ _features_rest=self._features_rest,
+ sh_degree=self.sh_degree,
+ device=self.device,
+ )
+
+ return GaussianOperator(**gs_dict)
+
+ def rescale(self, scale: float):
+ if scale != 1.0:
+ self._means *= scale
+ self._scales += torch.log(self._scales.new_tensor(scale))
+
+ def set_scale_by_height(self, real_height: float) -> None:
+ def _ptp(tensor, dim):
+ val = tensor.max(dim=dim).values - tensor.min(dim=dim).values
+ return val.tolist()
+
+ xyz_scale = max(_ptp(self._means, dim=0))
+ self.rescale(1 / (xyz_scale + 1e-6)) # Normalize to [-0.5, 0.5]
+ raw_height = _ptp(self._means, dim=0)[1]
+ scale = real_height / raw_height
+
+ self.rescale(scale)
+
+ return
+
+ @staticmethod
+ def resave_ply(
+ in_ply: str,
+ out_ply: str,
+ real_height: float = None,
+ instance_pose: np.ndarray = None,
+ device: str = "cuda",
+ ) -> None:
+ gs_model = GaussianOperator.load_from_ply(in_ply, device=device)
+
+ if instance_pose is not None:
+ gs_model = gs_model.get_gaussians(instance_pose=instance_pose)
+
+ if real_height is not None:
+ gs_model.set_scale_by_height(real_height)
+
+ gs_model.save_to_ply(out_ply)
+
+ return
+
+ @staticmethod
+ def trans_to_quatpose(
+ rot_matrix: list[list[float]],
+ trans_matrix: list[float] = [0, 0, 0],
+ ) -> torch.Tensor:
+ if isinstance(rot_matrix, list):
+ rot_matrix = np.array(rot_matrix)
+
+ rot = Rotation.from_matrix(rot_matrix)
+ qx, qy, qz, qw = rot.as_quat()
+ instance_pose = torch.tensor([*trans_matrix, qx, qy, qz, qw])
+
+ return instance_pose
+
+ def render(
+ self,
+ c2w: torch.Tensor,
+ Ks: torch.Tensor,
+ image_width: int,
+ image_height: int,
+ ) -> RenderResult:
+ gs = self.get_gaussians(c2w, apply_activate=True)
+ renders, alphas, _ = rasterization(
+ means=gs._means,
+ quats=gs._quats,
+ scales=gs._scales,
+ opacities=gs._opacities.squeeze(),
+ colors=gs._rgbs,
+ viewmats=torch.linalg.inv(c2w)[None, ...],
+ Ks=Ks[None, ...],
+ width=image_width,
+ height=image_height,
+ packed=False,
+ absgrad=True,
+ sparse_grad=False,
+ # rasterize_mode="classic",
+ rasterize_mode="antialiased",
+ **{
+ "near_plane": 0.01,
+ "far_plane": 1000000000,
+ "radius_clip": 0.0,
+ "render_mode": "RGB+ED",
+ },
+ )
+ renders = renders[0]
+ alphas = alphas[0].squeeze(-1)
+
+ assert renders.shape[-1] == 4, f"Must render rgb, depth and alpha"
+ rendered_rgb, rendered_depth = torch.split(renders, [3, 1], dim=-1)
+
+ return RenderResult(
+ torch.clamp(rendered_rgb, min=0, max=1),
+ rendered_depth,
+ alphas[..., None],
+ )
+
+
+if __name__ == "__main__":
+ input_gs = "outputs/test/debug.ply"
+ output_gs = "./debug_v3.ply"
+ gs_model: GaussianOperator = GaussianOperator.load_from_ply(input_gs)
+
+ # ็ป x ่ฝดๆ่ฝฌ 180ยฐ
+ R_x = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
+ instance_pose = gs_model.trans_to_quatpose(R_x)
+ gs_model = gs_model.get_gaussians(instance_pose=instance_pose)
+
+ gs_model.rescale(2)
+
+ gs_model.set_scale_by_height(1.3)
+
+ gs_model.save_to_ply(output_gs)
diff --git a/embodied_gen/models/segment_model.py b/embodied_gen/models/segment_model.py
new file mode 100644
index 0000000..ab92c0c
--- /dev/null
+++ b/embodied_gen/models/segment_model.py
@@ -0,0 +1,379 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+
+import logging
+import os
+from typing import Literal, Union
+
+import cv2
+import numpy as np
+import rembg
+import torch
+from huggingface_hub import snapshot_download
+from PIL import Image
+from segment_anything import (
+ SamAutomaticMaskGenerator,
+ SamPredictor,
+ sam_model_registry,
+)
+from transformers import pipeline
+from embodied_gen.data.utils import resize_pil, trellis_preprocess
+from embodied_gen.utils.process_media import filter_small_connected_components
+from embodied_gen.validators.quality_checkers import ImageSegChecker
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+__all__ = [
+ "SAMRemover",
+ "SAMPredictor",
+ "RembgRemover",
+ "get_segmented_image_by_agent",
+]
+
+
+class SAMRemover(object):
+ """Loading SAM models and performing background removal on images.
+
+ Attributes:
+ checkpoint (str): Path to the model checkpoint.
+ model_type (str): Type of the SAM model to load (default: "vit_h").
+ area_ratio (float): Area ratio filtering small connected components.
+ """
+
+ def __init__(
+ self,
+ checkpoint: str = None,
+ model_type: str = "vit_h",
+ area_ratio: float = 15,
+ ):
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.model_type = model_type
+ self.area_ratio = area_ratio
+
+ if checkpoint is None:
+ suffix = "sam"
+ model_path = snapshot_download(
+ repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
+ )
+ checkpoint = os.path.join(
+ model_path, suffix, "sam_vit_h_4b8939.pth"
+ )
+
+ self.mask_generator = self._load_sam_model(checkpoint)
+
+ def _load_sam_model(self, checkpoint: str) -> SamAutomaticMaskGenerator:
+ sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
+ sam.to(device=self.device)
+
+ return SamAutomaticMaskGenerator(sam)
+
+ def __call__(
+ self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
+ ) -> Image.Image:
+ """Removes the background from an image using the SAM model.
+
+ Args:
+ image (Union[str, Image.Image, np.ndarray]): Input image,
+ can be a file path, PIL Image, or numpy array.
+ save_path (str): Path to save the output image (default: None).
+
+ Returns:
+ Image.Image: The image with background removed,
+ including an alpha channel.
+ """
+ # Convert input to numpy array
+ if isinstance(image, str):
+ image = Image.open(image)
+ elif isinstance(image, np.ndarray):
+ image = Image.fromarray(image).convert("RGB")
+ image = resize_pil(image)
+ image = np.array(image.convert("RGB"))
+
+ # Generate masks
+ masks = self.mask_generator.generate(image)
+ masks = sorted(masks, key=lambda x: x["area"], reverse=True)
+
+ if not masks:
+ logger.warning(
+ "Segmentation failed: No mask generated, return raw image."
+ )
+ output_image = Image.fromarray(image, mode="RGB")
+ else:
+ # Use the largest mask
+ best_mask = masks[0]["segmentation"]
+ mask = (best_mask * 255).astype(np.uint8)
+ mask = filter_small_connected_components(
+ mask, area_ratio=self.area_ratio
+ )
+ # Apply the mask to remove the background
+ background_removed = cv2.bitwise_and(image, image, mask=mask)
+ output_image = np.dstack((background_removed, mask))
+ output_image = Image.fromarray(output_image, mode="RGBA")
+
+ if save_path is not None:
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+ output_image.save(save_path)
+
+ return output_image
+
+
+class SAMPredictor(object):
+ def __init__(
+ self,
+ checkpoint: str = None,
+ model_type: str = "vit_h",
+ binary_thresh: float = 0.1,
+ device: str = "cuda",
+ ):
+ self.device = device
+ self.model_type = model_type
+
+ if checkpoint is None:
+ suffix = "sam"
+ model_path = snapshot_download(
+ repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
+ )
+ checkpoint = os.path.join(
+ model_path, suffix, "sam_vit_h_4b8939.pth"
+ )
+
+ self.predictor = self._load_sam_model(checkpoint)
+ self.binary_thresh = binary_thresh
+
+ def _load_sam_model(self, checkpoint: str) -> SamPredictor:
+ sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
+ sam.to(device=self.device)
+
+ return SamPredictor(sam)
+
+ def preprocess_image(self, image: Image.Image) -> np.ndarray:
+ if isinstance(image, str):
+ image = Image.open(image)
+ elif isinstance(image, np.ndarray):
+ image = Image.fromarray(image).convert("RGB")
+
+ image = resize_pil(image)
+ image = np.array(image.convert("RGB"))
+
+ return image
+
+ def generate_masks(
+ self,
+ image: np.ndarray,
+ selected_points: list[list[int]],
+ ) -> np.ndarray:
+ if len(selected_points) == 0:
+ return []
+
+ points = (
+ torch.Tensor([p for p, _ in selected_points])
+ .to(self.predictor.device)
+ .unsqueeze(1)
+ )
+
+ labels = (
+ torch.Tensor([int(l) for _, l in selected_points])
+ .to(self.predictor.device)
+ .unsqueeze(1)
+ )
+
+ transformed_points = self.predictor.transform.apply_coords_torch(
+ points, image.shape[:2]
+ )
+
+ masks, scores, _ = self.predictor.predict_torch(
+ point_coords=transformed_points,
+ point_labels=labels,
+ multimask_output=True,
+ )
+ valid_mask = masks[:, torch.argmax(scores, dim=1)]
+ masks_pos = valid_mask[labels[:, 0] == 1, 0].cpu().detach().numpy()
+ masks_neg = valid_mask[labels[:, 0] == 0, 0].cpu().detach().numpy()
+ if len(masks_neg) == 0:
+ masks_neg = np.zeros_like(masks_pos)
+ if len(masks_pos) == 0:
+ masks_pos = np.zeros_like(masks_neg)
+ masks_neg = masks_neg.max(axis=0, keepdims=True)
+ masks_pos = masks_pos.max(axis=0, keepdims=True)
+ valid_mask = (masks_pos.astype(int) - masks_neg.astype(int)).clip(0, 1)
+
+ binary_mask = (valid_mask > self.binary_thresh).astype(np.int32)
+
+ return [(mask, f"mask_{i}") for i, mask in enumerate(binary_mask)]
+
+ def get_segmented_image(
+ self, image: np.ndarray, masks: list[tuple[np.ndarray, str]]
+ ) -> Image.Image:
+ seg_image = Image.fromarray(image, mode="RGB")
+ alpha_channel = np.zeros(
+ (seg_image.height, seg_image.width), dtype=np.uint8
+ )
+ for mask, _ in masks:
+ # Use the maximum to combine multiple masks
+ alpha_channel = np.maximum(alpha_channel, mask)
+
+ alpha_channel = np.clip(alpha_channel, 0, 1)
+ alpha_channel = (alpha_channel * 255).astype(np.uint8)
+ alpha_image = Image.fromarray(alpha_channel, mode="L")
+ r, g, b = seg_image.split()
+ seg_image = Image.merge("RGBA", (r, g, b, alpha_image))
+
+ return seg_image
+
+ def __call__(
+ self,
+ image: Union[str, Image.Image, np.ndarray],
+ selected_points: list[list[int]],
+ ) -> Image.Image:
+ image = self.preprocess_image(image)
+ self.predictor.set_image(image)
+ masks = self.generate_masks(image, selected_points)
+
+ return self.get_segmented_image(image, masks)
+
+
+class RembgRemover(object):
+ def __init__(self):
+ self.rembg_session = rembg.new_session("u2net")
+
+ def __call__(
+ self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
+ ) -> Image.Image:
+ if isinstance(image, str):
+ image = Image.open(image)
+ elif isinstance(image, np.ndarray):
+ image = Image.fromarray(image)
+
+ image = resize_pil(image)
+ output_image = rembg.remove(image, session=self.rembg_session)
+
+ if save_path is not None:
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+ output_image.save(save_path)
+
+ return output_image
+
+
+class BMGG14Remover(object):
+ def __init__(self) -> None:
+ self.model = pipeline(
+ "image-segmentation",
+ model="briaai/RMBG-1.4",
+ trust_remote_code=True,
+ )
+
+ def __call__(
+ self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
+ ):
+ if isinstance(image, str):
+ image = Image.open(image)
+ elif isinstance(image, np.ndarray):
+ image = Image.fromarray(image)
+
+ image = resize_pil(image)
+ output_image = self.model(image)
+
+ if save_path is not None:
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+ output_image.save(save_path)
+
+ return output_image
+
+
+def invert_rgba_pil(
+ image: Image.Image, mask: Image.Image, save_path: str = None
+) -> Image.Image:
+ mask = (255 - np.array(mask))[..., None]
+ image_array = np.concatenate([np.array(image), mask], axis=-1)
+ inverted_image = Image.fromarray(image_array, "RGBA")
+
+ if save_path is not None:
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+ inverted_image.save(save_path)
+
+ return inverted_image
+
+
+def get_segmented_image_by_agent(
+ image: Image.Image,
+ sam_remover: SAMRemover,
+ rbg_remover: RembgRemover,
+ seg_checker: ImageSegChecker = None,
+ save_path: str = None,
+ mode: Literal["loose", "strict"] = "loose",
+) -> Image.Image:
+ def _is_valid_seg(raw_img: Image.Image, seg_img: Image.Image) -> bool:
+ if seg_checker is None:
+ return True
+ return raw_img.mode == "RGBA" and seg_checker([raw_img, seg_img])[0]
+
+ out_sam = f"{save_path}_sam.png" if save_path else None
+ out_sam_inv = f"{save_path}_sam_inv.png" if save_path else None
+ out_rbg = f"{save_path}_rbg.png" if save_path else None
+
+ seg_image = sam_remover(image, out_sam)
+ seg_image = seg_image.convert("RGBA")
+ _, _, _, alpha = seg_image.split()
+ seg_image_inv = invert_rgba_pil(image.convert("RGB"), alpha, out_sam_inv)
+ seg_image_rbg = rbg_remover(image, out_rbg)
+
+ final_image = None
+ if _is_valid_seg(image, seg_image):
+ final_image = seg_image
+ elif _is_valid_seg(image, seg_image_inv):
+ final_image = seg_image_inv
+ elif _is_valid_seg(image, seg_image_rbg):
+ logger.warning(f"Failed to segment by `SAM`, retry with `rembg`.")
+ final_image = seg_image_rbg
+ else:
+ if mode == "strict":
+ raise RuntimeError(
+ f"Failed to segment by `SAM` or `rembg`, abort."
+ )
+ logger.warning("Failed to segment by SAM or rembg, use raw image.")
+ final_image = image.convert("RGBA")
+
+ if save_path:
+ final_image.save(save_path)
+
+ final_image = trellis_preprocess(final_image)
+
+ return final_image
+
+
+if __name__ == "__main__":
+ input_image = "outputs/text2image/demo_objects/electrical/sample_0.jpg"
+ output_image = "sample_0_seg2.png"
+
+ # input_image = "outputs/text2image/tmp/coffee_machine.jpeg"
+ # output_image = "outputs/text2image/tmp/coffee_machine_seg.png"
+
+ # input_image = "outputs/text2image/tmp/bucket.jpeg"
+ # output_image = "outputs/text2image/tmp/bucket_seg.png"
+
+ remover = SAMRemover(model_type="vit_h")
+ remover = RembgRemover()
+ clean_image = remover(input_image)
+ clean_image.save(output_image)
+ get_segmented_image_by_agent(
+ Image.open(input_image), remover, remover, None, "./test_seg.png"
+ )
+
+ remover = BMGG14Remover()
+ remover("embodied_gen/models/test_seg.jpg", "./seg.png")
diff --git a/embodied_gen/models/sr_model.py b/embodied_gen/models/sr_model.py
new file mode 100644
index 0000000..40310bb
--- /dev/null
+++ b/embodied_gen/models/sr_model.py
@@ -0,0 +1,174 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+
+import logging
+import os
+from typing import Union
+
+import numpy as np
+import spaces
+import torch
+from huggingface_hub import snapshot_download
+from PIL import Image
+from embodied_gen.data.utils import get_images_from_grid
+
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
+)
+logger = logging.getLogger(__name__)
+
+
+__all__ = [
+ "ImageStableSR",
+ "ImageRealESRGAN",
+]
+
+
+class ImageStableSR:
+ """Super-resolution image upscaler using Stable Diffusion x4 upscaling model from StabilityAI."""
+
+ def __init__(
+ self,
+ model_path: str = "stabilityai/stable-diffusion-x4-upscaler",
+ device="cuda",
+ ) -> None:
+ from diffusers import StableDiffusionUpscalePipeline
+
+ self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained(
+ model_path,
+ torch_dtype=torch.float16,
+ ).to(device)
+ self.up_pipeline_x4.set_progress_bar_config(disable=True)
+ self.up_pipeline_x4.enable_model_cpu_offload()
+
+ @spaces.GPU
+ def __call__(
+ self,
+ image: Union[Image.Image, np.ndarray],
+ prompt: str = "",
+ infer_step: int = 20,
+ ) -> Image.Image:
+ if isinstance(image, np.ndarray):
+ image = Image.fromarray(image)
+
+ image = image.convert("RGB")
+
+ with torch.no_grad():
+ upscaled_image = self.up_pipeline_x4(
+ image=image,
+ prompt=[prompt],
+ num_inference_steps=infer_step,
+ ).images[0]
+
+ return upscaled_image
+
+
+class ImageRealESRGAN:
+ """A wrapper for Real-ESRGAN-based image super-resolution.
+
+ This class uses the RealESRGAN model to perform image upscaling,
+ typically by a factor of 4.
+
+ Attributes:
+ outscale (int): The output image scale factor (e.g., 2, 4).
+ model_path (str): Path to the pre-trained model weights.
+ """
+
+ def __init__(self, outscale: int, model_path: str = None) -> None:
+ # monkey patch to support torchvision>=0.16
+ import torchvision
+ from packaging import version
+
+ if version.parse(torchvision.__version__) > version.parse("0.16"):
+ import sys
+ import types
+
+ import torchvision.transforms.functional as TF
+
+ functional_tensor = types.ModuleType(
+ "torchvision.transforms.functional_tensor"
+ )
+ functional_tensor.rgb_to_grayscale = TF.rgb_to_grayscale
+ sys.modules["torchvision.transforms.functional_tensor"] = (
+ functional_tensor
+ )
+
+ self.outscale = outscale
+ self.upsampler = None
+
+ if model_path is None:
+ suffix = "super_resolution"
+ model_path = snapshot_download(
+ repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
+ )
+ model_path = os.path.join(
+ model_path, suffix, "RealESRGAN_x4plus.pth"
+ )
+
+ self.model_path = model_path
+
+ def _lazy_init(self):
+ if self.upsampler is None:
+ from basicsr.archs.rrdbnet_arch import RRDBNet
+ from realesrgan import RealESRGANer
+
+ model = RRDBNet(
+ num_in_ch=3,
+ num_out_ch=3,
+ num_feat=64,
+ num_block=23,
+ num_grow_ch=32,
+ scale=4,
+ )
+
+ self.upsampler = RealESRGANer(
+ scale=4,
+ model_path=self.model_path,
+ model=model,
+ pre_pad=0,
+ half=True,
+ )
+
+ @spaces.GPU
+ def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
+ self._lazy_init()
+
+ if isinstance(image, Image.Image):
+ image = np.array(image)
+
+ with torch.no_grad():
+ output, _ = self.upsampler.enhance(image, outscale=self.outscale)
+
+ return Image.fromarray(output)
+
+
+if __name__ == "__main__":
+ color_path = "outputs/texture_mesh_gen/multi_view/color_sample0.png"
+
+ # Use RealESRGAN_x4plus for x4 (512->2048) image super resolution.
+ super_model = ImageRealESRGAN(outscale=4)
+ multiviews = get_images_from_grid(color_path, img_size=512)
+ multiviews = [super_model(img.convert("RGB")) for img in multiviews]
+ for idx, img in enumerate(multiviews):
+ img.save(f"sr{idx}.png")
+
+ # # Use stable diffusion for x4 (512->2048) image super resolution.
+ # super_model = ImageStableSR()
+ # multiviews = get_images_from_grid(color_path, img_size=512)
+ # multiviews = [super_model(img) for img in multiviews]
+ # for idx, img in enumerate(multiviews):
+ # img.save(f"sr_stable{idx}.png")
diff --git a/embodied_gen/models/text_model.py b/embodied_gen/models/text_model.py
new file mode 100644
index 0000000..109762b
--- /dev/null
+++ b/embodied_gen/models/text_model.py
@@ -0,0 +1,171 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+
+import logging
+import random
+
+import numpy as np
+import torch
+from diffusers import (
+ AutoencoderKL,
+ EulerDiscreteScheduler,
+ UNet2DConditionModel,
+)
+from kolors.models.modeling_chatglm import ChatGLMModel
+from kolors.models.tokenization_chatglm import ChatGLMTokenizer
+from kolors.models.unet_2d_condition import (
+ UNet2DConditionModel as UNet2DConditionModelIP,
+)
+from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import (
+ StableDiffusionXLPipeline,
+)
+from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import ( # noqa
+ StableDiffusionXLPipeline as StableDiffusionXLPipelineIP,
+)
+from PIL import Image
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+__all__ = [
+ "build_text2img_ip_pipeline",
+ "build_text2img_pipeline",
+ "text2img_gen",
+]
+
+
+def build_text2img_ip_pipeline(
+ ckpt_dir: str,
+ ref_scale: float,
+ device: str = "cuda",
+) -> StableDiffusionXLPipelineIP:
+ text_encoder = ChatGLMModel.from_pretrained(
+ f"{ckpt_dir}/text_encoder", torch_dtype=torch.float16
+ ).half()
+ tokenizer = ChatGLMTokenizer.from_pretrained(f"{ckpt_dir}/text_encoder")
+ vae = AutoencoderKL.from_pretrained(
+ f"{ckpt_dir}/vae", revision=None
+ ).half()
+ scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
+ unet = UNet2DConditionModelIP.from_pretrained(
+ f"{ckpt_dir}/unet", revision=None
+ ).half()
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ f"{ckpt_dir}/../Kolors-IP-Adapter-Plus/image_encoder",
+ ignore_mismatched_sizes=True,
+ ).to(dtype=torch.float16)
+ clip_image_processor = CLIPImageProcessor(size=336, crop_size=336)
+
+ pipe = StableDiffusionXLPipelineIP(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=clip_image_processor,
+ force_zeros_for_empty_prompt=False,
+ )
+
+ if hasattr(pipe.unet, "encoder_hid_proj"):
+ pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj
+
+ pipe.load_ip_adapter(
+ f"{ckpt_dir}/../Kolors-IP-Adapter-Plus",
+ subfolder="",
+ weight_name=["ip_adapter_plus_general.bin"],
+ )
+ pipe.set_ip_adapter_scale([ref_scale])
+
+ pipe = pipe.to(device)
+ pipe.enable_model_cpu_offload()
+ # pipe.enable_xformers_memory_efficient_attention()
+ # pipe.enable_vae_slicing()
+
+ return pipe
+
+
+def build_text2img_pipeline(
+ ckpt_dir: str,
+ device: str = "cuda",
+) -> StableDiffusionXLPipeline:
+ text_encoder = ChatGLMModel.from_pretrained(
+ f"{ckpt_dir}/text_encoder", torch_dtype=torch.float16
+ ).half()
+ tokenizer = ChatGLMTokenizer.from_pretrained(f"{ckpt_dir}/text_encoder")
+ vae = AutoencoderKL.from_pretrained(
+ f"{ckpt_dir}/vae", revision=None
+ ).half()
+ scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
+ unet = UNet2DConditionModel.from_pretrained(
+ f"{ckpt_dir}/unet", revision=None
+ ).half()
+ pipe = StableDiffusionXLPipeline(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ force_zeros_for_empty_prompt=False,
+ )
+ pipe = pipe.to(device)
+ pipe.enable_model_cpu_offload()
+ pipe.enable_xformers_memory_efficient_attention()
+
+ return pipe
+
+
+def text2img_gen(
+ prompt: str,
+ n_sample: int,
+ guidance_scale: float,
+ pipeline: StableDiffusionXLPipeline | StableDiffusionXLPipelineIP,
+ ip_image: Image.Image | str = None,
+ image_wh: tuple[int, int] = [1024, 1024],
+ infer_step: int = 50,
+ ip_image_size: int = 512,
+ seed: int = None,
+) -> list[Image.Image]:
+ prompt = "Single " + prompt + ", in the center of the image"
+ prompt += ", high quality, high resolution, best quality, white background, 3D style" # noqa
+ logger.info(f"Processing prompt: {prompt}")
+
+ generator = None
+ if seed is not None:
+ generator = torch.Generator(pipeline.device).manual_seed(seed)
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+
+ kwargs = dict(
+ prompt=prompt,
+ height=image_wh[1],
+ width=image_wh[0],
+ num_inference_steps=infer_step,
+ guidance_scale=guidance_scale,
+ num_images_per_prompt=n_sample,
+ generator=generator,
+ )
+ if ip_image is not None:
+ if isinstance(ip_image, str):
+ ip_image = Image.open(ip_image)
+ ip_image = ip_image.resize((ip_image_size, ip_image_size))
+ kwargs.update(ip_adapter_image=[ip_image])
+
+ return pipeline(**kwargs).images
diff --git a/embodied_gen/models/texture_model.py b/embodied_gen/models/texture_model.py
new file mode 100644
index 0000000..dd2b5d5
--- /dev/null
+++ b/embodied_gen/models/texture_model.py
@@ -0,0 +1,108 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+
+import os
+
+import torch
+from diffusers import AutoencoderKL, DiffusionPipeline, EulerDiscreteScheduler
+from huggingface_hub import snapshot_download
+from kolors.models.controlnet import ControlNetModel
+from kolors.models.modeling_chatglm import ChatGLMModel
+from kolors.models.tokenization_chatglm import ChatGLMTokenizer
+from kolors.models.unet_2d_condition import UNet2DConditionModel
+from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import (
+ StableDiffusionXLControlNetImg2ImgPipeline,
+)
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
+
+__all__ = [
+ "build_texture_gen_pipe",
+]
+
+
+def build_texture_gen_pipe(
+ base_ckpt_dir: str,
+ controlnet_ckpt: str = None,
+ ip_adapt_scale: float = 0,
+ device: str = "cuda",
+) -> DiffusionPipeline:
+ tokenizer = ChatGLMTokenizer.from_pretrained(
+ f"{base_ckpt_dir}/Kolors/text_encoder"
+ )
+ text_encoder = ChatGLMModel.from_pretrained(
+ f"{base_ckpt_dir}/Kolors/text_encoder", torch_dtype=torch.float16
+ ).half()
+ vae = AutoencoderKL.from_pretrained(
+ f"{base_ckpt_dir}/Kolors/vae", revision=None
+ ).half()
+ unet = UNet2DConditionModel.from_pretrained(
+ f"{base_ckpt_dir}/Kolors/unet", revision=None
+ ).half()
+ scheduler = EulerDiscreteScheduler.from_pretrained(
+ f"{base_ckpt_dir}/Kolors/scheduler"
+ )
+
+ if controlnet_ckpt is None:
+ suffix = "geo_cond_mv"
+ model_path = snapshot_download(
+ repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
+ )
+ controlnet_ckpt = os.path.join(model_path, suffix)
+
+ controlnet = ControlNetModel.from_pretrained(
+ controlnet_ckpt, use_safetensors=True
+ ).half()
+
+ # IP-Adapter model
+ image_encoder = None
+ clip_image_processor = None
+ if ip_adapt_scale > 0:
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ f"{base_ckpt_dir}/Kolors-IP-Adapter-Plus/image_encoder",
+ # ignore_mismatched_sizes=True,
+ ).to(dtype=torch.float16)
+ ip_img_size = 336
+ clip_image_processor = CLIPImageProcessor(
+ size=ip_img_size, crop_size=ip_img_size
+ )
+
+ pipe = StableDiffusionXLControlNetImg2ImgPipeline(
+ vae=vae,
+ controlnet=controlnet,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=clip_image_processor,
+ force_zeros_for_empty_prompt=False,
+ )
+
+ if ip_adapt_scale > 0:
+ if hasattr(pipe.unet, "encoder_hid_proj"):
+ pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj
+ pipe.load_ip_adapter(
+ f"{base_ckpt_dir}/Kolors-IP-Adapter-Plus",
+ subfolder="",
+ weight_name=["ip_adapter_plus_general.bin"],
+ )
+ pipe.set_ip_adapter_scale([ip_adapt_scale])
+
+ pipe = pipe.to(device)
+ pipe.enable_model_cpu_offload()
+
+ return pipe
diff --git a/embodied_gen/scripts/imageto3d.py b/embodied_gen/scripts/imageto3d.py
new file mode 100644
index 0000000..4ebcca8
--- /dev/null
+++ b/embodied_gen/scripts/imageto3d.py
@@ -0,0 +1,311 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+
+import argparse
+import logging
+import os
+import sys
+from glob import glob
+from shutil import copy, copytree
+
+import numpy as np
+import trimesh
+from PIL import Image
+from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
+from embodied_gen.data.utils import delete_dir, trellis_preprocess
+from embodied_gen.models.delight_model import DelightingModel
+from embodied_gen.models.gs_model import GaussianOperator
+from embodied_gen.models.segment_model import (
+ BMGG14Remover,
+ RembgRemover,
+ SAMPredictor,
+)
+from embodied_gen.models.sr_model import ImageRealESRGAN
+from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
+from embodied_gen.utils.gpt_clients import GPT_CLIENT
+from embodied_gen.utils.process_media import merge_images_video, render_video
+from embodied_gen.utils.tags import VERSION
+from embodied_gen.validators.quality_checkers import (
+ BaseChecker,
+ ImageAestheticChecker,
+ ImageSegChecker,
+ MeshGeoChecker,
+)
+from embodied_gen.validators.urdf_convertor import URDFGenerator
+
+current_file_path = os.path.abspath(__file__)
+current_dir = os.path.dirname(current_file_path)
+sys.path.append(os.path.join(current_dir, "../.."))
+from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
+
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
+)
+logger = logging.getLogger(__name__)
+
+
+os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
+ "~/.cache/torch_extensions"
+)
+os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
+os.environ["SPCONV_ALGO"] = "native"
+
+
+DELIGHT = DelightingModel()
+IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
+
+RBG_REMOVER = RembgRemover()
+RBG14_REMOVER = BMGG14Remover()
+SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
+PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
+ "microsoft/TRELLIS-image-large"
+)
+PIPELINE.cuda()
+SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
+GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
+AESTHETIC_CHECKER = ImageAestheticChecker()
+CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
+TMP_DIR = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
+)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Image to 3D pipeline args.")
+ parser.add_argument(
+ "--image_path", type=str, nargs="+", help="Path to the input images."
+ )
+ parser.add_argument(
+ "--image_root", type=str, help="Path to the input images folder."
+ )
+ parser.add_argument(
+ "--output_root",
+ type=str,
+ required=True,
+ help="Root directory for saving outputs.",
+ )
+ parser.add_argument(
+ "--height_range",
+ type=str,
+ default=None,
+ help="The hight in meter to restore the mesh real size.",
+ )
+ parser.add_argument(
+ "--mass_range",
+ type=str,
+ default=None,
+ help="The mass in kg to restore the mesh real weight.",
+ )
+ parser.add_argument("--asset_type", type=str, default=None)
+ parser.add_argument("--skip_exists", action="store_true")
+ parser.add_argument("--strict_seg", action="store_true")
+ parser.add_argument("--version", type=str, default=VERSION)
+ parser.add_argument("--remove_intermediate", type=bool, default=True)
+ args = parser.parse_args()
+
+ assert (
+ args.image_path or args.image_root
+ ), "Please provide either --image_path or --image_root."
+ if not args.image_path:
+ args.image_path = glob(os.path.join(args.image_root, "*.png"))
+ args.image_path += glob(os.path.join(args.image_root, "*.jpg"))
+ args.image_path += glob(os.path.join(args.image_root, "*.jpeg"))
+
+ return args
+
+
+if __name__ == "__main__":
+ args = parse_args()
+
+ for image_path in args.image_path:
+ try:
+ filename = os.path.basename(image_path).split(".")[0]
+ output_root = args.output_root
+ if args.image_root is not None or len(args.image_path) > 1:
+ output_root = os.path.join(output_root, filename)
+ os.makedirs(output_root, exist_ok=True)
+
+ mesh_out = f"{output_root}/{filename}.obj"
+ if args.skip_exists and os.path.exists(mesh_out):
+ logger.info(
+ f"Skip {image_path}, already processed in {mesh_out}"
+ )
+ continue
+
+ image = Image.open(image_path)
+ image.save(f"{output_root}/{filename}_raw.png")
+
+ # Segmentation: Get segmented image using SAM or Rembg.
+ seg_path = f"{output_root}/{filename}_cond.png"
+ if image.mode != "RGBA":
+ seg_image = RBG_REMOVER(image, save_path=seg_path)
+ seg_image = trellis_preprocess(seg_image)
+ else:
+ seg_image = image
+ seg_image.save(seg_path)
+
+ # Run the pipeline
+ try:
+ outputs = PIPELINE.run(
+ seg_image,
+ preprocess_image=False,
+ # Optional parameters
+ # seed=1,
+ # sparse_structure_sampler_params={
+ # "steps": 12,
+ # "cfg_strength": 7.5,
+ # },
+ # slat_sampler_params={
+ # "steps": 12,
+ # "cfg_strength": 3,
+ # },
+ )
+ except Exception as e:
+ logger.error(
+ f"[Pipeline Failed] process {image_path}: {e}, skip."
+ )
+ continue
+
+ # Render and save color and mesh videos
+ gs_model = outputs["gaussian"][0]
+ mesh_model = outputs["mesh"][0]
+ color_images = render_video(gs_model)["color"]
+ normal_images = render_video(mesh_model)["normal"]
+ video_path = os.path.join(output_root, "gs_mesh.mp4")
+ merge_images_video(color_images, normal_images, video_path)
+
+ # Save the raw Gaussian model
+ gs_path = mesh_out.replace(".obj", "_gs.ply")
+ gs_model.save_ply(gs_path)
+
+ # Rotate mesh and GS by 90 degrees around Z-axis.
+ rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
+ gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
+ mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
+
+ # Addtional rotation for GS to align mesh.
+ gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix)
+ pose = GaussianOperator.trans_to_quatpose(gs_rot)
+ aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
+ GaussianOperator.resave_ply(
+ in_ply=gs_path,
+ out_ply=aligned_gs_path,
+ instance_pose=pose,
+ device="cpu",
+ )
+ color_path = os.path.join(output_root, "color.png")
+ render_gs_api(aligned_gs_path, color_path)
+
+ mesh = trimesh.Trimesh(
+ vertices=mesh_model.vertices.cpu().numpy(),
+ faces=mesh_model.faces.cpu().numpy(),
+ )
+ mesh.vertices = mesh.vertices @ np.array(mesh_add_rot)
+ mesh.vertices = mesh.vertices @ np.array(rot_matrix)
+
+ mesh_obj_path = os.path.join(output_root, f"{filename}.obj")
+ mesh.export(mesh_obj_path)
+
+ mesh = backproject_api(
+ delight_model=DELIGHT,
+ imagesr_model=IMAGESR_MODEL,
+ color_path=color_path,
+ mesh_path=mesh_obj_path,
+ output_path=mesh_obj_path,
+ skip_fix_mesh=False,
+ delight=True,
+ texture_wh=[2048, 2048],
+ )
+
+ mesh_glb_path = os.path.join(output_root, f"{filename}.glb")
+ mesh.export(mesh_glb_path)
+
+ urdf_convertor = URDFGenerator(GPT_CLIENT, render_view_num=4)
+ asset_attrs = {
+ "version": VERSION,
+ "gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply",
+ }
+ if args.height_range:
+ min_height, max_height = map(
+ float, args.height_range.split("-")
+ )
+ asset_attrs["min_height"] = min_height
+ asset_attrs["max_height"] = max_height
+ if args.mass_range:
+ min_mass, max_mass = map(float, args.mass_range.split("-"))
+ asset_attrs["min_mass"] = min_mass
+ asset_attrs["max_mass"] = max_mass
+ if args.asset_type:
+ asset_attrs["category"] = args.asset_type
+ if args.version:
+ asset_attrs["version"] = args.version
+
+ urdf_root = f"{output_root}/URDF_{filename}"
+ urdf_path = urdf_convertor(
+ mesh_path=mesh_obj_path,
+ output_root=urdf_root,
+ **asset_attrs,
+ )
+
+ # Rescale GS and save to URDF/mesh folder.
+ real_height = urdf_convertor.get_attr_from_urdf(
+ urdf_path, attr_name="real_height"
+ )
+ out_gs = f"{urdf_root}/{urdf_convertor.output_mesh_dir}/{filename}_gs.ply" # noqa
+ GaussianOperator.resave_ply(
+ in_ply=aligned_gs_path,
+ out_ply=out_gs,
+ real_height=real_height,
+ device="cpu",
+ )
+
+ # Quality check and update .urdf file.
+ mesh_out = f"{urdf_root}/{urdf_convertor.output_mesh_dir}/{filename}.obj" # noqa
+ trimesh.load(mesh_out).export(mesh_out.replace(".obj", ".glb"))
+
+ image_dir = f"{urdf_root}/{urdf_convertor.output_render_dir}/image_color" # noqa
+ image_paths = glob(f"{image_dir}/*.png")
+ images_list = []
+ for checker in CHECKERS:
+ images = image_paths
+ if isinstance(checker, ImageSegChecker):
+ images = [
+ f"{output_root}/{filename}_raw.png",
+ f"{output_root}/{filename}_cond.png",
+ ]
+ images_list.append(images)
+
+ results = BaseChecker.validate(CHECKERS, images_list)
+ urdf_convertor.add_quality_tag(urdf_path, results)
+
+ # Organize the final result files
+ result_dir = f"{output_root}/result"
+ os.makedirs(result_dir, exist_ok=True)
+ copy(urdf_path, f"{result_dir}/{os.path.basename(urdf_path)}")
+ copytree(
+ f"{urdf_root}/{urdf_convertor.output_mesh_dir}",
+ f"{result_dir}/{urdf_convertor.output_mesh_dir}",
+ )
+ copy(video_path, f"{result_dir}/video.mp4")
+ if args.remove_intermediate:
+ delete_dir(output_root, keep_subs=["result"])
+
+ except Exception as e:
+ logger.error(f"Failed to process {image_path}: {e}, skip.")
+ continue
+
+ logger.info(f"Processing complete. Outputs saved to {args.output_root}")
diff --git a/embodied_gen/scripts/render_gs.py b/embodied_gen/scripts/render_gs.py
new file mode 100644
index 0000000..16e7f37
--- /dev/null
+++ b/embodied_gen/scripts/render_gs.py
@@ -0,0 +1,175 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+
+import argparse
+import logging
+import math
+import os
+
+import cv2
+import numpy as np
+import spaces
+import torch
+from tqdm import tqdm
+from embodied_gen.data.utils import (
+ CameraSetting,
+ init_kal_camera,
+ normalize_vertices_array,
+)
+from embodied_gen.models.gs_model import GaussianOperator
+
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
+)
+logger = logging.getLogger(__name__)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Render GS color images")
+
+ parser.add_argument(
+ "--input_gs", type=str, help="Input render GS.ply path."
+ )
+ parser.add_argument(
+ "--output_path",
+ type=str,
+ help="Output grid image path for rendered GS color images.",
+ )
+ parser.add_argument(
+ "--num_images", type=int, default=6, help="Number of images to render."
+ )
+ parser.add_argument(
+ "--elevation",
+ type=float,
+ nargs="+",
+ default=[20.0, -10.0],
+ help="Elevation angles for the camera (default: [20.0, -10.0])",
+ )
+ parser.add_argument(
+ "--distance",
+ type=float,
+ default=5,
+ help="Camera distance (default: 5)",
+ )
+ parser.add_argument(
+ "--resolution_hw",
+ type=int,
+ nargs=2,
+ default=(512, 512),
+ help="Resolution of the output images (default: (512, 512))",
+ )
+ parser.add_argument(
+ "--fov",
+ type=float,
+ default=30,
+ help="Field of view in degrees (default: 30)",
+ )
+ parser.add_argument(
+ "--device",
+ type=str,
+ choices=["cpu", "cuda"],
+ default="cuda",
+ help="Device to run on (default: `cuda`)",
+ )
+ parser.add_argument(
+ "--image_size",
+ type=int,
+ default=512,
+ help="Output image size for single view in color grid (default: 512)",
+ )
+
+ args, unknown = parser.parse_known_args()
+
+ return args
+
+
+def load_gs_model(
+ input_gs: str, pre_quat: list[float] = [0.0, 0.7071, 0.0, -0.7071]
+) -> GaussianOperator:
+ gs_model = GaussianOperator.load_from_ply(input_gs)
+ # Normalize vertices to [-1, 1], center to (0, 0, 0).
+ _, scale, center = normalize_vertices_array(gs_model._means)
+ scale, center = float(scale), center.tolist()
+ transpose = [*[-v for v in center], *pre_quat]
+ instance_pose = torch.tensor(transpose).to(gs_model.device)
+ gs_model = gs_model.get_gaussians(instance_pose=instance_pose)
+ gs_model.rescale(scale)
+
+ return gs_model
+
+
+@spaces.GPU
+def entrypoint(input_gs: str = None, output_path: str = None) -> None:
+ args = parse_args()
+ if isinstance(input_gs, str):
+ args.input_gs = input_gs
+ if isinstance(output_path, str):
+ args.output_path = output_path
+
+ # Setup camera parameters
+ camera_params = CameraSetting(
+ num_images=args.num_images,
+ elevation=args.elevation,
+ distance=args.distance,
+ resolution_hw=args.resolution_hw,
+ fov=math.radians(args.fov),
+ device=args.device,
+ )
+ camera = init_kal_camera(camera_params)
+ matrix_mv = camera.view_matrix() # (n_cam 4 4) world2cam
+ matrix_mv[:, :3, 3] = -matrix_mv[:, :3, 3]
+ w2cs = matrix_mv.to(camera_params.device)
+ c2ws = [torch.linalg.inv(matrix) for matrix in w2cs]
+ Ks = torch.tensor(camera_params.Ks).to(camera_params.device)
+
+ # Load GS model and normalize.
+ gs_model = load_gs_model(args.input_gs, pre_quat=[0.0, 0.0, 1.0, 0.0])
+
+ # Render GS color images.
+ images = []
+ for idx in tqdm(range(len(c2ws)), desc="Rendering GS"):
+ result = gs_model.render(
+ c2ws[idx],
+ Ks=Ks,
+ image_width=camera_params.resolution_hw[1],
+ image_height=camera_params.resolution_hw[0],
+ )
+ color = cv2.resize(
+ result.rgba,
+ (args.image_size, args.image_size),
+ interpolation=cv2.INTER_AREA,
+ )
+ images.append(color)
+
+ # Cat color images into grid image and save.
+ select_idxs = [[0, 2, 1], [5, 4, 3]] # fix order for 6 views
+ grid_image = []
+ for row_idxs in select_idxs:
+ row_image = []
+ for row_idx in row_idxs:
+ row_image.append(images[row_idx])
+ row_image = np.concatenate(row_image, axis=1)
+ grid_image.append(row_image)
+
+ grid_image = np.concatenate(grid_image, axis=0)
+ os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
+ cv2.imwrite(args.output_path, grid_image)
+ logger.info(f"Saved grid image to {args.output_path}")
+
+
+if __name__ == "__main__":
+ entrypoint()
diff --git a/embodied_gen/scripts/render_mv.py b/embodied_gen/scripts/render_mv.py
new file mode 100644
index 0000000..61c8013
--- /dev/null
+++ b/embodied_gen/scripts/render_mv.py
@@ -0,0 +1,198 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+
+import logging
+import os
+import random
+from typing import List, Tuple
+
+import fire
+import numpy as np
+import torch
+from diffusers.utils import make_image_grid
+from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import (
+ StableDiffusionXLControlNetImg2ImgPipeline,
+)
+from PIL import Image, ImageEnhance, ImageFilter
+from torchvision import transforms
+from embodied_gen.data.datasets import Asset3dGenDataset
+from embodied_gen.models.texture_model import build_texture_gen_pipe
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+def get_init_noise_image(image: Image.Image) -> Image.Image:
+ blurred_image = image.convert("L").filter(
+ ImageFilter.GaussianBlur(radius=3)
+ )
+
+ enhancer = ImageEnhance.Contrast(blurred_image)
+ image_decreased_contrast = enhancer.enhance(factor=0.5)
+
+ return image_decreased_contrast
+
+
+def infer_pipe(
+ index_file: str,
+ controlnet_ckpt: str = None,
+ uid: str = None,
+ prompt: str = None,
+ controlnet_cond_scale: float = 0.4,
+ control_guidance_end: float = 0.9,
+ strength: float = 1.0,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 10,
+ ip_adapt_scale: float = 0,
+ ip_img_path: str = None,
+ sub_idxs: List[List[int]] = None,
+ num_images_per_prompt: int = 3, # increase if want similar images.
+ device: str = "cuda",
+ save_dir: str = "infer_vis",
+ seed: int = None,
+ target_hw: tuple[int, int] = (512, 512),
+ pipeline: StableDiffusionXLControlNetImg2ImgPipeline = None,
+) -> str:
+ # sub_idxs = [[0, 1, 2], [3, 4, 5]] # None for single image.
+ if sub_idxs is None:
+ sub_idxs = [[random.randint(0, 5)]] # 6 views.
+ target_hw = [2 * size for size in target_hw]
+
+ transform_list = [
+ transforms.Resize(
+ target_hw, interpolation=transforms.InterpolationMode.BILINEAR
+ ),
+ transforms.CenterCrop(target_hw),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ image_transform = transforms.Compose(transform_list)
+ control_transform = transforms.Compose(transform_list[:-1])
+
+ grid_hw = (target_hw[0] * len(sub_idxs), target_hw[1] * len(sub_idxs[0]))
+ dataset = Asset3dGenDataset(
+ index_file, target_hw=grid_hw, sub_idxs=sub_idxs
+ )
+
+ if uid is None:
+ uid = random.choice(list(dataset.meta_info.keys()))
+ if prompt is None:
+ prompt = dataset.meta_info[uid]["capture"]
+ if isinstance(prompt, List) or isinstance(prompt, Tuple):
+ prompt = ", ".join(map(str, prompt))
+ # prompt += "high quality, ultra-clear, high resolution, best quality, 4k"
+ # prompt += "้ซๅ่ดจ,ๆธ
ๆฐ,็ป่"
+ prompt += ", high quality, high resolution, best quality"
+ # prompt += ", with diffuse lighting, showing no reflections."
+ logger.info(f"Inference with prompt: {prompt}")
+
+ negative_prompt = "nsfw,้ดๅฝฑ,ไฝๅ่พจ็,ไผชๅฝฑใๆจก็ณ,้่น็ฏ,้ซๅ
,้้ขๅๅฐ"
+
+ control_image = dataset.fetch_sample_grid_images(
+ uid,
+ attrs=["image_view_normal", "image_position", "image_mask"],
+ sub_idxs=sub_idxs,
+ transform=control_transform,
+ )
+
+ color_image = dataset.fetch_sample_grid_images(
+ uid,
+ attrs=["image_color"],
+ sub_idxs=sub_idxs,
+ transform=image_transform,
+ )
+
+ normal_pil, position_pil, mask_pil, color_pil = dataset.visualize_item(
+ control_image,
+ color_image,
+ save_dir=save_dir,
+ )
+
+ if pipeline is None:
+ pipeline = build_texture_gen_pipe(
+ base_ckpt_dir="./weights",
+ controlnet_ckpt=controlnet_ckpt,
+ ip_adapt_scale=ip_adapt_scale,
+ device=device,
+ )
+
+ if ip_adapt_scale > 0 and ip_img_path is not None and len(ip_img_path) > 0:
+ ip_image = Image.open(ip_img_path).convert("RGB")
+ ip_image = ip_image.resize(target_hw[::-1])
+ ip_image = [ip_image]
+ pipeline.set_ip_adapter_scale([ip_adapt_scale])
+ else:
+ ip_image = None
+
+ generator = None
+ if seed is not None:
+ generator = torch.Generator(device).manual_seed(seed)
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+
+ init_image = get_init_noise_image(normal_pil)
+ # init_image = get_init_noise_image(color_pil)
+
+ images = []
+ row_num, col_num = 2, 3
+ img_save_paths = []
+ while len(images) < col_num:
+ image = pipeline(
+ prompt=prompt,
+ image=init_image,
+ controlnet_conditioning_scale=controlnet_cond_scale,
+ control_guidance_end=control_guidance_end,
+ strength=strength,
+ control_image=control_image[None, ...],
+ negative_prompt=negative_prompt,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ num_images_per_prompt=num_images_per_prompt,
+ ip_adapter_image=ip_image,
+ generator=generator,
+ ).images
+ images.extend(image)
+
+ grid_image = [normal_pil, position_pil, color_pil] + images[:col_num]
+ # save_dir = os.path.join(save_dir, uid)
+ os.makedirs(save_dir, exist_ok=True)
+
+ for idx in range(col_num):
+ rgba_image = Image.merge("RGBA", (*images[idx].split(), mask_pil))
+ img_save_path = os.path.join(save_dir, f"color_sample{idx}.png")
+ rgba_image.save(img_save_path)
+ img_save_paths.append(img_save_path)
+
+ sub_idxs = "_".join(
+ [str(item) for sublist in sub_idxs for item in sublist]
+ )
+ save_path = os.path.join(
+ save_dir, f"sample_idx{str(sub_idxs)}_ip{ip_adapt_scale}.jpg"
+ )
+ make_image_grid(grid_image, row_num, col_num).save(save_path)
+ logger.info(f"Visualize in {save_path}")
+
+ return img_save_paths
+
+
+def entrypoint() -> None:
+ fire.Fire(infer_pipe)
+
+
+if __name__ == "__main__":
+ entrypoint()
diff --git a/embodied_gen/scripts/text2image.py b/embodied_gen/scripts/text2image.py
new file mode 100644
index 0000000..3b375a3
--- /dev/null
+++ b/embodied_gen/scripts/text2image.py
@@ -0,0 +1,168 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+
+import argparse
+import logging
+import os
+
+from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import (
+ StableDiffusionXLPipeline,
+)
+from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import ( # noqa
+ StableDiffusionXLPipeline as StableDiffusionXLPipelineIP,
+)
+from tqdm import tqdm
+from embodied_gen.models.text_model import (
+ build_text2img_ip_pipeline,
+ build_text2img_pipeline,
+ text2img_gen,
+)
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Text to Image.")
+ parser.add_argument(
+ "--prompts",
+ type=str,
+ nargs="+",
+ help="List of prompts (space-separated).",
+ )
+ parser.add_argument(
+ "--ref_image",
+ type=str,
+ nargs="+",
+ help="List of ref_image paths (space-separated).",
+ )
+ parser.add_argument(
+ "--output_root",
+ type=str,
+ help="Root directory for saving outputs.",
+ )
+ parser.add_argument(
+ "--guidance_scale",
+ type=float,
+ default=12.0,
+ help="Guidance scale for the diffusion model.",
+ )
+ parser.add_argument(
+ "--ref_scale",
+ type=float,
+ default=0.3,
+ help="Reference image scale for the IP adapter.",
+ )
+ parser.add_argument(
+ "--n_sample",
+ type=int,
+ default=1,
+ )
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ )
+ parser.add_argument(
+ "--infer_step",
+ type=int,
+ default=50,
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=0,
+ )
+ args = parser.parse_args()
+
+ return args
+
+
+def entrypoint(
+ pipeline: StableDiffusionXLPipeline | StableDiffusionXLPipelineIP = None,
+ **kwargs,
+) -> list[str]:
+ args = parse_args()
+ for k, v in kwargs.items():
+ if hasattr(args, k) and v is not None:
+ setattr(args, k, v)
+
+ prompts = args.prompts
+ if len(prompts) == 1 and prompts[0].endswith(".txt"):
+ with open(prompts[0], "r") as f:
+ prompts = f.readlines()
+ prompts = [
+ prompt.strip() for prompt in prompts if prompt.strip() != ""
+ ]
+
+ os.makedirs(args.output_root, exist_ok=True)
+
+ ip_img_paths = args.ref_image
+ if ip_img_paths is None or len(ip_img_paths) == 0:
+ args.ref_scale = 0
+ ip_img_paths = [None] * len(prompts)
+ elif isinstance(ip_img_paths, str):
+ ip_img_paths = [ip_img_paths] * len(prompts)
+ elif isinstance(ip_img_paths, list):
+ if len(ip_img_paths) == 1:
+ ip_img_paths = ip_img_paths * len(prompts)
+ else:
+ raise ValueError("Invalid ref_image paths.")
+ assert len(ip_img_paths) == len(
+ prompts
+ ), f"Number of ref images does not match prompts, {len(ip_img_paths)} != {len(prompts)}" # noqa
+
+ if pipeline is None:
+ if args.ref_scale > 0:
+ pipeline = build_text2img_ip_pipeline(
+ "weights/Kolors",
+ ref_scale=args.ref_scale,
+ )
+ else:
+ pipeline = build_text2img_pipeline("weights/Kolors")
+
+ for idx, (prompt, ip_img_path) in tqdm(
+ enumerate(zip(prompts, ip_img_paths)),
+ desc="Generating images",
+ total=len(prompts),
+ ):
+ images = text2img_gen(
+ prompt=prompt,
+ n_sample=args.n_sample,
+ guidance_scale=args.guidance_scale,
+ pipeline=pipeline,
+ ip_image=ip_img_path,
+ image_wh=[args.resolution, args.resolution],
+ infer_step=args.infer_step,
+ seed=args.seed,
+ )
+
+ save_paths = []
+ for sub_idx, image in enumerate(images):
+ save_path = (
+ f"{args.output_root}/sample_{idx*args.n_sample+sub_idx}.png"
+ )
+ image.save(save_path)
+ save_paths.append(save_path)
+
+ logger.info(f"Images saved to {args.output_root}")
+
+ return save_paths
+
+
+if __name__ == "__main__":
+ entrypoint()
diff --git a/embodied_gen/scripts/textto3d.sh b/embodied_gen/scripts/textto3d.sh
new file mode 100644
index 0000000..d9648c7
--- /dev/null
+++ b/embodied_gen/scripts/textto3d.sh
@@ -0,0 +1,56 @@
+#!/bin/bash
+
+# Initialize variables
+prompts=()
+output_root=""
+
+# Parse arguments
+while [[ $# -gt 0 ]]; do
+ case "$1" in
+ --prompts)
+ shift
+ while [[ $# -gt 0 && ! "$1" =~ ^-- ]]; do
+ prompts+=("$1")
+ shift
+ done
+ ;;
+ --output_root)
+ output_root="$2"
+ shift 2
+ ;;
+ *)
+ echo "Unknown argument: $1"
+ exit 1
+ ;;
+ esac
+done
+
+# Validate required arguments
+if [[ ${#prompts[@]} -eq 0 || -z "$output_root" ]]; then
+ echo "Missing required arguments."
+ echo "Usage: bash run_text2asset3d.sh --prompts \"Prompt1\" \"Prompt2\" --output_root "
+ exit 1
+fi
+
+# Print arguments (for debugging)
+echo "Prompts:"
+for p in "${prompts[@]}"; do
+ echo " - $p"
+done
+echo "Output root: ${output_root}"
+
+# Concatenate prompts for Python command
+prompt_args=""
+for p in "${prompts[@]}"; do
+ prompt_args+="\"$p\" "
+done
+
+# Step 1: Text-to-Image
+eval python3 embodied_gen/scripts/text2image.py \
+ --prompts ${prompt_args} \
+ --output_root "${output_root}/images"
+
+# Step 2: Image-to-3D
+python3 embodied_gen/scripts/imageto3d.py \
+ --image_root "${output_root}/images" \
+ --output_root "${output_root}/asset3d"
diff --git a/embodied_gen/scripts/texture_gen.sh b/embodied_gen/scripts/texture_gen.sh
new file mode 100644
index 0000000..747a5fe
--- /dev/null
+++ b/embodied_gen/scripts/texture_gen.sh
@@ -0,0 +1,80 @@
+#!/bin/bash
+
+while [[ $# -gt 0 ]]; do
+ case $1 in
+ --mesh_path)
+ mesh_path="$2"
+ shift 2
+ ;;
+ --prompt)
+ prompt="$2"
+ shift 2
+ ;;
+ --uuid)
+ uuid="$2"
+ shift 2
+ ;;
+ --output_root)
+ output_root="$2"
+ shift 2
+ ;;
+ *)
+ echo "unknown: $1"
+ exit 1
+ ;;
+ esac
+done
+
+
+if [[ -z "$mesh_path" || -z "$prompt" || -z "$uuid" || -z "$output_root" ]]; then
+ echo "params missing"
+ echo "usage: bash run.sh --mesh_path --prompt --uuid --output_root "
+ exit 1
+fi
+
+# Step 1: drender-cli for condition rendering
+drender-cli --mesh_path ${mesh_path} \
+ --output_root ${output_root}/condition \
+ --uuid ${uuid}
+
+# Step 2: multi-view rendering
+python embodied_gen/scripts/render_mv.py \
+ --index_file "${output_root}/condition/index.json" \
+ --controlnet_cond_scale 0.75 \
+ --guidance_scale 9 \
+ --strength 0.9 \
+ --num_inference_steps 40 \
+ --ip_adapt_scale 0 \
+ --ip_img_path None \
+ --uid ${uuid} \
+ --prompt "${prompt}" \
+ --save_dir "${output_root}/multi_view" \
+ --sub_idxs "[[0,1,2],[3,4,5]]" \
+ --seed 0
+
+# Step 3: backprojection
+backproject-cli --mesh_path ${mesh_path} \
+ --color_path ${output_root}/multi_view/color_sample0.png \
+ --output_path "${output_root}/texture_mesh/${uuid}.obj" \
+ --save_glb_path "${output_root}/texture_mesh/${uuid}.glb" \
+ --skip_fix_mesh \
+ --delight \
+ --no_save_delight_img
+
+# Step 4: final rendering of textured mesh
+drender-cli --mesh_path "${output_root}/texture_mesh/${uuid}.obj" \
+ --output_root ${output_root}/texture_mesh \
+ --num_images 90 \
+ --elevation 20 \
+ --with_mtl \
+ --gen_color_mp4 \
+ --pbr_light_factor 1.2
+
+# Organize folders
+rm -rf ${output_root}/condition
+video_path="${output_root}/texture_mesh/${uuid}/color.mp4"
+if [ -f "${video_path}" ]; then
+ cp "${video_path}" "${output_root}/texture_mesh/color.mp4"
+ echo "Resave video to ${output_root}/texture_mesh/color.mp4"
+fi
+rm -rf ${output_root}/texture_mesh/${uuid}
\ No newline at end of file
diff --git a/embodied_gen/utils/gpt_clients.py b/embodied_gen/utils/gpt_clients.py
new file mode 100644
index 0000000..7b3f72b
--- /dev/null
+++ b/embodied_gen/utils/gpt_clients.py
@@ -0,0 +1,211 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+
+import base64
+import logging
+import os
+from io import BytesIO
+from typing import Optional
+
+import yaml
+from openai import AzureOpenAI, OpenAI # pip install openai
+from PIL import Image
+from tenacity import (
+ retry,
+ stop_after_attempt,
+ stop_after_delay,
+ wait_random_exponential,
+)
+from embodied_gen.utils.process_media import combine_images_to_base64
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+class GPTclient:
+ """A client to interact with the GPT model via OpenAI or Azure API."""
+
+ def __init__(
+ self,
+ endpoint: str,
+ api_key: str,
+ model_name: str = "yfb-gpt-4o",
+ api_version: str = None,
+ verbose: bool = False,
+ ):
+ if api_version is not None:
+ self.client = AzureOpenAI(
+ azure_endpoint=endpoint,
+ api_key=api_key,
+ api_version=api_version,
+ )
+ else:
+ self.client = OpenAI(
+ base_url=endpoint,
+ api_key=api_key,
+ )
+
+ self.endpoint = endpoint
+ self.model_name = model_name
+ self.image_formats = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"}
+ self.verbose = verbose
+ logger.info(f"Using GPT model: {self.model_name}.")
+
+ @retry(
+ wait=wait_random_exponential(min=1, max=20),
+ stop=(stop_after_attempt(10) | stop_after_delay(30)),
+ )
+ def completion_with_backoff(self, **kwargs):
+ return self.client.chat.completions.create(**kwargs)
+
+ def query(
+ self,
+ text_prompt: str,
+ image_base64: Optional[list[str | Image.Image]] = None,
+ system_role: Optional[str] = None,
+ ) -> Optional[str]:
+ """Queries the GPT model with a text and optional image prompts.
+
+ Args:
+ text_prompt (str): The main text input that the model responds to.
+ image_base64 (Optional[List[str]]): A list of image base64 strings
+ or local image paths or PIL.Image to accompany the text prompt.
+ system_role (Optional[str]): Optional system-level instructions
+ that specify the behavior of the assistant.
+
+ Returns:
+ Optional[str]: The response content generated by the model based on
+ the prompt. Returns `None` if an error occurs.
+ """
+ if system_role is None:
+ system_role = "You are a highly knowledgeable assistant specializing in physics, engineering, and object properties." # noqa
+
+ content_user = [
+ {
+ "type": "text",
+ "text": text_prompt,
+ },
+ ]
+
+ # Process images if provided
+ if image_base64 is not None:
+ image_base64 = (
+ image_base64
+ if isinstance(image_base64, list)
+ else [image_base64]
+ )
+ for img in image_base64:
+ if isinstance(img, Image.Image):
+ buffer = BytesIO()
+ img.save(buffer, format=img.format or "PNG")
+ buffer.seek(0)
+ image_binary = buffer.read()
+ img = base64.b64encode(image_binary).decode("utf-8")
+ elif (
+ len(os.path.splitext(img)) > 1
+ and os.path.splitext(img)[-1].lower() in self.image_formats
+ ):
+ if not os.path.exists(img):
+ raise FileNotFoundError(f"Image file not found: {img}")
+ with open(img, "rb") as f:
+ img = base64.b64encode(f.read()).decode("utf-8")
+
+ content_user.append(
+ {
+ "type": "image_url",
+ "image_url": {"url": f"data:image/png;base64,{img}"},
+ }
+ )
+
+ payload = {
+ "messages": [
+ {"role": "system", "content": system_role},
+ {"role": "user", "content": content_user},
+ ],
+ "temperature": 0.1,
+ "max_tokens": 500,
+ "top_p": 0.1,
+ "frequency_penalty": 0,
+ "presence_penalty": 0,
+ "stop": None,
+ }
+ payload.update({"model": self.model_name})
+
+ response = None
+ try:
+ response = self.completion_with_backoff(**payload)
+ response = response.choices[0].message.content
+ except Exception as e:
+ logger.error(f"Error GPTclint {self.endpoint} API call: {e}")
+ response = None
+
+ if self.verbose:
+ logger.info(f"Prompt: {text_prompt}")
+ logger.info(f"Response: {response}")
+
+ return response
+
+
+with open("embodied_gen/utils/gpt_config.yaml", "r") as f:
+ config = yaml.safe_load(f)
+
+agent_type = config["agent_type"]
+agent_config = config.get(agent_type, {})
+
+# Prefer environment variables, fallback to YAML config
+endpoint = os.environ.get("ENDPOINT", agent_config.get("endpoint"))
+api_key = os.environ.get("API_KEY", agent_config.get("api_key"))
+api_version = os.environ.get("API_VERSION", agent_config.get("api_version"))
+model_name = os.environ.get("MODEL_NAME", agent_config.get("model_name"))
+
+GPT_CLIENT = GPTclient(
+ endpoint=endpoint,
+ api_key=api_key,
+ api_version=api_version,
+ model_name=model_name,
+)
+
+if __name__ == "__main__":
+ if "openrouter" in GPT_CLIENT.endpoint:
+ response = GPT_CLIENT.query(
+ text_prompt="What is the content in each image?",
+ image_base64=combine_images_to_base64(
+ [
+ "outputs/text2image/demo_objects/bed/sample_0.jpg",
+ "outputs/imageto3d/v2/cups/sample_69/URDF_sample_69/qa_renders/image_color/003.png", # noqa
+ "outputs/text2image/demo_objects/cardboard/sample_1.jpg",
+ ]
+ ), # input raw image_path if only one image
+ )
+ print(response)
+ else:
+ response = GPT_CLIENT.query(
+ text_prompt="What is the content in the images?",
+ image_base64=[
+ Image.open("outputs/text2image/demo_objects/bed/sample_0.jpg"),
+ Image.open(
+ "outputs/imageto3d/v2/cups/sample_69/URDF_sample_69/qa_renders/image_color/003.png" # noqa
+ ),
+ ],
+ )
+ print(response)
+
+ # test2: text prompt
+ response = GPT_CLIENT.query(
+ text_prompt="What is the capital of China?"
+ )
+ print(response)
diff --git a/embodied_gen/utils/gpt_config.yaml b/embodied_gen/utils/gpt_config.yaml
new file mode 100644
index 0000000..2e966bf
--- /dev/null
+++ b/embodied_gen/utils/gpt_config.yaml
@@ -0,0 +1,14 @@
+# config.yaml
+agent_type: "qwen2.5-vl" # gpt-4o or qwen2.5-vl
+
+gpt-4o:
+ endpoint: https://xxx.openai.azure.com
+ api_key: xxx
+ api_version: 2025-xx-xx
+ model_name: yfb-gpt-4o
+
+qwen2.5-vl:
+ endpoint: https://openrouter.ai/api/v1
+ api_key: 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
+ api_version: null
+ model_name: qwen/qwen2.5-vl-72b-instruct:free
diff --git a/embodied_gen/utils/process_media.py b/embodied_gen/utils/process_media.py
new file mode 100644
index 0000000..c5708e6
--- /dev/null
+++ b/embodied_gen/utils/process_media.py
@@ -0,0 +1,328 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+
+import base64
+import logging
+import math
+import os
+import subprocess
+import sys
+from glob import glob
+from io import BytesIO
+from typing import Union
+
+import cv2
+import imageio
+import numpy as np
+import PIL.Image as Image
+import spaces
+import torch
+from moviepy.editor import VideoFileClip, clips_array
+from tqdm import tqdm
+
+current_file_path = os.path.abspath(__file__)
+current_dir = os.path.dirname(current_file_path)
+sys.path.append(os.path.join(current_dir, "../.."))
+from thirdparty.TRELLIS.trellis.renderers.mesh_renderer import MeshRenderer
+from thirdparty.TRELLIS.trellis.representations import MeshExtractResult
+from thirdparty.TRELLIS.trellis.utils.render_utils import (
+ render_frames,
+ yaw_pitch_r_fov_to_extrinsics_intrinsics,
+)
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+__all__ = [
+ "render_asset3d",
+ "merge_images_video",
+ "filter_small_connected_components",
+ "filter_image_small_connected_components",
+ "combine_images_to_base64",
+ "render_mesh",
+ "render_video",
+ "create_mp4_from_images",
+ "create_gif_from_images",
+]
+
+
+@spaces.GPU
+def render_asset3d(
+ mesh_path: str,
+ output_root: str,
+ distance: float = 5.0,
+ num_images: int = 1,
+ elevation: list[float] = (0.0,),
+ pbr_light_factor: float = 1.5,
+ return_key: str = "image_color/*",
+ output_subdir: str = "renders",
+ gen_color_mp4: bool = False,
+ gen_viewnormal_mp4: bool = False,
+ gen_glonormal_mp4: bool = False,
+) -> list[str]:
+ command = [
+ "python3",
+ "embodied_gen/data/differentiable_render.py",
+ "--mesh_path",
+ mesh_path,
+ "--output_root",
+ output_root,
+ "--uuid",
+ output_subdir,
+ "--distance",
+ str(distance),
+ "--num_images",
+ str(num_images),
+ "--elevation",
+ *map(str, elevation),
+ "--pbr_light_factor",
+ str(pbr_light_factor),
+ "--with_mtl",
+ ]
+ if gen_color_mp4:
+ command.append("--gen_color_mp4")
+ if gen_viewnormal_mp4:
+ command.append("--gen_viewnormal_mp4")
+ if gen_glonormal_mp4:
+ command.append("--gen_glonormal_mp4")
+ try:
+ subprocess.run(command, check=True)
+ except subprocess.CalledProcessError as e:
+ logger.error(f"Error occurred during rendering: {e}.")
+
+ dst_paths = glob(os.path.join(output_root, output_subdir, return_key))
+
+ return dst_paths
+
+
+def merge_images_video(color_images, normal_images, output_path) -> None:
+ width = color_images[0].shape[1]
+ combined_video = [
+ np.hstack([rgb_img[:, : width // 2], normal_img[:, width // 2 :]])
+ for rgb_img, normal_img in zip(color_images, normal_images)
+ ]
+ imageio.mimsave(output_path, combined_video, fps=50)
+
+ return
+
+
+def merge_video_video(
+ video_path1: str, video_path2: str, output_path: str
+) -> None:
+ """Merge two videos by the left half and the right half of the videos."""
+ clip1 = VideoFileClip(video_path1)
+ clip2 = VideoFileClip(video_path2)
+
+ if clip1.size != clip2.size:
+ raise ValueError("The resolutions of the two videos do not match.")
+
+ width, height = clip1.size
+ clip1_half = clip1.crop(x1=0, y1=0, x2=width // 2, y2=height)
+ clip2_half = clip2.crop(x1=width // 2, y1=0, x2=width, y2=height)
+ final_clip = clips_array([[clip1_half, clip2_half]])
+ final_clip.write_videofile(output_path, codec="libx264")
+
+
+def filter_small_connected_components(
+ mask: Union[Image.Image, np.ndarray],
+ area_ratio: float,
+ connectivity: int = 8,
+) -> np.ndarray:
+ if isinstance(mask, Image.Image):
+ mask = np.array(mask)
+ num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(
+ mask,
+ connectivity=connectivity,
+ )
+
+ small_components = np.zeros_like(mask, dtype=np.uint8)
+ mask_area = (mask != 0).sum()
+ min_area = mask_area // area_ratio
+ for label in range(1, num_labels):
+ area = stats[label, cv2.CC_STAT_AREA]
+ if area < min_area:
+ small_components[labels == label] = 255
+
+ mask = cv2.bitwise_and(mask, cv2.bitwise_not(small_components))
+
+ return mask
+
+
+def filter_image_small_connected_components(
+ image: Union[Image.Image, np.ndarray],
+ area_ratio: float = 10,
+ connectivity: int = 8,
+) -> np.ndarray:
+ if isinstance(image, Image.Image):
+ image = image.convert("RGBA")
+ image = np.array(image)
+
+ mask = image[..., 3]
+ mask = filter_small_connected_components(mask, area_ratio, connectivity)
+ image[..., 3] = mask
+
+ return image
+
+
+def combine_images_to_base64(
+ images: list[str | Image.Image],
+ cat_row_col: tuple[int, int] = None,
+ target_wh: tuple[int, int] = (512, 512),
+) -> str:
+ n_images = len(images)
+ if cat_row_col is None:
+ n_col = math.ceil(math.sqrt(n_images))
+ n_row = math.ceil(n_images / n_col)
+ else:
+ n_row, n_col = cat_row_col
+
+ images = [
+ Image.open(p).convert("RGB") if isinstance(p, str) else p
+ for p in images[: n_row * n_col]
+ ]
+ images = [img.resize(target_wh) for img in images]
+
+ grid_w, grid_h = n_col * target_wh[0], n_row * target_wh[1]
+ grid = Image.new("RGB", (grid_w, grid_h), (255, 255, 255))
+
+ for idx, img in enumerate(images):
+ row, col = divmod(idx, n_col)
+ grid.paste(img, (col * target_wh[0], row * target_wh[1]))
+
+ buffer = BytesIO()
+ grid.save(buffer, format="PNG")
+
+ return base64.b64encode(buffer.getvalue()).decode("utf-8")
+
+
+@spaces.GPU
+def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs):
+ renderer = MeshRenderer()
+ renderer.rendering_options.resolution = options.get("resolution", 512)
+ renderer.rendering_options.near = options.get("near", 1)
+ renderer.rendering_options.far = options.get("far", 100)
+ renderer.rendering_options.ssaa = options.get("ssaa", 4)
+ rets = {}
+ for extr, intr in tqdm(zip(extrinsics, intrinsics), desc="Rendering"):
+ res = renderer.render(sample, extr, intr)
+ if "normal" not in rets:
+ rets["normal"] = []
+ normal = torch.lerp(
+ torch.zeros_like(res["normal"]), res["normal"], res["mask"]
+ )
+ normal = np.clip(
+ normal.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255
+ ).astype(np.uint8)
+ rets["normal"].append(normal)
+
+ return rets
+
+
+@spaces.GPU
+def render_video(
+ sample,
+ resolution=512,
+ bg_color=(0, 0, 0),
+ num_frames=300,
+ r=2,
+ fov=40,
+ **kwargs,
+):
+ yaws = torch.linspace(0, 2 * 3.1415, num_frames)
+ yaws = yaws.tolist()
+ pitch = [0.5] * num_frames
+ extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(
+ yaws, pitch, r, fov
+ )
+ render_fn = (
+ render_mesh if isinstance(sample, MeshExtractResult) else render_frames
+ )
+ result = render_fn(
+ sample,
+ extrinsics,
+ intrinsics,
+ {"resolution": resolution, "bg_color": bg_color},
+ **kwargs,
+ )
+
+ return result
+
+
+def create_mp4_from_images(images, output_path, fps=10, prompt=None):
+ font = cv2.FONT_HERSHEY_SIMPLEX
+ font_scale = 0.5
+ font_thickness = 1
+ color = (255, 255, 255)
+ position = (20, 25)
+
+ with imageio.get_writer(output_path, fps=fps) as writer:
+ for image in images:
+ image = image.clip(min=0, max=1)
+ image = (255.0 * image).astype(np.uint8)
+ image = image[..., :3]
+ if prompt is not None:
+ cv2.putText(
+ image,
+ prompt,
+ position,
+ font,
+ font_scale,
+ color,
+ font_thickness,
+ )
+
+ writer.append_data(image)
+
+ logger.info(f"MP4 video saved to {output_path}")
+
+
+def create_gif_from_images(images, output_path, fps=10):
+ pil_images = []
+ for image in images:
+ image = image.clip(min=0, max=1)
+ image = (255.0 * image).astype(np.uint8)
+ image = Image.fromarray(image, mode="RGBA")
+ pil_images.append(image.convert("RGB"))
+
+ duration = 1000 // fps
+ pil_images[0].save(
+ output_path,
+ save_all=True,
+ append_images=pil_images[1:],
+ duration=duration,
+ loop=0,
+ )
+
+ logger.info(f"GIF saved to {output_path}")
+
+
+if __name__ == "__main__":
+ # Example usage:
+ merge_video_video(
+ "outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh_glo_normal.mp4", # noqa
+ "outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh.mp4", # noqa
+ "merge.mp4",
+ )
+
+ image_base64 = combine_images_to_base64(
+ [
+ "apps/assets/example_image/sample_00.jpg",
+ "apps/assets/example_image/sample_01.jpg",
+ "apps/assets/example_image/sample_02.jpg",
+ ]
+ )
diff --git a/embodied_gen/utils/tags.py b/embodied_gen/utils/tags.py
new file mode 100644
index 0000000..07a9a63
--- /dev/null
+++ b/embodied_gen/utils/tags.py
@@ -0,0 +1 @@
+VERSION = "v0.1.0"
diff --git a/embodied_gen/validators/aesthetic_predictor.py b/embodied_gen/validators/aesthetic_predictor.py
new file mode 100644
index 0000000..5b9c557
--- /dev/null
+++ b/embodied_gen/validators/aesthetic_predictor.py
@@ -0,0 +1,149 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+
+import os
+
+import clip
+import numpy as np
+import pytorch_lightning as pl
+import torch
+import torch.nn as nn
+from huggingface_hub import snapshot_download
+from PIL import Image
+
+
+class AestheticPredictor:
+ """Aesthetic Score Predictor.
+
+ Checkpoints from https://github.com/christophschuhmann/improved-aesthetic-predictor/tree/main
+
+ Args:
+ clip_model_dir (str): Path to the directory of the CLIP model.
+ sac_model_path (str): Path to the pre-trained SAC model.
+ device (str): Device to use for computation ("cuda" or "cpu").
+ """
+
+ def __init__(self, clip_model_dir=None, sac_model_path=None, device="cpu"):
+
+ self.device = device
+
+ if clip_model_dir is None:
+ model_path = snapshot_download(
+ repo_id="xinjjj/RoboAssetGen", allow_patterns="aesthetic/*"
+ )
+ suffix = "aesthetic"
+ model_path = snapshot_download(
+ repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
+ )
+ clip_model_dir = os.path.join(model_path, suffix)
+
+ if sac_model_path is None:
+ model_path = snapshot_download(
+ repo_id="xinjjj/RoboAssetGen", allow_patterns="aesthetic/*"
+ )
+ suffix = "aesthetic"
+ model_path = snapshot_download(
+ repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
+ )
+ sac_model_path = os.path.join(
+ model_path, suffix, "sac+logos+ava1-l14-linearMSE.pth"
+ )
+
+ self.clip_model, self.preprocess = self._load_clip_model(
+ clip_model_dir
+ )
+ self.sac_model = self._load_sac_model(sac_model_path, input_size=768)
+
+ class MLP(pl.LightningModule): # noqa
+ def __init__(self, input_size):
+ super().__init__()
+ self.layers = nn.Sequential(
+ nn.Linear(input_size, 1024),
+ nn.Dropout(0.2),
+ nn.Linear(1024, 128),
+ nn.Dropout(0.2),
+ nn.Linear(128, 64),
+ nn.Dropout(0.1),
+ nn.Linear(64, 16),
+ nn.Linear(16, 1),
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+ @staticmethod
+ def normalized(a, axis=-1, order=2):
+ """Normalize the array to unit norm."""
+ l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
+ l2[l2 == 0] = 1
+ return a / np.expand_dims(l2, axis)
+
+ def _load_clip_model(self, model_dir: str, model_name: str = "ViT-L/14"):
+ """Load the CLIP model."""
+ model, preprocess = clip.load(
+ model_name, download_root=model_dir, device=self.device
+ )
+ return model, preprocess
+
+ def _load_sac_model(self, model_path, input_size):
+ """Load the SAC model."""
+ model = self.MLP(input_size)
+ ckpt = torch.load(model_path)
+ model.load_state_dict(ckpt)
+ model.to(self.device)
+ model.eval()
+ return model
+
+ def predict(self, image_path):
+ """Predict the aesthetic score for a given image.
+
+ Args:
+ image_path (str): Path to the image file.
+
+ Returns:
+ float: Predicted aesthetic score.
+ """
+ pil_image = Image.open(image_path)
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
+
+ with torch.no_grad():
+ # Extract CLIP features
+ image_features = self.clip_model.encode_image(image)
+ # Normalize features
+ normalized_features = self.normalized(
+ image_features.cpu().detach().numpy()
+ )
+ # Predict score
+ prediction = self.sac_model(
+ torch.from_numpy(normalized_features)
+ .type(torch.FloatTensor)
+ .to(self.device)
+ )
+
+ return prediction.item()
+
+
+if __name__ == "__main__":
+ # Configuration
+ img_path = "apps/assets/example_image/sample_00.jpg"
+
+ # Initialize the predictor
+ predictor = AestheticPredictor()
+
+ # Predict the aesthetic score
+ score = predictor.predict(img_path)
+ print("Aesthetic score predicted by the model:", score)
diff --git a/embodied_gen/validators/quality_checkers.py b/embodied_gen/validators/quality_checkers.py
new file mode 100644
index 0000000..4608c6b
--- /dev/null
+++ b/embodied_gen/validators/quality_checkers.py
@@ -0,0 +1,242 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+
+import logging
+import os
+
+from tqdm import tqdm
+from embodied_gen.utils.gpt_clients import GPT_CLIENT, GPTclient
+from embodied_gen.utils.process_media import render_asset3d
+from embodied_gen.validators.aesthetic_predictor import AestheticPredictor
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+class BaseChecker:
+ def __init__(self, prompt: str = None, verbose: bool = False) -> None:
+ self.prompt = prompt
+ self.verbose = verbose
+
+ def query(self, *args, **kwargs):
+ raise NotImplementedError(
+ "Subclasses must implement the query method."
+ )
+
+ def __call__(self, *args, **kwargs) -> bool:
+ response = self.query(*args, **kwargs)
+ if response is None:
+ response = "Error when calling gpt api."
+
+ if self.verbose and response != "YES":
+ logger.info(response)
+
+ flag = "YES" in response
+ response = "YES" if flag else response
+
+ return flag, response
+
+ @staticmethod
+ def validate(
+ checkers: list["BaseChecker"], images_list: list[list[str]]
+ ) -> list:
+ assert len(checkers) == len(images_list)
+ results = []
+ overall_result = True
+ for checker, images in zip(checkers, images_list):
+ qa_flag, qa_info = checker(images)
+ if isinstance(qa_info, str):
+ qa_info = qa_info.replace("\n", ".")
+ results.append([checker.__class__.__name__, qa_info])
+ if qa_flag is False:
+ overall_result = False
+
+ results.append(["overall", "YES" if overall_result else "NO"])
+
+ return results
+
+
+class MeshGeoChecker(BaseChecker):
+ """A geometry quality checker for 3D mesh assets using GPT-based reasoning.
+
+ This class leverages a multi-modal GPT client to analyze rendered images
+ of a 3D object and determine if its geometry is complete.
+
+ Attributes:
+ gpt_client (GPTclient): The GPT client used for multi-modal querying.
+ prompt (str): The prompt sent to the GPT model. If not provided, a default one is used.
+ verbose (bool): Whether to print debug information during evaluation.
+ """
+
+ def __init__(
+ self,
+ gpt_client: GPTclient,
+ prompt: str = None,
+ verbose: bool = False,
+ ) -> None:
+ super().__init__(prompt, verbose)
+ self.gpt_client = gpt_client
+ if self.prompt is None:
+ self.prompt = """
+ Refer to the provided multi-view rendering images to evaluate
+ whether the geometry of the 3D object asset is complete and
+ whether the asset can be placed stably on the ground.
+ Return "YES" only if reach the requirments,
+ otherwise "NO" and explain the reason very briefly.
+ """
+
+ def query(self, image_paths: str) -> str:
+ # Hardcode tmp because of the openrouter can't input multi images.
+ if "openrouter" in self.gpt_client.endpoint:
+ from embodied_gen.utils.process_media import (
+ combine_images_to_base64,
+ )
+
+ image_paths = combine_images_to_base64(image_paths)
+
+ return self.gpt_client.query(
+ text_prompt=self.prompt,
+ image_base64=image_paths,
+ )
+
+
+class ImageSegChecker(BaseChecker):
+ """A segmentation quality checker for 3D assets using GPT-based reasoning.
+
+ This class compares an original image with its segmented version to
+ evaluate whether the segmentation successfully isolates the main object
+ with minimal truncation and correct foreground extraction.
+
+ Attributes:
+ gpt_client (GPTclient): GPT client used for multi-modal image analysis.
+ prompt (str): The prompt used to guide the GPT model for evaluation.
+ verbose (bool): Whether to enable verbose logging.
+ """
+
+ def __init__(
+ self,
+ gpt_client: GPTclient,
+ prompt: str = None,
+ verbose: bool = False,
+ ) -> None:
+ super().__init__(prompt, verbose)
+ self.gpt_client = gpt_client
+ if self.prompt is None:
+ self.prompt = """
+ The first image is the original, and the second image is the
+ result after segmenting the main object. Evaluate the segmentation
+ quality to ensure the main object is clearly segmented without
+ significant truncation. Note that the foreground of the object
+ needs to be extracted instead of the background.
+ Minor imperfections can be ignored. If segmentation is acceptable,
+ return "YES" only; otherwise, return "NO" with
+ very brief explanation.
+ """
+
+ def query(self, image_paths: list[str]) -> str:
+ if len(image_paths) != 2:
+ raise ValueError(
+ "ImageSegChecker requires exactly two images: [raw_image, seg_image]." # noqa
+ )
+ # Hardcode tmp because of the openrouter can't input multi images.
+ if "openrouter" in self.gpt_client.endpoint:
+ from embodied_gen.utils.process_media import (
+ combine_images_to_base64,
+ )
+
+ image_paths = combine_images_to_base64(image_paths)
+
+ return self.gpt_client.query(
+ text_prompt=self.prompt,
+ image_base64=image_paths,
+ )
+
+
+class ImageAestheticChecker(BaseChecker):
+ """A class for evaluating the aesthetic quality of images.
+
+ Attributes:
+ clip_model_dir (str): Path to the CLIP model directory.
+ sac_model_path (str): Path to the aesthetic predictor model weights.
+ thresh (float): Threshold above which images are considered aesthetically acceptable.
+ verbose (bool): Whether to print detailed log messages.
+ predictor (AestheticPredictor): The model used to predict aesthetic scores.
+ """
+
+ def __init__(
+ self,
+ clip_model_dir: str = None,
+ sac_model_path: str = None,
+ thresh: float = 4.50,
+ verbose: bool = False,
+ ) -> None:
+ super().__init__(verbose=verbose)
+ self.clip_model_dir = clip_model_dir
+ self.sac_model_path = sac_model_path
+ self.thresh = thresh
+ self.predictor = AestheticPredictor(clip_model_dir, sac_model_path)
+
+ def query(self, image_paths: list[str]) -> float:
+ scores = [self.predictor.predict(img_path) for img_path in image_paths]
+ return sum(scores) / len(scores)
+
+ def __call__(self, image_paths: list[str], **kwargs) -> bool:
+ avg_score = self.query(image_paths)
+ if self.verbose:
+ logger.info(f"Average aesthetic score: {avg_score}")
+ return avg_score > self.thresh, avg_score
+
+
+if __name__ == "__main__":
+ geo_checker = MeshGeoChecker(GPT_CLIENT)
+ seg_checker = ImageSegChecker(GPT_CLIENT)
+ aesthetic_checker = ImageAestheticChecker()
+
+ checkers = [geo_checker, seg_checker, aesthetic_checker]
+
+ output_root = "outputs/test_gpt"
+
+ fails = []
+ for idx in tqdm(range(150)):
+ mesh_path = f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}.obj" # noqa
+ if not os.path.exists(mesh_path):
+ continue
+ image_paths = render_asset3d(
+ mesh_path,
+ f"{output_root}/{idx}",
+ num_images=8,
+ elevation=(30, -30),
+ distance=5.5,
+ )
+
+ for cid, checker in enumerate(checkers):
+ if isinstance(checker, ImageSegChecker):
+ images = [
+ f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}_raw.png", # noqa
+ f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}_cond.png", # noqa
+ ]
+ else:
+ images = image_paths
+ result, info = checker(images)
+ logger.info(
+ f"Checker {checker.__class__.__name__}: {result}, {info}, mesh {mesh_path}" # noqa
+ )
+
+ if result is False:
+ fails.append((idx, cid, info))
+
+ break
diff --git a/embodied_gen/validators/urdf_convertor.py b/embodied_gen/validators/urdf_convertor.py
new file mode 100644
index 0000000..eed9e07
--- /dev/null
+++ b/embodied_gen/validators/urdf_convertor.py
@@ -0,0 +1,419 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+
+import logging
+import os
+import shutil
+import xml.etree.ElementTree as ET
+from datetime import datetime
+from xml.dom.minidom import parseString
+
+import numpy as np
+import trimesh
+from embodied_gen.data.utils import zip_files
+from embodied_gen.utils.gpt_clients import GPT_CLIENT, GPTclient
+from embodied_gen.utils.process_media import render_asset3d
+from embodied_gen.utils.tags import VERSION
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+__all__ = ["URDFGenerator"]
+
+
+URDF_TEMPLATE = """
+
+
+
+
+
+
+
+
+
+
+
+
+ 0.8
+ 0.6
+
+
+
+
+
+
+
+
+ 1.0
+ "0.0.0"
+ "unknown"
+ "unknown"
+ 0.0
+ 0.0
+ 0.0
+ 0.0
+ 0.0
+ "-1"
+ ""
+
+
+
+"""
+
+
+class URDFGenerator(object):
+ def __init__(
+ self,
+ gpt_client: GPTclient,
+ mesh_file_list: list[str] = ["material_0.png", "material.mtl"],
+ prompt_template: str = None,
+ attrs_name: list[str] = None,
+ render_dir: str = "urdf_renders",
+ render_view_num: int = 4,
+ ) -> None:
+ if mesh_file_list is None:
+ mesh_file_list = []
+ self.mesh_file_list = mesh_file_list
+ self.output_mesh_dir = "mesh"
+ self.output_render_dir = render_dir
+ self.gpt_client = gpt_client
+ self.render_view_num = render_view_num
+ if render_view_num == 4:
+ view_desc = "This is orthographic projection showing the front, left, right and back views " # noqa
+ else:
+ view_desc = "This is the rendered views "
+
+ if prompt_template is None:
+ prompt_template = (
+ view_desc
+ + """of the 3D object asset,
+ category: {category}.
+ Give the category of this object asset (within 3 words),
+ (if category is already provided, use it directly),
+ accurately describe this 3D object asset (within 15 words),
+ and give the recommended geometric height range (unit: meter),
+ weight range (unit: kilogram), the average static friction
+ coefficient of the object relative to rubber and the average
+ dynamic friction coefficient of the object relative to rubber.
+ Return response format as shown in Example.
+
+ Example:
+ Category: cup
+ Description: shiny golden cup with floral design
+ Height: 0.1-0.15 m
+ Weight: 0.3-0.6 kg
+ Static friction coefficient: 1.1
+ Dynamic friction coefficient: 0.9
+ """
+ )
+
+ self.prompt_template = prompt_template
+ if attrs_name is None:
+ attrs_name = [
+ "category",
+ "description",
+ "min_height",
+ "max_height",
+ "real_height",
+ "min_mass",
+ "max_mass",
+ "version",
+ "generate_time",
+ "gs_model",
+ ]
+ self.attrs_name = attrs_name
+
+ def parse_response(self, response: str) -> dict[str, any]:
+ lines = response.split("\n")
+ lines = [line.strip() for line in lines if line]
+ category = lines[0].split(": ")[1]
+ description = lines[1].split(": ")[1]
+ min_height, max_height = map(
+ lambda x: float(x.strip().replace(",", "").split()[0]),
+ lines[2].split(": ")[1].split("-"),
+ )
+ min_mass, max_mass = map(
+ lambda x: float(x.strip().replace(",", "").split()[0]),
+ lines[3].split(": ")[1].split("-"),
+ )
+ mu1 = float(lines[4].split(": ")[1].replace(",", ""))
+ mu2 = float(lines[5].split(": ")[1].replace(",", ""))
+
+ return {
+ "category": category.lower(),
+ "description": description.lower(),
+ "min_height": round(min_height, 4),
+ "max_height": round(max_height, 4),
+ "min_mass": round(min_mass, 4),
+ "max_mass": round(max_mass, 4),
+ "mu1": round(mu1, 2),
+ "mu2": round(mu2, 2),
+ "version": VERSION,
+ "generate_time": datetime.now().strftime("%Y%m%d%H%M%S"),
+ }
+
+ def generate_urdf(
+ self,
+ input_mesh: str,
+ output_dir: str,
+ attr_dict: dict,
+ output_name: str = None,
+ ) -> str:
+ """Generate a URDF file for a given mesh with specified attributes.
+
+ Args:
+ input_mesh (str): Path to the input mesh file.
+ output_dir (str): Directory to store the generated URDF
+ and processed mesh.
+ attr_dict (dict): Dictionary containing attributes like height,
+ mass, and friction coefficients.
+ output_name (str, optional): Name for the generated URDF and robot.
+
+ Returns:
+ str: Path to the generated URDF file.
+ """
+
+ # 1. Load and normalize the mesh
+ mesh = trimesh.load(input_mesh)
+ mesh_scale = np.ptp(mesh.vertices, axis=0).max()
+ mesh.vertices /= mesh_scale # Normalize to [-0.5, 0.5]
+ raw_height = np.ptp(mesh.vertices, axis=0)[1]
+
+ # 2. Scale the mesh to real height
+ real_height = attr_dict["real_height"]
+ scale = round(real_height / raw_height, 6)
+ mesh = mesh.apply_scale(scale)
+
+ # 3. Prepare output directories and save scaled mesh
+ mesh_folder = os.path.join(output_dir, self.output_mesh_dir)
+ os.makedirs(mesh_folder, exist_ok=True)
+
+ obj_name = os.path.basename(input_mesh)
+ mesh_output_path = os.path.join(mesh_folder, obj_name)
+ mesh.export(mesh_output_path)
+
+ # 4. Copy additional mesh files, if any
+ input_dir = os.path.dirname(input_mesh)
+ for file in self.mesh_file_list:
+ src_file = os.path.join(input_dir, file)
+ dest_file = os.path.join(mesh_folder, file)
+ if os.path.isfile(src_file):
+ shutil.copy(src_file, dest_file)
+
+ # 5. Determine output name
+ if output_name is None:
+ output_name = os.path.splitext(obj_name)[0]
+
+ # 6. Load URDF template and update attributes
+ robot = ET.fromstring(URDF_TEMPLATE)
+ robot.set("name", output_name)
+
+ link = robot.find("link")
+ if link is None:
+ raise ValueError("URDF template is missing 'link' element.")
+ link.set("name", output_name)
+
+ # Update visual geometry
+ visual = link.find("visual/geometry/mesh")
+ if visual is not None:
+ visual.set(
+ "filename", os.path.join(self.output_mesh_dir, obj_name)
+ )
+ visual.set("scale", "1.0 1.0 1.0")
+
+ # Update collision geometry
+ collision = link.find("collision/geometry/mesh")
+ if collision is not None:
+ collision.set(
+ "filename", os.path.join(self.output_mesh_dir, obj_name)
+ )
+ collision.set("scale", "1.0 1.0 1.0")
+
+ # Update friction coefficients
+ gazebo = link.find("collision/gazebo")
+ if gazebo is not None:
+ for param, key in zip(["mu1", "mu2"], ["mu1", "mu2"]):
+ element = gazebo.find(param)
+ if element is not None:
+ element.text = f"{attr_dict[key]:.2f}"
+
+ # Update mass
+ inertial = link.find("inertial/mass")
+ if inertial is not None:
+ mass_value = (attr_dict["min_mass"] + attr_dict["max_mass"]) / 2
+ inertial.set("value", f"{mass_value:.4f}")
+
+ # Add extra_info element to the link
+ extra_info = link.find("extra_info/scale")
+ if extra_info is not None:
+ extra_info.text = f"{scale:.6f}"
+
+ for key in self.attrs_name:
+ extra_info = link.find(f"extra_info/{key}")
+ if extra_info is not None and key in attr_dict:
+ extra_info.text = f"{attr_dict[key]}"
+
+ # 7. Write URDF to file
+ os.makedirs(output_dir, exist_ok=True)
+ urdf_path = os.path.join(output_dir, f"{output_name}.urdf")
+ tree = ET.ElementTree(robot)
+ tree.write(urdf_path, encoding="utf-8", xml_declaration=True)
+
+ logger.info(f"URDF file saved to {urdf_path}")
+
+ return urdf_path
+
+ @staticmethod
+ def get_attr_from_urdf(
+ urdf_path: str,
+ attr_root: str = ".//link/extra_info",
+ attr_name: str = "scale",
+ ) -> float:
+ if not os.path.exists(urdf_path):
+ raise FileNotFoundError(f"URDF file not found: {urdf_path}")
+
+ mesh_scale = 1.0
+ tree = ET.parse(urdf_path)
+ root = tree.getroot()
+ extra_info = root.find(attr_root)
+ if extra_info is not None:
+ scale_element = extra_info.find(attr_name)
+ if scale_element is not None:
+ mesh_scale = float(scale_element.text)
+
+ return mesh_scale
+
+ @staticmethod
+ def add_quality_tag(
+ urdf_path: str, results, output_path: str = None
+ ) -> None:
+ if output_path is None:
+ output_path = urdf_path
+
+ tree = ET.parse(urdf_path)
+ root = tree.getroot()
+ custom_data = ET.SubElement(root, "custom_data")
+ quality = ET.SubElement(custom_data, "quality")
+ for key, value in results:
+ checker_tag = ET.SubElement(quality, key)
+ checker_tag.text = str(value)
+
+ rough_string = ET.tostring(root, encoding="utf-8")
+ formatted_string = parseString(rough_string).toprettyxml(indent=" ")
+ cleaned_string = "\n".join(
+ [line for line in formatted_string.splitlines() if line.strip()]
+ )
+
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+ with open(output_path, "w", encoding="utf-8") as f:
+ f.write(cleaned_string)
+
+ logger.info(f"URDF files saved to {output_path}")
+
+ def get_estimated_attributes(self, asset_attrs: dict):
+ estimated_attrs = {
+ "height": round(
+ (asset_attrs["min_height"] + asset_attrs["max_height"]) / 2, 4
+ ),
+ "mass": round(
+ (asset_attrs["min_mass"] + asset_attrs["max_mass"]) / 2, 4
+ ),
+ "mu": round((asset_attrs["mu1"] + asset_attrs["mu2"]) / 2, 4),
+ "category": asset_attrs["category"],
+ }
+
+ return estimated_attrs
+
+ def __call__(
+ self,
+ mesh_path: str,
+ output_root: str,
+ text_prompt: str = None,
+ category: str = "unknown",
+ **kwargs,
+ ):
+ if text_prompt is None or len(text_prompt) == 0:
+ text_prompt = self.prompt_template
+ text_prompt = text_prompt.format(category=category.lower())
+
+ image_path = render_asset3d(
+ mesh_path,
+ output_root,
+ num_images=self.render_view_num,
+ output_subdir=self.output_render_dir,
+ )
+
+ # Hardcode tmp because of the openrouter can't input multi images.
+ if "openrouter" in self.gpt_client.endpoint:
+ from embodied_gen.utils.process_media import (
+ combine_images_to_base64,
+ )
+
+ image_path = combine_images_to_base64(image_path)
+
+ response = self.gpt_client.query(text_prompt, image_path)
+ if response is None:
+ asset_attrs = {
+ "category": category.lower(),
+ "description": category.lower(),
+ "min_height": 1,
+ "max_height": 1,
+ "min_mass": 1,
+ "max_mass": 1,
+ "mu1": 0.8,
+ "mu2": 0.6,
+ "version": VERSION,
+ "generate_time": datetime.now().strftime("%Y%m%d%H%M%S"),
+ }
+ else:
+ asset_attrs = self.parse_response(response)
+ for key in self.attrs_name:
+ if key in kwargs:
+ asset_attrs[key] = kwargs[key]
+
+ asset_attrs["real_height"] = round(
+ (asset_attrs["min_height"] + asset_attrs["max_height"]) / 2, 4
+ )
+
+ self.estimated_attrs = self.get_estimated_attributes(asset_attrs)
+
+ urdf_path = self.generate_urdf(mesh_path, output_root, asset_attrs)
+
+ logger.info(f"response: {response}")
+
+ return urdf_path
+
+
+if __name__ == "__main__":
+ urdf_gen = URDFGenerator(GPT_CLIENT, render_view_num=4)
+ urdf_path = urdf_gen(
+ mesh_path="outputs/imageto3d/cma/o5/URDF_o5/mesh/o5.obj",
+ output_root="outputs/test_urdf",
+ # category="coffee machine",
+ # min_height=1.0,
+ # max_height=1.2,
+ version=VERSION,
+ )
+
+ # zip_files(
+ # input_paths=[
+ # "scripts/apps/tmp/2umpdum3e5n/URDF_sample/mesh",
+ # "scripts/apps/tmp/2umpdum3e5n/URDF_sample/sample.urdf"
+ # ],
+ # output_zip="zip.zip"
+ # )
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..c261905
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,51 @@
+[build-system]
+requires = ["setuptools", "wheel", "build"]
+build-backend = "setuptools.build_meta"
+
+[tool.setuptools]
+packages = ["embodied_gen"]
+
+[project]
+name = "embodied_gen"
+version = "v0.1.0"
+readme = "README.md"
+license = "Apache-2.0"
+license-files = ["LICENSE", "NOTICE"]
+
+dependencies = []
+requires-python = ">=3.10"
+
+[project.optional-dependencies]
+dev = [
+ "cpplint==2.0.0",
+ "pre-commit==2.13.0",
+ "pydocstyle",
+ "black",
+ "isort",
+]
+
+[project.scripts]
+drender-cli = "embodied_gen.data.differentiable_render:entrypoint"
+backproject-cli = "embodied_gen.data.backproject_v2:entrypoint"
+
+[tool.pydocstyle]
+match = '(?!test_).*(?!_pb2)\.py'
+match-dir = '^(?!(raw|projects|tools|k8s_submit|thirdparty)$)[\w.-]+$'
+convention = "google"
+add-ignore = 'D104,D107,D202,D105,D100,D102,D103,D101,E203'
+
+[tool.pycodestyle]
+max-line-length = 79
+ignore = "E203"
+
+[tool.black]
+line-length = 79
+exclude = "thirdparty"
+skip-string-normalization = true
+
+[tool.isort]
+line_length = 79
+profile = 'black'
+no_lines_before = 'FIRSTPARTY'
+known_first_party = ['embodied_gen']
+skip = "thirdparty/"
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..c1812df
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,41 @@
+torch==2.4.0+cu118
+torchvision==0.19.0+cu118
+xformers==0.0.27.post2
+pytorch-lightning==2.4.0
+spconv-cu120==2.3.6
+numpy==1.26.4
+triton==2.1.0
+dataclasses_json
+easydict
+opencv-python>4.5
+imageio==2.36.1
+imageio-ffmpeg==0.5.1
+rembg==2.0.61
+trimesh==4.4.4
+moviepy==1.0.3
+pymeshfix==0.17.0
+igraph==0.11.8
+pyvista==0.36.1
+openai==1.58.1
+transformers==4.42.4
+gradio==5.12.0
+sentencepiece==0.2.0
+diffusers==0.31.0
+xatlas==0.0.9
+onnxruntime==1.20.1
+tenacity==8.2.2
+accelerate==0.33.0
+basicsr==1.4.2
+realesrgan==0.3.0
+pydantic==2.9.2
+vtk==9.3.1
+spaces
+utils3d@git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8
+clip@git+https://github.com/openai/CLIP.git
+kolors@git+https://github.com/Kwai-Kolors/Kolors.git#egg=038818d
+segment-anything@git+https://github.com/facebookresearch/segment-anything.git#egg=dca509f
+nvdiffrast@git+https://github.com/NVlabs/nvdiffrast.git#egg=729261d
+kaolin@git+https://github.com/NVIDIAGameWorks/kaolin.git@v0.16.0
+https://github.com/nerfstudio-project/gsplat/releases/download/v1.5.0/gsplat-1.5.0+pt24cu118-cp310-cp310-linux_x86_64.whl
+https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.0.post2/flash_attn-2.7.0.post2+cu11torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
+https://huggingface.co/xinjjj/RoboAssetGen/resolve/main/wheel_cu118/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl
diff --git a/scripts/autoformat.sh b/scripts/autoformat.sh
new file mode 100644
index 0000000..bf12f78
--- /dev/null
+++ b/scripts/autoformat.sh
@@ -0,0 +1,10 @@
+#!/bin/bash
+
+ROOT_DIR=${1}
+
+set -e
+
+black --config=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./
+isort --settings-file=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./
+pycodestyle --show-source --config=${ROOT_DIR}setup.cfg ${ROOT_DIR}./
+pydocstyle --config=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./
diff --git a/scripts/check_lint.py b/scripts/check_lint.py
new file mode 100644
index 0000000..991ad1b
--- /dev/null
+++ b/scripts/check_lint.py
@@ -0,0 +1,61 @@
+import argparse
+import os
+import subprocess
+import sys
+
+
+def get_root():
+ current_file_path = os.path.abspath(__file__)
+ root_path = os.path.dirname(current_file_path)
+ for _ in range(2):
+ root_path = os.path.dirname(root_path)
+ return root_path
+
+
+def cpp_lint(root_path: str):
+ # run external python file to lint cpp
+ subprocess.check_call(
+ " ".join(
+ [
+ "python3",
+ f"{root_path}/scripts/lint_src/lint.py",
+ "--project=asset_recons",
+ "--path",
+ f"{root_path}/src/",
+ f"{root_path}/include/",
+ f"{root_path}/module/",
+ "--exclude_path",
+ f"{root_path}/module/web_viz/front_end/",
+ ]
+ ),
+ shell=True,
+ )
+
+
+def python_lint(root_path: str, auto_format: bool = False):
+ # run external python file to lint python
+ subprocess.check_call(
+ " ".join(
+ [
+ "bash",
+ (
+ f"{root_path}/scripts/lint/check_pylint.sh"
+ if not auto_format
+ else f"{root_path}/scripts/lint/autoformat.sh"
+ ),
+ f"{root_path}/",
+ ]
+ ),
+ shell=True,
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="check format.")
+ parser.add_argument(
+ "--auto_format", action="store_true", help="auto format python"
+ )
+ parser = parser.parse_args()
+ root_path = get_root()
+ cpp_lint(root_path)
+ python_lint(root_path, parser.auto_format)
diff --git a/scripts/check_pylint.sh b/scripts/check_pylint.sh
new file mode 100644
index 0000000..8938332
--- /dev/null
+++ b/scripts/check_pylint.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+
+ROOT_DIR=${1}
+
+set -e
+
+
+pycodestyle --show-source --config=${ROOT_DIR}setup.cfg ${ROOT_DIR}./
+pydocstyle --config=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./
+black --check --diff --config=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./
+isort --diff --settings-file=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./
diff --git a/scripts/lint/autoformat.sh b/scripts/lint/autoformat.sh
new file mode 100644
index 0000000..bf12f78
--- /dev/null
+++ b/scripts/lint/autoformat.sh
@@ -0,0 +1,10 @@
+#!/bin/bash
+
+ROOT_DIR=${1}
+
+set -e
+
+black --config=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./
+isort --settings-file=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./
+pycodestyle --show-source --config=${ROOT_DIR}setup.cfg ${ROOT_DIR}./
+pydocstyle --config=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./
diff --git a/scripts/lint/check_lint.py b/scripts/lint/check_lint.py
new file mode 100644
index 0000000..991ad1b
--- /dev/null
+++ b/scripts/lint/check_lint.py
@@ -0,0 +1,61 @@
+import argparse
+import os
+import subprocess
+import sys
+
+
+def get_root():
+ current_file_path = os.path.abspath(__file__)
+ root_path = os.path.dirname(current_file_path)
+ for _ in range(2):
+ root_path = os.path.dirname(root_path)
+ return root_path
+
+
+def cpp_lint(root_path: str):
+ # run external python file to lint cpp
+ subprocess.check_call(
+ " ".join(
+ [
+ "python3",
+ f"{root_path}/scripts/lint_src/lint.py",
+ "--project=asset_recons",
+ "--path",
+ f"{root_path}/src/",
+ f"{root_path}/include/",
+ f"{root_path}/module/",
+ "--exclude_path",
+ f"{root_path}/module/web_viz/front_end/",
+ ]
+ ),
+ shell=True,
+ )
+
+
+def python_lint(root_path: str, auto_format: bool = False):
+ # run external python file to lint python
+ subprocess.check_call(
+ " ".join(
+ [
+ "bash",
+ (
+ f"{root_path}/scripts/lint/check_pylint.sh"
+ if not auto_format
+ else f"{root_path}/scripts/lint/autoformat.sh"
+ ),
+ f"{root_path}/",
+ ]
+ ),
+ shell=True,
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="check format.")
+ parser.add_argument(
+ "--auto_format", action="store_true", help="auto format python"
+ )
+ parser = parser.parse_args()
+ root_path = get_root()
+ cpp_lint(root_path)
+ python_lint(root_path, parser.auto_format)
diff --git a/scripts/lint/check_pylint.sh b/scripts/lint/check_pylint.sh
new file mode 100644
index 0000000..8938332
--- /dev/null
+++ b/scripts/lint/check_pylint.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+
+ROOT_DIR=${1}
+
+set -e
+
+
+pycodestyle --show-source --config=${ROOT_DIR}setup.cfg ${ROOT_DIR}./
+pydocstyle --config=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./
+black --check --diff --config=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./
+isort --diff --settings-file=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./
diff --git a/scripts/lint_src/cpplint.hook b/scripts/lint_src/cpplint.hook
new file mode 100755
index 0000000..a9d87ee
--- /dev/null
+++ b/scripts/lint_src/cpplint.hook
@@ -0,0 +1,10 @@
+#!/bin/bash
+
+TOTAL_ERRORS=0
+if [[ ! $(which cpplint) ]]; then
+ pip install cpplint
+fi
+# diff files on local machine.
+files=$(git diff --cached --name-status | awk '$1 != "D" {print $2}')
+python3 scripts/lint_src/lint.py --project=asset_recons --path $files --exclude_path thirdparty patch_files;
+
diff --git a/scripts/lint_src/lint.py b/scripts/lint_src/lint.py
new file mode 100644
index 0000000..5003569
--- /dev/null
+++ b/scripts/lint_src/lint.py
@@ -0,0 +1,155 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+
+import argparse
+import codecs
+import os
+import re
+import sys
+
+import cpplint
+import pycodestyle
+from cpplint import _cpplint_state
+
+CXX_SUFFIX = set(["cc", "c", "cpp", "h", "cu", "hpp"])
+
+
+def filepath_enumerate(paths):
+ """Enumerate the file paths of all subfiles of the list of paths."""
+ out = []
+ for path in paths:
+ if os.path.isfile(path):
+ out.append(path)
+ else:
+ for root, dirs, files in os.walk(path):
+ for name in files:
+ out.append(os.path.normpath(os.path.join(root, name)))
+ return out
+
+
+class LintHelper(object):
+ @staticmethod
+ def _print_summary_map(strm, result_map, ftype):
+ """Print summary of certain result map."""
+ if len(result_map) == 0:
+ return 0
+ npass = len([x for k, x in result_map.items() if len(x) == 0])
+ strm.write(
+ "=====%d/%d %s files passed check=====\n"
+ % (npass, len(result_map), ftype)
+ )
+ for fname, emap in result_map.items():
+ if len(emap) == 0:
+ continue
+ strm.write(
+ "%s: %d Errors of %d Categories map=%s\n"
+ % (fname, sum(emap.values()), len(emap), str(emap))
+ )
+ return len(result_map) - npass
+
+ def __init__(self) -> None:
+ self.project_name = None
+ self.cpp_header_map = {}
+ self.cpp_src_map = {}
+ super().__init__()
+ cpplint_args = [".", "--extensions=" + (",".join(CXX_SUFFIX))]
+ _ = cpplint.ParseArguments(cpplint_args)
+ cpplint._SetFilters(
+ ",".join(
+ [
+ "-build/c++11",
+ "-build/namespaces",
+ "-build/include,",
+ "+build/include_what_you_use",
+ "+build/include_order",
+ ]
+ )
+ )
+ cpplint._SetCountingStyle("toplevel")
+ cpplint._line_length = 80
+
+ def process_cpp(self, path, suffix):
+ """Process a cpp file."""
+ _cpplint_state.ResetErrorCounts()
+ cpplint.ProcessFile(str(path), _cpplint_state.verbose_level)
+ _cpplint_state.PrintErrorCounts()
+ errors = _cpplint_state.errors_by_category.copy()
+
+ if suffix == "h":
+ self.cpp_header_map[str(path)] = errors
+ else:
+ self.cpp_src_map[str(path)] = errors
+
+ def print_summary(self, strm):
+ """Print summary of lint."""
+ nerr = 0
+ nerr += LintHelper._print_summary_map(
+ strm, self.cpp_header_map, "cpp-header"
+ )
+ nerr += LintHelper._print_summary_map(
+ strm, self.cpp_src_map, "cpp-source"
+ )
+ if nerr == 0:
+ strm.write("All passed!\n")
+ else:
+ strm.write("%d files failed lint\n" % nerr)
+ return nerr
+
+
+# singleton helper for lint check
+_HELPER = LintHelper()
+
+
+def process(fname, allow_type):
+ """Process a file."""
+ fname = str(fname)
+ arr = fname.rsplit(".", 1)
+ if fname.find("#") != -1 or arr[-1] not in allow_type:
+ return
+ if arr[-1] in CXX_SUFFIX:
+ _HELPER.process_cpp(fname, arr[-1])
+
+
+def main():
+ """Main entry function."""
+ parser = argparse.ArgumentParser(description="lint source codes")
+ parser.add_argument("--project", help="project name")
+ parser.add_argument(
+ "--path",
+ nargs="+",
+ default=[],
+ help="path to traverse",
+ required=False,
+ )
+ parser.add_argument(
+ "--exclude_path",
+ nargs="+",
+ default=[],
+ help="exclude this path, and all subfolders " + "if path is a folder",
+ )
+
+ args = parser.parse_args()
+ _HELPER.project_name = args.project
+ allow_type = []
+ allow_type += [x for x in CXX_SUFFIX]
+ allow_type = set(allow_type)
+
+ # get excluded files
+ excluded_paths = filepath_enumerate(args.exclude_path)
+ for path in args.path:
+ if os.path.isfile(path):
+ normpath = os.path.normpath(path)
+ if normpath not in excluded_paths:
+ process(path, allow_type)
+ else:
+ for root, dirs, files in os.walk(path):
+ for name in files:
+ file_path = os.path.normpath(os.path.join(root, name))
+ if file_path not in excluded_paths:
+ process(file_path, allow_type)
+ nerr = _HELPER.print_summary(sys.stderr)
+ sys.exit(nerr > 0)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/lint_src/pep8.hook b/scripts/lint_src/pep8.hook
new file mode 100755
index 0000000..6926a05
--- /dev/null
+++ b/scripts/lint_src/pep8.hook
@@ -0,0 +1,16 @@
+#!/bin/bash
+
+TOTAL_ERRORS=0
+if [[ ! $(which pycodestyle) ]]; then
+ pip install pycodestyle
+fi
+# diff files on local machine.
+files=$(git diff --cached --name-status | awk '$1 != "D" {print $2}')
+for file in $files; do
+ if [ "${file##*.}" == "py" & -f "${file}"] ; then
+ pycodestyle --show-source $file --config=setup.cfg;
+ TOTAL_ERRORS=$(expr $TOTAL_ERRORS + $?);
+ fi
+done
+
+exit $TOTAL_ERRORS
diff --git a/scripts/lint_src/pydocstyle.hook b/scripts/lint_src/pydocstyle.hook
new file mode 100755
index 0000000..8b689d6
--- /dev/null
+++ b/scripts/lint_src/pydocstyle.hook
@@ -0,0 +1,16 @@
+#!/bin/bash
+
+TOTAL_ERRORS=0
+if [[ ! $(which pydocstyle) ]]; then
+ pip install pydocstyle
+fi
+# diff files on local machine.
+files=$(git diff --cached --name-status | awk '$1 != "D" {print $2}')
+for file in $files; do
+ if [ "${file##*.}" == "py" ] ; then
+ pydocstyle $file;
+ TOTAL_ERRORS=$(expr $TOTAL_ERRORS + $?);
+ fi
+done
+
+exit $TOTAL_ERRORS
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000..7ef6850
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,2 @@
+[pycodestyle]
+ignore = E203,W503,E402,E501
diff --git a/thirdparty/TRELLIS b/thirdparty/TRELLIS
new file mode 160000
index 0000000..55a8e81
--- /dev/null
+++ b/thirdparty/TRELLIS
@@ -0,0 +1 @@
+Subproject commit 55a8e8164b195bbf927e0978f00e76c835e6011f