feat(pipe): Refine generation pipeline and add utests. (#23)

- Added intelligent quality checkers and auto-retry pipeline for `image-to-3d` and `text-to-3d`.
- Added unit tests for quality checkers.
- `text-to-3d` now supports more `text-to-image` models, pipeline success rate improved to 94%.
This commit is contained in:
Xinjie 2025-07-11 16:51:35 +08:00 committed by GitHub
parent 8bcfac9190
commit a3924ae4a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 1863 additions and 372 deletions

29
CHANGELOG.md Normal file
View File

@ -0,0 +1,29 @@
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0).
## [0.1.1] - 2025-07-xx
### Feature
- Added intelligent quality checkers and auto-retry pipeline for `image-to-3d` and `text-to-3d`.
- Added unit tests for quality checkers.
- `text-to-3d` now supports more `text-to-image` models, pipeline success rate improved to 94%.
## [0.1.0] - 2025-07-04
### Feature
🖼️ Single Image to Physics Realistic 3D Asset
- Generates watertight, simulation-ready 3D meshes with physical attributes.
- Auto-labels semantic and quality tags (geometry, texture, foreground, etc.).
- Produces 2K textures with highlight removal and multi-view fusion.
📝 Text-to-3D Asset Generation
- Creates realistic 3D assets from natural language (English & Chinese).
- Filters assets via QA tags to ensure visual and geometric quality.
🎨 Texture Generation & Editing
- Generates 2K textures from mesh and text with semantic alignment.
- Plug-and-play modules adapt text-to-image models for 3D textures.
- Supports realistic and stylized texture outputs, including text textures.

View File

@ -6,6 +6,7 @@
[![🤗 Hugging Face](https://img.shields.io/badge/🤗-Image_to_3D_Demo-blue)](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Image-to-3D) [![🤗 Hugging Face](https://img.shields.io/badge/🤗-Image_to_3D_Demo-blue)](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Image-to-3D)
[![🤗 Hugging Face](https://img.shields.io/badge/🤗-Text_to_3D_Demo-blue)](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Text-to-3D) [![🤗 Hugging Face](https://img.shields.io/badge/🤗-Text_to_3D_Demo-blue)](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Text-to-3D)
[![🤗 Hugging Face](https://img.shields.io/badge/🤗-Texture_Gen_Demo-blue)](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Texture-Gen) [![🤗 Hugging Face](https://img.shields.io/badge/🤗-Texture_Gen_Demo-blue)](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Texture-Gen)
[![中文介绍](https://img.shields.io/badge/中文介绍-07C160?logo=wechat&logoColor=white)](https://mp.weixin.qq.com/s/HH1cPBhK2xcDbyCK4BBTbw)
> ***EmbodiedGen*** is a generative engine to create diverse and interactive 3D worlds composed of high-quality 3D assets(mesh & 3DGS) with plausible physics, leveraging generative AI to address the challenges of generalization in embodied intelligence related research. > ***EmbodiedGen*** is a generative engine to create diverse and interactive 3D worlds composed of high-quality 3D assets(mesh & 3DGS) with plausible physics, leveraging generative AI to address the challenges of generalization in embodied intelligence related research.
@ -29,7 +30,7 @@
```sh ```sh
git clone https://github.com/HorizonRobotics/EmbodiedGen.git git clone https://github.com/HorizonRobotics/EmbodiedGen.git
cd EmbodiedGen cd EmbodiedGen
git checkout v0.1.0 git checkout v0.1.1
git submodule update --init --recursive --progress git submodule update --init --recursive --progress
conda create -n embodiedgen python=3.10.13 -y conda create -n embodiedgen python=3.10.13 -y
conda activate embodiedgen conda activate embodiedgen
@ -67,9 +68,8 @@ CUDA_VISIBLE_DEVICES=0 nohup python apps/image_to_3d.py > /dev/null 2>&1 &
### ⚡ API ### ⚡ API
Generate physically plausible 3D assets from image input via the command-line API. Generate physically plausible 3D assets from image input via the command-line API.
```sh ```sh
python3 embodied_gen/scripts/imageto3d.py \ img3d-cli --image_path apps/assets/example_image/sample_04.jpg apps/assets/example_image/sample_19.jpg \
--image_path apps/assets/example_image/sample_04.jpg apps/assets/example_image/sample_19.jpg \ --n_retry 2 --output_root outputs/imageto3d
--output_root outputs/imageto3d
# See result(.urdf/mesh.obj/mesh.glb/gs.ply) in ${output_root}/sample_xx/result # See result(.urdf/mesh.obj/mesh.glb/gs.ply) in ${output_root}/sample_xx/result
``` ```
@ -86,18 +86,29 @@ python3 embodied_gen/scripts/imageto3d.py \
### ☁️ Service ### ☁️ Service
Deploy the text-to-3D generation service locally. Deploy the text-to-3D generation service locally.
Text-to-image based on the Kolors model, supporting Chinese and English prompts. Text-to-image model based on the Kolors model, supporting Chinese and English prompts.
Models downloaded automatically on first run, see `download_kolors_weights`, please be patient. Models downloaded automatically on first run, please be patient.
```sh ```sh
python apps/text_to_3d.py python apps/text_to_3d.py
``` ```
### ⚡ API ### ⚡ API
Text-to-image based on the Kolors model. Text-to-image model based on SD3.5 Medium, English prompts only.
Usage requires agreement to the [model license(click accept)](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium), models downloaded automatically. (ps: models with more permissive licenses found in `embodied_gen/models/image_comm_model.py`)
For large-scale 3D assets generation, set `--n_pipe_retry=2` to ensure high end-to-end 3D asset usability through automatic quality check and retries. For more diverse results, do not set `--seed_img`.
```sh
text3d-cli --prompts "small bronze figurine of a lion" "A globe with wooden base" "wooden table with embroidery" \
--n_image_retry 2 --n_asset_retry 2 --n_pipe_retry 1 --seed_img 0 \
--output_root outputs/textto3d
```
Text-to-image model based on the Kolors model.
```sh ```sh
bash embodied_gen/scripts/textto3d.sh \ bash embodied_gen/scripts/textto3d.sh \
--prompts "small bronze figurine of a lion" "A globe with wooden base and latitude and longitude lines" "橙色电动手钻,有磨损细节" \ --prompts "small bronze figurine of a lion" "A globe with wooden base and latitude and longitude lines" "橙色电动手钻,有磨损细节" \
--output_root outputs/textto3d --output_root outputs/textto3d_k
``` ```
--- ---
@ -118,12 +129,17 @@ python apps/texture_edit.py
``` ```
### ⚡ API ### ⚡ API
Support Chinese and English prompts.
```sh ```sh
bash embodied_gen/scripts/texture_gen.sh \ bash embodied_gen/scripts/texture_gen.sh \
--mesh_path "apps/assets/example_texture/meshes/robot_text.obj" \ --mesh_path "apps/assets/example_texture/meshes/robot_text.obj" \
--prompt "举着牌子的写实风格机器人大眼睛牌子上写着“Hello”的文字" \ --prompt "举着牌子的写实风格机器人大眼睛牌子上写着“Hello”的文字" \
--output_root "outputs/texture_gen/" \ --output_root "outputs/texture_gen/robot_text"
--uuid "robot_text"
bash embodied_gen/scripts/texture_gen.sh \
--mesh_path "apps/assets/example_texture/meshes/horse.obj" \
--prompt "A gray horse head with flying mane and brown eyes" \
--output_root "outputs/texture_gen/gray_horse"
``` ```
--- ---
@ -171,6 +187,12 @@ bash embodied_gen/scripts/texture_gen.sh \
--- ---
## For Developer
```sh
pip install .[dev] && pre-commit install
python -m pytest # Pass all unit-test are required.
```
## 📚 Citation ## 📚 Citation
If you use EmbodiedGen in your research or projects, please cite: If you use EmbodiedGen in your research or projects, please cite:
@ -192,7 +214,7 @@ If you use EmbodiedGen in your research or projects, please cite:
## 🙌 Acknowledgement ## 🙌 Acknowledgement
EmbodiedGen builds upon the following amazing projects and models: EmbodiedGen builds upon the following amazing projects and models:
🌟 [Trellis](https://github.com/microsoft/TRELLIS) | 🌟 [Hunyuan-Delight](https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0) | 🌟 [Segment Anything](https://github.com/facebookresearch/segment-anything) | 🌟 [Rembg](https://github.com/danielgatis/rembg) | 🌟 [RMBG-1.4](https://huggingface.co/briaai/RMBG-1.4) | 🌟 [Stable Diffusion x4](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) | 🌟 [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) | 🌟 [Kolors](https://github.com/Kwai-Kolors/Kolors) | 🌟 [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 🌟 [Aesthetic Score](http://captions.christoph-schuhmann.de/aesthetic_viz_laion_sac+logos+ava1-l14-linearMSE-en-2.37B.html) | 🌟 [Pano2Room](https://github.com/TrickyGo/Pano2Room) | 🌟 [Diffusion360](https://github.com/ArcherFMY/SD-T2I-360PanoImage) | 🌟 [Kaolin](https://github.com/NVIDIAGameWorks/kaolin) | 🌟 [diffusers](https://github.com/huggingface/diffusers) | 🌟 [gsplat](https://github.com/nerfstudio-project/gsplat) | 🌟 [QWEN2.5VL](https://github.com/QwenLM/Qwen2.5-VL) | 🌟 [GPT4o](https://platform.openai.com/docs/models/gpt-4o) 🌟 [Trellis](https://github.com/microsoft/TRELLIS) | 🌟 [Hunyuan-Delight](https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0) | 🌟 [Segment Anything](https://github.com/facebookresearch/segment-anything) | 🌟 [Rembg](https://github.com/danielgatis/rembg) | 🌟 [RMBG-1.4](https://huggingface.co/briaai/RMBG-1.4) | 🌟 [Stable Diffusion x4](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) | 🌟 [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) | 🌟 [Kolors](https://github.com/Kwai-Kolors/Kolors) | 🌟 [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 🌟 [Aesthetic Score](http://captions.christoph-schuhmann.de/aesthetic_viz_laion_sac+logos+ava1-l14-linearMSE-en-2.37B.html) | 🌟 [Pano2Room](https://github.com/TrickyGo/Pano2Room) | 🌟 [Diffusion360](https://github.com/ArcherFMY/SD-T2I-360PanoImage) | 🌟 [Kaolin](https://github.com/NVIDIAGameWorks/kaolin) | 🌟 [diffusers](https://github.com/huggingface/diffusers) | 🌟 [gsplat](https://github.com/nerfstudio-project/gsplat) | 🌟 [QWEN-2.5VL](https://github.com/QwenLM/Qwen2.5-VL) | 🌟 [GPT4o](https://platform.openai.com/docs/models/gpt-4o) | 🌟 [SD3.5](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium)
--- ---

View File

@ -55,9 +55,9 @@ from embodied_gen.utils.gpt_clients import GPT_CLIENT
from embodied_gen.utils.process_media import ( from embodied_gen.utils.process_media import (
filter_image_small_connected_components, filter_image_small_connected_components,
merge_images_video, merge_images_video,
render_video,
) )
from embodied_gen.utils.tags import VERSION from embodied_gen.utils.tags import VERSION
from embodied_gen.utils.trender import render_video
from embodied_gen.validators.quality_checkers import ( from embodied_gen.validators.quality_checkers import (
BaseChecker, BaseChecker,
ImageAestheticChecker, ImageAestheticChecker,

View File

@ -33,7 +33,6 @@ from tqdm import tqdm
from embodied_gen.data.utils import ( from embodied_gen.data.utils import (
CameraSetting, CameraSetting,
DiffrastRender, DiffrastRender,
RenderItems,
as_list, as_list,
calc_vertex_normals, calc_vertex_normals,
import_kaolin_mesh, import_kaolin_mesh,
@ -42,6 +41,7 @@ from embodied_gen.data.utils import (
render_pbr, render_pbr,
save_images, save_images,
) )
from embodied_gen.utils.enum import RenderItems
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser( os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
@ -470,7 +470,7 @@ def parse_args():
"--pbr_light_factor", "--pbr_light_factor",
type=float, type=float,
default=1.0, default=1.0,
help="Light factor for mesh PBR rendering (default: 2.)", help="Light factor for mesh PBR rendering (default: 1.)",
) )
parser.add_argument( parser.add_argument(
"--with_mtl", "--with_mtl",
@ -482,6 +482,11 @@ def parse_args():
action="store_true", action="store_true",
help="Whether to generate color .gif rendering file.", help="Whether to generate color .gif rendering file.",
) )
parser.add_argument(
"--no_index_file",
action="store_true",
help="Whether skip the index file saving.",
)
parser.add_argument( parser.add_argument(
"--gen_color_mp4", "--gen_color_mp4",
action="store_true", action="store_true",
@ -568,7 +573,7 @@ def entrypoint(**kwargs) -> None:
gen_viewnormal_mp4=args.gen_viewnormal_mp4, gen_viewnormal_mp4=args.gen_viewnormal_mp4,
gen_glonormal_mp4=args.gen_glonormal_mp4, gen_glonormal_mp4=args.gen_glonormal_mp4,
light_factor=args.pbr_light_factor, light_factor=args.pbr_light_factor,
no_index_file=gen_video, no_index_file=gen_video or args.no_index_file,
) )
image_render.render_mesh( image_render.render_mesh(
mesh_path=args.mesh_path, mesh_path=args.mesh_path,

View File

@ -395,6 +395,8 @@ class MeshFixer(object):
self.vertices_np, self.vertices_np,
np.hstack([np.full((self.faces.shape[0], 1), 3), self.faces_np]), np.hstack([np.full((self.faces.shape[0], 1), 3), self.faces_np]),
) )
mesh.clean(inplace=True)
mesh.clear_data()
mesh = mesh.decimate(ratio, progress_bar=True) mesh = mesh.decimate(ratio, progress_bar=True)
# Update vertices and faces # Update vertices and faces

View File

@ -38,7 +38,6 @@ except ImportError:
ChatGLMModel = None ChatGLMModel = None
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum
import trimesh import trimesh
from kaolin.render.camera import Camera from kaolin.render.camera import Camera
@ -57,7 +56,6 @@ __all__ = [
"load_mesh_to_unit_cube", "load_mesh_to_unit_cube",
"as_list", "as_list",
"CameraSetting", "CameraSetting",
"RenderItems",
"import_kaolin_mesh", "import_kaolin_mesh",
"save_mesh_with_mtl", "save_mesh_with_mtl",
"get_images_from_grid", "get_images_from_grid",
@ -738,18 +736,6 @@ class CameraSetting:
self.Ks = Ks self.Ks = Ks
@dataclass
class RenderItems(str, Enum):
IMAGE = "image_color"
ALPHA = "image_mask"
VIEW_NORMAL = "image_view_normal"
GLOBAL_NORMAL = "image_global_normal"
POSITION_MAP = "image_position"
DEPTH = "image_depth"
ALBEDO = "image_albedo"
DIFFUSE = "image_diffuse"
def _compute_az_el_by_camera_params( def _compute_az_el_by_camera_params(
camera_params: CameraSetting, flip_az: bool = False camera_params: CameraSetting, flip_az: bool = False
): ):

View File

@ -0,0 +1,236 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
# Text-to-Image generation models from Hugging Face community.
import os
from abc import ABC, abstractmethod
import torch
from diffusers import (
ChromaPipeline,
Cosmos2TextToImagePipeline,
DPMSolverMultistepScheduler,
FluxPipeline,
KolorsPipeline,
StableDiffusion3Pipeline,
)
from diffusers.quantizers import PipelineQuantizationConfig
from huggingface_hub import snapshot_download
from PIL import Image
from transformers import AutoModelForCausalLM, SiglipProcessor
__all__ = [
"build_hf_image_pipeline",
]
class BasePipelineLoader(ABC):
def __init__(self, device="cuda"):
self.device = device
@abstractmethod
def load(self):
pass
class BasePipelineRunner(ABC):
def __init__(self, pipe):
self.pipe = pipe
@abstractmethod
def run(self, prompt: str, **kwargs) -> Image.Image:
pass
# ===== SD3.5-medium =====
class SD35Loader(BasePipelineLoader):
def load(self):
pipe = StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3.5-medium",
torch_dtype=torch.float16,
)
pipe = pipe.to(self.device)
pipe.enable_model_cpu_offload()
pipe.enable_xformers_memory_efficient_attention()
pipe.enable_attention_slicing()
return pipe
class SD35Runner(BasePipelineRunner):
def run(self, prompt: str, **kwargs) -> Image.Image:
return self.pipe(prompt=prompt, **kwargs).images
# ===== Cosmos2 =====
class CosmosLoader(BasePipelineLoader):
def __init__(
self,
model_id="nvidia/Cosmos-Predict2-2B-Text2Image",
local_dir="weights/cosmos2",
device="cuda",
):
super().__init__(device)
self.model_id = model_id
self.local_dir = local_dir
def _patch(self):
def patch_model(cls):
orig = cls.from_pretrained
def new(*args, **kwargs):
kwargs.setdefault("attn_implementation", "flash_attention_2")
kwargs.setdefault("torch_dtype", torch.bfloat16)
return orig(*args, **kwargs)
cls.from_pretrained = new
def patch_processor(cls):
orig = cls.from_pretrained
def new(*args, **kwargs):
kwargs.setdefault("use_fast", True)
return orig(*args, **kwargs)
cls.from_pretrained = new
patch_model(AutoModelForCausalLM)
patch_processor(SiglipProcessor)
def load(self):
self._patch()
snapshot_download(
repo_id=self.model_id,
local_dir=self.local_dir,
local_dir_use_symlinks=False,
resume_download=True,
)
config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit",
quant_kwargs={
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16,
"bnb_4bit_use_double_quant": True,
},
components_to_quantize=["text_encoder", "transformer", "unet"],
)
pipe = Cosmos2TextToImagePipeline.from_pretrained(
self.model_id,
torch_dtype=torch.bfloat16,
quantization_config=config,
use_safetensors=True,
safety_checker=None,
requires_safety_checker=False,
).to(self.device)
return pipe
class CosmosRunner(BasePipelineRunner):
def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
return self.pipe(
prompt=prompt, negative_prompt=negative_prompt, **kwargs
).images
# ===== Kolors =====
class KolorsLoader(BasePipelineLoader):
def load(self):
pipe = KolorsPipeline.from_pretrained(
"Kwai-Kolors/Kolors-diffusers",
torch_dtype=torch.float16,
variant="fp16",
).to(self.device)
pipe.enable_model_cpu_offload()
pipe.enable_xformers_memory_efficient_attention()
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
pipe.scheduler.config, use_karras_sigmas=True
)
return pipe
class KolorsRunner(BasePipelineRunner):
def run(self, prompt: str, **kwargs) -> Image.Image:
return self.pipe(prompt=prompt, **kwargs).images
# ===== Flux =====
class FluxLoader(BasePipelineLoader):
def load(self):
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()
pipe.enable_xformers_memory_efficient_attention()
pipe.enable_attention_slicing()
return pipe.to(self.device)
class FluxRunner(BasePipelineRunner):
def run(self, prompt: str, **kwargs) -> Image.Image:
return self.pipe(prompt=prompt, **kwargs).images
# ===== Chroma =====
class ChromaLoader(BasePipelineLoader):
def load(self):
return ChromaPipeline.from_pretrained(
"lodestones/Chroma", torch_dtype=torch.bfloat16
).to(self.device)
class ChromaRunner(BasePipelineRunner):
def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
return self.pipe(
prompt=prompt, negative_prompt=negative_prompt, **kwargs
).images
PIPELINE_REGISTRY = {
"sd35": (SD35Loader, SD35Runner),
"cosmos": (CosmosLoader, CosmosRunner),
"kolors": (KolorsLoader, KolorsRunner),
"flux": (FluxLoader, FluxRunner),
"chroma": (ChromaLoader, ChromaRunner),
}
def build_hf_image_pipeline(name: str, device="cuda") -> BasePipelineRunner:
if name not in PIPELINE_REGISTRY:
raise ValueError(f"Unsupported model: {name}")
loader_cls, runner_cls = PIPELINE_REGISTRY[name]
pipe = loader_cls(device=device).load()
return runner_cls(pipe)
if __name__ == "__main__":
model_name = "sd35"
runner = build_hf_image_pipeline(model_name)
# NOTE: Just for pipeline testing, generation quality at low resolution is poor.
images = runner.run(
prompt="A robot holding a sign that says 'Hello'",
height=512,
width=512,
num_inference_steps=10,
guidance_scale=6,
num_images_per_prompt=1,
)
for i, img in enumerate(images):
img.save(f"image_{model_name}_{i}.jpg")

View File

@ -52,8 +52,11 @@ __all__ = [
"download_kolors_weights", "download_kolors_weights",
] ]
PROMPT_APPEND = (
PROMPT_APPEND = "Full view of one {}, no cropping, centered, no occlusion, isolated product photo, matte, 3D style, on a plain clean surface" "Angled 3D view of one {object}, centered, no cropping, no occlusion, isolated product photo, "
"no surroundings, matte, on a plain clean surface, 3D style revealing multiple surfaces"
)
PROMPT_KAPPEND = "Single {object}, in the center of the image, white background, 3D style, best quality"
def download_kolors_weights(local_dir: str = "weights/Kolors") -> None: def download_kolors_weights(local_dir: str = "weights/Kolors") -> None:
@ -182,9 +185,7 @@ def text2img_gen(
ip_image_size: int = 512, ip_image_size: int = 512,
seed: int = None, seed: int = None,
) -> list[Image.Image]: ) -> list[Image.Image]:
# prompt = "Single " + prompt + ", in the center of the image" prompt = PROMPT_KAPPEND.format(object=prompt.strip())
# prompt += ", high quality, high resolution, best quality, white background, 3D style" # noqa
prompt = PROMPT_APPEND.format(prompt.strip())
logger.info(f"Processing prompt: {prompt}") logger.info(f"Processing prompt: {prompt}")
generator = None generator = None

View File

@ -16,13 +16,14 @@
import argparse import argparse
import logging
import os import os
import random
import sys import sys
from glob import glob from glob import glob
from shutil import copy, copytree, rmtree from shutil import copy, copytree, rmtree
import numpy as np import numpy as np
import torch
import trimesh import trimesh
from PIL import Image from PIL import Image
from embodied_gen.data.backproject_v2 import entrypoint as backproject_api from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
@ -37,8 +38,10 @@ from embodied_gen.models.segment_model import (
from embodied_gen.models.sr_model import ImageRealESRGAN from embodied_gen.models.sr_model import ImageRealESRGAN
from embodied_gen.scripts.render_gs import entrypoint as render_gs_api from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
from embodied_gen.utils.gpt_clients import GPT_CLIENT from embodied_gen.utils.gpt_clients import GPT_CLIENT
from embodied_gen.utils.process_media import merge_images_video, render_video from embodied_gen.utils.log import logger
from embodied_gen.utils.process_media import merge_images_video
from embodied_gen.utils.tags import VERSION from embodied_gen.utils.tags import VERSION
from embodied_gen.utils.trender import render_video
from embodied_gen.validators.quality_checkers import ( from embodied_gen.validators.quality_checkers import (
BaseChecker, BaseChecker,
ImageAestheticChecker, ImageAestheticChecker,
@ -52,19 +55,14 @@ current_dir = os.path.dirname(current_file_path)
sys.path.append(os.path.join(current_dir, "../..")) sys.path.append(os.path.join(current_dir, "../.."))
from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)
os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser( os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
"~/.cache/torch_extensions" "~/.cache/torch_extensions"
) )
os.environ["GRADIO_ANALYTICS_ENABLED"] = "false" os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
os.environ["SPCONV_ALGO"] = "native" os.environ["SPCONV_ALGO"] = "native"
random.seed(0)
logger.info("Loading Models...")
DELIGHT = DelightingModel() DELIGHT = DelightingModel()
IMAGESR_MODEL = ImageRealESRGAN(outscale=4) IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
@ -74,7 +72,7 @@ SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
PIPELINE = TrellisImageTo3DPipeline.from_pretrained( PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
"microsoft/TRELLIS-image-large" "microsoft/TRELLIS-image-large"
) )
PIPELINE.cuda() # PIPELINE.cuda()
SEG_CHECKER = ImageSegChecker(GPT_CLIENT) SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT) GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
AESTHETIC_CHECKER = ImageAestheticChecker() AESTHETIC_CHECKER = ImageAestheticChecker()
@ -95,7 +93,6 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--output_root", "--output_root",
type=str, type=str,
required=True,
help="Root directory for saving outputs.", help="Root directory for saving outputs.",
) )
parser.add_argument( parser.add_argument(
@ -110,12 +107,26 @@ def parse_args():
default=None, default=None,
help="The mass in kg to restore the mesh real weight.", help="The mass in kg to restore the mesh real weight.",
) )
parser.add_argument("--asset_type", type=str, default=None) parser.add_argument("--asset_type", type=str, nargs="+", default=None)
parser.add_argument("--skip_exists", action="store_true") parser.add_argument("--skip_exists", action="store_true")
parser.add_argument("--strict_seg", action="store_true")
parser.add_argument("--version", type=str, default=VERSION) parser.add_argument("--version", type=str, default=VERSION)
parser.add_argument("--remove_intermediate", type=bool, default=True) parser.add_argument("--keep_intermediate", action="store_true")
args = parser.parse_args() parser.add_argument("--seed", type=int, default=0)
parser.add_argument(
"--n_retry",
type=int,
default=2,
)
args, unknown = parser.parse_known_args()
return args
def entrypoint(**kwargs):
args = parse_args()
for k, v in kwargs.items():
if hasattr(args, k) and v is not None:
setattr(args, k, v)
assert ( assert (
args.image_path or args.image_root args.image_path or args.image_root
@ -125,13 +136,7 @@ def parse_args():
args.image_path += glob(os.path.join(args.image_root, "*.jpg")) args.image_path += glob(os.path.join(args.image_root, "*.jpg"))
args.image_path += glob(os.path.join(args.image_root, "*.jpeg")) args.image_path += glob(os.path.join(args.image_root, "*.jpeg"))
return args for idx, image_path in enumerate(args.image_path):
if __name__ == "__main__":
args = parse_args()
for image_path in args.image_path:
try: try:
filename = os.path.basename(image_path).split(".")[0] filename = os.path.basename(image_path).split(".")[0]
output_root = args.output_root output_root = args.output_root
@ -141,7 +146,7 @@ if __name__ == "__main__":
mesh_out = f"{output_root}/{filename}.obj" mesh_out = f"{output_root}/{filename}.obj"
if args.skip_exists and os.path.exists(mesh_out): if args.skip_exists and os.path.exists(mesh_out):
logger.info( logger.warning(
f"Skip {image_path}, already processed in {mesh_out}" f"Skip {image_path}, already processed in {mesh_out}"
) )
continue continue
@ -149,67 +154,84 @@ if __name__ == "__main__":
image = Image.open(image_path) image = Image.open(image_path)
image.save(f"{output_root}/{filename}_raw.png") image.save(f"{output_root}/{filename}_raw.png")
# Segmentation: Get segmented image using SAM or Rembg. # Segmentation: Get segmented image using Rembg.
seg_path = f"{output_root}/{filename}_cond.png" seg_path = f"{output_root}/{filename}_cond.png"
if image.mode != "RGBA": seg_image = RBG_REMOVER(image) if image.mode != "RGBA" else image
seg_image = RBG_REMOVER(image, save_path=seg_path) seg_image = trellis_preprocess(seg_image)
seg_image = trellis_preprocess(seg_image) seg_image.save(seg_path)
else:
seg_image = image
seg_image.save(seg_path)
# Run the pipeline seed = args.seed
try: for try_idx in range(args.n_retry):
outputs = PIPELINE.run( logger.info(
seg_image, f"Try: {try_idx + 1}/{args.n_retry}, Seed: {seed}, Prompt: {seg_path}"
preprocess_image=False,
# Optional parameters
# seed=1,
# sparse_structure_sampler_params={
# "steps": 12,
# "cfg_strength": 7.5,
# },
# slat_sampler_params={
# "steps": 12,
# "cfg_strength": 3,
# },
) )
except Exception as e: # Run the pipeline
logger.error( try:
f"[Pipeline Failed] process {image_path}: {e}, skip." PIPELINE.cuda()
) outputs = PIPELINE.run(
continue seg_image,
preprocess_image=False,
seed=(
random.randint(0, 100000) if seed is None else seed
),
# Optional parameters
# sparse_structure_sampler_params={
# "steps": 12,
# "cfg_strength": 7.5,
# },
# slat_sampler_params={
# "steps": 12,
# "cfg_strength": 3,
# },
)
PIPELINE.cpu()
torch.cuda.empty_cache()
except Exception as e:
logger.error(
f"[Pipeline Failed] process {image_path}: {e}, skip."
)
continue
# Render and save color and mesh videos gs_model = outputs["gaussian"][0]
gs_model = outputs["gaussian"][0] mesh_model = outputs["mesh"][0]
mesh_model = outputs["mesh"][0]
# Save the raw Gaussian model
gs_path = mesh_out.replace(".obj", "_gs.ply")
gs_model.save_ply(gs_path)
# Rotate mesh and GS by 90 degrees around Z-axis.
rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
# Addtional rotation for GS to align mesh.
gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix)
pose = GaussianOperator.trans_to_quatpose(gs_rot)
aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
GaussianOperator.resave_ply(
in_ply=gs_path,
out_ply=aligned_gs_path,
instance_pose=pose,
device="cpu",
)
color_path = os.path.join(output_root, "color.png")
render_gs_api(aligned_gs_path, color_path)
geo_flag, geo_result = GEO_CHECKER([color_path])
logger.warning(
f"{GEO_CHECKER.__class__.__name__}: {geo_result} for {seg_path}"
)
if geo_flag is True or geo_flag is None:
break
seed = random.randint(0, 100000) if seed is not None else None
# Render the video for generated 3D asset.
color_images = render_video(gs_model)["color"] color_images = render_video(gs_model)["color"]
normal_images = render_video(mesh_model)["normal"] normal_images = render_video(mesh_model)["normal"]
video_path = os.path.join(output_root, "gs_mesh.mp4") video_path = os.path.join(output_root, "gs_mesh.mp4")
merge_images_video(color_images, normal_images, video_path) merge_images_video(color_images, normal_images, video_path)
# Save the raw Gaussian model
gs_path = mesh_out.replace(".obj", "_gs.ply")
gs_model.save_ply(gs_path)
# Rotate mesh and GS by 90 degrees around Z-axis.
rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
# Addtional rotation for GS to align mesh.
gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix)
pose = GaussianOperator.trans_to_quatpose(gs_rot)
aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
GaussianOperator.resave_ply(
in_ply=gs_path,
out_ply=aligned_gs_path,
instance_pose=pose,
device="cpu",
)
color_path = os.path.join(output_root, "color.png")
render_gs_api(aligned_gs_path, color_path)
mesh = trimesh.Trimesh( mesh = trimesh.Trimesh(
vertices=mesh_model.vertices.cpu().numpy(), vertices=mesh_model.vertices.cpu().numpy(),
faces=mesh_model.faces.cpu().numpy(), faces=mesh_model.faces.cpu().numpy(),
@ -249,8 +271,8 @@ if __name__ == "__main__":
min_mass, max_mass = map(float, args.mass_range.split("-")) min_mass, max_mass = map(float, args.mass_range.split("-"))
asset_attrs["min_mass"] = min_mass asset_attrs["min_mass"] = min_mass
asset_attrs["max_mass"] = max_mass asset_attrs["max_mass"] = max_mass
if args.asset_type: if isinstance(args.asset_type, list) and args.asset_type[idx]:
asset_attrs["category"] = args.asset_type asset_attrs["category"] = args.asset_type[idx]
if args.version: if args.version:
asset_attrs["version"] = args.version asset_attrs["version"] = args.version
@ -289,8 +311,8 @@ if __name__ == "__main__":
] ]
images_list.append(images) images_list.append(images)
results = BaseChecker.validate(CHECKERS, images_list) qa_results = BaseChecker.validate(CHECKERS, images_list)
urdf_convertor.add_quality_tag(urdf_path, results) urdf_convertor.add_quality_tag(urdf_path, qa_results)
# Organize the final result files # Organize the final result files
result_dir = f"{output_root}/result" result_dir = f"{output_root}/result"
@ -303,7 +325,7 @@ if __name__ == "__main__":
f"{result_dir}/{urdf_convertor.output_mesh_dir}", f"{result_dir}/{urdf_convertor.output_mesh_dir}",
) )
copy(video_path, f"{result_dir}/video.mp4") copy(video_path, f"{result_dir}/video.mp4")
if args.remove_intermediate: if not args.keep_intermediate:
delete_dir(output_root, keep_subs=["result"]) delete_dir(output_root, keep_subs=["result"])
except Exception as e: except Exception as e:
@ -311,3 +333,7 @@ if __name__ == "__main__":
continue continue
logger.info(f"Processing complete. Outputs saved to {args.output_root}") logger.info(f"Processing complete. Outputs saved to {args.output_root}")
if __name__ == "__main__":
entrypoint()

View File

@ -85,7 +85,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--seed", "--seed",
type=int, type=int,
default=0, default=None,
) )
args = parser.parse_args() args = parser.parse_args()

View File

@ -0,0 +1,271 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import argparse
import os
import random
from collections import defaultdict
import torch
from PIL import Image
from embodied_gen.models.image_comm_model import build_hf_image_pipeline
from embodied_gen.models.segment_model import RembgRemover
from embodied_gen.models.text_model import PROMPT_APPEND
from embodied_gen.scripts.imageto3d import entrypoint as imageto3d_api
from embodied_gen.utils.gpt_clients import GPT_CLIENT
from embodied_gen.utils.log import logger
from embodied_gen.utils.process_media import render_asset3d
from embodied_gen.validators.quality_checkers import (
ImageSegChecker,
SemanticConsistChecker,
TextGenAlignChecker,
)
# Avoid huggingface/tokenizers: The current process just got forked.
os.environ["TOKENIZERS_PARALLELISM"] = "false"
random.seed(0)
__all__ = [
"text_to_image",
"text_to_3d",
]
logger.info("Loading Models...")
SEMANTIC_CHECKER = SemanticConsistChecker(GPT_CLIENT)
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
TXTGEN_CHECKER = TextGenAlignChecker(GPT_CLIENT)
PIPE_IMG = build_hf_image_pipeline("sd35")
BG_REMOVER = RembgRemover()
def text_to_image(
prompt: str,
save_path: str,
n_retry: int,
img_denoise_step: int,
text_guidance_scale: float,
n_img_sample: int,
image_hw: tuple[int, int] = (1024, 1024),
seed: int = None,
) -> bool:
select_image = None
success_flag = False
assert save_path.endswith(".png"), "Image save path must end with `.png`."
for try_idx in range(n_retry):
if select_image is not None:
select_image[0].save(save_path.replace(".png", "_raw.png"))
select_image[1].save(save_path)
break
f_prompt = PROMPT_APPEND.format(object=prompt)
logger.info(
f"Image GEN for {os.path.basename(save_path)}\n"
f"Try: {try_idx + 1}/{n_retry}, Seed: {seed}, Prompt: {f_prompt}"
)
images = PIPE_IMG.run(
f_prompt,
num_inference_steps=img_denoise_step,
guidance_scale=text_guidance_scale,
num_images_per_prompt=n_img_sample,
height=image_hw[0],
width=image_hw[1],
generator=(
torch.Generator().manual_seed(seed)
if seed is not None
else None
),
)
for idx in range(len(images)):
raw_image: Image.Image = images[idx]
image = BG_REMOVER(raw_image)
image.save(save_path)
semantic_flag, semantic_result = SEMANTIC_CHECKER(
prompt, [image.convert("RGB")]
)
seg_flag, seg_result = SEG_CHECKER(
[raw_image, image.convert("RGB")]
)
if (
(semantic_flag and seg_flag)
or semantic_flag is None
or seg_flag is None
):
select_image = [raw_image, image]
success_flag = True
break
torch.cuda.empty_cache()
seed = random.randint(0, 100000) if seed is not None else None
return success_flag
def text_to_3d(**kwargs) -> dict:
args = parse_args()
for k, v in kwargs.items():
if hasattr(args, k) and v is not None:
setattr(args, k, v)
if args.asset_names is None or len(args.asset_names) == 0:
args.asset_names = [f"sample3d_{i}" for i in range(len(args.prompts))]
img_save_dir = os.path.join(args.output_root, "images")
asset_save_dir = os.path.join(args.output_root, "asset3d")
os.makedirs(img_save_dir, exist_ok=True)
os.makedirs(asset_save_dir, exist_ok=True)
results = defaultdict(dict)
for prompt, node in zip(args.prompts, args.asset_names):
success_flag = False
n_pipe_retry = args.n_pipe_retry
seed_img = args.seed_img
seed_3d = args.seed_3d
while success_flag is False and n_pipe_retry > 0:
logger.info(
f"GEN pipeline for node {node}\n"
f"Try round: {args.n_pipe_retry-n_pipe_retry+1}/{args.n_pipe_retry}, Prompt: {prompt}"
)
# Text-to-image GEN
save_node = node.replace(" ", "_")
gen_image_path = f"{img_save_dir}/{save_node}.png"
textgen_flag = text_to_image(
prompt,
gen_image_path,
args.n_image_retry,
args.img_denoise_step,
args.text_guidance_scale,
args.n_img_sample,
seed=seed_img,
)
# Asset 3D GEN
node_save_dir = f"{asset_save_dir}/{save_node}"
asset_type = node if "sample3d_" not in node else None
imageto3d_api(
image_path=[gen_image_path],
output_root=node_save_dir,
asset_type=[asset_type],
seed=random.randint(0, 100000) if seed_3d is None else seed_3d,
n_retry=args.n_asset_retry,
keep_intermediate=args.keep_intermediate,
)
mesh_path = f"{node_save_dir}/result/mesh/{save_node}.obj"
image_path = render_asset3d(
mesh_path,
output_root=f"{node_save_dir}/result",
num_images=6,
elevation=(30, -30),
output_subdir="renders",
no_index_file=True,
)
check_text = asset_type if asset_type is not None else prompt
qa_flag, qa_result = TXTGEN_CHECKER(check_text, image_path)
logger.warning(
f"Node {node}, {TXTGEN_CHECKER.__class__.__name__}: {qa_result}"
)
results["assets"][node] = f"{node_save_dir}/result"
results["quality"][node] = qa_result
if qa_flag is None or qa_flag is True:
success_flag = True
break
n_pipe_retry -= 1
seed_img = (
random.randint(0, 100000) if seed_img is not None else None
)
seed_3d = (
random.randint(0, 100000) if seed_3d is not None else None
)
torch.cuda.empty_cache()
return results
def parse_args():
parser = argparse.ArgumentParser(description="3D Layout Generation Config")
parser.add_argument("--prompts", nargs="+", help="text descriptions")
parser.add_argument(
"--output_root",
type=str,
help="Directory to save outputs",
)
parser.add_argument(
"--asset_names",
type=str,
nargs="+",
default=None,
help="Asset names to generate",
)
parser.add_argument(
"--n_img_sample",
type=int,
default=3,
help="Number of image samples to generate",
)
parser.add_argument(
"--text_guidance_scale",
type=float,
default=7,
help="Text-to-image guidance scale",
)
parser.add_argument(
"--img_denoise_step",
type=int,
default=25,
help="Denoising steps for image generation",
)
parser.add_argument(
"--n_image_retry",
type=int,
default=2,
help="Max retry count for image generation",
)
parser.add_argument(
"--n_asset_retry",
type=int,
default=2,
help="Max retry count for 3D generation",
)
parser.add_argument(
"--n_pipe_retry",
type=int,
default=1,
help="Max retry count for 3D asset generation",
)
parser.add_argument(
"--seed_img",
type=int,
default=None,
help="Random seed for image generation",
)
parser.add_argument(
"--seed_3d",
type=int,
default=0,
help="Random seed for 3D generation",
)
parser.add_argument("--keep_intermediate", action="store_true")
args, unknown = parser.parse_known_args()
return args
if __name__ == "__main__":
text_to_3d()

View File

@ -2,7 +2,9 @@
# Initialize variables # Initialize variables
prompts=() prompts=()
asset_types=()
output_root="" output_root=""
seed=0
# Parse arguments # Parse arguments
while [[ $# -gt 0 ]]; do while [[ $# -gt 0 ]]; do
@ -14,10 +16,21 @@ while [[ $# -gt 0 ]]; do
shift shift
done done
;; ;;
--asset_types)
shift
while [[ $# -gt 0 && ! "$1" =~ ^-- ]]; do
asset_types+=("$1")
shift
done
;;
--output_root) --output_root)
output_root="$2" output_root="$2"
shift 2 shift 2
;; ;;
--seed)
seed="$2"
shift 2
;;
*) *)
echo "Unknown argument: $1" echo "Unknown argument: $1"
exit 1 exit 1
@ -28,7 +41,21 @@ done
# Validate required arguments # Validate required arguments
if [[ ${#prompts[@]} -eq 0 || -z "$output_root" ]]; then if [[ ${#prompts[@]} -eq 0 || -z "$output_root" ]]; then
echo "Missing required arguments." echo "Missing required arguments."
echo "Usage: bash run_text2asset3d.sh --prompts \"Prompt1\" \"Prompt2\" --output_root <path>" echo "Usage: bash run_text2asset3d.sh --prompts \"Prompt1\" \"Prompt2\" \
--asset_types \"type1\" \"type2\" --seed <seed_value> --output_root <path>"
exit 1
fi
# If no asset_types provided, default to ""
if [[ ${#asset_types[@]} -eq 0 ]]; then
for (( i=0; i<${#prompts[@]}; i++ )); do
asset_types+=("")
done
fi
# Ensure the number of asset_types matches the number of prompts
if [[ ${#prompts[@]} -ne ${#asset_types[@]} ]]; then
echo "The number of asset types must match the number of prompts."
exit 1 exit 1
fi fi
@ -37,20 +64,30 @@ echo "Prompts:"
for p in "${prompts[@]}"; do for p in "${prompts[@]}"; do
echo " - $p" echo " - $p"
done done
# echo "Asset types:"
# for at in "${asset_types[@]}"; do
# echo " - $at"
# done
echo "Output root: ${output_root}" echo "Output root: ${output_root}"
echo "Seed: ${seed}"
# Concatenate prompts for Python command # Concatenate prompts and asset types for Python command
prompt_args="" prompt_args=""
for p in "${prompts[@]}"; do asset_type_args=""
prompt_args+="\"$p\" " for i in "${!prompts[@]}"; do
prompt_args+="\"${prompts[$i]}\" "
asset_type_args+="\"${asset_types[$i]}\" "
done done
# Step 1: Text-to-Image # Step 1: Text-to-Image
eval python3 embodied_gen/scripts/text2image.py \ eval python3 embodied_gen/scripts/text2image.py \
--prompts ${prompt_args} \ --prompts ${prompt_args} \
--output_root "${output_root}/images" --output_root "${output_root}/images" \
--seed ${seed}
# Step 2: Image-to-3D # Step 2: Image-to-3D
python3 embodied_gen/scripts/imageto3d.py \ python3 embodied_gen/scripts/imageto3d.py \
--image_root "${output_root}/images" \ --image_root "${output_root}/images" \
--output_root "${output_root}/asset3d" --output_root "${output_root}/asset3d" \
--asset_type ${asset_type_args}

View File

@ -10,10 +10,6 @@ while [[ $# -gt 0 ]]; do
prompt="$2" prompt="$2"
shift 2 shift 2
;; ;;
--uuid)
uuid="$2"
shift 2
;;
--output_root) --output_root)
output_root="$2" output_root="$2"
shift 2 shift 2
@ -26,12 +22,13 @@ while [[ $# -gt 0 ]]; do
done done
if [[ -z "$mesh_path" || -z "$prompt" || -z "$uuid" || -z "$output_root" ]]; then if [[ -z "$mesh_path" || -z "$prompt" || -z "$output_root" ]]; then
echo "params missing" echo "params missing"
echo "usage: bash run.sh --mesh_path <path> --prompt <text> --uuid <id> --output_root <path>" echo "usage: bash run.sh --mesh_path <path> --prompt <text> --output_root <path>"
exit 1 exit 1
fi fi
uuid=$(basename "$output_root")
# Step 1: drender-cli for condition rendering # Step 1: drender-cli for condition rendering
drender-cli --mesh_path ${mesh_path} \ drender-cli --mesh_path ${mesh_path} \
--output_root ${output_root}/condition \ --output_root ${output_root}/condition \

107
embodied_gen/utils/enum.py Normal file
View File

@ -0,0 +1,107 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
from dataclasses import dataclass, field
from enum import Enum
from dataclasses_json import DataClassJsonMixin
__all__ = [
"RenderItems",
"Scene3DItemEnum",
"SpatialRelationEnum",
"RobotItemEnum",
]
@dataclass
class RenderItems(str, Enum):
IMAGE = "image_color"
ALPHA = "image_mask"
VIEW_NORMAL = "image_view_normal"
GLOBAL_NORMAL = "image_global_normal"
POSITION_MAP = "image_position"
DEPTH = "image_depth"
ALBEDO = "image_albedo"
DIFFUSE = "image_diffuse"
@dataclass
class Scene3DItemEnum(str, Enum):
BACKGROUND = "background"
CONTEXT = "context"
ROBOT = "robot"
MANIPULATED_OBJS = "manipulated_objs"
DISTRACTOR_OBJS = "distractor_objs"
OTHERS = "others"
@classmethod
def object_list(cls, layout_relation: dict) -> list:
return (
[
layout_relation[cls.BACKGROUND.value],
layout_relation[cls.CONTEXT.value],
]
+ layout_relation[cls.MANIPULATED_OBJS.value]
+ layout_relation[cls.DISTRACTOR_OBJS.value]
)
@classmethod
def object_mapping(cls, layout_relation):
relation_mapping = {
# layout_relation[cls.ROBOT.value]: cls.ROBOT.value,
layout_relation[cls.BACKGROUND.value]: cls.BACKGROUND.value,
layout_relation[cls.CONTEXT.value]: cls.CONTEXT.value,
}
relation_mapping.update(
{
item: cls.MANIPULATED_OBJS.value
for item in layout_relation[cls.MANIPULATED_OBJS.value]
}
)
relation_mapping.update(
{
item: cls.DISTRACTOR_OBJS.value
for item in layout_relation[cls.DISTRACTOR_OBJS.value]
}
)
return relation_mapping
@dataclass
class SpatialRelationEnum(str, Enum):
ON = "ON" # objects on the table
IN = "IN" # objects in the room
INSIDE = "INSIDE" # objects inside the shelf/rack
FLOOR = "FLOOR" # object floor room/bin
@dataclass
class RobotItemEnum(str, Enum):
FRANKA = "franka"
UR5 = "ur5"
PIPER = "piper"
@dataclass
class LayoutInfo(DataClassJsonMixin):
tree: dict[str, list]
relation: dict[str, str | list[str]]
objs_desc: dict[str, str] = field(default_factory=dict)
assets: dict[str, str] = field(default_factory=dict)
quality: dict[str, str] = field(default_factory=dict)
position: dict[str, list[float]] = field(default_factory=dict)

View File

@ -30,12 +30,20 @@ from tenacity import (
stop_after_delay, stop_after_delay,
wait_random_exponential, wait_random_exponential,
) )
from embodied_gen.utils.process_media import combine_images_to_base64 from embodied_gen.utils.process_media import combine_images_to_grid
logging.basicConfig(level=logging.INFO) logging.getLogger("httpx").setLevel(logging.WARNING)
logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
__all__ = [
"GPTclient",
]
CONFIG_FILE = "embodied_gen/utils/gpt_config.yaml"
class GPTclient: class GPTclient:
"""A client to interact with the GPT model via OpenAI or Azure API.""" """A client to interact with the GPT model via OpenAI or Azure API."""
@ -45,6 +53,7 @@ class GPTclient:
api_key: str, api_key: str,
model_name: str = "yfb-gpt-4o", model_name: str = "yfb-gpt-4o",
api_version: str = None, api_version: str = None,
check_connection: bool = True,
verbose: bool = False, verbose: bool = False,
): ):
if api_version is not None: if api_version is not None:
@ -63,6 +72,9 @@ class GPTclient:
self.model_name = model_name self.model_name = model_name
self.image_formats = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"} self.image_formats = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"}
self.verbose = verbose self.verbose = verbose
if check_connection:
self.check_connection()
logger.info(f"Using GPT model: {self.model_name}.") logger.info(f"Using GPT model: {self.model_name}.")
@retry( @retry(
@ -77,6 +89,7 @@ class GPTclient:
text_prompt: str, text_prompt: str,
image_base64: Optional[list[str | Image.Image]] = None, image_base64: Optional[list[str | Image.Image]] = None,
system_role: Optional[str] = None, system_role: Optional[str] = None,
params: Optional[dict] = None,
) -> Optional[str]: ) -> Optional[str]:
"""Queries the GPT model with a text and optional image prompts. """Queries the GPT model with a text and optional image prompts.
@ -86,6 +99,7 @@ class GPTclient:
or local image paths or PIL.Image to accompany the text prompt. or local image paths or PIL.Image to accompany the text prompt.
system_role (Optional[str]): Optional system-level instructions system_role (Optional[str]): Optional system-level instructions
that specify the behavior of the assistant. that specify the behavior of the assistant.
params (Optional[dict]): Additional parameters for GPT setting.
Returns: Returns:
Optional[str]: The response content generated by the model based on Optional[str]: The response content generated by the model based on
@ -103,11 +117,11 @@ class GPTclient:
# Process images if provided # Process images if provided
if image_base64 is not None: if image_base64 is not None:
image_base64 = ( if not isinstance(image_base64, list):
image_base64 image_base64 = [image_base64]
if isinstance(image_base64, list) # Hardcode tmp because of the openrouter can't input multi images.
else [image_base64] if "openrouter" in self.endpoint:
) image_base64 = combine_images_to_grid(image_base64)
for img in image_base64: for img in image_base64:
if isinstance(img, Image.Image): if isinstance(img, Image.Image):
buffer = BytesIO() buffer = BytesIO()
@ -142,8 +156,11 @@ class GPTclient:
"frequency_penalty": 0, "frequency_penalty": 0,
"presence_penalty": 0, "presence_penalty": 0,
"stop": None, "stop": None,
"model": self.model_name,
} }
payload.update({"model": self.model_name})
if params:
payload.update(params)
response = None response = None
try: try:
@ -159,8 +176,28 @@ class GPTclient:
return response return response
def check_connection(self) -> None:
"""Check whether the GPT API connection is working."""
try:
response = self.completion_with_backoff(
messages=[
{"role": "system", "content": "You are a test system."},
{"role": "user", "content": "Hello"},
],
model=self.model_name,
temperature=0,
max_tokens=100,
)
content = response.choices[0].message.content
logger.info(f"Connection check success.")
except Exception as e:
raise ConnectionError(
f"Failed to connect to GPT API at {self.endpoint}, "
f"please check setting in `{CONFIG_FILE}` and `README`."
)
with open("embodied_gen/utils/gpt_config.yaml", "r") as f:
with open(CONFIG_FILE, "r") as f:
config = yaml.safe_load(f) config = yaml.safe_load(f)
agent_type = config["agent_type"] agent_type = config["agent_type"]
@ -177,32 +214,5 @@ GPT_CLIENT = GPTclient(
api_key=api_key, api_key=api_key,
api_version=api_version, api_version=api_version,
model_name=model_name, model_name=model_name,
check_connection=False,
) )
if __name__ == "__main__":
if "openrouter" in GPT_CLIENT.endpoint:
response = GPT_CLIENT.query(
text_prompt="What is the content in each image?",
image_base64=combine_images_to_base64(
[
"apps/assets/example_image/sample_02.jpg",
"apps/assets/example_image/sample_03.jpg",
]
), # input raw image_path if only one image
)
print(response)
else:
response = GPT_CLIENT.query(
text_prompt="What is the content in the images?",
image_base64=[
Image.open("apps/assets/example_image/sample_02.jpg"),
Image.open("apps/assets/example_image/sample_03.jpg"),
],
)
print(response)
# test2: text prompt
response = GPT_CLIENT.query(
text_prompt="What is the capital of China?"
)
print(response)

48
embodied_gen/utils/log.py Normal file
View File

@ -0,0 +1,48 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import logging
from colorlog import ColoredFormatter
__all__ = [
"logger",
]
LOG_FORMAT = (
"%(log_color)s[%(asctime)s] %(levelname)-8s | %(message)s%(reset)s"
)
DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
formatter = ColoredFormatter(
LOG_FORMAT,
datefmt=DATE_FORMAT,
log_colors={
"DEBUG": "cyan",
"INFO": "green",
"WARNING": "yellow",
"ERROR": "red",
"CRITICAL": "bold_red",
},
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(handler)
logger.propagate = False

View File

@ -15,34 +15,24 @@
# permissions and limitations under the License. # permissions and limitations under the License.
import base64
import logging import logging
import math import math
import os import os
import sys import textwrap
from glob import glob from glob import glob
from io import BytesIO
from typing import Union from typing import Union
import cv2 import cv2
import imageio import imageio
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np import numpy as np
import PIL.Image as Image import PIL.Image as Image
import spaces import spaces
import torch from matplotlib.patches import Patch
from moviepy.editor import VideoFileClip, clips_array from moviepy.editor import VideoFileClip, clips_array
from tqdm import tqdm
from embodied_gen.data.differentiable_render import entrypoint as render_api from embodied_gen.data.differentiable_render import entrypoint as render_api
from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum
current_file_path = os.path.abspath(__file__)
current_dir = os.path.dirname(current_file_path)
sys.path.append(os.path.join(current_dir, "../.."))
from thirdparty.TRELLIS.trellis.renderers.mesh_renderer import MeshRenderer
from thirdparty.TRELLIS.trellis.representations import MeshExtractResult
from thirdparty.TRELLIS.trellis.utils.render_utils import (
render_frames,
yaw_pitch_r_fov_to_extrinsics_intrinsics,
)
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -53,9 +43,8 @@ __all__ = [
"merge_images_video", "merge_images_video",
"filter_small_connected_components", "filter_small_connected_components",
"filter_image_small_connected_components", "filter_image_small_connected_components",
"combine_images_to_base64", "combine_images_to_grid",
"render_mesh", "SceneTreeVisualizer",
"render_video",
] ]
@ -66,12 +55,14 @@ def render_asset3d(
distance: float = 5.0, distance: float = 5.0,
num_images: int = 1, num_images: int = 1,
elevation: list[float] = (0.0,), elevation: list[float] = (0.0,),
pbr_light_factor: float = 1.5, pbr_light_factor: float = 1.2,
return_key: str = "image_color/*", return_key: str = "image_color/*",
output_subdir: str = "renders", output_subdir: str = "renders",
gen_color_mp4: bool = False, gen_color_mp4: bool = False,
gen_viewnormal_mp4: bool = False, gen_viewnormal_mp4: bool = False,
gen_glonormal_mp4: bool = False, gen_glonormal_mp4: bool = False,
no_index_file: bool = False,
with_mtl: bool = True,
) -> list[str]: ) -> list[str]:
input_args = dict( input_args = dict(
mesh_path=mesh_path, mesh_path=mesh_path,
@ -81,14 +72,13 @@ def render_asset3d(
num_images=num_images, num_images=num_images,
elevation=elevation, elevation=elevation,
pbr_light_factor=pbr_light_factor, pbr_light_factor=pbr_light_factor,
with_mtl=True, with_mtl=with_mtl,
gen_color_mp4=gen_color_mp4,
gen_viewnormal_mp4=gen_viewnormal_mp4,
gen_glonormal_mp4=gen_glonormal_mp4,
no_index_file=no_index_file,
) )
if gen_color_mp4:
input_args["gen_color_mp4"] = True
if gen_viewnormal_mp4:
input_args["gen_viewnormal_mp4"] = True
if gen_glonormal_mp4:
input_args["gen_glonormal_mp4"] = True
try: try:
_ = render_api(**input_args) _ = render_api(**input_args)
except Exception as e: except Exception as e:
@ -168,12 +158,15 @@ def filter_image_small_connected_components(
return image return image
def combine_images_to_base64( def combine_images_to_grid(
images: list[str | Image.Image], images: list[str | Image.Image],
cat_row_col: tuple[int, int] = None, cat_row_col: tuple[int, int] = None,
target_wh: tuple[int, int] = (512, 512), target_wh: tuple[int, int] = (512, 512),
) -> str: ) -> list[str | Image.Image]:
n_images = len(images) n_images = len(images)
if n_images == 1:
return images
if cat_row_col is None: if cat_row_col is None:
n_col = math.ceil(math.sqrt(n_images)) n_col = math.ceil(math.sqrt(n_images))
n_row = math.ceil(n_images / n_col) n_row = math.ceil(n_images / n_col)
@ -182,88 +175,190 @@ def combine_images_to_base64(
images = [ images = [
Image.open(p).convert("RGB") if isinstance(p, str) else p Image.open(p).convert("RGB") if isinstance(p, str) else p
for p in images[: n_row * n_col] for p in images
] ]
images = [img.resize(target_wh) for img in images] images = [img.resize(target_wh) for img in images]
grid_w, grid_h = n_col * target_wh[0], n_row * target_wh[1] grid_w, grid_h = n_col * target_wh[0], n_row * target_wh[1]
grid = Image.new("RGB", (grid_w, grid_h), (255, 255, 255)) grid = Image.new("RGB", (grid_w, grid_h), (0, 0, 0))
for idx, img in enumerate(images): for idx, img in enumerate(images):
row, col = divmod(idx, n_col) row, col = divmod(idx, n_col)
grid.paste(img, (col * target_wh[0], row * target_wh[1])) grid.paste(img, (col * target_wh[0], row * target_wh[1]))
buffer = BytesIO() return [grid]
grid.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
@spaces.GPU class SceneTreeVisualizer:
def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs): def __init__(self, layout_info: LayoutInfo) -> None:
renderer = MeshRenderer() self.tree = layout_info.tree
renderer.rendering_options.resolution = options.get("resolution", 512) self.relation = layout_info.relation
renderer.rendering_options.near = options.get("near", 1) self.objs_desc = layout_info.objs_desc
renderer.rendering_options.far = options.get("far", 100) self.G = nx.DiGraph()
renderer.rendering_options.ssaa = options.get("ssaa", 4) self.root = self._find_root()
rets = {} self._build_graph()
for extr, intr in tqdm(zip(extrinsics, intrinsics), desc="Rendering"):
res = renderer.render(sample, extr, intr) self.role_colors = {
if "normal" not in rets: Scene3DItemEnum.BACKGROUND.value: "plum",
rets["normal"] = [] Scene3DItemEnum.CONTEXT.value: "lightblue",
normal = torch.lerp( Scene3DItemEnum.ROBOT.value: "lightcoral",
torch.zeros_like(res["normal"]), res["normal"], res["mask"] Scene3DItemEnum.MANIPULATED_OBJS.value: "lightgreen",
Scene3DItemEnum.DISTRACTOR_OBJS.value: "lightgray",
Scene3DItemEnum.OTHERS.value: "orange",
}
def _find_root(self) -> str:
children = {c for cs in self.tree.values() for c, _ in cs}
parents = set(self.tree.keys())
roots = parents - children
if not roots:
raise ValueError("No root node found.")
return next(iter(roots))
def _build_graph(self):
for parent, children in self.tree.items():
for child, relation in children:
self.G.add_edge(parent, child, relation=relation)
def _get_node_role(self, node: str) -> str:
if node == self.relation.get(Scene3DItemEnum.BACKGROUND.value):
return Scene3DItemEnum.BACKGROUND.value
if node == self.relation.get(Scene3DItemEnum.CONTEXT.value):
return Scene3DItemEnum.CONTEXT.value
if node == self.relation.get(Scene3DItemEnum.ROBOT.value):
return Scene3DItemEnum.ROBOT.value
if node in self.relation.get(
Scene3DItemEnum.MANIPULATED_OBJS.value, []
):
return Scene3DItemEnum.MANIPULATED_OBJS.value
if node in self.relation.get(
Scene3DItemEnum.DISTRACTOR_OBJS.value, []
):
return Scene3DItemEnum.DISTRACTOR_OBJS.value
return Scene3DItemEnum.OTHERS.value
def _get_positions(
self, root, width=1.0, vert_gap=0.1, vert_loc=1, xcenter=0.5, pos=None
):
if pos is None:
pos = {root: (xcenter, vert_loc)}
else:
pos[root] = (xcenter, vert_loc)
children = list(self.G.successors(root))
if children:
dx = width / len(children)
next_x = xcenter - width / 2 - dx / 2
for child in children:
next_x += dx
pos = self._get_positions(
child,
width=dx,
vert_gap=vert_gap,
vert_loc=vert_loc - vert_gap,
xcenter=next_x,
pos=pos,
)
return pos
def render(
self,
save_path: str,
figsize=(8, 6),
dpi=300,
title: str = "Scene 3D Hierarchy Tree",
):
node_colors = [
self.role_colors[self._get_node_role(n)] for n in self.G.nodes
]
pos = self._get_positions(self.root)
plt.figure(figsize=figsize)
nx.draw(
self.G,
pos,
with_labels=True,
arrows=False,
node_size=2000,
node_color=node_colors,
font_size=10,
font_weight="bold",
) )
normal = np.clip(
normal.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255
).astype(np.uint8)
rets["normal"].append(normal)
return rets # Draw edge labels
edge_labels = nx.get_edge_attributes(self.G, "relation")
nx.draw_networkx_edge_labels(
self.G,
pos,
edge_labels=edge_labels,
font_size=9,
font_color="black",
)
# Draw small description text under each node (if available)
for node, (x, y) in pos.items():
desc = self.objs_desc.get(node)
if desc:
wrapped = "\n".join(textwrap.wrap(desc, width=30))
plt.text(
x,
y - 0.006,
wrapped,
fontsize=6,
ha="center",
va="top",
wrap=True,
color="black",
bbox=dict(
facecolor="dimgray",
edgecolor="darkgray",
alpha=0.1,
boxstyle="round,pad=0.2",
),
)
plt.title(title, fontsize=12)
task_desc = self.relation.get("task_desc", "")
if task_desc:
plt.suptitle(
f"Task Description: {task_desc}", fontsize=10, y=0.999
)
plt.axis("off")
legend_handles = [
Patch(facecolor=color, edgecolor='black', label=role)
for role, color in self.role_colors.items()
]
plt.legend(
handles=legend_handles,
loc="lower center",
ncol=3,
bbox_to_anchor=(0.5, -0.1),
fontsize=9,
)
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(save_path, dpi=dpi, bbox_inches="tight")
plt.close()
@spaces.GPU def load_scene_dict(file_path: str) -> dict:
def render_video( scene_dict = {}
sample, with open(file_path, "r", encoding='utf-8') as f:
resolution=512, for line in f:
bg_color=(0, 0, 0), line = line.strip()
num_frames=300, if not line or ":" not in line:
r=2, continue
fov=40, scene_id, desc = line.split(":", 1)
**kwargs, scene_dict[scene_id.strip()] = desc.strip()
):
yaws = torch.linspace(0, 2 * 3.1415, num_frames)
yaws = yaws.tolist()
pitch = [0.5] * num_frames
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(
yaws, pitch, r, fov
)
render_fn = (
render_mesh if isinstance(sample, MeshExtractResult) else render_frames
)
result = render_fn(
sample,
extrinsics,
intrinsics,
{"resolution": resolution, "bg_color": bg_color},
**kwargs,
)
return result return scene_dict
if __name__ == "__main__": if __name__ == "__main__":
# Example usage:
merge_video_video( merge_video_video(
"outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh_glo_normal.mp4", # noqa "outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh_glo_normal.mp4", # noqa
"outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh.mp4", # noqa "outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh.mp4", # noqa
"merge.mp4", "merge.mp4",
) )
image_base64 = combine_images_to_base64(
[
"apps/assets/example_image/sample_00.jpg",
"apps/assets/example_image/sample_01.jpg",
"apps/assets/example_image/sample_02.jpg",
]
)

View File

@ -1 +1 @@
VERSION = "v0.1.0" VERSION = "v0.1.1"

View File

@ -0,0 +1,90 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import os
import sys
import numpy as np
import spaces
import torch
from tqdm import tqdm
current_file_path = os.path.abspath(__file__)
current_dir = os.path.dirname(current_file_path)
sys.path.append(os.path.join(current_dir, "../.."))
from thirdparty.TRELLIS.trellis.renderers.mesh_renderer import MeshRenderer
from thirdparty.TRELLIS.trellis.representations import MeshExtractResult
from thirdparty.TRELLIS.trellis.utils.render_utils import (
render_frames,
yaw_pitch_r_fov_to_extrinsics_intrinsics,
)
__all__ = [
"render_video",
]
@spaces.GPU
def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs):
renderer = MeshRenderer()
renderer.rendering_options.resolution = options.get("resolution", 512)
renderer.rendering_options.near = options.get("near", 1)
renderer.rendering_options.far = options.get("far", 100)
renderer.rendering_options.ssaa = options.get("ssaa", 4)
rets = {}
for extr, intr in tqdm(zip(extrinsics, intrinsics), desc="Rendering"):
res = renderer.render(sample, extr, intr)
if "normal" not in rets:
rets["normal"] = []
normal = torch.lerp(
torch.zeros_like(res["normal"]), res["normal"], res["mask"]
)
normal = np.clip(
normal.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255
).astype(np.uint8)
rets["normal"].append(normal)
return rets
@spaces.GPU
def render_video(
sample,
resolution=512,
bg_color=(0, 0, 0),
num_frames=300,
r=2,
fov=40,
**kwargs,
):
yaws = torch.linspace(0, 2 * 3.1415, num_frames)
yaws = yaws.tolist()
pitch = [0.5] * num_frames
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(
yaws, pitch, r, fov
)
render_fn = (
render_mesh if isinstance(sample, MeshExtractResult) else render_frames
)
result = render_fn(
sample,
extrinsics,
intrinsics,
{"resolution": resolution, "bg_color": bg_color},
**kwargs,
)
return result

View File

@ -102,7 +102,7 @@ class AestheticPredictor:
def _load_sac_model(self, model_path, input_size): def _load_sac_model(self, model_path, input_size):
"""Load the SAC model.""" """Load the SAC model."""
model = self.MLP(input_size) model = self.MLP(input_size)
ckpt = torch.load(model_path) ckpt = torch.load(model_path, weights_only=True)
model.load_state_dict(ckpt) model.load_state_dict(ckpt)
model.to(self.device) model.to(self.device)
model.eval() model.eval()
@ -135,15 +135,3 @@ class AestheticPredictor:
) )
return prediction.item() return prediction.item()
if __name__ == "__main__":
# Configuration
img_path = "apps/assets/example_image/sample_00.jpg"
# Initialize the predictor
predictor = AestheticPredictor()
# Predict the aesthetic score
score = predictor.predict(img_path)
print("Aesthetic score predicted by the model:", score)

View File

@ -16,17 +16,26 @@
import logging import logging
import os import random
from tqdm import tqdm import json_repair
from PIL import Image
from embodied_gen.utils.gpt_clients import GPT_CLIENT, GPTclient from embodied_gen.utils.gpt_clients import GPT_CLIENT, GPTclient
from embodied_gen.utils.process_media import render_asset3d
from embodied_gen.validators.aesthetic_predictor import AestheticPredictor from embodied_gen.validators.aesthetic_predictor import AestheticPredictor
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
__all__ = [
"MeshGeoChecker",
"ImageSegChecker",
"ImageAestheticChecker",
"SemanticConsistChecker",
"TextGenAlignChecker",
]
class BaseChecker: class BaseChecker:
def __init__(self, prompt: str = None, verbose: bool = False) -> None: def __init__(self, prompt: str = None, verbose: bool = False) -> None:
self.prompt = prompt self.prompt = prompt
@ -37,16 +46,20 @@ class BaseChecker:
"Subclasses must implement the query method." "Subclasses must implement the query method."
) )
def __call__(self, *args, **kwargs) -> bool: def __call__(self, *args, **kwargs) -> tuple[bool, str]:
response = self.query(*args, **kwargs) response = self.query(*args, **kwargs)
if response is None: if self.verbose:
response = "Error when calling gpt api."
if self.verbose and response != "YES":
logger.info(response) logger.info(response)
flag = "YES" in response if response is None:
response = "YES" if flag else response flag = None
response = (
"Error when calling GPT api, check config in "
"`embodied_gen/utils/gpt_config.yaml` or net connection."
)
else:
flag = "YES" in response
response = "YES" if flag else response
return flag, response return flag, response
@ -92,21 +105,29 @@ class MeshGeoChecker(BaseChecker):
self.gpt_client = gpt_client self.gpt_client = gpt_client
if self.prompt is None: if self.prompt is None:
self.prompt = """ self.prompt = """
Refer to the provided multi-view rendering images to evaluate You are an expert in evaluating the geometry quality of generated 3D asset.
whether the geometry of the 3D object asset is complete and You will be given rendered views of a generated 3D asset with black background.
whether the asset can be placed stably on the ground. Your task is to evaluate the quality of the 3D asset generation,
Return "YES" only if reach the requirments, including geometry, structure, and appearance, based on the rendered views.
otherwise "NO" and explain the reason very briefly. Criteria:
- Is the geometry complete and well-formed, without missing parts or redundant structures?
- Is the geometric structure of the object complete?
- Minor flaws, asymmetries, or simplifications (e.g., less detail on sides or back,
soft edges) are acceptable if the object is structurally sound and recognizable.
- Only evaluate geometry. Do not assess texture quality.
- The asset should not contain any unrelated elements, such as
ground planes, platforms, or background props (e.g., paper, flooring).
If all the above criteria are met, return "YES". Otherwise, return
"NO" followed by a brief explanation (no more than 20 words).
Example:
Images show a yellow cup standing on a flat white plane -> NO
-> Response: NO: extra white surface under the object.
Image shows a chair with simplified back legs and soft edges YES
""" """
def query(self, image_paths: str) -> str: def query(self, image_paths: list[str | Image.Image]) -> str:
# Hardcode tmp because of the openrouter can't input multi images.
if "openrouter" in self.gpt_client.endpoint:
from embodied_gen.utils.process_media import (
combine_images_to_base64,
)
image_paths = combine_images_to_base64(image_paths)
return self.gpt_client.query( return self.gpt_client.query(
text_prompt=self.prompt, text_prompt=self.prompt,
@ -137,14 +158,19 @@ class ImageSegChecker(BaseChecker):
self.gpt_client = gpt_client self.gpt_client = gpt_client
if self.prompt is None: if self.prompt is None:
self.prompt = """ self.prompt = """
The first image is the original, and the second image is the Task: Evaluate the quality of object segmentation between two images:
result after segmenting the main object. Evaluate the segmentation the first is the original, the second is the segmented result.
quality to ensure the main object is clearly segmented without
significant truncation. Note that the foreground of the object Criteria:
needs to be extracted instead of the background. - The main foreground object should be clearly extracted (not the background).
Minor imperfections can be ignored. If segmentation is acceptable, - The object must appear realistic, with reasonable geometry and color.
return "YES" only; otherwise, return "NO" with - The object should be geometrically complete no missing, truncated, or cropped parts.
very brief explanation. - The object must be centered, with a margin on all sides.
- Ignore minor imperfections (e.g., small holes or fine edge artifacts).
Output Rules:
If segmentation is acceptable, respond with "YES" (and nothing else).
If not acceptable, respond with "NO", followed by a brief reason (max 20 words).
""" """
def query(self, image_paths: list[str]) -> str: def query(self, image_paths: list[str]) -> str:
@ -152,13 +178,6 @@ class ImageSegChecker(BaseChecker):
raise ValueError( raise ValueError(
"ImageSegChecker requires exactly two images: [raw_image, seg_image]." # noqa "ImageSegChecker requires exactly two images: [raw_image, seg_image]." # noqa
) )
# Hardcode tmp because of the openrouter can't input multi images.
if "openrouter" in self.gpt_client.endpoint:
from embodied_gen.utils.process_media import (
combine_images_to_base64,
)
image_paths = combine_images_to_base64(image_paths)
return self.gpt_client.query( return self.gpt_client.query(
text_prompt=self.prompt, text_prompt=self.prompt,
@ -201,42 +220,204 @@ class ImageAestheticChecker(BaseChecker):
return avg_score > self.thresh, avg_score return avg_score > self.thresh, avg_score
if __name__ == "__main__": class SemanticConsistChecker(BaseChecker):
geo_checker = MeshGeoChecker(GPT_CLIENT) def __init__(
seg_checker = ImageSegChecker(GPT_CLIENT) self,
aesthetic_checker = ImageAestheticChecker() gpt_client: GPTclient,
prompt: str = None,
verbose: bool = False,
) -> None:
super().__init__(prompt, verbose)
self.gpt_client = gpt_client
if self.prompt is None:
self.prompt = """
You are an expert in image-text consistency assessment.
You will be given:
- A short text description of an object.
- An segmented image of the same object with the background removed.
checkers = [geo_checker, seg_checker, aesthetic_checker] Criteria:
- The image must visually match the text description in terms of object type, structure, geometry, and color.
- The object must appear realistic, with reasonable geometry (e.g., a table must have a stable number of legs).
- Geometric completeness is required: the object must not have missing, truncated, or cropped parts.
- The object must be centered in the image frame with clear margins on all sides,
it should not touch or nearly touch any image edge.
- The image must contain exactly one object. Multiple distinct objects are not allowed.
A single composite object (e.g., a chair with legs) is acceptable.
- The object should be shown from a slightly angled (three-quarter) perspective,
not a flat, front-facing view showing only one surface.
output_root = "outputs/test_gpt" Instructions:
- If all criteria are met, return `"YES"`.
- Otherwise, return "NO" with a brief explanation (max 20 words).
fails = [] Respond in exactly one of the following formats:
for idx in tqdm(range(150)): YES
mesh_path = f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}.obj" # noqa or
if not os.path.exists(mesh_path): NO: brief explanation.
continue
image_paths = render_asset3d( Input:
mesh_path, {}
f"{output_root}/{idx}", """
num_images=8,
elevation=(30, -30), def query(self, text: str, image: list[Image.Image | str]) -> str:
distance=5.5,
return self.gpt_client.query(
text_prompt=self.prompt.format(text),
image_base64=image,
) )
for cid, checker in enumerate(checkers):
if isinstance(checker, ImageSegChecker):
images = [
f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}_raw.png", # noqa
f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}_cond.png", # noqa
]
else:
images = image_paths
result, info = checker(images)
logger.info(
f"Checker {checker.__class__.__name__}: {result}, {info}, mesh {mesh_path}" # noqa
)
if result is False: class TextGenAlignChecker(BaseChecker):
fails.append((idx, cid, info)) def __init__(
self,
gpt_client: GPTclient,
prompt: str = None,
verbose: bool = False,
) -> None:
super().__init__(prompt, verbose)
self.gpt_client = gpt_client
if self.prompt is None:
self.prompt = """
You are an expert in evaluating the quality of generated 3D assets.
You will be given:
- A text description of an object: TEXT
- Rendered views of the generated 3D asset.
break Your task is to:
1. Determine whether the generated 3D asset roughly reflects the object class
or a semantically adjacent category described in the text.
2. Evaluate the geometry quality of the 3D asset generation based on the rendered views.
Criteria:
- Determine if the generated 3D asset belongs to the text described or a similar category.
- Focus on functional similarity: if the object serves the same general
purpose (e.g., writing, placing items), it should be accepted.
- Is the geometry complete and well-formed, with no missing parts,
distortions, visual artifacts, or redundant structures?
- Does the number of object instances match the description?
There should be only one object unless otherwise specified.
- Minor flaws in geometry or texture are acceptable, high tolerance for texture quality defects.
- Minor simplifications in geometry or texture (e.g. soft edges, less detail)
are acceptable if the object is still recognizable.
- The asset should not contain any unrelated elements, such as
ground planes, platforms, or background props (e.g., paper, flooring).
Example:
Text: "yellow cup"
Image: shows a yellow cup standing on a flat white plane -> NO: extra surface under the object.
Instructions:
- If the quality of generated asset is acceptable and faithfully represents the text, return "YES".
- Otherwise, return "NO" followed by a brief explanation (no more than 20 words).
Respond in exactly one of the following formats:
YES
or
NO: brief explanation
Input:
Text description: {}
"""
def query(self, text: str, image: list[Image.Image | str]) -> str:
return self.gpt_client.query(
text_prompt=self.prompt.format(text),
image_base64=image,
)
class SemanticMatcher(BaseChecker):
def __init__(
self,
gpt_client: GPTclient,
prompt: str = None,
verbose: bool = False,
seed: int = None,
) -> None:
super().__init__(prompt, verbose)
self.gpt_client = gpt_client
self.seed = seed
random.seed(seed)
if self.prompt is None:
self.prompt = """
You are an expert in semantic similarity and scene retrieval.
You will be given:
- A dictionary where each key is a scene ID, and each value is a scene description.
- A query text describing a target scene.
Your task:
return_num = 2
- Find the <return_num> most semantically similar scene IDs to the query text.
- If there are fewer than <return_num> distinct relevant matches, repeat the closest ones to make a list of <return_num>.
- Only output the list of <return_num> scene IDs, sorted from most to less similar.
- Do NOT use markdown, JSON code blocks, or any formatting syntax, only return a plain list like ["id1", ...].
Input example:
Dictionary:
"{{
"t_scene_008": "A study room with full bookshelves and a lamp in the corner.",
"t_scene_019": "A child's bedroom with pink walls and a small desk.",
"t_scene_020": "A living room with a wooden floor.",
"t_scene_021": "A living room with toys scattered on the floor.",
...
"t_scene_office_001": "A very spacious, modern open-plan office with wide desks and no people, panoramic view."
}}"
Text:
"A traditional indoor room"
Output:
'["t_scene_office_001", ...]'
Input:
Dictionary:
{context}
Text:
{text}
Output:
<topk_key_list>
"""
def query(
self, text: str, context: dict, rand: bool = True, params: dict = None
) -> str:
match_list = self.gpt_client.query(
self.prompt.format(context=context, text=text),
params=params,
)
match_list = json_repair.loads(match_list)
result = random.choice(match_list) if rand else match_list[0]
return result
def test_semantic_matcher(
bg_file: str = "outputs/bg_scenes/bg_scene_list.txt",
):
bg_file = "outputs/bg_scenes/bg_scene_list.txt"
scene_dict = {}
with open(bg_file, "r") as f:
for line in f:
line = line.strip()
if not line or ":" not in line:
continue
scene_id, desc = line.split(":", 1)
scene_dict[scene_id.strip()] = desc.strip()
office_scene = scene_dict.get("t_scene_office_001")
text = "bright kitchen"
SCENE_MATCHER = SemanticMatcher(GPT_CLIENT)
# gpt_params = {
# "temperature": 0.8,
# "max_tokens": 500,
# "top_p": 0.8,
# "frequency_penalty": 0.3,
# "presence_penalty": 0.3,
# }
gpt_params = None
match_key = SCENE_MATCHER.query(text, str(scene_dict))
print(match_key, ",", scene_dict[match_key])
if __name__ == "__main__":
test_semantic_matcher()

View File

@ -297,20 +297,24 @@ class URDFGenerator(object):
if not os.path.exists(urdf_path): if not os.path.exists(urdf_path):
raise FileNotFoundError(f"URDF file not found: {urdf_path}") raise FileNotFoundError(f"URDF file not found: {urdf_path}")
mesh_scale = 1.0 mesh_attr = None
tree = ET.parse(urdf_path) tree = ET.parse(urdf_path)
root = tree.getroot() root = tree.getroot()
extra_info = root.find(attr_root) extra_info = root.find(attr_root)
if extra_info is not None: if extra_info is not None:
scale_element = extra_info.find(attr_name) scale_element = extra_info.find(attr_name)
if scale_element is not None: if scale_element is not None:
mesh_scale = float(scale_element.text) mesh_attr = scale_element.text
try:
mesh_attr = float(mesh_attr)
except ValueError as e:
pass
return mesh_scale return mesh_attr
@staticmethod @staticmethod
def add_quality_tag( def add_quality_tag(
urdf_path: str, results, output_path: str = None urdf_path: str, results: list, output_path: str = None
) -> None: ) -> None:
if output_path is None: if output_path is None:
output_path = urdf_path output_path = urdf_path
@ -366,16 +370,9 @@ class URDFGenerator(object):
output_root, output_root,
num_images=self.render_view_num, num_images=self.render_view_num,
output_subdir=self.output_render_dir, output_subdir=self.output_render_dir,
no_index_file=True,
) )
# Hardcode tmp because of the openrouter can't input multi images.
if "openrouter" in self.gpt_client.endpoint:
from embodied_gen.utils.process_media import (
combine_images_to_base64,
)
image_path = combine_images_to_base64(image_path)
response = self.gpt_client.query(text_prompt, image_path) response = self.gpt_client.query(text_prompt, image_path)
if response is None: if response is None:
asset_attrs = { asset_attrs = {
@ -412,14 +409,18 @@ class URDFGenerator(object):
if __name__ == "__main__": if __name__ == "__main__":
urdf_gen = URDFGenerator(GPT_CLIENT, render_view_num=4) urdf_gen = URDFGenerator(GPT_CLIENT, render_view_num=4)
urdf_path = urdf_gen( urdf_path = urdf_gen(
mesh_path="outputs/imageto3d/cma/o5/URDF_o5/mesh/o5.obj", mesh_path="outputs/layout2/asset3d/marker/result/mesh/marker.obj",
output_root="outputs/test_urdf", output_root="outputs/test_urdf",
# category="coffee machine", category="marker",
# min_height=1.0, # min_height=1.0,
# max_height=1.2, # max_height=1.2,
version=VERSION, version=VERSION,
) )
URDFGenerator.add_quality_tag(
urdf_path, [[urdf_gen.__class__.__name__, "OK"]]
)
# zip_files( # zip_files(
# input_paths=[ # input_paths=[
# "scripts/apps/tmp/2umpdum3e5n/URDF_sample/mesh", # "scripts/apps/tmp/2umpdum3e5n/URDF_sample/mesh",

View File

@ -8,6 +8,12 @@ NC='\033[0m'
echo -e "${GREEN}Starting installation process...${NC}" echo -e "${GREEN}Starting installation process...${NC}"
git config --global http.postBuffer 524288000 git config --global http.postBuffer 524288000
echo -e "${GREEN}Installing flash-attn...${NC}"
pip install flash-attn==2.7.0.post2 --no-build-isolation || {
echo -e "${RED}Failed to install flash-attn${NC}"
exit 1
}
echo -e "${GREEN}Installing dependencies from requirements.txt...${NC}" echo -e "${GREEN}Installing dependencies from requirements.txt...${NC}"
pip install -r requirements.txt --use-deprecated=legacy-resolver --default-timeout=60 || { pip install -r requirements.txt --use-deprecated=legacy-resolver --default-timeout=60 || {
echo -e "${RED}Failed to install requirements${NC}" echo -e "${RED}Failed to install requirements${NC}"
@ -15,16 +21,16 @@ pip install -r requirements.txt --use-deprecated=legacy-resolver --default-timeo
} }
echo -e "${GREEN}Installing kaolin from GitHub...${NC}" echo -e "${GREEN}Installing kolors from GitHub...${NC}"
pip install kaolin@git+https://github.com/NVIDIAGameWorks/kaolin.git@v0.16.0 || { pip install kolors@git+https://github.com/Kwai-Kolors/Kolors.git#egg=038818d || {
echo -e "${RED}Failed to install kaolin${NC}" echo -e "${RED}Failed to install kolors${NC}"
exit 1 exit 1
} }
echo -e "${GREEN}Installing flash-attn...${NC}" echo -e "${GREEN}Installing kaolin from GitHub...${NC}"
pip install flash-attn==2.7.0.post2 --no-build-isolation || { pip install kaolin@git+https://github.com/NVIDIAGameWorks/kaolin.git@v0.16.0 || {
echo -e "${RED}Failed to install flash-attn${NC}" echo -e "${RED}Failed to install kaolin${NC}"
exit 1 exit 1
} }
@ -39,7 +45,6 @@ rm -rf "$TMP_DIR" || {
rm -rf "$TMP_DIR" rm -rf "$TMP_DIR"
exit 1 exit 1
} }
echo -e "${GREEN}Installation completed successfully!${NC}"
echo -e "${GREEN}Installing gsplat from GitHub...${NC}" echo -e "${GREEN}Installing gsplat from GitHub...${NC}"
@ -50,8 +55,9 @@ pip install git+https://github.com/nerfstudio-project/gsplat.git@v1.5.0 || {
echo -e "${GREEN}Installing EmbodiedGen...${NC}" echo -e "${GREEN}Installing EmbodiedGen...${NC}"
pip install triton==2.1.0
pip install -e . || { pip install -e . || {
echo -e "${RED}Failed to install local package${NC}" echo -e "${RED}Failed to install EmbodiedGen pyproject.toml${NC}"
exit 1 exit 1
} }

View File

@ -7,7 +7,7 @@ packages = ["embodied_gen"]
[project] [project]
name = "embodied_gen" name = "embodied_gen"
version = "v0.1.0" version = "v0.1.1"
readme = "README.md" readme = "README.md"
license = "Apache-2.0" license = "Apache-2.0"
license-files = ["LICENSE", "NOTICE"] license-files = ["LICENSE", "NOTICE"]
@ -17,16 +17,20 @@ requires-python = ">=3.10"
[project.optional-dependencies] [project.optional-dependencies]
dev = [ dev = [
"cpplint==2.0.0", "cpplint",
"pre-commit==2.13.0", "pre-commit",
"pydocstyle", "pydocstyle",
"black", "black",
"isort", "isort",
"pytest",
"pytest-mock",
] ]
[project.scripts] [project.scripts]
drender-cli = "embodied_gen.data.differentiable_render:entrypoint" drender-cli = "embodied_gen.data.differentiable_render:entrypoint"
backproject-cli = "embodied_gen.data.backproject_v2:entrypoint" backproject-cli = "embodied_gen.data.backproject_v2:entrypoint"
img3d-cli = "embodied_gen.scripts.imageto3d:entrypoint"
text3d-cli = "embodied_gen.scripts.textto3d:text_to_3d"
[tool.pydocstyle] [tool.pydocstyle]
match = '(?!test_).*(?!_pb2)\.py' match = '(?!test_).*(?!_pb2)\.py'

12
pytest.ini Normal file
View File

@ -0,0 +1,12 @@
[pytest]
testpaths =
tests/test_unit
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts = -ra --tb=short --disable-warnings --log-cli-level=INFO
filterwarnings =
ignore::DeprecationWarning
markers =
slow: marks tests as slow

View File

@ -8,10 +8,10 @@ triton==2.1.0
dataclasses_json dataclasses_json
easydict easydict
opencv-python>4.5 opencv-python>4.5
imageio==2.36.1 imageio
imageio-ffmpeg==0.5.1 imageio-ffmpeg
rembg==2.0.61 rembg==2.0.61
trimesh==4.4.4 trimesh
moviepy==1.0.3 moviepy==1.0.3
pymeshfix==0.17.0 pymeshfix==0.17.0
igraph==0.11.8 igraph==0.11.8
@ -20,21 +20,19 @@ openai==1.58.1
transformers==4.42.4 transformers==4.42.4
gradio==5.12.0 gradio==5.12.0
sentencepiece==0.2.0 sentencepiece==0.2.0
diffusers==0.31.0 diffusers==0.34.0
xatlas==0.0.9 xatlas
onnxruntime==1.20.1 onnxruntime==1.20.1
tenacity==8.2.2 tenacity
accelerate==0.33.0 accelerate==0.33.0
basicsr==1.4.2 basicsr==1.4.2
realesrgan==0.3.0 realesrgan==0.3.0
pydantic==2.9.2 pydantic==2.9.2
vtk==9.3.1 vtk==9.3.1
spaces spaces
colorlog
json-repair
utils3d@git+https://github.com/EasternJournalist/utils3d.git#egg=9a4eb15 utils3d@git+https://github.com/EasternJournalist/utils3d.git#egg=9a4eb15
clip@git+https://github.com/openai/CLIP.git clip@git+https://github.com/openai/CLIP.git
kolors@git+https://github.com/Kwai-Kolors/Kolors.git#egg=038818d
segment-anything@git+https://github.com/facebookresearch/segment-anything.git#egg=dca509f segment-anything@git+https://github.com/facebookresearch/segment-anything.git#egg=dca509f
nvdiffrast@git+https://github.com/NVlabs/nvdiffrast.git#egg=729261d nvdiffrast@git+https://github.com/NVlabs/nvdiffrast.git#egg=729261d
# https://github.com/nerfstudio-project/gsplat/releases/download/v1.5.0/gsplat-1.5.0+pt24cu118-cp310-cp310-linux_x86_64.whl
# https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.0.post2/flash_attn-2.7.0.post2+cu11torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
# https://huggingface.co/xinjjj/RoboAssetGen/resolve/main/wheel_cu118/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl

View File

@ -0,0 +1,31 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import logging
from embodied_gen.validators.aesthetic_predictor import AestheticPredictor
logger = logging.getLogger(__name__)
# @pytest.mark.manual
def test_aesthetic_predictor():
image_path = "apps/assets/example_image/sample_02.jpg"
predictor = AestheticPredictor(device="cpu")
score = predictor.predict(image_path)
assert isinstance(score, float)
logger.info(f"Aesthetic score: {score:.3f}")

View File

@ -0,0 +1,119 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import logging
import tempfile
import pytest
from embodied_gen.utils.gpt_clients import GPT_CLIENT
from embodied_gen.utils.process_media import render_asset3d
from embodied_gen.validators.quality_checkers import (
ImageAestheticChecker,
ImageSegChecker,
MeshGeoChecker,
SemanticConsistChecker,
TextGenAlignChecker,
)
logger = logging.getLogger(__name__)
@pytest.fixture(scope="module")
def geo_checker():
return MeshGeoChecker(GPT_CLIENT)
@pytest.fixture(scope="module")
def seg_checker():
return ImageSegChecker(GPT_CLIENT)
@pytest.fixture(scope="module")
def aesthetic_checker():
return ImageAestheticChecker()
@pytest.fixture(scope="module")
def semantic_checker():
return SemanticConsistChecker(GPT_CLIENT)
@pytest.fixture(scope="module")
def textalign_checker():
return TextGenAlignChecker(GPT_CLIENT)
def test_geo_checker(geo_checker):
flag, result = geo_checker(
[
"apps/assets/example_image/sample_02.jpg",
]
)
logger.info(f"geo_checker: {flag}, {result}")
assert isinstance(flag, bool)
assert isinstance(result, str)
def test_aesthetic_checker(aesthetic_checker):
flag, result = aesthetic_checker("apps/assets/example_image/sample_02.jpg")
logger.info(f"aesthetic_checker: {flag}, {result}")
assert isinstance(flag, bool)
assert isinstance(result, float)
def test_seg_checker(seg_checker):
flag, result = seg_checker(
[
"apps/assets/example_image/sample_02.jpg",
"apps/assets/example_image/sample_02.jpg",
]
)
logger.info(f"seg_checker: {flag}, {result}")
assert isinstance(flag, bool)
assert isinstance(result, str)
def test_semantic_checker(semantic_checker):
flag, result = semantic_checker(
text="can",
image=["apps/assets/example_image/sample_02.jpg"],
)
logger.info(f"semantic_checker: {flag}, {result}")
assert isinstance(flag, bool)
assert isinstance(result, str)
@pytest.mark.parametrize(
"mesh_path, text_desc",
[
("apps/assets/example_texture/meshes/chair.obj", "chair"),
("apps/assets/example_texture/meshes/clock.obj", "clock"),
],
)
def test_textgen_checker(textalign_checker, mesh_path, text_desc):
with tempfile.TemporaryDirectory() as output_root:
image_list = render_asset3d(
mesh_path,
output_root=output_root,
num_images=6,
elevation=(30, -30),
output_subdir="renders",
no_index_file=True,
with_mtl=False,
)
flag, result = textalign_checker(text_desc, image_list)
logger.info(f"textalign_checker: {flag}, {result}")

View File

@ -0,0 +1,95 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
from unittest.mock import patch
import pytest
from PIL import Image
from embodied_gen.utils.gpt_clients import GPT_CLIENT
from embodied_gen.validators.quality_checkers import (
ImageSegChecker,
MeshGeoChecker,
SemanticConsistChecker,
)
@pytest.fixture(autouse=True)
def gptclient_query():
with patch.object(
GPT_CLIENT, "query", return_value="mocked gpt response"
) as mock:
yield mock
@pytest.fixture()
def gptclient_query_case2():
with patch.object(GPT_CLIENT, "query", return_value=None) as mock:
yield mock
@pytest.mark.parametrize(
"input_images",
[
"dummy_path/color_grid_6view.png",
["dummy_path/color_grid_6view.jpg"],
[
"dummy_path/color_grid_6view.png",
"dummy_path/color_grid_6view2.png",
],
[
Image.new("RGB", (64, 64), "red"),
Image.new("RGB", (64, 64), "blue"),
],
],
)
def test_geo_checker_varied_inputs(input_images):
geo_checker = MeshGeoChecker(GPT_CLIENT)
flag, result = geo_checker(input_images)
assert isinstance(flag, (bool, type(None)))
assert isinstance(result, str)
def test_seg_checker():
seg_checker = ImageSegChecker(GPT_CLIENT)
flag, result = seg_checker(
[
"dummy_path/sample_0_raw.png", # raw image
"dummy_path/sample_0_cond.png", # segmented image
]
)
assert isinstance(flag, (bool, type(None)))
assert isinstance(result, str)
def test_semantic_checker():
semantic_checker = SemanticConsistChecker(GPT_CLIENT)
flag, result = semantic_checker(
text="pen",
image=["dummy_path/pen.png"],
)
assert isinstance(flag, (bool, type(None)))
assert isinstance(result, str)
def test_semantic_checker(gptclient_query_case2):
semantic_checker = SemanticConsistChecker(GPT_CLIENT)
flag, result = semantic_checker(
text="pen",
image=["dummy_path/pen.png"],
)
assert isinstance(flag, (bool, type(None)))
assert isinstance(result, str)

View File

@ -0,0 +1,94 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import os
from unittest.mock import patch
import pytest
import yaml
from PIL import Image
from embodied_gen.utils.gpt_clients import CONFIG_FILE, GPTclient
@pytest.fixture(scope="module")
def config():
with open(CONFIG_FILE, "r") as f:
return yaml.safe_load(f)
@pytest.fixture
def env_vars(monkeypatch, config):
agent_type = config["agent_type"]
agent_config = config.get(agent_type, {})
monkeypatch.setenv(
"ENDPOINT", agent_config.get("endpoint", "fake_endpoint")
)
monkeypatch.setenv("API_KEY", agent_config.get("api_key", "fake_api_key"))
monkeypatch.setenv("API_VERSION", agent_config.get("api_version", "v1"))
monkeypatch.setenv(
"MODEL_NAME", agent_config.get("model_name", "test_model")
)
yield
@pytest.fixture
def gpt_client(env_vars):
client = GPTclient(
endpoint=os.environ.get("ENDPOINT"),
api_key=os.environ.get("API_KEY"),
api_version=os.environ.get("API_VERSION"),
model_name=os.environ.get("MODEL_NAME"),
check_connection=False,
)
return client
@pytest.mark.parametrize(
"text_prompt, image_base64",
[
("What is the capital of China?", None),
(
"What is the content in each image?",
"apps/assets/example_image/sample_02.jpg",
),
(
"What is the content in each image?",
[
"apps/assets/example_image/sample_02.jpg",
"apps/assets/example_image/sample_03.jpg",
],
),
(
"What is the content in each image?",
[
Image.new("RGB", (64, 64), "red"),
Image.new("RGB", (64, 64), "blue"),
],
),
],
)
def test_gptclient_query(gpt_client, text_prompt, image_base64):
# mock GPTclient.query
with patch.object(
GPTclient, "query", return_value="mocked response"
) as mock_query:
response = gpt_client.query(
text_prompt=text_prompt, image_base64=image_base64
)
assert response == "mocked response"
mock_query.assert_called_once_with(
text_prompt=text_prompt, image_base64=image_base64
)