feat(pipeline): Add EmbodiedGen version v0.1.0. (#2)

Add EmbodiedGen version v0.1.0.
This commit is contained in:
Xinjie 2025-06-11 22:09:22 +08:00 committed by GitHub
parent 7420364fee
commit 18075659de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
49 changed files with 10418 additions and 3 deletions

62
.gitignore vendored Normal file
View File

@ -0,0 +1,62 @@
build/
dummy/
!scripts/build
builddir/
conan-deps/
distribute/
lib/
bin/
dist/
deps/*
docs/build
python/dist/
docs/index.rst
python/MANIFEST.in
*.egg-info
*.pyc
*.pyi
*.json
*.bak
*.zip
wheels*/
# Compiled Object files
*.slo
*.lo
*.o
# Compiled Dynamic libraries
*.so
*.dylib
# Compiled Static libraries
*.lai
*.la
*.a
.cproject
.project
.settings/
*.db
*.bak
.arcconfig
.vscode/
# files
*.pack
*.pcd
*.html
*.ply
*.mp4
# node
node_modules
# local files
build.sh
__pycache__/
output*
*.log
scripts/tools/
weights/
apps/assets/example_texture/
apps/sessions/

4
.gitmodules vendored Normal file
View File

@ -0,0 +1,4 @@
[submodule "thirdparty/TRELLIS"]
path = thirdparty/TRELLIS
url = https://github.com/microsoft/TRELLIS.git
branch = main

78
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,78 @@
repos:
- repo: git@gitlab.hobot.cc:ptd/3rd/pre-commit/pre-commit-hooks.git
rev: v2.3.0 # Use the ref you want to point at
hooks:
- id: trailing-whitespace
- id: check-added-large-files
name: Check for added large files
description: Prevent giant files from being committed
entry: check-added-large-files
language: python
args: ["--maxkb=1024"]
- repo: local
hooks:
- id: cpplint-cpp-source
name: cpplint
description: Check cpp code style.
entry: python3 scripts/lint_src/lint.py
language: system
exclude: (?x)(^tools/|^thirdparty/|^patch_files/)
files: \.(c|cc|cxx|cpp|cu|h|hpp)$
args: [--project=asset_recons, --path]
- repo: local
hooks:
- id: pycodestyle-python
name: pep8-exclude-docs
description: Check python code style.
entry: pycodestyle
language: system
exclude: (?x)(^docs/|^thirdparty/|^scripts/build/)
files: \.(py)$
types: [file, python]
args: [--config=setup.cfg]
# pre-commit install --hook-type commit-msg to enable it
# - repo: local
# hooks:
# - id: commit-check
# name: check for commit msg format
# language: pygrep
# entry: '\A(?!(feat|fix|docs|style|refactor|perf|test|chore)\(.*\): (\[[a-zA-Z][a-zA-Z0-9_]+-[1-9][0-9]*\]|\[cr_id_skip\]) [A-Z]+.*)'
# args: [--multiline]
# stages: [commit-msg]
- repo: local
hooks:
- id: pydocstyle-python
name: pydocstyle-change-exclude-docs
description: Check python doc style.
entry: pydocstyle
language: system
exclude: (?x)(^docs/|^thirdparty/)
files: \.(py)$
types: [file, python]
args: [--config=pyproject.toml]
- repo: local
hooks:
- id: black
name: black-exclude-docs
description: black format
entry: black
language: system
exclude: (?x)(^docs/|^thirdparty/)
files: \.(py)$
types: [file, python]
args: [--config=pyproject.toml]
- repo: local
hooks:
- id: isort
name: isort
description: isort format
entry: isort
language: system
exclude: (?x)(^thirdparty/)
files: \.(py)$
types: [file, python]
args: [--settings-file=pyproject.toml]

View File

@ -1,3 +1,5 @@
Copyright (c) 2024 Horizon Robotics and EmbodiedGen Contributors. All rights reserved.
Apache License Apache License
Version 2.0, January 2004 Version 2.0, January 2004
http://www.apache.org/licenses/ http://www.apache.org/licenses/
@ -186,7 +188,7 @@
same "printed page" as the copyright notice for easier same "printed page" as the copyright notice for easier
identification within third-party archives. identification within third-party archives.
Copyright [yyyy] [name of copyright owner] Copyright 2024 Horizon Robotics and EmbodiedGen Contributors
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.

1
MANIFEST.in Normal file
View File

@ -0,0 +1 @@
graft embodied_gen

180
README.md
View File

@ -1,2 +1,178 @@
# EmbodiedGen # EmbodiedGen: Towards a Generative 3D World Engine for Embodied Intelligence
Towards a Generative 3D World Engine for Embodied Intelligence
[![🌐 Project Page](https://img.shields.io/badge/🌐-Project_Page-blue)](https://horizonrobotics.github.io/robot_lab/embodied_gen/index.html)
[![📄 arXiv](https://img.shields.io/badge/📄-arXiv-b31b1b)](#)
[![🎥 Video](https://img.shields.io/badge/🎥-Video-red)](https://www.youtube.com/watch?v=SnHhzHeb_aI)
[![🤗 Hugging Face](https://img.shields.io/badge/🤗-Image_to_3D_Demo-blue)](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Image-to-3D)
[![🤗 Hugging Face](https://img.shields.io/badge/🤗-Text_to_3D_Demo-blue)](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Text-to-3D)
[![🤗 Hugging Face](https://img.shields.io/badge/🤗-Texture_Gen_Demo-blue)](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Texture-Gen)
<img src="apps/assets/overall.jpg" alt="Overall Framework" width="700"/>
**EmbodiedGen** generates interactive 3D worlds with real-world scale and physical realism at low cost.
---
## ✨ Table of Contents of EmbodiedGen
- [🖼️ Image-to-3D](#image-to-3d)
- [📝 Text-to-3D](#text-to-3d)
- [🎨 Texture Generation](#texture-generation)
- [🌍 3D Scene Generation](#3d-scene-generation)
- [⚙️ Articulated Object Generation](#articulated-object-generation)
- [🏞️ Layout Generation](#layout-generation)
## 🚀 Quick Start
```sh
git clone https://github.com/HorizonRobotics/EmbodiedGen
cd EmbodiedGen
conda create -n embodiedgen python=3.10.13 -y
conda activate embodiedgen
pip install -r requirements.txt --use-deprecated=legacy-resolver
pip install -e .
```
---
## 🟢 Setup GPT Agent
Update the API key in file: `embodied_gen/utils/gpt_config.yaml`.
You can choose between two backends for the GPT agent:
- **`gpt-4o`** (Recommended) Use this if you have access to **Azure OpenAI**.
- **`qwen2.5-vl`** An open alternative with free usage via [OpenRouter](https://openrouter.ai/settings/keys) (50 free requests per day)
---
<h2 id="image-to-3d">🖼️ Image-to-3D</h2>
[![🤗 Hugging Face](https://img.shields.io/badge/🤗-Image_to_3D_Demo-blue)](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Image-to-3D) Generate physically plausible 3D asset from input image.
### Local Service
Run the image-to-3D generation service locally. The first run will download required models.
```sh
# Run in foreground
python apps/image_to_3d.py
# Or run in the background
CUDA_VISIBLE_DEVICES=0 nohup python apps/image_to_3d.py > /dev/null 2>&1 &
```
### Local API
Generate a 3D model from an image using the command-line API.
```sh
python3 embodied_gen/scripts/imageto3d.py \
--image_path apps/assets/example_image/sample_04.jpg apps/assets/example_image/sample_19.jpg \
--output_root outputs/imageto3d/
# See result(.urdf/mesh.obj/mesh.glb/gs.ply) in ${output_root}/sample_xx/result
```
---
<h2 id="text-to-3d">📝 Text-to-3D</h2>
[![🤗 Hugging Face](https://img.shields.io/badge/🤗-Text_to_3D_Demo-blue)](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Text-to-3D) Create 3D assets from text descriptions for a wide range of geometry and styles.
### Local Service
Run the text-to-3D generation service locally.
```sh
python apps/text_to_3d.py
```
### Local API
```sh
bash embodied_gen/scripts/textto3d.sh \
--prompts "small bronze figurine of a lion" "带木质底座,具有经纬线的地球仪" "橙色电动手钻,有磨损细节" \
--output_root outputs/textto3d/
```
---
<h2 id="texture-generation">🎨 Texture Generation</h2>
[![🤗 Hugging Face](https://img.shields.io/badge/🤗-Texture_Gen_Demo-blue)](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Texture-Gen) Generate visually rich textures for 3D mesh.
### Local Service
Run the texture generation service locally.
```sh
python apps/texture_edit.py
```
### Local API
Generate textures for a 3D mesh using a text prompt.
```sh
bash embodied_gen/scripts/texture_gen.sh \
--mesh_path "apps/assets/example_texture/meshes/robot_text.obj" \
--prompt "举着牌子的红色写实风格机器人牌子上写着“Hello”" \
--output_root "outputs/texture_gen/" \
--uuid "robot_text"
```
---
<h2 id="3d-scene-generation">🌍 3D Scene Generation</h2>
🚧 *Coming Soon*
---
<h2 id="articulated-object-generation">⚙️ Articulated Object Generation</h2>
🚧 *Coming Soon*
---
<h2 id="layout-generation">🏞️ Layout Generation</h2>
🚧 *Coming Soon*
---
## 📚 Citation
If you use EmbodiedGen in your research or projects, please cite:
```bibtex
Coming Soon
```
---
## 🙌 Acknowledgement
EmbodiedGen builds upon the following amazing projects and models:
- 🌟 [Trellis](https://github.com/microsoft/TRELLIS)
- 🌟 [Hunyuan-Delight](https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0)
- 🌟 [Segment Anything Model](https://github.com/facebookresearch/segment-anything)
- 🌟 [Rembg: a tool to remove images background](https://github.com/danielgatis/rembg)
- 🌟 [RMBG-1.4: BRIA Background Removal](https://huggingface.co/briaai/RMBG-1.4)
- 🌟 [stable-diffusion-x4-upscaler](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler)
- 🌟 [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN)
- 🌟 [Kolors](https://github.com/Kwai-Kolors/Kolors)
- 🌟 [ChatGLM3](https://github.com/THUDM/ChatGLM3)
- 🌟 [Aesthetic Score Model](http://captions.christoph-schuhmann.de/aesthetic_viz_laion_sac+logos+ava1-l14-linearMSE-en-2.37B.html)
- 🌟 [Pano2Room](https://github.com/TrickyGo/Pano2Room)
- 🌟 [Diffusion360](https://github.com/ArcherFMY/SD-T2I-360PanoImage)
- 🌟 [kaolin](https://github.com/NVIDIAGameWorks/kaolin)
- 🌟 [diffusers](https://github.com/huggingface/diffusers)
- 🌟 GPT: QWEN2.5VL, GPT4o
---
## ⚖️ License
This project is licensed under the [Apache License 2.0](LICENSE). See the `LICENSE` file for details.

899
apps/common.py Normal file
View File

@ -0,0 +1,899 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import gc
import logging
import os
import shutil
import subprocess
import sys
from glob import glob
import cv2
import gradio as gr
import numpy as np
import spaces
import torch
import torch.nn.functional as F
import trimesh
from easydict import EasyDict as edict
from gradio.themes import Soft
from gradio.themes.utils.colors import gray, neutral, slate, stone, teal, zinc
from PIL import Image
from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
from embodied_gen.data.differentiable_render import entrypoint as render_api
from embodied_gen.data.utils import trellis_preprocess
from embodied_gen.models.delight_model import DelightingModel
from embodied_gen.models.gs_model import GaussianOperator
from embodied_gen.models.segment_model import (
BMGG14Remover,
RembgRemover,
SAMPredictor,
)
from embodied_gen.models.sr_model import ImageRealESRGAN, ImageStableSR
from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
from embodied_gen.scripts.render_mv import build_texture_gen_pipe, infer_pipe
from embodied_gen.scripts.text2image import (
build_text2img_ip_pipeline,
build_text2img_pipeline,
text2img_gen,
)
from embodied_gen.utils.gpt_clients import GPT_CLIENT
from embodied_gen.utils.process_media import (
filter_image_small_connected_components,
merge_images_video,
render_video,
)
from embodied_gen.utils.tags import VERSION
from embodied_gen.validators.quality_checkers import (
BaseChecker,
ImageAestheticChecker,
ImageSegChecker,
MeshGeoChecker,
)
from embodied_gen.validators.urdf_convertor import URDFGenerator, zip_files
current_file_path = os.path.abspath(__file__)
current_dir = os.path.dirname(current_file_path)
sys.path.append(os.path.join(current_dir, ".."))
from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
from thirdparty.TRELLIS.trellis.representations import (
Gaussian,
MeshExtractResult,
)
from thirdparty.TRELLIS.trellis.representations.gaussian.general_utils import (
build_scaling_rotation,
inverse_sigmoid,
strip_symmetric,
)
from thirdparty.TRELLIS.trellis.utils import postprocessing_utils
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)
os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
"~/.cache/torch_extensions"
)
os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
os.environ["SPCONV_ALGO"] = "native"
MAX_SEED = 100000
DELIGHT = DelightingModel()
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
# IMAGESR_MODEL = ImageStableSR()
def patched_setup_functions(self):
def inverse_softplus(x):
return x + torch.log(-torch.expm1(-x))
def build_covariance_from_scaling_rotation(
scaling, scaling_modifier, rotation
):
L = build_scaling_rotation(scaling_modifier * scaling, rotation)
actual_covariance = L @ L.transpose(1, 2)
symm = strip_symmetric(actual_covariance)
return symm
if self.scaling_activation_type == "exp":
self.scaling_activation = torch.exp
self.inverse_scaling_activation = torch.log
elif self.scaling_activation_type == "softplus":
self.scaling_activation = F.softplus
self.inverse_scaling_activation = inverse_softplus
self.covariance_activation = build_covariance_from_scaling_rotation
self.opacity_activation = torch.sigmoid
self.inverse_opacity_activation = inverse_sigmoid
self.rotation_activation = F.normalize
self.scale_bias = self.inverse_scaling_activation(
torch.tensor(self.scaling_bias)
).to(self.device)
self.rots_bias = torch.zeros((4)).to(self.device)
self.rots_bias[0] = 1
self.opacity_bias = self.inverse_opacity_activation(
torch.tensor(self.opacity_bias)
).to(self.device)
Gaussian.setup_functions = patched_setup_functions
def download_kolors_weights() -> None:
logger.info(f"Download kolors weights from huggingface...")
subprocess.run(
[
"huggingface-cli",
"download",
"--resume-download",
"Kwai-Kolors/Kolors",
"--local-dir",
"weights/Kolors",
],
check=True,
)
subprocess.run(
[
"huggingface-cli",
"download",
"--resume-download",
"Kwai-Kolors/Kolors-IP-Adapter-Plus",
"--local-dir",
"weights/Kolors-IP-Adapter-Plus",
],
check=True,
)
if os.getenv("GRADIO_APP") == "imageto3d":
RBG_REMOVER = RembgRemover()
RBG14_REMOVER = BMGG14Remover()
SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
"microsoft/TRELLIS-image-large"
)
# PIPELINE.cuda()
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
AESTHETIC_CHECKER = ImageAestheticChecker()
CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
TMP_DIR = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
)
elif os.getenv("GRADIO_APP") == "textto3d":
RBG_REMOVER = RembgRemover()
RBG14_REMOVER = BMGG14Remover()
PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
"microsoft/TRELLIS-image-large"
)
# PIPELINE.cuda()
text_model_dir = "weights/Kolors"
if not os.path.exists(text_model_dir):
download_kolors_weights()
PIPELINE_IMG_IP = build_text2img_ip_pipeline(text_model_dir, ref_scale=0.3)
PIPELINE_IMG = build_text2img_pipeline(text_model_dir)
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
AESTHETIC_CHECKER = ImageAestheticChecker()
CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
TMP_DIR = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "sessions/textto3d"
)
elif os.getenv("GRADIO_APP") == "texture_edit":
if not os.path.exists("weights/Kolors"):
download_kolors_weights()
PIPELINE_IP = build_texture_gen_pipe(
base_ckpt_dir="./weights",
ip_adapt_scale=0.7,
device="cuda",
)
PIPELINE = build_texture_gen_pipe(
base_ckpt_dir="./weights",
ip_adapt_scale=0,
device="cuda",
)
TMP_DIR = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "sessions/texture_edit"
)
os.makedirs(TMP_DIR, exist_ok=True)
lighting_css = """
<style>
#lighter_mesh canvas {
filter: brightness(1.8) !important;
}
</style>
"""
image_css = """
<style>
.image_fit .image-frame {
object-fit: contain !important;
height: 100% !important;
}
</style>
"""
custom_theme = Soft(
primary_hue=stone,
secondary_hue=gray,
radius_size="md",
text_size="sm",
spacing_size="sm",
)
def start_session(req: gr.Request) -> None:
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
def end_session(req: gr.Request) -> None:
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
if os.path.exists(user_dir):
shutil.rmtree(user_dir)
@spaces.GPU
def preprocess_image_fn(
image: str | np.ndarray | Image.Image, rmbg_tag: str = "rembg"
) -> tuple[Image.Image, Image.Image]:
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)
image_cache = image.copy().resize((512, 512))
bg_remover = RBG_REMOVER if rmbg_tag == "rembg" else RBG14_REMOVER
image = bg_remover(image)
image = trellis_preprocess(image)
return image, image_cache
def preprocess_sam_image_fn(
image: Image.Image,
) -> tuple[Image.Image, Image.Image]:
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
sam_image = SAM_PREDICTOR.preprocess_image(image)
image_cache = Image.fromarray(sam_image).resize((512, 512))
SAM_PREDICTOR.predictor.set_image(sam_image)
return sam_image, image_cache
def active_btn_by_content(content: gr.Image) -> gr.Button:
interactive = True if content is not None else False
return gr.Button(interactive=interactive)
def active_btn_by_text_content(content: gr.Textbox) -> gr.Button:
if content is not None and len(content) > 0:
interactive = True
else:
interactive = False
return gr.Button(interactive=interactive)
def get_selected_image(
choice: str, sample1: str, sample2: str, sample3: str
) -> str:
if choice == "sample1":
return sample1
elif choice == "sample2":
return sample2
elif choice == "sample3":
return sample3
else:
raise ValueError(f"Invalid choice: {choice}")
def get_cached_image(image_path: str) -> Image.Image:
if isinstance(image_path, Image.Image):
return image_path
return Image.open(image_path).resize((512, 512))
@spaces.GPU
def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
return {
"gaussian": {
**gs.init_params,
"_xyz": gs._xyz.cpu().numpy(),
"_features_dc": gs._features_dc.cpu().numpy(),
"_scaling": gs._scaling.cpu().numpy(),
"_rotation": gs._rotation.cpu().numpy(),
"_opacity": gs._opacity.cpu().numpy(),
},
"mesh": {
"vertices": mesh.vertices.cpu().numpy(),
"faces": mesh.faces.cpu().numpy(),
},
}
def unpack_state(state: dict, device: str = "cpu") -> tuple[Gaussian, dict]:
gs = Gaussian(
aabb=state["gaussian"]["aabb"],
sh_degree=state["gaussian"]["sh_degree"],
mininum_kernel_size=state["gaussian"]["mininum_kernel_size"],
scaling_bias=state["gaussian"]["scaling_bias"],
opacity_bias=state["gaussian"]["opacity_bias"],
scaling_activation=state["gaussian"]["scaling_activation"],
device=device,
)
gs._xyz = torch.tensor(state["gaussian"]["_xyz"], device=device)
gs._features_dc = torch.tensor(
state["gaussian"]["_features_dc"], device=device
)
gs._scaling = torch.tensor(state["gaussian"]["_scaling"], device=device)
gs._rotation = torch.tensor(state["gaussian"]["_rotation"], device=device)
gs._opacity = torch.tensor(state["gaussian"]["_opacity"], device=device)
mesh = edict(
vertices=torch.tensor(state["mesh"]["vertices"], device=device),
faces=torch.tensor(state["mesh"]["faces"], device=device),
)
return gs, mesh
def get_seed(randomize_seed: bool, seed: int, max_seed: int = MAX_SEED) -> int:
return np.random.randint(0, max_seed) if randomize_seed else seed
def select_point(
image: np.ndarray,
sel_pix: list,
point_type: str,
evt: gr.SelectData,
):
if point_type == "foreground_point":
sel_pix.append((evt.index, 1)) # append the foreground_point
elif point_type == "background_point":
sel_pix.append((evt.index, 0)) # append the background_point
else:
sel_pix.append((evt.index, 1)) # default foreground_point
masks = SAM_PREDICTOR.generate_masks(image, sel_pix)
seg_image = SAM_PREDICTOR.get_segmented_image(image, masks)
for point, label in sel_pix:
color = (255, 0, 0) if label == 0 else (0, 255, 0)
marker_type = 1 if label == 0 else 5
cv2.drawMarker(
image,
point,
color,
markerType=marker_type,
markerSize=15,
thickness=10,
)
torch.cuda.empty_cache()
return (image, masks), seg_image
@spaces.GPU
def image_to_3d(
image: Image.Image,
seed: int,
ss_guidance_strength: float,
ss_sampling_steps: int,
slat_guidance_strength: float,
slat_sampling_steps: int,
raw_image_cache: Image.Image,
sam_image: Image.Image = None,
is_sam_image: bool = False,
req: gr.Request = None,
) -> tuple[dict, str]:
if is_sam_image:
seg_image = filter_image_small_connected_components(sam_image)
seg_image = Image.fromarray(seg_image, mode="RGBA")
seg_image = trellis_preprocess(seg_image)
else:
seg_image = image
if isinstance(seg_image, np.ndarray):
seg_image = Image.fromarray(seg_image)
output_root = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(output_root, exist_ok=True)
seg_image.save(f"{output_root}/seg_image.png")
raw_image_cache.save(f"{output_root}/raw_image.png")
PIPELINE.cuda()
outputs = PIPELINE.run(
seg_image,
seed=seed,
formats=["gaussian", "mesh"],
preprocess_image=False,
sparse_structure_sampler_params={
"steps": ss_sampling_steps,
"cfg_strength": ss_guidance_strength,
},
slat_sampler_params={
"steps": slat_sampling_steps,
"cfg_strength": slat_guidance_strength,
},
)
# Set to cpu for memory saving.
PIPELINE.cpu()
gs_model = outputs["gaussian"][0]
mesh_model = outputs["mesh"][0]
color_images = render_video(gs_model)["color"]
normal_images = render_video(mesh_model)["normal"]
video_path = os.path.join(output_root, "gs_mesh.mp4")
merge_images_video(color_images, normal_images, video_path)
state = pack_state(gs_model, mesh_model)
gc.collect()
torch.cuda.empty_cache()
return state, video_path
@spaces.GPU
def extract_3d_representations(
state: dict, enable_delight: bool, texture_size: int, req: gr.Request
):
output_root = TMP_DIR
output_root = os.path.join(output_root, str(req.session_hash))
gs_model, mesh_model = unpack_state(state, device="cuda")
mesh = postprocessing_utils.to_glb(
gs_model,
mesh_model,
simplify=0.9,
texture_size=1024,
verbose=True,
)
filename = "sample"
gs_path = os.path.join(output_root, f"{filename}_gs.ply")
gs_model.save_ply(gs_path)
# Rotate mesh and GS by 90 degrees around Z-axis.
rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
# Addtional rotation for GS to align mesh.
gs_rot = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) @ np.array(
rot_matrix
)
pose = GaussianOperator.trans_to_quatpose(gs_rot)
aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
GaussianOperator.resave_ply(
in_ply=gs_path,
out_ply=aligned_gs_path,
instance_pose=pose,
)
mesh.vertices = mesh.vertices @ np.array(rot_matrix)
mesh_obj_path = os.path.join(output_root, f"{filename}.obj")
mesh.export(mesh_obj_path)
mesh_glb_path = os.path.join(output_root, f"{filename}.glb")
mesh.export(mesh_glb_path)
torch.cuda.empty_cache()
return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
def extract_3d_representations_v2(
state: dict,
enable_delight: bool,
texture_size: int,
req: gr.Request,
):
output_root = TMP_DIR
user_dir = os.path.join(output_root, str(req.session_hash))
gs_model, mesh_model = unpack_state(state, device="cpu")
filename = "sample"
gs_path = os.path.join(user_dir, f"{filename}_gs.ply")
gs_model.save_ply(gs_path)
# Rotate mesh and GS by 90 degrees around Z-axis.
rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
# Addtional rotation for GS to align mesh.
gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix)
pose = GaussianOperator.trans_to_quatpose(gs_rot)
aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
GaussianOperator.resave_ply(
in_ply=gs_path,
out_ply=aligned_gs_path,
instance_pose=pose,
device="cpu",
)
color_path = os.path.join(user_dir, "color.png")
render_gs_api(aligned_gs_path, color_path)
mesh = trimesh.Trimesh(
vertices=mesh_model.vertices.cpu().numpy(),
faces=mesh_model.faces.cpu().numpy(),
)
mesh.vertices = mesh.vertices @ np.array(mesh_add_rot)
mesh.vertices = mesh.vertices @ np.array(rot_matrix)
mesh_obj_path = os.path.join(user_dir, f"{filename}.obj")
mesh.export(mesh_obj_path)
mesh = backproject_api(
delight_model=DELIGHT,
imagesr_model=IMAGESR_MODEL,
color_path=color_path,
mesh_path=mesh_obj_path,
output_path=mesh_obj_path,
skip_fix_mesh=False,
delight=enable_delight,
texture_wh=[texture_size, texture_size],
)
mesh_glb_path = os.path.join(user_dir, f"{filename}.glb")
mesh.export(mesh_glb_path)
return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
def extract_urdf(
gs_path: str,
mesh_obj_path: str,
asset_cat_text: str,
height_range_text: str,
mass_range_text: str,
asset_version_text: str,
req: gr.Request = None,
):
output_root = TMP_DIR
if req is not None:
output_root = os.path.join(output_root, str(req.session_hash))
# Convert to URDF and recover attrs by GPT.
filename = "sample"
urdf_convertor = URDFGenerator(GPT_CLIENT, render_view_num=4)
asset_attrs = {
"version": VERSION,
"gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply",
}
if asset_version_text:
asset_attrs["version"] = asset_version_text
if asset_cat_text:
asset_attrs["category"] = asset_cat_text.lower()
if height_range_text:
try:
min_height, max_height = map(float, height_range_text.split("-"))
asset_attrs["min_height"] = min_height
asset_attrs["max_height"] = max_height
except ValueError:
return "Invalid height input format. Use the format: min-max."
if mass_range_text:
try:
min_mass, max_mass = map(float, mass_range_text.split("-"))
asset_attrs["min_mass"] = min_mass
asset_attrs["max_mass"] = max_mass
except ValueError:
return "Invalid mass input format. Use the format: min-max."
urdf_path = urdf_convertor(
mesh_path=mesh_obj_path,
output_root=f"{output_root}/URDF_{filename}",
**asset_attrs,
)
# Rescale GS and save to URDF/mesh folder.
real_height = urdf_convertor.get_attr_from_urdf(
urdf_path, attr_name="real_height"
)
out_gs = f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}/{filename}_gs.ply" # noqa
GaussianOperator.resave_ply(
in_ply=gs_path,
out_ply=out_gs,
real_height=real_height,
device="cpu",
)
# Quality check and update .urdf file.
mesh_out = f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}/{filename}.obj" # noqa
trimesh.load(mesh_out).export(mesh_out.replace(".obj", ".glb"))
# image_paths = render_asset3d(
# mesh_path=mesh_out,
# output_root=f"{output_root}/URDF_{filename}",
# output_subdir="qa_renders",
# num_images=8,
# elevation=(30, -30),
# distance=5.5,
# )
image_dir = f"{output_root}/URDF_{filename}/{urdf_convertor.output_render_dir}/image_color" # noqa
image_paths = glob(f"{image_dir}/*.png")
images_list = []
for checker in CHECKERS:
images = image_paths
if isinstance(checker, ImageSegChecker):
images = [
f"{TMP_DIR}/{req.session_hash}/raw_image.png",
f"{TMP_DIR}/{req.session_hash}/seg_image.png",
]
images_list.append(images)
results = BaseChecker.validate(CHECKERS, images_list)
urdf_convertor.add_quality_tag(urdf_path, results)
# Zip urdf files
urdf_zip = zip_files(
input_paths=[
f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}",
f"{output_root}/URDF_{filename}/{filename}.urdf",
],
output_zip=f"{output_root}/urdf_{filename}.zip",
)
estimated_type = urdf_convertor.estimated_attrs["category"]
estimated_height = urdf_convertor.estimated_attrs["height"]
estimated_mass = urdf_convertor.estimated_attrs["mass"]
estimated_mu = urdf_convertor.estimated_attrs["mu"]
return (
urdf_zip,
estimated_type,
estimated_height,
estimated_mass,
estimated_mu,
)
@spaces.GPU
def text2image_fn(
prompt: str,
guidance_scale: float,
infer_step: int = 50,
ip_image: Image.Image | str = None,
ip_adapt_scale: float = 0.3,
image_wh: int | tuple[int, int] = [1024, 1024],
rmbg_tag: str = "rembg",
seed: int = None,
n_sample: int = 3,
req: gr.Request = None,
):
if isinstance(image_wh, int):
image_wh = (image_wh, image_wh)
output_root = TMP_DIR
if req is not None:
output_root = os.path.join(output_root, str(req.session_hash))
os.makedirs(output_root, exist_ok=True)
pipeline = PIPELINE_IMG if ip_image is None else PIPELINE_IMG_IP
if ip_image is not None:
pipeline.set_ip_adapter_scale([ip_adapt_scale])
images = text2img_gen(
prompt=prompt,
n_sample=n_sample,
guidance_scale=guidance_scale,
pipeline=pipeline,
ip_image=ip_image,
image_wh=image_wh,
infer_step=infer_step,
seed=seed,
)
for idx in range(len(images)):
image = images[idx]
images[idx], _ = preprocess_image_fn(image, rmbg_tag)
save_paths = []
for idx, image in enumerate(images):
save_path = f"{output_root}/sample_{idx}.png"
image.save(save_path)
save_paths.append(save_path)
logger.info(f"Images saved to {output_root}")
gc.collect()
torch.cuda.empty_cache()
return save_paths + save_paths
@spaces.GPU
def generate_condition(mesh_path: str, req: gr.Request, uuid: str = "sample"):
output_root = os.path.join(TMP_DIR, str(req.session_hash))
_ = render_api(
mesh_path=mesh_path,
output_root=f"{output_root}/condition",
uuid=str(uuid),
)
gc.collect()
torch.cuda.empty_cache()
return None, None, None
@spaces.GPU
def generate_texture_mvimages(
prompt: str,
controlnet_cond_scale: float = 0.55,
guidance_scale: float = 9,
strength: float = 0.9,
num_inference_steps: int = 50,
seed: int = 0,
ip_adapt_scale: float = 0,
ip_img_path: str = None,
uid: str = "sample",
sub_idxs: tuple[tuple[int]] = ((0, 1, 2), (3, 4, 5)),
req: gr.Request = None,
) -> list[str]:
output_root = os.path.join(TMP_DIR, str(req.session_hash))
use_ip_adapter = True if ip_img_path and ip_adapt_scale > 0 else False
PIPELINE_IP.set_ip_adapter_scale([ip_adapt_scale])
img_save_paths = infer_pipe(
index_file=f"{output_root}/condition/index.json",
controlnet_cond_scale=controlnet_cond_scale,
guidance_scale=guidance_scale,
strength=strength,
num_inference_steps=num_inference_steps,
ip_adapt_scale=ip_adapt_scale,
ip_img_path=ip_img_path,
uid=uid,
prompt=prompt,
save_dir=f"{output_root}/multi_view",
sub_idxs=sub_idxs,
pipeline=PIPELINE_IP if use_ip_adapter else PIPELINE,
seed=seed,
)
gc.collect()
torch.cuda.empty_cache()
return img_save_paths + img_save_paths
def backproject_texture(
mesh_path: str,
input_image: str,
texture_size: int,
uuid: str = "sample",
req: gr.Request = None,
) -> str:
output_root = os.path.join(TMP_DIR, str(req.session_hash))
output_dir = os.path.join(output_root, "texture_mesh")
os.makedirs(output_dir, exist_ok=True)
command = [
"backproject-cli",
"--mesh_path",
mesh_path,
"--input_image",
input_image,
"--output_root",
output_dir,
"--uuid",
f"{uuid}",
"--texture_size",
str(texture_size),
"--skip_fix_mesh",
]
_ = subprocess.run(
command, capture_output=True, text=True, encoding="utf-8"
)
output_obj_mesh = os.path.join(output_dir, f"{uuid}.obj")
output_glb_mesh = os.path.join(output_dir, f"{uuid}.glb")
_ = trimesh.load(output_obj_mesh).export(output_glb_mesh)
zip_file = zip_files(
input_paths=[
output_glb_mesh,
output_obj_mesh,
os.path.join(output_dir, "material.mtl"),
os.path.join(output_dir, "material_0.png"),
],
output_zip=os.path.join(output_dir, f"{uuid}.zip"),
)
gc.collect()
torch.cuda.empty_cache()
return output_glb_mesh, output_obj_mesh, zip_file
@spaces.GPU
def backproject_texture_v2(
mesh_path: str,
input_image: str,
texture_size: int,
enable_delight: bool = True,
fix_mesh: bool = False,
uuid: str = "sample",
req: gr.Request = None,
) -> str:
output_root = os.path.join(TMP_DIR, str(req.session_hash))
output_dir = os.path.join(output_root, "texture_mesh")
os.makedirs(output_dir, exist_ok=True)
textured_mesh = backproject_api(
delight_model=DELIGHT,
imagesr_model=IMAGESR_MODEL,
color_path=input_image,
mesh_path=mesh_path,
output_path=f"{output_dir}/{uuid}.obj",
skip_fix_mesh=not fix_mesh,
delight=enable_delight,
texture_wh=[texture_size, texture_size],
)
output_obj_mesh = os.path.join(output_dir, f"{uuid}.obj")
output_glb_mesh = os.path.join(output_dir, f"{uuid}.glb")
_ = textured_mesh.export(output_glb_mesh)
zip_file = zip_files(
input_paths=[
output_glb_mesh,
output_obj_mesh,
os.path.join(output_dir, "material.mtl"),
os.path.join(output_dir, "material_0.png"),
],
output_zip=os.path.join(output_dir, f"{uuid}.zip"),
)
gc.collect()
torch.cuda.empty_cache()
return output_glb_mesh, output_obj_mesh, zip_file
@spaces.GPU
def render_result_video(
mesh_path: str, video_size: int, req: gr.Request, uuid: str = ""
) -> str:
output_root = os.path.join(TMP_DIR, str(req.session_hash))
output_dir = os.path.join(output_root, "texture_mesh")
_ = render_api(
mesh_path=mesh_path,
output_root=output_dir,
num_images=90,
elevation=[20],
with_mtl=True,
pbr_light_factor=1,
uuid=str(uuid),
gen_color_mp4=True,
gen_glonormal_mp4=True,
distance=5.5,
resolution_hw=(video_size, video_size),
)
gc.collect()
torch.cuda.empty_cache()
return f"{output_dir}/color.mp4"

501
apps/image_to_3d.py Normal file
View File

@ -0,0 +1,501 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import os
os.environ["GRADIO_APP"] = "imageto3d"
from glob import glob
import gradio as gr
from common import (
MAX_SEED,
VERSION,
active_btn_by_content,
custom_theme,
end_session,
extract_3d_representations_v2,
extract_urdf,
get_seed,
image_css,
image_to_3d,
lighting_css,
preprocess_image_fn,
preprocess_sam_image_fn,
select_point,
start_session,
)
with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
gr.Markdown(
"""
## ***EmbodiedGen***: Image-to-3D Asset
**🔖 Version**: {VERSION}
<p style="display: flex; gap: 10px; flex-wrap: nowrap;">
<a href="https://horizonrobotics.github.io/robot_lab/embodied_gen/index.html">
<img alt="🌐 Project Page" src="https://img.shields.io/badge/🌐-Project_Page-blue">
</a>
<a href="https://arxiv.org/abs/xxxx.xxxxx">
<img alt="📄 arXiv" src="https://img.shields.io/badge/📄-arXiv-b31b1b">
</a>
<a href="https://github.com/HorizonRobotics/EmbodiedGen">
<img alt="💻 GitHub" src="https://img.shields.io/badge/GitHub-000000?logo=github">
</a>
<a href="https://www.youtube.com/watch?v=SnHhzHeb_aI">
<img alt="🎥 Video" src="https://img.shields.io/badge/🎥-Video-red">
</a>
</p>
🖼 Generate physically plausible 3D asset from single input image.
""".format(
VERSION=VERSION
),
elem_classes=["header"],
)
gr.HTML(image_css)
gr.HTML(lighting_css)
with gr.Row():
with gr.Column(scale=2):
with gr.Tabs() as input_tabs:
with gr.Tab(
label="Image(auto seg)", id=0
) as single_image_input_tab:
raw_image_cache = gr.Image(
format="png",
image_mode="RGB",
type="pil",
visible=False,
)
image_prompt = gr.Image(
label="Input Image",
format="png",
image_mode="RGBA",
type="pil",
height=400,
elem_classes=["image_fit"],
)
gr.Markdown(
"""
If you are not satisfied with the auto segmentation
result, please switch to the `Image(SAM seg)` tab."""
)
with gr.Tab(
label="Image(SAM seg)", id=1
) as samimage_input_tab:
with gr.Row():
with gr.Column(scale=1):
image_prompt_sam = gr.Image(
label="Input Image",
type="numpy",
height=400,
elem_classes=["image_fit"],
)
image_seg_sam = gr.Image(
label="SAM Seg Image",
image_mode="RGBA",
type="pil",
height=400,
visible=False,
)
with gr.Column(scale=1):
image_mask_sam = gr.AnnotatedImage(
elem_classes=["image_fit"]
)
fg_bg_radio = gr.Radio(
["foreground_point", "background_point"],
label="Select foreground(green) or background(red) points, by default foreground", # noqa
value="foreground_point",
)
gr.Markdown(
""" Click the `Input Image` to select SAM points,
after get the satisified segmentation, click `Generate`
button to generate the 3D asset. \n
Note: If the segmented foreground is too small relative
to the entire image area, the generation will fail.
"""
)
with gr.Accordion(label="Generation Settings", open=False):
with gr.Row():
seed = gr.Slider(
0, MAX_SEED, label="Seed", value=0, step=1
)
texture_size = gr.Slider(
1024,
4096,
label="UV texture size",
value=2048,
step=256,
)
rmbg_tag = gr.Radio(
choices=["rembg", "rmbg14"],
value="rembg",
label="Background Removal Model",
)
with gr.Row():
randomize_seed = gr.Checkbox(
label="Randomize Seed", value=False
)
project_delight = gr.Checkbox(
label="Backproject delighting",
value=False,
)
gr.Markdown("Geo Structure Generation")
with gr.Row():
ss_guidance_strength = gr.Slider(
0.0,
10.0,
label="Guidance Strength",
value=7.5,
step=0.1,
)
ss_sampling_steps = gr.Slider(
1, 50, label="Sampling Steps", value=12, step=1
)
gr.Markdown("Visual Appearance Generation")
with gr.Row():
slat_guidance_strength = gr.Slider(
0.0,
10.0,
label="Guidance Strength",
value=3.0,
step=0.1,
)
slat_sampling_steps = gr.Slider(
1, 50, label="Sampling Steps", value=12, step=1
)
generate_btn = gr.Button(
"🚀 1. Generate(~0.5 mins)",
variant="primary",
interactive=False,
)
model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False)
with gr.Row():
extract_rep3d_btn = gr.Button(
"🔍 2. Extract 3D Representation(~2 mins)",
variant="primary",
interactive=False,
)
with gr.Accordion(
label="Enter Asset Attributes(optional)", open=False
):
asset_cat_text = gr.Textbox(
label="Enter Asset Category (e.g., chair)"
)
height_range_text = gr.Textbox(
label="Enter **Height Range** in meter (e.g., 0.5-0.6)"
)
mass_range_text = gr.Textbox(
label="Enter **Mass Range** in kg (e.g., 1.1-1.2)"
)
asset_version_text = gr.Textbox(
label=f"Enter version (e.g., {VERSION})"
)
with gr.Row():
extract_urdf_btn = gr.Button(
"🧩 3. Extract URDF with physics(~1 mins)",
variant="primary",
interactive=False,
)
with gr.Row():
gr.Markdown(
"#### Estimated Asset 3D Attributes(No input required)"
)
with gr.Row():
est_type_text = gr.Textbox(
label="Asset category", interactive=False
)
est_height_text = gr.Textbox(
label="Real height(.m)", interactive=False
)
est_mass_text = gr.Textbox(
label="Mass(.kg)", interactive=False
)
est_mu_text = gr.Textbox(
label="Friction coefficient", interactive=False
)
with gr.Row():
download_urdf = gr.DownloadButton(
label="⬇️ 4. Download URDF",
variant="primary",
interactive=False,
)
gr.Markdown(
""" NOTE: If `Asset Attributes` are provided, the provided
properties will be used; otherwise, the GPT-preset properties
will be applied. \n
The `Download URDF` file is restored to the real scale and
has quality inspection, open with an editor to view details.
"""
)
with gr.Row() as single_image_example:
examples = gr.Examples(
label="Image Gallery",
examples=[
[image_path]
for image_path in sorted(
glob("apps/assets/example_image/*")
)
],
inputs=[image_prompt, rmbg_tag],
fn=preprocess_image_fn,
outputs=[image_prompt, raw_image_cache],
run_on_click=True,
examples_per_page=10,
)
with gr.Row(visible=False) as single_sam_image_example:
examples = gr.Examples(
label="Image Gallery",
examples=[
[image_path]
for image_path in sorted(
glob("apps/assets/example_image/*")
)
],
inputs=[image_prompt_sam],
fn=preprocess_sam_image_fn,
outputs=[image_prompt_sam, raw_image_cache],
run_on_click=True,
examples_per_page=10,
)
with gr.Column(scale=1):
video_output = gr.Video(
label="Generated 3D Asset",
autoplay=True,
loop=True,
height=300,
)
model_output_gs = gr.Model3D(
label="Gaussian Representation", height=300, interactive=False
)
aligned_gs = gr.Textbox(visible=False)
gr.Markdown(
""" The rendering of `Gaussian Representation` takes additional 10s. """ # noqa
)
with gr.Row():
model_output_mesh = gr.Model3D(
label="Mesh Representation",
height=300,
interactive=False,
clear_color=[0.8, 0.8, 0.8, 1],
elem_id="lighter_mesh",
)
is_samimage = gr.State(False)
output_buf = gr.State()
selected_points = gr.State(value=[])
demo.load(start_session)
demo.unload(end_session)
single_image_input_tab.select(
lambda: tuple(
[False, gr.Row.update(visible=True), gr.Row.update(visible=False)]
),
outputs=[is_samimage, single_image_example, single_sam_image_example],
)
samimage_input_tab.select(
lambda: tuple(
[True, gr.Row.update(visible=True), gr.Row.update(visible=False)]
),
outputs=[is_samimage, single_sam_image_example, single_image_example],
)
image_prompt.upload(
preprocess_image_fn,
inputs=[image_prompt, rmbg_tag],
outputs=[image_prompt, raw_image_cache],
)
image_prompt.change(
lambda: tuple(
[
gr.Button(interactive=False),
gr.Button(interactive=False),
gr.Button(interactive=False),
None,
"",
None,
None,
"",
"",
"",
"",
"",
"",
"",
"",
]
),
outputs=[
extract_rep3d_btn,
extract_urdf_btn,
download_urdf,
model_output_gs,
aligned_gs,
model_output_mesh,
video_output,
asset_cat_text,
height_range_text,
mass_range_text,
asset_version_text,
est_type_text,
est_height_text,
est_mass_text,
est_mu_text,
],
)
image_prompt.change(
active_btn_by_content,
inputs=image_prompt,
outputs=generate_btn,
)
image_prompt_sam.upload(
preprocess_sam_image_fn,
inputs=[image_prompt_sam],
outputs=[image_prompt_sam, raw_image_cache],
)
image_prompt_sam.change(
lambda: tuple(
[
gr.Button(interactive=False),
gr.Button(interactive=False),
gr.Button(interactive=False),
None,
None,
None,
"",
"",
"",
"",
"",
"",
"",
"",
None,
[],
]
),
outputs=[
extract_rep3d_btn,
extract_urdf_btn,
download_urdf,
model_output_gs,
model_output_mesh,
video_output,
asset_cat_text,
height_range_text,
mass_range_text,
asset_version_text,
est_type_text,
est_height_text,
est_mass_text,
est_mu_text,
image_mask_sam,
selected_points,
],
)
image_prompt_sam.select(
select_point,
[
image_prompt_sam,
selected_points,
fg_bg_radio,
],
[image_mask_sam, image_seg_sam],
)
image_seg_sam.change(
active_btn_by_content,
inputs=image_seg_sam,
outputs=generate_btn,
)
generate_btn.click(
get_seed,
inputs=[randomize_seed, seed],
outputs=[seed],
).success(
image_to_3d,
inputs=[
image_prompt,
seed,
ss_guidance_strength,
ss_sampling_steps,
slat_guidance_strength,
slat_sampling_steps,
raw_image_cache,
image_seg_sam,
is_samimage,
],
outputs=[output_buf, video_output],
).success(
lambda: gr.Button(interactive=True),
outputs=[extract_rep3d_btn],
)
extract_rep3d_btn.click(
extract_3d_representations_v2,
inputs=[
output_buf,
project_delight,
texture_size,
],
outputs=[
model_output_mesh,
model_output_gs,
model_output_obj,
aligned_gs,
],
).success(
lambda: gr.Button(interactive=True),
outputs=[extract_urdf_btn],
)
extract_urdf_btn.click(
extract_urdf,
inputs=[
aligned_gs,
model_output_obj,
asset_cat_text,
height_range_text,
mass_range_text,
asset_version_text,
],
outputs=[
download_urdf,
est_type_text,
est_height_text,
est_mass_text,
est_mu_text,
],
queue=True,
show_progress="full",
).success(
lambda: gr.Button(interactive=True),
outputs=[download_urdf],
)
if __name__ == "__main__":
demo.launch(server_name="10.34.8.82", server_port=8081)

481
apps/text_to_3d.py Normal file
View File

@ -0,0 +1,481 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import os
os.environ["GRADIO_APP"] = "textto3d"
import gradio as gr
from common import (
MAX_SEED,
VERSION,
active_btn_by_text_content,
custom_theme,
end_session,
extract_3d_representations_v2,
extract_urdf,
get_cached_image,
get_seed,
get_selected_image,
image_css,
image_to_3d,
lighting_css,
start_session,
text2image_fn,
)
with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
gr.Markdown(
"""
## ***EmbodiedGen***: Text-to-3D Asset
**🔖 Version**: {VERSION}
<p style="display: flex; gap: 10px; flex-wrap: nowrap;">
<a href="https://horizonrobotics.github.io/robot_lab/embodied_gen/index.html">
<img alt="🌐 Project Page" src="https://img.shields.io/badge/🌐-Project_Page-blue">
</a>
<a href="https://arxiv.org/abs/xxxx.xxxxx">
<img alt="📄 arXiv" src="https://img.shields.io/badge/📄-arXiv-b31b1b">
</a>
<a href="https://github.com/HorizonRobotics/EmbodiedGen">
<img alt="💻 GitHub" src="https://img.shields.io/badge/GitHub-000000?logo=github">
</a>
<a href="https://www.youtube.com/watch?v=SnHhzHeb_aI">
<img alt="🎥 Video" src="https://img.shields.io/badge/🎥-Video-red">
</a>
</p>
📝 Create 3D assets from text descriptions for a wide range of geometry and styles.
""".format(
VERSION=VERSION
),
elem_classes=["header"],
)
gr.HTML(image_css)
gr.HTML(lighting_css)
with gr.Row():
with gr.Column(scale=1):
raw_image_cache = gr.Image(
format="png",
image_mode="RGB",
type="pil",
visible=False,
)
text_prompt = gr.Textbox(
label="Text Prompt (Chinese or English)",
placeholder="Input text prompt here",
)
ip_image = gr.Image(
label="Reference Image(optional)",
format="png",
image_mode="RGB",
type="filepath",
height=250,
elem_classes=["image_fit"],
)
gr.Markdown(
"Note: The `reference image` is optional, if use, "
"please provide image in nearly square resolution."
)
with gr.Accordion(label="Image Generation Settings", open=False):
with gr.Row():
seed = gr.Slider(
0, MAX_SEED, label="Seed", value=0, step=1
)
randomize_seed = gr.Checkbox(
label="Randomize Seed", value=False
)
rmbg_tag = gr.Radio(
choices=["rembg", "rmbg14"],
value="rembg",
label="Background Removal Model",
)
ip_adapt_scale = gr.Slider(
0, 1, label="IP-adapter Scale", value=0.3, step=0.05
)
img_guidance_scale = gr.Slider(
1, 30, label="Text Guidance Scale", value=12, step=0.2
)
img_inference_steps = gr.Slider(
10, 100, label="Sampling Steps", value=50, step=5
)
img_resolution = gr.Slider(
512,
1536,
label="Image Resolution",
value=1024,
step=128,
)
generate_img_btn = gr.Button(
"🎨 1. Generate Images(~1min)",
variant="primary",
interactive=False,
)
dropdown = gr.Radio(
choices=["sample1", "sample2", "sample3"],
value="sample1",
label="Choose your favorite sample style.",
)
select_img = gr.Image(
visible=False,
format="png",
image_mode="RGBA",
type="pil",
height=300,
)
# text to 3d
with gr.Accordion(label="Generation Settings", open=False):
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
texture_size = gr.Slider(
1024, 4096, label="UV texture size", value=2048, step=256
)
with gr.Row():
randomize_seed = gr.Checkbox(
label="Randomize Seed", value=False
)
project_delight = gr.Checkbox(
label="backproject delight", value=False
)
gr.Markdown("Geo Structure Generation")
with gr.Row():
ss_guidance_strength = gr.Slider(
0.0,
10.0,
label="Guidance Strength",
value=7.5,
step=0.1,
)
ss_sampling_steps = gr.Slider(
1, 50, label="Sampling Steps", value=12, step=1
)
gr.Markdown("Visual Appearance Generation")
with gr.Row():
slat_guidance_strength = gr.Slider(
0.0,
10.0,
label="Guidance Strength",
value=3.0,
step=0.1,
)
slat_sampling_steps = gr.Slider(
1, 50, label="Sampling Steps", value=12, step=1
)
generate_btn = gr.Button(
"🚀 2. Generate 3D(~0.5 mins)",
variant="primary",
interactive=False,
)
model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False)
with gr.Row():
extract_rep3d_btn = gr.Button(
"🔍 3. Extract 3D Representation(~1 mins)",
variant="primary",
interactive=False,
)
with gr.Accordion(
label="Enter Asset Attributes(optional)", open=False
):
asset_cat_text = gr.Textbox(
label="Enter Asset Category (e.g., chair)"
)
height_range_text = gr.Textbox(
label="Enter Height Range in meter (e.g., 0.5-0.6)"
)
mass_range_text = gr.Textbox(
label="Enter Mass Range in kg (e.g., 1.1-1.2)"
)
asset_version_text = gr.Textbox(
label=f"Enter version (e.g., {VERSION})"
)
with gr.Row():
extract_urdf_btn = gr.Button(
"🧩 4. Extract URDF with physics(~1 mins)",
variant="primary",
interactive=False,
)
with gr.Row():
download_urdf = gr.DownloadButton(
label="⬇️ 5. Download URDF",
variant="primary",
interactive=False,
)
with gr.Column(scale=3):
with gr.Row():
image_sample1 = gr.Image(
label="sample1",
format="png",
image_mode="RGBA",
type="filepath",
height=300,
interactive=False,
elem_classes=["image_fit"],
)
image_sample2 = gr.Image(
label="sample2",
format="png",
image_mode="RGBA",
type="filepath",
height=300,
interactive=False,
elem_classes=["image_fit"],
)
image_sample3 = gr.Image(
label="sample3",
format="png",
image_mode="RGBA",
type="filepath",
height=300,
interactive=False,
elem_classes=["image_fit"],
)
usample1 = gr.Image(
format="png",
image_mode="RGBA",
type="filepath",
visible=False,
)
usample2 = gr.Image(
format="png",
image_mode="RGBA",
type="filepath",
visible=False,
)
usample3 = gr.Image(
format="png",
image_mode="RGBA",
type="filepath",
visible=False,
)
gr.Markdown(
"The generated image may be of poor quality due to auto "
"segmentation. Try adjusting the text prompt or seed."
)
with gr.Row():
video_output = gr.Video(
label="Generated 3D Asset",
autoplay=True,
loop=True,
height=300,
interactive=False,
)
model_output_gs = gr.Model3D(
label="Gaussian Representation",
height=300,
interactive=False,
)
aligned_gs = gr.Textbox(visible=False)
model_output_mesh = gr.Model3D(
label="Mesh Representation",
clear_color=[0.8, 0.8, 0.8, 1],
height=300,
interactive=False,
elem_id="lighter_mesh",
)
gr.Markdown("Estimated Asset 3D Attributes(No input required)")
with gr.Row():
est_type_text = gr.Textbox(
label="Asset category", interactive=False
)
est_height_text = gr.Textbox(
label="Real height(.m)", interactive=False
)
est_mass_text = gr.Textbox(
label="Mass(.kg)", interactive=False
)
est_mu_text = gr.Textbox(
label="Friction coefficient", interactive=False
)
prompt_examples = [
"satin gold tea cup with saucer",
"small bronze figurine of a lion",
"brown leather bag",
"Miniature cup with floral design",
"带木质底座, 具有经纬线的地球仪",
"橙色电动手钻, 有磨损细节",
"手工制作的皮革笔记本",
]
examples = gr.Examples(
label="Gallery",
examples=prompt_examples,
inputs=[text_prompt],
examples_per_page=10,
)
output_buf = gr.State()
demo.load(start_session)
demo.unload(end_session)
text_prompt.change(
active_btn_by_text_content,
inputs=[text_prompt],
outputs=[generate_img_btn],
)
generate_img_btn.click(
lambda: tuple(
[
gr.Button(interactive=False),
gr.Button(interactive=False),
gr.Button(interactive=False),
gr.Button(interactive=False),
None,
"",
None,
None,
"",
"",
"",
"",
"",
"",
"",
"",
None,
None,
None,
]
),
outputs=[
extract_rep3d_btn,
extract_urdf_btn,
download_urdf,
generate_btn,
model_output_gs,
aligned_gs,
model_output_mesh,
video_output,
asset_cat_text,
height_range_text,
mass_range_text,
asset_version_text,
est_type_text,
est_height_text,
est_mass_text,
est_mu_text,
image_sample1,
image_sample2,
image_sample3,
],
).success(
text2image_fn,
inputs=[
text_prompt,
img_guidance_scale,
img_inference_steps,
ip_image,
ip_adapt_scale,
img_resolution,
rmbg_tag,
seed,
],
outputs=[
image_sample1,
image_sample2,
image_sample3,
usample1,
usample2,
usample3,
],
).success(
lambda: gr.Button(interactive=True),
outputs=[generate_btn],
)
generate_btn.click(
get_seed,
inputs=[randomize_seed, seed],
outputs=[seed],
).success(
get_selected_image,
inputs=[dropdown, usample1, usample2, usample3],
outputs=select_img,
).success(
get_cached_image,
inputs=[select_img],
outputs=[raw_image_cache],
).success(
image_to_3d,
inputs=[
select_img,
seed,
ss_guidance_strength,
ss_sampling_steps,
slat_guidance_strength,
slat_sampling_steps,
raw_image_cache,
],
outputs=[output_buf, video_output],
).success(
lambda: gr.Button(interactive=True),
outputs=[extract_rep3d_btn],
)
extract_rep3d_btn.click(
extract_3d_representations_v2,
inputs=[
output_buf,
project_delight,
texture_size,
],
outputs=[
model_output_mesh,
model_output_gs,
model_output_obj,
aligned_gs,
],
).success(
lambda: gr.Button(interactive=True),
outputs=[extract_urdf_btn],
)
extract_urdf_btn.click(
extract_urdf,
inputs=[
aligned_gs,
model_output_obj,
asset_cat_text,
height_range_text,
mass_range_text,
asset_version_text,
],
outputs=[
download_urdf,
est_type_text,
est_height_text,
est_mass_text,
est_mu_text,
],
queue=True,
show_progress="full",
).success(
lambda: gr.Button(interactive=True),
outputs=[download_urdf],
)
if __name__ == "__main__":
demo.launch(server_name="10.34.8.82", server_port=8082)

382
apps/texture_edit.py Normal file
View File

@ -0,0 +1,382 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import os
os.environ["GRADIO_APP"] = "texture_edit"
import gradio as gr
from common import (
MAX_SEED,
VERSION,
backproject_texture_v2,
custom_theme,
end_session,
generate_condition,
generate_texture_mvimages,
get_seed,
get_selected_image,
image_css,
lighting_css,
render_result_video,
start_session,
)
def active_btn_by_content(mesh_content: gr.Model3D, text_content: gr.Textbox):
if (
mesh_content is not None
and text_content is not None
and len(text_content) > 0
):
interactive = True
else:
interactive = False
return gr.Button(interactive=interactive)
with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
gr.Markdown(
"""
## ***EmbodiedGen***: Texture Generation
**🔖 Version**: {VERSION}
<p style="display: flex; gap: 10px; flex-wrap: nowrap;">
<a href="https://horizonrobotics.github.io/robot_lab/embodied_gen/index.html">
<img alt="🌐 Project Page" src="https://img.shields.io/badge/🌐-Project_Page-blue">
</a>
<a href="https://arxiv.org/abs/xxxx.xxxxx">
<img alt="📄 arXiv" src="https://img.shields.io/badge/📄-arXiv-b31b1b">
</a>
<a href="https://github.com/HorizonRobotics/EmbodiedGen">
<img alt="💻 GitHub" src="https://img.shields.io/badge/GitHub-000000?logo=github">
</a>
<a href="https://www.youtube.com/watch?v=SnHhzHeb_aI">
<img alt="🎥 Video" src="https://img.shields.io/badge/🎥-Video-red">
</a>
</p>
🎨 Generate visually rich textures for 3D mesh.
""".format(
VERSION=VERSION
),
elem_classes=["header"],
)
gr.HTML(image_css)
gr.HTML(lighting_css)
with gr.Row():
with gr.Column(scale=1):
mesh_input = gr.Model3D(
label="Upload Mesh File(.obj or .glb)", height=300
)
local_mesh = gr.Textbox(visible=False)
text_prompt = gr.Textbox(
label="Text Prompt (Chinese or English)",
placeholder="Input text prompt here",
)
ip_image = gr.Image(
label="Reference Image(optional)",
format="png",
image_mode="RGB",
type="filepath",
height=250,
elem_classes=["image_fit"],
)
gr.Markdown(
"Note: The `reference image` is optional. If provided, please "
"increase the `Condition Scale` in Generation Settings."
)
with gr.Accordion(label="Generation Settings", open=False):
with gr.Row():
seed = gr.Slider(
0, MAX_SEED, label="Seed", value=0, step=1
)
randomize_seed = gr.Checkbox(
label="Randomize Seed", value=False
)
ip_adapt_scale = gr.Slider(
0, 1, label="IP-adapter Scale", value=0.7, step=0.05
)
cond_scale = gr.Slider(
0.0,
1.0,
label="Geo Condition Scale",
value=0.60,
step=0.01,
)
guidance_scale = gr.Slider(
1, 30, label="Text Guidance Scale", value=9, step=0.2
)
guidance_strength = gr.Slider(
0.0,
1.0,
label="Strength",
value=0.9,
step=0.05,
)
num_inference_steps = gr.Slider(
10, 100, label="Sampling Steps", value=50, step=5
)
texture_size = gr.Slider(
1024, 4096, label="UV texture size", value=2048, step=256
)
video_size = gr.Slider(
512, 2048, label="Video Resolution", value=512, step=256
)
generate_mv_btn = gr.Button(
"🎨 1. Generate MV Images(~1min)",
variant="primary",
interactive=False,
)
with gr.Column(scale=3):
with gr.Row():
image_sample1 = gr.Image(
label="sample1",
format="png",
image_mode="RGBA",
type="filepath",
height=300,
interactive=False,
elem_classes=["image_fit"],
)
image_sample2 = gr.Image(
label="sample2",
format="png",
image_mode="RGBA",
type="filepath",
height=300,
interactive=False,
elem_classes=["image_fit"],
)
image_sample3 = gr.Image(
label="sample3",
format="png",
image_mode="RGBA",
type="filepath",
height=300,
interactive=False,
elem_classes=["image_fit"],
)
usample1 = gr.Image(
format="png",
image_mode="RGBA",
type="filepath",
visible=False,
)
usample2 = gr.Image(
format="png",
image_mode="RGBA",
type="filepath",
visible=False,
)
usample3 = gr.Image(
format="png",
image_mode="RGBA",
type="filepath",
visible=False,
)
gr.Markdown(
"Note: Select samples with consistent textures from various "
"perspectives and no obvious reflections."
)
with gr.Row():
with gr.Column(scale=1):
with gr.Row():
dropdown = gr.Radio(
choices=["sample1", "sample2", "sample3"],
value="sample1",
label="Choose your favorite sample style.",
)
select_img = gr.Image(
visible=False,
format="png",
image_mode="RGBA",
type="filepath",
height=300,
)
with gr.Row():
project_delight = gr.Checkbox(
label="delight", value=True
)
fix_mesh = gr.Checkbox(
label="simplify mesh", value=False
)
with gr.Column(scale=1):
texture_bake_btn = gr.Button(
"🛠️ 2. Texture Baking(~2min)",
variant="primary",
interactive=False,
)
download_btn = gr.DownloadButton(
label="⬇️ 3. Download Mesh",
variant="primary",
interactive=False,
)
with gr.Row():
mesh_output = gr.Model3D(
label="Mesh Edit Result",
clear_color=[0.8, 0.8, 0.8, 1],
height=380,
interactive=False,
elem_id="lighter_mesh",
)
mesh_outpath = gr.Textbox(visible=False)
video_output = gr.Video(
label="Mesh Edit Video",
autoplay=True,
loop=True,
height=380,
)
with gr.Row():
prompt_examples = []
with open("apps/assets/example_texture/text_prompts.txt", "r") as f:
for line in f:
parts = line.strip().split("\\")
prompt_examples.append([parts[0].strip(), parts[1].strip()])
examples = gr.Examples(
label="Mesh Gallery",
examples=prompt_examples,
inputs=[mesh_input, text_prompt],
examples_per_page=10,
)
demo.load(start_session)
demo.unload(end_session)
mesh_input.change(
lambda: tuple(
[
None,
None,
None,
gr.Button(interactive=False),
gr.Button(interactive=False),
None,
None,
None,
]
),
outputs=[
mesh_outpath,
mesh_output,
video_output,
texture_bake_btn,
download_btn,
image_sample1,
image_sample2,
image_sample3,
],
).success(
active_btn_by_content,
inputs=[mesh_input, text_prompt],
outputs=[generate_mv_btn],
)
text_prompt.change(
active_btn_by_content,
inputs=[mesh_input, text_prompt],
outputs=[generate_mv_btn],
)
generate_mv_btn.click(
get_seed,
inputs=[randomize_seed, seed],
outputs=[seed],
).success(
lambda: tuple(
[
None,
None,
None,
gr.Button(interactive=False),
gr.Button(interactive=False),
]
),
outputs=[
mesh_outpath,
mesh_output,
video_output,
texture_bake_btn,
download_btn,
],
).success(
generate_condition,
inputs=[mesh_input],
outputs=[image_sample1, image_sample2, image_sample3],
).success(
generate_texture_mvimages,
inputs=[
text_prompt,
cond_scale,
guidance_scale,
guidance_strength,
num_inference_steps,
seed,
ip_adapt_scale,
ip_image,
],
outputs=[
image_sample1,
image_sample2,
image_sample3,
usample1,
usample2,
usample3,
],
).success(
lambda: gr.Button(interactive=True),
outputs=[texture_bake_btn],
)
texture_bake_btn.click(
lambda: tuple([None, None, None, gr.Button(interactive=False)]),
outputs=[mesh_outpath, mesh_output, video_output, download_btn],
).success(
get_selected_image,
inputs=[dropdown, usample1, usample2, usample3],
outputs=select_img,
).success(
backproject_texture_v2,
inputs=[
mesh_input,
select_img,
texture_size,
project_delight,
fix_mesh,
],
outputs=[mesh_output, mesh_outpath, download_btn],
).success(
lambda: gr.DownloadButton(interactive=True),
outputs=[download_btn],
).success(
render_result_video,
inputs=[mesh_outpath, video_size],
outputs=[video_output],
)
if __name__ == "__main__":
demo.launch(server_name="10.34.8.82", server_port=8083)

View File

@ -0,0 +1,518 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import argparse
import logging
import math
import os
from typing import List, Literal, Union
import cv2
import numpy as np
import nvdiffrast.torch as dr
import torch
import trimesh
import utils3d
import xatlas
from tqdm import tqdm
from embodied_gen.data.mesh_operator import MeshFixer
from embodied_gen.data.utils import (
CameraSetting,
get_images_from_grid,
init_kal_camera,
normalize_vertices_array,
post_process_texture,
save_mesh_with_mtl,
)
from embodied_gen.models.delight_model import DelightingModel
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)
class TextureBaker(object):
"""Baking textures onto a mesh from multiple observations.
This class take 3D mesh data, camera settings and texture baking parameters
to generate texture map by projecting images to the mesh from diff views.
It supports both a fast texture baking approach and a more optimized method
with total variation regularization.
Attributes:
vertices (torch.Tensor): The vertices of the mesh.
faces (torch.Tensor): The faces of the mesh, defined by vertex indices.
uvs (torch.Tensor): The UV coordinates of the mesh.
camera_params (CameraSetting): Camera setting (intrinsics, extrinsics).
device (str): The device to run computations on ("cpu" or "cuda").
w2cs (torch.Tensor): World-to-camera transformation matrices.
projections (torch.Tensor): Camera projection matrices.
Example:
>>> vertices, faces, uvs = TextureBaker.parametrize_mesh(vertices, faces) # noqa
>>> texture_backer = TextureBaker(vertices, faces, uvs, camera_params)
>>> images = get_images_from_grid(args.color_path, image_size)
>>> texture = texture_backer.bake_texture(
... images, texture_size=args.texture_size, mode=args.baker_mode
... )
>>> texture = post_process_texture(texture)
"""
def __init__(
self,
vertices: np.ndarray,
faces: np.ndarray,
uvs: np.ndarray,
camera_params: CameraSetting,
device: str = "cuda",
) -> None:
self.vertices = (
torch.tensor(vertices, device=device)
if isinstance(vertices, np.ndarray)
else vertices.to(device)
)
self.faces = (
torch.tensor(faces.astype(np.int32), device=device)
if isinstance(faces, np.ndarray)
else faces.to(device)
)
self.uvs = (
torch.tensor(uvs, device=device)
if isinstance(uvs, np.ndarray)
else uvs.to(device)
)
self.camera_params = camera_params
self.device = device
camera = init_kal_camera(camera_params)
matrix_mv = camera.view_matrix() # (n_cam 4 4) world2cam
matrix_mv = kaolin_to_opencv_view(matrix_mv)
matrix_p = (
camera.intrinsics.projection_matrix()
) # (n_cam 4 4) cam2pixel
self.w2cs = matrix_mv.to(self.device)
self.projections = matrix_p.to(self.device)
@staticmethod
def parametrize_mesh(
vertices: np.array, faces: np.array
) -> Union[np.array, np.array, np.array]:
vmapping, indices, uvs = xatlas.parametrize(vertices, faces)
vertices = vertices[vmapping]
faces = indices
return vertices, faces, uvs
def _bake_fast(self, observations, w2cs, projections, texture_size, masks):
texture = torch.zeros(
(texture_size * texture_size, 3), dtype=torch.float32
).cuda()
texture_weights = torch.zeros(
(texture_size * texture_size), dtype=torch.float32
).cuda()
rastctx = utils3d.torch.RastContext(backend="cuda")
for observation, w2c, projection in tqdm(
zip(observations, w2cs, projections),
total=len(observations),
desc="Texture baking (fast)",
):
with torch.no_grad():
rast = utils3d.torch.rasterize_triangle_faces(
rastctx,
self.vertices[None],
self.faces,
observation.shape[1],
observation.shape[0],
uv=self.uvs[None],
view=w2c,
projection=projection,
)
uv_map = rast["uv"][0].detach().flip(0)
mask = rast["mask"][0].detach().bool() & masks[0]
# nearest neighbor interpolation
uv_map = (uv_map * texture_size).floor().long()
obs = observation[mask]
uv_map = uv_map[mask]
idx = (
uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size
)
texture = texture.scatter_add(
0, idx.view(-1, 1).expand(-1, 3), obs
)
texture_weights = texture_weights.scatter_add(
0,
idx,
torch.ones(
(obs.shape[0]), dtype=torch.float32, device=texture.device
),
)
mask = texture_weights > 0
texture[mask] /= texture_weights[mask][:, None]
texture = np.clip(
texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255,
0,
255,
).astype(np.uint8)
# inpaint
mask = (
(texture_weights == 0)
.cpu()
.numpy()
.astype(np.uint8)
.reshape(texture_size, texture_size)
)
texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
return texture
def _bake_opt(
self,
observations,
w2cs,
projections,
texture_size,
lambda_tv,
masks,
total_steps,
):
rastctx = utils3d.torch.RastContext(backend="cuda")
observations = [observations.flip(0) for observations in observations]
masks = [m.flip(0) for m in masks]
_uv = []
_uv_dr = []
for observation, w2c, projection in tqdm(
zip(observations, w2cs, projections),
total=len(w2cs),
):
with torch.no_grad():
rast = utils3d.torch.rasterize_triangle_faces(
rastctx,
self.vertices[None],
self.faces,
observation.shape[1],
observation.shape[0],
uv=self.uvs[None],
view=w2c,
projection=projection,
)
_uv.append(rast["uv"].detach())
_uv_dr.append(rast["uv_dr"].detach())
texture = torch.nn.Parameter(
torch.zeros(
(1, texture_size, texture_size, 3), dtype=torch.float32
).cuda()
)
optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2)
def cosine_anealing(step, total_steps, start_lr, end_lr):
return end_lr + 0.5 * (start_lr - end_lr) * (
1 + np.cos(np.pi * step / total_steps)
)
def tv_loss(texture):
return torch.nn.functional.l1_loss(
texture[:, :-1, :, :], texture[:, 1:, :, :]
) + torch.nn.functional.l1_loss(
texture[:, :, :-1, :], texture[:, :, 1:, :]
)
with tqdm(total=total_steps, desc="Texture baking") as pbar:
for step in range(total_steps):
optimizer.zero_grad()
selected = np.random.randint(0, len(w2cs))
uv, uv_dr, observation, mask = (
_uv[selected],
_uv_dr[selected],
observations[selected],
masks[selected],
)
render = dr.texture(texture, uv, uv_dr)[0]
loss = torch.nn.functional.l1_loss(
render[mask], observation[mask]
)
if lambda_tv > 0:
loss += lambda_tv * tv_loss(texture)
loss.backward()
optimizer.step()
optimizer.param_groups[0]["lr"] = cosine_anealing(
step, total_steps, 1e-2, 1e-5
)
pbar.set_postfix({"loss": loss.item()})
pbar.update()
texture = np.clip(
texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255
).astype(np.uint8)
mask = 1 - utils3d.torch.rasterize_triangle_faces(
rastctx,
(self.uvs * 2 - 1)[None],
self.faces,
texture_size,
texture_size,
)["mask"][0].detach().cpu().numpy().astype(np.uint8)
texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
return texture
def bake_texture(
self,
images: List[np.array],
texture_size: int = 1024,
mode: Literal["fast", "opt"] = "opt",
lambda_tv: float = 1e-2,
opt_step: int = 2000,
):
masks = [np.any(img > 0, axis=-1) for img in images]
masks = [torch.tensor(m > 0).bool().to(self.device) for m in masks]
images = [
torch.tensor(obs / 255.0).float().to(self.device) for obs in images
]
if mode == "fast":
return self._bake_fast(
images, self.w2cs, self.projections, texture_size, masks
)
elif mode == "opt":
return self._bake_opt(
images,
self.w2cs,
self.projections,
texture_size,
lambda_tv,
masks,
opt_step,
)
else:
raise ValueError(f"Unknown mode: {mode}")
def kaolin_to_opencv_view(raw_matrix):
R_orig = raw_matrix[:, :3, :3]
t_orig = raw_matrix[:, :3, 3]
R_target = torch.zeros_like(R_orig)
R_target[:, :, 0] = R_orig[:, :, 2]
R_target[:, :, 1] = R_orig[:, :, 0]
R_target[:, :, 2] = R_orig[:, :, 1]
t_target = t_orig
target_matrix = (
torch.eye(4, device=raw_matrix.device)
.unsqueeze(0)
.repeat(raw_matrix.size(0), 1, 1)
)
target_matrix[:, :3, :3] = R_target
target_matrix[:, :3, 3] = t_target
return target_matrix
def parse_args():
parser = argparse.ArgumentParser(description="Render settings")
parser.add_argument(
"--mesh_path",
type=str,
nargs="+",
required=True,
help="Paths to the mesh files for rendering.",
)
parser.add_argument(
"--color_path",
type=str,
nargs="+",
required=True,
help="Paths to the mesh files for rendering.",
)
parser.add_argument(
"--output_root",
type=str,
default="./outputs",
help="Root directory for output",
)
parser.add_argument(
"--uuid",
type=str,
nargs="+",
default=None,
help="uuid for rendering saving.",
)
parser.add_argument(
"--num_images", type=int, default=6, help="Number of images to render."
)
parser.add_argument(
"--elevation",
type=float,
nargs="+",
default=[20.0, -10.0],
help="Elevation angles for the camera (default: [20.0, -10.0])",
)
parser.add_argument(
"--distance",
type=float,
default=5,
help="Camera distance (default: 5)",
)
parser.add_argument(
"--resolution_hw",
type=int,
nargs=2,
default=(512, 512),
help="Resolution of the output images (default: (512, 512))",
)
parser.add_argument(
"--fov",
type=float,
default=30,
help="Field of view in degrees (default: 30)",
)
parser.add_argument(
"--device",
type=str,
choices=["cpu", "cuda"],
default="cuda",
help="Device to run on (default: `cuda`)",
)
parser.add_argument(
"--texture_size",
type=int,
default=1024,
help="Texture size for texture baking (default: 1024)",
)
parser.add_argument(
"--baker_mode",
type=str,
default="opt",
help="Texture baking mode, `fast` or `opt` (default: opt)",
)
parser.add_argument(
"--opt_step",
type=int,
default=2500,
help="Optimization steps for texture baking (default: 2500)",
)
parser.add_argument(
"--mesh_sipmlify_ratio",
type=float,
default=0.9,
help="Mesh simplification ratio (default: 0.9)",
)
parser.add_argument(
"--no_coor_trans",
action="store_true",
help="Do not transform the asset coordinate system.",
)
parser.add_argument(
"--delight", action="store_true", help="Use delighting model."
)
parser.add_argument(
"--skip_fix_mesh", action="store_true", help="Fix mesh geometry."
)
args = parser.parse_args()
if args.uuid is None:
args.uuid = []
for path in args.mesh_path:
uuid = os.path.basename(path).split(".")[0]
args.uuid.append(uuid)
return args
def entrypoint() -> None:
args = parse_args()
camera_params = CameraSetting(
num_images=args.num_images,
elevation=args.elevation,
distance=args.distance,
resolution_hw=args.resolution_hw,
fov=math.radians(args.fov),
device=args.device,
)
for mesh_path, uuid, img_path in zip(
args.mesh_path, args.uuid, args.color_path
):
mesh = trimesh.load(mesh_path)
if isinstance(mesh, trimesh.Scene):
mesh = mesh.dump(concatenate=True)
vertices, scale, center = normalize_vertices_array(mesh.vertices)
if not args.no_coor_trans:
x_rot = np.array([[1, 0, 0], [0, 0, 1], [0, -1, 0]])
z_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])
vertices = vertices @ x_rot
vertices = vertices @ z_rot
faces = mesh.faces.astype(np.int32)
vertices = vertices.astype(np.float32)
if not args.skip_fix_mesh:
mesh_fixer = MeshFixer(vertices, faces, args.device)
vertices, faces = mesh_fixer(
filter_ratio=args.mesh_sipmlify_ratio,
max_hole_size=0.04,
resolution=1024,
num_views=1000,
norm_mesh_ratio=0.5,
)
vertices, faces, uvs = TextureBaker.parametrize_mesh(vertices, faces)
texture_backer = TextureBaker(
vertices,
faces,
uvs,
camera_params,
)
images = get_images_from_grid(
img_path, img_size=camera_params.resolution_hw[0]
)
if args.delight:
delight_model = DelightingModel()
images = [delight_model(img) for img in images]
images = [np.array(img) for img in images]
texture = texture_backer.bake_texture(
images=[img[..., :3] for img in images],
texture_size=args.texture_size,
mode=args.baker_mode,
opt_step=args.opt_step,
)
texture = post_process_texture(texture)
if not args.no_coor_trans:
vertices = vertices @ np.linalg.inv(z_rot)
vertices = vertices @ np.linalg.inv(x_rot)
vertices = vertices / scale
vertices = vertices + center
output_path = os.path.join(args.output_root, f"{uuid}.obj")
mesh = save_mesh_with_mtl(vertices, faces, uvs, texture, output_path)
return
if __name__ == "__main__":
entrypoint()

View File

@ -0,0 +1,702 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import argparse
import logging
import math
import os
import cv2
import numpy as np
import nvdiffrast.torch as dr
import spaces
import torch
import torch.nn.functional as F
import trimesh
import xatlas
from PIL import Image
from embodied_gen.data.mesh_operator import MeshFixer
from embodied_gen.data.utils import (
CameraSetting,
DiffrastRender,
get_images_from_grid,
init_kal_camera,
normalize_vertices_array,
post_process_texture,
save_mesh_with_mtl,
)
from embodied_gen.models.delight_model import DelightingModel
from embodied_gen.models.sr_model import ImageRealESRGAN
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)
__all__ = [
"TextureBacker",
]
def _transform_vertices(
mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False
) -> torch.Tensor:
"""Transform 3D vertices using a projection matrix."""
t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype)
if pos.size(-1) == 3:
pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1)
result = pos @ t_mtx.T
return result if keepdim else result.unsqueeze(0)
def _bilinear_interpolation_scattering(
image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor
) -> torch.Tensor:
"""Bilinear interpolation scattering for grid-based value accumulation."""
device = values.device
dtype = values.dtype
C = values.shape[-1]
indices = coords * torch.tensor(
[image_h - 1, image_w - 1], dtype=dtype, device=device
)
i, j = indices.unbind(-1)
i0, j0 = (
indices.floor()
.long()
.clamp(0, image_h - 2)
.clamp(0, image_w - 2)
.unbind(-1)
)
i1, j1 = i0 + 1, j0 + 1
w_i = i - i0.float()
w_j = j - j0.float()
weights = torch.stack(
[(1 - w_i) * (1 - w_j), (1 - w_i) * w_j, w_i * (1 - w_j), w_i * w_j],
dim=1,
)
indices_comb = torch.stack(
[
torch.stack([i0, j0], dim=1),
torch.stack([i0, j1], dim=1),
torch.stack([i1, j0], dim=1),
torch.stack([i1, j1], dim=1),
],
dim=1,
)
grid = torch.zeros(image_h, image_w, C, device=device, dtype=dtype)
cnt = torch.zeros(image_h, image_w, 1, device=device, dtype=dtype)
for k in range(4):
idx = indices_comb[:, k]
w = weights[:, k].unsqueeze(-1)
stride = torch.tensor([image_w, 1], device=device, dtype=torch.long)
flat_idx = (idx * stride).sum(-1)
grid.view(-1, C).scatter_add_(
0, flat_idx.unsqueeze(-1).expand(-1, C), values * w
)
cnt.view(-1, 1).scatter_add_(0, flat_idx.unsqueeze(-1), w)
mask = cnt.squeeze(-1) > 0
grid[mask] = grid[mask] / cnt[mask].repeat(1, C)
return grid
def _texture_inpaint_smooth(
texture: np.ndarray,
mask: np.ndarray,
vertices: np.ndarray,
faces: np.ndarray,
uv_map: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
"""Perform texture inpainting using vertex-based color propagation."""
image_h, image_w, C = texture.shape
N = vertices.shape[0]
# Initialize vertex data structures
vtx_mask = np.zeros(N, dtype=np.float32)
vtx_colors = np.zeros((N, C), dtype=np.float32)
unprocessed = []
adjacency = [[] for _ in range(N)]
# Build adjacency graph and initial color assignment
for face_idx in range(faces.shape[0]):
for k in range(3):
uv_idx_k = faces[face_idx, k]
v_idx = faces[face_idx, k]
# Convert UV to pixel coordinates with boundary clamping
u = np.clip(
int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
)
v = np.clip(
int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
0,
image_h - 1,
)
if mask[v, u]:
vtx_mask[v_idx] = 1.0
vtx_colors[v_idx] = texture[v, u]
elif v_idx not in unprocessed:
unprocessed.append(v_idx)
# Build undirected adjacency graph
neighbor = faces[face_idx, (k + 1) % 3]
if neighbor not in adjacency[v_idx]:
adjacency[v_idx].append(neighbor)
if v_idx not in adjacency[neighbor]:
adjacency[neighbor].append(v_idx)
# Color propagation with dynamic stopping
remaining_iters, prev_count = 2, 0
while remaining_iters > 0:
current_unprocessed = []
for v_idx in unprocessed:
valid_neighbors = [n for n in adjacency[v_idx] if vtx_mask[n] > 0]
if not valid_neighbors:
current_unprocessed.append(v_idx)
continue
# Calculate inverse square distance weights
neighbors_pos = vertices[valid_neighbors]
dist_sq = np.sum((vertices[v_idx] - neighbors_pos) ** 2, axis=1)
weights = 1 / np.maximum(dist_sq, 1e-8)
vtx_colors[v_idx] = np.average(
vtx_colors[valid_neighbors], weights=weights, axis=0
)
vtx_mask[v_idx] = 1.0
# Update iteration control
if len(current_unprocessed) == prev_count:
remaining_iters -= 1
else:
remaining_iters = min(remaining_iters + 1, 2)
prev_count = len(current_unprocessed)
unprocessed = current_unprocessed
# Generate output texture
inpainted_texture, updated_mask = texture.copy(), mask.copy()
for face_idx in range(faces.shape[0]):
for k in range(3):
v_idx = faces[face_idx, k]
if not vtx_mask[v_idx]:
continue
# UV coordinate conversion
uv_idx_k = faces[face_idx, k]
u = np.clip(
int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
)
v = np.clip(
int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
0,
image_h - 1,
)
inpainted_texture[v, u] = vtx_colors[v_idx]
updated_mask[v, u] = 255
return inpainted_texture, updated_mask
class TextureBacker:
"""Texture baking pipeline for multi-view projection and fusion.
This class performs UV-based texture generation for a 3D mesh using
multi-view color images, depth, and normal information. The pipeline
includes mesh normalization and UV unwrapping, visibility-aware
back-projection, confidence-weighted texture fusion, and inpainting
of missing texture regions.
Args:
camera_params (CameraSetting): Camera intrinsics and extrinsics used
for rendering each view.
view_weights (list[float]): A list of weights for each view, used
to blend confidence maps during texture fusion.
render_wh (tuple[int, int], optional): Resolution (width, height) for
intermediate rendering passes. Defaults to (2048, 2048).
texture_wh (tuple[int, int], optional): Output texture resolution
(width, height). Defaults to (2048, 2048).
bake_angle_thresh (int, optional): Maximum angle (in degrees) between
view direction and surface normal for projection to be considered valid.
Defaults to 75.
mask_thresh (float, optional): Threshold applied to visibility masks
during rendering. Defaults to 0.5.
smooth_texture (bool, optional): If True, apply post-processing (e.g.,
blurring) to the final texture. Defaults to True.
"""
def __init__(
self,
camera_params: CameraSetting,
view_weights: list[float],
render_wh: tuple[int, int] = (2048, 2048),
texture_wh: tuple[int, int] = (2048, 2048),
bake_angle_thresh: int = 75,
mask_thresh: float = 0.5,
smooth_texture: bool = True,
) -> None:
self.camera_params = camera_params
self.renderer = None
self.view_weights = view_weights
self.device = camera_params.device
self.render_wh = render_wh
self.texture_wh = texture_wh
self.mask_thresh = mask_thresh
self.smooth_texture = smooth_texture
self.bake_angle_thresh = bake_angle_thresh
self.bake_unreliable_kernel_size = int(
(2 / 512) * max(self.render_wh[0], self.render_wh[1])
)
def _lazy_init_render(self, camera_params, mask_thresh):
if self.renderer is None:
camera = init_kal_camera(camera_params)
mv = camera.view_matrix() # (n 4 4) world2cam
p = camera.intrinsics.projection_matrix()
# NOTE: add a negative sign at P[0, 2] as the y axis is flipped in `nvdiffrast` output. # noqa
p[:, 1, 1] = -p[:, 1, 1]
self.renderer = DiffrastRender(
p_matrix=p,
mv_matrix=mv,
resolution_hw=camera_params.resolution_hw,
context=dr.RasterizeCudaContext(),
mask_thresh=mask_thresh,
grad_db=False,
device=self.device,
antialias_mask=True,
)
def load_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh:
mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
self.scale, self.center = scale, center
vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces)
uvs[:, 1] = 1 - uvs[:, 1]
mesh.vertices = mesh.vertices[vmapping]
mesh.faces = indices
mesh.visual.uv = uvs
return mesh
def get_mesh_np_attrs(
self,
mesh: trimesh.Trimesh,
scale: float = None,
center: np.ndarray = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
vertices = mesh.vertices.copy()
faces = mesh.faces.copy()
uv_map = mesh.visual.uv.copy()
uv_map[:, 1] = 1.0 - uv_map[:, 1]
if scale is not None:
vertices = vertices / scale
if center is not None:
vertices = vertices + center
return vertices, faces, uv_map
def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor:
depth_image_np = depth_image.cpu().numpy()
depth_image_np = (depth_image_np * 255).astype(np.uint8)
depth_edges = cv2.Canny(depth_image_np, 30, 80)
sketch_image = (
torch.from_numpy(depth_edges).to(depth_image.device).float() / 255
)
sketch_image = sketch_image.unsqueeze(-1)
return sketch_image
def compute_enhanced_viewnormal(
self, mv_mtx: torch.Tensor, vertices: torch.Tensor, faces: torch.Tensor
) -> torch.Tensor:
rast, _ = self.renderer.compute_dr_raster(vertices, faces)
rendered_view_normals = []
for idx in range(len(mv_mtx)):
pos_cam = _transform_vertices(mv_mtx[idx], vertices, keepdim=True)
pos_cam = pos_cam[:, :3] / pos_cam[:, 3:]
v0, v1, v2 = (pos_cam[faces[:, i]] for i in range(3))
face_norm = F.normalize(
torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1
)
vertex_norm = (
torch.from_numpy(
trimesh.geometry.mean_vertex_normals(
len(pos_cam), faces.cpu(), face_norm.cpu()
)
)
.to(vertices.device)
.contiguous()
)
im_base_normals, _ = dr.interpolate(
vertex_norm[None, ...].float(),
rast[idx : idx + 1],
faces.to(torch.int32),
)
rendered_view_normals.append(im_base_normals)
rendered_view_normals = torch.cat(rendered_view_normals, dim=0)
return rendered_view_normals
def back_project(
self, image, vis_mask, depth, normal, uv
) -> tuple[torch.Tensor, torch.Tensor]:
image = np.array(image)
image = torch.as_tensor(image, device=self.device, dtype=torch.float32)
if image.ndim == 2:
image = image.unsqueeze(-1)
image = image / 255
depth_inv = (1.0 - depth) * vis_mask
sketch_image = self._render_depth_edges(depth_inv)
cos = F.cosine_similarity(
torch.tensor([[0, 0, 1]], device=self.device),
normal.view(-1, 3),
).view_as(normal[..., :1])
cos[cos < np.cos(np.radians(self.bake_angle_thresh))] = 0
k = self.bake_unreliable_kernel_size * 2 + 1
kernel = torch.ones((1, 1, k, k), device=self.device)
vis_mask = vis_mask.permute(2, 0, 1).unsqueeze(0).float()
vis_mask = F.conv2d(
1.0 - vis_mask,
kernel,
padding=k // 2,
)
vis_mask = 1.0 - (vis_mask > 0).float()
vis_mask = vis_mask.squeeze(0).permute(1, 2, 0)
sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0)
sketch_image = F.conv2d(sketch_image, kernel, padding=k // 2)
sketch_image = (sketch_image > 0).float()
sketch_image = sketch_image.squeeze(0).permute(1, 2, 0)
vis_mask = vis_mask * (sketch_image < 0.5)
cos[vis_mask == 0] = 0
valid_pixels = (vis_mask != 0).view(-1)
return (
self._scatter_texture(uv, image, valid_pixels),
self._scatter_texture(uv, cos, valid_pixels),
)
def _scatter_texture(self, uv, data, mask):
def __filter_data(data, mask):
return data.view(-1, data.shape[-1])[mask]
return _bilinear_interpolation_scattering(
self.texture_wh[1],
self.texture_wh[0],
__filter_data(uv, mask)[..., [1, 0]],
__filter_data(data, mask),
)
@torch.no_grad()
def fast_bake_texture(
self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor]
) -> tuple[torch.Tensor, torch.Tensor]:
channel = textures[0].shape[-1]
texture_merge = torch.zeros(self.texture_wh + [channel]).to(
self.device
)
trust_map_merge = torch.zeros(self.texture_wh + [1]).to(self.device)
for texture, cos_map in zip(textures, confidence_maps):
view_sum = (cos_map > 0).sum()
painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum()
if painted_sum / view_sum > 0.99:
continue
texture_merge += texture * cos_map
trust_map_merge += cos_map
texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1e-8)
return texture_merge, trust_map_merge > 1e-8
def uv_inpaint(
self, mesh: trimesh.Trimesh, texture: np.ndarray, mask: np.ndarray
) -> np.ndarray:
vertices, faces, uv_map = self.get_mesh_np_attrs(mesh)
texture, mask = _texture_inpaint_smooth(
texture, mask, vertices, faces, uv_map
)
texture = texture.clip(0, 1)
texture = cv2.inpaint(
(texture * 255).astype(np.uint8),
255 - mask,
3,
cv2.INPAINT_NS,
)
return texture
@spaces.GPU
def compute_texture(
self,
colors: list[Image.Image],
mesh: trimesh.Trimesh,
) -> trimesh.Trimesh:
self._lazy_init_render(self.camera_params, self.mask_thresh)
vertices = torch.from_numpy(mesh.vertices).to(self.device).float()
faces = torch.from_numpy(mesh.faces).to(self.device).to(torch.int)
uv_map = torch.from_numpy(mesh.visual.uv).to(self.device).float()
rendered_depth, masks = self.renderer.render_depth(vertices, faces)
norm_deps = self.renderer.normalize_map_by_mask(rendered_depth, masks)
render_uvs, _ = self.renderer.render_uv(vertices, faces, uv_map)
view_normals = self.compute_enhanced_viewnormal(
self.renderer.mv_mtx, vertices, faces
)
textures, weighted_cos_maps = [], []
for color, mask, dep, normal, uv, weight in zip(
colors,
masks,
norm_deps,
view_normals,
render_uvs,
self.view_weights,
):
texture, cos_map = self.back_project(color, mask, dep, normal, uv)
textures.append(texture)
weighted_cos_maps.append(weight * (cos_map**4))
texture, mask = self.fast_bake_texture(textures, weighted_cos_maps)
texture_np = texture.cpu().numpy()
mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8)
return texture_np, mask_np
def __call__(
self,
colors: list[Image.Image],
mesh: trimesh.Trimesh,
output_path: str,
) -> trimesh.Trimesh:
"""Runs the texture baking and exports the textured mesh.
Args:
colors (list[Image.Image]): List of input view images.
mesh (trimesh.Trimesh): Input mesh to be textured.
output_path (str): Path to save the output textured mesh (.obj or .glb).
Returns:
trimesh.Trimesh: The textured mesh with UV and texture image.
"""
mesh = self.load_mesh(mesh)
texture_np, mask_np = self.compute_texture(colors, mesh)
texture_np = self.uv_inpaint(mesh, texture_np, mask_np)
if self.smooth_texture:
texture_np = post_process_texture(texture_np)
vertices, faces, uv_map = self.get_mesh_np_attrs(
mesh, self.scale, self.center
)
textured_mesh = save_mesh_with_mtl(
vertices, faces, uv_map, texture_np, output_path
)
return textured_mesh
def parse_args():
parser = argparse.ArgumentParser(description="Backproject texture")
parser.add_argument(
"--color_path",
type=str,
help="Multiview color image in 6x512x512 file path",
)
parser.add_argument(
"--mesh_path",
type=str,
help="Mesh path, .obj, .glb or .ply",
)
parser.add_argument(
"--output_path",
type=str,
help="Output mesh path with suffix",
)
parser.add_argument(
"--num_images", type=int, default=6, help="Number of images to render."
)
parser.add_argument(
"--elevation",
nargs=2,
type=float,
default=[20.0, -10.0],
help="Elevation angles for the camera (default: [20.0, -10.0])",
)
parser.add_argument(
"--distance",
type=float,
default=5,
help="Camera distance (default: 5)",
)
parser.add_argument(
"--resolution_hw",
type=int,
nargs=2,
default=(2048, 2048),
help="Resolution of the output images (default: (2048, 2048))",
)
parser.add_argument(
"--fov",
type=float,
default=30,
help="Field of view in degrees (default: 30)",
)
parser.add_argument(
"--device",
type=str,
choices=["cpu", "cuda"],
default="cuda",
help="Device to run on (default: `cuda`)",
)
parser.add_argument(
"--skip_fix_mesh", action="store_true", help="Fix mesh geometry."
)
parser.add_argument(
"--texture_wh",
nargs=2,
type=int,
default=[2048, 2048],
help="Texture resolution width and height",
)
parser.add_argument(
"--mesh_sipmlify_ratio",
type=float,
default=0.9,
help="Mesh simplification ratio (default: 0.9)",
)
parser.add_argument(
"--delight", action="store_true", help="Use delighting model."
)
parser.add_argument(
"--no_smooth_texture",
action="store_true",
help="Do not smooth the texture.",
)
parser.add_argument(
"--save_glb_path", type=str, default=None, help="Save glb path."
)
parser.add_argument(
"--no_save_delight_img",
action="store_true",
help="Disable saving delight image",
)
args, unknown = parser.parse_known_args()
return args
def entrypoint(
delight_model: DelightingModel = None,
imagesr_model: ImageRealESRGAN = None,
**kwargs,
) -> trimesh.Trimesh:
args = parse_args()
for k, v in kwargs.items():
if hasattr(args, k) and v is not None:
setattr(args, k, v)
# Setup camera parameters.
camera_params = CameraSetting(
num_images=args.num_images,
elevation=args.elevation,
distance=args.distance,
resolution_hw=args.resolution_hw,
fov=math.radians(args.fov),
device=args.device,
)
view_weights = [1, 0.1, 0.02, 0.1, 1, 0.02]
color_grid = Image.open(args.color_path)
if args.delight:
if delight_model is None:
delight_model = DelightingModel()
save_dir = os.path.dirname(args.output_path)
os.makedirs(save_dir, exist_ok=True)
color_grid = delight_model(color_grid)
if not args.no_save_delight_img:
color_grid.save(f"{save_dir}/color_grid_delight.png")
multiviews = get_images_from_grid(color_grid, img_size=512)
# Use RealESRGAN_x4plus for x4 (512->2048) image super resolution.
if imagesr_model is None:
imagesr_model = ImageRealESRGAN(outscale=4)
multiviews = [imagesr_model(img) for img in multiviews]
multiviews = [img.convert("RGB") for img in multiviews]
mesh = trimesh.load(args.mesh_path)
if isinstance(mesh, trimesh.Scene):
mesh = mesh.dump(concatenate=True)
if not args.skip_fix_mesh:
mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
mesh_fixer = MeshFixer(mesh.vertices, mesh.faces, args.device)
mesh.vertices, mesh.faces = mesh_fixer(
filter_ratio=args.mesh_sipmlify_ratio,
max_hole_size=0.04,
resolution=1024,
num_views=1000,
norm_mesh_ratio=0.5,
)
# Restore scale.
mesh.vertices = mesh.vertices / scale
mesh.vertices = mesh.vertices + center
# Baking texture to mesh.
texture_backer = TextureBacker(
camera_params=camera_params,
view_weights=view_weights,
render_wh=camera_params.resolution_hw,
texture_wh=args.texture_wh,
smooth_texture=not args.no_smooth_texture,
)
textured_mesh = texture_backer(multiviews, mesh, args.output_path)
if args.save_glb_path is not None:
os.makedirs(os.path.dirname(args.save_glb_path), exist_ok=True)
textured_mesh.export(args.save_glb_path)
return textured_mesh
if __name__ == "__main__":
entrypoint()

View File

@ -0,0 +1,256 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import json
import logging
import os
import random
from typing import Any, Callable, Dict, List, Tuple
import torch
import torch.utils.checkpoint
from PIL import Image
from torch import nn
from torch.utils.data import Dataset
from torchvision import transforms
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)
__all__ = [
"Asset3dGenDataset",
]
class Asset3dGenDataset(Dataset):
def __init__(
self,
index_file: str,
target_hw: Tuple[int, int],
transform: Callable = None,
control_transform: Callable = None,
max_train_samples: int = None,
sub_idxs: List[List[int]] = None,
seed: int = 79,
) -> None:
if not os.path.exists(index_file):
raise FileNotFoundError(f"{index_file} index_file not found.")
self.index_file = index_file
self.target_hw = target_hw
self.transform = transform
self.control_transform = control_transform
self.max_train_samples = max_train_samples
self.meta_info = self.prepare_data_index(index_file)
self.data_list = sorted(self.meta_info.keys())
self.sub_idxs = sub_idxs # sub_idxs [[0,1,2], [3,4,5], [...], ...]
self.image_num = 6 # hardcode temp.
random.seed(seed)
logger.info(f"Trainset: {len(self)} asset3d instances.")
def __len__(self) -> int:
return len(self.meta_info)
def prepare_data_index(self, index_file: str) -> Dict[str, Any]:
with open(index_file, "r") as fin:
meta_info = json.load(fin)
meta_info_filtered = dict()
for idx, uid in enumerate(meta_info):
if "status" not in meta_info[uid]:
continue
if meta_info[uid]["status"] != "success":
continue
if self.max_train_samples and idx >= self.max_train_samples:
break
meta_info_filtered[uid] = meta_info[uid]
logger.info(
f"Load {len(meta_info)} assets, keep {len(meta_info_filtered)} valids." # noqa
)
return meta_info_filtered
def fetch_sample_images(
self,
uid: str,
attrs: List[str],
sub_index: int = None,
transform: Callable = None,
) -> torch.Tensor:
sample = self.meta_info[uid]
images = []
for attr in attrs:
item = sample[attr]
if sub_index is not None:
item = item[sub_index]
mode = "L" if attr == "image_mask" else "RGB"
image = Image.open(item).convert(mode)
if transform is not None:
image = transform(image)
if len(image.shape) == 2:
image = image[..., None]
images.append(image)
images = torch.cat(images, dim=0)
return images
def fetch_sample_grid_images(
self,
uid: str,
attrs: List[str],
sub_idxs: List[List[int]],
transform: Callable = None,
) -> torch.Tensor:
assert transform is not None
grid_image = []
for row_idxs in sub_idxs:
row_image = []
for row_idx in row_idxs:
image = self.fetch_sample_images(
uid, attrs, row_idx, transform
)
row_image.append(image)
row_image = torch.cat(row_image, dim=2) # (c h w)
grid_image.append(row_image)
grid_image = torch.cat(grid_image, dim=1)
return grid_image
def compute_text_embeddings(
self, embed_path: str, original_size: Tuple[int, int]
) -> Dict[str, nn.Module]:
data_dict = torch.load(embed_path)
prompt_embeds = data_dict["prompt_embeds"][0]
add_text_embeds = data_dict["pooled_prompt_embeds"][0]
# Need changed if random crop, set as crop_top_left [y1, x1], center crop as [0, 0]. # noqa
crops_coords_top_left = (0, 0)
add_time_ids = list(
original_size + crops_coords_top_left + self.target_hw
)
add_time_ids = torch.tensor([add_time_ids])
# add_time_ids = add_time_ids.repeat((len(add_text_embeds), 1))
unet_added_cond_kwargs = {
"text_embeds": add_text_embeds,
"time_ids": add_time_ids,
}
return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
def visualize_item(
self,
control: torch.Tensor,
color: torch.Tensor,
save_dir: str = None,
) -> List[Image.Image]:
to_pil = transforms.ToPILImage()
color = (color + 1) / 2
color_pil = to_pil(color)
normal_pil = to_pil(control[0:3])
position_pil = to_pil(control[3:6])
mask_pil = to_pil(control[6:])
if save_dir is not None:
os.makedirs(save_dir, exist_ok=True)
color_pil.save(f"{save_dir}/rgb.jpg")
normal_pil.save(f"{save_dir}/normal.jpg")
position_pil.save(f"{save_dir}/position.jpg")
mask_pil.save(f"{save_dir}/mask.jpg")
logger.info(f"Visualization in {save_dir}")
return normal_pil, position_pil, mask_pil, color_pil
def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
uid = self.data_list[index]
sub_idxs = self.sub_idxs
if sub_idxs is None:
sub_idxs = [[random.randint(0, self.image_num - 1)]]
input_image = self.fetch_sample_grid_images(
uid,
attrs=["image_view_normal", "image_position", "image_mask"],
sub_idxs=sub_idxs,
transform=self.control_transform,
)
assert input_image.shape[1:] == self.target_hw
output_image = self.fetch_sample_grid_images(
uid,
attrs=["image_color"],
sub_idxs=sub_idxs,
transform=self.transform,
)
sample = self.meta_info[uid]
text_feats = self.compute_text_embeddings(
sample["text_feat"], tuple(sample["image_hw"])
)
data = dict(
pixel_values=output_image,
conditioning_pixel_values=input_image,
prompt_embeds=text_feats["prompt_embeds"],
text_embeds=text_feats["text_embeds"],
time_ids=text_feats["time_ids"],
)
return data
if __name__ == "__main__":
index_file = "datasets/objaverse/v1.0/statistics_1.0_gobjaverse_filter/view6s_v4/meta_ac2e0ddea8909db26d102c8465b5bcb2.json" # noqa
target_hw = (512, 512)
transform_list = [
transforms.Resize(
target_hw, interpolation=transforms.InterpolationMode.BILINEAR
),
transforms.CenterCrop(target_hw),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
image_transform = transforms.Compose(transform_list)
control_transform = transforms.Compose(transform_list[:-1])
sub_idxs = [[0, 1, 2], [3, 4, 5]] # None
if sub_idxs is not None:
target_hw = (
target_hw[0] * len(sub_idxs),
target_hw[1] * len(sub_idxs[0]),
)
dataset = Asset3dGenDataset(
index_file,
target_hw,
image_transform,
control_transform,
sub_idxs=sub_idxs,
)
data = dataset[0]
dataset.visualize_item(
data["conditioning_pixel_values"], data["pixel_values"], save_dir="./"
)

View File

@ -0,0 +1,526 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import argparse
import json
import logging
import math
import os
from collections import defaultdict
from typing import List, Union
import cv2
import nvdiffrast.torch as dr
import torch
from tqdm import tqdm
from embodied_gen.data.utils import (
CameraSetting,
DiffrastRender,
RenderItems,
as_list,
calc_vertex_normals,
import_kaolin_mesh,
init_kal_camera,
normalize_vertices_array,
render_pbr,
save_images,
)
from embodied_gen.utils.process_media import (
create_gif_from_images,
create_mp4_from_images,
)
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
"~/.cache/torch_extensions"
)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)
__all__ = ["ImageRender"]
class ImageRender(object):
"""A differentiable mesh renderer supporting multi-view rendering.
This class wraps a differentiable rasterization using `nvdiffrast` to
render mesh geometry to various maps (normal, depth, alpha, albedo, etc.).
Args:
render_items (list[RenderItems]): A list of rendering targets to
generate (e.g., IMAGE, DEPTH, NORMAL, etc.).
camera_params (CameraSetting): The camera parameters for rendering,
including intrinsic and extrinsic matrices.
recompute_vtx_normal (bool, optional): If True, recomputes
vertex normals from the mesh geometry. Defaults to True.
with_mtl (bool, optional): Whether to load `.mtl` material files
for meshes. Defaults to False.
gen_color_gif (bool, optional): Generate a GIF of rendered
color images. Defaults to False.
gen_color_mp4 (bool, optional): Generate an MP4 video of rendered
color images. Defaults to False.
gen_viewnormal_mp4 (bool, optional): Generate an MP4 video of
view-space normals. Defaults to False.
gen_glonormal_mp4 (bool, optional): Generate an MP4 video of
global-space normals. Defaults to False.
no_index_file (bool, optional): If True, skip saving the `index.json`
summary file. Defaults to False.
light_factor (float, optional): A scalar multiplier for
PBR light intensity. Defaults to 1.0.
"""
def __init__(
self,
render_items: list[RenderItems],
camera_params: CameraSetting,
recompute_vtx_normal: bool = True,
with_mtl: bool = False,
gen_color_gif: bool = False,
gen_color_mp4: bool = False,
gen_viewnormal_mp4: bool = False,
gen_glonormal_mp4: bool = False,
no_index_file: bool = False,
light_factor: float = 1.0,
) -> None:
camera = init_kal_camera(camera_params)
self.camera = camera
# Setup MVP matrix and renderer.
mv = camera.view_matrix() # (n 4 4) world2cam
p = camera.intrinsics.projection_matrix()
# NOTE: add a negative sign at P[0, 2] as the y axis is flipped in `nvdiffrast` output. # noqa
p[:, 1, 1] = -p[:, 1, 1]
# mvp = torch.bmm(p, mv) # camera.view_projection_matrix()
self.mv = mv
self.p = p
renderer = DiffrastRender(
p_matrix=p,
mv_matrix=mv,
resolution_hw=camera_params.resolution_hw,
context=dr.RasterizeCudaContext(),
mask_thresh=0.5,
grad_db=False,
device=camera_params.device,
antialias_mask=True,
)
self.renderer = renderer
self.recompute_vtx_normal = recompute_vtx_normal
self.render_items = render_items
self.device = camera_params.device
self.with_mtl = with_mtl
self.gen_color_gif = gen_color_gif
self.gen_color_mp4 = gen_color_mp4
self.gen_viewnormal_mp4 = gen_viewnormal_mp4
self.gen_glonormal_mp4 = gen_glonormal_mp4
self.light_factor = light_factor
self.no_index_file = no_index_file
def render_mesh(
self,
mesh_path: Union[str, List[str]],
output_root: str,
uuid: Union[str, List[str]] = None,
prompts: List[str] = None,
) -> None:
mesh_path = as_list(mesh_path)
if uuid is None:
uuid = [os.path.basename(p).split(".")[0] for p in mesh_path]
uuid = as_list(uuid)
assert len(mesh_path) == len(uuid)
os.makedirs(output_root, exist_ok=True)
meta_info = dict()
for idx, (path, uid) in tqdm(
enumerate(zip(mesh_path, uuid)), total=len(mesh_path)
):
output_dir = os.path.join(output_root, uid)
os.makedirs(output_dir, exist_ok=True)
prompt = prompts[idx] if prompts else None
data_dict = self(path, output_dir, prompt)
meta_info[uid] = data_dict
if self.no_index_file:
return
index_file = os.path.join(output_root, "index.json")
with open(index_file, "w") as fout:
json.dump(meta_info, fout)
logger.info(f"Rendering meta info logged in {index_file}")
def __call__(
self, mesh_path: str, output_dir: str, prompt: str = None
) -> dict[str, str]:
"""Render a single mesh and return paths to the rendered outputs.
Processes the input mesh, renders multiple modalities (e.g., normals,
depth, albedo), and optionally saves video or image sequences.
Args:
mesh_path (str): Path to the mesh file (.obj/.glb).
output_dir (str): Directory to save rendered outputs.
prompt (str, optional): Optional caption prompt for MP4 metadata.
Returns:
dict[str, str]: A mapping render types to the saved image paths.
"""
try:
mesh = import_kaolin_mesh(mesh_path, self.with_mtl)
except Exception as e:
logger.error(f"[ERROR MESH LOAD]: {e}, skip {mesh_path}")
return
mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
if self.recompute_vtx_normal:
mesh.vertex_normals = calc_vertex_normals(
mesh.vertices, mesh.faces
)
mesh = mesh.to(self.device)
vertices, faces, vertex_normals = (
mesh.vertices,
mesh.faces,
mesh.vertex_normals,
)
# Perform rendering.
data_dict = defaultdict(list)
if RenderItems.ALPHA.value in self.render_items:
masks, _ = self.renderer.render_rast_alpha(vertices, faces)
render_paths = save_images(
masks, f"{output_dir}/{RenderItems.ALPHA}"
)
data_dict[RenderItems.ALPHA.value] = render_paths
if RenderItems.GLOBAL_NORMAL.value in self.render_items:
rendered_normals, masks = self.renderer.render_global_normal(
vertices, faces, vertex_normals
)
if self.gen_glonormal_mp4:
if isinstance(rendered_normals, torch.Tensor):
rendered_normals = rendered_normals.detach().cpu().numpy()
create_mp4_from_images(
rendered_normals,
output_path=f"{output_dir}/normal.mp4",
fps=15,
prompt=prompt,
)
else:
render_paths = save_images(
rendered_normals,
f"{output_dir}/{RenderItems.GLOBAL_NORMAL}",
cvt_color=cv2.COLOR_BGR2RGB,
)
data_dict[RenderItems.GLOBAL_NORMAL.value] = render_paths
if RenderItems.VIEW_NORMAL.value in self.render_items:
assert (
RenderItems.GLOBAL_NORMAL in self.render_items
), f"Must render global normal firstly, got render_items: {self.render_items}." # noqa
rendered_view_normals = self.renderer.transform_normal(
rendered_normals, self.mv, masks, to_view=True
)
if self.gen_viewnormal_mp4:
create_mp4_from_images(
rendered_view_normals,
output_path=f"{output_dir}/view_normal.mp4",
fps=15,
prompt=prompt,
)
else:
render_paths = save_images(
rendered_view_normals,
f"{output_dir}/{RenderItems.VIEW_NORMAL}",
cvt_color=cv2.COLOR_BGR2RGB,
)
data_dict[RenderItems.VIEW_NORMAL.value] = render_paths
if RenderItems.POSITION_MAP.value in self.render_items:
rendered_position, masks = self.renderer.render_position(
vertices, faces
)
norm_position = self.renderer.normalize_map_by_mask(
rendered_position, masks
)
render_paths = save_images(
norm_position,
f"{output_dir}/{RenderItems.POSITION_MAP}",
cvt_color=cv2.COLOR_BGR2RGB,
)
data_dict[RenderItems.POSITION_MAP.value] = render_paths
if RenderItems.DEPTH.value in self.render_items:
rendered_depth, masks = self.renderer.render_depth(vertices, faces)
norm_depth = self.renderer.normalize_map_by_mask(
rendered_depth, masks
)
render_paths = save_images(
norm_depth,
f"{output_dir}/{RenderItems.DEPTH}",
)
data_dict[RenderItems.DEPTH.value] = render_paths
render_paths = save_images(
rendered_depth,
f"{output_dir}/{RenderItems.DEPTH}_exr",
to_uint8=False,
format=".exr",
)
data_dict[f"{RenderItems.DEPTH.value}_exr"] = render_paths
if RenderItems.IMAGE.value in self.render_items:
images = []
albedos = []
diffuses = []
masks, _ = self.renderer.render_rast_alpha(vertices, faces)
try:
for idx, cam in enumerate(self.camera):
image, albedo, diffuse, _ = render_pbr(
mesh, cam, light_factor=self.light_factor
)
image = torch.cat([image[0], masks[idx]], axis=-1)
images.append(image.detach().cpu().numpy())
if RenderItems.ALBEDO.value in self.render_items:
albedo = torch.cat([albedo[0], masks[idx]], axis=-1)
albedos.append(albedo.detach().cpu().numpy())
if RenderItems.DIFFUSE.value in self.render_items:
diffuse = torch.cat([diffuse[0], masks[idx]], axis=-1)
diffuses.append(diffuse.detach().cpu().numpy())
except Exception as e:
logger.error(f"[ERROR pbr render]: {e}, skip {mesh_path}")
return
if self.gen_color_gif:
create_gif_from_images(
images,
output_path=f"{output_dir}/color.gif",
fps=15,
)
if self.gen_color_mp4:
create_mp4_from_images(
images,
output_path=f"{output_dir}/color.mp4",
fps=15,
prompt=prompt,
)
if self.gen_color_mp4 or self.gen_color_gif:
return data_dict
render_paths = save_images(
images,
f"{output_dir}/{RenderItems.IMAGE}",
cvt_color=cv2.COLOR_BGRA2RGBA,
)
data_dict[RenderItems.IMAGE.value] = render_paths
render_paths = save_images(
albedos,
f"{output_dir}/{RenderItems.ALBEDO}",
cvt_color=cv2.COLOR_BGRA2RGBA,
)
data_dict[RenderItems.ALBEDO.value] = render_paths
render_paths = save_images(
diffuses,
f"{output_dir}/{RenderItems.DIFFUSE}",
cvt_color=cv2.COLOR_BGRA2RGBA,
)
data_dict[RenderItems.DIFFUSE.value] = render_paths
data_dict["status"] = "success"
logger.info(f"Finish rendering in {output_dir}")
return data_dict
def parse_args():
parser = argparse.ArgumentParser(description="Render settings")
parser.add_argument(
"--mesh_path",
type=str,
nargs="+",
help="Paths to the mesh files for rendering.",
)
parser.add_argument(
"--output_root",
type=str,
help="Root directory for output",
)
parser.add_argument(
"--uuid",
type=str,
nargs="+",
default=None,
help="uuid for rendering saving.",
)
parser.add_argument(
"--num_images", type=int, default=6, help="Number of images to render."
)
parser.add_argument(
"--elevation",
type=float,
nargs="+",
default=[20.0, -10.0],
help="Elevation angles for the camera (default: [20.0, -10.0])",
)
parser.add_argument(
"--distance",
type=float,
default=5,
help="Camera distance (default: 5)",
)
parser.add_argument(
"--resolution_hw",
type=int,
nargs=2,
default=(512, 512),
help="Resolution of the output images (default: (512, 512))",
)
parser.add_argument(
"--fov",
type=float,
default=30,
help="Field of view in degrees (default: 30)",
)
parser.add_argument(
"--pbr_light_factor",
type=float,
default=1.0,
help="Light factor for mesh PBR rendering (default: 2.)",
)
parser.add_argument(
"--with_mtl",
action="store_true",
help="Whether to render with mesh material.",
)
parser.add_argument(
"--gen_color_gif",
action="store_true",
help="Whether to generate color .gif rendering file.",
)
parser.add_argument(
"--gen_color_mp4",
action="store_true",
help="Whether to generate color .mp4 rendering file.",
)
parser.add_argument(
"--gen_viewnormal_mp4",
action="store_true",
help="Whether to generate view normal .mp4 rendering file.",
)
parser.add_argument(
"--gen_glonormal_mp4",
action="store_true",
help="Whether to generate global normal .mp4 rendering file.",
)
parser.add_argument(
"--prompts",
type=str,
nargs="+",
default=None,
help="Text prompts for the rendering.",
)
args = parser.parse_args()
if args.uuid is None and args.mesh_path is not None:
args.uuid = []
for path in args.mesh_path:
uuid = os.path.basename(path).split(".")[0]
args.uuid.append(uuid)
return args
def entrypoint(**kwargs) -> None:
args = parse_args()
for k, v in kwargs.items():
if hasattr(args, k) and v is not None:
setattr(args, k, v)
camera_settings = CameraSetting(
num_images=args.num_images,
elevation=args.elevation,
distance=args.distance,
resolution_hw=args.resolution_hw,
fov=math.radians(args.fov),
device="cuda",
)
render_items = [
RenderItems.ALPHA.value,
RenderItems.GLOBAL_NORMAL.value,
RenderItems.VIEW_NORMAL.value,
RenderItems.POSITION_MAP.value,
RenderItems.IMAGE.value,
RenderItems.DEPTH.value,
# RenderItems.ALBEDO.value,
# RenderItems.DIFFUSE.value,
]
gen_video = (
args.gen_color_gif
or args.gen_color_mp4
or args.gen_viewnormal_mp4
or args.gen_glonormal_mp4
)
if gen_video:
render_items = []
if args.gen_color_gif or args.gen_color_mp4:
render_items.append(RenderItems.IMAGE.value)
if args.gen_glonormal_mp4:
render_items.append(RenderItems.GLOBAL_NORMAL.value)
if args.gen_viewnormal_mp4:
render_items.append(RenderItems.VIEW_NORMAL.value)
if RenderItems.GLOBAL_NORMAL.value not in render_items:
render_items.append(RenderItems.GLOBAL_NORMAL.value)
image_render = ImageRender(
render_items=render_items,
camera_params=camera_settings,
with_mtl=args.with_mtl,
gen_color_gif=args.gen_color_gif,
gen_color_mp4=args.gen_color_mp4,
gen_viewnormal_mp4=args.gen_viewnormal_mp4,
gen_glonormal_mp4=args.gen_glonormal_mp4,
light_factor=args.pbr_light_factor,
no_index_file=gen_video,
)
image_render.render_mesh(
mesh_path=args.mesh_path,
output_root=args.output_root,
uuid=args.uuid,
prompts=args.prompts,
)
return
if __name__ == "__main__":
entrypoint()

View File

@ -0,0 +1,452 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import logging
from typing import Tuple, Union
import igraph
import numpy as np
import pyvista as pv
import spaces
import torch
import utils3d
from pymeshfix import _meshfix
from tqdm import tqdm
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)
__all__ = ["MeshFixer"]
def _radical_inverse(base, n):
val = 0
inv_base = 1.0 / base
inv_base_n = inv_base
while n > 0:
digit = n % base
val += digit * inv_base_n
n //= base
inv_base_n *= inv_base
return val
def _halton_sequence(dim, n):
PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
return [_radical_inverse(PRIMES[dim], n) for dim in range(dim)]
def _hammersley_sequence(dim, n, num_samples):
return [n / num_samples] + _halton_sequence(dim - 1, n)
def _sphere_hammersley_seq(n, num_samples, offset=(0, 0), remap=False):
"""Generate a point on a unit sphere using the Hammersley sequence.
Args:
n (int): The index of the sample.
num_samples (int): The total number of samples.
offset (tuple, optional): Offset for the u and v coordinates.
remap (bool, optional): Whether to remap the u coordinate.
Returns:
list: A list containing the spherical coordinates [phi, theta].
"""
u, v = _hammersley_sequence(2, n, num_samples)
u += offset[0] / num_samples
v += offset[1]
if remap:
u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3
theta = np.arccos(1 - 2 * u) - np.pi / 2
phi = v * 2 * np.pi
return [phi, theta]
class MeshFixer(object):
"""MeshFixer simplifies and repairs 3D triangle meshes by TSDF.
Attributes:
vertices (torch.Tensor): A tensor of shape (V, 3) representing vertex positions.
faces (torch.Tensor): A tensor of shape (F, 3) representing face indices.
device (str): Device to run computations on, typically "cuda" or "cpu".
Main logic reference: https://github.com/microsoft/TRELLIS/blob/main/trellis/utils/postprocessing_utils.py#L22
"""
def __init__(
self,
vertices: Union[torch.Tensor, np.ndarray],
faces: Union[torch.Tensor, np.ndarray],
device: str = "cuda",
) -> None:
self.device = device
if isinstance(vertices, np.ndarray):
vertices = torch.tensor(vertices)
self.vertices = vertices
if isinstance(faces, np.ndarray):
faces = torch.tensor(faces)
self.faces = faces
@staticmethod
def log_mesh_changes(method):
def wrapper(self, *args, **kwargs):
logger.info(
f"Before {method.__name__}: {self.vertices.shape[0]} vertices, {self.faces.shape[0]} faces" # noqa
)
result = method(self, *args, **kwargs)
logger.info(
f"After {method.__name__}: {self.vertices.shape[0]} vertices, {self.faces.shape[0]} faces" # noqa
)
return result
return wrapper
@log_mesh_changes
def fill_holes(
self,
max_hole_size: float,
max_hole_nbe: int,
resolution: int,
num_views: int,
norm_mesh_ratio: float = 1.0,
) -> None:
self.vertices = self.vertices * norm_mesh_ratio
vertices, self.faces = self._fill_holes(
self.vertices,
self.faces,
max_hole_size,
max_hole_nbe,
resolution,
num_views,
)
self.vertices = vertices / norm_mesh_ratio
@staticmethod
@torch.no_grad()
def _fill_holes(
vertices: torch.Tensor,
faces: torch.Tensor,
max_hole_size: float,
max_hole_nbe: int,
resolution: int,
num_views: int,
) -> Union[torch.Tensor, torch.Tensor]:
yaws, pitchs = [], []
for i in range(num_views):
y, p = _sphere_hammersley_seq(i, num_views)
yaws.append(y)
pitchs.append(p)
yaws, pitchs = torch.tensor(yaws).to(vertices), torch.tensor(
pitchs
).to(vertices)
radius, fov = 2.0, torch.deg2rad(torch.tensor(40)).to(vertices)
projection = utils3d.torch.perspective_from_fov_xy(fov, fov, 1, 3)
views = []
for yaw, pitch in zip(yaws, pitchs):
orig = (
torch.tensor(
[
torch.sin(yaw) * torch.cos(pitch),
torch.cos(yaw) * torch.cos(pitch),
torch.sin(pitch),
]
).to(vertices)
* radius
)
view = utils3d.torch.view_look_at(
orig,
torch.tensor([0, 0, 0]).to(vertices),
torch.tensor([0, 0, 1]).to(vertices),
)
views.append(view)
views = torch.stack(views, dim=0)
# Rasterize the mesh
visibility = torch.zeros(
faces.shape[0], dtype=torch.int32, device=faces.device
)
rastctx = utils3d.torch.RastContext(backend="cuda")
for i in tqdm(
range(views.shape[0]), total=views.shape[0], desc="Rasterizing"
):
view = views[i]
buffers = utils3d.torch.rasterize_triangle_faces(
rastctx,
vertices[None],
faces,
resolution,
resolution,
view=view,
projection=projection,
)
face_id = buffers["face_id"][0][buffers["mask"][0] > 0.95] - 1
face_id = torch.unique(face_id).long()
visibility[face_id] += 1
# Normalize visibility by the number of views
visibility = visibility.float() / num_views
# Mincut: Identify outer and inner faces
edges, face2edge, edge_degrees = utils3d.torch.compute_edges(faces)
boundary_edge_indices = torch.nonzero(edge_degrees == 1).reshape(-1)
connected_components = utils3d.torch.compute_connected_components(
faces, edges, face2edge
)
outer_face_indices = torch.zeros(
faces.shape[0], dtype=torch.bool, device=faces.device
)
for i in range(len(connected_components)):
outer_face_indices[connected_components[i]] = visibility[
connected_components[i]
] > min(
max(
visibility[connected_components[i]].quantile(0.75).item(),
0.25,
),
0.5,
)
outer_face_indices = outer_face_indices.nonzero().reshape(-1)
inner_face_indices = torch.nonzero(visibility == 0).reshape(-1)
if inner_face_indices.shape[0] == 0:
return vertices, faces
# Construct dual graph (faces as nodes, edges as edges)
dual_edges, dual_edge2edge = utils3d.torch.compute_dual_graph(
face2edge
)
dual_edge2edge = edges[dual_edge2edge]
dual_edges_weights = torch.norm(
vertices[dual_edge2edge[:, 0]] - vertices[dual_edge2edge[:, 1]],
dim=1,
)
# Mincut: Construct main graph and solve the mincut problem
g = igraph.Graph()
g.add_vertices(faces.shape[0])
g.add_edges(dual_edges.cpu().numpy())
g.es["weight"] = dual_edges_weights.cpu().numpy()
g.add_vertex("s") # source
g.add_vertex("t") # target
g.add_edges(
[(f, "s") for f in inner_face_indices],
attributes={
"weight": torch.ones(
inner_face_indices.shape[0], dtype=torch.float32
)
.cpu()
.numpy()
},
)
g.add_edges(
[(f, "t") for f in outer_face_indices],
attributes={
"weight": torch.ones(
outer_face_indices.shape[0], dtype=torch.float32
)
.cpu()
.numpy()
},
)
cut = g.mincut("s", "t", (np.array(g.es["weight"]) * 1000).tolist())
remove_face_indices = torch.tensor(
[v for v in cut.partition[0] if v < faces.shape[0]],
dtype=torch.long,
device=faces.device,
)
# Check if the cut is valid with each connected component
to_remove_cc = utils3d.torch.compute_connected_components(
faces[remove_face_indices]
)
valid_remove_cc = []
cutting_edges = []
for cc in to_remove_cc:
# Check visibility median for connected component
visibility_median = visibility[remove_face_indices[cc]].median()
if visibility_median > 0.25:
continue
# Check if the cutting loop is small enough
cc_edge_indices, cc_edges_degree = torch.unique(
face2edge[remove_face_indices[cc]], return_counts=True
)
cc_boundary_edge_indices = cc_edge_indices[cc_edges_degree == 1]
cc_new_boundary_edge_indices = cc_boundary_edge_indices[
~torch.isin(cc_boundary_edge_indices, boundary_edge_indices)
]
if len(cc_new_boundary_edge_indices) > 0:
cc_new_boundary_edge_cc = (
utils3d.torch.compute_edge_connected_components(
edges[cc_new_boundary_edge_indices]
)
)
cc_new_boundary_edges_cc_center = [
vertices[edges[cc_new_boundary_edge_indices[edge_cc]]]
.mean(dim=1)
.mean(dim=0)
for edge_cc in cc_new_boundary_edge_cc
]
cc_new_boundary_edges_cc_area = []
for i, edge_cc in enumerate(cc_new_boundary_edge_cc):
_e1 = (
vertices[
edges[cc_new_boundary_edge_indices[edge_cc]][:, 0]
]
- cc_new_boundary_edges_cc_center[i]
)
_e2 = (
vertices[
edges[cc_new_boundary_edge_indices[edge_cc]][:, 1]
]
- cc_new_boundary_edges_cc_center[i]
)
cc_new_boundary_edges_cc_area.append(
torch.norm(torch.cross(_e1, _e2, dim=-1), dim=1).sum()
* 0.5
)
cutting_edges.append(cc_new_boundary_edge_indices)
if any(
[
_l > max_hole_size
for _l in cc_new_boundary_edges_cc_area
]
):
continue
valid_remove_cc.append(cc)
if len(valid_remove_cc) > 0:
remove_face_indices = remove_face_indices[
torch.cat(valid_remove_cc)
]
mask = torch.ones(
faces.shape[0], dtype=torch.bool, device=faces.device
)
mask[remove_face_indices] = 0
faces = faces[mask]
faces, vertices = utils3d.torch.remove_unreferenced_vertices(
faces, vertices
)
tqdm.write(f"Removed {(~mask).sum()} faces by mincut")
else:
tqdm.write(f"Removed 0 faces by mincut")
# Fill small boundaries (holes)
mesh = _meshfix.PyTMesh()
mesh.load_array(vertices.cpu().numpy(), faces.cpu().numpy())
mesh.fill_small_boundaries(nbe=max_hole_nbe, refine=True)
_vertices, _faces = mesh.return_arrays()
vertices = torch.tensor(_vertices).to(vertices)
faces = torch.tensor(_faces).to(faces)
return vertices, faces
@property
def vertices_np(self) -> np.ndarray:
return self.vertices.cpu().numpy()
@property
def faces_np(self) -> np.ndarray:
return self.faces.cpu().numpy()
@log_mesh_changes
def simplify(self, ratio: float) -> None:
"""Simplify the mesh using quadric edge collapse decimation.
Args:
ratio (float): Ratio of faces to filter out.
"""
if ratio <= 0 or ratio >= 1:
raise ValueError("Simplify ratio must be between 0 and 1.")
# Convert to PyVista format for simplification
mesh = pv.PolyData(
self.vertices_np,
np.hstack([np.full((self.faces.shape[0], 1), 3), self.faces_np]),
)
mesh = mesh.decimate(ratio, progress_bar=True)
# Update vertices and faces
self.vertices = torch.tensor(
mesh.points, device=self.device, dtype=torch.float32
)
self.faces = torch.tensor(
mesh.faces.reshape(-1, 4)[:, 1:],
device=self.device,
dtype=torch.int32,
)
@spaces.GPU
def __call__(
self,
filter_ratio: float,
max_hole_size: float,
resolution: int,
num_views: int,
norm_mesh_ratio: float = 1.0,
) -> Tuple[np.ndarray, np.ndarray]:
"""Post-process the mesh by simplifying and filling holes.
This method performs a two-step process:
1. Simplifies mesh by reducing faces using quadric edge decimation.
2. Fills holes by removing invisible faces, repairing small boundaries.
Args:
filter_ratio (float): Ratio of faces to simplify out.
Must be in the range (0, 1).
max_hole_size (float): Maximum area of a hole to fill. Connected
components of holes larger than this size will not be repaired.
resolution (int): Resolution of the rasterization buffer.
num_views (int): Number of viewpoints to sample for rasterization.
norm_mesh_ratio (float, optional): A scaling factor applied to the
vertices of the mesh during processing.
Returns:
Tuple[np.ndarray, np.ndarray]:
- vertices: Simplified and repaired vertex array of (V, 3).
- faces: Simplified and repaired face array of (F, 3).
"""
self.vertices = self.vertices.to(self.device)
self.faces = self.faces.to(self.device)
self.simplify(ratio=filter_ratio)
self.fill_holes(
max_hole_size=max_hole_size,
max_hole_nbe=int(250 * np.sqrt(1 - filter_ratio)),
resolution=resolution,
num_views=num_views,
norm_mesh_ratio=norm_mesh_ratio,
)
return self.vertices_np, self.faces_np

1009
embodied_gen/data/utils.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,200 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import os
from typing import Union
import cv2
import numpy as np
import spaces
import torch
from diffusers import (
EulerAncestralDiscreteScheduler,
StableDiffusionInstructPix2PixPipeline,
)
from huggingface_hub import snapshot_download
from PIL import Image
from embodied_gen.models.segment_model import RembgRemover
__all__ = [
"DelightingModel",
]
class DelightingModel(object):
"""A model to remove the lighting in image space.
This model is encapsulated based on the Hunyuan3D-Delight model
from https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0 # noqa
Attributes:
image_guide_scale (float): Weight of image guidance in diffusion process.
text_guide_scale (float): Weight of text (prompt) guidance in diffusion process.
num_infer_step (int): Number of inference steps for diffusion model.
mask_erosion_size (int): Size of erosion kernel for alpha mask cleanup.
device (str): Device used for inference, e.g., 'cuda' or 'cpu'.
seed (int): Random seed for diffusion model reproducibility.
model_path (str): Filesystem path to pretrained model weights.
pipeline: Lazy-loaded diffusion pipeline instance.
"""
def __init__(
self,
model_path: str = None,
num_infer_step: int = 50,
mask_erosion_size: int = 3,
image_guide_scale: float = 1.5,
text_guide_scale: float = 1.0,
device: str = "cuda",
seed: int = 0,
) -> None:
self.image_guide_scale = image_guide_scale
self.text_guide_scale = text_guide_scale
self.num_infer_step = num_infer_step
self.mask_erosion_size = mask_erosion_size
self.kernel = np.ones(
(self.mask_erosion_size, self.mask_erosion_size), np.uint8
)
self.seed = seed
self.device = device
self.pipeline = None # lazy load model adapt to @spaces.GPU
if model_path is None:
suffix = "hunyuan3d-delight-v2-0"
model_path = snapshot_download(
repo_id="tencent/Hunyuan3D-2", allow_patterns=f"{suffix}/*"
)
model_path = os.path.join(model_path, suffix)
self.model_path = model_path
def _lazy_init_pipeline(self):
if self.pipeline is None:
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
self.model_path,
torch_dtype=torch.float16,
safety_checker=None,
)
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipeline.scheduler.config
)
pipeline.set_progress_bar_config(disable=True)
pipeline.to(self.device, torch.float16)
self.pipeline = pipeline
def recenter_image(
self, image: Image.Image, border_ratio: float = 0.2
) -> Image.Image:
if image.mode == "RGB":
return image
elif image.mode == "L":
image = image.convert("RGB")
return image
alpha_channel = np.array(image)[:, :, 3]
non_zero_indices = np.argwhere(alpha_channel > 0)
if non_zero_indices.size == 0:
raise ValueError("Image is fully transparent")
min_row, min_col = non_zero_indices.min(axis=0)
max_row, max_col = non_zero_indices.max(axis=0)
cropped_image = image.crop(
(min_col, min_row, max_col + 1, max_row + 1)
)
width, height = cropped_image.size
border_width = int(width * border_ratio)
border_height = int(height * border_ratio)
new_width = width + 2 * border_width
new_height = height + 2 * border_height
square_size = max(new_width, new_height)
new_image = Image.new(
"RGBA", (square_size, square_size), (255, 255, 255, 0)
)
paste_x = (square_size - new_width) // 2 + border_width
paste_y = (square_size - new_height) // 2 + border_height
new_image.paste(cropped_image, (paste_x, paste_y))
return new_image
@spaces.GPU
@torch.no_grad()
def __call__(
self,
image: Union[str, np.ndarray, Image.Image],
preprocess: bool = False,
target_wh: tuple[int, int] = None,
) -> Image.Image:
self._lazy_init_pipeline()
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)
if preprocess:
bg_remover = RembgRemover()
image = bg_remover(image)
image = self.recenter_image(image)
if target_wh is not None:
image = image.resize(target_wh)
else:
target_wh = image.size
image_array = np.array(image)
assert image_array.shape[-1] == 4, "Image must have alpha channel"
raw_alpha_channel = image_array[:, :, 3]
alpha_channel = cv2.erode(raw_alpha_channel, self.kernel, iterations=1)
image_array[alpha_channel == 0, :3] = 255 # must be white background
image_array[:, :, 3] = alpha_channel
image = self.pipeline(
prompt="",
image=Image.fromarray(image_array).convert("RGB"),
generator=torch.manual_seed(self.seed),
num_inference_steps=self.num_infer_step,
image_guidance_scale=self.image_guide_scale,
guidance_scale=self.text_guide_scale,
).images[0]
alpha_channel = Image.fromarray(alpha_channel)
rgba_image = image.convert("RGBA").resize(target_wh)
rgba_image.putalpha(alpha_channel)
return rgba_image
if __name__ == "__main__":
delighting_model = DelightingModel()
image_path = "apps/assets/example_image/sample_12.jpg"
image = delighting_model(
image_path, preprocess=True, target_wh=(512, 512)
) # noqa
image.save("delight.png")
# image_path = "embodied_gen/scripts/test_robot.png"
# image = delighting_model(image_path)
# image.save("delighting_image_a2.png")

View File

@ -0,0 +1,526 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import logging
import os
import struct
from dataclasses import dataclass
from typing import Optional
import cv2
import numpy as np
import torch
from gsplat.cuda._wrapper import spherical_harmonics
from gsplat.rendering import rasterization
from plyfile import PlyData
from scipy.spatial.transform import Rotation
from embodied_gen.data.utils import gamma_shs, quat_mult, quat_to_rotmat
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
__all__ = [
"RenderResult",
"GaussianOperator",
]
@dataclass
class RenderResult:
rgb: np.ndarray
depth: np.ndarray
opacity: np.ndarray
mask_threshold: float = 10
mask: Optional[np.ndarray] = None
rgba: Optional[np.ndarray] = None
def __post_init__(self):
if isinstance(self.rgb, torch.Tensor):
rgb = self.rgb.detach().cpu().numpy()
rgb = (rgb * 255).astype(np.uint8)
self.rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
if isinstance(self.depth, torch.Tensor):
self.depth = self.depth.detach().cpu().numpy()
if isinstance(self.opacity, torch.Tensor):
opacity = self.opacity.detach().cpu().numpy()
opacity = (opacity * 255).astype(np.uint8)
self.opacity = cv2.cvtColor(opacity, cv2.COLOR_GRAY2RGB)
mask = np.where(self.opacity > self.mask_threshold, 255, 0)
self.mask = mask[..., 0:1].astype(np.uint8)
self.rgba = np.concatenate([self.rgb, self.mask], axis=-1)
@dataclass
class GaussianBase:
_opacities: torch.Tensor
_means: torch.Tensor
_scales: torch.Tensor
_quats: torch.Tensor
_rgbs: Optional[torch.Tensor] = None
_features_dc: Optional[torch.Tensor] = None
_features_rest: Optional[torch.Tensor] = None
sh_degree: Optional[int] = 0
device: str = "cuda"
def __post_init__(self):
self.active_sh_degree: int = self.sh_degree
self.to(self.device)
def to(self, device: str) -> None:
for k, v in self.__dict__.items():
if not isinstance(v, torch.Tensor):
continue
self.__dict__[k] = v.to(device)
def get_numpy_data(self):
data = {}
for k, v in self.__dict__.items():
if not isinstance(v, torch.Tensor):
continue
data[k] = v.detach().cpu().numpy()
return data
def quat_norm(self, x: torch.Tensor) -> torch.Tensor:
return x / x.norm(dim=-1, keepdim=True)
@classmethod
def load_from_ply(
cls,
path: str,
gamma: float = 1.0,
device: str = "cuda",
) -> "GaussianBase":
plydata = PlyData.read(path)
xyz = torch.stack(
(
torch.tensor(plydata.elements[0]["x"], dtype=torch.float32),
torch.tensor(plydata.elements[0]["y"], dtype=torch.float32),
torch.tensor(plydata.elements[0]["z"], dtype=torch.float32),
),
dim=1,
)
opacities = torch.tensor(
plydata.elements[0]["opacity"], dtype=torch.float32
).unsqueeze(-1)
features_dc = torch.zeros((xyz.shape[0], 3), dtype=torch.float32)
features_dc[:, 0] = torch.tensor(
plydata.elements[0]["f_dc_0"], dtype=torch.float32
)
features_dc[:, 1] = torch.tensor(
plydata.elements[0]["f_dc_1"], dtype=torch.float32
)
features_dc[:, 2] = torch.tensor(
plydata.elements[0]["f_dc_2"], dtype=torch.float32
)
scale_names = [
p.name
for p in plydata.elements[0].properties
if p.name.startswith("scale_")
]
scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1]))
scales = torch.zeros(
(xyz.shape[0], len(scale_names)), dtype=torch.float32
)
for idx, attr_name in enumerate(scale_names):
scales[:, idx] = torch.tensor(
plydata.elements[0][attr_name], dtype=torch.float32
)
rot_names = [
p.name
for p in plydata.elements[0].properties
if p.name.startswith("rot_")
]
rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1]))
rots = torch.zeros((xyz.shape[0], len(rot_names)), dtype=torch.float32)
for idx, attr_name in enumerate(rot_names):
rots[:, idx] = torch.tensor(
plydata.elements[0][attr_name], dtype=torch.float32
)
rots = rots / torch.norm(rots, dim=-1, keepdim=True)
# extra features
extra_f_names = [
p.name
for p in plydata.elements[0].properties
if p.name.startswith("f_rest_")
]
extra_f_names = sorted(
extra_f_names, key=lambda x: int(x.split("_")[-1])
)
max_sh_degree = int(np.sqrt((len(extra_f_names) + 3) / 3) - 1)
if max_sh_degree != 0:
features_extra = torch.zeros(
(xyz.shape[0], len(extra_f_names)), dtype=torch.float32
)
for idx, attr_name in enumerate(extra_f_names):
features_extra[:, idx] = torch.tensor(
plydata.elements[0][attr_name], dtype=torch.float32
)
features_extra = features_extra.view(
(features_extra.shape[0], 3, (max_sh_degree + 1) ** 2 - 1)
)
features_extra = features_extra.permute(0, 2, 1)
if abs(gamma - 1.0) > 1e-3:
features_dc = gamma_shs(features_dc, gamma)
features_extra[..., :] = 0.0
opacities *= 0.8
shs = torch.cat(
[
features_dc.reshape(-1, 3),
features_extra.reshape(len(features_dc), -1),
],
dim=-1,
)
else:
# sh_dim is 0, only dc features
shs = features_dc
features_extra = None
return cls(
sh_degree=max_sh_degree,
_means=xyz,
_opacities=opacities,
_rgbs=shs,
_scales=scales,
_quats=rots,
_features_dc=features_dc,
_features_rest=features_extra,
device=device,
)
def save_to_ply(
self, path: str, colors: torch.Tensor = None, enable_mask: bool = False
):
os.makedirs(os.path.dirname(path), exist_ok=True)
numpy_data = self.get_numpy_data()
means = numpy_data["_means"]
scales = numpy_data["_scales"]
quats = numpy_data["_quats"]
opacities = numpy_data["_opacities"]
sh0 = numpy_data["_features_dc"]
shN = numpy_data.get("_features_rest", np.zeros((means.shape[0], 0)))
shN = shN.reshape(means.shape[0], -1)
# Create a mask to identify rows with NaN or Inf in any of the numpy_data arrays # noqa
if enable_mask:
invalid_mask = (
np.isnan(means).any(axis=1)
| np.isinf(means).any(axis=1)
| np.isnan(scales).any(axis=1)
| np.isinf(scales).any(axis=1)
| np.isnan(quats).any(axis=1)
| np.isinf(quats).any(axis=1)
| np.isnan(opacities).any(axis=0)
| np.isinf(opacities).any(axis=0)
| np.isnan(sh0).any(axis=1)
| np.isinf(sh0).any(axis=1)
| np.isnan(shN).any(axis=1)
| np.isinf(shN).any(axis=1)
)
# Filter out rows with NaNs or Infs from all data arrays
means = means[~invalid_mask]
scales = scales[~invalid_mask]
quats = quats[~invalid_mask]
opacities = opacities[~invalid_mask]
sh0 = sh0[~invalid_mask]
shN = shN[~invalid_mask]
num_points = means.shape[0]
with open(path, "wb") as f:
# Write PLY header
f.write(b"ply\n")
f.write(b"format binary_little_endian 1.0\n")
f.write(f"element vertex {num_points}\n".encode())
f.write(b"property float x\n")
f.write(b"property float y\n")
f.write(b"property float z\n")
f.write(b"property float nx\n")
f.write(b"property float ny\n")
f.write(b"property float nz\n")
if colors is not None:
for j in range(colors.shape[1]):
f.write(f"property float f_dc_{j}\n".encode())
else:
for i, data in enumerate([sh0, shN]):
prefix = "f_dc" if i == 0 else "f_rest"
for j in range(data.shape[1]):
f.write(f"property float {prefix}_{j}\n".encode())
f.write(b"property float opacity\n")
for i in range(scales.shape[1]):
f.write(f"property float scale_{i}\n".encode())
for i in range(quats.shape[1]):
f.write(f"property float rot_{i}\n".encode())
f.write(b"end_header\n")
# Write vertex data
for i in range(num_points):
f.write(struct.pack("<fff", *means[i])) # x, y, z
f.write(struct.pack("<fff", 0, 0, 0)) # nx, ny, nz (zeros)
if colors is not None:
color = colors.detach().cpu().numpy()
for j in range(color.shape[1]):
f_dc = (color[i, j] - 0.5) / 0.2820947917738781
f.write(struct.pack("<f", f_dc))
else:
for data in [sh0, shN]:
for j in range(data.shape[1]):
f.write(struct.pack("<f", data[i, j]))
f.write(struct.pack("<f", opacities[i])) # opacity
for data in [scales, quats]:
for j in range(data.shape[1]):
f.write(struct.pack("<f", data[i, j]))
@dataclass
class GaussianOperator(GaussianBase):
"""Gaussian Splatting operator.
Supports transformation, scaling, color computation, and
rasterization-based rendering.
Inherits:
GaussianBase: Base class with Gaussian params (means, scales, etc.)
Functionality includes:
- Applying instance poses to transform Gaussian means and quaternions.
- Scaling Gaussians to a real-world size.
- Computing colors using spherical harmonics.
- Rendering images via differentiable rasterization.
- Exporting transformed and rescaled models to .ply format.
"""
def _compute_transform(
self,
means: torch.Tensor,
quats: torch.Tensor,
instance_pose: torch.Tensor,
):
"""Compute the transform of the GS models.
Args:
means: tensor of gs means.
quats: tensor of gs quaternions.
instance_pose: instances poses in [x y z qx qy qz qw] format.
"""
# (x y z qx qy qz qw) -> (x y z qw qx qy qz)
instance_pose = instance_pose[[0, 1, 2, 6, 3, 4, 5]]
cur_instances_quats = self.quat_norm(instance_pose[3:])
rot_cur = quat_to_rotmat(cur_instances_quats, mode="wxyz")
# update the means
num_gs = means.shape[0]
trans_per_pts = torch.stack([instance_pose[:3]] * num_gs, dim=0)
quat_per_pts = torch.stack([instance_pose[3:]] * num_gs, dim=0)
rot_per_pts = torch.stack([rot_cur] * num_gs, dim=0) # (num_gs, 3, 3)
# update the means
cur_means = (
torch.bmm(rot_per_pts, means.unsqueeze(-1)).squeeze(-1)
+ trans_per_pts
)
# update the quats
_quats = self.quat_norm(quats)
cur_quats = quat_mult(quat_per_pts, _quats)
return cur_means, cur_quats
def get_gaussians(
self,
c2w: torch.Tensor = None,
instance_pose: torch.Tensor = None,
apply_activate: bool = False,
) -> "GaussianBase":
"""Get Gaussian data under the given instance_pose."""
if c2w is None:
c2w = torch.eye(4).to(self.device)
if instance_pose is not None:
# compute the transformed gs means and quats
world_means, world_quats = self._compute_transform(
self._means, self._quats, instance_pose.float().to(self.device)
)
else:
world_means, world_quats = self._means, self._quats
# get colors of gaussians
if self._features_rest is not None:
colors = torch.cat(
(self._features_dc[:, None, :], self._features_rest), dim=1
)
else:
colors = self._features_dc[:, None, :]
if self.sh_degree > 0:
viewdirs = world_means.detach() - c2w[..., :3, 3] # (N, 3)
viewdirs = viewdirs / viewdirs.norm(dim=-1, keepdim=True)
rgbs = spherical_harmonics(self.sh_degree, viewdirs, colors)
rgbs = torch.clamp(rgbs + 0.5, 0.0, 1.0)
else:
rgbs = torch.sigmoid(colors[:, 0, :])
gs_dict = dict(
_means=world_means,
_opacities=(
torch.sigmoid(self._opacities)
if apply_activate
else self._opacities
),
_rgbs=rgbs,
_scales=(
torch.exp(self._scales) if apply_activate else self._scales
),
_quats=self.quat_norm(world_quats),
_features_dc=self._features_dc,
_features_rest=self._features_rest,
sh_degree=self.sh_degree,
device=self.device,
)
return GaussianOperator(**gs_dict)
def rescale(self, scale: float):
if scale != 1.0:
self._means *= scale
self._scales += torch.log(self._scales.new_tensor(scale))
def set_scale_by_height(self, real_height: float) -> None:
def _ptp(tensor, dim):
val = tensor.max(dim=dim).values - tensor.min(dim=dim).values
return val.tolist()
xyz_scale = max(_ptp(self._means, dim=0))
self.rescale(1 / (xyz_scale + 1e-6)) # Normalize to [-0.5, 0.5]
raw_height = _ptp(self._means, dim=0)[1]
scale = real_height / raw_height
self.rescale(scale)
return
@staticmethod
def resave_ply(
in_ply: str,
out_ply: str,
real_height: float = None,
instance_pose: np.ndarray = None,
device: str = "cuda",
) -> None:
gs_model = GaussianOperator.load_from_ply(in_ply, device=device)
if instance_pose is not None:
gs_model = gs_model.get_gaussians(instance_pose=instance_pose)
if real_height is not None:
gs_model.set_scale_by_height(real_height)
gs_model.save_to_ply(out_ply)
return
@staticmethod
def trans_to_quatpose(
rot_matrix: list[list[float]],
trans_matrix: list[float] = [0, 0, 0],
) -> torch.Tensor:
if isinstance(rot_matrix, list):
rot_matrix = np.array(rot_matrix)
rot = Rotation.from_matrix(rot_matrix)
qx, qy, qz, qw = rot.as_quat()
instance_pose = torch.tensor([*trans_matrix, qx, qy, qz, qw])
return instance_pose
def render(
self,
c2w: torch.Tensor,
Ks: torch.Tensor,
image_width: int,
image_height: int,
) -> RenderResult:
gs = self.get_gaussians(c2w, apply_activate=True)
renders, alphas, _ = rasterization(
means=gs._means,
quats=gs._quats,
scales=gs._scales,
opacities=gs._opacities.squeeze(),
colors=gs._rgbs,
viewmats=torch.linalg.inv(c2w)[None, ...],
Ks=Ks[None, ...],
width=image_width,
height=image_height,
packed=False,
absgrad=True,
sparse_grad=False,
# rasterize_mode="classic",
rasterize_mode="antialiased",
**{
"near_plane": 0.01,
"far_plane": 1000000000,
"radius_clip": 0.0,
"render_mode": "RGB+ED",
},
)
renders = renders[0]
alphas = alphas[0].squeeze(-1)
assert renders.shape[-1] == 4, f"Must render rgb, depth and alpha"
rendered_rgb, rendered_depth = torch.split(renders, [3, 1], dim=-1)
return RenderResult(
torch.clamp(rendered_rgb, min=0, max=1),
rendered_depth,
alphas[..., None],
)
if __name__ == "__main__":
input_gs = "outputs/test/debug.ply"
output_gs = "./debug_v3.ply"
gs_model: GaussianOperator = GaussianOperator.load_from_ply(input_gs)
# 绕 x 轴旋转 180°
R_x = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
instance_pose = gs_model.trans_to_quatpose(R_x)
gs_model = gs_model.get_gaussians(instance_pose=instance_pose)
gs_model.rescale(2)
gs_model.set_scale_by_height(1.3)
gs_model.save_to_ply(output_gs)

View File

@ -0,0 +1,379 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import logging
import os
from typing import Literal, Union
import cv2
import numpy as np
import rembg
import torch
from huggingface_hub import snapshot_download
from PIL import Image
from segment_anything import (
SamAutomaticMaskGenerator,
SamPredictor,
sam_model_registry,
)
from transformers import pipeline
from embodied_gen.data.utils import resize_pil, trellis_preprocess
from embodied_gen.utils.process_media import filter_small_connected_components
from embodied_gen.validators.quality_checkers import ImageSegChecker
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
__all__ = [
"SAMRemover",
"SAMPredictor",
"RembgRemover",
"get_segmented_image_by_agent",
]
class SAMRemover(object):
"""Loading SAM models and performing background removal on images.
Attributes:
checkpoint (str): Path to the model checkpoint.
model_type (str): Type of the SAM model to load (default: "vit_h").
area_ratio (float): Area ratio filtering small connected components.
"""
def __init__(
self,
checkpoint: str = None,
model_type: str = "vit_h",
area_ratio: float = 15,
):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model_type = model_type
self.area_ratio = area_ratio
if checkpoint is None:
suffix = "sam"
model_path = snapshot_download(
repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
)
checkpoint = os.path.join(
model_path, suffix, "sam_vit_h_4b8939.pth"
)
self.mask_generator = self._load_sam_model(checkpoint)
def _load_sam_model(self, checkpoint: str) -> SamAutomaticMaskGenerator:
sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
sam.to(device=self.device)
return SamAutomaticMaskGenerator(sam)
def __call__(
self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
) -> Image.Image:
"""Removes the background from an image using the SAM model.
Args:
image (Union[str, Image.Image, np.ndarray]): Input image,
can be a file path, PIL Image, or numpy array.
save_path (str): Path to save the output image (default: None).
Returns:
Image.Image: The image with background removed,
including an alpha channel.
"""
# Convert input to numpy array
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image).convert("RGB")
image = resize_pil(image)
image = np.array(image.convert("RGB"))
# Generate masks
masks = self.mask_generator.generate(image)
masks = sorted(masks, key=lambda x: x["area"], reverse=True)
if not masks:
logger.warning(
"Segmentation failed: No mask generated, return raw image."
)
output_image = Image.fromarray(image, mode="RGB")
else:
# Use the largest mask
best_mask = masks[0]["segmentation"]
mask = (best_mask * 255).astype(np.uint8)
mask = filter_small_connected_components(
mask, area_ratio=self.area_ratio
)
# Apply the mask to remove the background
background_removed = cv2.bitwise_and(image, image, mask=mask)
output_image = np.dstack((background_removed, mask))
output_image = Image.fromarray(output_image, mode="RGBA")
if save_path is not None:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
output_image.save(save_path)
return output_image
class SAMPredictor(object):
def __init__(
self,
checkpoint: str = None,
model_type: str = "vit_h",
binary_thresh: float = 0.1,
device: str = "cuda",
):
self.device = device
self.model_type = model_type
if checkpoint is None:
suffix = "sam"
model_path = snapshot_download(
repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
)
checkpoint = os.path.join(
model_path, suffix, "sam_vit_h_4b8939.pth"
)
self.predictor = self._load_sam_model(checkpoint)
self.binary_thresh = binary_thresh
def _load_sam_model(self, checkpoint: str) -> SamPredictor:
sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
sam.to(device=self.device)
return SamPredictor(sam)
def preprocess_image(self, image: Image.Image) -> np.ndarray:
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image).convert("RGB")
image = resize_pil(image)
image = np.array(image.convert("RGB"))
return image
def generate_masks(
self,
image: np.ndarray,
selected_points: list[list[int]],
) -> np.ndarray:
if len(selected_points) == 0:
return []
points = (
torch.Tensor([p for p, _ in selected_points])
.to(self.predictor.device)
.unsqueeze(1)
)
labels = (
torch.Tensor([int(l) for _, l in selected_points])
.to(self.predictor.device)
.unsqueeze(1)
)
transformed_points = self.predictor.transform.apply_coords_torch(
points, image.shape[:2]
)
masks, scores, _ = self.predictor.predict_torch(
point_coords=transformed_points,
point_labels=labels,
multimask_output=True,
)
valid_mask = masks[:, torch.argmax(scores, dim=1)]
masks_pos = valid_mask[labels[:, 0] == 1, 0].cpu().detach().numpy()
masks_neg = valid_mask[labels[:, 0] == 0, 0].cpu().detach().numpy()
if len(masks_neg) == 0:
masks_neg = np.zeros_like(masks_pos)
if len(masks_pos) == 0:
masks_pos = np.zeros_like(masks_neg)
masks_neg = masks_neg.max(axis=0, keepdims=True)
masks_pos = masks_pos.max(axis=0, keepdims=True)
valid_mask = (masks_pos.astype(int) - masks_neg.astype(int)).clip(0, 1)
binary_mask = (valid_mask > self.binary_thresh).astype(np.int32)
return [(mask, f"mask_{i}") for i, mask in enumerate(binary_mask)]
def get_segmented_image(
self, image: np.ndarray, masks: list[tuple[np.ndarray, str]]
) -> Image.Image:
seg_image = Image.fromarray(image, mode="RGB")
alpha_channel = np.zeros(
(seg_image.height, seg_image.width), dtype=np.uint8
)
for mask, _ in masks:
# Use the maximum to combine multiple masks
alpha_channel = np.maximum(alpha_channel, mask)
alpha_channel = np.clip(alpha_channel, 0, 1)
alpha_channel = (alpha_channel * 255).astype(np.uint8)
alpha_image = Image.fromarray(alpha_channel, mode="L")
r, g, b = seg_image.split()
seg_image = Image.merge("RGBA", (r, g, b, alpha_image))
return seg_image
def __call__(
self,
image: Union[str, Image.Image, np.ndarray],
selected_points: list[list[int]],
) -> Image.Image:
image = self.preprocess_image(image)
self.predictor.set_image(image)
masks = self.generate_masks(image, selected_points)
return self.get_segmented_image(image, masks)
class RembgRemover(object):
def __init__(self):
self.rembg_session = rembg.new_session("u2net")
def __call__(
self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
) -> Image.Image:
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)
image = resize_pil(image)
output_image = rembg.remove(image, session=self.rembg_session)
if save_path is not None:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
output_image.save(save_path)
return output_image
class BMGG14Remover(object):
def __init__(self) -> None:
self.model = pipeline(
"image-segmentation",
model="briaai/RMBG-1.4",
trust_remote_code=True,
)
def __call__(
self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
):
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)
image = resize_pil(image)
output_image = self.model(image)
if save_path is not None:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
output_image.save(save_path)
return output_image
def invert_rgba_pil(
image: Image.Image, mask: Image.Image, save_path: str = None
) -> Image.Image:
mask = (255 - np.array(mask))[..., None]
image_array = np.concatenate([np.array(image), mask], axis=-1)
inverted_image = Image.fromarray(image_array, "RGBA")
if save_path is not None:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
inverted_image.save(save_path)
return inverted_image
def get_segmented_image_by_agent(
image: Image.Image,
sam_remover: SAMRemover,
rbg_remover: RembgRemover,
seg_checker: ImageSegChecker = None,
save_path: str = None,
mode: Literal["loose", "strict"] = "loose",
) -> Image.Image:
def _is_valid_seg(raw_img: Image.Image, seg_img: Image.Image) -> bool:
if seg_checker is None:
return True
return raw_img.mode == "RGBA" and seg_checker([raw_img, seg_img])[0]
out_sam = f"{save_path}_sam.png" if save_path else None
out_sam_inv = f"{save_path}_sam_inv.png" if save_path else None
out_rbg = f"{save_path}_rbg.png" if save_path else None
seg_image = sam_remover(image, out_sam)
seg_image = seg_image.convert("RGBA")
_, _, _, alpha = seg_image.split()
seg_image_inv = invert_rgba_pil(image.convert("RGB"), alpha, out_sam_inv)
seg_image_rbg = rbg_remover(image, out_rbg)
final_image = None
if _is_valid_seg(image, seg_image):
final_image = seg_image
elif _is_valid_seg(image, seg_image_inv):
final_image = seg_image_inv
elif _is_valid_seg(image, seg_image_rbg):
logger.warning(f"Failed to segment by `SAM`, retry with `rembg`.")
final_image = seg_image_rbg
else:
if mode == "strict":
raise RuntimeError(
f"Failed to segment by `SAM` or `rembg`, abort."
)
logger.warning("Failed to segment by SAM or rembg, use raw image.")
final_image = image.convert("RGBA")
if save_path:
final_image.save(save_path)
final_image = trellis_preprocess(final_image)
return final_image
if __name__ == "__main__":
input_image = "outputs/text2image/demo_objects/electrical/sample_0.jpg"
output_image = "sample_0_seg2.png"
# input_image = "outputs/text2image/tmp/coffee_machine.jpeg"
# output_image = "outputs/text2image/tmp/coffee_machine_seg.png"
# input_image = "outputs/text2image/tmp/bucket.jpeg"
# output_image = "outputs/text2image/tmp/bucket_seg.png"
remover = SAMRemover(model_type="vit_h")
remover = RembgRemover()
clean_image = remover(input_image)
clean_image.save(output_image)
get_segmented_image_by_agent(
Image.open(input_image), remover, remover, None, "./test_seg.png"
)
remover = BMGG14Remover()
remover("embodied_gen/models/test_seg.jpg", "./seg.png")

View File

@ -0,0 +1,174 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import logging
import os
from typing import Union
import numpy as np
import spaces
import torch
from huggingface_hub import snapshot_download
from PIL import Image
from embodied_gen.data.utils import get_images_from_grid
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)
__all__ = [
"ImageStableSR",
"ImageRealESRGAN",
]
class ImageStableSR:
"""Super-resolution image upscaler using Stable Diffusion x4 upscaling model from StabilityAI."""
def __init__(
self,
model_path: str = "stabilityai/stable-diffusion-x4-upscaler",
device="cuda",
) -> None:
from diffusers import StableDiffusionUpscalePipeline
self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained(
model_path,
torch_dtype=torch.float16,
).to(device)
self.up_pipeline_x4.set_progress_bar_config(disable=True)
self.up_pipeline_x4.enable_model_cpu_offload()
@spaces.GPU
def __call__(
self,
image: Union[Image.Image, np.ndarray],
prompt: str = "",
infer_step: int = 20,
) -> Image.Image:
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
image = image.convert("RGB")
with torch.no_grad():
upscaled_image = self.up_pipeline_x4(
image=image,
prompt=[prompt],
num_inference_steps=infer_step,
).images[0]
return upscaled_image
class ImageRealESRGAN:
"""A wrapper for Real-ESRGAN-based image super-resolution.
This class uses the RealESRGAN model to perform image upscaling,
typically by a factor of 4.
Attributes:
outscale (int): The output image scale factor (e.g., 2, 4).
model_path (str): Path to the pre-trained model weights.
"""
def __init__(self, outscale: int, model_path: str = None) -> None:
# monkey patch to support torchvision>=0.16
import torchvision
from packaging import version
if version.parse(torchvision.__version__) > version.parse("0.16"):
import sys
import types
import torchvision.transforms.functional as TF
functional_tensor = types.ModuleType(
"torchvision.transforms.functional_tensor"
)
functional_tensor.rgb_to_grayscale = TF.rgb_to_grayscale
sys.modules["torchvision.transforms.functional_tensor"] = (
functional_tensor
)
self.outscale = outscale
self.upsampler = None
if model_path is None:
suffix = "super_resolution"
model_path = snapshot_download(
repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
)
model_path = os.path.join(
model_path, suffix, "RealESRGAN_x4plus.pth"
)
self.model_path = model_path
def _lazy_init(self):
if self.upsampler is None:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
)
self.upsampler = RealESRGANer(
scale=4,
model_path=self.model_path,
model=model,
pre_pad=0,
half=True,
)
@spaces.GPU
def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
self._lazy_init()
if isinstance(image, Image.Image):
image = np.array(image)
with torch.no_grad():
output, _ = self.upsampler.enhance(image, outscale=self.outscale)
return Image.fromarray(output)
if __name__ == "__main__":
color_path = "outputs/texture_mesh_gen/multi_view/color_sample0.png"
# Use RealESRGAN_x4plus for x4 (512->2048) image super resolution.
super_model = ImageRealESRGAN(outscale=4)
multiviews = get_images_from_grid(color_path, img_size=512)
multiviews = [super_model(img.convert("RGB")) for img in multiviews]
for idx, img in enumerate(multiviews):
img.save(f"sr{idx}.png")
# # Use stable diffusion for x4 (512->2048) image super resolution.
# super_model = ImageStableSR()
# multiviews = get_images_from_grid(color_path, img_size=512)
# multiviews = [super_model(img) for img in multiviews]
# for idx, img in enumerate(multiviews):
# img.save(f"sr_stable{idx}.png")

View File

@ -0,0 +1,171 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import logging
import random
import numpy as np
import torch
from diffusers import (
AutoencoderKL,
EulerDiscreteScheduler,
UNet2DConditionModel,
)
from kolors.models.modeling_chatglm import ChatGLMModel
from kolors.models.tokenization_chatglm import ChatGLMTokenizer
from kolors.models.unet_2d_condition import (
UNet2DConditionModel as UNet2DConditionModelIP,
)
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import (
StableDiffusionXLPipeline,
)
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import ( # noqa
StableDiffusionXLPipeline as StableDiffusionXLPipelineIP,
)
from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
__all__ = [
"build_text2img_ip_pipeline",
"build_text2img_pipeline",
"text2img_gen",
]
def build_text2img_ip_pipeline(
ckpt_dir: str,
ref_scale: float,
device: str = "cuda",
) -> StableDiffusionXLPipelineIP:
text_encoder = ChatGLMModel.from_pretrained(
f"{ckpt_dir}/text_encoder", torch_dtype=torch.float16
).half()
tokenizer = ChatGLMTokenizer.from_pretrained(f"{ckpt_dir}/text_encoder")
vae = AutoencoderKL.from_pretrained(
f"{ckpt_dir}/vae", revision=None
).half()
scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
unet = UNet2DConditionModelIP.from_pretrained(
f"{ckpt_dir}/unet", revision=None
).half()
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
f"{ckpt_dir}/../Kolors-IP-Adapter-Plus/image_encoder",
ignore_mismatched_sizes=True,
).to(dtype=torch.float16)
clip_image_processor = CLIPImageProcessor(size=336, crop_size=336)
pipe = StableDiffusionXLPipelineIP(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
image_encoder=image_encoder,
feature_extractor=clip_image_processor,
force_zeros_for_empty_prompt=False,
)
if hasattr(pipe.unet, "encoder_hid_proj"):
pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj
pipe.load_ip_adapter(
f"{ckpt_dir}/../Kolors-IP-Adapter-Plus",
subfolder="",
weight_name=["ip_adapter_plus_general.bin"],
)
pipe.set_ip_adapter_scale([ref_scale])
pipe = pipe.to(device)
pipe.enable_model_cpu_offload()
# pipe.enable_xformers_memory_efficient_attention()
# pipe.enable_vae_slicing()
return pipe
def build_text2img_pipeline(
ckpt_dir: str,
device: str = "cuda",
) -> StableDiffusionXLPipeline:
text_encoder = ChatGLMModel.from_pretrained(
f"{ckpt_dir}/text_encoder", torch_dtype=torch.float16
).half()
tokenizer = ChatGLMTokenizer.from_pretrained(f"{ckpt_dir}/text_encoder")
vae = AutoencoderKL.from_pretrained(
f"{ckpt_dir}/vae", revision=None
).half()
scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
unet = UNet2DConditionModel.from_pretrained(
f"{ckpt_dir}/unet", revision=None
).half()
pipe = StableDiffusionXLPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
force_zeros_for_empty_prompt=False,
)
pipe = pipe.to(device)
pipe.enable_model_cpu_offload()
pipe.enable_xformers_memory_efficient_attention()
return pipe
def text2img_gen(
prompt: str,
n_sample: int,
guidance_scale: float,
pipeline: StableDiffusionXLPipeline | StableDiffusionXLPipelineIP,
ip_image: Image.Image | str = None,
image_wh: tuple[int, int] = [1024, 1024],
infer_step: int = 50,
ip_image_size: int = 512,
seed: int = None,
) -> list[Image.Image]:
prompt = "Single " + prompt + ", in the center of the image"
prompt += ", high quality, high resolution, best quality, white background, 3D style" # noqa
logger.info(f"Processing prompt: {prompt}")
generator = None
if seed is not None:
generator = torch.Generator(pipeline.device).manual_seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
kwargs = dict(
prompt=prompt,
height=image_wh[1],
width=image_wh[0],
num_inference_steps=infer_step,
guidance_scale=guidance_scale,
num_images_per_prompt=n_sample,
generator=generator,
)
if ip_image is not None:
if isinstance(ip_image, str):
ip_image = Image.open(ip_image)
ip_image = ip_image.resize((ip_image_size, ip_image_size))
kwargs.update(ip_adapter_image=[ip_image])
return pipeline(**kwargs).images

View File

@ -0,0 +1,108 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import os
import torch
from diffusers import AutoencoderKL, DiffusionPipeline, EulerDiscreteScheduler
from huggingface_hub import snapshot_download
from kolors.models.controlnet import ControlNetModel
from kolors.models.modeling_chatglm import ChatGLMModel
from kolors.models.tokenization_chatglm import ChatGLMTokenizer
from kolors.models.unet_2d_condition import UNet2DConditionModel
from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import (
StableDiffusionXLControlNetImg2ImgPipeline,
)
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
__all__ = [
"build_texture_gen_pipe",
]
def build_texture_gen_pipe(
base_ckpt_dir: str,
controlnet_ckpt: str = None,
ip_adapt_scale: float = 0,
device: str = "cuda",
) -> DiffusionPipeline:
tokenizer = ChatGLMTokenizer.from_pretrained(
f"{base_ckpt_dir}/Kolors/text_encoder"
)
text_encoder = ChatGLMModel.from_pretrained(
f"{base_ckpt_dir}/Kolors/text_encoder", torch_dtype=torch.float16
).half()
vae = AutoencoderKL.from_pretrained(
f"{base_ckpt_dir}/Kolors/vae", revision=None
).half()
unet = UNet2DConditionModel.from_pretrained(
f"{base_ckpt_dir}/Kolors/unet", revision=None
).half()
scheduler = EulerDiscreteScheduler.from_pretrained(
f"{base_ckpt_dir}/Kolors/scheduler"
)
if controlnet_ckpt is None:
suffix = "geo_cond_mv"
model_path = snapshot_download(
repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
)
controlnet_ckpt = os.path.join(model_path, suffix)
controlnet = ControlNetModel.from_pretrained(
controlnet_ckpt, use_safetensors=True
).half()
# IP-Adapter model
image_encoder = None
clip_image_processor = None
if ip_adapt_scale > 0:
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
f"{base_ckpt_dir}/Kolors-IP-Adapter-Plus/image_encoder",
# ignore_mismatched_sizes=True,
).to(dtype=torch.float16)
ip_img_size = 336
clip_image_processor = CLIPImageProcessor(
size=ip_img_size, crop_size=ip_img_size
)
pipe = StableDiffusionXLControlNetImg2ImgPipeline(
vae=vae,
controlnet=controlnet,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
image_encoder=image_encoder,
feature_extractor=clip_image_processor,
force_zeros_for_empty_prompt=False,
)
if ip_adapt_scale > 0:
if hasattr(pipe.unet, "encoder_hid_proj"):
pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj
pipe.load_ip_adapter(
f"{base_ckpt_dir}/Kolors-IP-Adapter-Plus",
subfolder="",
weight_name=["ip_adapter_plus_general.bin"],
)
pipe.set_ip_adapter_scale([ip_adapt_scale])
pipe = pipe.to(device)
pipe.enable_model_cpu_offload()
return pipe

View File

@ -0,0 +1,311 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import argparse
import logging
import os
import sys
from glob import glob
from shutil import copy, copytree
import numpy as np
import trimesh
from PIL import Image
from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
from embodied_gen.data.utils import delete_dir, trellis_preprocess
from embodied_gen.models.delight_model import DelightingModel
from embodied_gen.models.gs_model import GaussianOperator
from embodied_gen.models.segment_model import (
BMGG14Remover,
RembgRemover,
SAMPredictor,
)
from embodied_gen.models.sr_model import ImageRealESRGAN
from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
from embodied_gen.utils.gpt_clients import GPT_CLIENT
from embodied_gen.utils.process_media import merge_images_video, render_video
from embodied_gen.utils.tags import VERSION
from embodied_gen.validators.quality_checkers import (
BaseChecker,
ImageAestheticChecker,
ImageSegChecker,
MeshGeoChecker,
)
from embodied_gen.validators.urdf_convertor import URDFGenerator
current_file_path = os.path.abspath(__file__)
current_dir = os.path.dirname(current_file_path)
sys.path.append(os.path.join(current_dir, "../.."))
from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)
os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
"~/.cache/torch_extensions"
)
os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
os.environ["SPCONV_ALGO"] = "native"
DELIGHT = DelightingModel()
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
RBG_REMOVER = RembgRemover()
RBG14_REMOVER = BMGG14Remover()
SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
"microsoft/TRELLIS-image-large"
)
PIPELINE.cuda()
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
AESTHETIC_CHECKER = ImageAestheticChecker()
CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
TMP_DIR = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
)
def parse_args():
parser = argparse.ArgumentParser(description="Image to 3D pipeline args.")
parser.add_argument(
"--image_path", type=str, nargs="+", help="Path to the input images."
)
parser.add_argument(
"--image_root", type=str, help="Path to the input images folder."
)
parser.add_argument(
"--output_root",
type=str,
required=True,
help="Root directory for saving outputs.",
)
parser.add_argument(
"--height_range",
type=str,
default=None,
help="The hight in meter to restore the mesh real size.",
)
parser.add_argument(
"--mass_range",
type=str,
default=None,
help="The mass in kg to restore the mesh real weight.",
)
parser.add_argument("--asset_type", type=str, default=None)
parser.add_argument("--skip_exists", action="store_true")
parser.add_argument("--strict_seg", action="store_true")
parser.add_argument("--version", type=str, default=VERSION)
parser.add_argument("--remove_intermediate", type=bool, default=True)
args = parser.parse_args()
assert (
args.image_path or args.image_root
), "Please provide either --image_path or --image_root."
if not args.image_path:
args.image_path = glob(os.path.join(args.image_root, "*.png"))
args.image_path += glob(os.path.join(args.image_root, "*.jpg"))
args.image_path += glob(os.path.join(args.image_root, "*.jpeg"))
return args
if __name__ == "__main__":
args = parse_args()
for image_path in args.image_path:
try:
filename = os.path.basename(image_path).split(".")[0]
output_root = args.output_root
if args.image_root is not None or len(args.image_path) > 1:
output_root = os.path.join(output_root, filename)
os.makedirs(output_root, exist_ok=True)
mesh_out = f"{output_root}/{filename}.obj"
if args.skip_exists and os.path.exists(mesh_out):
logger.info(
f"Skip {image_path}, already processed in {mesh_out}"
)
continue
image = Image.open(image_path)
image.save(f"{output_root}/{filename}_raw.png")
# Segmentation: Get segmented image using SAM or Rembg.
seg_path = f"{output_root}/{filename}_cond.png"
if image.mode != "RGBA":
seg_image = RBG_REMOVER(image, save_path=seg_path)
seg_image = trellis_preprocess(seg_image)
else:
seg_image = image
seg_image.save(seg_path)
# Run the pipeline
try:
outputs = PIPELINE.run(
seg_image,
preprocess_image=False,
# Optional parameters
# seed=1,
# sparse_structure_sampler_params={
# "steps": 12,
# "cfg_strength": 7.5,
# },
# slat_sampler_params={
# "steps": 12,
# "cfg_strength": 3,
# },
)
except Exception as e:
logger.error(
f"[Pipeline Failed] process {image_path}: {e}, skip."
)
continue
# Render and save color and mesh videos
gs_model = outputs["gaussian"][0]
mesh_model = outputs["mesh"][0]
color_images = render_video(gs_model)["color"]
normal_images = render_video(mesh_model)["normal"]
video_path = os.path.join(output_root, "gs_mesh.mp4")
merge_images_video(color_images, normal_images, video_path)
# Save the raw Gaussian model
gs_path = mesh_out.replace(".obj", "_gs.ply")
gs_model.save_ply(gs_path)
# Rotate mesh and GS by 90 degrees around Z-axis.
rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
# Addtional rotation for GS to align mesh.
gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix)
pose = GaussianOperator.trans_to_quatpose(gs_rot)
aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
GaussianOperator.resave_ply(
in_ply=gs_path,
out_ply=aligned_gs_path,
instance_pose=pose,
device="cpu",
)
color_path = os.path.join(output_root, "color.png")
render_gs_api(aligned_gs_path, color_path)
mesh = trimesh.Trimesh(
vertices=mesh_model.vertices.cpu().numpy(),
faces=mesh_model.faces.cpu().numpy(),
)
mesh.vertices = mesh.vertices @ np.array(mesh_add_rot)
mesh.vertices = mesh.vertices @ np.array(rot_matrix)
mesh_obj_path = os.path.join(output_root, f"{filename}.obj")
mesh.export(mesh_obj_path)
mesh = backproject_api(
delight_model=DELIGHT,
imagesr_model=IMAGESR_MODEL,
color_path=color_path,
mesh_path=mesh_obj_path,
output_path=mesh_obj_path,
skip_fix_mesh=False,
delight=True,
texture_wh=[2048, 2048],
)
mesh_glb_path = os.path.join(output_root, f"{filename}.glb")
mesh.export(mesh_glb_path)
urdf_convertor = URDFGenerator(GPT_CLIENT, render_view_num=4)
asset_attrs = {
"version": VERSION,
"gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply",
}
if args.height_range:
min_height, max_height = map(
float, args.height_range.split("-")
)
asset_attrs["min_height"] = min_height
asset_attrs["max_height"] = max_height
if args.mass_range:
min_mass, max_mass = map(float, args.mass_range.split("-"))
asset_attrs["min_mass"] = min_mass
asset_attrs["max_mass"] = max_mass
if args.asset_type:
asset_attrs["category"] = args.asset_type
if args.version:
asset_attrs["version"] = args.version
urdf_root = f"{output_root}/URDF_{filename}"
urdf_path = urdf_convertor(
mesh_path=mesh_obj_path,
output_root=urdf_root,
**asset_attrs,
)
# Rescale GS and save to URDF/mesh folder.
real_height = urdf_convertor.get_attr_from_urdf(
urdf_path, attr_name="real_height"
)
out_gs = f"{urdf_root}/{urdf_convertor.output_mesh_dir}/{filename}_gs.ply" # noqa
GaussianOperator.resave_ply(
in_ply=aligned_gs_path,
out_ply=out_gs,
real_height=real_height,
device="cpu",
)
# Quality check and update .urdf file.
mesh_out = f"{urdf_root}/{urdf_convertor.output_mesh_dir}/{filename}.obj" # noqa
trimesh.load(mesh_out).export(mesh_out.replace(".obj", ".glb"))
image_dir = f"{urdf_root}/{urdf_convertor.output_render_dir}/image_color" # noqa
image_paths = glob(f"{image_dir}/*.png")
images_list = []
for checker in CHECKERS:
images = image_paths
if isinstance(checker, ImageSegChecker):
images = [
f"{output_root}/{filename}_raw.png",
f"{output_root}/{filename}_cond.png",
]
images_list.append(images)
results = BaseChecker.validate(CHECKERS, images_list)
urdf_convertor.add_quality_tag(urdf_path, results)
# Organize the final result files
result_dir = f"{output_root}/result"
os.makedirs(result_dir, exist_ok=True)
copy(urdf_path, f"{result_dir}/{os.path.basename(urdf_path)}")
copytree(
f"{urdf_root}/{urdf_convertor.output_mesh_dir}",
f"{result_dir}/{urdf_convertor.output_mesh_dir}",
)
copy(video_path, f"{result_dir}/video.mp4")
if args.remove_intermediate:
delete_dir(output_root, keep_subs=["result"])
except Exception as e:
logger.error(f"Failed to process {image_path}: {e}, skip.")
continue
logger.info(f"Processing complete. Outputs saved to {args.output_root}")

View File

@ -0,0 +1,175 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import argparse
import logging
import math
import os
import cv2
import numpy as np
import spaces
import torch
from tqdm import tqdm
from embodied_gen.data.utils import (
CameraSetting,
init_kal_camera,
normalize_vertices_array,
)
from embodied_gen.models.gs_model import GaussianOperator
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Render GS color images")
parser.add_argument(
"--input_gs", type=str, help="Input render GS.ply path."
)
parser.add_argument(
"--output_path",
type=str,
help="Output grid image path for rendered GS color images.",
)
parser.add_argument(
"--num_images", type=int, default=6, help="Number of images to render."
)
parser.add_argument(
"--elevation",
type=float,
nargs="+",
default=[20.0, -10.0],
help="Elevation angles for the camera (default: [20.0, -10.0])",
)
parser.add_argument(
"--distance",
type=float,
default=5,
help="Camera distance (default: 5)",
)
parser.add_argument(
"--resolution_hw",
type=int,
nargs=2,
default=(512, 512),
help="Resolution of the output images (default: (512, 512))",
)
parser.add_argument(
"--fov",
type=float,
default=30,
help="Field of view in degrees (default: 30)",
)
parser.add_argument(
"--device",
type=str,
choices=["cpu", "cuda"],
default="cuda",
help="Device to run on (default: `cuda`)",
)
parser.add_argument(
"--image_size",
type=int,
default=512,
help="Output image size for single view in color grid (default: 512)",
)
args, unknown = parser.parse_known_args()
return args
def load_gs_model(
input_gs: str, pre_quat: list[float] = [0.0, 0.7071, 0.0, -0.7071]
) -> GaussianOperator:
gs_model = GaussianOperator.load_from_ply(input_gs)
# Normalize vertices to [-1, 1], center to (0, 0, 0).
_, scale, center = normalize_vertices_array(gs_model._means)
scale, center = float(scale), center.tolist()
transpose = [*[-v for v in center], *pre_quat]
instance_pose = torch.tensor(transpose).to(gs_model.device)
gs_model = gs_model.get_gaussians(instance_pose=instance_pose)
gs_model.rescale(scale)
return gs_model
@spaces.GPU
def entrypoint(input_gs: str = None, output_path: str = None) -> None:
args = parse_args()
if isinstance(input_gs, str):
args.input_gs = input_gs
if isinstance(output_path, str):
args.output_path = output_path
# Setup camera parameters
camera_params = CameraSetting(
num_images=args.num_images,
elevation=args.elevation,
distance=args.distance,
resolution_hw=args.resolution_hw,
fov=math.radians(args.fov),
device=args.device,
)
camera = init_kal_camera(camera_params)
matrix_mv = camera.view_matrix() # (n_cam 4 4) world2cam
matrix_mv[:, :3, 3] = -matrix_mv[:, :3, 3]
w2cs = matrix_mv.to(camera_params.device)
c2ws = [torch.linalg.inv(matrix) for matrix in w2cs]
Ks = torch.tensor(camera_params.Ks).to(camera_params.device)
# Load GS model and normalize.
gs_model = load_gs_model(args.input_gs, pre_quat=[0.0, 0.0, 1.0, 0.0])
# Render GS color images.
images = []
for idx in tqdm(range(len(c2ws)), desc="Rendering GS"):
result = gs_model.render(
c2ws[idx],
Ks=Ks,
image_width=camera_params.resolution_hw[1],
image_height=camera_params.resolution_hw[0],
)
color = cv2.resize(
result.rgba,
(args.image_size, args.image_size),
interpolation=cv2.INTER_AREA,
)
images.append(color)
# Cat color images into grid image and save.
select_idxs = [[0, 2, 1], [5, 4, 3]] # fix order for 6 views
grid_image = []
for row_idxs in select_idxs:
row_image = []
for row_idx in row_idxs:
row_image.append(images[row_idx])
row_image = np.concatenate(row_image, axis=1)
grid_image.append(row_image)
grid_image = np.concatenate(grid_image, axis=0)
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
cv2.imwrite(args.output_path, grid_image)
logger.info(f"Saved grid image to {args.output_path}")
if __name__ == "__main__":
entrypoint()

View File

@ -0,0 +1,198 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import logging
import os
import random
from typing import List, Tuple
import fire
import numpy as np
import torch
from diffusers.utils import make_image_grid
from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import (
StableDiffusionXLControlNetImg2ImgPipeline,
)
from PIL import Image, ImageEnhance, ImageFilter
from torchvision import transforms
from embodied_gen.data.datasets import Asset3dGenDataset
from embodied_gen.models.texture_model import build_texture_gen_pipe
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def get_init_noise_image(image: Image.Image) -> Image.Image:
blurred_image = image.convert("L").filter(
ImageFilter.GaussianBlur(radius=3)
)
enhancer = ImageEnhance.Contrast(blurred_image)
image_decreased_contrast = enhancer.enhance(factor=0.5)
return image_decreased_contrast
def infer_pipe(
index_file: str,
controlnet_ckpt: str = None,
uid: str = None,
prompt: str = None,
controlnet_cond_scale: float = 0.4,
control_guidance_end: float = 0.9,
strength: float = 1.0,
num_inference_steps: int = 50,
guidance_scale: float = 10,
ip_adapt_scale: float = 0,
ip_img_path: str = None,
sub_idxs: List[List[int]] = None,
num_images_per_prompt: int = 3, # increase if want similar images.
device: str = "cuda",
save_dir: str = "infer_vis",
seed: int = None,
target_hw: tuple[int, int] = (512, 512),
pipeline: StableDiffusionXLControlNetImg2ImgPipeline = None,
) -> str:
# sub_idxs = [[0, 1, 2], [3, 4, 5]] # None for single image.
if sub_idxs is None:
sub_idxs = [[random.randint(0, 5)]] # 6 views.
target_hw = [2 * size for size in target_hw]
transform_list = [
transforms.Resize(
target_hw, interpolation=transforms.InterpolationMode.BILINEAR
),
transforms.CenterCrop(target_hw),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
image_transform = transforms.Compose(transform_list)
control_transform = transforms.Compose(transform_list[:-1])
grid_hw = (target_hw[0] * len(sub_idxs), target_hw[1] * len(sub_idxs[0]))
dataset = Asset3dGenDataset(
index_file, target_hw=grid_hw, sub_idxs=sub_idxs
)
if uid is None:
uid = random.choice(list(dataset.meta_info.keys()))
if prompt is None:
prompt = dataset.meta_info[uid]["capture"]
if isinstance(prompt, List) or isinstance(prompt, Tuple):
prompt = ", ".join(map(str, prompt))
# prompt += "high quality, ultra-clear, high resolution, best quality, 4k"
# prompt += "高品质,清晰,细节"
prompt += ", high quality, high resolution, best quality"
# prompt += ", with diffuse lighting, showing no reflections."
logger.info(f"Inference with prompt: {prompt}")
negative_prompt = "nsfw,阴影,低分辨率,伪影、模糊,霓虹灯,高光,镜面反射"
control_image = dataset.fetch_sample_grid_images(
uid,
attrs=["image_view_normal", "image_position", "image_mask"],
sub_idxs=sub_idxs,
transform=control_transform,
)
color_image = dataset.fetch_sample_grid_images(
uid,
attrs=["image_color"],
sub_idxs=sub_idxs,
transform=image_transform,
)
normal_pil, position_pil, mask_pil, color_pil = dataset.visualize_item(
control_image,
color_image,
save_dir=save_dir,
)
if pipeline is None:
pipeline = build_texture_gen_pipe(
base_ckpt_dir="./weights",
controlnet_ckpt=controlnet_ckpt,
ip_adapt_scale=ip_adapt_scale,
device=device,
)
if ip_adapt_scale > 0 and ip_img_path is not None and len(ip_img_path) > 0:
ip_image = Image.open(ip_img_path).convert("RGB")
ip_image = ip_image.resize(target_hw[::-1])
ip_image = [ip_image]
pipeline.set_ip_adapter_scale([ip_adapt_scale])
else:
ip_image = None
generator = None
if seed is not None:
generator = torch.Generator(device).manual_seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
init_image = get_init_noise_image(normal_pil)
# init_image = get_init_noise_image(color_pil)
images = []
row_num, col_num = 2, 3
img_save_paths = []
while len(images) < col_num:
image = pipeline(
prompt=prompt,
image=init_image,
controlnet_conditioning_scale=controlnet_cond_scale,
control_guidance_end=control_guidance_end,
strength=strength,
control_image=control_image[None, ...],
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
ip_adapter_image=ip_image,
generator=generator,
).images
images.extend(image)
grid_image = [normal_pil, position_pil, color_pil] + images[:col_num]
# save_dir = os.path.join(save_dir, uid)
os.makedirs(save_dir, exist_ok=True)
for idx in range(col_num):
rgba_image = Image.merge("RGBA", (*images[idx].split(), mask_pil))
img_save_path = os.path.join(save_dir, f"color_sample{idx}.png")
rgba_image.save(img_save_path)
img_save_paths.append(img_save_path)
sub_idxs = "_".join(
[str(item) for sublist in sub_idxs for item in sublist]
)
save_path = os.path.join(
save_dir, f"sample_idx{str(sub_idxs)}_ip{ip_adapt_scale}.jpg"
)
make_image_grid(grid_image, row_num, col_num).save(save_path)
logger.info(f"Visualize in {save_path}")
return img_save_paths
def entrypoint() -> None:
fire.Fire(infer_pipe)
if __name__ == "__main__":
entrypoint()

View File

@ -0,0 +1,168 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import argparse
import logging
import os
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import (
StableDiffusionXLPipeline,
)
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import ( # noqa
StableDiffusionXLPipeline as StableDiffusionXLPipelineIP,
)
from tqdm import tqdm
from embodied_gen.models.text_model import (
build_text2img_ip_pipeline,
build_text2img_pipeline,
text2img_gen,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Text to Image.")
parser.add_argument(
"--prompts",
type=str,
nargs="+",
help="List of prompts (space-separated).",
)
parser.add_argument(
"--ref_image",
type=str,
nargs="+",
help="List of ref_image paths (space-separated).",
)
parser.add_argument(
"--output_root",
type=str,
help="Root directory for saving outputs.",
)
parser.add_argument(
"--guidance_scale",
type=float,
default=12.0,
help="Guidance scale for the diffusion model.",
)
parser.add_argument(
"--ref_scale",
type=float,
default=0.3,
help="Reference image scale for the IP adapter.",
)
parser.add_argument(
"--n_sample",
type=int,
default=1,
)
parser.add_argument(
"--resolution",
type=int,
default=1024,
)
parser.add_argument(
"--infer_step",
type=int,
default=50,
)
parser.add_argument(
"--seed",
type=int,
default=0,
)
args = parser.parse_args()
return args
def entrypoint(
pipeline: StableDiffusionXLPipeline | StableDiffusionXLPipelineIP = None,
**kwargs,
) -> list[str]:
args = parse_args()
for k, v in kwargs.items():
if hasattr(args, k) and v is not None:
setattr(args, k, v)
prompts = args.prompts
if len(prompts) == 1 and prompts[0].endswith(".txt"):
with open(prompts[0], "r") as f:
prompts = f.readlines()
prompts = [
prompt.strip() for prompt in prompts if prompt.strip() != ""
]
os.makedirs(args.output_root, exist_ok=True)
ip_img_paths = args.ref_image
if ip_img_paths is None or len(ip_img_paths) == 0:
args.ref_scale = 0
ip_img_paths = [None] * len(prompts)
elif isinstance(ip_img_paths, str):
ip_img_paths = [ip_img_paths] * len(prompts)
elif isinstance(ip_img_paths, list):
if len(ip_img_paths) == 1:
ip_img_paths = ip_img_paths * len(prompts)
else:
raise ValueError("Invalid ref_image paths.")
assert len(ip_img_paths) == len(
prompts
), f"Number of ref images does not match prompts, {len(ip_img_paths)} != {len(prompts)}" # noqa
if pipeline is None:
if args.ref_scale > 0:
pipeline = build_text2img_ip_pipeline(
"weights/Kolors",
ref_scale=args.ref_scale,
)
else:
pipeline = build_text2img_pipeline("weights/Kolors")
for idx, (prompt, ip_img_path) in tqdm(
enumerate(zip(prompts, ip_img_paths)),
desc="Generating images",
total=len(prompts),
):
images = text2img_gen(
prompt=prompt,
n_sample=args.n_sample,
guidance_scale=args.guidance_scale,
pipeline=pipeline,
ip_image=ip_img_path,
image_wh=[args.resolution, args.resolution],
infer_step=args.infer_step,
seed=args.seed,
)
save_paths = []
for sub_idx, image in enumerate(images):
save_path = (
f"{args.output_root}/sample_{idx*args.n_sample+sub_idx}.png"
)
image.save(save_path)
save_paths.append(save_path)
logger.info(f"Images saved to {args.output_root}")
return save_paths
if __name__ == "__main__":
entrypoint()

View File

@ -0,0 +1,56 @@
#!/bin/bash
# Initialize variables
prompts=()
output_root=""
# Parse arguments
while [[ $# -gt 0 ]]; do
case "$1" in
--prompts)
shift
while [[ $# -gt 0 && ! "$1" =~ ^-- ]]; do
prompts+=("$1")
shift
done
;;
--output_root)
output_root="$2"
shift 2
;;
*)
echo "Unknown argument: $1"
exit 1
;;
esac
done
# Validate required arguments
if [[ ${#prompts[@]} -eq 0 || -z "$output_root" ]]; then
echo "Missing required arguments."
echo "Usage: bash run_text2asset3d.sh --prompts \"Prompt1\" \"Prompt2\" --output_root <path>"
exit 1
fi
# Print arguments (for debugging)
echo "Prompts:"
for p in "${prompts[@]}"; do
echo " - $p"
done
echo "Output root: ${output_root}"
# Concatenate prompts for Python command
prompt_args=""
for p in "${prompts[@]}"; do
prompt_args+="\"$p\" "
done
# Step 1: Text-to-Image
eval python3 embodied_gen/scripts/text2image.py \
--prompts ${prompt_args} \
--output_root "${output_root}/images"
# Step 2: Image-to-3D
python3 embodied_gen/scripts/imageto3d.py \
--image_root "${output_root}/images" \
--output_root "${output_root}/asset3d"

View File

@ -0,0 +1,80 @@
#!/bin/bash
while [[ $# -gt 0 ]]; do
case $1 in
--mesh_path)
mesh_path="$2"
shift 2
;;
--prompt)
prompt="$2"
shift 2
;;
--uuid)
uuid="$2"
shift 2
;;
--output_root)
output_root="$2"
shift 2
;;
*)
echo "unknown: $1"
exit 1
;;
esac
done
if [[ -z "$mesh_path" || -z "$prompt" || -z "$uuid" || -z "$output_root" ]]; then
echo "params missing"
echo "usage: bash run.sh --mesh_path <path> --prompt <text> --uuid <id> --output_root <path>"
exit 1
fi
# Step 1: drender-cli for condition rendering
drender-cli --mesh_path ${mesh_path} \
--output_root ${output_root}/condition \
--uuid ${uuid}
# Step 2: multi-view rendering
python embodied_gen/scripts/render_mv.py \
--index_file "${output_root}/condition/index.json" \
--controlnet_cond_scale 0.75 \
--guidance_scale 9 \
--strength 0.9 \
--num_inference_steps 40 \
--ip_adapt_scale 0 \
--ip_img_path None \
--uid ${uuid} \
--prompt "${prompt}" \
--save_dir "${output_root}/multi_view" \
--sub_idxs "[[0,1,2],[3,4,5]]" \
--seed 0
# Step 3: backprojection
backproject-cli --mesh_path ${mesh_path} \
--color_path ${output_root}/multi_view/color_sample0.png \
--output_path "${output_root}/texture_mesh/${uuid}.obj" \
--save_glb_path "${output_root}/texture_mesh/${uuid}.glb" \
--skip_fix_mesh \
--delight \
--no_save_delight_img
# Step 4: final rendering of textured mesh
drender-cli --mesh_path "${output_root}/texture_mesh/${uuid}.obj" \
--output_root ${output_root}/texture_mesh \
--num_images 90 \
--elevation 20 \
--with_mtl \
--gen_color_mp4 \
--pbr_light_factor 1.2
# Organize folders
rm -rf ${output_root}/condition
video_path="${output_root}/texture_mesh/${uuid}/color.mp4"
if [ -f "${video_path}" ]; then
cp "${video_path}" "${output_root}/texture_mesh/color.mp4"
echo "Resave video to ${output_root}/texture_mesh/color.mp4"
fi
rm -rf ${output_root}/texture_mesh/${uuid}

View File

@ -0,0 +1,211 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import base64
import logging
import os
from io import BytesIO
from typing import Optional
import yaml
from openai import AzureOpenAI, OpenAI # pip install openai
from PIL import Image
from tenacity import (
retry,
stop_after_attempt,
stop_after_delay,
wait_random_exponential,
)
from embodied_gen.utils.process_media import combine_images_to_base64
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class GPTclient:
"""A client to interact with the GPT model via OpenAI or Azure API."""
def __init__(
self,
endpoint: str,
api_key: str,
model_name: str = "yfb-gpt-4o",
api_version: str = None,
verbose: bool = False,
):
if api_version is not None:
self.client = AzureOpenAI(
azure_endpoint=endpoint,
api_key=api_key,
api_version=api_version,
)
else:
self.client = OpenAI(
base_url=endpoint,
api_key=api_key,
)
self.endpoint = endpoint
self.model_name = model_name
self.image_formats = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"}
self.verbose = verbose
logger.info(f"Using GPT model: {self.model_name}.")
@retry(
wait=wait_random_exponential(min=1, max=20),
stop=(stop_after_attempt(10) | stop_after_delay(30)),
)
def completion_with_backoff(self, **kwargs):
return self.client.chat.completions.create(**kwargs)
def query(
self,
text_prompt: str,
image_base64: Optional[list[str | Image.Image]] = None,
system_role: Optional[str] = None,
) -> Optional[str]:
"""Queries the GPT model with a text and optional image prompts.
Args:
text_prompt (str): The main text input that the model responds to.
image_base64 (Optional[List[str]]): A list of image base64 strings
or local image paths or PIL.Image to accompany the text prompt.
system_role (Optional[str]): Optional system-level instructions
that specify the behavior of the assistant.
Returns:
Optional[str]: The response content generated by the model based on
the prompt. Returns `None` if an error occurs.
"""
if system_role is None:
system_role = "You are a highly knowledgeable assistant specializing in physics, engineering, and object properties." # noqa
content_user = [
{
"type": "text",
"text": text_prompt,
},
]
# Process images if provided
if image_base64 is not None:
image_base64 = (
image_base64
if isinstance(image_base64, list)
else [image_base64]
)
for img in image_base64:
if isinstance(img, Image.Image):
buffer = BytesIO()
img.save(buffer, format=img.format or "PNG")
buffer.seek(0)
image_binary = buffer.read()
img = base64.b64encode(image_binary).decode("utf-8")
elif (
len(os.path.splitext(img)) > 1
and os.path.splitext(img)[-1].lower() in self.image_formats
):
if not os.path.exists(img):
raise FileNotFoundError(f"Image file not found: {img}")
with open(img, "rb") as f:
img = base64.b64encode(f.read()).decode("utf-8")
content_user.append(
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{img}"},
}
)
payload = {
"messages": [
{"role": "system", "content": system_role},
{"role": "user", "content": content_user},
],
"temperature": 0.1,
"max_tokens": 500,
"top_p": 0.1,
"frequency_penalty": 0,
"presence_penalty": 0,
"stop": None,
}
payload.update({"model": self.model_name})
response = None
try:
response = self.completion_with_backoff(**payload)
response = response.choices[0].message.content
except Exception as e:
logger.error(f"Error GPTclint {self.endpoint} API call: {e}")
response = None
if self.verbose:
logger.info(f"Prompt: {text_prompt}")
logger.info(f"Response: {response}")
return response
with open("embodied_gen/utils/gpt_config.yaml", "r") as f:
config = yaml.safe_load(f)
agent_type = config["agent_type"]
agent_config = config.get(agent_type, {})
# Prefer environment variables, fallback to YAML config
endpoint = os.environ.get("ENDPOINT", agent_config.get("endpoint"))
api_key = os.environ.get("API_KEY", agent_config.get("api_key"))
api_version = os.environ.get("API_VERSION", agent_config.get("api_version"))
model_name = os.environ.get("MODEL_NAME", agent_config.get("model_name"))
GPT_CLIENT = GPTclient(
endpoint=endpoint,
api_key=api_key,
api_version=api_version,
model_name=model_name,
)
if __name__ == "__main__":
if "openrouter" in GPT_CLIENT.endpoint:
response = GPT_CLIENT.query(
text_prompt="What is the content in each image?",
image_base64=combine_images_to_base64(
[
"outputs/text2image/demo_objects/bed/sample_0.jpg",
"outputs/imageto3d/v2/cups/sample_69/URDF_sample_69/qa_renders/image_color/003.png", # noqa
"outputs/text2image/demo_objects/cardboard/sample_1.jpg",
]
), # input raw image_path if only one image
)
print(response)
else:
response = GPT_CLIENT.query(
text_prompt="What is the content in the images?",
image_base64=[
Image.open("outputs/text2image/demo_objects/bed/sample_0.jpg"),
Image.open(
"outputs/imageto3d/v2/cups/sample_69/URDF_sample_69/qa_renders/image_color/003.png" # noqa
),
],
)
print(response)
# test2: text prompt
response = GPT_CLIENT.query(
text_prompt="What is the capital of China?"
)
print(response)

View File

@ -0,0 +1,14 @@
# config.yaml
agent_type: "qwen2.5-vl" # gpt-4o or qwen2.5-vl
gpt-4o:
endpoint: https://xxx.openai.azure.com
api_key: xxx
api_version: 2025-xx-xx
model_name: yfb-gpt-4o
qwen2.5-vl:
endpoint: https://openrouter.ai/api/v1
api_key: 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
api_version: null
model_name: qwen/qwen2.5-vl-72b-instruct:free

View File

@ -0,0 +1,328 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import base64
import logging
import math
import os
import subprocess
import sys
from glob import glob
from io import BytesIO
from typing import Union
import cv2
import imageio
import numpy as np
import PIL.Image as Image
import spaces
import torch
from moviepy.editor import VideoFileClip, clips_array
from tqdm import tqdm
current_file_path = os.path.abspath(__file__)
current_dir = os.path.dirname(current_file_path)
sys.path.append(os.path.join(current_dir, "../.."))
from thirdparty.TRELLIS.trellis.renderers.mesh_renderer import MeshRenderer
from thirdparty.TRELLIS.trellis.representations import MeshExtractResult
from thirdparty.TRELLIS.trellis.utils.render_utils import (
render_frames,
yaw_pitch_r_fov_to_extrinsics_intrinsics,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
__all__ = [
"render_asset3d",
"merge_images_video",
"filter_small_connected_components",
"filter_image_small_connected_components",
"combine_images_to_base64",
"render_mesh",
"render_video",
"create_mp4_from_images",
"create_gif_from_images",
]
@spaces.GPU
def render_asset3d(
mesh_path: str,
output_root: str,
distance: float = 5.0,
num_images: int = 1,
elevation: list[float] = (0.0,),
pbr_light_factor: float = 1.5,
return_key: str = "image_color/*",
output_subdir: str = "renders",
gen_color_mp4: bool = False,
gen_viewnormal_mp4: bool = False,
gen_glonormal_mp4: bool = False,
) -> list[str]:
command = [
"python3",
"embodied_gen/data/differentiable_render.py",
"--mesh_path",
mesh_path,
"--output_root",
output_root,
"--uuid",
output_subdir,
"--distance",
str(distance),
"--num_images",
str(num_images),
"--elevation",
*map(str, elevation),
"--pbr_light_factor",
str(pbr_light_factor),
"--with_mtl",
]
if gen_color_mp4:
command.append("--gen_color_mp4")
if gen_viewnormal_mp4:
command.append("--gen_viewnormal_mp4")
if gen_glonormal_mp4:
command.append("--gen_glonormal_mp4")
try:
subprocess.run(command, check=True)
except subprocess.CalledProcessError as e:
logger.error(f"Error occurred during rendering: {e}.")
dst_paths = glob(os.path.join(output_root, output_subdir, return_key))
return dst_paths
def merge_images_video(color_images, normal_images, output_path) -> None:
width = color_images[0].shape[1]
combined_video = [
np.hstack([rgb_img[:, : width // 2], normal_img[:, width // 2 :]])
for rgb_img, normal_img in zip(color_images, normal_images)
]
imageio.mimsave(output_path, combined_video, fps=50)
return
def merge_video_video(
video_path1: str, video_path2: str, output_path: str
) -> None:
"""Merge two videos by the left half and the right half of the videos."""
clip1 = VideoFileClip(video_path1)
clip2 = VideoFileClip(video_path2)
if clip1.size != clip2.size:
raise ValueError("The resolutions of the two videos do not match.")
width, height = clip1.size
clip1_half = clip1.crop(x1=0, y1=0, x2=width // 2, y2=height)
clip2_half = clip2.crop(x1=width // 2, y1=0, x2=width, y2=height)
final_clip = clips_array([[clip1_half, clip2_half]])
final_clip.write_videofile(output_path, codec="libx264")
def filter_small_connected_components(
mask: Union[Image.Image, np.ndarray],
area_ratio: float,
connectivity: int = 8,
) -> np.ndarray:
if isinstance(mask, Image.Image):
mask = np.array(mask)
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(
mask,
connectivity=connectivity,
)
small_components = np.zeros_like(mask, dtype=np.uint8)
mask_area = (mask != 0).sum()
min_area = mask_area // area_ratio
for label in range(1, num_labels):
area = stats[label, cv2.CC_STAT_AREA]
if area < min_area:
small_components[labels == label] = 255
mask = cv2.bitwise_and(mask, cv2.bitwise_not(small_components))
return mask
def filter_image_small_connected_components(
image: Union[Image.Image, np.ndarray],
area_ratio: float = 10,
connectivity: int = 8,
) -> np.ndarray:
if isinstance(image, Image.Image):
image = image.convert("RGBA")
image = np.array(image)
mask = image[..., 3]
mask = filter_small_connected_components(mask, area_ratio, connectivity)
image[..., 3] = mask
return image
def combine_images_to_base64(
images: list[str | Image.Image],
cat_row_col: tuple[int, int] = None,
target_wh: tuple[int, int] = (512, 512),
) -> str:
n_images = len(images)
if cat_row_col is None:
n_col = math.ceil(math.sqrt(n_images))
n_row = math.ceil(n_images / n_col)
else:
n_row, n_col = cat_row_col
images = [
Image.open(p).convert("RGB") if isinstance(p, str) else p
for p in images[: n_row * n_col]
]
images = [img.resize(target_wh) for img in images]
grid_w, grid_h = n_col * target_wh[0], n_row * target_wh[1]
grid = Image.new("RGB", (grid_w, grid_h), (255, 255, 255))
for idx, img in enumerate(images):
row, col = divmod(idx, n_col)
grid.paste(img, (col * target_wh[0], row * target_wh[1]))
buffer = BytesIO()
grid.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
@spaces.GPU
def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs):
renderer = MeshRenderer()
renderer.rendering_options.resolution = options.get("resolution", 512)
renderer.rendering_options.near = options.get("near", 1)
renderer.rendering_options.far = options.get("far", 100)
renderer.rendering_options.ssaa = options.get("ssaa", 4)
rets = {}
for extr, intr in tqdm(zip(extrinsics, intrinsics), desc="Rendering"):
res = renderer.render(sample, extr, intr)
if "normal" not in rets:
rets["normal"] = []
normal = torch.lerp(
torch.zeros_like(res["normal"]), res["normal"], res["mask"]
)
normal = np.clip(
normal.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255
).astype(np.uint8)
rets["normal"].append(normal)
return rets
@spaces.GPU
def render_video(
sample,
resolution=512,
bg_color=(0, 0, 0),
num_frames=300,
r=2,
fov=40,
**kwargs,
):
yaws = torch.linspace(0, 2 * 3.1415, num_frames)
yaws = yaws.tolist()
pitch = [0.5] * num_frames
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(
yaws, pitch, r, fov
)
render_fn = (
render_mesh if isinstance(sample, MeshExtractResult) else render_frames
)
result = render_fn(
sample,
extrinsics,
intrinsics,
{"resolution": resolution, "bg_color": bg_color},
**kwargs,
)
return result
def create_mp4_from_images(images, output_path, fps=10, prompt=None):
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.5
font_thickness = 1
color = (255, 255, 255)
position = (20, 25)
with imageio.get_writer(output_path, fps=fps) as writer:
for image in images:
image = image.clip(min=0, max=1)
image = (255.0 * image).astype(np.uint8)
image = image[..., :3]
if prompt is not None:
cv2.putText(
image,
prompt,
position,
font,
font_scale,
color,
font_thickness,
)
writer.append_data(image)
logger.info(f"MP4 video saved to {output_path}")
def create_gif_from_images(images, output_path, fps=10):
pil_images = []
for image in images:
image = image.clip(min=0, max=1)
image = (255.0 * image).astype(np.uint8)
image = Image.fromarray(image, mode="RGBA")
pil_images.append(image.convert("RGB"))
duration = 1000 // fps
pil_images[0].save(
output_path,
save_all=True,
append_images=pil_images[1:],
duration=duration,
loop=0,
)
logger.info(f"GIF saved to {output_path}")
if __name__ == "__main__":
# Example usage:
merge_video_video(
"outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh_glo_normal.mp4", # noqa
"outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh.mp4", # noqa
"merge.mp4",
)
image_base64 = combine_images_to_base64(
[
"apps/assets/example_image/sample_00.jpg",
"apps/assets/example_image/sample_01.jpg",
"apps/assets/example_image/sample_02.jpg",
]
)

View File

@ -0,0 +1 @@
VERSION = "v0.1.0"

View File

@ -0,0 +1,149 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import os
import clip
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
from PIL import Image
class AestheticPredictor:
"""Aesthetic Score Predictor.
Checkpoints from https://github.com/christophschuhmann/improved-aesthetic-predictor/tree/main
Args:
clip_model_dir (str): Path to the directory of the CLIP model.
sac_model_path (str): Path to the pre-trained SAC model.
device (str): Device to use for computation ("cuda" or "cpu").
"""
def __init__(self, clip_model_dir=None, sac_model_path=None, device="cpu"):
self.device = device
if clip_model_dir is None:
model_path = snapshot_download(
repo_id="xinjjj/RoboAssetGen", allow_patterns="aesthetic/*"
)
suffix = "aesthetic"
model_path = snapshot_download(
repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
)
clip_model_dir = os.path.join(model_path, suffix)
if sac_model_path is None:
model_path = snapshot_download(
repo_id="xinjjj/RoboAssetGen", allow_patterns="aesthetic/*"
)
suffix = "aesthetic"
model_path = snapshot_download(
repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
)
sac_model_path = os.path.join(
model_path, suffix, "sac+logos+ava1-l14-linearMSE.pth"
)
self.clip_model, self.preprocess = self._load_clip_model(
clip_model_dir
)
self.sac_model = self._load_sac_model(sac_model_path, input_size=768)
class MLP(pl.LightningModule): # noqa
def __init__(self, input_size):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(input_size, 1024),
nn.Dropout(0.2),
nn.Linear(1024, 128),
nn.Dropout(0.2),
nn.Linear(128, 64),
nn.Dropout(0.1),
nn.Linear(64, 16),
nn.Linear(16, 1),
)
def forward(self, x):
return self.layers(x)
@staticmethod
def normalized(a, axis=-1, order=2):
"""Normalize the array to unit norm."""
l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
l2[l2 == 0] = 1
return a / np.expand_dims(l2, axis)
def _load_clip_model(self, model_dir: str, model_name: str = "ViT-L/14"):
"""Load the CLIP model."""
model, preprocess = clip.load(
model_name, download_root=model_dir, device=self.device
)
return model, preprocess
def _load_sac_model(self, model_path, input_size):
"""Load the SAC model."""
model = self.MLP(input_size)
ckpt = torch.load(model_path)
model.load_state_dict(ckpt)
model.to(self.device)
model.eval()
return model
def predict(self, image_path):
"""Predict the aesthetic score for a given image.
Args:
image_path (str): Path to the image file.
Returns:
float: Predicted aesthetic score.
"""
pil_image = Image.open(image_path)
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
with torch.no_grad():
# Extract CLIP features
image_features = self.clip_model.encode_image(image)
# Normalize features
normalized_features = self.normalized(
image_features.cpu().detach().numpy()
)
# Predict score
prediction = self.sac_model(
torch.from_numpy(normalized_features)
.type(torch.FloatTensor)
.to(self.device)
)
return prediction.item()
if __name__ == "__main__":
# Configuration
img_path = "apps/assets/example_image/sample_00.jpg"
# Initialize the predictor
predictor = AestheticPredictor()
# Predict the aesthetic score
score = predictor.predict(img_path)
print("Aesthetic score predicted by the model:", score)

View File

@ -0,0 +1,242 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import logging
import os
from tqdm import tqdm
from embodied_gen.utils.gpt_clients import GPT_CLIENT, GPTclient
from embodied_gen.utils.process_media import render_asset3d
from embodied_gen.validators.aesthetic_predictor import AestheticPredictor
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class BaseChecker:
def __init__(self, prompt: str = None, verbose: bool = False) -> None:
self.prompt = prompt
self.verbose = verbose
def query(self, *args, **kwargs):
raise NotImplementedError(
"Subclasses must implement the query method."
)
def __call__(self, *args, **kwargs) -> bool:
response = self.query(*args, **kwargs)
if response is None:
response = "Error when calling gpt api."
if self.verbose and response != "YES":
logger.info(response)
flag = "YES" in response
response = "YES" if flag else response
return flag, response
@staticmethod
def validate(
checkers: list["BaseChecker"], images_list: list[list[str]]
) -> list:
assert len(checkers) == len(images_list)
results = []
overall_result = True
for checker, images in zip(checkers, images_list):
qa_flag, qa_info = checker(images)
if isinstance(qa_info, str):
qa_info = qa_info.replace("\n", ".")
results.append([checker.__class__.__name__, qa_info])
if qa_flag is False:
overall_result = False
results.append(["overall", "YES" if overall_result else "NO"])
return results
class MeshGeoChecker(BaseChecker):
"""A geometry quality checker for 3D mesh assets using GPT-based reasoning.
This class leverages a multi-modal GPT client to analyze rendered images
of a 3D object and determine if its geometry is complete.
Attributes:
gpt_client (GPTclient): The GPT client used for multi-modal querying.
prompt (str): The prompt sent to the GPT model. If not provided, a default one is used.
verbose (bool): Whether to print debug information during evaluation.
"""
def __init__(
self,
gpt_client: GPTclient,
prompt: str = None,
verbose: bool = False,
) -> None:
super().__init__(prompt, verbose)
self.gpt_client = gpt_client
if self.prompt is None:
self.prompt = """
Refer to the provided multi-view rendering images to evaluate
whether the geometry of the 3D object asset is complete and
whether the asset can be placed stably on the ground.
Return "YES" only if reach the requirments,
otherwise "NO" and explain the reason very briefly.
"""
def query(self, image_paths: str) -> str:
# Hardcode tmp because of the openrouter can't input multi images.
if "openrouter" in self.gpt_client.endpoint:
from embodied_gen.utils.process_media import (
combine_images_to_base64,
)
image_paths = combine_images_to_base64(image_paths)
return self.gpt_client.query(
text_prompt=self.prompt,
image_base64=image_paths,
)
class ImageSegChecker(BaseChecker):
"""A segmentation quality checker for 3D assets using GPT-based reasoning.
This class compares an original image with its segmented version to
evaluate whether the segmentation successfully isolates the main object
with minimal truncation and correct foreground extraction.
Attributes:
gpt_client (GPTclient): GPT client used for multi-modal image analysis.
prompt (str): The prompt used to guide the GPT model for evaluation.
verbose (bool): Whether to enable verbose logging.
"""
def __init__(
self,
gpt_client: GPTclient,
prompt: str = None,
verbose: bool = False,
) -> None:
super().__init__(prompt, verbose)
self.gpt_client = gpt_client
if self.prompt is None:
self.prompt = """
The first image is the original, and the second image is the
result after segmenting the main object. Evaluate the segmentation
quality to ensure the main object is clearly segmented without
significant truncation. Note that the foreground of the object
needs to be extracted instead of the background.
Minor imperfections can be ignored. If segmentation is acceptable,
return "YES" only; otherwise, return "NO" with
very brief explanation.
"""
def query(self, image_paths: list[str]) -> str:
if len(image_paths) != 2:
raise ValueError(
"ImageSegChecker requires exactly two images: [raw_image, seg_image]." # noqa
)
# Hardcode tmp because of the openrouter can't input multi images.
if "openrouter" in self.gpt_client.endpoint:
from embodied_gen.utils.process_media import (
combine_images_to_base64,
)
image_paths = combine_images_to_base64(image_paths)
return self.gpt_client.query(
text_prompt=self.prompt,
image_base64=image_paths,
)
class ImageAestheticChecker(BaseChecker):
"""A class for evaluating the aesthetic quality of images.
Attributes:
clip_model_dir (str): Path to the CLIP model directory.
sac_model_path (str): Path to the aesthetic predictor model weights.
thresh (float): Threshold above which images are considered aesthetically acceptable.
verbose (bool): Whether to print detailed log messages.
predictor (AestheticPredictor): The model used to predict aesthetic scores.
"""
def __init__(
self,
clip_model_dir: str = None,
sac_model_path: str = None,
thresh: float = 4.50,
verbose: bool = False,
) -> None:
super().__init__(verbose=verbose)
self.clip_model_dir = clip_model_dir
self.sac_model_path = sac_model_path
self.thresh = thresh
self.predictor = AestheticPredictor(clip_model_dir, sac_model_path)
def query(self, image_paths: list[str]) -> float:
scores = [self.predictor.predict(img_path) for img_path in image_paths]
return sum(scores) / len(scores)
def __call__(self, image_paths: list[str], **kwargs) -> bool:
avg_score = self.query(image_paths)
if self.verbose:
logger.info(f"Average aesthetic score: {avg_score}")
return avg_score > self.thresh, avg_score
if __name__ == "__main__":
geo_checker = MeshGeoChecker(GPT_CLIENT)
seg_checker = ImageSegChecker(GPT_CLIENT)
aesthetic_checker = ImageAestheticChecker()
checkers = [geo_checker, seg_checker, aesthetic_checker]
output_root = "outputs/test_gpt"
fails = []
for idx in tqdm(range(150)):
mesh_path = f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}.obj" # noqa
if not os.path.exists(mesh_path):
continue
image_paths = render_asset3d(
mesh_path,
f"{output_root}/{idx}",
num_images=8,
elevation=(30, -30),
distance=5.5,
)
for cid, checker in enumerate(checkers):
if isinstance(checker, ImageSegChecker):
images = [
f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}_raw.png", # noqa
f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}_cond.png", # noqa
]
else:
images = image_paths
result, info = checker(images)
logger.info(
f"Checker {checker.__class__.__name__}: {result}, {info}, mesh {mesh_path}" # noqa
)
if result is False:
fails.append((idx, cid, info))
break

View File

@ -0,0 +1,419 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import logging
import os
import shutil
import xml.etree.ElementTree as ET
from datetime import datetime
from xml.dom.minidom import parseString
import numpy as np
import trimesh
from embodied_gen.data.utils import zip_files
from embodied_gen.utils.gpt_clients import GPT_CLIENT, GPTclient
from embodied_gen.utils.process_media import render_asset3d
from embodied_gen.utils.tags import VERSION
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
__all__ = ["URDFGenerator"]
URDF_TEMPLATE = """
<robot name="template_robot">
<link name="template_link">
<visual>
<geometry>
<mesh filename="mesh.obj" scale="1.0 1.0 1.0"/>
</geometry>
</visual>
<collision>
<geometry>
<mesh filename="mesh.obj" scale="1.0 1.0 1.0"/>
</geometry>
<gazebo>
<mu1>0.8</mu1> <!-- Main friction coefficient -->
<mu2>0.6</mu2> <!-- Secondary friction coefficient -->
</gazebo>
</collision>
<inertial>
<mass value="1.0"/>
<origin xyz="0 0 0"/>
<inertia ixx="1.0" ixy="0.0" ixz="0.0" iyy="1.0" iyz="0.0" izz="1.0"/>
</inertial>
<extra_info>
<scale>1.0</scale>
<version>"0.0.0"</version>
<category>"unknown"</category>
<description>"unknown"</description>
<min_height>0.0</min_height>
<max_height>0.0</max_height>
<real_height>0.0</real_height>
<min_mass>0.0</min_mass>
<max_mass>0.0</max_mass>
<generate_time>"-1"</generate_time>
<gs_model>""</gs_model>
</extra_info>
</link>
</robot>
"""
class URDFGenerator(object):
def __init__(
self,
gpt_client: GPTclient,
mesh_file_list: list[str] = ["material_0.png", "material.mtl"],
prompt_template: str = None,
attrs_name: list[str] = None,
render_dir: str = "urdf_renders",
render_view_num: int = 4,
) -> None:
if mesh_file_list is None:
mesh_file_list = []
self.mesh_file_list = mesh_file_list
self.output_mesh_dir = "mesh"
self.output_render_dir = render_dir
self.gpt_client = gpt_client
self.render_view_num = render_view_num
if render_view_num == 4:
view_desc = "This is orthographic projection showing the front, left, right and back views " # noqa
else:
view_desc = "This is the rendered views "
if prompt_template is None:
prompt_template = (
view_desc
+ """of the 3D object asset,
category: {category}.
Give the category of this object asset (within 3 words),
(if category is already provided, use it directly),
accurately describe this 3D object asset (within 15 words),
and give the recommended geometric height range (unit: meter),
weight range (unit: kilogram), the average static friction
coefficient of the object relative to rubber and the average
dynamic friction coefficient of the object relative to rubber.
Return response format as shown in Example.
Example:
Category: cup
Description: shiny golden cup with floral design
Height: 0.1-0.15 m
Weight: 0.3-0.6 kg
Static friction coefficient: 1.1
Dynamic friction coefficient: 0.9
"""
)
self.prompt_template = prompt_template
if attrs_name is None:
attrs_name = [
"category",
"description",
"min_height",
"max_height",
"real_height",
"min_mass",
"max_mass",
"version",
"generate_time",
"gs_model",
]
self.attrs_name = attrs_name
def parse_response(self, response: str) -> dict[str, any]:
lines = response.split("\n")
lines = [line.strip() for line in lines if line]
category = lines[0].split(": ")[1]
description = lines[1].split(": ")[1]
min_height, max_height = map(
lambda x: float(x.strip().replace(",", "").split()[0]),
lines[2].split(": ")[1].split("-"),
)
min_mass, max_mass = map(
lambda x: float(x.strip().replace(",", "").split()[0]),
lines[3].split(": ")[1].split("-"),
)
mu1 = float(lines[4].split(": ")[1].replace(",", ""))
mu2 = float(lines[5].split(": ")[1].replace(",", ""))
return {
"category": category.lower(),
"description": description.lower(),
"min_height": round(min_height, 4),
"max_height": round(max_height, 4),
"min_mass": round(min_mass, 4),
"max_mass": round(max_mass, 4),
"mu1": round(mu1, 2),
"mu2": round(mu2, 2),
"version": VERSION,
"generate_time": datetime.now().strftime("%Y%m%d%H%M%S"),
}
def generate_urdf(
self,
input_mesh: str,
output_dir: str,
attr_dict: dict,
output_name: str = None,
) -> str:
"""Generate a URDF file for a given mesh with specified attributes.
Args:
input_mesh (str): Path to the input mesh file.
output_dir (str): Directory to store the generated URDF
and processed mesh.
attr_dict (dict): Dictionary containing attributes like height,
mass, and friction coefficients.
output_name (str, optional): Name for the generated URDF and robot.
Returns:
str: Path to the generated URDF file.
"""
# 1. Load and normalize the mesh
mesh = trimesh.load(input_mesh)
mesh_scale = np.ptp(mesh.vertices, axis=0).max()
mesh.vertices /= mesh_scale # Normalize to [-0.5, 0.5]
raw_height = np.ptp(mesh.vertices, axis=0)[1]
# 2. Scale the mesh to real height
real_height = attr_dict["real_height"]
scale = round(real_height / raw_height, 6)
mesh = mesh.apply_scale(scale)
# 3. Prepare output directories and save scaled mesh
mesh_folder = os.path.join(output_dir, self.output_mesh_dir)
os.makedirs(mesh_folder, exist_ok=True)
obj_name = os.path.basename(input_mesh)
mesh_output_path = os.path.join(mesh_folder, obj_name)
mesh.export(mesh_output_path)
# 4. Copy additional mesh files, if any
input_dir = os.path.dirname(input_mesh)
for file in self.mesh_file_list:
src_file = os.path.join(input_dir, file)
dest_file = os.path.join(mesh_folder, file)
if os.path.isfile(src_file):
shutil.copy(src_file, dest_file)
# 5. Determine output name
if output_name is None:
output_name = os.path.splitext(obj_name)[0]
# 6. Load URDF template and update attributes
robot = ET.fromstring(URDF_TEMPLATE)
robot.set("name", output_name)
link = robot.find("link")
if link is None:
raise ValueError("URDF template is missing 'link' element.")
link.set("name", output_name)
# Update visual geometry
visual = link.find("visual/geometry/mesh")
if visual is not None:
visual.set(
"filename", os.path.join(self.output_mesh_dir, obj_name)
)
visual.set("scale", "1.0 1.0 1.0")
# Update collision geometry
collision = link.find("collision/geometry/mesh")
if collision is not None:
collision.set(
"filename", os.path.join(self.output_mesh_dir, obj_name)
)
collision.set("scale", "1.0 1.0 1.0")
# Update friction coefficients
gazebo = link.find("collision/gazebo")
if gazebo is not None:
for param, key in zip(["mu1", "mu2"], ["mu1", "mu2"]):
element = gazebo.find(param)
if element is not None:
element.text = f"{attr_dict[key]:.2f}"
# Update mass
inertial = link.find("inertial/mass")
if inertial is not None:
mass_value = (attr_dict["min_mass"] + attr_dict["max_mass"]) / 2
inertial.set("value", f"{mass_value:.4f}")
# Add extra_info element to the link
extra_info = link.find("extra_info/scale")
if extra_info is not None:
extra_info.text = f"{scale:.6f}"
for key in self.attrs_name:
extra_info = link.find(f"extra_info/{key}")
if extra_info is not None and key in attr_dict:
extra_info.text = f"{attr_dict[key]}"
# 7. Write URDF to file
os.makedirs(output_dir, exist_ok=True)
urdf_path = os.path.join(output_dir, f"{output_name}.urdf")
tree = ET.ElementTree(robot)
tree.write(urdf_path, encoding="utf-8", xml_declaration=True)
logger.info(f"URDF file saved to {urdf_path}")
return urdf_path
@staticmethod
def get_attr_from_urdf(
urdf_path: str,
attr_root: str = ".//link/extra_info",
attr_name: str = "scale",
) -> float:
if not os.path.exists(urdf_path):
raise FileNotFoundError(f"URDF file not found: {urdf_path}")
mesh_scale = 1.0
tree = ET.parse(urdf_path)
root = tree.getroot()
extra_info = root.find(attr_root)
if extra_info is not None:
scale_element = extra_info.find(attr_name)
if scale_element is not None:
mesh_scale = float(scale_element.text)
return mesh_scale
@staticmethod
def add_quality_tag(
urdf_path: str, results, output_path: str = None
) -> None:
if output_path is None:
output_path = urdf_path
tree = ET.parse(urdf_path)
root = tree.getroot()
custom_data = ET.SubElement(root, "custom_data")
quality = ET.SubElement(custom_data, "quality")
for key, value in results:
checker_tag = ET.SubElement(quality, key)
checker_tag.text = str(value)
rough_string = ET.tostring(root, encoding="utf-8")
formatted_string = parseString(rough_string).toprettyxml(indent=" ")
cleaned_string = "\n".join(
[line for line in formatted_string.splitlines() if line.strip()]
)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
f.write(cleaned_string)
logger.info(f"URDF files saved to {output_path}")
def get_estimated_attributes(self, asset_attrs: dict):
estimated_attrs = {
"height": round(
(asset_attrs["min_height"] + asset_attrs["max_height"]) / 2, 4
),
"mass": round(
(asset_attrs["min_mass"] + asset_attrs["max_mass"]) / 2, 4
),
"mu": round((asset_attrs["mu1"] + asset_attrs["mu2"]) / 2, 4),
"category": asset_attrs["category"],
}
return estimated_attrs
def __call__(
self,
mesh_path: str,
output_root: str,
text_prompt: str = None,
category: str = "unknown",
**kwargs,
):
if text_prompt is None or len(text_prompt) == 0:
text_prompt = self.prompt_template
text_prompt = text_prompt.format(category=category.lower())
image_path = render_asset3d(
mesh_path,
output_root,
num_images=self.render_view_num,
output_subdir=self.output_render_dir,
)
# Hardcode tmp because of the openrouter can't input multi images.
if "openrouter" in self.gpt_client.endpoint:
from embodied_gen.utils.process_media import (
combine_images_to_base64,
)
image_path = combine_images_to_base64(image_path)
response = self.gpt_client.query(text_prompt, image_path)
if response is None:
asset_attrs = {
"category": category.lower(),
"description": category.lower(),
"min_height": 1,
"max_height": 1,
"min_mass": 1,
"max_mass": 1,
"mu1": 0.8,
"mu2": 0.6,
"version": VERSION,
"generate_time": datetime.now().strftime("%Y%m%d%H%M%S"),
}
else:
asset_attrs = self.parse_response(response)
for key in self.attrs_name:
if key in kwargs:
asset_attrs[key] = kwargs[key]
asset_attrs["real_height"] = round(
(asset_attrs["min_height"] + asset_attrs["max_height"]) / 2, 4
)
self.estimated_attrs = self.get_estimated_attributes(asset_attrs)
urdf_path = self.generate_urdf(mesh_path, output_root, asset_attrs)
logger.info(f"response: {response}")
return urdf_path
if __name__ == "__main__":
urdf_gen = URDFGenerator(GPT_CLIENT, render_view_num=4)
urdf_path = urdf_gen(
mesh_path="outputs/imageto3d/cma/o5/URDF_o5/mesh/o5.obj",
output_root="outputs/test_urdf",
# category="coffee machine",
# min_height=1.0,
# max_height=1.2,
version=VERSION,
)
# zip_files(
# input_paths=[
# "scripts/apps/tmp/2umpdum3e5n/URDF_sample/mesh",
# "scripts/apps/tmp/2umpdum3e5n/URDF_sample/sample.urdf"
# ],
# output_zip="zip.zip"
# )

51
pyproject.toml Normal file
View File

@ -0,0 +1,51 @@
[build-system]
requires = ["setuptools", "wheel", "build"]
build-backend = "setuptools.build_meta"
[tool.setuptools]
packages = ["embodied_gen"]
[project]
name = "embodied_gen"
version = "v0.1.0"
readme = "README.md"
license = "Apache-2.0"
license-files = ["LICENSE", "NOTICE"]
dependencies = []
requires-python = ">=3.10"
[project.optional-dependencies]
dev = [
"cpplint==2.0.0",
"pre-commit==2.13.0",
"pydocstyle",
"black",
"isort",
]
[project.scripts]
drender-cli = "embodied_gen.data.differentiable_render:entrypoint"
backproject-cli = "embodied_gen.data.backproject_v2:entrypoint"
[tool.pydocstyle]
match = '(?!test_).*(?!_pb2)\.py'
match-dir = '^(?!(raw|projects|tools|k8s_submit|thirdparty)$)[\w.-]+$'
convention = "google"
add-ignore = 'D104,D107,D202,D105,D100,D102,D103,D101,E203'
[tool.pycodestyle]
max-line-length = 79
ignore = "E203"
[tool.black]
line-length = 79
exclude = "thirdparty"
skip-string-normalization = true
[tool.isort]
line_length = 79
profile = 'black'
no_lines_before = 'FIRSTPARTY'
known_first_party = ['embodied_gen']
skip = "thirdparty/"

41
requirements.txt Normal file
View File

@ -0,0 +1,41 @@
torch==2.4.0+cu118
torchvision==0.19.0+cu118
xformers==0.0.27.post2
pytorch-lightning==2.4.0
spconv-cu120==2.3.6
numpy==1.26.4
triton==2.1.0
dataclasses_json
easydict
opencv-python>4.5
imageio==2.36.1
imageio-ffmpeg==0.5.1
rembg==2.0.61
trimesh==4.4.4
moviepy==1.0.3
pymeshfix==0.17.0
igraph==0.11.8
pyvista==0.36.1
openai==1.58.1
transformers==4.42.4
gradio==5.12.0
sentencepiece==0.2.0
diffusers==0.31.0
xatlas==0.0.9
onnxruntime==1.20.1
tenacity==8.2.2
accelerate==0.33.0
basicsr==1.4.2
realesrgan==0.3.0
pydantic==2.9.2
vtk==9.3.1
spaces
utils3d@git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8
clip@git+https://github.com/openai/CLIP.git
kolors@git+https://github.com/Kwai-Kolors/Kolors.git#egg=038818d
segment-anything@git+https://github.com/facebookresearch/segment-anything.git#egg=dca509f
nvdiffrast@git+https://github.com/NVlabs/nvdiffrast.git#egg=729261d
kaolin@git+https://github.com/NVIDIAGameWorks/kaolin.git@v0.16.0
https://github.com/nerfstudio-project/gsplat/releases/download/v1.5.0/gsplat-1.5.0+pt24cu118-cp310-cp310-linux_x86_64.whl
https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.0.post2/flash_attn-2.7.0.post2+cu11torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
https://huggingface.co/xinjjj/RoboAssetGen/resolve/main/wheel_cu118/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl

10
scripts/autoformat.sh Normal file
View File

@ -0,0 +1,10 @@
#!/bin/bash
ROOT_DIR=${1}
set -e
black --config=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./
isort --settings-file=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./
pycodestyle --show-source --config=${ROOT_DIR}setup.cfg ${ROOT_DIR}./
pydocstyle --config=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./

61
scripts/check_lint.py Normal file
View File

@ -0,0 +1,61 @@
import argparse
import os
import subprocess
import sys
def get_root():
current_file_path = os.path.abspath(__file__)
root_path = os.path.dirname(current_file_path)
for _ in range(2):
root_path = os.path.dirname(root_path)
return root_path
def cpp_lint(root_path: str):
# run external python file to lint cpp
subprocess.check_call(
" ".join(
[
"python3",
f"{root_path}/scripts/lint_src/lint.py",
"--project=asset_recons",
"--path",
f"{root_path}/src/",
f"{root_path}/include/",
f"{root_path}/module/",
"--exclude_path",
f"{root_path}/module/web_viz/front_end/",
]
),
shell=True,
)
def python_lint(root_path: str, auto_format: bool = False):
# run external python file to lint python
subprocess.check_call(
" ".join(
[
"bash",
(
f"{root_path}/scripts/lint/check_pylint.sh"
if not auto_format
else f"{root_path}/scripts/lint/autoformat.sh"
),
f"{root_path}/",
]
),
shell=True,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="check format.")
parser.add_argument(
"--auto_format", action="store_true", help="auto format python"
)
parser = parser.parse_args()
root_path = get_root()
cpp_lint(root_path)
python_lint(root_path, parser.auto_format)

11
scripts/check_pylint.sh Normal file
View File

@ -0,0 +1,11 @@
#!/bin/bash
ROOT_DIR=${1}
set -e
pycodestyle --show-source --config=${ROOT_DIR}setup.cfg ${ROOT_DIR}./
pydocstyle --config=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./
black --check --diff --config=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./
isort --diff --settings-file=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./

View File

@ -0,0 +1,10 @@
#!/bin/bash
ROOT_DIR=${1}
set -e
black --config=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./
isort --settings-file=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./
pycodestyle --show-source --config=${ROOT_DIR}setup.cfg ${ROOT_DIR}./
pydocstyle --config=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./

View File

@ -0,0 +1,61 @@
import argparse
import os
import subprocess
import sys
def get_root():
current_file_path = os.path.abspath(__file__)
root_path = os.path.dirname(current_file_path)
for _ in range(2):
root_path = os.path.dirname(root_path)
return root_path
def cpp_lint(root_path: str):
# run external python file to lint cpp
subprocess.check_call(
" ".join(
[
"python3",
f"{root_path}/scripts/lint_src/lint.py",
"--project=asset_recons",
"--path",
f"{root_path}/src/",
f"{root_path}/include/",
f"{root_path}/module/",
"--exclude_path",
f"{root_path}/module/web_viz/front_end/",
]
),
shell=True,
)
def python_lint(root_path: str, auto_format: bool = False):
# run external python file to lint python
subprocess.check_call(
" ".join(
[
"bash",
(
f"{root_path}/scripts/lint/check_pylint.sh"
if not auto_format
else f"{root_path}/scripts/lint/autoformat.sh"
),
f"{root_path}/",
]
),
shell=True,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="check format.")
parser.add_argument(
"--auto_format", action="store_true", help="auto format python"
)
parser = parser.parse_args()
root_path = get_root()
cpp_lint(root_path)
python_lint(root_path, parser.auto_format)

View File

@ -0,0 +1,11 @@
#!/bin/bash
ROOT_DIR=${1}
set -e
pycodestyle --show-source --config=${ROOT_DIR}setup.cfg ${ROOT_DIR}./
pydocstyle --config=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./
black --check --diff --config=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./
isort --diff --settings-file=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./

10
scripts/lint_src/cpplint.hook Executable file
View File

@ -0,0 +1,10 @@
#!/bin/bash
TOTAL_ERRORS=0
if [[ ! $(which cpplint) ]]; then
pip install cpplint
fi
# diff files on local machine.
files=$(git diff --cached --name-status | awk '$1 != "D" {print $2}')
python3 scripts/lint_src/lint.py --project=asset_recons --path $files --exclude_path thirdparty patch_files;

155
scripts/lint_src/lint.py Normal file
View File

@ -0,0 +1,155 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import argparse
import codecs
import os
import re
import sys
import cpplint
import pycodestyle
from cpplint import _cpplint_state
CXX_SUFFIX = set(["cc", "c", "cpp", "h", "cu", "hpp"])
def filepath_enumerate(paths):
"""Enumerate the file paths of all subfiles of the list of paths."""
out = []
for path in paths:
if os.path.isfile(path):
out.append(path)
else:
for root, dirs, files in os.walk(path):
for name in files:
out.append(os.path.normpath(os.path.join(root, name)))
return out
class LintHelper(object):
@staticmethod
def _print_summary_map(strm, result_map, ftype):
"""Print summary of certain result map."""
if len(result_map) == 0:
return 0
npass = len([x for k, x in result_map.items() if len(x) == 0])
strm.write(
"=====%d/%d %s files passed check=====\n"
% (npass, len(result_map), ftype)
)
for fname, emap in result_map.items():
if len(emap) == 0:
continue
strm.write(
"%s: %d Errors of %d Categories map=%s\n"
% (fname, sum(emap.values()), len(emap), str(emap))
)
return len(result_map) - npass
def __init__(self) -> None:
self.project_name = None
self.cpp_header_map = {}
self.cpp_src_map = {}
super().__init__()
cpplint_args = [".", "--extensions=" + (",".join(CXX_SUFFIX))]
_ = cpplint.ParseArguments(cpplint_args)
cpplint._SetFilters(
",".join(
[
"-build/c++11",
"-build/namespaces",
"-build/include,",
"+build/include_what_you_use",
"+build/include_order",
]
)
)
cpplint._SetCountingStyle("toplevel")
cpplint._line_length = 80
def process_cpp(self, path, suffix):
"""Process a cpp file."""
_cpplint_state.ResetErrorCounts()
cpplint.ProcessFile(str(path), _cpplint_state.verbose_level)
_cpplint_state.PrintErrorCounts()
errors = _cpplint_state.errors_by_category.copy()
if suffix == "h":
self.cpp_header_map[str(path)] = errors
else:
self.cpp_src_map[str(path)] = errors
def print_summary(self, strm):
"""Print summary of lint."""
nerr = 0
nerr += LintHelper._print_summary_map(
strm, self.cpp_header_map, "cpp-header"
)
nerr += LintHelper._print_summary_map(
strm, self.cpp_src_map, "cpp-source"
)
if nerr == 0:
strm.write("All passed!\n")
else:
strm.write("%d files failed lint\n" % nerr)
return nerr
# singleton helper for lint check
_HELPER = LintHelper()
def process(fname, allow_type):
"""Process a file."""
fname = str(fname)
arr = fname.rsplit(".", 1)
if fname.find("#") != -1 or arr[-1] not in allow_type:
return
if arr[-1] in CXX_SUFFIX:
_HELPER.process_cpp(fname, arr[-1])
def main():
"""Main entry function."""
parser = argparse.ArgumentParser(description="lint source codes")
parser.add_argument("--project", help="project name")
parser.add_argument(
"--path",
nargs="+",
default=[],
help="path to traverse",
required=False,
)
parser.add_argument(
"--exclude_path",
nargs="+",
default=[],
help="exclude this path, and all subfolders " + "if path is a folder",
)
args = parser.parse_args()
_HELPER.project_name = args.project
allow_type = []
allow_type += [x for x in CXX_SUFFIX]
allow_type = set(allow_type)
# get excluded files
excluded_paths = filepath_enumerate(args.exclude_path)
for path in args.path:
if os.path.isfile(path):
normpath = os.path.normpath(path)
if normpath not in excluded_paths:
process(path, allow_type)
else:
for root, dirs, files in os.walk(path):
for name in files:
file_path = os.path.normpath(os.path.join(root, name))
if file_path not in excluded_paths:
process(file_path, allow_type)
nerr = _HELPER.print_summary(sys.stderr)
sys.exit(nerr > 0)
if __name__ == "__main__":
main()

16
scripts/lint_src/pep8.hook Executable file
View File

@ -0,0 +1,16 @@
#!/bin/bash
TOTAL_ERRORS=0
if [[ ! $(which pycodestyle) ]]; then
pip install pycodestyle
fi
# diff files on local machine.
files=$(git diff --cached --name-status | awk '$1 != "D" {print $2}')
for file in $files; do
if [ "${file##*.}" == "py" & -f "${file}"] ; then
pycodestyle --show-source $file --config=setup.cfg;
TOTAL_ERRORS=$(expr $TOTAL_ERRORS + $?);
fi
done
exit $TOTAL_ERRORS

View File

@ -0,0 +1,16 @@
#!/bin/bash
TOTAL_ERRORS=0
if [[ ! $(which pydocstyle) ]]; then
pip install pydocstyle
fi
# diff files on local machine.
files=$(git diff --cached --name-status | awk '$1 != "D" {print $2}')
for file in $files; do
if [ "${file##*.}" == "py" ] ; then
pydocstyle $file;
TOTAL_ERRORS=$(expr $TOTAL_ERRORS + $?);
fi
done
exit $TOTAL_ERRORS

2
setup.cfg Normal file
View File

@ -0,0 +1,2 @@
[pycodestyle]
ignore = E203,W503,E402,E501

1
thirdparty/TRELLIS vendored Submodule

@ -0,0 +1 @@
Subproject commit 55a8e8164b195bbf927e0978f00e76c835e6011f