From a3924ae4a852b361fbcd2565101e90b4fa5081d1 Mon Sep 17 00:00:00 2001 From: Xinjie Date: Fri, 11 Jul 2025 16:51:35 +0800 Subject: [PATCH] feat(pipe): Refine generation pipeline and add utests. (#23) - Added intelligent quality checkers and auto-retry pipeline for `image-to-3d` and `text-to-3d`. - Added unit tests for quality checkers. - `text-to-3d` now supports more `text-to-image` models, pipeline success rate improved to 94%. --- CHANGELOG.md | 29 ++ README.md | 44 ++- apps/common.py | 2 +- embodied_gen/data/differentiable_render.py | 11 +- embodied_gen/data/mesh_operator.py | 2 + embodied_gen/data/utils.py | 14 - embodied_gen/models/image_comm_model.py | 236 +++++++++++++ embodied_gen/models/text_model.py | 11 +- embodied_gen/scripts/imageto3d.py | 186 +++++----- embodied_gen/scripts/text2image.py | 2 +- embodied_gen/scripts/textto3d.py | 271 +++++++++++++++ embodied_gen/scripts/textto3d.sh | 49 ++- embodied_gen/scripts/texture_gen.sh | 9 +- embodied_gen/utils/enum.py | 107 ++++++ embodied_gen/utils/gpt_clients.py | 84 +++-- embodied_gen/utils/log.py | 48 +++ embodied_gen/utils/process_media.py | 273 ++++++++++----- embodied_gen/utils/tags.py | 2 +- embodied_gen/utils/trender.py | 90 +++++ .../validators/aesthetic_predictor.py | 14 +- embodied_gen/validators/quality_checkers.py | 321 ++++++++++++++---- embodied_gen/validators/urdf_convertor.py | 29 +- install.sh | 22 +- pyproject.toml | 10 +- pytest.ini | 12 + requirements.txt | 18 +- .../test_examples/test_aesthetic_predictor.py | 31 ++ tests/test_examples/test_quality_checkers.py | 119 +++++++ tests/test_unit/test_agents.py | 95 ++++++ tests/test_unit/test_gpt_client.py | 94 +++++ 30 files changed, 1863 insertions(+), 372 deletions(-) create mode 100644 CHANGELOG.md create mode 100644 embodied_gen/models/image_comm_model.py create mode 100644 embodied_gen/scripts/textto3d.py create mode 100644 embodied_gen/utils/enum.py create mode 100644 embodied_gen/utils/log.py create mode 100644 embodied_gen/utils/trender.py create mode 100644 pytest.ini create mode 100644 tests/test_examples/test_aesthetic_predictor.py create mode 100644 tests/test_examples/test_quality_checkers.py create mode 100644 tests/test_unit/test_agents.py create mode 100644 tests/test_unit/test_gpt_client.py diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..799e178 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,29 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0). + + +## [0.1.1] - 2025-07-xx +### Feature +- Added intelligent quality checkers and auto-retry pipeline for `image-to-3d` and `text-to-3d`. +- Added unit tests for quality checkers. +- `text-to-3d` now supports more `text-to-image` models, pipeline success rate improved to 94%. + +## [0.1.0] - 2025-07-04 +### Feature +๐Ÿ–ผ๏ธ Single Image to Physics Realistic 3D Asset +- Generates watertight, simulation-ready 3D meshes with physical attributes. +- Auto-labels semantic and quality tags (geometry, texture, foreground, etc.). +- Produces 2K textures with highlight removal and multi-view fusion. + +๐Ÿ“ Text-to-3D Asset Generation +- Creates realistic 3D assets from natural language (English & Chinese). +- Filters assets via QA tags to ensure visual and geometric quality. + +๐ŸŽจ Texture Generation & Editing +- Generates 2K textures from mesh and text with semantic alignment. +- Plug-and-play modules adapt text-to-image models for 3D textures. +- Supports realistic and stylized texture outputs, including text textures. + diff --git a/README.md b/README.md index 20bcba0..0cfc0cd 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ [![๐Ÿค— Hugging Face](https://img.shields.io/badge/๐Ÿค—-Image_to_3D_Demo-blue)](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Image-to-3D) [![๐Ÿค— Hugging Face](https://img.shields.io/badge/๐Ÿค—-Text_to_3D_Demo-blue)](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Text-to-3D) [![๐Ÿค— Hugging Face](https://img.shields.io/badge/๐Ÿค—-Texture_Gen_Demo-blue)](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Texture-Gen) +[![ไธญๆ–‡ไป‹็ป](https://img.shields.io/badge/ไธญๆ–‡ไป‹็ป-07C160?logo=wechat&logoColor=white)](https://mp.weixin.qq.com/s/HH1cPBhK2xcDbyCK4BBTbw) > ***EmbodiedGen*** is a generative engine to create diverse and interactive 3D worlds composed of high-quality 3D assets(mesh & 3DGS) with plausible physics, leveraging generative AI to address the challenges of generalization in embodied intelligence related research. @@ -29,7 +30,7 @@ ```sh git clone https://github.com/HorizonRobotics/EmbodiedGen.git cd EmbodiedGen -git checkout v0.1.0 +git checkout v0.1.1 git submodule update --init --recursive --progress conda create -n embodiedgen python=3.10.13 -y conda activate embodiedgen @@ -67,9 +68,8 @@ CUDA_VISIBLE_DEVICES=0 nohup python apps/image_to_3d.py > /dev/null 2>&1 & ### โšก API Generate physically plausible 3D assets from image input via the command-line API. ```sh -python3 embodied_gen/scripts/imageto3d.py \ - --image_path apps/assets/example_image/sample_04.jpg apps/assets/example_image/sample_19.jpg \ - --output_root outputs/imageto3d +img3d-cli --image_path apps/assets/example_image/sample_04.jpg apps/assets/example_image/sample_19.jpg \ + --n_retry 2 --output_root outputs/imageto3d # See result(.urdf/mesh.obj/mesh.glb/gs.ply) in ${output_root}/sample_xx/result ``` @@ -86,18 +86,29 @@ python3 embodied_gen/scripts/imageto3d.py \ ### โ˜๏ธ Service Deploy the text-to-3D generation service locally. -Text-to-image based on the Kolors model, supporting Chinese and English prompts. -Models downloaded automatically on first run, see `download_kolors_weights`, please be patient. +Text-to-image model based on the Kolors model, supporting Chinese and English prompts. +Models downloaded automatically on first run, please be patient. ```sh python apps/text_to_3d.py ``` ### โšก API -Text-to-image based on the Kolors model. +Text-to-image model based on SD3.5 Medium, English prompts only. +Usage requires agreement to the [model license(click accept)](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium), models downloaded automatically. (ps: models with more permissive licenses found in `embodied_gen/models/image_comm_model.py`) + +For large-scale 3D assets generation, set `--n_pipe_retry=2` to ensure high end-to-end 3D asset usability through automatic quality check and retries. For more diverse results, do not set `--seed_img`. + +```sh +text3d-cli --prompts "small bronze figurine of a lion" "A globe with wooden base" "wooden table with embroidery" \ + --n_image_retry 2 --n_asset_retry 2 --n_pipe_retry 1 --seed_img 0 \ + --output_root outputs/textto3d +``` + +Text-to-image model based on the Kolors model. ```sh bash embodied_gen/scripts/textto3d.sh \ --prompts "small bronze figurine of a lion" "A globe with wooden base and latitude and longitude lines" "ๆฉ™่‰ฒ็”ตๅŠจๆ‰‹้’ป๏ผŒๆœ‰็ฃจๆŸ็ป†่Š‚" \ - --output_root outputs/textto3d + --output_root outputs/textto3d_k ``` --- @@ -118,12 +129,17 @@ python apps/texture_edit.py ``` ### โšก API +Support Chinese and English prompts. ```sh bash embodied_gen/scripts/texture_gen.sh \ --mesh_path "apps/assets/example_texture/meshes/robot_text.obj" \ --prompt "ไธพ็€็‰Œๅญ็š„ๅ†™ๅฎž้ฃŽๆ ผๆœบๅ™จไบบ๏ผŒๅคง็œผ็›๏ผŒ็‰ŒๅญไธŠๅ†™็€โ€œHelloโ€็š„ๆ–‡ๅญ—" \ - --output_root "outputs/texture_gen/" \ - --uuid "robot_text" + --output_root "outputs/texture_gen/robot_text" + +bash embodied_gen/scripts/texture_gen.sh \ + --mesh_path "apps/assets/example_texture/meshes/horse.obj" \ + --prompt "A gray horse head with flying mane and brown eyes" \ + --output_root "outputs/texture_gen/gray_horse" ``` --- @@ -171,6 +187,12 @@ bash embodied_gen/scripts/texture_gen.sh \ --- +## For Developer +```sh +pip install .[dev] && pre-commit install +python -m pytest # Pass all unit-test are required. +``` + ## ๐Ÿ“š Citation If you use EmbodiedGen in your research or projects, please cite: @@ -192,7 +214,7 @@ If you use EmbodiedGen in your research or projects, please cite: ## ๐Ÿ™Œ Acknowledgement EmbodiedGen builds upon the following amazing projects and models: -๐ŸŒŸ [Trellis](https://github.com/microsoft/TRELLIS) | ๐ŸŒŸ [Hunyuan-Delight](https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0) | ๐ŸŒŸ [Segment Anything](https://github.com/facebookresearch/segment-anything) | ๐ŸŒŸ [Rembg](https://github.com/danielgatis/rembg) | ๐ŸŒŸ [RMBG-1.4](https://huggingface.co/briaai/RMBG-1.4) | ๐ŸŒŸ [Stable Diffusion x4](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) | ๐ŸŒŸ [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) | ๐ŸŒŸ [Kolors](https://github.com/Kwai-Kolors/Kolors) | ๐ŸŒŸ [ChatGLM3](https://github.com/THUDM/ChatGLM3) | ๐ŸŒŸ [Aesthetic Score](http://captions.christoph-schuhmann.de/aesthetic_viz_laion_sac+logos+ava1-l14-linearMSE-en-2.37B.html) | ๐ŸŒŸ [Pano2Room](https://github.com/TrickyGo/Pano2Room) | ๐ŸŒŸ [Diffusion360](https://github.com/ArcherFMY/SD-T2I-360PanoImage) | ๐ŸŒŸ [Kaolin](https://github.com/NVIDIAGameWorks/kaolin) | ๐ŸŒŸ [diffusers](https://github.com/huggingface/diffusers) | ๐ŸŒŸ [gsplat](https://github.com/nerfstudio-project/gsplat) | ๐ŸŒŸ [QWEN2.5VL](https://github.com/QwenLM/Qwen2.5-VL) | ๐ŸŒŸ [GPT4o](https://platform.openai.com/docs/models/gpt-4o) +๐ŸŒŸ [Trellis](https://github.com/microsoft/TRELLIS) | ๐ŸŒŸ [Hunyuan-Delight](https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0) | ๐ŸŒŸ [Segment Anything](https://github.com/facebookresearch/segment-anything) | ๐ŸŒŸ [Rembg](https://github.com/danielgatis/rembg) | ๐ŸŒŸ [RMBG-1.4](https://huggingface.co/briaai/RMBG-1.4) | ๐ŸŒŸ [Stable Diffusion x4](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) | ๐ŸŒŸ [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) | ๐ŸŒŸ [Kolors](https://github.com/Kwai-Kolors/Kolors) | ๐ŸŒŸ [ChatGLM3](https://github.com/THUDM/ChatGLM3) | ๐ŸŒŸ [Aesthetic Score](http://captions.christoph-schuhmann.de/aesthetic_viz_laion_sac+logos+ava1-l14-linearMSE-en-2.37B.html) | ๐ŸŒŸ [Pano2Room](https://github.com/TrickyGo/Pano2Room) | ๐ŸŒŸ [Diffusion360](https://github.com/ArcherFMY/SD-T2I-360PanoImage) | ๐ŸŒŸ [Kaolin](https://github.com/NVIDIAGameWorks/kaolin) | ๐ŸŒŸ [diffusers](https://github.com/huggingface/diffusers) | ๐ŸŒŸ [gsplat](https://github.com/nerfstudio-project/gsplat) | ๐ŸŒŸ [QWEN-2.5VL](https://github.com/QwenLM/Qwen2.5-VL) | ๐ŸŒŸ [GPT4o](https://platform.openai.com/docs/models/gpt-4o) | ๐ŸŒŸ [SD3.5](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium) --- diff --git a/apps/common.py b/apps/common.py index 9ecea89..f7c7f40 100644 --- a/apps/common.py +++ b/apps/common.py @@ -55,9 +55,9 @@ from embodied_gen.utils.gpt_clients import GPT_CLIENT from embodied_gen.utils.process_media import ( filter_image_small_connected_components, merge_images_video, - render_video, ) from embodied_gen.utils.tags import VERSION +from embodied_gen.utils.trender import render_video from embodied_gen.validators.quality_checkers import ( BaseChecker, ImageAestheticChecker, diff --git a/embodied_gen/data/differentiable_render.py b/embodied_gen/data/differentiable_render.py index 1762439..2ad386c 100644 --- a/embodied_gen/data/differentiable_render.py +++ b/embodied_gen/data/differentiable_render.py @@ -33,7 +33,6 @@ from tqdm import tqdm from embodied_gen.data.utils import ( CameraSetting, DiffrastRender, - RenderItems, as_list, calc_vertex_normals, import_kaolin_mesh, @@ -42,6 +41,7 @@ from embodied_gen.data.utils import ( render_pbr, save_images, ) +from embodied_gen.utils.enum import RenderItems os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser( @@ -470,7 +470,7 @@ def parse_args(): "--pbr_light_factor", type=float, default=1.0, - help="Light factor for mesh PBR rendering (default: 2.)", + help="Light factor for mesh PBR rendering (default: 1.)", ) parser.add_argument( "--with_mtl", @@ -482,6 +482,11 @@ def parse_args(): action="store_true", help="Whether to generate color .gif rendering file.", ) + parser.add_argument( + "--no_index_file", + action="store_true", + help="Whether skip the index file saving.", + ) parser.add_argument( "--gen_color_mp4", action="store_true", @@ -568,7 +573,7 @@ def entrypoint(**kwargs) -> None: gen_viewnormal_mp4=args.gen_viewnormal_mp4, gen_glonormal_mp4=args.gen_glonormal_mp4, light_factor=args.pbr_light_factor, - no_index_file=gen_video, + no_index_file=gen_video or args.no_index_file, ) image_render.render_mesh( mesh_path=args.mesh_path, diff --git a/embodied_gen/data/mesh_operator.py b/embodied_gen/data/mesh_operator.py index 888b203..38f0563 100644 --- a/embodied_gen/data/mesh_operator.py +++ b/embodied_gen/data/mesh_operator.py @@ -395,6 +395,8 @@ class MeshFixer(object): self.vertices_np, np.hstack([np.full((self.faces.shape[0], 1), 3), self.faces_np]), ) + mesh.clean(inplace=True) + mesh.clear_data() mesh = mesh.decimate(ratio, progress_bar=True) # Update vertices and faces diff --git a/embodied_gen/data/utils.py b/embodied_gen/data/utils.py index b8d632c..83cb39d 100644 --- a/embodied_gen/data/utils.py +++ b/embodied_gen/data/utils.py @@ -38,7 +38,6 @@ except ImportError: ChatGLMModel = None import logging from dataclasses import dataclass, field -from enum import Enum import trimesh from kaolin.render.camera import Camera @@ -57,7 +56,6 @@ __all__ = [ "load_mesh_to_unit_cube", "as_list", "CameraSetting", - "RenderItems", "import_kaolin_mesh", "save_mesh_with_mtl", "get_images_from_grid", @@ -738,18 +736,6 @@ class CameraSetting: self.Ks = Ks -@dataclass -class RenderItems(str, Enum): - IMAGE = "image_color" - ALPHA = "image_mask" - VIEW_NORMAL = "image_view_normal" - GLOBAL_NORMAL = "image_global_normal" - POSITION_MAP = "image_position" - DEPTH = "image_depth" - ALBEDO = "image_albedo" - DIFFUSE = "image_diffuse" - - def _compute_az_el_by_camera_params( camera_params: CameraSetting, flip_az: bool = False ): diff --git a/embodied_gen/models/image_comm_model.py b/embodied_gen/models/image_comm_model.py new file mode 100644 index 0000000..7a8c30c --- /dev/null +++ b/embodied_gen/models/image_comm_model.py @@ -0,0 +1,236 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. +# Text-to-Image generation models from Hugging Face community. + +import os +from abc import ABC, abstractmethod + +import torch +from diffusers import ( + ChromaPipeline, + Cosmos2TextToImagePipeline, + DPMSolverMultistepScheduler, + FluxPipeline, + KolorsPipeline, + StableDiffusion3Pipeline, +) +from diffusers.quantizers import PipelineQuantizationConfig +from huggingface_hub import snapshot_download +from PIL import Image +from transformers import AutoModelForCausalLM, SiglipProcessor + +__all__ = [ + "build_hf_image_pipeline", +] + + +class BasePipelineLoader(ABC): + def __init__(self, device="cuda"): + self.device = device + + @abstractmethod + def load(self): + pass + + +class BasePipelineRunner(ABC): + def __init__(self, pipe): + self.pipe = pipe + + @abstractmethod + def run(self, prompt: str, **kwargs) -> Image.Image: + pass + + +# ===== SD3.5-medium ===== +class SD35Loader(BasePipelineLoader): + def load(self): + pipe = StableDiffusion3Pipeline.from_pretrained( + "stabilityai/stable-diffusion-3.5-medium", + torch_dtype=torch.float16, + ) + pipe = pipe.to(self.device) + pipe.enable_model_cpu_offload() + pipe.enable_xformers_memory_efficient_attention() + pipe.enable_attention_slicing() + return pipe + + +class SD35Runner(BasePipelineRunner): + def run(self, prompt: str, **kwargs) -> Image.Image: + return self.pipe(prompt=prompt, **kwargs).images + + +# ===== Cosmos2 ===== +class CosmosLoader(BasePipelineLoader): + def __init__( + self, + model_id="nvidia/Cosmos-Predict2-2B-Text2Image", + local_dir="weights/cosmos2", + device="cuda", + ): + super().__init__(device) + self.model_id = model_id + self.local_dir = local_dir + + def _patch(self): + def patch_model(cls): + orig = cls.from_pretrained + + def new(*args, **kwargs): + kwargs.setdefault("attn_implementation", "flash_attention_2") + kwargs.setdefault("torch_dtype", torch.bfloat16) + return orig(*args, **kwargs) + + cls.from_pretrained = new + + def patch_processor(cls): + orig = cls.from_pretrained + + def new(*args, **kwargs): + kwargs.setdefault("use_fast", True) + return orig(*args, **kwargs) + + cls.from_pretrained = new + + patch_model(AutoModelForCausalLM) + patch_processor(SiglipProcessor) + + def load(self): + self._patch() + snapshot_download( + repo_id=self.model_id, + local_dir=self.local_dir, + local_dir_use_symlinks=False, + resume_download=True, + ) + + config = PipelineQuantizationConfig( + quant_backend="bitsandbytes_4bit", + quant_kwargs={ + "load_in_4bit": True, + "bnb_4bit_quant_type": "nf4", + "bnb_4bit_compute_dtype": torch.bfloat16, + "bnb_4bit_use_double_quant": True, + }, + components_to_quantize=["text_encoder", "transformer", "unet"], + ) + + pipe = Cosmos2TextToImagePipeline.from_pretrained( + self.model_id, + torch_dtype=torch.bfloat16, + quantization_config=config, + use_safetensors=True, + safety_checker=None, + requires_safety_checker=False, + ).to(self.device) + return pipe + + +class CosmosRunner(BasePipelineRunner): + def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image: + return self.pipe( + prompt=prompt, negative_prompt=negative_prompt, **kwargs + ).images + + +# ===== Kolors ===== +class KolorsLoader(BasePipelineLoader): + def load(self): + pipe = KolorsPipeline.from_pretrained( + "Kwai-Kolors/Kolors-diffusers", + torch_dtype=torch.float16, + variant="fp16", + ).to(self.device) + pipe.enable_model_cpu_offload() + pipe.enable_xformers_memory_efficient_attention() + pipe.scheduler = DPMSolverMultistepScheduler.from_config( + pipe.scheduler.config, use_karras_sigmas=True + ) + return pipe + + +class KolorsRunner(BasePipelineRunner): + def run(self, prompt: str, **kwargs) -> Image.Image: + return self.pipe(prompt=prompt, **kwargs).images + + +# ===== Flux ===== +class FluxLoader(BasePipelineLoader): + def load(self): + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' + pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16 + ) + pipe.enable_model_cpu_offload() + pipe.enable_xformers_memory_efficient_attention() + pipe.enable_attention_slicing() + return pipe.to(self.device) + + +class FluxRunner(BasePipelineRunner): + def run(self, prompt: str, **kwargs) -> Image.Image: + return self.pipe(prompt=prompt, **kwargs).images + + +# ===== Chroma ===== +class ChromaLoader(BasePipelineLoader): + def load(self): + return ChromaPipeline.from_pretrained( + "lodestones/Chroma", torch_dtype=torch.bfloat16 + ).to(self.device) + + +class ChromaRunner(BasePipelineRunner): + def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image: + return self.pipe( + prompt=prompt, negative_prompt=negative_prompt, **kwargs + ).images + + +PIPELINE_REGISTRY = { + "sd35": (SD35Loader, SD35Runner), + "cosmos": (CosmosLoader, CosmosRunner), + "kolors": (KolorsLoader, KolorsRunner), + "flux": (FluxLoader, FluxRunner), + "chroma": (ChromaLoader, ChromaRunner), +} + + +def build_hf_image_pipeline(name: str, device="cuda") -> BasePipelineRunner: + if name not in PIPELINE_REGISTRY: + raise ValueError(f"Unsupported model: {name}") + loader_cls, runner_cls = PIPELINE_REGISTRY[name] + pipe = loader_cls(device=device).load() + + return runner_cls(pipe) + + +if __name__ == "__main__": + model_name = "sd35" + runner = build_hf_image_pipeline(model_name) + # NOTE: Just for pipeline testing, generation quality at low resolution is poor. + images = runner.run( + prompt="A robot holding a sign that says 'Hello'", + height=512, + width=512, + num_inference_steps=10, + guidance_scale=6, + num_images_per_prompt=1, + ) + + for i, img in enumerate(images): + img.save(f"image_{model_name}_{i}.jpg") diff --git a/embodied_gen/models/text_model.py b/embodied_gen/models/text_model.py index 7ea8c4c..3ad44f4 100644 --- a/embodied_gen/models/text_model.py +++ b/embodied_gen/models/text_model.py @@ -52,8 +52,11 @@ __all__ = [ "download_kolors_weights", ] - -PROMPT_APPEND = "Full view of one {}, no cropping, centered, no occlusion, isolated product photo, matte, 3D style, on a plain clean surface" +PROMPT_APPEND = ( + "Angled 3D view of one {object}, centered, no cropping, no occlusion, isolated product photo, " + "no surroundings, matte, on a plain clean surface, 3D style revealing multiple surfaces" +) +PROMPT_KAPPEND = "Single {object}, in the center of the image, white background, 3D style, best quality" def download_kolors_weights(local_dir: str = "weights/Kolors") -> None: @@ -182,9 +185,7 @@ def text2img_gen( ip_image_size: int = 512, seed: int = None, ) -> list[Image.Image]: - # prompt = "Single " + prompt + ", in the center of the image" - # prompt += ", high quality, high resolution, best quality, white background, 3D style" # noqa - prompt = PROMPT_APPEND.format(prompt.strip()) + prompt = PROMPT_KAPPEND.format(object=prompt.strip()) logger.info(f"Processing prompt: {prompt}") generator = None diff --git a/embodied_gen/scripts/imageto3d.py b/embodied_gen/scripts/imageto3d.py index 847252c..00a0958 100644 --- a/embodied_gen/scripts/imageto3d.py +++ b/embodied_gen/scripts/imageto3d.py @@ -16,13 +16,14 @@ import argparse -import logging import os +import random import sys from glob import glob from shutil import copy, copytree, rmtree import numpy as np +import torch import trimesh from PIL import Image from embodied_gen.data.backproject_v2 import entrypoint as backproject_api @@ -37,8 +38,10 @@ from embodied_gen.models.segment_model import ( from embodied_gen.models.sr_model import ImageRealESRGAN from embodied_gen.scripts.render_gs import entrypoint as render_gs_api from embodied_gen.utils.gpt_clients import GPT_CLIENT -from embodied_gen.utils.process_media import merge_images_video, render_video +from embodied_gen.utils.log import logger +from embodied_gen.utils.process_media import merge_images_video from embodied_gen.utils.tags import VERSION +from embodied_gen.utils.trender import render_video from embodied_gen.validators.quality_checkers import ( BaseChecker, ImageAestheticChecker, @@ -52,19 +55,14 @@ current_dir = os.path.dirname(current_file_path) sys.path.append(os.path.join(current_dir, "../..")) from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline -logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO -) -logger = logging.getLogger(__name__) - - os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser( "~/.cache/torch_extensions" ) os.environ["GRADIO_ANALYTICS_ENABLED"] = "false" os.environ["SPCONV_ALGO"] = "native" +random.seed(0) - +logger.info("Loading Models...") DELIGHT = DelightingModel() IMAGESR_MODEL = ImageRealESRGAN(outscale=4) @@ -74,7 +72,7 @@ SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu") PIPELINE = TrellisImageTo3DPipeline.from_pretrained( "microsoft/TRELLIS-image-large" ) -PIPELINE.cuda() +# PIPELINE.cuda() SEG_CHECKER = ImageSegChecker(GPT_CLIENT) GEO_CHECKER = MeshGeoChecker(GPT_CLIENT) AESTHETIC_CHECKER = ImageAestheticChecker() @@ -95,7 +93,6 @@ def parse_args(): parser.add_argument( "--output_root", type=str, - required=True, help="Root directory for saving outputs.", ) parser.add_argument( @@ -110,12 +107,26 @@ def parse_args(): default=None, help="The mass in kg to restore the mesh real weight.", ) - parser.add_argument("--asset_type", type=str, default=None) + parser.add_argument("--asset_type", type=str, nargs="+", default=None) parser.add_argument("--skip_exists", action="store_true") - parser.add_argument("--strict_seg", action="store_true") parser.add_argument("--version", type=str, default=VERSION) - parser.add_argument("--remove_intermediate", type=bool, default=True) - args = parser.parse_args() + parser.add_argument("--keep_intermediate", action="store_true") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--n_retry", + type=int, + default=2, + ) + args, unknown = parser.parse_known_args() + + return args + + +def entrypoint(**kwargs): + args = parse_args() + for k, v in kwargs.items(): + if hasattr(args, k) and v is not None: + setattr(args, k, v) assert ( args.image_path or args.image_root @@ -125,13 +136,7 @@ def parse_args(): args.image_path += glob(os.path.join(args.image_root, "*.jpg")) args.image_path += glob(os.path.join(args.image_root, "*.jpeg")) - return args - - -if __name__ == "__main__": - args = parse_args() - - for image_path in args.image_path: + for idx, image_path in enumerate(args.image_path): try: filename = os.path.basename(image_path).split(".")[0] output_root = args.output_root @@ -141,7 +146,7 @@ if __name__ == "__main__": mesh_out = f"{output_root}/{filename}.obj" if args.skip_exists and os.path.exists(mesh_out): - logger.info( + logger.warning( f"Skip {image_path}, already processed in {mesh_out}" ) continue @@ -149,67 +154,84 @@ if __name__ == "__main__": image = Image.open(image_path) image.save(f"{output_root}/{filename}_raw.png") - # Segmentation: Get segmented image using SAM or Rembg. + # Segmentation: Get segmented image using Rembg. seg_path = f"{output_root}/{filename}_cond.png" - if image.mode != "RGBA": - seg_image = RBG_REMOVER(image, save_path=seg_path) - seg_image = trellis_preprocess(seg_image) - else: - seg_image = image - seg_image.save(seg_path) + seg_image = RBG_REMOVER(image) if image.mode != "RGBA" else image + seg_image = trellis_preprocess(seg_image) + seg_image.save(seg_path) - # Run the pipeline - try: - outputs = PIPELINE.run( - seg_image, - preprocess_image=False, - # Optional parameters - # seed=1, - # sparse_structure_sampler_params={ - # "steps": 12, - # "cfg_strength": 7.5, - # }, - # slat_sampler_params={ - # "steps": 12, - # "cfg_strength": 3, - # }, + seed = args.seed + for try_idx in range(args.n_retry): + logger.info( + f"Try: {try_idx + 1}/{args.n_retry}, Seed: {seed}, Prompt: {seg_path}" ) - except Exception as e: - logger.error( - f"[Pipeline Failed] process {image_path}: {e}, skip." - ) - continue + # Run the pipeline + try: + PIPELINE.cuda() + outputs = PIPELINE.run( + seg_image, + preprocess_image=False, + seed=( + random.randint(0, 100000) if seed is None else seed + ), + # Optional parameters + # sparse_structure_sampler_params={ + # "steps": 12, + # "cfg_strength": 7.5, + # }, + # slat_sampler_params={ + # "steps": 12, + # "cfg_strength": 3, + # }, + ) + PIPELINE.cpu() + torch.cuda.empty_cache() + except Exception as e: + logger.error( + f"[Pipeline Failed] process {image_path}: {e}, skip." + ) + continue - # Render and save color and mesh videos - gs_model = outputs["gaussian"][0] - mesh_model = outputs["mesh"][0] + gs_model = outputs["gaussian"][0] + mesh_model = outputs["mesh"][0] + + # Save the raw Gaussian model + gs_path = mesh_out.replace(".obj", "_gs.ply") + gs_model.save_ply(gs_path) + + # Rotate mesh and GS by 90 degrees around Z-axis. + rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]] + gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]] + mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]] + + # Addtional rotation for GS to align mesh. + gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix) + pose = GaussianOperator.trans_to_quatpose(gs_rot) + aligned_gs_path = gs_path.replace(".ply", "_aligned.ply") + GaussianOperator.resave_ply( + in_ply=gs_path, + out_ply=aligned_gs_path, + instance_pose=pose, + device="cpu", + ) + color_path = os.path.join(output_root, "color.png") + render_gs_api(aligned_gs_path, color_path) + + geo_flag, geo_result = GEO_CHECKER([color_path]) + logger.warning( + f"{GEO_CHECKER.__class__.__name__}: {geo_result} for {seg_path}" + ) + if geo_flag is True or geo_flag is None: + break + + seed = random.randint(0, 100000) if seed is not None else None + + # Render the video for generated 3D asset. color_images = render_video(gs_model)["color"] normal_images = render_video(mesh_model)["normal"] video_path = os.path.join(output_root, "gs_mesh.mp4") merge_images_video(color_images, normal_images, video_path) - # Save the raw Gaussian model - gs_path = mesh_out.replace(".obj", "_gs.ply") - gs_model.save_ply(gs_path) - - # Rotate mesh and GS by 90 degrees around Z-axis. - rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]] - gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]] - mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]] - - # Addtional rotation for GS to align mesh. - gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix) - pose = GaussianOperator.trans_to_quatpose(gs_rot) - aligned_gs_path = gs_path.replace(".ply", "_aligned.ply") - GaussianOperator.resave_ply( - in_ply=gs_path, - out_ply=aligned_gs_path, - instance_pose=pose, - device="cpu", - ) - color_path = os.path.join(output_root, "color.png") - render_gs_api(aligned_gs_path, color_path) - mesh = trimesh.Trimesh( vertices=mesh_model.vertices.cpu().numpy(), faces=mesh_model.faces.cpu().numpy(), @@ -249,8 +271,8 @@ if __name__ == "__main__": min_mass, max_mass = map(float, args.mass_range.split("-")) asset_attrs["min_mass"] = min_mass asset_attrs["max_mass"] = max_mass - if args.asset_type: - asset_attrs["category"] = args.asset_type + if isinstance(args.asset_type, list) and args.asset_type[idx]: + asset_attrs["category"] = args.asset_type[idx] if args.version: asset_attrs["version"] = args.version @@ -289,8 +311,8 @@ if __name__ == "__main__": ] images_list.append(images) - results = BaseChecker.validate(CHECKERS, images_list) - urdf_convertor.add_quality_tag(urdf_path, results) + qa_results = BaseChecker.validate(CHECKERS, images_list) + urdf_convertor.add_quality_tag(urdf_path, qa_results) # Organize the final result files result_dir = f"{output_root}/result" @@ -303,7 +325,7 @@ if __name__ == "__main__": f"{result_dir}/{urdf_convertor.output_mesh_dir}", ) copy(video_path, f"{result_dir}/video.mp4") - if args.remove_intermediate: + if not args.keep_intermediate: delete_dir(output_root, keep_subs=["result"]) except Exception as e: @@ -311,3 +333,7 @@ if __name__ == "__main__": continue logger.info(f"Processing complete. Outputs saved to {args.output_root}") + + +if __name__ == "__main__": + entrypoint() diff --git a/embodied_gen/scripts/text2image.py b/embodied_gen/scripts/text2image.py index 3b375a3..ac1587c 100644 --- a/embodied_gen/scripts/text2image.py +++ b/embodied_gen/scripts/text2image.py @@ -85,7 +85,7 @@ def parse_args(): parser.add_argument( "--seed", type=int, - default=0, + default=None, ) args = parser.parse_args() diff --git a/embodied_gen/scripts/textto3d.py b/embodied_gen/scripts/textto3d.py new file mode 100644 index 0000000..a4262a3 --- /dev/null +++ b/embodied_gen/scripts/textto3d.py @@ -0,0 +1,271 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +import argparse +import os +import random +from collections import defaultdict + +import torch +from PIL import Image +from embodied_gen.models.image_comm_model import build_hf_image_pipeline +from embodied_gen.models.segment_model import RembgRemover +from embodied_gen.models.text_model import PROMPT_APPEND +from embodied_gen.scripts.imageto3d import entrypoint as imageto3d_api +from embodied_gen.utils.gpt_clients import GPT_CLIENT +from embodied_gen.utils.log import logger +from embodied_gen.utils.process_media import render_asset3d +from embodied_gen.validators.quality_checkers import ( + ImageSegChecker, + SemanticConsistChecker, + TextGenAlignChecker, +) + +# Avoid huggingface/tokenizers: The current process just got forked. +os.environ["TOKENIZERS_PARALLELISM"] = "false" +random.seed(0) + + +__all__ = [ + "text_to_image", + "text_to_3d", +] + +logger.info("Loading Models...") +SEMANTIC_CHECKER = SemanticConsistChecker(GPT_CLIENT) +SEG_CHECKER = ImageSegChecker(GPT_CLIENT) +TXTGEN_CHECKER = TextGenAlignChecker(GPT_CLIENT) +PIPE_IMG = build_hf_image_pipeline("sd35") +BG_REMOVER = RembgRemover() + + +def text_to_image( + prompt: str, + save_path: str, + n_retry: int, + img_denoise_step: int, + text_guidance_scale: float, + n_img_sample: int, + image_hw: tuple[int, int] = (1024, 1024), + seed: int = None, +) -> bool: + select_image = None + success_flag = False + assert save_path.endswith(".png"), "Image save path must end with `.png`." + for try_idx in range(n_retry): + if select_image is not None: + select_image[0].save(save_path.replace(".png", "_raw.png")) + select_image[1].save(save_path) + break + + f_prompt = PROMPT_APPEND.format(object=prompt) + logger.info( + f"Image GEN for {os.path.basename(save_path)}\n" + f"Try: {try_idx + 1}/{n_retry}, Seed: {seed}, Prompt: {f_prompt}" + ) + images = PIPE_IMG.run( + f_prompt, + num_inference_steps=img_denoise_step, + guidance_scale=text_guidance_scale, + num_images_per_prompt=n_img_sample, + height=image_hw[0], + width=image_hw[1], + generator=( + torch.Generator().manual_seed(seed) + if seed is not None + else None + ), + ) + + for idx in range(len(images)): + raw_image: Image.Image = images[idx] + image = BG_REMOVER(raw_image) + image.save(save_path) + semantic_flag, semantic_result = SEMANTIC_CHECKER( + prompt, [image.convert("RGB")] + ) + seg_flag, seg_result = SEG_CHECKER( + [raw_image, image.convert("RGB")] + ) + if ( + (semantic_flag and seg_flag) + or semantic_flag is None + or seg_flag is None + ): + select_image = [raw_image, image] + success_flag = True + break + + torch.cuda.empty_cache() + seed = random.randint(0, 100000) if seed is not None else None + + return success_flag + + +def text_to_3d(**kwargs) -> dict: + args = parse_args() + for k, v in kwargs.items(): + if hasattr(args, k) and v is not None: + setattr(args, k, v) + + if args.asset_names is None or len(args.asset_names) == 0: + args.asset_names = [f"sample3d_{i}" for i in range(len(args.prompts))] + img_save_dir = os.path.join(args.output_root, "images") + asset_save_dir = os.path.join(args.output_root, "asset3d") + os.makedirs(img_save_dir, exist_ok=True) + os.makedirs(asset_save_dir, exist_ok=True) + results = defaultdict(dict) + for prompt, node in zip(args.prompts, args.asset_names): + success_flag = False + n_pipe_retry = args.n_pipe_retry + seed_img = args.seed_img + seed_3d = args.seed_3d + while success_flag is False and n_pipe_retry > 0: + logger.info( + f"GEN pipeline for node {node}\n" + f"Try round: {args.n_pipe_retry-n_pipe_retry+1}/{args.n_pipe_retry}, Prompt: {prompt}" + ) + # Text-to-image GEN + save_node = node.replace(" ", "_") + gen_image_path = f"{img_save_dir}/{save_node}.png" + textgen_flag = text_to_image( + prompt, + gen_image_path, + args.n_image_retry, + args.img_denoise_step, + args.text_guidance_scale, + args.n_img_sample, + seed=seed_img, + ) + + # Asset 3D GEN + node_save_dir = f"{asset_save_dir}/{save_node}" + asset_type = node if "sample3d_" not in node else None + imageto3d_api( + image_path=[gen_image_path], + output_root=node_save_dir, + asset_type=[asset_type], + seed=random.randint(0, 100000) if seed_3d is None else seed_3d, + n_retry=args.n_asset_retry, + keep_intermediate=args.keep_intermediate, + ) + mesh_path = f"{node_save_dir}/result/mesh/{save_node}.obj" + image_path = render_asset3d( + mesh_path, + output_root=f"{node_save_dir}/result", + num_images=6, + elevation=(30, -30), + output_subdir="renders", + no_index_file=True, + ) + + check_text = asset_type if asset_type is not None else prompt + qa_flag, qa_result = TXTGEN_CHECKER(check_text, image_path) + logger.warning( + f"Node {node}, {TXTGEN_CHECKER.__class__.__name__}: {qa_result}" + ) + results["assets"][node] = f"{node_save_dir}/result" + results["quality"][node] = qa_result + + if qa_flag is None or qa_flag is True: + success_flag = True + break + + n_pipe_retry -= 1 + seed_img = ( + random.randint(0, 100000) if seed_img is not None else None + ) + seed_3d = ( + random.randint(0, 100000) if seed_3d is not None else None + ) + + torch.cuda.empty_cache() + + return results + + +def parse_args(): + parser = argparse.ArgumentParser(description="3D Layout Generation Config") + parser.add_argument("--prompts", nargs="+", help="text descriptions") + parser.add_argument( + "--output_root", + type=str, + help="Directory to save outputs", + ) + parser.add_argument( + "--asset_names", + type=str, + nargs="+", + default=None, + help="Asset names to generate", + ) + parser.add_argument( + "--n_img_sample", + type=int, + default=3, + help="Number of image samples to generate", + ) + parser.add_argument( + "--text_guidance_scale", + type=float, + default=7, + help="Text-to-image guidance scale", + ) + parser.add_argument( + "--img_denoise_step", + type=int, + default=25, + help="Denoising steps for image generation", + ) + parser.add_argument( + "--n_image_retry", + type=int, + default=2, + help="Max retry count for image generation", + ) + parser.add_argument( + "--n_asset_retry", + type=int, + default=2, + help="Max retry count for 3D generation", + ) + parser.add_argument( + "--n_pipe_retry", + type=int, + default=1, + help="Max retry count for 3D asset generation", + ) + parser.add_argument( + "--seed_img", + type=int, + default=None, + help="Random seed for image generation", + ) + parser.add_argument( + "--seed_3d", + type=int, + default=0, + help="Random seed for 3D generation", + ) + parser.add_argument("--keep_intermediate", action="store_true") + + args, unknown = parser.parse_known_args() + + return args + + +if __name__ == "__main__": + text_to_3d() diff --git a/embodied_gen/scripts/textto3d.sh b/embodied_gen/scripts/textto3d.sh index d9648c7..1e84599 100644 --- a/embodied_gen/scripts/textto3d.sh +++ b/embodied_gen/scripts/textto3d.sh @@ -2,7 +2,9 @@ # Initialize variables prompts=() +asset_types=() output_root="" +seed=0 # Parse arguments while [[ $# -gt 0 ]]; do @@ -14,10 +16,21 @@ while [[ $# -gt 0 ]]; do shift done ;; + --asset_types) + shift + while [[ $# -gt 0 && ! "$1" =~ ^-- ]]; do + asset_types+=("$1") + shift + done + ;; --output_root) output_root="$2" shift 2 ;; + --seed) + seed="$2" + shift 2 + ;; *) echo "Unknown argument: $1" exit 1 @@ -28,7 +41,21 @@ done # Validate required arguments if [[ ${#prompts[@]} -eq 0 || -z "$output_root" ]]; then echo "Missing required arguments." - echo "Usage: bash run_text2asset3d.sh --prompts \"Prompt1\" \"Prompt2\" --output_root " + echo "Usage: bash run_text2asset3d.sh --prompts \"Prompt1\" \"Prompt2\" \ + --asset_types \"type1\" \"type2\" --seed --output_root " + exit 1 +fi + +# If no asset_types provided, default to "" +if [[ ${#asset_types[@]} -eq 0 ]]; then + for (( i=0; i<${#prompts[@]}; i++ )); do + asset_types+=("") + done +fi + +# Ensure the number of asset_types matches the number of prompts +if [[ ${#prompts[@]} -ne ${#asset_types[@]} ]]; then + echo "The number of asset types must match the number of prompts." exit 1 fi @@ -37,20 +64,30 @@ echo "Prompts:" for p in "${prompts[@]}"; do echo " - $p" done +# echo "Asset types:" +# for at in "${asset_types[@]}"; do +# echo " - $at" +# done echo "Output root: ${output_root}" +echo "Seed: ${seed}" -# Concatenate prompts for Python command +# Concatenate prompts and asset types for Python command prompt_args="" -for p in "${prompts[@]}"; do - prompt_args+="\"$p\" " +asset_type_args="" +for i in "${!prompts[@]}"; do + prompt_args+="\"${prompts[$i]}\" " + asset_type_args+="\"${asset_types[$i]}\" " done + # Step 1: Text-to-Image eval python3 embodied_gen/scripts/text2image.py \ --prompts ${prompt_args} \ - --output_root "${output_root}/images" + --output_root "${output_root}/images" \ + --seed ${seed} # Step 2: Image-to-3D python3 embodied_gen/scripts/imageto3d.py \ --image_root "${output_root}/images" \ - --output_root "${output_root}/asset3d" + --output_root "${output_root}/asset3d" \ + --asset_type ${asset_type_args} diff --git a/embodied_gen/scripts/texture_gen.sh b/embodied_gen/scripts/texture_gen.sh index 8311e22..7374e84 100644 --- a/embodied_gen/scripts/texture_gen.sh +++ b/embodied_gen/scripts/texture_gen.sh @@ -10,10 +10,6 @@ while [[ $# -gt 0 ]]; do prompt="$2" shift 2 ;; - --uuid) - uuid="$2" - shift 2 - ;; --output_root) output_root="$2" shift 2 @@ -26,12 +22,13 @@ while [[ $# -gt 0 ]]; do done -if [[ -z "$mesh_path" || -z "$prompt" || -z "$uuid" || -z "$output_root" ]]; then +if [[ -z "$mesh_path" || -z "$prompt" || -z "$output_root" ]]; then echo "params missing" - echo "usage: bash run.sh --mesh_path --prompt --uuid --output_root " + echo "usage: bash run.sh --mesh_path --prompt --output_root " exit 1 fi +uuid=$(basename "$output_root") # Step 1: drender-cli for condition rendering drender-cli --mesh_path ${mesh_path} \ --output_root ${output_root}/condition \ diff --git a/embodied_gen/utils/enum.py b/embodied_gen/utils/enum.py new file mode 100644 index 0000000..7fc3347 --- /dev/null +++ b/embodied_gen/utils/enum.py @@ -0,0 +1,107 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +from dataclasses import dataclass, field +from enum import Enum + +from dataclasses_json import DataClassJsonMixin + +__all__ = [ + "RenderItems", + "Scene3DItemEnum", + "SpatialRelationEnum", + "RobotItemEnum", +] + + +@dataclass +class RenderItems(str, Enum): + IMAGE = "image_color" + ALPHA = "image_mask" + VIEW_NORMAL = "image_view_normal" + GLOBAL_NORMAL = "image_global_normal" + POSITION_MAP = "image_position" + DEPTH = "image_depth" + ALBEDO = "image_albedo" + DIFFUSE = "image_diffuse" + + +@dataclass +class Scene3DItemEnum(str, Enum): + BACKGROUND = "background" + CONTEXT = "context" + ROBOT = "robot" + MANIPULATED_OBJS = "manipulated_objs" + DISTRACTOR_OBJS = "distractor_objs" + OTHERS = "others" + + @classmethod + def object_list(cls, layout_relation: dict) -> list: + return ( + [ + layout_relation[cls.BACKGROUND.value], + layout_relation[cls.CONTEXT.value], + ] + + layout_relation[cls.MANIPULATED_OBJS.value] + + layout_relation[cls.DISTRACTOR_OBJS.value] + ) + + @classmethod + def object_mapping(cls, layout_relation): + relation_mapping = { + # layout_relation[cls.ROBOT.value]: cls.ROBOT.value, + layout_relation[cls.BACKGROUND.value]: cls.BACKGROUND.value, + layout_relation[cls.CONTEXT.value]: cls.CONTEXT.value, + } + relation_mapping.update( + { + item: cls.MANIPULATED_OBJS.value + for item in layout_relation[cls.MANIPULATED_OBJS.value] + } + ) + relation_mapping.update( + { + item: cls.DISTRACTOR_OBJS.value + for item in layout_relation[cls.DISTRACTOR_OBJS.value] + } + ) + + return relation_mapping + + +@dataclass +class SpatialRelationEnum(str, Enum): + ON = "ON" # objects on the table + IN = "IN" # objects in the room + INSIDE = "INSIDE" # objects inside the shelf/rack + FLOOR = "FLOOR" # object floor room/bin + + +@dataclass +class RobotItemEnum(str, Enum): + FRANKA = "franka" + UR5 = "ur5" + PIPER = "piper" + + +@dataclass +class LayoutInfo(DataClassJsonMixin): + tree: dict[str, list] + relation: dict[str, str | list[str]] + objs_desc: dict[str, str] = field(default_factory=dict) + assets: dict[str, str] = field(default_factory=dict) + quality: dict[str, str] = field(default_factory=dict) + position: dict[str, list[float]] = field(default_factory=dict) diff --git a/embodied_gen/utils/gpt_clients.py b/embodied_gen/utils/gpt_clients.py index f7ce067..de435e2 100644 --- a/embodied_gen/utils/gpt_clients.py +++ b/embodied_gen/utils/gpt_clients.py @@ -30,12 +30,20 @@ from tenacity import ( stop_after_delay, wait_random_exponential, ) -from embodied_gen.utils.process_media import combine_images_to_base64 +from embodied_gen.utils.process_media import combine_images_to_grid -logging.basicConfig(level=logging.INFO) +logging.getLogger("httpx").setLevel(logging.WARNING) +logging.basicConfig(level=logging.WARNING) logger = logging.getLogger(__name__) +__all__ = [ + "GPTclient", +] + +CONFIG_FILE = "embodied_gen/utils/gpt_config.yaml" + + class GPTclient: """A client to interact with the GPT model via OpenAI or Azure API.""" @@ -45,6 +53,7 @@ class GPTclient: api_key: str, model_name: str = "yfb-gpt-4o", api_version: str = None, + check_connection: bool = True, verbose: bool = False, ): if api_version is not None: @@ -63,6 +72,9 @@ class GPTclient: self.model_name = model_name self.image_formats = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"} self.verbose = verbose + if check_connection: + self.check_connection() + logger.info(f"Using GPT model: {self.model_name}.") @retry( @@ -77,6 +89,7 @@ class GPTclient: text_prompt: str, image_base64: Optional[list[str | Image.Image]] = None, system_role: Optional[str] = None, + params: Optional[dict] = None, ) -> Optional[str]: """Queries the GPT model with a text and optional image prompts. @@ -86,6 +99,7 @@ class GPTclient: or local image paths or PIL.Image to accompany the text prompt. system_role (Optional[str]): Optional system-level instructions that specify the behavior of the assistant. + params (Optional[dict]): Additional parameters for GPT setting. Returns: Optional[str]: The response content generated by the model based on @@ -103,11 +117,11 @@ class GPTclient: # Process images if provided if image_base64 is not None: - image_base64 = ( - image_base64 - if isinstance(image_base64, list) - else [image_base64] - ) + if not isinstance(image_base64, list): + image_base64 = [image_base64] + # Hardcode tmp because of the openrouter can't input multi images. + if "openrouter" in self.endpoint: + image_base64 = combine_images_to_grid(image_base64) for img in image_base64: if isinstance(img, Image.Image): buffer = BytesIO() @@ -142,8 +156,11 @@ class GPTclient: "frequency_penalty": 0, "presence_penalty": 0, "stop": None, + "model": self.model_name, } - payload.update({"model": self.model_name}) + + if params: + payload.update(params) response = None try: @@ -159,8 +176,28 @@ class GPTclient: return response + def check_connection(self) -> None: + """Check whether the GPT API connection is working.""" + try: + response = self.completion_with_backoff( + messages=[ + {"role": "system", "content": "You are a test system."}, + {"role": "user", "content": "Hello"}, + ], + model=self.model_name, + temperature=0, + max_tokens=100, + ) + content = response.choices[0].message.content + logger.info(f"Connection check success.") + except Exception as e: + raise ConnectionError( + f"Failed to connect to GPT API at {self.endpoint}, " + f"please check setting in `{CONFIG_FILE}` and `README`." + ) -with open("embodied_gen/utils/gpt_config.yaml", "r") as f: + +with open(CONFIG_FILE, "r") as f: config = yaml.safe_load(f) agent_type = config["agent_type"] @@ -177,32 +214,5 @@ GPT_CLIENT = GPTclient( api_key=api_key, api_version=api_version, model_name=model_name, + check_connection=False, ) - -if __name__ == "__main__": - if "openrouter" in GPT_CLIENT.endpoint: - response = GPT_CLIENT.query( - text_prompt="What is the content in each image?", - image_base64=combine_images_to_base64( - [ - "apps/assets/example_image/sample_02.jpg", - "apps/assets/example_image/sample_03.jpg", - ] - ), # input raw image_path if only one image - ) - print(response) - else: - response = GPT_CLIENT.query( - text_prompt="What is the content in the images?", - image_base64=[ - Image.open("apps/assets/example_image/sample_02.jpg"), - Image.open("apps/assets/example_image/sample_03.jpg"), - ], - ) - print(response) - - # test2: text prompt - response = GPT_CLIENT.query( - text_prompt="What is the capital of China?" - ) - print(response) diff --git a/embodied_gen/utils/log.py b/embodied_gen/utils/log.py new file mode 100644 index 0000000..7c8998b --- /dev/null +++ b/embodied_gen/utils/log.py @@ -0,0 +1,48 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +import logging + +from colorlog import ColoredFormatter + +__all__ = [ + "logger", +] + +LOG_FORMAT = ( + "%(log_color)s[%(asctime)s] %(levelname)-8s | %(message)s%(reset)s" +) +DATE_FORMAT = "%Y-%m-%d %H:%M:%S" + +formatter = ColoredFormatter( + LOG_FORMAT, + datefmt=DATE_FORMAT, + log_colors={ + "DEBUG": "cyan", + "INFO": "green", + "WARNING": "yellow", + "ERROR": "red", + "CRITICAL": "bold_red", + }, +) + +handler = logging.StreamHandler() +handler.setFormatter(formatter) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +logger.addHandler(handler) +logger.propagate = False diff --git a/embodied_gen/utils/process_media.py b/embodied_gen/utils/process_media.py index 2d47c69..edfdcff 100644 --- a/embodied_gen/utils/process_media.py +++ b/embodied_gen/utils/process_media.py @@ -15,34 +15,24 @@ # permissions and limitations under the License. -import base64 import logging import math import os -import sys +import textwrap from glob import glob -from io import BytesIO from typing import Union import cv2 import imageio +import matplotlib.pyplot as plt +import networkx as nx import numpy as np import PIL.Image as Image import spaces -import torch +from matplotlib.patches import Patch from moviepy.editor import VideoFileClip, clips_array -from tqdm import tqdm from embodied_gen.data.differentiable_render import entrypoint as render_api - -current_file_path = os.path.abspath(__file__) -current_dir = os.path.dirname(current_file_path) -sys.path.append(os.path.join(current_dir, "../..")) -from thirdparty.TRELLIS.trellis.renderers.mesh_renderer import MeshRenderer -from thirdparty.TRELLIS.trellis.representations import MeshExtractResult -from thirdparty.TRELLIS.trellis.utils.render_utils import ( - render_frames, - yaw_pitch_r_fov_to_extrinsics_intrinsics, -) +from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -53,9 +43,8 @@ __all__ = [ "merge_images_video", "filter_small_connected_components", "filter_image_small_connected_components", - "combine_images_to_base64", - "render_mesh", - "render_video", + "combine_images_to_grid", + "SceneTreeVisualizer", ] @@ -66,12 +55,14 @@ def render_asset3d( distance: float = 5.0, num_images: int = 1, elevation: list[float] = (0.0,), - pbr_light_factor: float = 1.5, + pbr_light_factor: float = 1.2, return_key: str = "image_color/*", output_subdir: str = "renders", gen_color_mp4: bool = False, gen_viewnormal_mp4: bool = False, gen_glonormal_mp4: bool = False, + no_index_file: bool = False, + with_mtl: bool = True, ) -> list[str]: input_args = dict( mesh_path=mesh_path, @@ -81,14 +72,13 @@ def render_asset3d( num_images=num_images, elevation=elevation, pbr_light_factor=pbr_light_factor, - with_mtl=True, + with_mtl=with_mtl, + gen_color_mp4=gen_color_mp4, + gen_viewnormal_mp4=gen_viewnormal_mp4, + gen_glonormal_mp4=gen_glonormal_mp4, + no_index_file=no_index_file, ) - if gen_color_mp4: - input_args["gen_color_mp4"] = True - if gen_viewnormal_mp4: - input_args["gen_viewnormal_mp4"] = True - if gen_glonormal_mp4: - input_args["gen_glonormal_mp4"] = True + try: _ = render_api(**input_args) except Exception as e: @@ -168,12 +158,15 @@ def filter_image_small_connected_components( return image -def combine_images_to_base64( +def combine_images_to_grid( images: list[str | Image.Image], cat_row_col: tuple[int, int] = None, target_wh: tuple[int, int] = (512, 512), -) -> str: +) -> list[str | Image.Image]: n_images = len(images) + if n_images == 1: + return images + if cat_row_col is None: n_col = math.ceil(math.sqrt(n_images)) n_row = math.ceil(n_images / n_col) @@ -182,88 +175,190 @@ def combine_images_to_base64( images = [ Image.open(p).convert("RGB") if isinstance(p, str) else p - for p in images[: n_row * n_col] + for p in images ] images = [img.resize(target_wh) for img in images] grid_w, grid_h = n_col * target_wh[0], n_row * target_wh[1] - grid = Image.new("RGB", (grid_w, grid_h), (255, 255, 255)) + grid = Image.new("RGB", (grid_w, grid_h), (0, 0, 0)) for idx, img in enumerate(images): row, col = divmod(idx, n_col) grid.paste(img, (col * target_wh[0], row * target_wh[1])) - buffer = BytesIO() - grid.save(buffer, format="PNG") - - return base64.b64encode(buffer.getvalue()).decode("utf-8") + return [grid] -@spaces.GPU -def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs): - renderer = MeshRenderer() - renderer.rendering_options.resolution = options.get("resolution", 512) - renderer.rendering_options.near = options.get("near", 1) - renderer.rendering_options.far = options.get("far", 100) - renderer.rendering_options.ssaa = options.get("ssaa", 4) - rets = {} - for extr, intr in tqdm(zip(extrinsics, intrinsics), desc="Rendering"): - res = renderer.render(sample, extr, intr) - if "normal" not in rets: - rets["normal"] = [] - normal = torch.lerp( - torch.zeros_like(res["normal"]), res["normal"], res["mask"] +class SceneTreeVisualizer: + def __init__(self, layout_info: LayoutInfo) -> None: + self.tree = layout_info.tree + self.relation = layout_info.relation + self.objs_desc = layout_info.objs_desc + self.G = nx.DiGraph() + self.root = self._find_root() + self._build_graph() + + self.role_colors = { + Scene3DItemEnum.BACKGROUND.value: "plum", + Scene3DItemEnum.CONTEXT.value: "lightblue", + Scene3DItemEnum.ROBOT.value: "lightcoral", + Scene3DItemEnum.MANIPULATED_OBJS.value: "lightgreen", + Scene3DItemEnum.DISTRACTOR_OBJS.value: "lightgray", + Scene3DItemEnum.OTHERS.value: "orange", + } + + def _find_root(self) -> str: + children = {c for cs in self.tree.values() for c, _ in cs} + parents = set(self.tree.keys()) + roots = parents - children + if not roots: + raise ValueError("No root node found.") + return next(iter(roots)) + + def _build_graph(self): + for parent, children in self.tree.items(): + for child, relation in children: + self.G.add_edge(parent, child, relation=relation) + + def _get_node_role(self, node: str) -> str: + if node == self.relation.get(Scene3DItemEnum.BACKGROUND.value): + return Scene3DItemEnum.BACKGROUND.value + if node == self.relation.get(Scene3DItemEnum.CONTEXT.value): + return Scene3DItemEnum.CONTEXT.value + if node == self.relation.get(Scene3DItemEnum.ROBOT.value): + return Scene3DItemEnum.ROBOT.value + if node in self.relation.get( + Scene3DItemEnum.MANIPULATED_OBJS.value, [] + ): + return Scene3DItemEnum.MANIPULATED_OBJS.value + if node in self.relation.get( + Scene3DItemEnum.DISTRACTOR_OBJS.value, [] + ): + return Scene3DItemEnum.DISTRACTOR_OBJS.value + return Scene3DItemEnum.OTHERS.value + + def _get_positions( + self, root, width=1.0, vert_gap=0.1, vert_loc=1, xcenter=0.5, pos=None + ): + if pos is None: + pos = {root: (xcenter, vert_loc)} + else: + pos[root] = (xcenter, vert_loc) + + children = list(self.G.successors(root)) + if children: + dx = width / len(children) + next_x = xcenter - width / 2 - dx / 2 + for child in children: + next_x += dx + pos = self._get_positions( + child, + width=dx, + vert_gap=vert_gap, + vert_loc=vert_loc - vert_gap, + xcenter=next_x, + pos=pos, + ) + return pos + + def render( + self, + save_path: str, + figsize=(8, 6), + dpi=300, + title: str = "Scene 3D Hierarchy Tree", + ): + node_colors = [ + self.role_colors[self._get_node_role(n)] for n in self.G.nodes + ] + pos = self._get_positions(self.root) + + plt.figure(figsize=figsize) + nx.draw( + self.G, + pos, + with_labels=True, + arrows=False, + node_size=2000, + node_color=node_colors, + font_size=10, + font_weight="bold", ) - normal = np.clip( - normal.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255 - ).astype(np.uint8) - rets["normal"].append(normal) - return rets + # Draw edge labels + edge_labels = nx.get_edge_attributes(self.G, "relation") + nx.draw_networkx_edge_labels( + self.G, + pos, + edge_labels=edge_labels, + font_size=9, + font_color="black", + ) + + # Draw small description text under each node (if available) + for node, (x, y) in pos.items(): + desc = self.objs_desc.get(node) + if desc: + wrapped = "\n".join(textwrap.wrap(desc, width=30)) + plt.text( + x, + y - 0.006, + wrapped, + fontsize=6, + ha="center", + va="top", + wrap=True, + color="black", + bbox=dict( + facecolor="dimgray", + edgecolor="darkgray", + alpha=0.1, + boxstyle="round,pad=0.2", + ), + ) + + plt.title(title, fontsize=12) + task_desc = self.relation.get("task_desc", "") + if task_desc: + plt.suptitle( + f"Task Description: {task_desc}", fontsize=10, y=0.999 + ) + + plt.axis("off") + + legend_handles = [ + Patch(facecolor=color, edgecolor='black', label=role) + for role, color in self.role_colors.items() + ] + plt.legend( + handles=legend_handles, + loc="lower center", + ncol=3, + bbox_to_anchor=(0.5, -0.1), + fontsize=9, + ) + + os.makedirs(os.path.dirname(save_path), exist_ok=True) + plt.savefig(save_path, dpi=dpi, bbox_inches="tight") + plt.close() -@spaces.GPU -def render_video( - sample, - resolution=512, - bg_color=(0, 0, 0), - num_frames=300, - r=2, - fov=40, - **kwargs, -): - yaws = torch.linspace(0, 2 * 3.1415, num_frames) - yaws = yaws.tolist() - pitch = [0.5] * num_frames - extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics( - yaws, pitch, r, fov - ) - render_fn = ( - render_mesh if isinstance(sample, MeshExtractResult) else render_frames - ) - result = render_fn( - sample, - extrinsics, - intrinsics, - {"resolution": resolution, "bg_color": bg_color}, - **kwargs, - ) +def load_scene_dict(file_path: str) -> dict: + scene_dict = {} + with open(file_path, "r", encoding='utf-8') as f: + for line in f: + line = line.strip() + if not line or ":" not in line: + continue + scene_id, desc = line.split(":", 1) + scene_dict[scene_id.strip()] = desc.strip() - return result + return scene_dict if __name__ == "__main__": - # Example usage: merge_video_video( "outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh_glo_normal.mp4", # noqa "outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh.mp4", # noqa "merge.mp4", ) - - image_base64 = combine_images_to_base64( - [ - "apps/assets/example_image/sample_00.jpg", - "apps/assets/example_image/sample_01.jpg", - "apps/assets/example_image/sample_02.jpg", - ] - ) diff --git a/embodied_gen/utils/tags.py b/embodied_gen/utils/tags.py index 07a9a63..56deb45 100644 --- a/embodied_gen/utils/tags.py +++ b/embodied_gen/utils/tags.py @@ -1 +1 @@ -VERSION = "v0.1.0" +VERSION = "v0.1.1" diff --git a/embodied_gen/utils/trender.py b/embodied_gen/utils/trender.py new file mode 100644 index 0000000..53acc50 --- /dev/null +++ b/embodied_gen/utils/trender.py @@ -0,0 +1,90 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +import os +import sys + +import numpy as np +import spaces +import torch +from tqdm import tqdm + +current_file_path = os.path.abspath(__file__) +current_dir = os.path.dirname(current_file_path) +sys.path.append(os.path.join(current_dir, "../..")) +from thirdparty.TRELLIS.trellis.renderers.mesh_renderer import MeshRenderer +from thirdparty.TRELLIS.trellis.representations import MeshExtractResult +from thirdparty.TRELLIS.trellis.utils.render_utils import ( + render_frames, + yaw_pitch_r_fov_to_extrinsics_intrinsics, +) + +__all__ = [ + "render_video", +] + + +@spaces.GPU +def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs): + renderer = MeshRenderer() + renderer.rendering_options.resolution = options.get("resolution", 512) + renderer.rendering_options.near = options.get("near", 1) + renderer.rendering_options.far = options.get("far", 100) + renderer.rendering_options.ssaa = options.get("ssaa", 4) + rets = {} + for extr, intr in tqdm(zip(extrinsics, intrinsics), desc="Rendering"): + res = renderer.render(sample, extr, intr) + if "normal" not in rets: + rets["normal"] = [] + normal = torch.lerp( + torch.zeros_like(res["normal"]), res["normal"], res["mask"] + ) + normal = np.clip( + normal.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255 + ).astype(np.uint8) + rets["normal"].append(normal) + + return rets + + +@spaces.GPU +def render_video( + sample, + resolution=512, + bg_color=(0, 0, 0), + num_frames=300, + r=2, + fov=40, + **kwargs, +): + yaws = torch.linspace(0, 2 * 3.1415, num_frames) + yaws = yaws.tolist() + pitch = [0.5] * num_frames + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics( + yaws, pitch, r, fov + ) + render_fn = ( + render_mesh if isinstance(sample, MeshExtractResult) else render_frames + ) + result = render_fn( + sample, + extrinsics, + intrinsics, + {"resolution": resolution, "bg_color": bg_color}, + **kwargs, + ) + + return result diff --git a/embodied_gen/validators/aesthetic_predictor.py b/embodied_gen/validators/aesthetic_predictor.py index 5b9c557..921f363 100644 --- a/embodied_gen/validators/aesthetic_predictor.py +++ b/embodied_gen/validators/aesthetic_predictor.py @@ -102,7 +102,7 @@ class AestheticPredictor: def _load_sac_model(self, model_path, input_size): """Load the SAC model.""" model = self.MLP(input_size) - ckpt = torch.load(model_path) + ckpt = torch.load(model_path, weights_only=True) model.load_state_dict(ckpt) model.to(self.device) model.eval() @@ -135,15 +135,3 @@ class AestheticPredictor: ) return prediction.item() - - -if __name__ == "__main__": - # Configuration - img_path = "apps/assets/example_image/sample_00.jpg" - - # Initialize the predictor - predictor = AestheticPredictor() - - # Predict the aesthetic score - score = predictor.predict(img_path) - print("Aesthetic score predicted by the model:", score) diff --git a/embodied_gen/validators/quality_checkers.py b/embodied_gen/validators/quality_checkers.py index 4608c6b..88636f6 100644 --- a/embodied_gen/validators/quality_checkers.py +++ b/embodied_gen/validators/quality_checkers.py @@ -16,17 +16,26 @@ import logging -import os +import random -from tqdm import tqdm +import json_repair +from PIL import Image from embodied_gen.utils.gpt_clients import GPT_CLIENT, GPTclient -from embodied_gen.utils.process_media import render_asset3d from embodied_gen.validators.aesthetic_predictor import AestheticPredictor logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) +__all__ = [ + "MeshGeoChecker", + "ImageSegChecker", + "ImageAestheticChecker", + "SemanticConsistChecker", + "TextGenAlignChecker", +] + + class BaseChecker: def __init__(self, prompt: str = None, verbose: bool = False) -> None: self.prompt = prompt @@ -37,16 +46,20 @@ class BaseChecker: "Subclasses must implement the query method." ) - def __call__(self, *args, **kwargs) -> bool: + def __call__(self, *args, **kwargs) -> tuple[bool, str]: response = self.query(*args, **kwargs) - if response is None: - response = "Error when calling gpt api." - - if self.verbose and response != "YES": + if self.verbose: logger.info(response) - flag = "YES" in response - response = "YES" if flag else response + if response is None: + flag = None + response = ( + "Error when calling GPT api, check config in " + "`embodied_gen/utils/gpt_config.yaml` or net connection." + ) + else: + flag = "YES" in response + response = "YES" if flag else response return flag, response @@ -92,21 +105,29 @@ class MeshGeoChecker(BaseChecker): self.gpt_client = gpt_client if self.prompt is None: self.prompt = """ - Refer to the provided multi-view rendering images to evaluate - whether the geometry of the 3D object asset is complete and - whether the asset can be placed stably on the ground. - Return "YES" only if reach the requirments, - otherwise "NO" and explain the reason very briefly. + You are an expert in evaluating the geometry quality of generated 3D asset. + You will be given rendered views of a generated 3D asset with black background. + Your task is to evaluate the quality of the 3D asset generation, + including geometry, structure, and appearance, based on the rendered views. + Criteria: + - Is the geometry complete and well-formed, without missing parts or redundant structures? + - Is the geometric structure of the object complete? + - Minor flaws, asymmetries, or simplifications (e.g., less detail on sides or back, + soft edges) are acceptable if the object is structurally sound and recognizable. + - Only evaluate geometry. Do not assess texture quality. + - The asset should not contain any unrelated elements, such as + ground planes, platforms, or background props (e.g., paper, flooring). + + If all the above criteria are met, return "YES". Otherwise, return + "NO" followed by a brief explanation (no more than 20 words). + + Example: + Images show a yellow cup standing on a flat white plane -> NO + -> Response: NO: extra white surface under the object. + Image shows a chair with simplified back legs and soft edges โ†’ YES """ - def query(self, image_paths: str) -> str: - # Hardcode tmp because of the openrouter can't input multi images. - if "openrouter" in self.gpt_client.endpoint: - from embodied_gen.utils.process_media import ( - combine_images_to_base64, - ) - - image_paths = combine_images_to_base64(image_paths) + def query(self, image_paths: list[str | Image.Image]) -> str: return self.gpt_client.query( text_prompt=self.prompt, @@ -137,14 +158,19 @@ class ImageSegChecker(BaseChecker): self.gpt_client = gpt_client if self.prompt is None: self.prompt = """ - The first image is the original, and the second image is the - result after segmenting the main object. Evaluate the segmentation - quality to ensure the main object is clearly segmented without - significant truncation. Note that the foreground of the object - needs to be extracted instead of the background. - Minor imperfections can be ignored. If segmentation is acceptable, - return "YES" only; otherwise, return "NO" with - very brief explanation. + Task: Evaluate the quality of object segmentation between two images: + the first is the original, the second is the segmented result. + + Criteria: + - The main foreground object should be clearly extracted (not the background). + - The object must appear realistic, with reasonable geometry and color. + - The object should be geometrically complete โ€” no missing, truncated, or cropped parts. + - The object must be centered, with a margin on all sides. + - Ignore minor imperfections (e.g., small holes or fine edge artifacts). + + Output Rules: + If segmentation is acceptable, respond with "YES" (and nothing else). + If not acceptable, respond with "NO", followed by a brief reason (max 20 words). """ def query(self, image_paths: list[str]) -> str: @@ -152,13 +178,6 @@ class ImageSegChecker(BaseChecker): raise ValueError( "ImageSegChecker requires exactly two images: [raw_image, seg_image]." # noqa ) - # Hardcode tmp because of the openrouter can't input multi images. - if "openrouter" in self.gpt_client.endpoint: - from embodied_gen.utils.process_media import ( - combine_images_to_base64, - ) - - image_paths = combine_images_to_base64(image_paths) return self.gpt_client.query( text_prompt=self.prompt, @@ -201,42 +220,204 @@ class ImageAestheticChecker(BaseChecker): return avg_score > self.thresh, avg_score -if __name__ == "__main__": - geo_checker = MeshGeoChecker(GPT_CLIENT) - seg_checker = ImageSegChecker(GPT_CLIENT) - aesthetic_checker = ImageAestheticChecker() +class SemanticConsistChecker(BaseChecker): + def __init__( + self, + gpt_client: GPTclient, + prompt: str = None, + verbose: bool = False, + ) -> None: + super().__init__(prompt, verbose) + self.gpt_client = gpt_client + if self.prompt is None: + self.prompt = """ + You are an expert in image-text consistency assessment. + You will be given: + - A short text description of an object. + - An segmented image of the same object with the background removed. - checkers = [geo_checker, seg_checker, aesthetic_checker] + Criteria: + - The image must visually match the text description in terms of object type, structure, geometry, and color. + - The object must appear realistic, with reasonable geometry (e.g., a table must have a stable number of legs). + - Geometric completeness is required: the object must not have missing, truncated, or cropped parts. + - The object must be centered in the image frame with clear margins on all sides, + it should not touch or nearly touch any image edge. + - The image must contain exactly one object. Multiple distinct objects are not allowed. + A single composite object (e.g., a chair with legs) is acceptable. + - The object should be shown from a slightly angled (three-quarter) perspective, + not a flat, front-facing view showing only one surface. - output_root = "outputs/test_gpt" + Instructions: + - If all criteria are met, return `"YES"`. + - Otherwise, return "NO" with a brief explanation (max 20 words). - fails = [] - for idx in tqdm(range(150)): - mesh_path = f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}.obj" # noqa - if not os.path.exists(mesh_path): - continue - image_paths = render_asset3d( - mesh_path, - f"{output_root}/{idx}", - num_images=8, - elevation=(30, -30), - distance=5.5, + Respond in exactly one of the following formats: + YES + or + NO: brief explanation. + + Input: + {} + """ + + def query(self, text: str, image: list[Image.Image | str]) -> str: + + return self.gpt_client.query( + text_prompt=self.prompt.format(text), + image_base64=image, ) - for cid, checker in enumerate(checkers): - if isinstance(checker, ImageSegChecker): - images = [ - f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}_raw.png", # noqa - f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}_cond.png", # noqa - ] - else: - images = image_paths - result, info = checker(images) - logger.info( - f"Checker {checker.__class__.__name__}: {result}, {info}, mesh {mesh_path}" # noqa - ) - if result is False: - fails.append((idx, cid, info)) +class TextGenAlignChecker(BaseChecker): + def __init__( + self, + gpt_client: GPTclient, + prompt: str = None, + verbose: bool = False, + ) -> None: + super().__init__(prompt, verbose) + self.gpt_client = gpt_client + if self.prompt is None: + self.prompt = """ + You are an expert in evaluating the quality of generated 3D assets. + You will be given: + - A text description of an object: TEXT + - Rendered views of the generated 3D asset. - break + Your task is to: + 1. Determine whether the generated 3D asset roughly reflects the object class + or a semantically adjacent category described in the text. + 2. Evaluate the geometry quality of the 3D asset generation based on the rendered views. + + Criteria: + - Determine if the generated 3D asset belongs to the text described or a similar category. + - Focus on functional similarity: if the object serves the same general + purpose (e.g., writing, placing items), it should be accepted. + - Is the geometry complete and well-formed, with no missing parts, + distortions, visual artifacts, or redundant structures? + - Does the number of object instances match the description? + There should be only one object unless otherwise specified. + - Minor flaws in geometry or texture are acceptable, high tolerance for texture quality defects. + - Minor simplifications in geometry or texture (e.g. soft edges, less detail) + are acceptable if the object is still recognizable. + - The asset should not contain any unrelated elements, such as + ground planes, platforms, or background props (e.g., paper, flooring). + + Example: + Text: "yellow cup" + Image: shows a yellow cup standing on a flat white plane -> NO: extra surface under the object. + + Instructions: + - If the quality of generated asset is acceptable and faithfully represents the text, return "YES". + - Otherwise, return "NO" followed by a brief explanation (no more than 20 words). + + Respond in exactly one of the following formats: + YES + or + NO: brief explanation + + Input: + Text description: {} + """ + + def query(self, text: str, image: list[Image.Image | str]) -> str: + + return self.gpt_client.query( + text_prompt=self.prompt.format(text), + image_base64=image, + ) + + +class SemanticMatcher(BaseChecker): + def __init__( + self, + gpt_client: GPTclient, + prompt: str = None, + verbose: bool = False, + seed: int = None, + ) -> None: + super().__init__(prompt, verbose) + self.gpt_client = gpt_client + self.seed = seed + random.seed(seed) + if self.prompt is None: + self.prompt = """ + You are an expert in semantic similarity and scene retrieval. + You will be given: + - A dictionary where each key is a scene ID, and each value is a scene description. + - A query text describing a target scene. + + Your task: + return_num = 2 + - Find the most semantically similar scene IDs to the query text. + - If there are fewer than distinct relevant matches, repeat the closest ones to make a list of . + - Only output the list of scene IDs, sorted from most to less similar. + - Do NOT use markdown, JSON code blocks, or any formatting syntax, only return a plain list like ["id1", ...]. + + Input example: + Dictionary: + "{{ + "t_scene_008": "A study room with full bookshelves and a lamp in the corner.", + "t_scene_019": "A child's bedroom with pink walls and a small desk.", + "t_scene_020": "A living room with a wooden floor.", + "t_scene_021": "A living room with toys scattered on the floor.", + ... + "t_scene_office_001": "A very spacious, modern open-plan office with wide desks and no people, panoramic view." + }}" + Text: + "A traditional indoor room" + Output: + '["t_scene_office_001", ...]' + + Input: + Dictionary: + {context} + Text: + {text} + Output: + + """ + + def query( + self, text: str, context: dict, rand: bool = True, params: dict = None + ) -> str: + match_list = self.gpt_client.query( + self.prompt.format(context=context, text=text), + params=params, + ) + match_list = json_repair.loads(match_list) + result = random.choice(match_list) if rand else match_list[0] + + return result + + +def test_semantic_matcher( + bg_file: str = "outputs/bg_scenes/bg_scene_list.txt", +): + bg_file = "outputs/bg_scenes/bg_scene_list.txt" + scene_dict = {} + with open(bg_file, "r") as f: + for line in f: + line = line.strip() + if not line or ":" not in line: + continue + scene_id, desc = line.split(":", 1) + scene_dict[scene_id.strip()] = desc.strip() + + office_scene = scene_dict.get("t_scene_office_001") + text = "bright kitchen" + SCENE_MATCHER = SemanticMatcher(GPT_CLIENT) + # gpt_params = { + # "temperature": 0.8, + # "max_tokens": 500, + # "top_p": 0.8, + # "frequency_penalty": 0.3, + # "presence_penalty": 0.3, + # } + gpt_params = None + match_key = SCENE_MATCHER.query(text, str(scene_dict)) + print(match_key, ",", scene_dict[match_key]) + + +if __name__ == "__main__": + test_semantic_matcher() diff --git a/embodied_gen/validators/urdf_convertor.py b/embodied_gen/validators/urdf_convertor.py index 076c01a..b18be1e 100644 --- a/embodied_gen/validators/urdf_convertor.py +++ b/embodied_gen/validators/urdf_convertor.py @@ -297,20 +297,24 @@ class URDFGenerator(object): if not os.path.exists(urdf_path): raise FileNotFoundError(f"URDF file not found: {urdf_path}") - mesh_scale = 1.0 + mesh_attr = None tree = ET.parse(urdf_path) root = tree.getroot() extra_info = root.find(attr_root) if extra_info is not None: scale_element = extra_info.find(attr_name) if scale_element is not None: - mesh_scale = float(scale_element.text) + mesh_attr = scale_element.text + try: + mesh_attr = float(mesh_attr) + except ValueError as e: + pass - return mesh_scale + return mesh_attr @staticmethod def add_quality_tag( - urdf_path: str, results, output_path: str = None + urdf_path: str, results: list, output_path: str = None ) -> None: if output_path is None: output_path = urdf_path @@ -366,16 +370,9 @@ class URDFGenerator(object): output_root, num_images=self.render_view_num, output_subdir=self.output_render_dir, + no_index_file=True, ) - # Hardcode tmp because of the openrouter can't input multi images. - if "openrouter" in self.gpt_client.endpoint: - from embodied_gen.utils.process_media import ( - combine_images_to_base64, - ) - - image_path = combine_images_to_base64(image_path) - response = self.gpt_client.query(text_prompt, image_path) if response is None: asset_attrs = { @@ -412,14 +409,18 @@ class URDFGenerator(object): if __name__ == "__main__": urdf_gen = URDFGenerator(GPT_CLIENT, render_view_num=4) urdf_path = urdf_gen( - mesh_path="outputs/imageto3d/cma/o5/URDF_o5/mesh/o5.obj", + mesh_path="outputs/layout2/asset3d/marker/result/mesh/marker.obj", output_root="outputs/test_urdf", - # category="coffee machine", + category="marker", # min_height=1.0, # max_height=1.2, version=VERSION, ) + URDFGenerator.add_quality_tag( + urdf_path, [[urdf_gen.__class__.__name__, "OK"]] + ) + # zip_files( # input_paths=[ # "scripts/apps/tmp/2umpdum3e5n/URDF_sample/mesh", diff --git a/install.sh b/install.sh index 535d847..2568aa9 100644 --- a/install.sh +++ b/install.sh @@ -8,6 +8,12 @@ NC='\033[0m' echo -e "${GREEN}Starting installation process...${NC}" git config --global http.postBuffer 524288000 +echo -e "${GREEN}Installing flash-attn...${NC}" +pip install flash-attn==2.7.0.post2 --no-build-isolation || { + echo -e "${RED}Failed to install flash-attn${NC}" + exit 1 +} + echo -e "${GREEN}Installing dependencies from requirements.txt...${NC}" pip install -r requirements.txt --use-deprecated=legacy-resolver --default-timeout=60 || { echo -e "${RED}Failed to install requirements${NC}" @@ -15,16 +21,16 @@ pip install -r requirements.txt --use-deprecated=legacy-resolver --default-timeo } -echo -e "${GREEN}Installing kaolin from GitHub...${NC}" -pip install kaolin@git+https://github.com/NVIDIAGameWorks/kaolin.git@v0.16.0 || { - echo -e "${RED}Failed to install kaolin${NC}" +echo -e "${GREEN}Installing kolors from GitHub...${NC}" +pip install kolors@git+https://github.com/Kwai-Kolors/Kolors.git#egg=038818d || { + echo -e "${RED}Failed to install kolors${NC}" exit 1 } -echo -e "${GREEN}Installing flash-attn...${NC}" -pip install flash-attn==2.7.0.post2 --no-build-isolation || { - echo -e "${RED}Failed to install flash-attn${NC}" +echo -e "${GREEN}Installing kaolin from GitHub...${NC}" +pip install kaolin@git+https://github.com/NVIDIAGameWorks/kaolin.git@v0.16.0 || { + echo -e "${RED}Failed to install kaolin${NC}" exit 1 } @@ -39,7 +45,6 @@ rm -rf "$TMP_DIR" || { rm -rf "$TMP_DIR" exit 1 } -echo -e "${GREEN}Installation completed successfully!${NC}" echo -e "${GREEN}Installing gsplat from GitHub...${NC}" @@ -50,8 +55,9 @@ pip install git+https://github.com/nerfstudio-project/gsplat.git@v1.5.0 || { echo -e "${GREEN}Installing EmbodiedGen...${NC}" +pip install triton==2.1.0 pip install -e . || { - echo -e "${RED}Failed to install local package${NC}" + echo -e "${RED}Failed to install EmbodiedGen pyproject.toml${NC}" exit 1 } diff --git a/pyproject.toml b/pyproject.toml index c261905..adfbc51 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ packages = ["embodied_gen"] [project] name = "embodied_gen" -version = "v0.1.0" +version = "v0.1.1" readme = "README.md" license = "Apache-2.0" license-files = ["LICENSE", "NOTICE"] @@ -17,16 +17,20 @@ requires-python = ">=3.10" [project.optional-dependencies] dev = [ - "cpplint==2.0.0", - "pre-commit==2.13.0", + "cpplint", + "pre-commit", "pydocstyle", "black", "isort", + "pytest", + "pytest-mock", ] [project.scripts] drender-cli = "embodied_gen.data.differentiable_render:entrypoint" backproject-cli = "embodied_gen.data.backproject_v2:entrypoint" +img3d-cli = "embodied_gen.scripts.imageto3d:entrypoint" +text3d-cli = "embodied_gen.scripts.textto3d:text_to_3d" [tool.pydocstyle] match = '(?!test_).*(?!_pb2)\.py' diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..dbd1233 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,12 @@ +[pytest] +testpaths = + tests/test_unit + +python_files = test_*.py +python_classes = Test* +python_functions = test_* +addopts = -ra --tb=short --disable-warnings --log-cli-level=INFO +filterwarnings = + ignore::DeprecationWarning +markers = + slow: marks tests as slow diff --git a/requirements.txt b/requirements.txt index 79e459a..0c62681 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,10 +8,10 @@ triton==2.1.0 dataclasses_json easydict opencv-python>4.5 -imageio==2.36.1 -imageio-ffmpeg==0.5.1 +imageio +imageio-ffmpeg rembg==2.0.61 -trimesh==4.4.4 +trimesh moviepy==1.0.3 pymeshfix==0.17.0 igraph==0.11.8 @@ -20,21 +20,19 @@ openai==1.58.1 transformers==4.42.4 gradio==5.12.0 sentencepiece==0.2.0 -diffusers==0.31.0 -xatlas==0.0.9 +diffusers==0.34.0 +xatlas onnxruntime==1.20.1 -tenacity==8.2.2 +tenacity accelerate==0.33.0 basicsr==1.4.2 realesrgan==0.3.0 pydantic==2.9.2 vtk==9.3.1 spaces +colorlog +json-repair utils3d@git+https://github.com/EasternJournalist/utils3d.git#egg=9a4eb15 clip@git+https://github.com/openai/CLIP.git -kolors@git+https://github.com/Kwai-Kolors/Kolors.git#egg=038818d segment-anything@git+https://github.com/facebookresearch/segment-anything.git#egg=dca509f nvdiffrast@git+https://github.com/NVlabs/nvdiffrast.git#egg=729261d -# https://github.com/nerfstudio-project/gsplat/releases/download/v1.5.0/gsplat-1.5.0+pt24cu118-cp310-cp310-linux_x86_64.whl -# https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.0.post2/flash_attn-2.7.0.post2+cu11torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl -# https://huggingface.co/xinjjj/RoboAssetGen/resolve/main/wheel_cu118/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl diff --git a/tests/test_examples/test_aesthetic_predictor.py b/tests/test_examples/test_aesthetic_predictor.py new file mode 100644 index 0000000..a61eff5 --- /dev/null +++ b/tests/test_examples/test_aesthetic_predictor.py @@ -0,0 +1,31 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +import logging + +from embodied_gen.validators.aesthetic_predictor import AestheticPredictor + +logger = logging.getLogger(__name__) + + +# @pytest.mark.manual +def test_aesthetic_predictor(): + image_path = "apps/assets/example_image/sample_02.jpg" + predictor = AestheticPredictor(device="cpu") + score = predictor.predict(image_path) + + assert isinstance(score, float) + logger.info(f"Aesthetic score: {score:.3f}") diff --git a/tests/test_examples/test_quality_checkers.py b/tests/test_examples/test_quality_checkers.py new file mode 100644 index 0000000..4031c01 --- /dev/null +++ b/tests/test_examples/test_quality_checkers.py @@ -0,0 +1,119 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +import logging +import tempfile + +import pytest +from embodied_gen.utils.gpt_clients import GPT_CLIENT +from embodied_gen.utils.process_media import render_asset3d +from embodied_gen.validators.quality_checkers import ( + ImageAestheticChecker, + ImageSegChecker, + MeshGeoChecker, + SemanticConsistChecker, + TextGenAlignChecker, +) + +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="module") +def geo_checker(): + return MeshGeoChecker(GPT_CLIENT) + + +@pytest.fixture(scope="module") +def seg_checker(): + return ImageSegChecker(GPT_CLIENT) + + +@pytest.fixture(scope="module") +def aesthetic_checker(): + return ImageAestheticChecker() + + +@pytest.fixture(scope="module") +def semantic_checker(): + return SemanticConsistChecker(GPT_CLIENT) + + +@pytest.fixture(scope="module") +def textalign_checker(): + return TextGenAlignChecker(GPT_CLIENT) + + +def test_geo_checker(geo_checker): + flag, result = geo_checker( + [ + "apps/assets/example_image/sample_02.jpg", + ] + ) + logger.info(f"geo_checker: {flag}, {result}") + assert isinstance(flag, bool) + assert isinstance(result, str) + + +def test_aesthetic_checker(aesthetic_checker): + flag, result = aesthetic_checker("apps/assets/example_image/sample_02.jpg") + logger.info(f"aesthetic_checker: {flag}, {result}") + assert isinstance(flag, bool) + assert isinstance(result, float) + + +def test_seg_checker(seg_checker): + flag, result = seg_checker( + [ + "apps/assets/example_image/sample_02.jpg", + "apps/assets/example_image/sample_02.jpg", + ] + ) + logger.info(f"seg_checker: {flag}, {result}") + assert isinstance(flag, bool) + assert isinstance(result, str) + + +def test_semantic_checker(semantic_checker): + flag, result = semantic_checker( + text="can", + image=["apps/assets/example_image/sample_02.jpg"], + ) + logger.info(f"semantic_checker: {flag}, {result}") + assert isinstance(flag, bool) + assert isinstance(result, str) + + +@pytest.mark.parametrize( + "mesh_path, text_desc", + [ + ("apps/assets/example_texture/meshes/chair.obj", "chair"), + ("apps/assets/example_texture/meshes/clock.obj", "clock"), + ], +) +def test_textgen_checker(textalign_checker, mesh_path, text_desc): + with tempfile.TemporaryDirectory() as output_root: + image_list = render_asset3d( + mesh_path, + output_root=output_root, + num_images=6, + elevation=(30, -30), + output_subdir="renders", + no_index_file=True, + with_mtl=False, + ) + flag, result = textalign_checker(text_desc, image_list) + logger.info(f"textalign_checker: {flag}, {result}") diff --git a/tests/test_unit/test_agents.py b/tests/test_unit/test_agents.py new file mode 100644 index 0000000..f49fab5 --- /dev/null +++ b/tests/test_unit/test_agents.py @@ -0,0 +1,95 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +from unittest.mock import patch + +import pytest +from PIL import Image +from embodied_gen.utils.gpt_clients import GPT_CLIENT +from embodied_gen.validators.quality_checkers import ( + ImageSegChecker, + MeshGeoChecker, + SemanticConsistChecker, +) + + +@pytest.fixture(autouse=True) +def gptclient_query(): + with patch.object( + GPT_CLIENT, "query", return_value="mocked gpt response" + ) as mock: + yield mock + + +@pytest.fixture() +def gptclient_query_case2(): + with patch.object(GPT_CLIENT, "query", return_value=None) as mock: + yield mock + + +@pytest.mark.parametrize( + "input_images", + [ + "dummy_path/color_grid_6view.png", + ["dummy_path/color_grid_6view.jpg"], + [ + "dummy_path/color_grid_6view.png", + "dummy_path/color_grid_6view2.png", + ], + [ + Image.new("RGB", (64, 64), "red"), + Image.new("RGB", (64, 64), "blue"), + ], + ], +) +def test_geo_checker_varied_inputs(input_images): + geo_checker = MeshGeoChecker(GPT_CLIENT) + flag, result = geo_checker(input_images) + assert isinstance(flag, (bool, type(None))) + assert isinstance(result, str) + + +def test_seg_checker(): + seg_checker = ImageSegChecker(GPT_CLIENT) + flag, result = seg_checker( + [ + "dummy_path/sample_0_raw.png", # raw image + "dummy_path/sample_0_cond.png", # segmented image + ] + ) + assert isinstance(flag, (bool, type(None))) + assert isinstance(result, str) + + +def test_semantic_checker(): + semantic_checker = SemanticConsistChecker(GPT_CLIENT) + flag, result = semantic_checker( + text="pen", + image=["dummy_path/pen.png"], + ) + assert isinstance(flag, (bool, type(None))) + assert isinstance(result, str) + + +def test_semantic_checker(gptclient_query_case2): + semantic_checker = SemanticConsistChecker(GPT_CLIENT) + flag, result = semantic_checker( + text="pen", + image=["dummy_path/pen.png"], + ) + assert isinstance(flag, (bool, type(None))) + assert isinstance(result, str) diff --git a/tests/test_unit/test_gpt_client.py b/tests/test_unit/test_gpt_client.py new file mode 100644 index 0000000..e7b8180 --- /dev/null +++ b/tests/test_unit/test_gpt_client.py @@ -0,0 +1,94 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +import os +from unittest.mock import patch + +import pytest +import yaml +from PIL import Image +from embodied_gen.utils.gpt_clients import CONFIG_FILE, GPTclient + + +@pytest.fixture(scope="module") +def config(): + with open(CONFIG_FILE, "r") as f: + return yaml.safe_load(f) + + +@pytest.fixture +def env_vars(monkeypatch, config): + agent_type = config["agent_type"] + agent_config = config.get(agent_type, {}) + monkeypatch.setenv( + "ENDPOINT", agent_config.get("endpoint", "fake_endpoint") + ) + monkeypatch.setenv("API_KEY", agent_config.get("api_key", "fake_api_key")) + monkeypatch.setenv("API_VERSION", agent_config.get("api_version", "v1")) + monkeypatch.setenv( + "MODEL_NAME", agent_config.get("model_name", "test_model") + ) + yield + + +@pytest.fixture +def gpt_client(env_vars): + client = GPTclient( + endpoint=os.environ.get("ENDPOINT"), + api_key=os.environ.get("API_KEY"), + api_version=os.environ.get("API_VERSION"), + model_name=os.environ.get("MODEL_NAME"), + check_connection=False, + ) + return client + + +@pytest.mark.parametrize( + "text_prompt, image_base64", + [ + ("What is the capital of China?", None), + ( + "What is the content in each image?", + "apps/assets/example_image/sample_02.jpg", + ), + ( + "What is the content in each image?", + [ + "apps/assets/example_image/sample_02.jpg", + "apps/assets/example_image/sample_03.jpg", + ], + ), + ( + "What is the content in each image?", + [ + Image.new("RGB", (64, 64), "red"), + Image.new("RGB", (64, 64), "blue"), + ], + ), + ], +) +def test_gptclient_query(gpt_client, text_prompt, image_base64): + # mock GPTclient.query + with patch.object( + GPTclient, "query", return_value="mocked response" + ) as mock_query: + response = gpt_client.query( + text_prompt=text_prompt, image_base64=image_base64 + ) + assert response == "mocked response" + mock_query.assert_called_once_with( + text_prompt=text_prompt, image_base64=image_base64 + )