feat(pipe): Refine generation pipeline and add utests. (#23)
- Added intelligent quality checkers and auto-retry pipeline for `image-to-3d` and `text-to-3d`. - Added unit tests for quality checkers. - `text-to-3d` now supports more `text-to-image` models, pipeline success rate improved to 94%.
This commit is contained in:
parent
8bcfac9190
commit
a3924ae4a8
29
CHANGELOG.md
Normal file
29
CHANGELOG.md
Normal file
@ -0,0 +1,29 @@
|
||||
# Changelog
|
||||
|
||||
All notable changes to this project will be documented in this file.
|
||||
|
||||
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0).
|
||||
|
||||
|
||||
## [0.1.1] - 2025-07-xx
|
||||
### Feature
|
||||
- Added intelligent quality checkers and auto-retry pipeline for `image-to-3d` and `text-to-3d`.
|
||||
- Added unit tests for quality checkers.
|
||||
- `text-to-3d` now supports more `text-to-image` models, pipeline success rate improved to 94%.
|
||||
|
||||
## [0.1.0] - 2025-07-04
|
||||
### Feature
|
||||
🖼️ Single Image to Physics Realistic 3D Asset
|
||||
- Generates watertight, simulation-ready 3D meshes with physical attributes.
|
||||
- Auto-labels semantic and quality tags (geometry, texture, foreground, etc.).
|
||||
- Produces 2K textures with highlight removal and multi-view fusion.
|
||||
|
||||
📝 Text-to-3D Asset Generation
|
||||
- Creates realistic 3D assets from natural language (English & Chinese).
|
||||
- Filters assets via QA tags to ensure visual and geometric quality.
|
||||
|
||||
🎨 Texture Generation & Editing
|
||||
- Generates 2K textures from mesh and text with semantic alignment.
|
||||
- Plug-and-play modules adapt text-to-image models for 3D textures.
|
||||
- Supports realistic and stylized texture outputs, including text textures.
|
||||
|
||||
44
README.md
44
README.md
@ -6,6 +6,7 @@
|
||||
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Image-to-3D)
|
||||
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Text-to-3D)
|
||||
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Texture-Gen)
|
||||
[](https://mp.weixin.qq.com/s/HH1cPBhK2xcDbyCK4BBTbw)
|
||||
|
||||
|
||||
> ***EmbodiedGen*** is a generative engine to create diverse and interactive 3D worlds composed of high-quality 3D assets(mesh & 3DGS) with plausible physics, leveraging generative AI to address the challenges of generalization in embodied intelligence related research.
|
||||
@ -29,7 +30,7 @@
|
||||
```sh
|
||||
git clone https://github.com/HorizonRobotics/EmbodiedGen.git
|
||||
cd EmbodiedGen
|
||||
git checkout v0.1.0
|
||||
git checkout v0.1.1
|
||||
git submodule update --init --recursive --progress
|
||||
conda create -n embodiedgen python=3.10.13 -y
|
||||
conda activate embodiedgen
|
||||
@ -67,9 +68,8 @@ CUDA_VISIBLE_DEVICES=0 nohup python apps/image_to_3d.py > /dev/null 2>&1 &
|
||||
### ⚡ API
|
||||
Generate physically plausible 3D assets from image input via the command-line API.
|
||||
```sh
|
||||
python3 embodied_gen/scripts/imageto3d.py \
|
||||
--image_path apps/assets/example_image/sample_04.jpg apps/assets/example_image/sample_19.jpg \
|
||||
--output_root outputs/imageto3d
|
||||
img3d-cli --image_path apps/assets/example_image/sample_04.jpg apps/assets/example_image/sample_19.jpg \
|
||||
--n_retry 2 --output_root outputs/imageto3d
|
||||
|
||||
# See result(.urdf/mesh.obj/mesh.glb/gs.ply) in ${output_root}/sample_xx/result
|
||||
```
|
||||
@ -86,18 +86,29 @@ python3 embodied_gen/scripts/imageto3d.py \
|
||||
### ☁️ Service
|
||||
Deploy the text-to-3D generation service locally.
|
||||
|
||||
Text-to-image based on the Kolors model, supporting Chinese and English prompts.
|
||||
Models downloaded automatically on first run, see `download_kolors_weights`, please be patient.
|
||||
Text-to-image model based on the Kolors model, supporting Chinese and English prompts.
|
||||
Models downloaded automatically on first run, please be patient.
|
||||
```sh
|
||||
python apps/text_to_3d.py
|
||||
```
|
||||
|
||||
### ⚡ API
|
||||
Text-to-image based on the Kolors model.
|
||||
Text-to-image model based on SD3.5 Medium, English prompts only.
|
||||
Usage requires agreement to the [model license(click accept)](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium), models downloaded automatically. (ps: models with more permissive licenses found in `embodied_gen/models/image_comm_model.py`)
|
||||
|
||||
For large-scale 3D assets generation, set `--n_pipe_retry=2` to ensure high end-to-end 3D asset usability through automatic quality check and retries. For more diverse results, do not set `--seed_img`.
|
||||
|
||||
```sh
|
||||
text3d-cli --prompts "small bronze figurine of a lion" "A globe with wooden base" "wooden table with embroidery" \
|
||||
--n_image_retry 2 --n_asset_retry 2 --n_pipe_retry 1 --seed_img 0 \
|
||||
--output_root outputs/textto3d
|
||||
```
|
||||
|
||||
Text-to-image model based on the Kolors model.
|
||||
```sh
|
||||
bash embodied_gen/scripts/textto3d.sh \
|
||||
--prompts "small bronze figurine of a lion" "A globe with wooden base and latitude and longitude lines" "橙色电动手钻,有磨损细节" \
|
||||
--output_root outputs/textto3d
|
||||
--output_root outputs/textto3d_k
|
||||
```
|
||||
|
||||
---
|
||||
@ -118,12 +129,17 @@ python apps/texture_edit.py
|
||||
```
|
||||
|
||||
### ⚡ API
|
||||
Support Chinese and English prompts.
|
||||
```sh
|
||||
bash embodied_gen/scripts/texture_gen.sh \
|
||||
--mesh_path "apps/assets/example_texture/meshes/robot_text.obj" \
|
||||
--prompt "举着牌子的写实风格机器人,大眼睛,牌子上写着“Hello”的文字" \
|
||||
--output_root "outputs/texture_gen/" \
|
||||
--uuid "robot_text"
|
||||
--output_root "outputs/texture_gen/robot_text"
|
||||
|
||||
bash embodied_gen/scripts/texture_gen.sh \
|
||||
--mesh_path "apps/assets/example_texture/meshes/horse.obj" \
|
||||
--prompt "A gray horse head with flying mane and brown eyes" \
|
||||
--output_root "outputs/texture_gen/gray_horse"
|
||||
```
|
||||
|
||||
---
|
||||
@ -171,6 +187,12 @@ bash embodied_gen/scripts/texture_gen.sh \
|
||||
|
||||
---
|
||||
|
||||
## For Developer
|
||||
```sh
|
||||
pip install .[dev] && pre-commit install
|
||||
python -m pytest # Pass all unit-test are required.
|
||||
```
|
||||
|
||||
## 📚 Citation
|
||||
|
||||
If you use EmbodiedGen in your research or projects, please cite:
|
||||
@ -192,7 +214,7 @@ If you use EmbodiedGen in your research or projects, please cite:
|
||||
## 🙌 Acknowledgement
|
||||
|
||||
EmbodiedGen builds upon the following amazing projects and models:
|
||||
🌟 [Trellis](https://github.com/microsoft/TRELLIS) | 🌟 [Hunyuan-Delight](https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0) | 🌟 [Segment Anything](https://github.com/facebookresearch/segment-anything) | 🌟 [Rembg](https://github.com/danielgatis/rembg) | 🌟 [RMBG-1.4](https://huggingface.co/briaai/RMBG-1.4) | 🌟 [Stable Diffusion x4](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) | 🌟 [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) | 🌟 [Kolors](https://github.com/Kwai-Kolors/Kolors) | 🌟 [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 🌟 [Aesthetic Score](http://captions.christoph-schuhmann.de/aesthetic_viz_laion_sac+logos+ava1-l14-linearMSE-en-2.37B.html) | 🌟 [Pano2Room](https://github.com/TrickyGo/Pano2Room) | 🌟 [Diffusion360](https://github.com/ArcherFMY/SD-T2I-360PanoImage) | 🌟 [Kaolin](https://github.com/NVIDIAGameWorks/kaolin) | 🌟 [diffusers](https://github.com/huggingface/diffusers) | 🌟 [gsplat](https://github.com/nerfstudio-project/gsplat) | 🌟 [QWEN2.5VL](https://github.com/QwenLM/Qwen2.5-VL) | 🌟 [GPT4o](https://platform.openai.com/docs/models/gpt-4o)
|
||||
🌟 [Trellis](https://github.com/microsoft/TRELLIS) | 🌟 [Hunyuan-Delight](https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0) | 🌟 [Segment Anything](https://github.com/facebookresearch/segment-anything) | 🌟 [Rembg](https://github.com/danielgatis/rembg) | 🌟 [RMBG-1.4](https://huggingface.co/briaai/RMBG-1.4) | 🌟 [Stable Diffusion x4](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) | 🌟 [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) | 🌟 [Kolors](https://github.com/Kwai-Kolors/Kolors) | 🌟 [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 🌟 [Aesthetic Score](http://captions.christoph-schuhmann.de/aesthetic_viz_laion_sac+logos+ava1-l14-linearMSE-en-2.37B.html) | 🌟 [Pano2Room](https://github.com/TrickyGo/Pano2Room) | 🌟 [Diffusion360](https://github.com/ArcherFMY/SD-T2I-360PanoImage) | 🌟 [Kaolin](https://github.com/NVIDIAGameWorks/kaolin) | 🌟 [diffusers](https://github.com/huggingface/diffusers) | 🌟 [gsplat](https://github.com/nerfstudio-project/gsplat) | 🌟 [QWEN-2.5VL](https://github.com/QwenLM/Qwen2.5-VL) | 🌟 [GPT4o](https://platform.openai.com/docs/models/gpt-4o) | 🌟 [SD3.5](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium)
|
||||
|
||||
---
|
||||
|
||||
|
||||
@ -55,9 +55,9 @@ from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
||||
from embodied_gen.utils.process_media import (
|
||||
filter_image_small_connected_components,
|
||||
merge_images_video,
|
||||
render_video,
|
||||
)
|
||||
from embodied_gen.utils.tags import VERSION
|
||||
from embodied_gen.utils.trender import render_video
|
||||
from embodied_gen.validators.quality_checkers import (
|
||||
BaseChecker,
|
||||
ImageAestheticChecker,
|
||||
|
||||
@ -33,7 +33,6 @@ from tqdm import tqdm
|
||||
from embodied_gen.data.utils import (
|
||||
CameraSetting,
|
||||
DiffrastRender,
|
||||
RenderItems,
|
||||
as_list,
|
||||
calc_vertex_normals,
|
||||
import_kaolin_mesh,
|
||||
@ -42,6 +41,7 @@ from embodied_gen.data.utils import (
|
||||
render_pbr,
|
||||
save_images,
|
||||
)
|
||||
from embodied_gen.utils.enum import RenderItems
|
||||
|
||||
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
||||
os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
|
||||
@ -470,7 +470,7 @@ def parse_args():
|
||||
"--pbr_light_factor",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Light factor for mesh PBR rendering (default: 2.)",
|
||||
help="Light factor for mesh PBR rendering (default: 1.)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--with_mtl",
|
||||
@ -482,6 +482,11 @@ def parse_args():
|
||||
action="store_true",
|
||||
help="Whether to generate color .gif rendering file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_index_file",
|
||||
action="store_true",
|
||||
help="Whether skip the index file saving.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen_color_mp4",
|
||||
action="store_true",
|
||||
@ -568,7 +573,7 @@ def entrypoint(**kwargs) -> None:
|
||||
gen_viewnormal_mp4=args.gen_viewnormal_mp4,
|
||||
gen_glonormal_mp4=args.gen_glonormal_mp4,
|
||||
light_factor=args.pbr_light_factor,
|
||||
no_index_file=gen_video,
|
||||
no_index_file=gen_video or args.no_index_file,
|
||||
)
|
||||
image_render.render_mesh(
|
||||
mesh_path=args.mesh_path,
|
||||
|
||||
@ -395,6 +395,8 @@ class MeshFixer(object):
|
||||
self.vertices_np,
|
||||
np.hstack([np.full((self.faces.shape[0], 1), 3), self.faces_np]),
|
||||
)
|
||||
mesh.clean(inplace=True)
|
||||
mesh.clear_data()
|
||||
mesh = mesh.decimate(ratio, progress_bar=True)
|
||||
|
||||
# Update vertices and faces
|
||||
|
||||
@ -38,7 +38,6 @@ except ImportError:
|
||||
ChatGLMModel = None
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
import trimesh
|
||||
from kaolin.render.camera import Camera
|
||||
@ -57,7 +56,6 @@ __all__ = [
|
||||
"load_mesh_to_unit_cube",
|
||||
"as_list",
|
||||
"CameraSetting",
|
||||
"RenderItems",
|
||||
"import_kaolin_mesh",
|
||||
"save_mesh_with_mtl",
|
||||
"get_images_from_grid",
|
||||
@ -738,18 +736,6 @@ class CameraSetting:
|
||||
self.Ks = Ks
|
||||
|
||||
|
||||
@dataclass
|
||||
class RenderItems(str, Enum):
|
||||
IMAGE = "image_color"
|
||||
ALPHA = "image_mask"
|
||||
VIEW_NORMAL = "image_view_normal"
|
||||
GLOBAL_NORMAL = "image_global_normal"
|
||||
POSITION_MAP = "image_position"
|
||||
DEPTH = "image_depth"
|
||||
ALBEDO = "image_albedo"
|
||||
DIFFUSE = "image_diffuse"
|
||||
|
||||
|
||||
def _compute_az_el_by_camera_params(
|
||||
camera_params: CameraSetting, flip_az: bool = False
|
||||
):
|
||||
|
||||
236
embodied_gen/models/image_comm_model.py
Normal file
236
embodied_gen/models/image_comm_model.py
Normal file
@ -0,0 +1,236 @@
|
||||
# Project EmbodiedGen
|
||||
#
|
||||
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
# Text-to-Image generation models from Hugging Face community.
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
from diffusers import (
|
||||
ChromaPipeline,
|
||||
Cosmos2TextToImagePipeline,
|
||||
DPMSolverMultistepScheduler,
|
||||
FluxPipeline,
|
||||
KolorsPipeline,
|
||||
StableDiffusion3Pipeline,
|
||||
)
|
||||
from diffusers.quantizers import PipelineQuantizationConfig
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL import Image
|
||||
from transformers import AutoModelForCausalLM, SiglipProcessor
|
||||
|
||||
__all__ = [
|
||||
"build_hf_image_pipeline",
|
||||
]
|
||||
|
||||
|
||||
class BasePipelineLoader(ABC):
|
||||
def __init__(self, device="cuda"):
|
||||
self.device = device
|
||||
|
||||
@abstractmethod
|
||||
def load(self):
|
||||
pass
|
||||
|
||||
|
||||
class BasePipelineRunner(ABC):
|
||||
def __init__(self, pipe):
|
||||
self.pipe = pipe
|
||||
|
||||
@abstractmethod
|
||||
def run(self, prompt: str, **kwargs) -> Image.Image:
|
||||
pass
|
||||
|
||||
|
||||
# ===== SD3.5-medium =====
|
||||
class SD35Loader(BasePipelineLoader):
|
||||
def load(self):
|
||||
pipe = StableDiffusion3Pipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3.5-medium",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe = pipe.to(self.device)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
pipe.enable_attention_slicing()
|
||||
return pipe
|
||||
|
||||
|
||||
class SD35Runner(BasePipelineRunner):
|
||||
def run(self, prompt: str, **kwargs) -> Image.Image:
|
||||
return self.pipe(prompt=prompt, **kwargs).images
|
||||
|
||||
|
||||
# ===== Cosmos2 =====
|
||||
class CosmosLoader(BasePipelineLoader):
|
||||
def __init__(
|
||||
self,
|
||||
model_id="nvidia/Cosmos-Predict2-2B-Text2Image",
|
||||
local_dir="weights/cosmos2",
|
||||
device="cuda",
|
||||
):
|
||||
super().__init__(device)
|
||||
self.model_id = model_id
|
||||
self.local_dir = local_dir
|
||||
|
||||
def _patch(self):
|
||||
def patch_model(cls):
|
||||
orig = cls.from_pretrained
|
||||
|
||||
def new(*args, **kwargs):
|
||||
kwargs.setdefault("attn_implementation", "flash_attention_2")
|
||||
kwargs.setdefault("torch_dtype", torch.bfloat16)
|
||||
return orig(*args, **kwargs)
|
||||
|
||||
cls.from_pretrained = new
|
||||
|
||||
def patch_processor(cls):
|
||||
orig = cls.from_pretrained
|
||||
|
||||
def new(*args, **kwargs):
|
||||
kwargs.setdefault("use_fast", True)
|
||||
return orig(*args, **kwargs)
|
||||
|
||||
cls.from_pretrained = new
|
||||
|
||||
patch_model(AutoModelForCausalLM)
|
||||
patch_processor(SiglipProcessor)
|
||||
|
||||
def load(self):
|
||||
self._patch()
|
||||
snapshot_download(
|
||||
repo_id=self.model_id,
|
||||
local_dir=self.local_dir,
|
||||
local_dir_use_symlinks=False,
|
||||
resume_download=True,
|
||||
)
|
||||
|
||||
config = PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_4bit",
|
||||
quant_kwargs={
|
||||
"load_in_4bit": True,
|
||||
"bnb_4bit_quant_type": "nf4",
|
||||
"bnb_4bit_compute_dtype": torch.bfloat16,
|
||||
"bnb_4bit_use_double_quant": True,
|
||||
},
|
||||
components_to_quantize=["text_encoder", "transformer", "unet"],
|
||||
)
|
||||
|
||||
pipe = Cosmos2TextToImagePipeline.from_pretrained(
|
||||
self.model_id,
|
||||
torch_dtype=torch.bfloat16,
|
||||
quantization_config=config,
|
||||
use_safetensors=True,
|
||||
safety_checker=None,
|
||||
requires_safety_checker=False,
|
||||
).to(self.device)
|
||||
return pipe
|
||||
|
||||
|
||||
class CosmosRunner(BasePipelineRunner):
|
||||
def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
|
||||
return self.pipe(
|
||||
prompt=prompt, negative_prompt=negative_prompt, **kwargs
|
||||
).images
|
||||
|
||||
|
||||
# ===== Kolors =====
|
||||
class KolorsLoader(BasePipelineLoader):
|
||||
def load(self):
|
||||
pipe = KolorsPipeline.from_pretrained(
|
||||
"Kwai-Kolors/Kolors-diffusers",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
).to(self.device)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
|
||||
pipe.scheduler.config, use_karras_sigmas=True
|
||||
)
|
||||
return pipe
|
||||
|
||||
|
||||
class KolorsRunner(BasePipelineRunner):
|
||||
def run(self, prompt: str, **kwargs) -> Image.Image:
|
||||
return self.pipe(prompt=prompt, **kwargs).images
|
||||
|
||||
|
||||
# ===== Flux =====
|
||||
class FluxLoader(BasePipelineLoader):
|
||||
def load(self):
|
||||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
pipe.enable_attention_slicing()
|
||||
return pipe.to(self.device)
|
||||
|
||||
|
||||
class FluxRunner(BasePipelineRunner):
|
||||
def run(self, prompt: str, **kwargs) -> Image.Image:
|
||||
return self.pipe(prompt=prompt, **kwargs).images
|
||||
|
||||
|
||||
# ===== Chroma =====
|
||||
class ChromaLoader(BasePipelineLoader):
|
||||
def load(self):
|
||||
return ChromaPipeline.from_pretrained(
|
||||
"lodestones/Chroma", torch_dtype=torch.bfloat16
|
||||
).to(self.device)
|
||||
|
||||
|
||||
class ChromaRunner(BasePipelineRunner):
|
||||
def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
|
||||
return self.pipe(
|
||||
prompt=prompt, negative_prompt=negative_prompt, **kwargs
|
||||
).images
|
||||
|
||||
|
||||
PIPELINE_REGISTRY = {
|
||||
"sd35": (SD35Loader, SD35Runner),
|
||||
"cosmos": (CosmosLoader, CosmosRunner),
|
||||
"kolors": (KolorsLoader, KolorsRunner),
|
||||
"flux": (FluxLoader, FluxRunner),
|
||||
"chroma": (ChromaLoader, ChromaRunner),
|
||||
}
|
||||
|
||||
|
||||
def build_hf_image_pipeline(name: str, device="cuda") -> BasePipelineRunner:
|
||||
if name not in PIPELINE_REGISTRY:
|
||||
raise ValueError(f"Unsupported model: {name}")
|
||||
loader_cls, runner_cls = PIPELINE_REGISTRY[name]
|
||||
pipe = loader_cls(device=device).load()
|
||||
|
||||
return runner_cls(pipe)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model_name = "sd35"
|
||||
runner = build_hf_image_pipeline(model_name)
|
||||
# NOTE: Just for pipeline testing, generation quality at low resolution is poor.
|
||||
images = runner.run(
|
||||
prompt="A robot holding a sign that says 'Hello'",
|
||||
height=512,
|
||||
width=512,
|
||||
num_inference_steps=10,
|
||||
guidance_scale=6,
|
||||
num_images_per_prompt=1,
|
||||
)
|
||||
|
||||
for i, img in enumerate(images):
|
||||
img.save(f"image_{model_name}_{i}.jpg")
|
||||
@ -52,8 +52,11 @@ __all__ = [
|
||||
"download_kolors_weights",
|
||||
]
|
||||
|
||||
|
||||
PROMPT_APPEND = "Full view of one {}, no cropping, centered, no occlusion, isolated product photo, matte, 3D style, on a plain clean surface"
|
||||
PROMPT_APPEND = (
|
||||
"Angled 3D view of one {object}, centered, no cropping, no occlusion, isolated product photo, "
|
||||
"no surroundings, matte, on a plain clean surface, 3D style revealing multiple surfaces"
|
||||
)
|
||||
PROMPT_KAPPEND = "Single {object}, in the center of the image, white background, 3D style, best quality"
|
||||
|
||||
|
||||
def download_kolors_weights(local_dir: str = "weights/Kolors") -> None:
|
||||
@ -182,9 +185,7 @@ def text2img_gen(
|
||||
ip_image_size: int = 512,
|
||||
seed: int = None,
|
||||
) -> list[Image.Image]:
|
||||
# prompt = "Single " + prompt + ", in the center of the image"
|
||||
# prompt += ", high quality, high resolution, best quality, white background, 3D style" # noqa
|
||||
prompt = PROMPT_APPEND.format(prompt.strip())
|
||||
prompt = PROMPT_KAPPEND.format(object=prompt.strip())
|
||||
logger.info(f"Processing prompt: {prompt}")
|
||||
|
||||
generator = None
|
||||
|
||||
@ -16,13 +16,14 @@
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from glob import glob
|
||||
from shutil import copy, copytree, rmtree
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import trimesh
|
||||
from PIL import Image
|
||||
from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
|
||||
@ -37,8 +38,10 @@ from embodied_gen.models.segment_model import (
|
||||
from embodied_gen.models.sr_model import ImageRealESRGAN
|
||||
from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
|
||||
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
||||
from embodied_gen.utils.process_media import merge_images_video, render_video
|
||||
from embodied_gen.utils.log import logger
|
||||
from embodied_gen.utils.process_media import merge_images_video
|
||||
from embodied_gen.utils.tags import VERSION
|
||||
from embodied_gen.utils.trender import render_video
|
||||
from embodied_gen.validators.quality_checkers import (
|
||||
BaseChecker,
|
||||
ImageAestheticChecker,
|
||||
@ -52,19 +55,14 @@ current_dir = os.path.dirname(current_file_path)
|
||||
sys.path.append(os.path.join(current_dir, "../.."))
|
||||
from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
|
||||
"~/.cache/torch_extensions"
|
||||
)
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
|
||||
os.environ["SPCONV_ALGO"] = "native"
|
||||
random.seed(0)
|
||||
|
||||
|
||||
logger.info("Loading Models...")
|
||||
DELIGHT = DelightingModel()
|
||||
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
||||
|
||||
@ -74,7 +72,7 @@ SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
|
||||
PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
|
||||
"microsoft/TRELLIS-image-large"
|
||||
)
|
||||
PIPELINE.cuda()
|
||||
# PIPELINE.cuda()
|
||||
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
|
||||
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
|
||||
AESTHETIC_CHECKER = ImageAestheticChecker()
|
||||
@ -95,7 +93,6 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--output_root",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Root directory for saving outputs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -110,12 +107,26 @@ def parse_args():
|
||||
default=None,
|
||||
help="The mass in kg to restore the mesh real weight.",
|
||||
)
|
||||
parser.add_argument("--asset_type", type=str, default=None)
|
||||
parser.add_argument("--asset_type", type=str, nargs="+", default=None)
|
||||
parser.add_argument("--skip_exists", action="store_true")
|
||||
parser.add_argument("--strict_seg", action="store_true")
|
||||
parser.add_argument("--version", type=str, default=VERSION)
|
||||
parser.add_argument("--remove_intermediate", type=bool, default=True)
|
||||
args = parser.parse_args()
|
||||
parser.add_argument("--keep_intermediate", action="store_true")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument(
|
||||
"--n_retry",
|
||||
type=int,
|
||||
default=2,
|
||||
)
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def entrypoint(**kwargs):
|
||||
args = parse_args()
|
||||
for k, v in kwargs.items():
|
||||
if hasattr(args, k) and v is not None:
|
||||
setattr(args, k, v)
|
||||
|
||||
assert (
|
||||
args.image_path or args.image_root
|
||||
@ -125,13 +136,7 @@ def parse_args():
|
||||
args.image_path += glob(os.path.join(args.image_root, "*.jpg"))
|
||||
args.image_path += glob(os.path.join(args.image_root, "*.jpeg"))
|
||||
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
for image_path in args.image_path:
|
||||
for idx, image_path in enumerate(args.image_path):
|
||||
try:
|
||||
filename = os.path.basename(image_path).split(".")[0]
|
||||
output_root = args.output_root
|
||||
@ -141,7 +146,7 @@ if __name__ == "__main__":
|
||||
|
||||
mesh_out = f"{output_root}/{filename}.obj"
|
||||
if args.skip_exists and os.path.exists(mesh_out):
|
||||
logger.info(
|
||||
logger.warning(
|
||||
f"Skip {image_path}, already processed in {mesh_out}"
|
||||
)
|
||||
continue
|
||||
@ -149,22 +154,27 @@ if __name__ == "__main__":
|
||||
image = Image.open(image_path)
|
||||
image.save(f"{output_root}/{filename}_raw.png")
|
||||
|
||||
# Segmentation: Get segmented image using SAM or Rembg.
|
||||
# Segmentation: Get segmented image using Rembg.
|
||||
seg_path = f"{output_root}/{filename}_cond.png"
|
||||
if image.mode != "RGBA":
|
||||
seg_image = RBG_REMOVER(image, save_path=seg_path)
|
||||
seg_image = RBG_REMOVER(image) if image.mode != "RGBA" else image
|
||||
seg_image = trellis_preprocess(seg_image)
|
||||
else:
|
||||
seg_image = image
|
||||
seg_image.save(seg_path)
|
||||
|
||||
seed = args.seed
|
||||
for try_idx in range(args.n_retry):
|
||||
logger.info(
|
||||
f"Try: {try_idx + 1}/{args.n_retry}, Seed: {seed}, Prompt: {seg_path}"
|
||||
)
|
||||
# Run the pipeline
|
||||
try:
|
||||
PIPELINE.cuda()
|
||||
outputs = PIPELINE.run(
|
||||
seg_image,
|
||||
preprocess_image=False,
|
||||
seed=(
|
||||
random.randint(0, 100000) if seed is None else seed
|
||||
),
|
||||
# Optional parameters
|
||||
# seed=1,
|
||||
# sparse_structure_sampler_params={
|
||||
# "steps": 12,
|
||||
# "cfg_strength": 7.5,
|
||||
@ -174,19 +184,16 @@ if __name__ == "__main__":
|
||||
# "cfg_strength": 3,
|
||||
# },
|
||||
)
|
||||
PIPELINE.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Pipeline Failed] process {image_path}: {e}, skip."
|
||||
)
|
||||
continue
|
||||
|
||||
# Render and save color and mesh videos
|
||||
gs_model = outputs["gaussian"][0]
|
||||
mesh_model = outputs["mesh"][0]
|
||||
color_images = render_video(gs_model)["color"]
|
||||
normal_images = render_video(mesh_model)["normal"]
|
||||
video_path = os.path.join(output_root, "gs_mesh.mp4")
|
||||
merge_images_video(color_images, normal_images, video_path)
|
||||
|
||||
# Save the raw Gaussian model
|
||||
gs_path = mesh_out.replace(".obj", "_gs.ply")
|
||||
@ -210,6 +217,21 @@ if __name__ == "__main__":
|
||||
color_path = os.path.join(output_root, "color.png")
|
||||
render_gs_api(aligned_gs_path, color_path)
|
||||
|
||||
geo_flag, geo_result = GEO_CHECKER([color_path])
|
||||
logger.warning(
|
||||
f"{GEO_CHECKER.__class__.__name__}: {geo_result} for {seg_path}"
|
||||
)
|
||||
if geo_flag is True or geo_flag is None:
|
||||
break
|
||||
|
||||
seed = random.randint(0, 100000) if seed is not None else None
|
||||
|
||||
# Render the video for generated 3D asset.
|
||||
color_images = render_video(gs_model)["color"]
|
||||
normal_images = render_video(mesh_model)["normal"]
|
||||
video_path = os.path.join(output_root, "gs_mesh.mp4")
|
||||
merge_images_video(color_images, normal_images, video_path)
|
||||
|
||||
mesh = trimesh.Trimesh(
|
||||
vertices=mesh_model.vertices.cpu().numpy(),
|
||||
faces=mesh_model.faces.cpu().numpy(),
|
||||
@ -249,8 +271,8 @@ if __name__ == "__main__":
|
||||
min_mass, max_mass = map(float, args.mass_range.split("-"))
|
||||
asset_attrs["min_mass"] = min_mass
|
||||
asset_attrs["max_mass"] = max_mass
|
||||
if args.asset_type:
|
||||
asset_attrs["category"] = args.asset_type
|
||||
if isinstance(args.asset_type, list) and args.asset_type[idx]:
|
||||
asset_attrs["category"] = args.asset_type[idx]
|
||||
if args.version:
|
||||
asset_attrs["version"] = args.version
|
||||
|
||||
@ -289,8 +311,8 @@ if __name__ == "__main__":
|
||||
]
|
||||
images_list.append(images)
|
||||
|
||||
results = BaseChecker.validate(CHECKERS, images_list)
|
||||
urdf_convertor.add_quality_tag(urdf_path, results)
|
||||
qa_results = BaseChecker.validate(CHECKERS, images_list)
|
||||
urdf_convertor.add_quality_tag(urdf_path, qa_results)
|
||||
|
||||
# Organize the final result files
|
||||
result_dir = f"{output_root}/result"
|
||||
@ -303,7 +325,7 @@ if __name__ == "__main__":
|
||||
f"{result_dir}/{urdf_convertor.output_mesh_dir}",
|
||||
)
|
||||
copy(video_path, f"{result_dir}/video.mp4")
|
||||
if args.remove_intermediate:
|
||||
if not args.keep_intermediate:
|
||||
delete_dir(output_root, keep_subs=["result"])
|
||||
|
||||
except Exception as e:
|
||||
@ -311,3 +333,7 @@ if __name__ == "__main__":
|
||||
continue
|
||||
|
||||
logger.info(f"Processing complete. Outputs saved to {args.output_root}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
entrypoint()
|
||||
|
||||
@ -85,7 +85,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=0,
|
||||
default=None,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
271
embodied_gen/scripts/textto3d.py
Normal file
271
embodied_gen/scripts/textto3d.py
Normal file
@ -0,0 +1,271 @@
|
||||
# Project EmbodiedGen
|
||||
#
|
||||
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from embodied_gen.models.image_comm_model import build_hf_image_pipeline
|
||||
from embodied_gen.models.segment_model import RembgRemover
|
||||
from embodied_gen.models.text_model import PROMPT_APPEND
|
||||
from embodied_gen.scripts.imageto3d import entrypoint as imageto3d_api
|
||||
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
||||
from embodied_gen.utils.log import logger
|
||||
from embodied_gen.utils.process_media import render_asset3d
|
||||
from embodied_gen.validators.quality_checkers import (
|
||||
ImageSegChecker,
|
||||
SemanticConsistChecker,
|
||||
TextGenAlignChecker,
|
||||
)
|
||||
|
||||
# Avoid huggingface/tokenizers: The current process just got forked.
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
random.seed(0)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"text_to_image",
|
||||
"text_to_3d",
|
||||
]
|
||||
|
||||
logger.info("Loading Models...")
|
||||
SEMANTIC_CHECKER = SemanticConsistChecker(GPT_CLIENT)
|
||||
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
|
||||
TXTGEN_CHECKER = TextGenAlignChecker(GPT_CLIENT)
|
||||
PIPE_IMG = build_hf_image_pipeline("sd35")
|
||||
BG_REMOVER = RembgRemover()
|
||||
|
||||
|
||||
def text_to_image(
|
||||
prompt: str,
|
||||
save_path: str,
|
||||
n_retry: int,
|
||||
img_denoise_step: int,
|
||||
text_guidance_scale: float,
|
||||
n_img_sample: int,
|
||||
image_hw: tuple[int, int] = (1024, 1024),
|
||||
seed: int = None,
|
||||
) -> bool:
|
||||
select_image = None
|
||||
success_flag = False
|
||||
assert save_path.endswith(".png"), "Image save path must end with `.png`."
|
||||
for try_idx in range(n_retry):
|
||||
if select_image is not None:
|
||||
select_image[0].save(save_path.replace(".png", "_raw.png"))
|
||||
select_image[1].save(save_path)
|
||||
break
|
||||
|
||||
f_prompt = PROMPT_APPEND.format(object=prompt)
|
||||
logger.info(
|
||||
f"Image GEN for {os.path.basename(save_path)}\n"
|
||||
f"Try: {try_idx + 1}/{n_retry}, Seed: {seed}, Prompt: {f_prompt}"
|
||||
)
|
||||
images = PIPE_IMG.run(
|
||||
f_prompt,
|
||||
num_inference_steps=img_denoise_step,
|
||||
guidance_scale=text_guidance_scale,
|
||||
num_images_per_prompt=n_img_sample,
|
||||
height=image_hw[0],
|
||||
width=image_hw[1],
|
||||
generator=(
|
||||
torch.Generator().manual_seed(seed)
|
||||
if seed is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
for idx in range(len(images)):
|
||||
raw_image: Image.Image = images[idx]
|
||||
image = BG_REMOVER(raw_image)
|
||||
image.save(save_path)
|
||||
semantic_flag, semantic_result = SEMANTIC_CHECKER(
|
||||
prompt, [image.convert("RGB")]
|
||||
)
|
||||
seg_flag, seg_result = SEG_CHECKER(
|
||||
[raw_image, image.convert("RGB")]
|
||||
)
|
||||
if (
|
||||
(semantic_flag and seg_flag)
|
||||
or semantic_flag is None
|
||||
or seg_flag is None
|
||||
):
|
||||
select_image = [raw_image, image]
|
||||
success_flag = True
|
||||
break
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
seed = random.randint(0, 100000) if seed is not None else None
|
||||
|
||||
return success_flag
|
||||
|
||||
|
||||
def text_to_3d(**kwargs) -> dict:
|
||||
args = parse_args()
|
||||
for k, v in kwargs.items():
|
||||
if hasattr(args, k) and v is not None:
|
||||
setattr(args, k, v)
|
||||
|
||||
if args.asset_names is None or len(args.asset_names) == 0:
|
||||
args.asset_names = [f"sample3d_{i}" for i in range(len(args.prompts))]
|
||||
img_save_dir = os.path.join(args.output_root, "images")
|
||||
asset_save_dir = os.path.join(args.output_root, "asset3d")
|
||||
os.makedirs(img_save_dir, exist_ok=True)
|
||||
os.makedirs(asset_save_dir, exist_ok=True)
|
||||
results = defaultdict(dict)
|
||||
for prompt, node in zip(args.prompts, args.asset_names):
|
||||
success_flag = False
|
||||
n_pipe_retry = args.n_pipe_retry
|
||||
seed_img = args.seed_img
|
||||
seed_3d = args.seed_3d
|
||||
while success_flag is False and n_pipe_retry > 0:
|
||||
logger.info(
|
||||
f"GEN pipeline for node {node}\n"
|
||||
f"Try round: {args.n_pipe_retry-n_pipe_retry+1}/{args.n_pipe_retry}, Prompt: {prompt}"
|
||||
)
|
||||
# Text-to-image GEN
|
||||
save_node = node.replace(" ", "_")
|
||||
gen_image_path = f"{img_save_dir}/{save_node}.png"
|
||||
textgen_flag = text_to_image(
|
||||
prompt,
|
||||
gen_image_path,
|
||||
args.n_image_retry,
|
||||
args.img_denoise_step,
|
||||
args.text_guidance_scale,
|
||||
args.n_img_sample,
|
||||
seed=seed_img,
|
||||
)
|
||||
|
||||
# Asset 3D GEN
|
||||
node_save_dir = f"{asset_save_dir}/{save_node}"
|
||||
asset_type = node if "sample3d_" not in node else None
|
||||
imageto3d_api(
|
||||
image_path=[gen_image_path],
|
||||
output_root=node_save_dir,
|
||||
asset_type=[asset_type],
|
||||
seed=random.randint(0, 100000) if seed_3d is None else seed_3d,
|
||||
n_retry=args.n_asset_retry,
|
||||
keep_intermediate=args.keep_intermediate,
|
||||
)
|
||||
mesh_path = f"{node_save_dir}/result/mesh/{save_node}.obj"
|
||||
image_path = render_asset3d(
|
||||
mesh_path,
|
||||
output_root=f"{node_save_dir}/result",
|
||||
num_images=6,
|
||||
elevation=(30, -30),
|
||||
output_subdir="renders",
|
||||
no_index_file=True,
|
||||
)
|
||||
|
||||
check_text = asset_type if asset_type is not None else prompt
|
||||
qa_flag, qa_result = TXTGEN_CHECKER(check_text, image_path)
|
||||
logger.warning(
|
||||
f"Node {node}, {TXTGEN_CHECKER.__class__.__name__}: {qa_result}"
|
||||
)
|
||||
results["assets"][node] = f"{node_save_dir}/result"
|
||||
results["quality"][node] = qa_result
|
||||
|
||||
if qa_flag is None or qa_flag is True:
|
||||
success_flag = True
|
||||
break
|
||||
|
||||
n_pipe_retry -= 1
|
||||
seed_img = (
|
||||
random.randint(0, 100000) if seed_img is not None else None
|
||||
)
|
||||
seed_3d = (
|
||||
random.randint(0, 100000) if seed_3d is not None else None
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="3D Layout Generation Config")
|
||||
parser.add_argument("--prompts", nargs="+", help="text descriptions")
|
||||
parser.add_argument(
|
||||
"--output_root",
|
||||
type=str,
|
||||
help="Directory to save outputs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--asset_names",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="Asset names to generate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_img_sample",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of image samples to generate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_guidance_scale",
|
||||
type=float,
|
||||
default=7,
|
||||
help="Text-to-image guidance scale",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--img_denoise_step",
|
||||
type=int,
|
||||
default=25,
|
||||
help="Denoising steps for image generation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_image_retry",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Max retry count for image generation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_asset_retry",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Max retry count for 3D generation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_pipe_retry",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Max retry count for 3D asset generation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed_img",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Random seed for image generation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed_3d",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Random seed for 3D generation",
|
||||
)
|
||||
parser.add_argument("--keep_intermediate", action="store_true")
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
text_to_3d()
|
||||
@ -2,7 +2,9 @@
|
||||
|
||||
# Initialize variables
|
||||
prompts=()
|
||||
asset_types=()
|
||||
output_root=""
|
||||
seed=0
|
||||
|
||||
# Parse arguments
|
||||
while [[ $# -gt 0 ]]; do
|
||||
@ -14,10 +16,21 @@ while [[ $# -gt 0 ]]; do
|
||||
shift
|
||||
done
|
||||
;;
|
||||
--asset_types)
|
||||
shift
|
||||
while [[ $# -gt 0 && ! "$1" =~ ^-- ]]; do
|
||||
asset_types+=("$1")
|
||||
shift
|
||||
done
|
||||
;;
|
||||
--output_root)
|
||||
output_root="$2"
|
||||
shift 2
|
||||
;;
|
||||
--seed)
|
||||
seed="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
echo "Unknown argument: $1"
|
||||
exit 1
|
||||
@ -28,7 +41,21 @@ done
|
||||
# Validate required arguments
|
||||
if [[ ${#prompts[@]} -eq 0 || -z "$output_root" ]]; then
|
||||
echo "Missing required arguments."
|
||||
echo "Usage: bash run_text2asset3d.sh --prompts \"Prompt1\" \"Prompt2\" --output_root <path>"
|
||||
echo "Usage: bash run_text2asset3d.sh --prompts \"Prompt1\" \"Prompt2\" \
|
||||
--asset_types \"type1\" \"type2\" --seed <seed_value> --output_root <path>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# If no asset_types provided, default to ""
|
||||
if [[ ${#asset_types[@]} -eq 0 ]]; then
|
||||
for (( i=0; i<${#prompts[@]}; i++ )); do
|
||||
asset_types+=("")
|
||||
done
|
||||
fi
|
||||
|
||||
# Ensure the number of asset_types matches the number of prompts
|
||||
if [[ ${#prompts[@]} -ne ${#asset_types[@]} ]]; then
|
||||
echo "The number of asset types must match the number of prompts."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@ -37,20 +64,30 @@ echo "Prompts:"
|
||||
for p in "${prompts[@]}"; do
|
||||
echo " - $p"
|
||||
done
|
||||
# echo "Asset types:"
|
||||
# for at in "${asset_types[@]}"; do
|
||||
# echo " - $at"
|
||||
# done
|
||||
echo "Output root: ${output_root}"
|
||||
echo "Seed: ${seed}"
|
||||
|
||||
# Concatenate prompts for Python command
|
||||
# Concatenate prompts and asset types for Python command
|
||||
prompt_args=""
|
||||
for p in "${prompts[@]}"; do
|
||||
prompt_args+="\"$p\" "
|
||||
asset_type_args=""
|
||||
for i in "${!prompts[@]}"; do
|
||||
prompt_args+="\"${prompts[$i]}\" "
|
||||
asset_type_args+="\"${asset_types[$i]}\" "
|
||||
done
|
||||
|
||||
|
||||
# Step 1: Text-to-Image
|
||||
eval python3 embodied_gen/scripts/text2image.py \
|
||||
--prompts ${prompt_args} \
|
||||
--output_root "${output_root}/images"
|
||||
--output_root "${output_root}/images" \
|
||||
--seed ${seed}
|
||||
|
||||
# Step 2: Image-to-3D
|
||||
python3 embodied_gen/scripts/imageto3d.py \
|
||||
--image_root "${output_root}/images" \
|
||||
--output_root "${output_root}/asset3d"
|
||||
--output_root "${output_root}/asset3d" \
|
||||
--asset_type ${asset_type_args}
|
||||
|
||||
@ -10,10 +10,6 @@ while [[ $# -gt 0 ]]; do
|
||||
prompt="$2"
|
||||
shift 2
|
||||
;;
|
||||
--uuid)
|
||||
uuid="$2"
|
||||
shift 2
|
||||
;;
|
||||
--output_root)
|
||||
output_root="$2"
|
||||
shift 2
|
||||
@ -26,12 +22,13 @@ while [[ $# -gt 0 ]]; do
|
||||
done
|
||||
|
||||
|
||||
if [[ -z "$mesh_path" || -z "$prompt" || -z "$uuid" || -z "$output_root" ]]; then
|
||||
if [[ -z "$mesh_path" || -z "$prompt" || -z "$output_root" ]]; then
|
||||
echo "params missing"
|
||||
echo "usage: bash run.sh --mesh_path <path> --prompt <text> --uuid <id> --output_root <path>"
|
||||
echo "usage: bash run.sh --mesh_path <path> --prompt <text> --output_root <path>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
uuid=$(basename "$output_root")
|
||||
# Step 1: drender-cli for condition rendering
|
||||
drender-cli --mesh_path ${mesh_path} \
|
||||
--output_root ${output_root}/condition \
|
||||
|
||||
107
embodied_gen/utils/enum.py
Normal file
107
embodied_gen/utils/enum.py
Normal file
@ -0,0 +1,107 @@
|
||||
# Project EmbodiedGen
|
||||
#
|
||||
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
from dataclasses_json import DataClassJsonMixin
|
||||
|
||||
__all__ = [
|
||||
"RenderItems",
|
||||
"Scene3DItemEnum",
|
||||
"SpatialRelationEnum",
|
||||
"RobotItemEnum",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RenderItems(str, Enum):
|
||||
IMAGE = "image_color"
|
||||
ALPHA = "image_mask"
|
||||
VIEW_NORMAL = "image_view_normal"
|
||||
GLOBAL_NORMAL = "image_global_normal"
|
||||
POSITION_MAP = "image_position"
|
||||
DEPTH = "image_depth"
|
||||
ALBEDO = "image_albedo"
|
||||
DIFFUSE = "image_diffuse"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Scene3DItemEnum(str, Enum):
|
||||
BACKGROUND = "background"
|
||||
CONTEXT = "context"
|
||||
ROBOT = "robot"
|
||||
MANIPULATED_OBJS = "manipulated_objs"
|
||||
DISTRACTOR_OBJS = "distractor_objs"
|
||||
OTHERS = "others"
|
||||
|
||||
@classmethod
|
||||
def object_list(cls, layout_relation: dict) -> list:
|
||||
return (
|
||||
[
|
||||
layout_relation[cls.BACKGROUND.value],
|
||||
layout_relation[cls.CONTEXT.value],
|
||||
]
|
||||
+ layout_relation[cls.MANIPULATED_OBJS.value]
|
||||
+ layout_relation[cls.DISTRACTOR_OBJS.value]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def object_mapping(cls, layout_relation):
|
||||
relation_mapping = {
|
||||
# layout_relation[cls.ROBOT.value]: cls.ROBOT.value,
|
||||
layout_relation[cls.BACKGROUND.value]: cls.BACKGROUND.value,
|
||||
layout_relation[cls.CONTEXT.value]: cls.CONTEXT.value,
|
||||
}
|
||||
relation_mapping.update(
|
||||
{
|
||||
item: cls.MANIPULATED_OBJS.value
|
||||
for item in layout_relation[cls.MANIPULATED_OBJS.value]
|
||||
}
|
||||
)
|
||||
relation_mapping.update(
|
||||
{
|
||||
item: cls.DISTRACTOR_OBJS.value
|
||||
for item in layout_relation[cls.DISTRACTOR_OBJS.value]
|
||||
}
|
||||
)
|
||||
|
||||
return relation_mapping
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpatialRelationEnum(str, Enum):
|
||||
ON = "ON" # objects on the table
|
||||
IN = "IN" # objects in the room
|
||||
INSIDE = "INSIDE" # objects inside the shelf/rack
|
||||
FLOOR = "FLOOR" # object floor room/bin
|
||||
|
||||
|
||||
@dataclass
|
||||
class RobotItemEnum(str, Enum):
|
||||
FRANKA = "franka"
|
||||
UR5 = "ur5"
|
||||
PIPER = "piper"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayoutInfo(DataClassJsonMixin):
|
||||
tree: dict[str, list]
|
||||
relation: dict[str, str | list[str]]
|
||||
objs_desc: dict[str, str] = field(default_factory=dict)
|
||||
assets: dict[str, str] = field(default_factory=dict)
|
||||
quality: dict[str, str] = field(default_factory=dict)
|
||||
position: dict[str, list[float]] = field(default_factory=dict)
|
||||
@ -30,12 +30,20 @@ from tenacity import (
|
||||
stop_after_delay,
|
||||
wait_random_exponential,
|
||||
)
|
||||
from embodied_gen.utils.process_media import combine_images_to_base64
|
||||
from embodied_gen.utils.process_media import combine_images_to_grid
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"GPTclient",
|
||||
]
|
||||
|
||||
CONFIG_FILE = "embodied_gen/utils/gpt_config.yaml"
|
||||
|
||||
|
||||
class GPTclient:
|
||||
"""A client to interact with the GPT model via OpenAI or Azure API."""
|
||||
|
||||
@ -45,6 +53,7 @@ class GPTclient:
|
||||
api_key: str,
|
||||
model_name: str = "yfb-gpt-4o",
|
||||
api_version: str = None,
|
||||
check_connection: bool = True,
|
||||
verbose: bool = False,
|
||||
):
|
||||
if api_version is not None:
|
||||
@ -63,6 +72,9 @@ class GPTclient:
|
||||
self.model_name = model_name
|
||||
self.image_formats = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"}
|
||||
self.verbose = verbose
|
||||
if check_connection:
|
||||
self.check_connection()
|
||||
|
||||
logger.info(f"Using GPT model: {self.model_name}.")
|
||||
|
||||
@retry(
|
||||
@ -77,6 +89,7 @@ class GPTclient:
|
||||
text_prompt: str,
|
||||
image_base64: Optional[list[str | Image.Image]] = None,
|
||||
system_role: Optional[str] = None,
|
||||
params: Optional[dict] = None,
|
||||
) -> Optional[str]:
|
||||
"""Queries the GPT model with a text and optional image prompts.
|
||||
|
||||
@ -86,6 +99,7 @@ class GPTclient:
|
||||
or local image paths or PIL.Image to accompany the text prompt.
|
||||
system_role (Optional[str]): Optional system-level instructions
|
||||
that specify the behavior of the assistant.
|
||||
params (Optional[dict]): Additional parameters for GPT setting.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The response content generated by the model based on
|
||||
@ -103,11 +117,11 @@ class GPTclient:
|
||||
|
||||
# Process images if provided
|
||||
if image_base64 is not None:
|
||||
image_base64 = (
|
||||
image_base64
|
||||
if isinstance(image_base64, list)
|
||||
else [image_base64]
|
||||
)
|
||||
if not isinstance(image_base64, list):
|
||||
image_base64 = [image_base64]
|
||||
# Hardcode tmp because of the openrouter can't input multi images.
|
||||
if "openrouter" in self.endpoint:
|
||||
image_base64 = combine_images_to_grid(image_base64)
|
||||
for img in image_base64:
|
||||
if isinstance(img, Image.Image):
|
||||
buffer = BytesIO()
|
||||
@ -142,8 +156,11 @@ class GPTclient:
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"stop": None,
|
||||
"model": self.model_name,
|
||||
}
|
||||
payload.update({"model": self.model_name})
|
||||
|
||||
if params:
|
||||
payload.update(params)
|
||||
|
||||
response = None
|
||||
try:
|
||||
@ -159,8 +176,28 @@ class GPTclient:
|
||||
|
||||
return response
|
||||
|
||||
def check_connection(self) -> None:
|
||||
"""Check whether the GPT API connection is working."""
|
||||
try:
|
||||
response = self.completion_with_backoff(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a test system."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
],
|
||||
model=self.model_name,
|
||||
temperature=0,
|
||||
max_tokens=100,
|
||||
)
|
||||
content = response.choices[0].message.content
|
||||
logger.info(f"Connection check success.")
|
||||
except Exception as e:
|
||||
raise ConnectionError(
|
||||
f"Failed to connect to GPT API at {self.endpoint}, "
|
||||
f"please check setting in `{CONFIG_FILE}` and `README`."
|
||||
)
|
||||
|
||||
with open("embodied_gen/utils/gpt_config.yaml", "r") as f:
|
||||
|
||||
with open(CONFIG_FILE, "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
agent_type = config["agent_type"]
|
||||
@ -177,32 +214,5 @@ GPT_CLIENT = GPTclient(
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
model_name=model_name,
|
||||
check_connection=False,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
if "openrouter" in GPT_CLIENT.endpoint:
|
||||
response = GPT_CLIENT.query(
|
||||
text_prompt="What is the content in each image?",
|
||||
image_base64=combine_images_to_base64(
|
||||
[
|
||||
"apps/assets/example_image/sample_02.jpg",
|
||||
"apps/assets/example_image/sample_03.jpg",
|
||||
]
|
||||
), # input raw image_path if only one image
|
||||
)
|
||||
print(response)
|
||||
else:
|
||||
response = GPT_CLIENT.query(
|
||||
text_prompt="What is the content in the images?",
|
||||
image_base64=[
|
||||
Image.open("apps/assets/example_image/sample_02.jpg"),
|
||||
Image.open("apps/assets/example_image/sample_03.jpg"),
|
||||
],
|
||||
)
|
||||
print(response)
|
||||
|
||||
# test2: text prompt
|
||||
response = GPT_CLIENT.query(
|
||||
text_prompt="What is the capital of China?"
|
||||
)
|
||||
print(response)
|
||||
|
||||
48
embodied_gen/utils/log.py
Normal file
48
embodied_gen/utils/log.py
Normal file
@ -0,0 +1,48 @@
|
||||
# Project EmbodiedGen
|
||||
#
|
||||
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from colorlog import ColoredFormatter
|
||||
|
||||
__all__ = [
|
||||
"logger",
|
||||
]
|
||||
|
||||
LOG_FORMAT = (
|
||||
"%(log_color)s[%(asctime)s] %(levelname)-8s | %(message)s%(reset)s"
|
||||
)
|
||||
DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
formatter = ColoredFormatter(
|
||||
LOG_FORMAT,
|
||||
datefmt=DATE_FORMAT,
|
||||
log_colors={
|
||||
"DEBUG": "cyan",
|
||||
"INFO": "green",
|
||||
"WARNING": "yellow",
|
||||
"ERROR": "red",
|
||||
"CRITICAL": "bold_red",
|
||||
},
|
||||
)
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.addHandler(handler)
|
||||
logger.propagate = False
|
||||
@ -15,34 +15,24 @@
|
||||
# permissions and limitations under the License.
|
||||
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import textwrap
|
||||
from glob import glob
|
||||
from io import BytesIO
|
||||
from typing import Union
|
||||
|
||||
import cv2
|
||||
import imageio
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import PIL.Image as Image
|
||||
import spaces
|
||||
import torch
|
||||
from matplotlib.patches import Patch
|
||||
from moviepy.editor import VideoFileClip, clips_array
|
||||
from tqdm import tqdm
|
||||
from embodied_gen.data.differentiable_render import entrypoint as render_api
|
||||
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
current_dir = os.path.dirname(current_file_path)
|
||||
sys.path.append(os.path.join(current_dir, "../.."))
|
||||
from thirdparty.TRELLIS.trellis.renderers.mesh_renderer import MeshRenderer
|
||||
from thirdparty.TRELLIS.trellis.representations import MeshExtractResult
|
||||
from thirdparty.TRELLIS.trellis.utils.render_utils import (
|
||||
render_frames,
|
||||
yaw_pitch_r_fov_to_extrinsics_intrinsics,
|
||||
)
|
||||
from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -53,9 +43,8 @@ __all__ = [
|
||||
"merge_images_video",
|
||||
"filter_small_connected_components",
|
||||
"filter_image_small_connected_components",
|
||||
"combine_images_to_base64",
|
||||
"render_mesh",
|
||||
"render_video",
|
||||
"combine_images_to_grid",
|
||||
"SceneTreeVisualizer",
|
||||
]
|
||||
|
||||
|
||||
@ -66,12 +55,14 @@ def render_asset3d(
|
||||
distance: float = 5.0,
|
||||
num_images: int = 1,
|
||||
elevation: list[float] = (0.0,),
|
||||
pbr_light_factor: float = 1.5,
|
||||
pbr_light_factor: float = 1.2,
|
||||
return_key: str = "image_color/*",
|
||||
output_subdir: str = "renders",
|
||||
gen_color_mp4: bool = False,
|
||||
gen_viewnormal_mp4: bool = False,
|
||||
gen_glonormal_mp4: bool = False,
|
||||
no_index_file: bool = False,
|
||||
with_mtl: bool = True,
|
||||
) -> list[str]:
|
||||
input_args = dict(
|
||||
mesh_path=mesh_path,
|
||||
@ -81,14 +72,13 @@ def render_asset3d(
|
||||
num_images=num_images,
|
||||
elevation=elevation,
|
||||
pbr_light_factor=pbr_light_factor,
|
||||
with_mtl=True,
|
||||
with_mtl=with_mtl,
|
||||
gen_color_mp4=gen_color_mp4,
|
||||
gen_viewnormal_mp4=gen_viewnormal_mp4,
|
||||
gen_glonormal_mp4=gen_glonormal_mp4,
|
||||
no_index_file=no_index_file,
|
||||
)
|
||||
if gen_color_mp4:
|
||||
input_args["gen_color_mp4"] = True
|
||||
if gen_viewnormal_mp4:
|
||||
input_args["gen_viewnormal_mp4"] = True
|
||||
if gen_glonormal_mp4:
|
||||
input_args["gen_glonormal_mp4"] = True
|
||||
|
||||
try:
|
||||
_ = render_api(**input_args)
|
||||
except Exception as e:
|
||||
@ -168,12 +158,15 @@ def filter_image_small_connected_components(
|
||||
return image
|
||||
|
||||
|
||||
def combine_images_to_base64(
|
||||
def combine_images_to_grid(
|
||||
images: list[str | Image.Image],
|
||||
cat_row_col: tuple[int, int] = None,
|
||||
target_wh: tuple[int, int] = (512, 512),
|
||||
) -> str:
|
||||
) -> list[str | Image.Image]:
|
||||
n_images = len(images)
|
||||
if n_images == 1:
|
||||
return images
|
||||
|
||||
if cat_row_col is None:
|
||||
n_col = math.ceil(math.sqrt(n_images))
|
||||
n_row = math.ceil(n_images / n_col)
|
||||
@ -182,88 +175,190 @@ def combine_images_to_base64(
|
||||
|
||||
images = [
|
||||
Image.open(p).convert("RGB") if isinstance(p, str) else p
|
||||
for p in images[: n_row * n_col]
|
||||
for p in images
|
||||
]
|
||||
images = [img.resize(target_wh) for img in images]
|
||||
|
||||
grid_w, grid_h = n_col * target_wh[0], n_row * target_wh[1]
|
||||
grid = Image.new("RGB", (grid_w, grid_h), (255, 255, 255))
|
||||
grid = Image.new("RGB", (grid_w, grid_h), (0, 0, 0))
|
||||
|
||||
for idx, img in enumerate(images):
|
||||
row, col = divmod(idx, n_col)
|
||||
grid.paste(img, (col * target_wh[0], row * target_wh[1]))
|
||||
|
||||
buffer = BytesIO()
|
||||
grid.save(buffer, format="PNG")
|
||||
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
return [grid]
|
||||
|
||||
|
||||
@spaces.GPU
|
||||
def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs):
|
||||
renderer = MeshRenderer()
|
||||
renderer.rendering_options.resolution = options.get("resolution", 512)
|
||||
renderer.rendering_options.near = options.get("near", 1)
|
||||
renderer.rendering_options.far = options.get("far", 100)
|
||||
renderer.rendering_options.ssaa = options.get("ssaa", 4)
|
||||
rets = {}
|
||||
for extr, intr in tqdm(zip(extrinsics, intrinsics), desc="Rendering"):
|
||||
res = renderer.render(sample, extr, intr)
|
||||
if "normal" not in rets:
|
||||
rets["normal"] = []
|
||||
normal = torch.lerp(
|
||||
torch.zeros_like(res["normal"]), res["normal"], res["mask"]
|
||||
class SceneTreeVisualizer:
|
||||
def __init__(self, layout_info: LayoutInfo) -> None:
|
||||
self.tree = layout_info.tree
|
||||
self.relation = layout_info.relation
|
||||
self.objs_desc = layout_info.objs_desc
|
||||
self.G = nx.DiGraph()
|
||||
self.root = self._find_root()
|
||||
self._build_graph()
|
||||
|
||||
self.role_colors = {
|
||||
Scene3DItemEnum.BACKGROUND.value: "plum",
|
||||
Scene3DItemEnum.CONTEXT.value: "lightblue",
|
||||
Scene3DItemEnum.ROBOT.value: "lightcoral",
|
||||
Scene3DItemEnum.MANIPULATED_OBJS.value: "lightgreen",
|
||||
Scene3DItemEnum.DISTRACTOR_OBJS.value: "lightgray",
|
||||
Scene3DItemEnum.OTHERS.value: "orange",
|
||||
}
|
||||
|
||||
def _find_root(self) -> str:
|
||||
children = {c for cs in self.tree.values() for c, _ in cs}
|
||||
parents = set(self.tree.keys())
|
||||
roots = parents - children
|
||||
if not roots:
|
||||
raise ValueError("No root node found.")
|
||||
return next(iter(roots))
|
||||
|
||||
def _build_graph(self):
|
||||
for parent, children in self.tree.items():
|
||||
for child, relation in children:
|
||||
self.G.add_edge(parent, child, relation=relation)
|
||||
|
||||
def _get_node_role(self, node: str) -> str:
|
||||
if node == self.relation.get(Scene3DItemEnum.BACKGROUND.value):
|
||||
return Scene3DItemEnum.BACKGROUND.value
|
||||
if node == self.relation.get(Scene3DItemEnum.CONTEXT.value):
|
||||
return Scene3DItemEnum.CONTEXT.value
|
||||
if node == self.relation.get(Scene3DItemEnum.ROBOT.value):
|
||||
return Scene3DItemEnum.ROBOT.value
|
||||
if node in self.relation.get(
|
||||
Scene3DItemEnum.MANIPULATED_OBJS.value, []
|
||||
):
|
||||
return Scene3DItemEnum.MANIPULATED_OBJS.value
|
||||
if node in self.relation.get(
|
||||
Scene3DItemEnum.DISTRACTOR_OBJS.value, []
|
||||
):
|
||||
return Scene3DItemEnum.DISTRACTOR_OBJS.value
|
||||
return Scene3DItemEnum.OTHERS.value
|
||||
|
||||
def _get_positions(
|
||||
self, root, width=1.0, vert_gap=0.1, vert_loc=1, xcenter=0.5, pos=None
|
||||
):
|
||||
if pos is None:
|
||||
pos = {root: (xcenter, vert_loc)}
|
||||
else:
|
||||
pos[root] = (xcenter, vert_loc)
|
||||
|
||||
children = list(self.G.successors(root))
|
||||
if children:
|
||||
dx = width / len(children)
|
||||
next_x = xcenter - width / 2 - dx / 2
|
||||
for child in children:
|
||||
next_x += dx
|
||||
pos = self._get_positions(
|
||||
child,
|
||||
width=dx,
|
||||
vert_gap=vert_gap,
|
||||
vert_loc=vert_loc - vert_gap,
|
||||
xcenter=next_x,
|
||||
pos=pos,
|
||||
)
|
||||
normal = np.clip(
|
||||
normal.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255
|
||||
).astype(np.uint8)
|
||||
rets["normal"].append(normal)
|
||||
return pos
|
||||
|
||||
return rets
|
||||
def render(
|
||||
self,
|
||||
save_path: str,
|
||||
figsize=(8, 6),
|
||||
dpi=300,
|
||||
title: str = "Scene 3D Hierarchy Tree",
|
||||
):
|
||||
node_colors = [
|
||||
self.role_colors[self._get_node_role(n)] for n in self.G.nodes
|
||||
]
|
||||
pos = self._get_positions(self.root)
|
||||
|
||||
|
||||
@spaces.GPU
|
||||
def render_video(
|
||||
sample,
|
||||
resolution=512,
|
||||
bg_color=(0, 0, 0),
|
||||
num_frames=300,
|
||||
r=2,
|
||||
fov=40,
|
||||
**kwargs,
|
||||
):
|
||||
yaws = torch.linspace(0, 2 * 3.1415, num_frames)
|
||||
yaws = yaws.tolist()
|
||||
pitch = [0.5] * num_frames
|
||||
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(
|
||||
yaws, pitch, r, fov
|
||||
)
|
||||
render_fn = (
|
||||
render_mesh if isinstance(sample, MeshExtractResult) else render_frames
|
||||
)
|
||||
result = render_fn(
|
||||
sample,
|
||||
extrinsics,
|
||||
intrinsics,
|
||||
{"resolution": resolution, "bg_color": bg_color},
|
||||
**kwargs,
|
||||
plt.figure(figsize=figsize)
|
||||
nx.draw(
|
||||
self.G,
|
||||
pos,
|
||||
with_labels=True,
|
||||
arrows=False,
|
||||
node_size=2000,
|
||||
node_color=node_colors,
|
||||
font_size=10,
|
||||
font_weight="bold",
|
||||
)
|
||||
|
||||
return result
|
||||
# Draw edge labels
|
||||
edge_labels = nx.get_edge_attributes(self.G, "relation")
|
||||
nx.draw_networkx_edge_labels(
|
||||
self.G,
|
||||
pos,
|
||||
edge_labels=edge_labels,
|
||||
font_size=9,
|
||||
font_color="black",
|
||||
)
|
||||
|
||||
# Draw small description text under each node (if available)
|
||||
for node, (x, y) in pos.items():
|
||||
desc = self.objs_desc.get(node)
|
||||
if desc:
|
||||
wrapped = "\n".join(textwrap.wrap(desc, width=30))
|
||||
plt.text(
|
||||
x,
|
||||
y - 0.006,
|
||||
wrapped,
|
||||
fontsize=6,
|
||||
ha="center",
|
||||
va="top",
|
||||
wrap=True,
|
||||
color="black",
|
||||
bbox=dict(
|
||||
facecolor="dimgray",
|
||||
edgecolor="darkgray",
|
||||
alpha=0.1,
|
||||
boxstyle="round,pad=0.2",
|
||||
),
|
||||
)
|
||||
|
||||
plt.title(title, fontsize=12)
|
||||
task_desc = self.relation.get("task_desc", "")
|
||||
if task_desc:
|
||||
plt.suptitle(
|
||||
f"Task Description: {task_desc}", fontsize=10, y=0.999
|
||||
)
|
||||
|
||||
plt.axis("off")
|
||||
|
||||
legend_handles = [
|
||||
Patch(facecolor=color, edgecolor='black', label=role)
|
||||
for role, color in self.role_colors.items()
|
||||
]
|
||||
plt.legend(
|
||||
handles=legend_handles,
|
||||
loc="lower center",
|
||||
ncol=3,
|
||||
bbox_to_anchor=(0.5, -0.1),
|
||||
fontsize=9,
|
||||
)
|
||||
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
plt.savefig(save_path, dpi=dpi, bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
|
||||
def load_scene_dict(file_path: str) -> dict:
|
||||
scene_dict = {}
|
||||
with open(file_path, "r", encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line or ":" not in line:
|
||||
continue
|
||||
scene_id, desc = line.split(":", 1)
|
||||
scene_dict[scene_id.strip()] = desc.strip()
|
||||
|
||||
return scene_dict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage:
|
||||
merge_video_video(
|
||||
"outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh_glo_normal.mp4", # noqa
|
||||
"outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh.mp4", # noqa
|
||||
"merge.mp4",
|
||||
)
|
||||
|
||||
image_base64 = combine_images_to_base64(
|
||||
[
|
||||
"apps/assets/example_image/sample_00.jpg",
|
||||
"apps/assets/example_image/sample_01.jpg",
|
||||
"apps/assets/example_image/sample_02.jpg",
|
||||
]
|
||||
)
|
||||
|
||||
@ -1 +1 @@
|
||||
VERSION = "v0.1.0"
|
||||
VERSION = "v0.1.1"
|
||||
|
||||
90
embodied_gen/utils/trender.py
Normal file
90
embodied_gen/utils/trender.py
Normal file
@ -0,0 +1,90 @@
|
||||
# Project EmbodiedGen
|
||||
#
|
||||
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import spaces
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
current_dir = os.path.dirname(current_file_path)
|
||||
sys.path.append(os.path.join(current_dir, "../.."))
|
||||
from thirdparty.TRELLIS.trellis.renderers.mesh_renderer import MeshRenderer
|
||||
from thirdparty.TRELLIS.trellis.representations import MeshExtractResult
|
||||
from thirdparty.TRELLIS.trellis.utils.render_utils import (
|
||||
render_frames,
|
||||
yaw_pitch_r_fov_to_extrinsics_intrinsics,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"render_video",
|
||||
]
|
||||
|
||||
|
||||
@spaces.GPU
|
||||
def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs):
|
||||
renderer = MeshRenderer()
|
||||
renderer.rendering_options.resolution = options.get("resolution", 512)
|
||||
renderer.rendering_options.near = options.get("near", 1)
|
||||
renderer.rendering_options.far = options.get("far", 100)
|
||||
renderer.rendering_options.ssaa = options.get("ssaa", 4)
|
||||
rets = {}
|
||||
for extr, intr in tqdm(zip(extrinsics, intrinsics), desc="Rendering"):
|
||||
res = renderer.render(sample, extr, intr)
|
||||
if "normal" not in rets:
|
||||
rets["normal"] = []
|
||||
normal = torch.lerp(
|
||||
torch.zeros_like(res["normal"]), res["normal"], res["mask"]
|
||||
)
|
||||
normal = np.clip(
|
||||
normal.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255
|
||||
).astype(np.uint8)
|
||||
rets["normal"].append(normal)
|
||||
|
||||
return rets
|
||||
|
||||
|
||||
@spaces.GPU
|
||||
def render_video(
|
||||
sample,
|
||||
resolution=512,
|
||||
bg_color=(0, 0, 0),
|
||||
num_frames=300,
|
||||
r=2,
|
||||
fov=40,
|
||||
**kwargs,
|
||||
):
|
||||
yaws = torch.linspace(0, 2 * 3.1415, num_frames)
|
||||
yaws = yaws.tolist()
|
||||
pitch = [0.5] * num_frames
|
||||
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(
|
||||
yaws, pitch, r, fov
|
||||
)
|
||||
render_fn = (
|
||||
render_mesh if isinstance(sample, MeshExtractResult) else render_frames
|
||||
)
|
||||
result = render_fn(
|
||||
sample,
|
||||
extrinsics,
|
||||
intrinsics,
|
||||
{"resolution": resolution, "bg_color": bg_color},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return result
|
||||
@ -102,7 +102,7 @@ class AestheticPredictor:
|
||||
def _load_sac_model(self, model_path, input_size):
|
||||
"""Load the SAC model."""
|
||||
model = self.MLP(input_size)
|
||||
ckpt = torch.load(model_path)
|
||||
ckpt = torch.load(model_path, weights_only=True)
|
||||
model.load_state_dict(ckpt)
|
||||
model.to(self.device)
|
||||
model.eval()
|
||||
@ -135,15 +135,3 @@ class AestheticPredictor:
|
||||
)
|
||||
|
||||
return prediction.item()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Configuration
|
||||
img_path = "apps/assets/example_image/sample_00.jpg"
|
||||
|
||||
# Initialize the predictor
|
||||
predictor = AestheticPredictor()
|
||||
|
||||
# Predict the aesthetic score
|
||||
score = predictor.predict(img_path)
|
||||
print("Aesthetic score predicted by the model:", score)
|
||||
|
||||
@ -16,17 +16,26 @@
|
||||
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
|
||||
from tqdm import tqdm
|
||||
import json_repair
|
||||
from PIL import Image
|
||||
from embodied_gen.utils.gpt_clients import GPT_CLIENT, GPTclient
|
||||
from embodied_gen.utils.process_media import render_asset3d
|
||||
from embodied_gen.validators.aesthetic_predictor import AestheticPredictor
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MeshGeoChecker",
|
||||
"ImageSegChecker",
|
||||
"ImageAestheticChecker",
|
||||
"SemanticConsistChecker",
|
||||
"TextGenAlignChecker",
|
||||
]
|
||||
|
||||
|
||||
class BaseChecker:
|
||||
def __init__(self, prompt: str = None, verbose: bool = False) -> None:
|
||||
self.prompt = prompt
|
||||
@ -37,14 +46,18 @@ class BaseChecker:
|
||||
"Subclasses must implement the query method."
|
||||
)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> bool:
|
||||
def __call__(self, *args, **kwargs) -> tuple[bool, str]:
|
||||
response = self.query(*args, **kwargs)
|
||||
if response is None:
|
||||
response = "Error when calling gpt api."
|
||||
|
||||
if self.verbose and response != "YES":
|
||||
if self.verbose:
|
||||
logger.info(response)
|
||||
|
||||
if response is None:
|
||||
flag = None
|
||||
response = (
|
||||
"Error when calling GPT api, check config in "
|
||||
"`embodied_gen/utils/gpt_config.yaml` or net connection."
|
||||
)
|
||||
else:
|
||||
flag = "YES" in response
|
||||
response = "YES" if flag else response
|
||||
|
||||
@ -92,21 +105,29 @@ class MeshGeoChecker(BaseChecker):
|
||||
self.gpt_client = gpt_client
|
||||
if self.prompt is None:
|
||||
self.prompt = """
|
||||
Refer to the provided multi-view rendering images to evaluate
|
||||
whether the geometry of the 3D object asset is complete and
|
||||
whether the asset can be placed stably on the ground.
|
||||
Return "YES" only if reach the requirments,
|
||||
otherwise "NO" and explain the reason very briefly.
|
||||
You are an expert in evaluating the geometry quality of generated 3D asset.
|
||||
You will be given rendered views of a generated 3D asset with black background.
|
||||
Your task is to evaluate the quality of the 3D asset generation,
|
||||
including geometry, structure, and appearance, based on the rendered views.
|
||||
Criteria:
|
||||
- Is the geometry complete and well-formed, without missing parts or redundant structures?
|
||||
- Is the geometric structure of the object complete?
|
||||
- Minor flaws, asymmetries, or simplifications (e.g., less detail on sides or back,
|
||||
soft edges) are acceptable if the object is structurally sound and recognizable.
|
||||
- Only evaluate geometry. Do not assess texture quality.
|
||||
- The asset should not contain any unrelated elements, such as
|
||||
ground planes, platforms, or background props (e.g., paper, flooring).
|
||||
|
||||
If all the above criteria are met, return "YES". Otherwise, return
|
||||
"NO" followed by a brief explanation (no more than 20 words).
|
||||
|
||||
Example:
|
||||
Images show a yellow cup standing on a flat white plane -> NO
|
||||
-> Response: NO: extra white surface under the object.
|
||||
Image shows a chair with simplified back legs and soft edges → YES
|
||||
"""
|
||||
|
||||
def query(self, image_paths: str) -> str:
|
||||
# Hardcode tmp because of the openrouter can't input multi images.
|
||||
if "openrouter" in self.gpt_client.endpoint:
|
||||
from embodied_gen.utils.process_media import (
|
||||
combine_images_to_base64,
|
||||
)
|
||||
|
||||
image_paths = combine_images_to_base64(image_paths)
|
||||
def query(self, image_paths: list[str | Image.Image]) -> str:
|
||||
|
||||
return self.gpt_client.query(
|
||||
text_prompt=self.prompt,
|
||||
@ -137,14 +158,19 @@ class ImageSegChecker(BaseChecker):
|
||||
self.gpt_client = gpt_client
|
||||
if self.prompt is None:
|
||||
self.prompt = """
|
||||
The first image is the original, and the second image is the
|
||||
result after segmenting the main object. Evaluate the segmentation
|
||||
quality to ensure the main object is clearly segmented without
|
||||
significant truncation. Note that the foreground of the object
|
||||
needs to be extracted instead of the background.
|
||||
Minor imperfections can be ignored. If segmentation is acceptable,
|
||||
return "YES" only; otherwise, return "NO" with
|
||||
very brief explanation.
|
||||
Task: Evaluate the quality of object segmentation between two images:
|
||||
the first is the original, the second is the segmented result.
|
||||
|
||||
Criteria:
|
||||
- The main foreground object should be clearly extracted (not the background).
|
||||
- The object must appear realistic, with reasonable geometry and color.
|
||||
- The object should be geometrically complete — no missing, truncated, or cropped parts.
|
||||
- The object must be centered, with a margin on all sides.
|
||||
- Ignore minor imperfections (e.g., small holes or fine edge artifacts).
|
||||
|
||||
Output Rules:
|
||||
If segmentation is acceptable, respond with "YES" (and nothing else).
|
||||
If not acceptable, respond with "NO", followed by a brief reason (max 20 words).
|
||||
"""
|
||||
|
||||
def query(self, image_paths: list[str]) -> str:
|
||||
@ -152,13 +178,6 @@ class ImageSegChecker(BaseChecker):
|
||||
raise ValueError(
|
||||
"ImageSegChecker requires exactly two images: [raw_image, seg_image]." # noqa
|
||||
)
|
||||
# Hardcode tmp because of the openrouter can't input multi images.
|
||||
if "openrouter" in self.gpt_client.endpoint:
|
||||
from embodied_gen.utils.process_media import (
|
||||
combine_images_to_base64,
|
||||
)
|
||||
|
||||
image_paths = combine_images_to_base64(image_paths)
|
||||
|
||||
return self.gpt_client.query(
|
||||
text_prompt=self.prompt,
|
||||
@ -201,42 +220,204 @@ class ImageAestheticChecker(BaseChecker):
|
||||
return avg_score > self.thresh, avg_score
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
geo_checker = MeshGeoChecker(GPT_CLIENT)
|
||||
seg_checker = ImageSegChecker(GPT_CLIENT)
|
||||
aesthetic_checker = ImageAestheticChecker()
|
||||
class SemanticConsistChecker(BaseChecker):
|
||||
def __init__(
|
||||
self,
|
||||
gpt_client: GPTclient,
|
||||
prompt: str = None,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
super().__init__(prompt, verbose)
|
||||
self.gpt_client = gpt_client
|
||||
if self.prompt is None:
|
||||
self.prompt = """
|
||||
You are an expert in image-text consistency assessment.
|
||||
You will be given:
|
||||
- A short text description of an object.
|
||||
- An segmented image of the same object with the background removed.
|
||||
|
||||
checkers = [geo_checker, seg_checker, aesthetic_checker]
|
||||
Criteria:
|
||||
- The image must visually match the text description in terms of object type, structure, geometry, and color.
|
||||
- The object must appear realistic, with reasonable geometry (e.g., a table must have a stable number of legs).
|
||||
- Geometric completeness is required: the object must not have missing, truncated, or cropped parts.
|
||||
- The object must be centered in the image frame with clear margins on all sides,
|
||||
it should not touch or nearly touch any image edge.
|
||||
- The image must contain exactly one object. Multiple distinct objects are not allowed.
|
||||
A single composite object (e.g., a chair with legs) is acceptable.
|
||||
- The object should be shown from a slightly angled (three-quarter) perspective,
|
||||
not a flat, front-facing view showing only one surface.
|
||||
|
||||
output_root = "outputs/test_gpt"
|
||||
Instructions:
|
||||
- If all criteria are met, return `"YES"`.
|
||||
- Otherwise, return "NO" with a brief explanation (max 20 words).
|
||||
|
||||
fails = []
|
||||
for idx in tqdm(range(150)):
|
||||
mesh_path = f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}.obj" # noqa
|
||||
if not os.path.exists(mesh_path):
|
||||
Respond in exactly one of the following formats:
|
||||
YES
|
||||
or
|
||||
NO: brief explanation.
|
||||
|
||||
Input:
|
||||
{}
|
||||
"""
|
||||
|
||||
def query(self, text: str, image: list[Image.Image | str]) -> str:
|
||||
|
||||
return self.gpt_client.query(
|
||||
text_prompt=self.prompt.format(text),
|
||||
image_base64=image,
|
||||
)
|
||||
|
||||
|
||||
class TextGenAlignChecker(BaseChecker):
|
||||
def __init__(
|
||||
self,
|
||||
gpt_client: GPTclient,
|
||||
prompt: str = None,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
super().__init__(prompt, verbose)
|
||||
self.gpt_client = gpt_client
|
||||
if self.prompt is None:
|
||||
self.prompt = """
|
||||
You are an expert in evaluating the quality of generated 3D assets.
|
||||
You will be given:
|
||||
- A text description of an object: TEXT
|
||||
- Rendered views of the generated 3D asset.
|
||||
|
||||
Your task is to:
|
||||
1. Determine whether the generated 3D asset roughly reflects the object class
|
||||
or a semantically adjacent category described in the text.
|
||||
2. Evaluate the geometry quality of the 3D asset generation based on the rendered views.
|
||||
|
||||
Criteria:
|
||||
- Determine if the generated 3D asset belongs to the text described or a similar category.
|
||||
- Focus on functional similarity: if the object serves the same general
|
||||
purpose (e.g., writing, placing items), it should be accepted.
|
||||
- Is the geometry complete and well-formed, with no missing parts,
|
||||
distortions, visual artifacts, or redundant structures?
|
||||
- Does the number of object instances match the description?
|
||||
There should be only one object unless otherwise specified.
|
||||
- Minor flaws in geometry or texture are acceptable, high tolerance for texture quality defects.
|
||||
- Minor simplifications in geometry or texture (e.g. soft edges, less detail)
|
||||
are acceptable if the object is still recognizable.
|
||||
- The asset should not contain any unrelated elements, such as
|
||||
ground planes, platforms, or background props (e.g., paper, flooring).
|
||||
|
||||
Example:
|
||||
Text: "yellow cup"
|
||||
Image: shows a yellow cup standing on a flat white plane -> NO: extra surface under the object.
|
||||
|
||||
Instructions:
|
||||
- If the quality of generated asset is acceptable and faithfully represents the text, return "YES".
|
||||
- Otherwise, return "NO" followed by a brief explanation (no more than 20 words).
|
||||
|
||||
Respond in exactly one of the following formats:
|
||||
YES
|
||||
or
|
||||
NO: brief explanation
|
||||
|
||||
Input:
|
||||
Text description: {}
|
||||
"""
|
||||
|
||||
def query(self, text: str, image: list[Image.Image | str]) -> str:
|
||||
|
||||
return self.gpt_client.query(
|
||||
text_prompt=self.prompt.format(text),
|
||||
image_base64=image,
|
||||
)
|
||||
|
||||
|
||||
class SemanticMatcher(BaseChecker):
|
||||
def __init__(
|
||||
self,
|
||||
gpt_client: GPTclient,
|
||||
prompt: str = None,
|
||||
verbose: bool = False,
|
||||
seed: int = None,
|
||||
) -> None:
|
||||
super().__init__(prompt, verbose)
|
||||
self.gpt_client = gpt_client
|
||||
self.seed = seed
|
||||
random.seed(seed)
|
||||
if self.prompt is None:
|
||||
self.prompt = """
|
||||
You are an expert in semantic similarity and scene retrieval.
|
||||
You will be given:
|
||||
- A dictionary where each key is a scene ID, and each value is a scene description.
|
||||
- A query text describing a target scene.
|
||||
|
||||
Your task:
|
||||
return_num = 2
|
||||
- Find the <return_num> most semantically similar scene IDs to the query text.
|
||||
- If there are fewer than <return_num> distinct relevant matches, repeat the closest ones to make a list of <return_num>.
|
||||
- Only output the list of <return_num> scene IDs, sorted from most to less similar.
|
||||
- Do NOT use markdown, JSON code blocks, or any formatting syntax, only return a plain list like ["id1", ...].
|
||||
|
||||
Input example:
|
||||
Dictionary:
|
||||
"{{
|
||||
"t_scene_008": "A study room with full bookshelves and a lamp in the corner.",
|
||||
"t_scene_019": "A child's bedroom with pink walls and a small desk.",
|
||||
"t_scene_020": "A living room with a wooden floor.",
|
||||
"t_scene_021": "A living room with toys scattered on the floor.",
|
||||
...
|
||||
"t_scene_office_001": "A very spacious, modern open-plan office with wide desks and no people, panoramic view."
|
||||
}}"
|
||||
Text:
|
||||
"A traditional indoor room"
|
||||
Output:
|
||||
'["t_scene_office_001", ...]'
|
||||
|
||||
Input:
|
||||
Dictionary:
|
||||
{context}
|
||||
Text:
|
||||
{text}
|
||||
Output:
|
||||
<topk_key_list>
|
||||
"""
|
||||
|
||||
def query(
|
||||
self, text: str, context: dict, rand: bool = True, params: dict = None
|
||||
) -> str:
|
||||
match_list = self.gpt_client.query(
|
||||
self.prompt.format(context=context, text=text),
|
||||
params=params,
|
||||
)
|
||||
match_list = json_repair.loads(match_list)
|
||||
result = random.choice(match_list) if rand else match_list[0]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def test_semantic_matcher(
|
||||
bg_file: str = "outputs/bg_scenes/bg_scene_list.txt",
|
||||
):
|
||||
bg_file = "outputs/bg_scenes/bg_scene_list.txt"
|
||||
scene_dict = {}
|
||||
with open(bg_file, "r") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line or ":" not in line:
|
||||
continue
|
||||
image_paths = render_asset3d(
|
||||
mesh_path,
|
||||
f"{output_root}/{idx}",
|
||||
num_images=8,
|
||||
elevation=(30, -30),
|
||||
distance=5.5,
|
||||
)
|
||||
scene_id, desc = line.split(":", 1)
|
||||
scene_dict[scene_id.strip()] = desc.strip()
|
||||
|
||||
for cid, checker in enumerate(checkers):
|
||||
if isinstance(checker, ImageSegChecker):
|
||||
images = [
|
||||
f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}_raw.png", # noqa
|
||||
f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}_cond.png", # noqa
|
||||
]
|
||||
else:
|
||||
images = image_paths
|
||||
result, info = checker(images)
|
||||
logger.info(
|
||||
f"Checker {checker.__class__.__name__}: {result}, {info}, mesh {mesh_path}" # noqa
|
||||
)
|
||||
office_scene = scene_dict.get("t_scene_office_001")
|
||||
text = "bright kitchen"
|
||||
SCENE_MATCHER = SemanticMatcher(GPT_CLIENT)
|
||||
# gpt_params = {
|
||||
# "temperature": 0.8,
|
||||
# "max_tokens": 500,
|
||||
# "top_p": 0.8,
|
||||
# "frequency_penalty": 0.3,
|
||||
# "presence_penalty": 0.3,
|
||||
# }
|
||||
gpt_params = None
|
||||
match_key = SCENE_MATCHER.query(text, str(scene_dict))
|
||||
print(match_key, ",", scene_dict[match_key])
|
||||
|
||||
if result is False:
|
||||
fails.append((idx, cid, info))
|
||||
|
||||
break
|
||||
if __name__ == "__main__":
|
||||
test_semantic_matcher()
|
||||
|
||||
@ -297,20 +297,24 @@ class URDFGenerator(object):
|
||||
if not os.path.exists(urdf_path):
|
||||
raise FileNotFoundError(f"URDF file not found: {urdf_path}")
|
||||
|
||||
mesh_scale = 1.0
|
||||
mesh_attr = None
|
||||
tree = ET.parse(urdf_path)
|
||||
root = tree.getroot()
|
||||
extra_info = root.find(attr_root)
|
||||
if extra_info is not None:
|
||||
scale_element = extra_info.find(attr_name)
|
||||
if scale_element is not None:
|
||||
mesh_scale = float(scale_element.text)
|
||||
mesh_attr = scale_element.text
|
||||
try:
|
||||
mesh_attr = float(mesh_attr)
|
||||
except ValueError as e:
|
||||
pass
|
||||
|
||||
return mesh_scale
|
||||
return mesh_attr
|
||||
|
||||
@staticmethod
|
||||
def add_quality_tag(
|
||||
urdf_path: str, results, output_path: str = None
|
||||
urdf_path: str, results: list, output_path: str = None
|
||||
) -> None:
|
||||
if output_path is None:
|
||||
output_path = urdf_path
|
||||
@ -366,16 +370,9 @@ class URDFGenerator(object):
|
||||
output_root,
|
||||
num_images=self.render_view_num,
|
||||
output_subdir=self.output_render_dir,
|
||||
no_index_file=True,
|
||||
)
|
||||
|
||||
# Hardcode tmp because of the openrouter can't input multi images.
|
||||
if "openrouter" in self.gpt_client.endpoint:
|
||||
from embodied_gen.utils.process_media import (
|
||||
combine_images_to_base64,
|
||||
)
|
||||
|
||||
image_path = combine_images_to_base64(image_path)
|
||||
|
||||
response = self.gpt_client.query(text_prompt, image_path)
|
||||
if response is None:
|
||||
asset_attrs = {
|
||||
@ -412,14 +409,18 @@ class URDFGenerator(object):
|
||||
if __name__ == "__main__":
|
||||
urdf_gen = URDFGenerator(GPT_CLIENT, render_view_num=4)
|
||||
urdf_path = urdf_gen(
|
||||
mesh_path="outputs/imageto3d/cma/o5/URDF_o5/mesh/o5.obj",
|
||||
mesh_path="outputs/layout2/asset3d/marker/result/mesh/marker.obj",
|
||||
output_root="outputs/test_urdf",
|
||||
# category="coffee machine",
|
||||
category="marker",
|
||||
# min_height=1.0,
|
||||
# max_height=1.2,
|
||||
version=VERSION,
|
||||
)
|
||||
|
||||
URDFGenerator.add_quality_tag(
|
||||
urdf_path, [[urdf_gen.__class__.__name__, "OK"]]
|
||||
)
|
||||
|
||||
# zip_files(
|
||||
# input_paths=[
|
||||
# "scripts/apps/tmp/2umpdum3e5n/URDF_sample/mesh",
|
||||
|
||||
22
install.sh
22
install.sh
@ -8,6 +8,12 @@ NC='\033[0m'
|
||||
echo -e "${GREEN}Starting installation process...${NC}"
|
||||
git config --global http.postBuffer 524288000
|
||||
|
||||
echo -e "${GREEN}Installing flash-attn...${NC}"
|
||||
pip install flash-attn==2.7.0.post2 --no-build-isolation || {
|
||||
echo -e "${RED}Failed to install flash-attn${NC}"
|
||||
exit 1
|
||||
}
|
||||
|
||||
echo -e "${GREEN}Installing dependencies from requirements.txt...${NC}"
|
||||
pip install -r requirements.txt --use-deprecated=legacy-resolver --default-timeout=60 || {
|
||||
echo -e "${RED}Failed to install requirements${NC}"
|
||||
@ -15,16 +21,16 @@ pip install -r requirements.txt --use-deprecated=legacy-resolver --default-timeo
|
||||
}
|
||||
|
||||
|
||||
echo -e "${GREEN}Installing kaolin from GitHub...${NC}"
|
||||
pip install kaolin@git+https://github.com/NVIDIAGameWorks/kaolin.git@v0.16.0 || {
|
||||
echo -e "${RED}Failed to install kaolin${NC}"
|
||||
echo -e "${GREEN}Installing kolors from GitHub...${NC}"
|
||||
pip install kolors@git+https://github.com/Kwai-Kolors/Kolors.git#egg=038818d || {
|
||||
echo -e "${RED}Failed to install kolors${NC}"
|
||||
exit 1
|
||||
}
|
||||
|
||||
|
||||
echo -e "${GREEN}Installing flash-attn...${NC}"
|
||||
pip install flash-attn==2.7.0.post2 --no-build-isolation || {
|
||||
echo -e "${RED}Failed to install flash-attn${NC}"
|
||||
echo -e "${GREEN}Installing kaolin from GitHub...${NC}"
|
||||
pip install kaolin@git+https://github.com/NVIDIAGameWorks/kaolin.git@v0.16.0 || {
|
||||
echo -e "${RED}Failed to install kaolin${NC}"
|
||||
exit 1
|
||||
}
|
||||
|
||||
@ -39,7 +45,6 @@ rm -rf "$TMP_DIR" || {
|
||||
rm -rf "$TMP_DIR"
|
||||
exit 1
|
||||
}
|
||||
echo -e "${GREEN}Installation completed successfully!${NC}"
|
||||
|
||||
|
||||
echo -e "${GREEN}Installing gsplat from GitHub...${NC}"
|
||||
@ -50,8 +55,9 @@ pip install git+https://github.com/nerfstudio-project/gsplat.git@v1.5.0 || {
|
||||
|
||||
|
||||
echo -e "${GREEN}Installing EmbodiedGen...${NC}"
|
||||
pip install triton==2.1.0
|
||||
pip install -e . || {
|
||||
echo -e "${RED}Failed to install local package${NC}"
|
||||
echo -e "${RED}Failed to install EmbodiedGen pyproject.toml${NC}"
|
||||
exit 1
|
||||
}
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ packages = ["embodied_gen"]
|
||||
|
||||
[project]
|
||||
name = "embodied_gen"
|
||||
version = "v0.1.0"
|
||||
version = "v0.1.1"
|
||||
readme = "README.md"
|
||||
license = "Apache-2.0"
|
||||
license-files = ["LICENSE", "NOTICE"]
|
||||
@ -17,16 +17,20 @@ requires-python = ">=3.10"
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"cpplint==2.0.0",
|
||||
"pre-commit==2.13.0",
|
||||
"cpplint",
|
||||
"pre-commit",
|
||||
"pydocstyle",
|
||||
"black",
|
||||
"isort",
|
||||
"pytest",
|
||||
"pytest-mock",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
drender-cli = "embodied_gen.data.differentiable_render:entrypoint"
|
||||
backproject-cli = "embodied_gen.data.backproject_v2:entrypoint"
|
||||
img3d-cli = "embodied_gen.scripts.imageto3d:entrypoint"
|
||||
text3d-cli = "embodied_gen.scripts.textto3d:text_to_3d"
|
||||
|
||||
[tool.pydocstyle]
|
||||
match = '(?!test_).*(?!_pb2)\.py'
|
||||
|
||||
12
pytest.ini
Normal file
12
pytest.ini
Normal file
@ -0,0 +1,12 @@
|
||||
[pytest]
|
||||
testpaths =
|
||||
tests/test_unit
|
||||
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
addopts = -ra --tb=short --disable-warnings --log-cli-level=INFO
|
||||
filterwarnings =
|
||||
ignore::DeprecationWarning
|
||||
markers =
|
||||
slow: marks tests as slow
|
||||
@ -8,10 +8,10 @@ triton==2.1.0
|
||||
dataclasses_json
|
||||
easydict
|
||||
opencv-python>4.5
|
||||
imageio==2.36.1
|
||||
imageio-ffmpeg==0.5.1
|
||||
imageio
|
||||
imageio-ffmpeg
|
||||
rembg==2.0.61
|
||||
trimesh==4.4.4
|
||||
trimesh
|
||||
moviepy==1.0.3
|
||||
pymeshfix==0.17.0
|
||||
igraph==0.11.8
|
||||
@ -20,21 +20,19 @@ openai==1.58.1
|
||||
transformers==4.42.4
|
||||
gradio==5.12.0
|
||||
sentencepiece==0.2.0
|
||||
diffusers==0.31.0
|
||||
xatlas==0.0.9
|
||||
diffusers==0.34.0
|
||||
xatlas
|
||||
onnxruntime==1.20.1
|
||||
tenacity==8.2.2
|
||||
tenacity
|
||||
accelerate==0.33.0
|
||||
basicsr==1.4.2
|
||||
realesrgan==0.3.0
|
||||
pydantic==2.9.2
|
||||
vtk==9.3.1
|
||||
spaces
|
||||
colorlog
|
||||
json-repair
|
||||
utils3d@git+https://github.com/EasternJournalist/utils3d.git#egg=9a4eb15
|
||||
clip@git+https://github.com/openai/CLIP.git
|
||||
kolors@git+https://github.com/Kwai-Kolors/Kolors.git#egg=038818d
|
||||
segment-anything@git+https://github.com/facebookresearch/segment-anything.git#egg=dca509f
|
||||
nvdiffrast@git+https://github.com/NVlabs/nvdiffrast.git#egg=729261d
|
||||
# https://github.com/nerfstudio-project/gsplat/releases/download/v1.5.0/gsplat-1.5.0+pt24cu118-cp310-cp310-linux_x86_64.whl
|
||||
# https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.0.post2/flash_attn-2.7.0.post2+cu11torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
||||
# https://huggingface.co/xinjjj/RoboAssetGen/resolve/main/wheel_cu118/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl
|
||||
|
||||
31
tests/test_examples/test_aesthetic_predictor.py
Normal file
31
tests/test_examples/test_aesthetic_predictor.py
Normal file
@ -0,0 +1,31 @@
|
||||
# Project EmbodiedGen
|
||||
#
|
||||
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from embodied_gen.validators.aesthetic_predictor import AestheticPredictor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# @pytest.mark.manual
|
||||
def test_aesthetic_predictor():
|
||||
image_path = "apps/assets/example_image/sample_02.jpg"
|
||||
predictor = AestheticPredictor(device="cpu")
|
||||
score = predictor.predict(image_path)
|
||||
|
||||
assert isinstance(score, float)
|
||||
logger.info(f"Aesthetic score: {score:.3f}")
|
||||
119
tests/test_examples/test_quality_checkers.py
Normal file
119
tests/test_examples/test_quality_checkers.py
Normal file
@ -0,0 +1,119 @@
|
||||
# Project EmbodiedGen
|
||||
#
|
||||
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
||||
from embodied_gen.utils.process_media import render_asset3d
|
||||
from embodied_gen.validators.quality_checkers import (
|
||||
ImageAestheticChecker,
|
||||
ImageSegChecker,
|
||||
MeshGeoChecker,
|
||||
SemanticConsistChecker,
|
||||
TextGenAlignChecker,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def geo_checker():
|
||||
return MeshGeoChecker(GPT_CLIENT)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def seg_checker():
|
||||
return ImageSegChecker(GPT_CLIENT)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def aesthetic_checker():
|
||||
return ImageAestheticChecker()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def semantic_checker():
|
||||
return SemanticConsistChecker(GPT_CLIENT)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def textalign_checker():
|
||||
return TextGenAlignChecker(GPT_CLIENT)
|
||||
|
||||
|
||||
def test_geo_checker(geo_checker):
|
||||
flag, result = geo_checker(
|
||||
[
|
||||
"apps/assets/example_image/sample_02.jpg",
|
||||
]
|
||||
)
|
||||
logger.info(f"geo_checker: {flag}, {result}")
|
||||
assert isinstance(flag, bool)
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
def test_aesthetic_checker(aesthetic_checker):
|
||||
flag, result = aesthetic_checker("apps/assets/example_image/sample_02.jpg")
|
||||
logger.info(f"aesthetic_checker: {flag}, {result}")
|
||||
assert isinstance(flag, bool)
|
||||
assert isinstance(result, float)
|
||||
|
||||
|
||||
def test_seg_checker(seg_checker):
|
||||
flag, result = seg_checker(
|
||||
[
|
||||
"apps/assets/example_image/sample_02.jpg",
|
||||
"apps/assets/example_image/sample_02.jpg",
|
||||
]
|
||||
)
|
||||
logger.info(f"seg_checker: {flag}, {result}")
|
||||
assert isinstance(flag, bool)
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
def test_semantic_checker(semantic_checker):
|
||||
flag, result = semantic_checker(
|
||||
text="can",
|
||||
image=["apps/assets/example_image/sample_02.jpg"],
|
||||
)
|
||||
logger.info(f"semantic_checker: {flag}, {result}")
|
||||
assert isinstance(flag, bool)
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mesh_path, text_desc",
|
||||
[
|
||||
("apps/assets/example_texture/meshes/chair.obj", "chair"),
|
||||
("apps/assets/example_texture/meshes/clock.obj", "clock"),
|
||||
],
|
||||
)
|
||||
def test_textgen_checker(textalign_checker, mesh_path, text_desc):
|
||||
with tempfile.TemporaryDirectory() as output_root:
|
||||
image_list = render_asset3d(
|
||||
mesh_path,
|
||||
output_root=output_root,
|
||||
num_images=6,
|
||||
elevation=(30, -30),
|
||||
output_subdir="renders",
|
||||
no_index_file=True,
|
||||
with_mtl=False,
|
||||
)
|
||||
flag, result = textalign_checker(text_desc, image_list)
|
||||
logger.info(f"textalign_checker: {flag}, {result}")
|
||||
95
tests/test_unit/test_agents.py
Normal file
95
tests/test_unit/test_agents.py
Normal file
@ -0,0 +1,95 @@
|
||||
# Project EmbodiedGen
|
||||
#
|
||||
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
||||
from embodied_gen.validators.quality_checkers import (
|
||||
ImageSegChecker,
|
||||
MeshGeoChecker,
|
||||
SemanticConsistChecker,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def gptclient_query():
|
||||
with patch.object(
|
||||
GPT_CLIENT, "query", return_value="mocked gpt response"
|
||||
) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def gptclient_query_case2():
|
||||
with patch.object(GPT_CLIENT, "query", return_value=None) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_images",
|
||||
[
|
||||
"dummy_path/color_grid_6view.png",
|
||||
["dummy_path/color_grid_6view.jpg"],
|
||||
[
|
||||
"dummy_path/color_grid_6view.png",
|
||||
"dummy_path/color_grid_6view2.png",
|
||||
],
|
||||
[
|
||||
Image.new("RGB", (64, 64), "red"),
|
||||
Image.new("RGB", (64, 64), "blue"),
|
||||
],
|
||||
],
|
||||
)
|
||||
def test_geo_checker_varied_inputs(input_images):
|
||||
geo_checker = MeshGeoChecker(GPT_CLIENT)
|
||||
flag, result = geo_checker(input_images)
|
||||
assert isinstance(flag, (bool, type(None)))
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
def test_seg_checker():
|
||||
seg_checker = ImageSegChecker(GPT_CLIENT)
|
||||
flag, result = seg_checker(
|
||||
[
|
||||
"dummy_path/sample_0_raw.png", # raw image
|
||||
"dummy_path/sample_0_cond.png", # segmented image
|
||||
]
|
||||
)
|
||||
assert isinstance(flag, (bool, type(None)))
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
def test_semantic_checker():
|
||||
semantic_checker = SemanticConsistChecker(GPT_CLIENT)
|
||||
flag, result = semantic_checker(
|
||||
text="pen",
|
||||
image=["dummy_path/pen.png"],
|
||||
)
|
||||
assert isinstance(flag, (bool, type(None)))
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
def test_semantic_checker(gptclient_query_case2):
|
||||
semantic_checker = SemanticConsistChecker(GPT_CLIENT)
|
||||
flag, result = semantic_checker(
|
||||
text="pen",
|
||||
image=["dummy_path/pen.png"],
|
||||
)
|
||||
assert isinstance(flag, (bool, type(None)))
|
||||
assert isinstance(result, str)
|
||||
94
tests/test_unit/test_gpt_client.py
Normal file
94
tests/test_unit/test_gpt_client.py
Normal file
@ -0,0 +1,94 @@
|
||||
# Project EmbodiedGen
|
||||
#
|
||||
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from PIL import Image
|
||||
from embodied_gen.utils.gpt_clients import CONFIG_FILE, GPTclient
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def config():
|
||||
with open(CONFIG_FILE, "r") as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def env_vars(monkeypatch, config):
|
||||
agent_type = config["agent_type"]
|
||||
agent_config = config.get(agent_type, {})
|
||||
monkeypatch.setenv(
|
||||
"ENDPOINT", agent_config.get("endpoint", "fake_endpoint")
|
||||
)
|
||||
monkeypatch.setenv("API_KEY", agent_config.get("api_key", "fake_api_key"))
|
||||
monkeypatch.setenv("API_VERSION", agent_config.get("api_version", "v1"))
|
||||
monkeypatch.setenv(
|
||||
"MODEL_NAME", agent_config.get("model_name", "test_model")
|
||||
)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gpt_client(env_vars):
|
||||
client = GPTclient(
|
||||
endpoint=os.environ.get("ENDPOINT"),
|
||||
api_key=os.environ.get("API_KEY"),
|
||||
api_version=os.environ.get("API_VERSION"),
|
||||
model_name=os.environ.get("MODEL_NAME"),
|
||||
check_connection=False,
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text_prompt, image_base64",
|
||||
[
|
||||
("What is the capital of China?", None),
|
||||
(
|
||||
"What is the content in each image?",
|
||||
"apps/assets/example_image/sample_02.jpg",
|
||||
),
|
||||
(
|
||||
"What is the content in each image?",
|
||||
[
|
||||
"apps/assets/example_image/sample_02.jpg",
|
||||
"apps/assets/example_image/sample_03.jpg",
|
||||
],
|
||||
),
|
||||
(
|
||||
"What is the content in each image?",
|
||||
[
|
||||
Image.new("RGB", (64, 64), "red"),
|
||||
Image.new("RGB", (64, 64), "blue"),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_gptclient_query(gpt_client, text_prompt, image_base64):
|
||||
# mock GPTclient.query
|
||||
with patch.object(
|
||||
GPTclient, "query", return_value="mocked response"
|
||||
) as mock_query:
|
||||
response = gpt_client.query(
|
||||
text_prompt=text_prompt, image_base64=image_base64
|
||||
)
|
||||
assert response == "mocked response"
|
||||
mock_query.assert_called_once_with(
|
||||
text_prompt=text_prompt, image_base64=image_base64
|
||||
)
|
||||
Loading…
x
Reference in New Issue
Block a user