feat(pipe): Refine generation pipeline and add utests. (#23)
- Added intelligent quality checkers and auto-retry pipeline for `image-to-3d` and `text-to-3d`. - Added unit tests for quality checkers. - `text-to-3d` now supports more `text-to-image` models, pipeline success rate improved to 94%.
This commit is contained in:
parent
8bcfac9190
commit
a3924ae4a8
29
CHANGELOG.md
Normal file
29
CHANGELOG.md
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# Changelog
|
||||||
|
|
||||||
|
All notable changes to this project will be documented in this file.
|
||||||
|
|
||||||
|
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0).
|
||||||
|
|
||||||
|
|
||||||
|
## [0.1.1] - 2025-07-xx
|
||||||
|
### Feature
|
||||||
|
- Added intelligent quality checkers and auto-retry pipeline for `image-to-3d` and `text-to-3d`.
|
||||||
|
- Added unit tests for quality checkers.
|
||||||
|
- `text-to-3d` now supports more `text-to-image` models, pipeline success rate improved to 94%.
|
||||||
|
|
||||||
|
## [0.1.0] - 2025-07-04
|
||||||
|
### Feature
|
||||||
|
🖼️ Single Image to Physics Realistic 3D Asset
|
||||||
|
- Generates watertight, simulation-ready 3D meshes with physical attributes.
|
||||||
|
- Auto-labels semantic and quality tags (geometry, texture, foreground, etc.).
|
||||||
|
- Produces 2K textures with highlight removal and multi-view fusion.
|
||||||
|
|
||||||
|
📝 Text-to-3D Asset Generation
|
||||||
|
- Creates realistic 3D assets from natural language (English & Chinese).
|
||||||
|
- Filters assets via QA tags to ensure visual and geometric quality.
|
||||||
|
|
||||||
|
🎨 Texture Generation & Editing
|
||||||
|
- Generates 2K textures from mesh and text with semantic alignment.
|
||||||
|
- Plug-and-play modules adapt text-to-image models for 3D textures.
|
||||||
|
- Supports realistic and stylized texture outputs, including text textures.
|
||||||
|
|
||||||
44
README.md
44
README.md
@ -6,6 +6,7 @@
|
|||||||
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Image-to-3D)
|
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Image-to-3D)
|
||||||
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Text-to-3D)
|
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Text-to-3D)
|
||||||
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Texture-Gen)
|
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Texture-Gen)
|
||||||
|
[](https://mp.weixin.qq.com/s/HH1cPBhK2xcDbyCK4BBTbw)
|
||||||
|
|
||||||
|
|
||||||
> ***EmbodiedGen*** is a generative engine to create diverse and interactive 3D worlds composed of high-quality 3D assets(mesh & 3DGS) with plausible physics, leveraging generative AI to address the challenges of generalization in embodied intelligence related research.
|
> ***EmbodiedGen*** is a generative engine to create diverse and interactive 3D worlds composed of high-quality 3D assets(mesh & 3DGS) with plausible physics, leveraging generative AI to address the challenges of generalization in embodied intelligence related research.
|
||||||
@ -29,7 +30,7 @@
|
|||||||
```sh
|
```sh
|
||||||
git clone https://github.com/HorizonRobotics/EmbodiedGen.git
|
git clone https://github.com/HorizonRobotics/EmbodiedGen.git
|
||||||
cd EmbodiedGen
|
cd EmbodiedGen
|
||||||
git checkout v0.1.0
|
git checkout v0.1.1
|
||||||
git submodule update --init --recursive --progress
|
git submodule update --init --recursive --progress
|
||||||
conda create -n embodiedgen python=3.10.13 -y
|
conda create -n embodiedgen python=3.10.13 -y
|
||||||
conda activate embodiedgen
|
conda activate embodiedgen
|
||||||
@ -67,9 +68,8 @@ CUDA_VISIBLE_DEVICES=0 nohup python apps/image_to_3d.py > /dev/null 2>&1 &
|
|||||||
### ⚡ API
|
### ⚡ API
|
||||||
Generate physically plausible 3D assets from image input via the command-line API.
|
Generate physically plausible 3D assets from image input via the command-line API.
|
||||||
```sh
|
```sh
|
||||||
python3 embodied_gen/scripts/imageto3d.py \
|
img3d-cli --image_path apps/assets/example_image/sample_04.jpg apps/assets/example_image/sample_19.jpg \
|
||||||
--image_path apps/assets/example_image/sample_04.jpg apps/assets/example_image/sample_19.jpg \
|
--n_retry 2 --output_root outputs/imageto3d
|
||||||
--output_root outputs/imageto3d
|
|
||||||
|
|
||||||
# See result(.urdf/mesh.obj/mesh.glb/gs.ply) in ${output_root}/sample_xx/result
|
# See result(.urdf/mesh.obj/mesh.glb/gs.ply) in ${output_root}/sample_xx/result
|
||||||
```
|
```
|
||||||
@ -86,18 +86,29 @@ python3 embodied_gen/scripts/imageto3d.py \
|
|||||||
### ☁️ Service
|
### ☁️ Service
|
||||||
Deploy the text-to-3D generation service locally.
|
Deploy the text-to-3D generation service locally.
|
||||||
|
|
||||||
Text-to-image based on the Kolors model, supporting Chinese and English prompts.
|
Text-to-image model based on the Kolors model, supporting Chinese and English prompts.
|
||||||
Models downloaded automatically on first run, see `download_kolors_weights`, please be patient.
|
Models downloaded automatically on first run, please be patient.
|
||||||
```sh
|
```sh
|
||||||
python apps/text_to_3d.py
|
python apps/text_to_3d.py
|
||||||
```
|
```
|
||||||
|
|
||||||
### ⚡ API
|
### ⚡ API
|
||||||
Text-to-image based on the Kolors model.
|
Text-to-image model based on SD3.5 Medium, English prompts only.
|
||||||
|
Usage requires agreement to the [model license(click accept)](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium), models downloaded automatically. (ps: models with more permissive licenses found in `embodied_gen/models/image_comm_model.py`)
|
||||||
|
|
||||||
|
For large-scale 3D assets generation, set `--n_pipe_retry=2` to ensure high end-to-end 3D asset usability through automatic quality check and retries. For more diverse results, do not set `--seed_img`.
|
||||||
|
|
||||||
|
```sh
|
||||||
|
text3d-cli --prompts "small bronze figurine of a lion" "A globe with wooden base" "wooden table with embroidery" \
|
||||||
|
--n_image_retry 2 --n_asset_retry 2 --n_pipe_retry 1 --seed_img 0 \
|
||||||
|
--output_root outputs/textto3d
|
||||||
|
```
|
||||||
|
|
||||||
|
Text-to-image model based on the Kolors model.
|
||||||
```sh
|
```sh
|
||||||
bash embodied_gen/scripts/textto3d.sh \
|
bash embodied_gen/scripts/textto3d.sh \
|
||||||
--prompts "small bronze figurine of a lion" "A globe with wooden base and latitude and longitude lines" "橙色电动手钻,有磨损细节" \
|
--prompts "small bronze figurine of a lion" "A globe with wooden base and latitude and longitude lines" "橙色电动手钻,有磨损细节" \
|
||||||
--output_root outputs/textto3d
|
--output_root outputs/textto3d_k
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
@ -118,12 +129,17 @@ python apps/texture_edit.py
|
|||||||
```
|
```
|
||||||
|
|
||||||
### ⚡ API
|
### ⚡ API
|
||||||
|
Support Chinese and English prompts.
|
||||||
```sh
|
```sh
|
||||||
bash embodied_gen/scripts/texture_gen.sh \
|
bash embodied_gen/scripts/texture_gen.sh \
|
||||||
--mesh_path "apps/assets/example_texture/meshes/robot_text.obj" \
|
--mesh_path "apps/assets/example_texture/meshes/robot_text.obj" \
|
||||||
--prompt "举着牌子的写实风格机器人,大眼睛,牌子上写着“Hello”的文字" \
|
--prompt "举着牌子的写实风格机器人,大眼睛,牌子上写着“Hello”的文字" \
|
||||||
--output_root "outputs/texture_gen/" \
|
--output_root "outputs/texture_gen/robot_text"
|
||||||
--uuid "robot_text"
|
|
||||||
|
bash embodied_gen/scripts/texture_gen.sh \
|
||||||
|
--mesh_path "apps/assets/example_texture/meshes/horse.obj" \
|
||||||
|
--prompt "A gray horse head with flying mane and brown eyes" \
|
||||||
|
--output_root "outputs/texture_gen/gray_horse"
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
@ -171,6 +187,12 @@ bash embodied_gen/scripts/texture_gen.sh \
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## For Developer
|
||||||
|
```sh
|
||||||
|
pip install .[dev] && pre-commit install
|
||||||
|
python -m pytest # Pass all unit-test are required.
|
||||||
|
```
|
||||||
|
|
||||||
## 📚 Citation
|
## 📚 Citation
|
||||||
|
|
||||||
If you use EmbodiedGen in your research or projects, please cite:
|
If you use EmbodiedGen in your research or projects, please cite:
|
||||||
@ -192,7 +214,7 @@ If you use EmbodiedGen in your research or projects, please cite:
|
|||||||
## 🙌 Acknowledgement
|
## 🙌 Acknowledgement
|
||||||
|
|
||||||
EmbodiedGen builds upon the following amazing projects and models:
|
EmbodiedGen builds upon the following amazing projects and models:
|
||||||
🌟 [Trellis](https://github.com/microsoft/TRELLIS) | 🌟 [Hunyuan-Delight](https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0) | 🌟 [Segment Anything](https://github.com/facebookresearch/segment-anything) | 🌟 [Rembg](https://github.com/danielgatis/rembg) | 🌟 [RMBG-1.4](https://huggingface.co/briaai/RMBG-1.4) | 🌟 [Stable Diffusion x4](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) | 🌟 [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) | 🌟 [Kolors](https://github.com/Kwai-Kolors/Kolors) | 🌟 [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 🌟 [Aesthetic Score](http://captions.christoph-schuhmann.de/aesthetic_viz_laion_sac+logos+ava1-l14-linearMSE-en-2.37B.html) | 🌟 [Pano2Room](https://github.com/TrickyGo/Pano2Room) | 🌟 [Diffusion360](https://github.com/ArcherFMY/SD-T2I-360PanoImage) | 🌟 [Kaolin](https://github.com/NVIDIAGameWorks/kaolin) | 🌟 [diffusers](https://github.com/huggingface/diffusers) | 🌟 [gsplat](https://github.com/nerfstudio-project/gsplat) | 🌟 [QWEN2.5VL](https://github.com/QwenLM/Qwen2.5-VL) | 🌟 [GPT4o](https://platform.openai.com/docs/models/gpt-4o)
|
🌟 [Trellis](https://github.com/microsoft/TRELLIS) | 🌟 [Hunyuan-Delight](https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0) | 🌟 [Segment Anything](https://github.com/facebookresearch/segment-anything) | 🌟 [Rembg](https://github.com/danielgatis/rembg) | 🌟 [RMBG-1.4](https://huggingface.co/briaai/RMBG-1.4) | 🌟 [Stable Diffusion x4](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) | 🌟 [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) | 🌟 [Kolors](https://github.com/Kwai-Kolors/Kolors) | 🌟 [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 🌟 [Aesthetic Score](http://captions.christoph-schuhmann.de/aesthetic_viz_laion_sac+logos+ava1-l14-linearMSE-en-2.37B.html) | 🌟 [Pano2Room](https://github.com/TrickyGo/Pano2Room) | 🌟 [Diffusion360](https://github.com/ArcherFMY/SD-T2I-360PanoImage) | 🌟 [Kaolin](https://github.com/NVIDIAGameWorks/kaolin) | 🌟 [diffusers](https://github.com/huggingface/diffusers) | 🌟 [gsplat](https://github.com/nerfstudio-project/gsplat) | 🌟 [QWEN-2.5VL](https://github.com/QwenLM/Qwen2.5-VL) | 🌟 [GPT4o](https://platform.openai.com/docs/models/gpt-4o) | 🌟 [SD3.5](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium)
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
@ -55,9 +55,9 @@ from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
|||||||
from embodied_gen.utils.process_media import (
|
from embodied_gen.utils.process_media import (
|
||||||
filter_image_small_connected_components,
|
filter_image_small_connected_components,
|
||||||
merge_images_video,
|
merge_images_video,
|
||||||
render_video,
|
|
||||||
)
|
)
|
||||||
from embodied_gen.utils.tags import VERSION
|
from embodied_gen.utils.tags import VERSION
|
||||||
|
from embodied_gen.utils.trender import render_video
|
||||||
from embodied_gen.validators.quality_checkers import (
|
from embodied_gen.validators.quality_checkers import (
|
||||||
BaseChecker,
|
BaseChecker,
|
||||||
ImageAestheticChecker,
|
ImageAestheticChecker,
|
||||||
|
|||||||
@ -33,7 +33,6 @@ from tqdm import tqdm
|
|||||||
from embodied_gen.data.utils import (
|
from embodied_gen.data.utils import (
|
||||||
CameraSetting,
|
CameraSetting,
|
||||||
DiffrastRender,
|
DiffrastRender,
|
||||||
RenderItems,
|
|
||||||
as_list,
|
as_list,
|
||||||
calc_vertex_normals,
|
calc_vertex_normals,
|
||||||
import_kaolin_mesh,
|
import_kaolin_mesh,
|
||||||
@ -42,6 +41,7 @@ from embodied_gen.data.utils import (
|
|||||||
render_pbr,
|
render_pbr,
|
||||||
save_images,
|
save_images,
|
||||||
)
|
)
|
||||||
|
from embodied_gen.utils.enum import RenderItems
|
||||||
|
|
||||||
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
||||||
os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
|
os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
|
||||||
@ -470,7 +470,7 @@ def parse_args():
|
|||||||
"--pbr_light_factor",
|
"--pbr_light_factor",
|
||||||
type=float,
|
type=float,
|
||||||
default=1.0,
|
default=1.0,
|
||||||
help="Light factor for mesh PBR rendering (default: 2.)",
|
help="Light factor for mesh PBR rendering (default: 1.)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--with_mtl",
|
"--with_mtl",
|
||||||
@ -482,6 +482,11 @@ def parse_args():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Whether to generate color .gif rendering file.",
|
help="Whether to generate color .gif rendering file.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no_index_file",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether skip the index file saving.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--gen_color_mp4",
|
"--gen_color_mp4",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@ -568,7 +573,7 @@ def entrypoint(**kwargs) -> None:
|
|||||||
gen_viewnormal_mp4=args.gen_viewnormal_mp4,
|
gen_viewnormal_mp4=args.gen_viewnormal_mp4,
|
||||||
gen_glonormal_mp4=args.gen_glonormal_mp4,
|
gen_glonormal_mp4=args.gen_glonormal_mp4,
|
||||||
light_factor=args.pbr_light_factor,
|
light_factor=args.pbr_light_factor,
|
||||||
no_index_file=gen_video,
|
no_index_file=gen_video or args.no_index_file,
|
||||||
)
|
)
|
||||||
image_render.render_mesh(
|
image_render.render_mesh(
|
||||||
mesh_path=args.mesh_path,
|
mesh_path=args.mesh_path,
|
||||||
|
|||||||
@ -395,6 +395,8 @@ class MeshFixer(object):
|
|||||||
self.vertices_np,
|
self.vertices_np,
|
||||||
np.hstack([np.full((self.faces.shape[0], 1), 3), self.faces_np]),
|
np.hstack([np.full((self.faces.shape[0], 1), 3), self.faces_np]),
|
||||||
)
|
)
|
||||||
|
mesh.clean(inplace=True)
|
||||||
|
mesh.clear_data()
|
||||||
mesh = mesh.decimate(ratio, progress_bar=True)
|
mesh = mesh.decimate(ratio, progress_bar=True)
|
||||||
|
|
||||||
# Update vertices and faces
|
# Update vertices and faces
|
||||||
|
|||||||
@ -38,7 +38,6 @@ except ImportError:
|
|||||||
ChatGLMModel = None
|
ChatGLMModel = None
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
import trimesh
|
import trimesh
|
||||||
from kaolin.render.camera import Camera
|
from kaolin.render.camera import Camera
|
||||||
@ -57,7 +56,6 @@ __all__ = [
|
|||||||
"load_mesh_to_unit_cube",
|
"load_mesh_to_unit_cube",
|
||||||
"as_list",
|
"as_list",
|
||||||
"CameraSetting",
|
"CameraSetting",
|
||||||
"RenderItems",
|
|
||||||
"import_kaolin_mesh",
|
"import_kaolin_mesh",
|
||||||
"save_mesh_with_mtl",
|
"save_mesh_with_mtl",
|
||||||
"get_images_from_grid",
|
"get_images_from_grid",
|
||||||
@ -738,18 +736,6 @@ class CameraSetting:
|
|||||||
self.Ks = Ks
|
self.Ks = Ks
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RenderItems(str, Enum):
|
|
||||||
IMAGE = "image_color"
|
|
||||||
ALPHA = "image_mask"
|
|
||||||
VIEW_NORMAL = "image_view_normal"
|
|
||||||
GLOBAL_NORMAL = "image_global_normal"
|
|
||||||
POSITION_MAP = "image_position"
|
|
||||||
DEPTH = "image_depth"
|
|
||||||
ALBEDO = "image_albedo"
|
|
||||||
DIFFUSE = "image_diffuse"
|
|
||||||
|
|
||||||
|
|
||||||
def _compute_az_el_by_camera_params(
|
def _compute_az_el_by_camera_params(
|
||||||
camera_params: CameraSetting, flip_az: bool = False
|
camera_params: CameraSetting, flip_az: bool = False
|
||||||
):
|
):
|
||||||
|
|||||||
236
embodied_gen/models/image_comm_model.py
Normal file
236
embodied_gen/models/image_comm_model.py
Normal file
@ -0,0 +1,236 @@
|
|||||||
|
# Project EmbodiedGen
|
||||||
|
#
|
||||||
|
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
# implied. See the License for the specific language governing
|
||||||
|
# permissions and limitations under the License.
|
||||||
|
# Text-to-Image generation models from Hugging Face community.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers import (
|
||||||
|
ChromaPipeline,
|
||||||
|
Cosmos2TextToImagePipeline,
|
||||||
|
DPMSolverMultistepScheduler,
|
||||||
|
FluxPipeline,
|
||||||
|
KolorsPipeline,
|
||||||
|
StableDiffusion3Pipeline,
|
||||||
|
)
|
||||||
|
from diffusers.quantizers import PipelineQuantizationConfig
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import AutoModelForCausalLM, SiglipProcessor
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"build_hf_image_pipeline",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class BasePipelineLoader(ABC):
|
||||||
|
def __init__(self, device="cuda"):
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BasePipelineRunner(ABC):
|
||||||
|
def __init__(self, pipe):
|
||||||
|
self.pipe = pipe
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run(self, prompt: str, **kwargs) -> Image.Image:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# ===== SD3.5-medium =====
|
||||||
|
class SD35Loader(BasePipelineLoader):
|
||||||
|
def load(self):
|
||||||
|
pipe = StableDiffusion3Pipeline.from_pretrained(
|
||||||
|
"stabilityai/stable-diffusion-3.5-medium",
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
)
|
||||||
|
pipe = pipe.to(self.device)
|
||||||
|
pipe.enable_model_cpu_offload()
|
||||||
|
pipe.enable_xformers_memory_efficient_attention()
|
||||||
|
pipe.enable_attention_slicing()
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
|
class SD35Runner(BasePipelineRunner):
|
||||||
|
def run(self, prompt: str, **kwargs) -> Image.Image:
|
||||||
|
return self.pipe(prompt=prompt, **kwargs).images
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Cosmos2 =====
|
||||||
|
class CosmosLoader(BasePipelineLoader):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id="nvidia/Cosmos-Predict2-2B-Text2Image",
|
||||||
|
local_dir="weights/cosmos2",
|
||||||
|
device="cuda",
|
||||||
|
):
|
||||||
|
super().__init__(device)
|
||||||
|
self.model_id = model_id
|
||||||
|
self.local_dir = local_dir
|
||||||
|
|
||||||
|
def _patch(self):
|
||||||
|
def patch_model(cls):
|
||||||
|
orig = cls.from_pretrained
|
||||||
|
|
||||||
|
def new(*args, **kwargs):
|
||||||
|
kwargs.setdefault("attn_implementation", "flash_attention_2")
|
||||||
|
kwargs.setdefault("torch_dtype", torch.bfloat16)
|
||||||
|
return orig(*args, **kwargs)
|
||||||
|
|
||||||
|
cls.from_pretrained = new
|
||||||
|
|
||||||
|
def patch_processor(cls):
|
||||||
|
orig = cls.from_pretrained
|
||||||
|
|
||||||
|
def new(*args, **kwargs):
|
||||||
|
kwargs.setdefault("use_fast", True)
|
||||||
|
return orig(*args, **kwargs)
|
||||||
|
|
||||||
|
cls.from_pretrained = new
|
||||||
|
|
||||||
|
patch_model(AutoModelForCausalLM)
|
||||||
|
patch_processor(SiglipProcessor)
|
||||||
|
|
||||||
|
def load(self):
|
||||||
|
self._patch()
|
||||||
|
snapshot_download(
|
||||||
|
repo_id=self.model_id,
|
||||||
|
local_dir=self.local_dir,
|
||||||
|
local_dir_use_symlinks=False,
|
||||||
|
resume_download=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = PipelineQuantizationConfig(
|
||||||
|
quant_backend="bitsandbytes_4bit",
|
||||||
|
quant_kwargs={
|
||||||
|
"load_in_4bit": True,
|
||||||
|
"bnb_4bit_quant_type": "nf4",
|
||||||
|
"bnb_4bit_compute_dtype": torch.bfloat16,
|
||||||
|
"bnb_4bit_use_double_quant": True,
|
||||||
|
},
|
||||||
|
components_to_quantize=["text_encoder", "transformer", "unet"],
|
||||||
|
)
|
||||||
|
|
||||||
|
pipe = Cosmos2TextToImagePipeline.from_pretrained(
|
||||||
|
self.model_id,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
quantization_config=config,
|
||||||
|
use_safetensors=True,
|
||||||
|
safety_checker=None,
|
||||||
|
requires_safety_checker=False,
|
||||||
|
).to(self.device)
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
|
class CosmosRunner(BasePipelineRunner):
|
||||||
|
def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
|
||||||
|
return self.pipe(
|
||||||
|
prompt=prompt, negative_prompt=negative_prompt, **kwargs
|
||||||
|
).images
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Kolors =====
|
||||||
|
class KolorsLoader(BasePipelineLoader):
|
||||||
|
def load(self):
|
||||||
|
pipe = KolorsPipeline.from_pretrained(
|
||||||
|
"Kwai-Kolors/Kolors-diffusers",
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
variant="fp16",
|
||||||
|
).to(self.device)
|
||||||
|
pipe.enable_model_cpu_offload()
|
||||||
|
pipe.enable_xformers_memory_efficient_attention()
|
||||||
|
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
|
||||||
|
pipe.scheduler.config, use_karras_sigmas=True
|
||||||
|
)
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
|
class KolorsRunner(BasePipelineRunner):
|
||||||
|
def run(self, prompt: str, **kwargs) -> Image.Image:
|
||||||
|
return self.pipe(prompt=prompt, **kwargs).images
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Flux =====
|
||||||
|
class FluxLoader(BasePipelineLoader):
|
||||||
|
def load(self):
|
||||||
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
||||||
|
pipe = FluxPipeline.from_pretrained(
|
||||||
|
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
pipe.enable_model_cpu_offload()
|
||||||
|
pipe.enable_xformers_memory_efficient_attention()
|
||||||
|
pipe.enable_attention_slicing()
|
||||||
|
return pipe.to(self.device)
|
||||||
|
|
||||||
|
|
||||||
|
class FluxRunner(BasePipelineRunner):
|
||||||
|
def run(self, prompt: str, **kwargs) -> Image.Image:
|
||||||
|
return self.pipe(prompt=prompt, **kwargs).images
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Chroma =====
|
||||||
|
class ChromaLoader(BasePipelineLoader):
|
||||||
|
def load(self):
|
||||||
|
return ChromaPipeline.from_pretrained(
|
||||||
|
"lodestones/Chroma", torch_dtype=torch.bfloat16
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
|
||||||
|
class ChromaRunner(BasePipelineRunner):
|
||||||
|
def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
|
||||||
|
return self.pipe(
|
||||||
|
prompt=prompt, negative_prompt=negative_prompt, **kwargs
|
||||||
|
).images
|
||||||
|
|
||||||
|
|
||||||
|
PIPELINE_REGISTRY = {
|
||||||
|
"sd35": (SD35Loader, SD35Runner),
|
||||||
|
"cosmos": (CosmosLoader, CosmosRunner),
|
||||||
|
"kolors": (KolorsLoader, KolorsRunner),
|
||||||
|
"flux": (FluxLoader, FluxRunner),
|
||||||
|
"chroma": (ChromaLoader, ChromaRunner),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_hf_image_pipeline(name: str, device="cuda") -> BasePipelineRunner:
|
||||||
|
if name not in PIPELINE_REGISTRY:
|
||||||
|
raise ValueError(f"Unsupported model: {name}")
|
||||||
|
loader_cls, runner_cls = PIPELINE_REGISTRY[name]
|
||||||
|
pipe = loader_cls(device=device).load()
|
||||||
|
|
||||||
|
return runner_cls(pipe)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
model_name = "sd35"
|
||||||
|
runner = build_hf_image_pipeline(model_name)
|
||||||
|
# NOTE: Just for pipeline testing, generation quality at low resolution is poor.
|
||||||
|
images = runner.run(
|
||||||
|
prompt="A robot holding a sign that says 'Hello'",
|
||||||
|
height=512,
|
||||||
|
width=512,
|
||||||
|
num_inference_steps=10,
|
||||||
|
guidance_scale=6,
|
||||||
|
num_images_per_prompt=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, img in enumerate(images):
|
||||||
|
img.save(f"image_{model_name}_{i}.jpg")
|
||||||
@ -52,8 +52,11 @@ __all__ = [
|
|||||||
"download_kolors_weights",
|
"download_kolors_weights",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
PROMPT_APPEND = (
|
||||||
PROMPT_APPEND = "Full view of one {}, no cropping, centered, no occlusion, isolated product photo, matte, 3D style, on a plain clean surface"
|
"Angled 3D view of one {object}, centered, no cropping, no occlusion, isolated product photo, "
|
||||||
|
"no surroundings, matte, on a plain clean surface, 3D style revealing multiple surfaces"
|
||||||
|
)
|
||||||
|
PROMPT_KAPPEND = "Single {object}, in the center of the image, white background, 3D style, best quality"
|
||||||
|
|
||||||
|
|
||||||
def download_kolors_weights(local_dir: str = "weights/Kolors") -> None:
|
def download_kolors_weights(local_dir: str = "weights/Kolors") -> None:
|
||||||
@ -182,9 +185,7 @@ def text2img_gen(
|
|||||||
ip_image_size: int = 512,
|
ip_image_size: int = 512,
|
||||||
seed: int = None,
|
seed: int = None,
|
||||||
) -> list[Image.Image]:
|
) -> list[Image.Image]:
|
||||||
# prompt = "Single " + prompt + ", in the center of the image"
|
prompt = PROMPT_KAPPEND.format(object=prompt.strip())
|
||||||
# prompt += ", high quality, high resolution, best quality, white background, 3D style" # noqa
|
|
||||||
prompt = PROMPT_APPEND.format(prompt.strip())
|
|
||||||
logger.info(f"Processing prompt: {prompt}")
|
logger.info(f"Processing prompt: {prompt}")
|
||||||
|
|
||||||
generator = None
|
generator = None
|
||||||
|
|||||||
@ -16,13 +16,14 @@
|
|||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
import sys
|
import sys
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from shutil import copy, copytree, rmtree
|
from shutil import copy, copytree, rmtree
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
import trimesh
|
import trimesh
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
|
from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
|
||||||
@ -37,8 +38,10 @@ from embodied_gen.models.segment_model import (
|
|||||||
from embodied_gen.models.sr_model import ImageRealESRGAN
|
from embodied_gen.models.sr_model import ImageRealESRGAN
|
||||||
from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
|
from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
|
||||||
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
||||||
from embodied_gen.utils.process_media import merge_images_video, render_video
|
from embodied_gen.utils.log import logger
|
||||||
|
from embodied_gen.utils.process_media import merge_images_video
|
||||||
from embodied_gen.utils.tags import VERSION
|
from embodied_gen.utils.tags import VERSION
|
||||||
|
from embodied_gen.utils.trender import render_video
|
||||||
from embodied_gen.validators.quality_checkers import (
|
from embodied_gen.validators.quality_checkers import (
|
||||||
BaseChecker,
|
BaseChecker,
|
||||||
ImageAestheticChecker,
|
ImageAestheticChecker,
|
||||||
@ -52,19 +55,14 @@ current_dir = os.path.dirname(current_file_path)
|
|||||||
sys.path.append(os.path.join(current_dir, "../.."))
|
sys.path.append(os.path.join(current_dir, "../.."))
|
||||||
from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
|
from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
|
||||||
|
|
||||||
logging.basicConfig(
|
|
||||||
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
|
|
||||||
)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
|
os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
|
||||||
"~/.cache/torch_extensions"
|
"~/.cache/torch_extensions"
|
||||||
)
|
)
|
||||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
|
os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
|
||||||
os.environ["SPCONV_ALGO"] = "native"
|
os.environ["SPCONV_ALGO"] = "native"
|
||||||
|
random.seed(0)
|
||||||
|
|
||||||
|
logger.info("Loading Models...")
|
||||||
DELIGHT = DelightingModel()
|
DELIGHT = DelightingModel()
|
||||||
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
||||||
|
|
||||||
@ -74,7 +72,7 @@ SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
|
|||||||
PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
|
PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
|
||||||
"microsoft/TRELLIS-image-large"
|
"microsoft/TRELLIS-image-large"
|
||||||
)
|
)
|
||||||
PIPELINE.cuda()
|
# PIPELINE.cuda()
|
||||||
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
|
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
|
||||||
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
|
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
|
||||||
AESTHETIC_CHECKER = ImageAestheticChecker()
|
AESTHETIC_CHECKER = ImageAestheticChecker()
|
||||||
@ -95,7 +93,6 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output_root",
|
"--output_root",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
|
||||||
help="Root directory for saving outputs.",
|
help="Root directory for saving outputs.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -110,12 +107,26 @@ def parse_args():
|
|||||||
default=None,
|
default=None,
|
||||||
help="The mass in kg to restore the mesh real weight.",
|
help="The mass in kg to restore the mesh real weight.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--asset_type", type=str, default=None)
|
parser.add_argument("--asset_type", type=str, nargs="+", default=None)
|
||||||
parser.add_argument("--skip_exists", action="store_true")
|
parser.add_argument("--skip_exists", action="store_true")
|
||||||
parser.add_argument("--strict_seg", action="store_true")
|
|
||||||
parser.add_argument("--version", type=str, default=VERSION)
|
parser.add_argument("--version", type=str, default=VERSION)
|
||||||
parser.add_argument("--remove_intermediate", type=bool, default=True)
|
parser.add_argument("--keep_intermediate", action="store_true")
|
||||||
args = parser.parse_args()
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_retry",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
)
|
||||||
|
args, unknown = parser.parse_known_args()
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def entrypoint(**kwargs):
|
||||||
|
args = parse_args()
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if hasattr(args, k) and v is not None:
|
||||||
|
setattr(args, k, v)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
args.image_path or args.image_root
|
args.image_path or args.image_root
|
||||||
@ -125,13 +136,7 @@ def parse_args():
|
|||||||
args.image_path += glob(os.path.join(args.image_root, "*.jpg"))
|
args.image_path += glob(os.path.join(args.image_root, "*.jpg"))
|
||||||
args.image_path += glob(os.path.join(args.image_root, "*.jpeg"))
|
args.image_path += glob(os.path.join(args.image_root, "*.jpeg"))
|
||||||
|
|
||||||
return args
|
for idx, image_path in enumerate(args.image_path):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
args = parse_args()
|
|
||||||
|
|
||||||
for image_path in args.image_path:
|
|
||||||
try:
|
try:
|
||||||
filename = os.path.basename(image_path).split(".")[0]
|
filename = os.path.basename(image_path).split(".")[0]
|
||||||
output_root = args.output_root
|
output_root = args.output_root
|
||||||
@ -141,7 +146,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
mesh_out = f"{output_root}/{filename}.obj"
|
mesh_out = f"{output_root}/{filename}.obj"
|
||||||
if args.skip_exists and os.path.exists(mesh_out):
|
if args.skip_exists and os.path.exists(mesh_out):
|
||||||
logger.info(
|
logger.warning(
|
||||||
f"Skip {image_path}, already processed in {mesh_out}"
|
f"Skip {image_path}, already processed in {mesh_out}"
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
@ -149,22 +154,27 @@ if __name__ == "__main__":
|
|||||||
image = Image.open(image_path)
|
image = Image.open(image_path)
|
||||||
image.save(f"{output_root}/{filename}_raw.png")
|
image.save(f"{output_root}/{filename}_raw.png")
|
||||||
|
|
||||||
# Segmentation: Get segmented image using SAM or Rembg.
|
# Segmentation: Get segmented image using Rembg.
|
||||||
seg_path = f"{output_root}/{filename}_cond.png"
|
seg_path = f"{output_root}/{filename}_cond.png"
|
||||||
if image.mode != "RGBA":
|
seg_image = RBG_REMOVER(image) if image.mode != "RGBA" else image
|
||||||
seg_image = RBG_REMOVER(image, save_path=seg_path)
|
|
||||||
seg_image = trellis_preprocess(seg_image)
|
seg_image = trellis_preprocess(seg_image)
|
||||||
else:
|
|
||||||
seg_image = image
|
|
||||||
seg_image.save(seg_path)
|
seg_image.save(seg_path)
|
||||||
|
|
||||||
|
seed = args.seed
|
||||||
|
for try_idx in range(args.n_retry):
|
||||||
|
logger.info(
|
||||||
|
f"Try: {try_idx + 1}/{args.n_retry}, Seed: {seed}, Prompt: {seg_path}"
|
||||||
|
)
|
||||||
# Run the pipeline
|
# Run the pipeline
|
||||||
try:
|
try:
|
||||||
|
PIPELINE.cuda()
|
||||||
outputs = PIPELINE.run(
|
outputs = PIPELINE.run(
|
||||||
seg_image,
|
seg_image,
|
||||||
preprocess_image=False,
|
preprocess_image=False,
|
||||||
|
seed=(
|
||||||
|
random.randint(0, 100000) if seed is None else seed
|
||||||
|
),
|
||||||
# Optional parameters
|
# Optional parameters
|
||||||
# seed=1,
|
|
||||||
# sparse_structure_sampler_params={
|
# sparse_structure_sampler_params={
|
||||||
# "steps": 12,
|
# "steps": 12,
|
||||||
# "cfg_strength": 7.5,
|
# "cfg_strength": 7.5,
|
||||||
@ -174,19 +184,16 @@ if __name__ == "__main__":
|
|||||||
# "cfg_strength": 3,
|
# "cfg_strength": 3,
|
||||||
# },
|
# },
|
||||||
)
|
)
|
||||||
|
PIPELINE.cpu()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[Pipeline Failed] process {image_path}: {e}, skip."
|
f"[Pipeline Failed] process {image_path}: {e}, skip."
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Render and save color and mesh videos
|
|
||||||
gs_model = outputs["gaussian"][0]
|
gs_model = outputs["gaussian"][0]
|
||||||
mesh_model = outputs["mesh"][0]
|
mesh_model = outputs["mesh"][0]
|
||||||
color_images = render_video(gs_model)["color"]
|
|
||||||
normal_images = render_video(mesh_model)["normal"]
|
|
||||||
video_path = os.path.join(output_root, "gs_mesh.mp4")
|
|
||||||
merge_images_video(color_images, normal_images, video_path)
|
|
||||||
|
|
||||||
# Save the raw Gaussian model
|
# Save the raw Gaussian model
|
||||||
gs_path = mesh_out.replace(".obj", "_gs.ply")
|
gs_path = mesh_out.replace(".obj", "_gs.ply")
|
||||||
@ -210,6 +217,21 @@ if __name__ == "__main__":
|
|||||||
color_path = os.path.join(output_root, "color.png")
|
color_path = os.path.join(output_root, "color.png")
|
||||||
render_gs_api(aligned_gs_path, color_path)
|
render_gs_api(aligned_gs_path, color_path)
|
||||||
|
|
||||||
|
geo_flag, geo_result = GEO_CHECKER([color_path])
|
||||||
|
logger.warning(
|
||||||
|
f"{GEO_CHECKER.__class__.__name__}: {geo_result} for {seg_path}"
|
||||||
|
)
|
||||||
|
if geo_flag is True or geo_flag is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
seed = random.randint(0, 100000) if seed is not None else None
|
||||||
|
|
||||||
|
# Render the video for generated 3D asset.
|
||||||
|
color_images = render_video(gs_model)["color"]
|
||||||
|
normal_images = render_video(mesh_model)["normal"]
|
||||||
|
video_path = os.path.join(output_root, "gs_mesh.mp4")
|
||||||
|
merge_images_video(color_images, normal_images, video_path)
|
||||||
|
|
||||||
mesh = trimesh.Trimesh(
|
mesh = trimesh.Trimesh(
|
||||||
vertices=mesh_model.vertices.cpu().numpy(),
|
vertices=mesh_model.vertices.cpu().numpy(),
|
||||||
faces=mesh_model.faces.cpu().numpy(),
|
faces=mesh_model.faces.cpu().numpy(),
|
||||||
@ -249,8 +271,8 @@ if __name__ == "__main__":
|
|||||||
min_mass, max_mass = map(float, args.mass_range.split("-"))
|
min_mass, max_mass = map(float, args.mass_range.split("-"))
|
||||||
asset_attrs["min_mass"] = min_mass
|
asset_attrs["min_mass"] = min_mass
|
||||||
asset_attrs["max_mass"] = max_mass
|
asset_attrs["max_mass"] = max_mass
|
||||||
if args.asset_type:
|
if isinstance(args.asset_type, list) and args.asset_type[idx]:
|
||||||
asset_attrs["category"] = args.asset_type
|
asset_attrs["category"] = args.asset_type[idx]
|
||||||
if args.version:
|
if args.version:
|
||||||
asset_attrs["version"] = args.version
|
asset_attrs["version"] = args.version
|
||||||
|
|
||||||
@ -289,8 +311,8 @@ if __name__ == "__main__":
|
|||||||
]
|
]
|
||||||
images_list.append(images)
|
images_list.append(images)
|
||||||
|
|
||||||
results = BaseChecker.validate(CHECKERS, images_list)
|
qa_results = BaseChecker.validate(CHECKERS, images_list)
|
||||||
urdf_convertor.add_quality_tag(urdf_path, results)
|
urdf_convertor.add_quality_tag(urdf_path, qa_results)
|
||||||
|
|
||||||
# Organize the final result files
|
# Organize the final result files
|
||||||
result_dir = f"{output_root}/result"
|
result_dir = f"{output_root}/result"
|
||||||
@ -303,7 +325,7 @@ if __name__ == "__main__":
|
|||||||
f"{result_dir}/{urdf_convertor.output_mesh_dir}",
|
f"{result_dir}/{urdf_convertor.output_mesh_dir}",
|
||||||
)
|
)
|
||||||
copy(video_path, f"{result_dir}/video.mp4")
|
copy(video_path, f"{result_dir}/video.mp4")
|
||||||
if args.remove_intermediate:
|
if not args.keep_intermediate:
|
||||||
delete_dir(output_root, keep_subs=["result"])
|
delete_dir(output_root, keep_subs=["result"])
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -311,3 +333,7 @@ if __name__ == "__main__":
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
logger.info(f"Processing complete. Outputs saved to {args.output_root}")
|
logger.info(f"Processing complete. Outputs saved to {args.output_root}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
entrypoint()
|
||||||
|
|||||||
@ -85,7 +85,7 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--seed",
|
"--seed",
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=None,
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|||||||
271
embodied_gen/scripts/textto3d.py
Normal file
271
embodied_gen/scripts/textto3d.py
Normal file
@ -0,0 +1,271 @@
|
|||||||
|
# Project EmbodiedGen
|
||||||
|
#
|
||||||
|
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
# implied. See the License for the specific language governing
|
||||||
|
# permissions and limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from embodied_gen.models.image_comm_model import build_hf_image_pipeline
|
||||||
|
from embodied_gen.models.segment_model import RembgRemover
|
||||||
|
from embodied_gen.models.text_model import PROMPT_APPEND
|
||||||
|
from embodied_gen.scripts.imageto3d import entrypoint as imageto3d_api
|
||||||
|
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
||||||
|
from embodied_gen.utils.log import logger
|
||||||
|
from embodied_gen.utils.process_media import render_asset3d
|
||||||
|
from embodied_gen.validators.quality_checkers import (
|
||||||
|
ImageSegChecker,
|
||||||
|
SemanticConsistChecker,
|
||||||
|
TextGenAlignChecker,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Avoid huggingface/tokenizers: The current process just got forked.
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
random.seed(0)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"text_to_image",
|
||||||
|
"text_to_3d",
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.info("Loading Models...")
|
||||||
|
SEMANTIC_CHECKER = SemanticConsistChecker(GPT_CLIENT)
|
||||||
|
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
|
||||||
|
TXTGEN_CHECKER = TextGenAlignChecker(GPT_CLIENT)
|
||||||
|
PIPE_IMG = build_hf_image_pipeline("sd35")
|
||||||
|
BG_REMOVER = RembgRemover()
|
||||||
|
|
||||||
|
|
||||||
|
def text_to_image(
|
||||||
|
prompt: str,
|
||||||
|
save_path: str,
|
||||||
|
n_retry: int,
|
||||||
|
img_denoise_step: int,
|
||||||
|
text_guidance_scale: float,
|
||||||
|
n_img_sample: int,
|
||||||
|
image_hw: tuple[int, int] = (1024, 1024),
|
||||||
|
seed: int = None,
|
||||||
|
) -> bool:
|
||||||
|
select_image = None
|
||||||
|
success_flag = False
|
||||||
|
assert save_path.endswith(".png"), "Image save path must end with `.png`."
|
||||||
|
for try_idx in range(n_retry):
|
||||||
|
if select_image is not None:
|
||||||
|
select_image[0].save(save_path.replace(".png", "_raw.png"))
|
||||||
|
select_image[1].save(save_path)
|
||||||
|
break
|
||||||
|
|
||||||
|
f_prompt = PROMPT_APPEND.format(object=prompt)
|
||||||
|
logger.info(
|
||||||
|
f"Image GEN for {os.path.basename(save_path)}\n"
|
||||||
|
f"Try: {try_idx + 1}/{n_retry}, Seed: {seed}, Prompt: {f_prompt}"
|
||||||
|
)
|
||||||
|
images = PIPE_IMG.run(
|
||||||
|
f_prompt,
|
||||||
|
num_inference_steps=img_denoise_step,
|
||||||
|
guidance_scale=text_guidance_scale,
|
||||||
|
num_images_per_prompt=n_img_sample,
|
||||||
|
height=image_hw[0],
|
||||||
|
width=image_hw[1],
|
||||||
|
generator=(
|
||||||
|
torch.Generator().manual_seed(seed)
|
||||||
|
if seed is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
for idx in range(len(images)):
|
||||||
|
raw_image: Image.Image = images[idx]
|
||||||
|
image = BG_REMOVER(raw_image)
|
||||||
|
image.save(save_path)
|
||||||
|
semantic_flag, semantic_result = SEMANTIC_CHECKER(
|
||||||
|
prompt, [image.convert("RGB")]
|
||||||
|
)
|
||||||
|
seg_flag, seg_result = SEG_CHECKER(
|
||||||
|
[raw_image, image.convert("RGB")]
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
(semantic_flag and seg_flag)
|
||||||
|
or semantic_flag is None
|
||||||
|
or seg_flag is None
|
||||||
|
):
|
||||||
|
select_image = [raw_image, image]
|
||||||
|
success_flag = True
|
||||||
|
break
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
seed = random.randint(0, 100000) if seed is not None else None
|
||||||
|
|
||||||
|
return success_flag
|
||||||
|
|
||||||
|
|
||||||
|
def text_to_3d(**kwargs) -> dict:
|
||||||
|
args = parse_args()
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if hasattr(args, k) and v is not None:
|
||||||
|
setattr(args, k, v)
|
||||||
|
|
||||||
|
if args.asset_names is None or len(args.asset_names) == 0:
|
||||||
|
args.asset_names = [f"sample3d_{i}" for i in range(len(args.prompts))]
|
||||||
|
img_save_dir = os.path.join(args.output_root, "images")
|
||||||
|
asset_save_dir = os.path.join(args.output_root, "asset3d")
|
||||||
|
os.makedirs(img_save_dir, exist_ok=True)
|
||||||
|
os.makedirs(asset_save_dir, exist_ok=True)
|
||||||
|
results = defaultdict(dict)
|
||||||
|
for prompt, node in zip(args.prompts, args.asset_names):
|
||||||
|
success_flag = False
|
||||||
|
n_pipe_retry = args.n_pipe_retry
|
||||||
|
seed_img = args.seed_img
|
||||||
|
seed_3d = args.seed_3d
|
||||||
|
while success_flag is False and n_pipe_retry > 0:
|
||||||
|
logger.info(
|
||||||
|
f"GEN pipeline for node {node}\n"
|
||||||
|
f"Try round: {args.n_pipe_retry-n_pipe_retry+1}/{args.n_pipe_retry}, Prompt: {prompt}"
|
||||||
|
)
|
||||||
|
# Text-to-image GEN
|
||||||
|
save_node = node.replace(" ", "_")
|
||||||
|
gen_image_path = f"{img_save_dir}/{save_node}.png"
|
||||||
|
textgen_flag = text_to_image(
|
||||||
|
prompt,
|
||||||
|
gen_image_path,
|
||||||
|
args.n_image_retry,
|
||||||
|
args.img_denoise_step,
|
||||||
|
args.text_guidance_scale,
|
||||||
|
args.n_img_sample,
|
||||||
|
seed=seed_img,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Asset 3D GEN
|
||||||
|
node_save_dir = f"{asset_save_dir}/{save_node}"
|
||||||
|
asset_type = node if "sample3d_" not in node else None
|
||||||
|
imageto3d_api(
|
||||||
|
image_path=[gen_image_path],
|
||||||
|
output_root=node_save_dir,
|
||||||
|
asset_type=[asset_type],
|
||||||
|
seed=random.randint(0, 100000) if seed_3d is None else seed_3d,
|
||||||
|
n_retry=args.n_asset_retry,
|
||||||
|
keep_intermediate=args.keep_intermediate,
|
||||||
|
)
|
||||||
|
mesh_path = f"{node_save_dir}/result/mesh/{save_node}.obj"
|
||||||
|
image_path = render_asset3d(
|
||||||
|
mesh_path,
|
||||||
|
output_root=f"{node_save_dir}/result",
|
||||||
|
num_images=6,
|
||||||
|
elevation=(30, -30),
|
||||||
|
output_subdir="renders",
|
||||||
|
no_index_file=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
check_text = asset_type if asset_type is not None else prompt
|
||||||
|
qa_flag, qa_result = TXTGEN_CHECKER(check_text, image_path)
|
||||||
|
logger.warning(
|
||||||
|
f"Node {node}, {TXTGEN_CHECKER.__class__.__name__}: {qa_result}"
|
||||||
|
)
|
||||||
|
results["assets"][node] = f"{node_save_dir}/result"
|
||||||
|
results["quality"][node] = qa_result
|
||||||
|
|
||||||
|
if qa_flag is None or qa_flag is True:
|
||||||
|
success_flag = True
|
||||||
|
break
|
||||||
|
|
||||||
|
n_pipe_retry -= 1
|
||||||
|
seed_img = (
|
||||||
|
random.randint(0, 100000) if seed_img is not None else None
|
||||||
|
)
|
||||||
|
seed_3d = (
|
||||||
|
random.randint(0, 100000) if seed_3d is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="3D Layout Generation Config")
|
||||||
|
parser.add_argument("--prompts", nargs="+", help="text descriptions")
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_root",
|
||||||
|
type=str,
|
||||||
|
help="Directory to save outputs",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--asset_names",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
default=None,
|
||||||
|
help="Asset names to generate",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_img_sample",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="Number of image samples to generate",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--text_guidance_scale",
|
||||||
|
type=float,
|
||||||
|
default=7,
|
||||||
|
help="Text-to-image guidance scale",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--img_denoise_step",
|
||||||
|
type=int,
|
||||||
|
default=25,
|
||||||
|
help="Denoising steps for image generation",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_image_retry",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Max retry count for image generation",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_asset_retry",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Max retry count for 3D generation",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_pipe_retry",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Max retry count for 3D asset generation",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed_img",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Random seed for image generation",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed_3d",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Random seed for 3D generation",
|
||||||
|
)
|
||||||
|
parser.add_argument("--keep_intermediate", action="store_true")
|
||||||
|
|
||||||
|
args, unknown = parser.parse_known_args()
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
text_to_3d()
|
||||||
@ -2,7 +2,9 @@
|
|||||||
|
|
||||||
# Initialize variables
|
# Initialize variables
|
||||||
prompts=()
|
prompts=()
|
||||||
|
asset_types=()
|
||||||
output_root=""
|
output_root=""
|
||||||
|
seed=0
|
||||||
|
|
||||||
# Parse arguments
|
# Parse arguments
|
||||||
while [[ $# -gt 0 ]]; do
|
while [[ $# -gt 0 ]]; do
|
||||||
@ -14,10 +16,21 @@ while [[ $# -gt 0 ]]; do
|
|||||||
shift
|
shift
|
||||||
done
|
done
|
||||||
;;
|
;;
|
||||||
|
--asset_types)
|
||||||
|
shift
|
||||||
|
while [[ $# -gt 0 && ! "$1" =~ ^-- ]]; do
|
||||||
|
asset_types+=("$1")
|
||||||
|
shift
|
||||||
|
done
|
||||||
|
;;
|
||||||
--output_root)
|
--output_root)
|
||||||
output_root="$2"
|
output_root="$2"
|
||||||
shift 2
|
shift 2
|
||||||
;;
|
;;
|
||||||
|
--seed)
|
||||||
|
seed="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
*)
|
*)
|
||||||
echo "Unknown argument: $1"
|
echo "Unknown argument: $1"
|
||||||
exit 1
|
exit 1
|
||||||
@ -28,7 +41,21 @@ done
|
|||||||
# Validate required arguments
|
# Validate required arguments
|
||||||
if [[ ${#prompts[@]} -eq 0 || -z "$output_root" ]]; then
|
if [[ ${#prompts[@]} -eq 0 || -z "$output_root" ]]; then
|
||||||
echo "Missing required arguments."
|
echo "Missing required arguments."
|
||||||
echo "Usage: bash run_text2asset3d.sh --prompts \"Prompt1\" \"Prompt2\" --output_root <path>"
|
echo "Usage: bash run_text2asset3d.sh --prompts \"Prompt1\" \"Prompt2\" \
|
||||||
|
--asset_types \"type1\" \"type2\" --seed <seed_value> --output_root <path>"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# If no asset_types provided, default to ""
|
||||||
|
if [[ ${#asset_types[@]} -eq 0 ]]; then
|
||||||
|
for (( i=0; i<${#prompts[@]}; i++ )); do
|
||||||
|
asset_types+=("")
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Ensure the number of asset_types matches the number of prompts
|
||||||
|
if [[ ${#prompts[@]} -ne ${#asset_types[@]} ]]; then
|
||||||
|
echo "The number of asset types must match the number of prompts."
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@ -37,20 +64,30 @@ echo "Prompts:"
|
|||||||
for p in "${prompts[@]}"; do
|
for p in "${prompts[@]}"; do
|
||||||
echo " - $p"
|
echo " - $p"
|
||||||
done
|
done
|
||||||
|
# echo "Asset types:"
|
||||||
|
# for at in "${asset_types[@]}"; do
|
||||||
|
# echo " - $at"
|
||||||
|
# done
|
||||||
echo "Output root: ${output_root}"
|
echo "Output root: ${output_root}"
|
||||||
|
echo "Seed: ${seed}"
|
||||||
|
|
||||||
# Concatenate prompts for Python command
|
# Concatenate prompts and asset types for Python command
|
||||||
prompt_args=""
|
prompt_args=""
|
||||||
for p in "${prompts[@]}"; do
|
asset_type_args=""
|
||||||
prompt_args+="\"$p\" "
|
for i in "${!prompts[@]}"; do
|
||||||
|
prompt_args+="\"${prompts[$i]}\" "
|
||||||
|
asset_type_args+="\"${asset_types[$i]}\" "
|
||||||
done
|
done
|
||||||
|
|
||||||
|
|
||||||
# Step 1: Text-to-Image
|
# Step 1: Text-to-Image
|
||||||
eval python3 embodied_gen/scripts/text2image.py \
|
eval python3 embodied_gen/scripts/text2image.py \
|
||||||
--prompts ${prompt_args} \
|
--prompts ${prompt_args} \
|
||||||
--output_root "${output_root}/images"
|
--output_root "${output_root}/images" \
|
||||||
|
--seed ${seed}
|
||||||
|
|
||||||
# Step 2: Image-to-3D
|
# Step 2: Image-to-3D
|
||||||
python3 embodied_gen/scripts/imageto3d.py \
|
python3 embodied_gen/scripts/imageto3d.py \
|
||||||
--image_root "${output_root}/images" \
|
--image_root "${output_root}/images" \
|
||||||
--output_root "${output_root}/asset3d"
|
--output_root "${output_root}/asset3d" \
|
||||||
|
--asset_type ${asset_type_args}
|
||||||
|
|||||||
@ -10,10 +10,6 @@ while [[ $# -gt 0 ]]; do
|
|||||||
prompt="$2"
|
prompt="$2"
|
||||||
shift 2
|
shift 2
|
||||||
;;
|
;;
|
||||||
--uuid)
|
|
||||||
uuid="$2"
|
|
||||||
shift 2
|
|
||||||
;;
|
|
||||||
--output_root)
|
--output_root)
|
||||||
output_root="$2"
|
output_root="$2"
|
||||||
shift 2
|
shift 2
|
||||||
@ -26,12 +22,13 @@ while [[ $# -gt 0 ]]; do
|
|||||||
done
|
done
|
||||||
|
|
||||||
|
|
||||||
if [[ -z "$mesh_path" || -z "$prompt" || -z "$uuid" || -z "$output_root" ]]; then
|
if [[ -z "$mesh_path" || -z "$prompt" || -z "$output_root" ]]; then
|
||||||
echo "params missing"
|
echo "params missing"
|
||||||
echo "usage: bash run.sh --mesh_path <path> --prompt <text> --uuid <id> --output_root <path>"
|
echo "usage: bash run.sh --mesh_path <path> --prompt <text> --output_root <path>"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
uuid=$(basename "$output_root")
|
||||||
# Step 1: drender-cli for condition rendering
|
# Step 1: drender-cli for condition rendering
|
||||||
drender-cli --mesh_path ${mesh_path} \
|
drender-cli --mesh_path ${mesh_path} \
|
||||||
--output_root ${output_root}/condition \
|
--output_root ${output_root}/condition \
|
||||||
|
|||||||
107
embodied_gen/utils/enum.py
Normal file
107
embodied_gen/utils/enum.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
# Project EmbodiedGen
|
||||||
|
#
|
||||||
|
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
# implied. See the License for the specific language governing
|
||||||
|
# permissions and limitations under the License.
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from dataclasses_json import DataClassJsonMixin
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"RenderItems",
|
||||||
|
"Scene3DItemEnum",
|
||||||
|
"SpatialRelationEnum",
|
||||||
|
"RobotItemEnum",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RenderItems(str, Enum):
|
||||||
|
IMAGE = "image_color"
|
||||||
|
ALPHA = "image_mask"
|
||||||
|
VIEW_NORMAL = "image_view_normal"
|
||||||
|
GLOBAL_NORMAL = "image_global_normal"
|
||||||
|
POSITION_MAP = "image_position"
|
||||||
|
DEPTH = "image_depth"
|
||||||
|
ALBEDO = "image_albedo"
|
||||||
|
DIFFUSE = "image_diffuse"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Scene3DItemEnum(str, Enum):
|
||||||
|
BACKGROUND = "background"
|
||||||
|
CONTEXT = "context"
|
||||||
|
ROBOT = "robot"
|
||||||
|
MANIPULATED_OBJS = "manipulated_objs"
|
||||||
|
DISTRACTOR_OBJS = "distractor_objs"
|
||||||
|
OTHERS = "others"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def object_list(cls, layout_relation: dict) -> list:
|
||||||
|
return (
|
||||||
|
[
|
||||||
|
layout_relation[cls.BACKGROUND.value],
|
||||||
|
layout_relation[cls.CONTEXT.value],
|
||||||
|
]
|
||||||
|
+ layout_relation[cls.MANIPULATED_OBJS.value]
|
||||||
|
+ layout_relation[cls.DISTRACTOR_OBJS.value]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def object_mapping(cls, layout_relation):
|
||||||
|
relation_mapping = {
|
||||||
|
# layout_relation[cls.ROBOT.value]: cls.ROBOT.value,
|
||||||
|
layout_relation[cls.BACKGROUND.value]: cls.BACKGROUND.value,
|
||||||
|
layout_relation[cls.CONTEXT.value]: cls.CONTEXT.value,
|
||||||
|
}
|
||||||
|
relation_mapping.update(
|
||||||
|
{
|
||||||
|
item: cls.MANIPULATED_OBJS.value
|
||||||
|
for item in layout_relation[cls.MANIPULATED_OBJS.value]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
relation_mapping.update(
|
||||||
|
{
|
||||||
|
item: cls.DISTRACTOR_OBJS.value
|
||||||
|
for item in layout_relation[cls.DISTRACTOR_OBJS.value]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return relation_mapping
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SpatialRelationEnum(str, Enum):
|
||||||
|
ON = "ON" # objects on the table
|
||||||
|
IN = "IN" # objects in the room
|
||||||
|
INSIDE = "INSIDE" # objects inside the shelf/rack
|
||||||
|
FLOOR = "FLOOR" # object floor room/bin
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RobotItemEnum(str, Enum):
|
||||||
|
FRANKA = "franka"
|
||||||
|
UR5 = "ur5"
|
||||||
|
PIPER = "piper"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LayoutInfo(DataClassJsonMixin):
|
||||||
|
tree: dict[str, list]
|
||||||
|
relation: dict[str, str | list[str]]
|
||||||
|
objs_desc: dict[str, str] = field(default_factory=dict)
|
||||||
|
assets: dict[str, str] = field(default_factory=dict)
|
||||||
|
quality: dict[str, str] = field(default_factory=dict)
|
||||||
|
position: dict[str, list[float]] = field(default_factory=dict)
|
||||||
@ -30,12 +30,20 @@ from tenacity import (
|
|||||||
stop_after_delay,
|
stop_after_delay,
|
||||||
wait_random_exponential,
|
wait_random_exponential,
|
||||||
)
|
)
|
||||||
from embodied_gen.utils.process_media import combine_images_to_base64
|
from embodied_gen.utils.process_media import combine_images_to_grid
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||||
|
logging.basicConfig(level=logging.WARNING)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"GPTclient",
|
||||||
|
]
|
||||||
|
|
||||||
|
CONFIG_FILE = "embodied_gen/utils/gpt_config.yaml"
|
||||||
|
|
||||||
|
|
||||||
class GPTclient:
|
class GPTclient:
|
||||||
"""A client to interact with the GPT model via OpenAI or Azure API."""
|
"""A client to interact with the GPT model via OpenAI or Azure API."""
|
||||||
|
|
||||||
@ -45,6 +53,7 @@ class GPTclient:
|
|||||||
api_key: str,
|
api_key: str,
|
||||||
model_name: str = "yfb-gpt-4o",
|
model_name: str = "yfb-gpt-4o",
|
||||||
api_version: str = None,
|
api_version: str = None,
|
||||||
|
check_connection: bool = True,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
):
|
):
|
||||||
if api_version is not None:
|
if api_version is not None:
|
||||||
@ -63,6 +72,9 @@ class GPTclient:
|
|||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.image_formats = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"}
|
self.image_formats = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"}
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
if check_connection:
|
||||||
|
self.check_connection()
|
||||||
|
|
||||||
logger.info(f"Using GPT model: {self.model_name}.")
|
logger.info(f"Using GPT model: {self.model_name}.")
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
@ -77,6 +89,7 @@ class GPTclient:
|
|||||||
text_prompt: str,
|
text_prompt: str,
|
||||||
image_base64: Optional[list[str | Image.Image]] = None,
|
image_base64: Optional[list[str | Image.Image]] = None,
|
||||||
system_role: Optional[str] = None,
|
system_role: Optional[str] = None,
|
||||||
|
params: Optional[dict] = None,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""Queries the GPT model with a text and optional image prompts.
|
"""Queries the GPT model with a text and optional image prompts.
|
||||||
|
|
||||||
@ -86,6 +99,7 @@ class GPTclient:
|
|||||||
or local image paths or PIL.Image to accompany the text prompt.
|
or local image paths or PIL.Image to accompany the text prompt.
|
||||||
system_role (Optional[str]): Optional system-level instructions
|
system_role (Optional[str]): Optional system-level instructions
|
||||||
that specify the behavior of the assistant.
|
that specify the behavior of the assistant.
|
||||||
|
params (Optional[dict]): Additional parameters for GPT setting.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[str]: The response content generated by the model based on
|
Optional[str]: The response content generated by the model based on
|
||||||
@ -103,11 +117,11 @@ class GPTclient:
|
|||||||
|
|
||||||
# Process images if provided
|
# Process images if provided
|
||||||
if image_base64 is not None:
|
if image_base64 is not None:
|
||||||
image_base64 = (
|
if not isinstance(image_base64, list):
|
||||||
image_base64
|
image_base64 = [image_base64]
|
||||||
if isinstance(image_base64, list)
|
# Hardcode tmp because of the openrouter can't input multi images.
|
||||||
else [image_base64]
|
if "openrouter" in self.endpoint:
|
||||||
)
|
image_base64 = combine_images_to_grid(image_base64)
|
||||||
for img in image_base64:
|
for img in image_base64:
|
||||||
if isinstance(img, Image.Image):
|
if isinstance(img, Image.Image):
|
||||||
buffer = BytesIO()
|
buffer = BytesIO()
|
||||||
@ -142,8 +156,11 @@ class GPTclient:
|
|||||||
"frequency_penalty": 0,
|
"frequency_penalty": 0,
|
||||||
"presence_penalty": 0,
|
"presence_penalty": 0,
|
||||||
"stop": None,
|
"stop": None,
|
||||||
|
"model": self.model_name,
|
||||||
}
|
}
|
||||||
payload.update({"model": self.model_name})
|
|
||||||
|
if params:
|
||||||
|
payload.update(params)
|
||||||
|
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
@ -159,8 +176,28 @@ class GPTclient:
|
|||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
def check_connection(self) -> None:
|
||||||
|
"""Check whether the GPT API connection is working."""
|
||||||
|
try:
|
||||||
|
response = self.completion_with_backoff(
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a test system."},
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
],
|
||||||
|
model=self.model_name,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=100,
|
||||||
|
)
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
logger.info(f"Connection check success.")
|
||||||
|
except Exception as e:
|
||||||
|
raise ConnectionError(
|
||||||
|
f"Failed to connect to GPT API at {self.endpoint}, "
|
||||||
|
f"please check setting in `{CONFIG_FILE}` and `README`."
|
||||||
|
)
|
||||||
|
|
||||||
with open("embodied_gen/utils/gpt_config.yaml", "r") as f:
|
|
||||||
|
with open(CONFIG_FILE, "r") as f:
|
||||||
config = yaml.safe_load(f)
|
config = yaml.safe_load(f)
|
||||||
|
|
||||||
agent_type = config["agent_type"]
|
agent_type = config["agent_type"]
|
||||||
@ -177,32 +214,5 @@ GPT_CLIENT = GPTclient(
|
|||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
|
check_connection=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
if "openrouter" in GPT_CLIENT.endpoint:
|
|
||||||
response = GPT_CLIENT.query(
|
|
||||||
text_prompt="What is the content in each image?",
|
|
||||||
image_base64=combine_images_to_base64(
|
|
||||||
[
|
|
||||||
"apps/assets/example_image/sample_02.jpg",
|
|
||||||
"apps/assets/example_image/sample_03.jpg",
|
|
||||||
]
|
|
||||||
), # input raw image_path if only one image
|
|
||||||
)
|
|
||||||
print(response)
|
|
||||||
else:
|
|
||||||
response = GPT_CLIENT.query(
|
|
||||||
text_prompt="What is the content in the images?",
|
|
||||||
image_base64=[
|
|
||||||
Image.open("apps/assets/example_image/sample_02.jpg"),
|
|
||||||
Image.open("apps/assets/example_image/sample_03.jpg"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
print(response)
|
|
||||||
|
|
||||||
# test2: text prompt
|
|
||||||
response = GPT_CLIENT.query(
|
|
||||||
text_prompt="What is the capital of China?"
|
|
||||||
)
|
|
||||||
print(response)
|
|
||||||
|
|||||||
48
embodied_gen/utils/log.py
Normal file
48
embodied_gen/utils/log.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
# Project EmbodiedGen
|
||||||
|
#
|
||||||
|
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
# implied. See the License for the specific language governing
|
||||||
|
# permissions and limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from colorlog import ColoredFormatter
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"logger",
|
||||||
|
]
|
||||||
|
|
||||||
|
LOG_FORMAT = (
|
||||||
|
"%(log_color)s[%(asctime)s] %(levelname)-8s | %(message)s%(reset)s"
|
||||||
|
)
|
||||||
|
DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
|
||||||
|
|
||||||
|
formatter = ColoredFormatter(
|
||||||
|
LOG_FORMAT,
|
||||||
|
datefmt=DATE_FORMAT,
|
||||||
|
log_colors={
|
||||||
|
"DEBUG": "cyan",
|
||||||
|
"INFO": "green",
|
||||||
|
"WARNING": "yellow",
|
||||||
|
"ERROR": "red",
|
||||||
|
"CRITICAL": "bold_red",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
logger.propagate = False
|
||||||
@ -15,34 +15,24 @@
|
|||||||
# permissions and limitations under the License.
|
# permissions and limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import textwrap
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from io import BytesIO
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import imageio
|
import imageio
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import PIL.Image as Image
|
import PIL.Image as Image
|
||||||
import spaces
|
import spaces
|
||||||
import torch
|
from matplotlib.patches import Patch
|
||||||
from moviepy.editor import VideoFileClip, clips_array
|
from moviepy.editor import VideoFileClip, clips_array
|
||||||
from tqdm import tqdm
|
|
||||||
from embodied_gen.data.differentiable_render import entrypoint as render_api
|
from embodied_gen.data.differentiable_render import entrypoint as render_api
|
||||||
|
from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum
|
||||||
current_file_path = os.path.abspath(__file__)
|
|
||||||
current_dir = os.path.dirname(current_file_path)
|
|
||||||
sys.path.append(os.path.join(current_dir, "../.."))
|
|
||||||
from thirdparty.TRELLIS.trellis.renderers.mesh_renderer import MeshRenderer
|
|
||||||
from thirdparty.TRELLIS.trellis.representations import MeshExtractResult
|
|
||||||
from thirdparty.TRELLIS.trellis.utils.render_utils import (
|
|
||||||
render_frames,
|
|
||||||
yaw_pitch_r_fov_to_extrinsics_intrinsics,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -53,9 +43,8 @@ __all__ = [
|
|||||||
"merge_images_video",
|
"merge_images_video",
|
||||||
"filter_small_connected_components",
|
"filter_small_connected_components",
|
||||||
"filter_image_small_connected_components",
|
"filter_image_small_connected_components",
|
||||||
"combine_images_to_base64",
|
"combine_images_to_grid",
|
||||||
"render_mesh",
|
"SceneTreeVisualizer",
|
||||||
"render_video",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -66,12 +55,14 @@ def render_asset3d(
|
|||||||
distance: float = 5.0,
|
distance: float = 5.0,
|
||||||
num_images: int = 1,
|
num_images: int = 1,
|
||||||
elevation: list[float] = (0.0,),
|
elevation: list[float] = (0.0,),
|
||||||
pbr_light_factor: float = 1.5,
|
pbr_light_factor: float = 1.2,
|
||||||
return_key: str = "image_color/*",
|
return_key: str = "image_color/*",
|
||||||
output_subdir: str = "renders",
|
output_subdir: str = "renders",
|
||||||
gen_color_mp4: bool = False,
|
gen_color_mp4: bool = False,
|
||||||
gen_viewnormal_mp4: bool = False,
|
gen_viewnormal_mp4: bool = False,
|
||||||
gen_glonormal_mp4: bool = False,
|
gen_glonormal_mp4: bool = False,
|
||||||
|
no_index_file: bool = False,
|
||||||
|
with_mtl: bool = True,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
input_args = dict(
|
input_args = dict(
|
||||||
mesh_path=mesh_path,
|
mesh_path=mesh_path,
|
||||||
@ -81,14 +72,13 @@ def render_asset3d(
|
|||||||
num_images=num_images,
|
num_images=num_images,
|
||||||
elevation=elevation,
|
elevation=elevation,
|
||||||
pbr_light_factor=pbr_light_factor,
|
pbr_light_factor=pbr_light_factor,
|
||||||
with_mtl=True,
|
with_mtl=with_mtl,
|
||||||
|
gen_color_mp4=gen_color_mp4,
|
||||||
|
gen_viewnormal_mp4=gen_viewnormal_mp4,
|
||||||
|
gen_glonormal_mp4=gen_glonormal_mp4,
|
||||||
|
no_index_file=no_index_file,
|
||||||
)
|
)
|
||||||
if gen_color_mp4:
|
|
||||||
input_args["gen_color_mp4"] = True
|
|
||||||
if gen_viewnormal_mp4:
|
|
||||||
input_args["gen_viewnormal_mp4"] = True
|
|
||||||
if gen_glonormal_mp4:
|
|
||||||
input_args["gen_glonormal_mp4"] = True
|
|
||||||
try:
|
try:
|
||||||
_ = render_api(**input_args)
|
_ = render_api(**input_args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -168,12 +158,15 @@ def filter_image_small_connected_components(
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
def combine_images_to_base64(
|
def combine_images_to_grid(
|
||||||
images: list[str | Image.Image],
|
images: list[str | Image.Image],
|
||||||
cat_row_col: tuple[int, int] = None,
|
cat_row_col: tuple[int, int] = None,
|
||||||
target_wh: tuple[int, int] = (512, 512),
|
target_wh: tuple[int, int] = (512, 512),
|
||||||
) -> str:
|
) -> list[str | Image.Image]:
|
||||||
n_images = len(images)
|
n_images = len(images)
|
||||||
|
if n_images == 1:
|
||||||
|
return images
|
||||||
|
|
||||||
if cat_row_col is None:
|
if cat_row_col is None:
|
||||||
n_col = math.ceil(math.sqrt(n_images))
|
n_col = math.ceil(math.sqrt(n_images))
|
||||||
n_row = math.ceil(n_images / n_col)
|
n_row = math.ceil(n_images / n_col)
|
||||||
@ -182,88 +175,190 @@ def combine_images_to_base64(
|
|||||||
|
|
||||||
images = [
|
images = [
|
||||||
Image.open(p).convert("RGB") if isinstance(p, str) else p
|
Image.open(p).convert("RGB") if isinstance(p, str) else p
|
||||||
for p in images[: n_row * n_col]
|
for p in images
|
||||||
]
|
]
|
||||||
images = [img.resize(target_wh) for img in images]
|
images = [img.resize(target_wh) for img in images]
|
||||||
|
|
||||||
grid_w, grid_h = n_col * target_wh[0], n_row * target_wh[1]
|
grid_w, grid_h = n_col * target_wh[0], n_row * target_wh[1]
|
||||||
grid = Image.new("RGB", (grid_w, grid_h), (255, 255, 255))
|
grid = Image.new("RGB", (grid_w, grid_h), (0, 0, 0))
|
||||||
|
|
||||||
for idx, img in enumerate(images):
|
for idx, img in enumerate(images):
|
||||||
row, col = divmod(idx, n_col)
|
row, col = divmod(idx, n_col)
|
||||||
grid.paste(img, (col * target_wh[0], row * target_wh[1]))
|
grid.paste(img, (col * target_wh[0], row * target_wh[1]))
|
||||||
|
|
||||||
buffer = BytesIO()
|
return [grid]
|
||||||
grid.save(buffer, format="PNG")
|
|
||||||
|
|
||||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
@spaces.GPU
|
class SceneTreeVisualizer:
|
||||||
def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs):
|
def __init__(self, layout_info: LayoutInfo) -> None:
|
||||||
renderer = MeshRenderer()
|
self.tree = layout_info.tree
|
||||||
renderer.rendering_options.resolution = options.get("resolution", 512)
|
self.relation = layout_info.relation
|
||||||
renderer.rendering_options.near = options.get("near", 1)
|
self.objs_desc = layout_info.objs_desc
|
||||||
renderer.rendering_options.far = options.get("far", 100)
|
self.G = nx.DiGraph()
|
||||||
renderer.rendering_options.ssaa = options.get("ssaa", 4)
|
self.root = self._find_root()
|
||||||
rets = {}
|
self._build_graph()
|
||||||
for extr, intr in tqdm(zip(extrinsics, intrinsics), desc="Rendering"):
|
|
||||||
res = renderer.render(sample, extr, intr)
|
self.role_colors = {
|
||||||
if "normal" not in rets:
|
Scene3DItemEnum.BACKGROUND.value: "plum",
|
||||||
rets["normal"] = []
|
Scene3DItemEnum.CONTEXT.value: "lightblue",
|
||||||
normal = torch.lerp(
|
Scene3DItemEnum.ROBOT.value: "lightcoral",
|
||||||
torch.zeros_like(res["normal"]), res["normal"], res["mask"]
|
Scene3DItemEnum.MANIPULATED_OBJS.value: "lightgreen",
|
||||||
|
Scene3DItemEnum.DISTRACTOR_OBJS.value: "lightgray",
|
||||||
|
Scene3DItemEnum.OTHERS.value: "orange",
|
||||||
|
}
|
||||||
|
|
||||||
|
def _find_root(self) -> str:
|
||||||
|
children = {c for cs in self.tree.values() for c, _ in cs}
|
||||||
|
parents = set(self.tree.keys())
|
||||||
|
roots = parents - children
|
||||||
|
if not roots:
|
||||||
|
raise ValueError("No root node found.")
|
||||||
|
return next(iter(roots))
|
||||||
|
|
||||||
|
def _build_graph(self):
|
||||||
|
for parent, children in self.tree.items():
|
||||||
|
for child, relation in children:
|
||||||
|
self.G.add_edge(parent, child, relation=relation)
|
||||||
|
|
||||||
|
def _get_node_role(self, node: str) -> str:
|
||||||
|
if node == self.relation.get(Scene3DItemEnum.BACKGROUND.value):
|
||||||
|
return Scene3DItemEnum.BACKGROUND.value
|
||||||
|
if node == self.relation.get(Scene3DItemEnum.CONTEXT.value):
|
||||||
|
return Scene3DItemEnum.CONTEXT.value
|
||||||
|
if node == self.relation.get(Scene3DItemEnum.ROBOT.value):
|
||||||
|
return Scene3DItemEnum.ROBOT.value
|
||||||
|
if node in self.relation.get(
|
||||||
|
Scene3DItemEnum.MANIPULATED_OBJS.value, []
|
||||||
|
):
|
||||||
|
return Scene3DItemEnum.MANIPULATED_OBJS.value
|
||||||
|
if node in self.relation.get(
|
||||||
|
Scene3DItemEnum.DISTRACTOR_OBJS.value, []
|
||||||
|
):
|
||||||
|
return Scene3DItemEnum.DISTRACTOR_OBJS.value
|
||||||
|
return Scene3DItemEnum.OTHERS.value
|
||||||
|
|
||||||
|
def _get_positions(
|
||||||
|
self, root, width=1.0, vert_gap=0.1, vert_loc=1, xcenter=0.5, pos=None
|
||||||
|
):
|
||||||
|
if pos is None:
|
||||||
|
pos = {root: (xcenter, vert_loc)}
|
||||||
|
else:
|
||||||
|
pos[root] = (xcenter, vert_loc)
|
||||||
|
|
||||||
|
children = list(self.G.successors(root))
|
||||||
|
if children:
|
||||||
|
dx = width / len(children)
|
||||||
|
next_x = xcenter - width / 2 - dx / 2
|
||||||
|
for child in children:
|
||||||
|
next_x += dx
|
||||||
|
pos = self._get_positions(
|
||||||
|
child,
|
||||||
|
width=dx,
|
||||||
|
vert_gap=vert_gap,
|
||||||
|
vert_loc=vert_loc - vert_gap,
|
||||||
|
xcenter=next_x,
|
||||||
|
pos=pos,
|
||||||
)
|
)
|
||||||
normal = np.clip(
|
return pos
|
||||||
normal.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255
|
|
||||||
).astype(np.uint8)
|
|
||||||
rets["normal"].append(normal)
|
|
||||||
|
|
||||||
return rets
|
def render(
|
||||||
|
self,
|
||||||
|
save_path: str,
|
||||||
|
figsize=(8, 6),
|
||||||
|
dpi=300,
|
||||||
|
title: str = "Scene 3D Hierarchy Tree",
|
||||||
|
):
|
||||||
|
node_colors = [
|
||||||
|
self.role_colors[self._get_node_role(n)] for n in self.G.nodes
|
||||||
|
]
|
||||||
|
pos = self._get_positions(self.root)
|
||||||
|
|
||||||
|
plt.figure(figsize=figsize)
|
||||||
@spaces.GPU
|
nx.draw(
|
||||||
def render_video(
|
self.G,
|
||||||
sample,
|
pos,
|
||||||
resolution=512,
|
with_labels=True,
|
||||||
bg_color=(0, 0, 0),
|
arrows=False,
|
||||||
num_frames=300,
|
node_size=2000,
|
||||||
r=2,
|
node_color=node_colors,
|
||||||
fov=40,
|
font_size=10,
|
||||||
**kwargs,
|
font_weight="bold",
|
||||||
):
|
|
||||||
yaws = torch.linspace(0, 2 * 3.1415, num_frames)
|
|
||||||
yaws = yaws.tolist()
|
|
||||||
pitch = [0.5] * num_frames
|
|
||||||
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(
|
|
||||||
yaws, pitch, r, fov
|
|
||||||
)
|
|
||||||
render_fn = (
|
|
||||||
render_mesh if isinstance(sample, MeshExtractResult) else render_frames
|
|
||||||
)
|
|
||||||
result = render_fn(
|
|
||||||
sample,
|
|
||||||
extrinsics,
|
|
||||||
intrinsics,
|
|
||||||
{"resolution": resolution, "bg_color": bg_color},
|
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
# Draw edge labels
|
||||||
|
edge_labels = nx.get_edge_attributes(self.G, "relation")
|
||||||
|
nx.draw_networkx_edge_labels(
|
||||||
|
self.G,
|
||||||
|
pos,
|
||||||
|
edge_labels=edge_labels,
|
||||||
|
font_size=9,
|
||||||
|
font_color="black",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Draw small description text under each node (if available)
|
||||||
|
for node, (x, y) in pos.items():
|
||||||
|
desc = self.objs_desc.get(node)
|
||||||
|
if desc:
|
||||||
|
wrapped = "\n".join(textwrap.wrap(desc, width=30))
|
||||||
|
plt.text(
|
||||||
|
x,
|
||||||
|
y - 0.006,
|
||||||
|
wrapped,
|
||||||
|
fontsize=6,
|
||||||
|
ha="center",
|
||||||
|
va="top",
|
||||||
|
wrap=True,
|
||||||
|
color="black",
|
||||||
|
bbox=dict(
|
||||||
|
facecolor="dimgray",
|
||||||
|
edgecolor="darkgray",
|
||||||
|
alpha=0.1,
|
||||||
|
boxstyle="round,pad=0.2",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
plt.title(title, fontsize=12)
|
||||||
|
task_desc = self.relation.get("task_desc", "")
|
||||||
|
if task_desc:
|
||||||
|
plt.suptitle(
|
||||||
|
f"Task Description: {task_desc}", fontsize=10, y=0.999
|
||||||
|
)
|
||||||
|
|
||||||
|
plt.axis("off")
|
||||||
|
|
||||||
|
legend_handles = [
|
||||||
|
Patch(facecolor=color, edgecolor='black', label=role)
|
||||||
|
for role, color in self.role_colors.items()
|
||||||
|
]
|
||||||
|
plt.legend(
|
||||||
|
handles=legend_handles,
|
||||||
|
loc="lower center",
|
||||||
|
ncol=3,
|
||||||
|
bbox_to_anchor=(0.5, -0.1),
|
||||||
|
fontsize=9,
|
||||||
|
)
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||||
|
plt.savefig(save_path, dpi=dpi, bbox_inches="tight")
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def load_scene_dict(file_path: str) -> dict:
|
||||||
|
scene_dict = {}
|
||||||
|
with open(file_path, "r", encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line or ":" not in line:
|
||||||
|
continue
|
||||||
|
scene_id, desc = line.split(":", 1)
|
||||||
|
scene_dict[scene_id.strip()] = desc.strip()
|
||||||
|
|
||||||
|
return scene_dict
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Example usage:
|
|
||||||
merge_video_video(
|
merge_video_video(
|
||||||
"outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh_glo_normal.mp4", # noqa
|
"outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh_glo_normal.mp4", # noqa
|
||||||
"outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh.mp4", # noqa
|
"outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh.mp4", # noqa
|
||||||
"merge.mp4",
|
"merge.mp4",
|
||||||
)
|
)
|
||||||
|
|
||||||
image_base64 = combine_images_to_base64(
|
|
||||||
[
|
|
||||||
"apps/assets/example_image/sample_00.jpg",
|
|
||||||
"apps/assets/example_image/sample_01.jpg",
|
|
||||||
"apps/assets/example_image/sample_02.jpg",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
VERSION = "v0.1.0"
|
VERSION = "v0.1.1"
|
||||||
|
|||||||
90
embodied_gen/utils/trender.py
Normal file
90
embodied_gen/utils/trender.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
# Project EmbodiedGen
|
||||||
|
#
|
||||||
|
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
# implied. See the License for the specific language governing
|
||||||
|
# permissions and limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import spaces
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
current_file_path = os.path.abspath(__file__)
|
||||||
|
current_dir = os.path.dirname(current_file_path)
|
||||||
|
sys.path.append(os.path.join(current_dir, "../.."))
|
||||||
|
from thirdparty.TRELLIS.trellis.renderers.mesh_renderer import MeshRenderer
|
||||||
|
from thirdparty.TRELLIS.trellis.representations import MeshExtractResult
|
||||||
|
from thirdparty.TRELLIS.trellis.utils.render_utils import (
|
||||||
|
render_frames,
|
||||||
|
yaw_pitch_r_fov_to_extrinsics_intrinsics,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"render_video",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@spaces.GPU
|
||||||
|
def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs):
|
||||||
|
renderer = MeshRenderer()
|
||||||
|
renderer.rendering_options.resolution = options.get("resolution", 512)
|
||||||
|
renderer.rendering_options.near = options.get("near", 1)
|
||||||
|
renderer.rendering_options.far = options.get("far", 100)
|
||||||
|
renderer.rendering_options.ssaa = options.get("ssaa", 4)
|
||||||
|
rets = {}
|
||||||
|
for extr, intr in tqdm(zip(extrinsics, intrinsics), desc="Rendering"):
|
||||||
|
res = renderer.render(sample, extr, intr)
|
||||||
|
if "normal" not in rets:
|
||||||
|
rets["normal"] = []
|
||||||
|
normal = torch.lerp(
|
||||||
|
torch.zeros_like(res["normal"]), res["normal"], res["mask"]
|
||||||
|
)
|
||||||
|
normal = np.clip(
|
||||||
|
normal.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255
|
||||||
|
).astype(np.uint8)
|
||||||
|
rets["normal"].append(normal)
|
||||||
|
|
||||||
|
return rets
|
||||||
|
|
||||||
|
|
||||||
|
@spaces.GPU
|
||||||
|
def render_video(
|
||||||
|
sample,
|
||||||
|
resolution=512,
|
||||||
|
bg_color=(0, 0, 0),
|
||||||
|
num_frames=300,
|
||||||
|
r=2,
|
||||||
|
fov=40,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
yaws = torch.linspace(0, 2 * 3.1415, num_frames)
|
||||||
|
yaws = yaws.tolist()
|
||||||
|
pitch = [0.5] * num_frames
|
||||||
|
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(
|
||||||
|
yaws, pitch, r, fov
|
||||||
|
)
|
||||||
|
render_fn = (
|
||||||
|
render_mesh if isinstance(sample, MeshExtractResult) else render_frames
|
||||||
|
)
|
||||||
|
result = render_fn(
|
||||||
|
sample,
|
||||||
|
extrinsics,
|
||||||
|
intrinsics,
|
||||||
|
{"resolution": resolution, "bg_color": bg_color},
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
@ -102,7 +102,7 @@ class AestheticPredictor:
|
|||||||
def _load_sac_model(self, model_path, input_size):
|
def _load_sac_model(self, model_path, input_size):
|
||||||
"""Load the SAC model."""
|
"""Load the SAC model."""
|
||||||
model = self.MLP(input_size)
|
model = self.MLP(input_size)
|
||||||
ckpt = torch.load(model_path)
|
ckpt = torch.load(model_path, weights_only=True)
|
||||||
model.load_state_dict(ckpt)
|
model.load_state_dict(ckpt)
|
||||||
model.to(self.device)
|
model.to(self.device)
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -135,15 +135,3 @@ class AestheticPredictor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return prediction.item()
|
return prediction.item()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Configuration
|
|
||||||
img_path = "apps/assets/example_image/sample_00.jpg"
|
|
||||||
|
|
||||||
# Initialize the predictor
|
|
||||||
predictor = AestheticPredictor()
|
|
||||||
|
|
||||||
# Predict the aesthetic score
|
|
||||||
score = predictor.predict(img_path)
|
|
||||||
print("Aesthetic score predicted by the model:", score)
|
|
||||||
|
|||||||
@ -16,17 +16,26 @@
|
|||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import random
|
||||||
|
|
||||||
from tqdm import tqdm
|
import json_repair
|
||||||
|
from PIL import Image
|
||||||
from embodied_gen.utils.gpt_clients import GPT_CLIENT, GPTclient
|
from embodied_gen.utils.gpt_clients import GPT_CLIENT, GPTclient
|
||||||
from embodied_gen.utils.process_media import render_asset3d
|
|
||||||
from embodied_gen.validators.aesthetic_predictor import AestheticPredictor
|
from embodied_gen.validators.aesthetic_predictor import AestheticPredictor
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MeshGeoChecker",
|
||||||
|
"ImageSegChecker",
|
||||||
|
"ImageAestheticChecker",
|
||||||
|
"SemanticConsistChecker",
|
||||||
|
"TextGenAlignChecker",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class BaseChecker:
|
class BaseChecker:
|
||||||
def __init__(self, prompt: str = None, verbose: bool = False) -> None:
|
def __init__(self, prompt: str = None, verbose: bool = False) -> None:
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
@ -37,14 +46,18 @@ class BaseChecker:
|
|||||||
"Subclasses must implement the query method."
|
"Subclasses must implement the query method."
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs) -> bool:
|
def __call__(self, *args, **kwargs) -> tuple[bool, str]:
|
||||||
response = self.query(*args, **kwargs)
|
response = self.query(*args, **kwargs)
|
||||||
if response is None:
|
if self.verbose:
|
||||||
response = "Error when calling gpt api."
|
|
||||||
|
|
||||||
if self.verbose and response != "YES":
|
|
||||||
logger.info(response)
|
logger.info(response)
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
flag = None
|
||||||
|
response = (
|
||||||
|
"Error when calling GPT api, check config in "
|
||||||
|
"`embodied_gen/utils/gpt_config.yaml` or net connection."
|
||||||
|
)
|
||||||
|
else:
|
||||||
flag = "YES" in response
|
flag = "YES" in response
|
||||||
response = "YES" if flag else response
|
response = "YES" if flag else response
|
||||||
|
|
||||||
@ -92,21 +105,29 @@ class MeshGeoChecker(BaseChecker):
|
|||||||
self.gpt_client = gpt_client
|
self.gpt_client = gpt_client
|
||||||
if self.prompt is None:
|
if self.prompt is None:
|
||||||
self.prompt = """
|
self.prompt = """
|
||||||
Refer to the provided multi-view rendering images to evaluate
|
You are an expert in evaluating the geometry quality of generated 3D asset.
|
||||||
whether the geometry of the 3D object asset is complete and
|
You will be given rendered views of a generated 3D asset with black background.
|
||||||
whether the asset can be placed stably on the ground.
|
Your task is to evaluate the quality of the 3D asset generation,
|
||||||
Return "YES" only if reach the requirments,
|
including geometry, structure, and appearance, based on the rendered views.
|
||||||
otherwise "NO" and explain the reason very briefly.
|
Criteria:
|
||||||
|
- Is the geometry complete and well-formed, without missing parts or redundant structures?
|
||||||
|
- Is the geometric structure of the object complete?
|
||||||
|
- Minor flaws, asymmetries, or simplifications (e.g., less detail on sides or back,
|
||||||
|
soft edges) are acceptable if the object is structurally sound and recognizable.
|
||||||
|
- Only evaluate geometry. Do not assess texture quality.
|
||||||
|
- The asset should not contain any unrelated elements, such as
|
||||||
|
ground planes, platforms, or background props (e.g., paper, flooring).
|
||||||
|
|
||||||
|
If all the above criteria are met, return "YES". Otherwise, return
|
||||||
|
"NO" followed by a brief explanation (no more than 20 words).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
Images show a yellow cup standing on a flat white plane -> NO
|
||||||
|
-> Response: NO: extra white surface under the object.
|
||||||
|
Image shows a chair with simplified back legs and soft edges → YES
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def query(self, image_paths: str) -> str:
|
def query(self, image_paths: list[str | Image.Image]) -> str:
|
||||||
# Hardcode tmp because of the openrouter can't input multi images.
|
|
||||||
if "openrouter" in self.gpt_client.endpoint:
|
|
||||||
from embodied_gen.utils.process_media import (
|
|
||||||
combine_images_to_base64,
|
|
||||||
)
|
|
||||||
|
|
||||||
image_paths = combine_images_to_base64(image_paths)
|
|
||||||
|
|
||||||
return self.gpt_client.query(
|
return self.gpt_client.query(
|
||||||
text_prompt=self.prompt,
|
text_prompt=self.prompt,
|
||||||
@ -137,14 +158,19 @@ class ImageSegChecker(BaseChecker):
|
|||||||
self.gpt_client = gpt_client
|
self.gpt_client = gpt_client
|
||||||
if self.prompt is None:
|
if self.prompt is None:
|
||||||
self.prompt = """
|
self.prompt = """
|
||||||
The first image is the original, and the second image is the
|
Task: Evaluate the quality of object segmentation between two images:
|
||||||
result after segmenting the main object. Evaluate the segmentation
|
the first is the original, the second is the segmented result.
|
||||||
quality to ensure the main object is clearly segmented without
|
|
||||||
significant truncation. Note that the foreground of the object
|
Criteria:
|
||||||
needs to be extracted instead of the background.
|
- The main foreground object should be clearly extracted (not the background).
|
||||||
Minor imperfections can be ignored. If segmentation is acceptable,
|
- The object must appear realistic, with reasonable geometry and color.
|
||||||
return "YES" only; otherwise, return "NO" with
|
- The object should be geometrically complete — no missing, truncated, or cropped parts.
|
||||||
very brief explanation.
|
- The object must be centered, with a margin on all sides.
|
||||||
|
- Ignore minor imperfections (e.g., small holes or fine edge artifacts).
|
||||||
|
|
||||||
|
Output Rules:
|
||||||
|
If segmentation is acceptable, respond with "YES" (and nothing else).
|
||||||
|
If not acceptable, respond with "NO", followed by a brief reason (max 20 words).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def query(self, image_paths: list[str]) -> str:
|
def query(self, image_paths: list[str]) -> str:
|
||||||
@ -152,13 +178,6 @@ class ImageSegChecker(BaseChecker):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"ImageSegChecker requires exactly two images: [raw_image, seg_image]." # noqa
|
"ImageSegChecker requires exactly two images: [raw_image, seg_image]." # noqa
|
||||||
)
|
)
|
||||||
# Hardcode tmp because of the openrouter can't input multi images.
|
|
||||||
if "openrouter" in self.gpt_client.endpoint:
|
|
||||||
from embodied_gen.utils.process_media import (
|
|
||||||
combine_images_to_base64,
|
|
||||||
)
|
|
||||||
|
|
||||||
image_paths = combine_images_to_base64(image_paths)
|
|
||||||
|
|
||||||
return self.gpt_client.query(
|
return self.gpt_client.query(
|
||||||
text_prompt=self.prompt,
|
text_prompt=self.prompt,
|
||||||
@ -201,42 +220,204 @@ class ImageAestheticChecker(BaseChecker):
|
|||||||
return avg_score > self.thresh, avg_score
|
return avg_score > self.thresh, avg_score
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
class SemanticConsistChecker(BaseChecker):
|
||||||
geo_checker = MeshGeoChecker(GPT_CLIENT)
|
def __init__(
|
||||||
seg_checker = ImageSegChecker(GPT_CLIENT)
|
self,
|
||||||
aesthetic_checker = ImageAestheticChecker()
|
gpt_client: GPTclient,
|
||||||
|
prompt: str = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(prompt, verbose)
|
||||||
|
self.gpt_client = gpt_client
|
||||||
|
if self.prompt is None:
|
||||||
|
self.prompt = """
|
||||||
|
You are an expert in image-text consistency assessment.
|
||||||
|
You will be given:
|
||||||
|
- A short text description of an object.
|
||||||
|
- An segmented image of the same object with the background removed.
|
||||||
|
|
||||||
checkers = [geo_checker, seg_checker, aesthetic_checker]
|
Criteria:
|
||||||
|
- The image must visually match the text description in terms of object type, structure, geometry, and color.
|
||||||
|
- The object must appear realistic, with reasonable geometry (e.g., a table must have a stable number of legs).
|
||||||
|
- Geometric completeness is required: the object must not have missing, truncated, or cropped parts.
|
||||||
|
- The object must be centered in the image frame with clear margins on all sides,
|
||||||
|
it should not touch or nearly touch any image edge.
|
||||||
|
- The image must contain exactly one object. Multiple distinct objects are not allowed.
|
||||||
|
A single composite object (e.g., a chair with legs) is acceptable.
|
||||||
|
- The object should be shown from a slightly angled (three-quarter) perspective,
|
||||||
|
not a flat, front-facing view showing only one surface.
|
||||||
|
|
||||||
output_root = "outputs/test_gpt"
|
Instructions:
|
||||||
|
- If all criteria are met, return `"YES"`.
|
||||||
|
- Otherwise, return "NO" with a brief explanation (max 20 words).
|
||||||
|
|
||||||
fails = []
|
Respond in exactly one of the following formats:
|
||||||
for idx in tqdm(range(150)):
|
YES
|
||||||
mesh_path = f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}.obj" # noqa
|
or
|
||||||
if not os.path.exists(mesh_path):
|
NO: brief explanation.
|
||||||
|
|
||||||
|
Input:
|
||||||
|
{}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def query(self, text: str, image: list[Image.Image | str]) -> str:
|
||||||
|
|
||||||
|
return self.gpt_client.query(
|
||||||
|
text_prompt=self.prompt.format(text),
|
||||||
|
image_base64=image,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TextGenAlignChecker(BaseChecker):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
gpt_client: GPTclient,
|
||||||
|
prompt: str = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(prompt, verbose)
|
||||||
|
self.gpt_client = gpt_client
|
||||||
|
if self.prompt is None:
|
||||||
|
self.prompt = """
|
||||||
|
You are an expert in evaluating the quality of generated 3D assets.
|
||||||
|
You will be given:
|
||||||
|
- A text description of an object: TEXT
|
||||||
|
- Rendered views of the generated 3D asset.
|
||||||
|
|
||||||
|
Your task is to:
|
||||||
|
1. Determine whether the generated 3D asset roughly reflects the object class
|
||||||
|
or a semantically adjacent category described in the text.
|
||||||
|
2. Evaluate the geometry quality of the 3D asset generation based on the rendered views.
|
||||||
|
|
||||||
|
Criteria:
|
||||||
|
- Determine if the generated 3D asset belongs to the text described or a similar category.
|
||||||
|
- Focus on functional similarity: if the object serves the same general
|
||||||
|
purpose (e.g., writing, placing items), it should be accepted.
|
||||||
|
- Is the geometry complete and well-formed, with no missing parts,
|
||||||
|
distortions, visual artifacts, or redundant structures?
|
||||||
|
- Does the number of object instances match the description?
|
||||||
|
There should be only one object unless otherwise specified.
|
||||||
|
- Minor flaws in geometry or texture are acceptable, high tolerance for texture quality defects.
|
||||||
|
- Minor simplifications in geometry or texture (e.g. soft edges, less detail)
|
||||||
|
are acceptable if the object is still recognizable.
|
||||||
|
- The asset should not contain any unrelated elements, such as
|
||||||
|
ground planes, platforms, or background props (e.g., paper, flooring).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
Text: "yellow cup"
|
||||||
|
Image: shows a yellow cup standing on a flat white plane -> NO: extra surface under the object.
|
||||||
|
|
||||||
|
Instructions:
|
||||||
|
- If the quality of generated asset is acceptable and faithfully represents the text, return "YES".
|
||||||
|
- Otherwise, return "NO" followed by a brief explanation (no more than 20 words).
|
||||||
|
|
||||||
|
Respond in exactly one of the following formats:
|
||||||
|
YES
|
||||||
|
or
|
||||||
|
NO: brief explanation
|
||||||
|
|
||||||
|
Input:
|
||||||
|
Text description: {}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def query(self, text: str, image: list[Image.Image | str]) -> str:
|
||||||
|
|
||||||
|
return self.gpt_client.query(
|
||||||
|
text_prompt=self.prompt.format(text),
|
||||||
|
image_base64=image,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SemanticMatcher(BaseChecker):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
gpt_client: GPTclient,
|
||||||
|
prompt: str = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
seed: int = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(prompt, verbose)
|
||||||
|
self.gpt_client = gpt_client
|
||||||
|
self.seed = seed
|
||||||
|
random.seed(seed)
|
||||||
|
if self.prompt is None:
|
||||||
|
self.prompt = """
|
||||||
|
You are an expert in semantic similarity and scene retrieval.
|
||||||
|
You will be given:
|
||||||
|
- A dictionary where each key is a scene ID, and each value is a scene description.
|
||||||
|
- A query text describing a target scene.
|
||||||
|
|
||||||
|
Your task:
|
||||||
|
return_num = 2
|
||||||
|
- Find the <return_num> most semantically similar scene IDs to the query text.
|
||||||
|
- If there are fewer than <return_num> distinct relevant matches, repeat the closest ones to make a list of <return_num>.
|
||||||
|
- Only output the list of <return_num> scene IDs, sorted from most to less similar.
|
||||||
|
- Do NOT use markdown, JSON code blocks, or any formatting syntax, only return a plain list like ["id1", ...].
|
||||||
|
|
||||||
|
Input example:
|
||||||
|
Dictionary:
|
||||||
|
"{{
|
||||||
|
"t_scene_008": "A study room with full bookshelves and a lamp in the corner.",
|
||||||
|
"t_scene_019": "A child's bedroom with pink walls and a small desk.",
|
||||||
|
"t_scene_020": "A living room with a wooden floor.",
|
||||||
|
"t_scene_021": "A living room with toys scattered on the floor.",
|
||||||
|
...
|
||||||
|
"t_scene_office_001": "A very spacious, modern open-plan office with wide desks and no people, panoramic view."
|
||||||
|
}}"
|
||||||
|
Text:
|
||||||
|
"A traditional indoor room"
|
||||||
|
Output:
|
||||||
|
'["t_scene_office_001", ...]'
|
||||||
|
|
||||||
|
Input:
|
||||||
|
Dictionary:
|
||||||
|
{context}
|
||||||
|
Text:
|
||||||
|
{text}
|
||||||
|
Output:
|
||||||
|
<topk_key_list>
|
||||||
|
"""
|
||||||
|
|
||||||
|
def query(
|
||||||
|
self, text: str, context: dict, rand: bool = True, params: dict = None
|
||||||
|
) -> str:
|
||||||
|
match_list = self.gpt_client.query(
|
||||||
|
self.prompt.format(context=context, text=text),
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
match_list = json_repair.loads(match_list)
|
||||||
|
result = random.choice(match_list) if rand else match_list[0]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def test_semantic_matcher(
|
||||||
|
bg_file: str = "outputs/bg_scenes/bg_scene_list.txt",
|
||||||
|
):
|
||||||
|
bg_file = "outputs/bg_scenes/bg_scene_list.txt"
|
||||||
|
scene_dict = {}
|
||||||
|
with open(bg_file, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line or ":" not in line:
|
||||||
continue
|
continue
|
||||||
image_paths = render_asset3d(
|
scene_id, desc = line.split(":", 1)
|
||||||
mesh_path,
|
scene_dict[scene_id.strip()] = desc.strip()
|
||||||
f"{output_root}/{idx}",
|
|
||||||
num_images=8,
|
|
||||||
elevation=(30, -30),
|
|
||||||
distance=5.5,
|
|
||||||
)
|
|
||||||
|
|
||||||
for cid, checker in enumerate(checkers):
|
office_scene = scene_dict.get("t_scene_office_001")
|
||||||
if isinstance(checker, ImageSegChecker):
|
text = "bright kitchen"
|
||||||
images = [
|
SCENE_MATCHER = SemanticMatcher(GPT_CLIENT)
|
||||||
f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}_raw.png", # noqa
|
# gpt_params = {
|
||||||
f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}_cond.png", # noqa
|
# "temperature": 0.8,
|
||||||
]
|
# "max_tokens": 500,
|
||||||
else:
|
# "top_p": 0.8,
|
||||||
images = image_paths
|
# "frequency_penalty": 0.3,
|
||||||
result, info = checker(images)
|
# "presence_penalty": 0.3,
|
||||||
logger.info(
|
# }
|
||||||
f"Checker {checker.__class__.__name__}: {result}, {info}, mesh {mesh_path}" # noqa
|
gpt_params = None
|
||||||
)
|
match_key = SCENE_MATCHER.query(text, str(scene_dict))
|
||||||
|
print(match_key, ",", scene_dict[match_key])
|
||||||
|
|
||||||
if result is False:
|
|
||||||
fails.append((idx, cid, info))
|
|
||||||
|
|
||||||
break
|
if __name__ == "__main__":
|
||||||
|
test_semantic_matcher()
|
||||||
|
|||||||
@ -297,20 +297,24 @@ class URDFGenerator(object):
|
|||||||
if not os.path.exists(urdf_path):
|
if not os.path.exists(urdf_path):
|
||||||
raise FileNotFoundError(f"URDF file not found: {urdf_path}")
|
raise FileNotFoundError(f"URDF file not found: {urdf_path}")
|
||||||
|
|
||||||
mesh_scale = 1.0
|
mesh_attr = None
|
||||||
tree = ET.parse(urdf_path)
|
tree = ET.parse(urdf_path)
|
||||||
root = tree.getroot()
|
root = tree.getroot()
|
||||||
extra_info = root.find(attr_root)
|
extra_info = root.find(attr_root)
|
||||||
if extra_info is not None:
|
if extra_info is not None:
|
||||||
scale_element = extra_info.find(attr_name)
|
scale_element = extra_info.find(attr_name)
|
||||||
if scale_element is not None:
|
if scale_element is not None:
|
||||||
mesh_scale = float(scale_element.text)
|
mesh_attr = scale_element.text
|
||||||
|
try:
|
||||||
|
mesh_attr = float(mesh_attr)
|
||||||
|
except ValueError as e:
|
||||||
|
pass
|
||||||
|
|
||||||
return mesh_scale
|
return mesh_attr
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_quality_tag(
|
def add_quality_tag(
|
||||||
urdf_path: str, results, output_path: str = None
|
urdf_path: str, results: list, output_path: str = None
|
||||||
) -> None:
|
) -> None:
|
||||||
if output_path is None:
|
if output_path is None:
|
||||||
output_path = urdf_path
|
output_path = urdf_path
|
||||||
@ -366,16 +370,9 @@ class URDFGenerator(object):
|
|||||||
output_root,
|
output_root,
|
||||||
num_images=self.render_view_num,
|
num_images=self.render_view_num,
|
||||||
output_subdir=self.output_render_dir,
|
output_subdir=self.output_render_dir,
|
||||||
|
no_index_file=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Hardcode tmp because of the openrouter can't input multi images.
|
|
||||||
if "openrouter" in self.gpt_client.endpoint:
|
|
||||||
from embodied_gen.utils.process_media import (
|
|
||||||
combine_images_to_base64,
|
|
||||||
)
|
|
||||||
|
|
||||||
image_path = combine_images_to_base64(image_path)
|
|
||||||
|
|
||||||
response = self.gpt_client.query(text_prompt, image_path)
|
response = self.gpt_client.query(text_prompt, image_path)
|
||||||
if response is None:
|
if response is None:
|
||||||
asset_attrs = {
|
asset_attrs = {
|
||||||
@ -412,14 +409,18 @@ class URDFGenerator(object):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
urdf_gen = URDFGenerator(GPT_CLIENT, render_view_num=4)
|
urdf_gen = URDFGenerator(GPT_CLIENT, render_view_num=4)
|
||||||
urdf_path = urdf_gen(
|
urdf_path = urdf_gen(
|
||||||
mesh_path="outputs/imageto3d/cma/o5/URDF_o5/mesh/o5.obj",
|
mesh_path="outputs/layout2/asset3d/marker/result/mesh/marker.obj",
|
||||||
output_root="outputs/test_urdf",
|
output_root="outputs/test_urdf",
|
||||||
# category="coffee machine",
|
category="marker",
|
||||||
# min_height=1.0,
|
# min_height=1.0,
|
||||||
# max_height=1.2,
|
# max_height=1.2,
|
||||||
version=VERSION,
|
version=VERSION,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
URDFGenerator.add_quality_tag(
|
||||||
|
urdf_path, [[urdf_gen.__class__.__name__, "OK"]]
|
||||||
|
)
|
||||||
|
|
||||||
# zip_files(
|
# zip_files(
|
||||||
# input_paths=[
|
# input_paths=[
|
||||||
# "scripts/apps/tmp/2umpdum3e5n/URDF_sample/mesh",
|
# "scripts/apps/tmp/2umpdum3e5n/URDF_sample/mesh",
|
||||||
|
|||||||
22
install.sh
22
install.sh
@ -8,6 +8,12 @@ NC='\033[0m'
|
|||||||
echo -e "${GREEN}Starting installation process...${NC}"
|
echo -e "${GREEN}Starting installation process...${NC}"
|
||||||
git config --global http.postBuffer 524288000
|
git config --global http.postBuffer 524288000
|
||||||
|
|
||||||
|
echo -e "${GREEN}Installing flash-attn...${NC}"
|
||||||
|
pip install flash-attn==2.7.0.post2 --no-build-isolation || {
|
||||||
|
echo -e "${RED}Failed to install flash-attn${NC}"
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
echo -e "${GREEN}Installing dependencies from requirements.txt...${NC}"
|
echo -e "${GREEN}Installing dependencies from requirements.txt...${NC}"
|
||||||
pip install -r requirements.txt --use-deprecated=legacy-resolver --default-timeout=60 || {
|
pip install -r requirements.txt --use-deprecated=legacy-resolver --default-timeout=60 || {
|
||||||
echo -e "${RED}Failed to install requirements${NC}"
|
echo -e "${RED}Failed to install requirements${NC}"
|
||||||
@ -15,16 +21,16 @@ pip install -r requirements.txt --use-deprecated=legacy-resolver --default-timeo
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
echo -e "${GREEN}Installing kaolin from GitHub...${NC}"
|
echo -e "${GREEN}Installing kolors from GitHub...${NC}"
|
||||||
pip install kaolin@git+https://github.com/NVIDIAGameWorks/kaolin.git@v0.16.0 || {
|
pip install kolors@git+https://github.com/Kwai-Kolors/Kolors.git#egg=038818d || {
|
||||||
echo -e "${RED}Failed to install kaolin${NC}"
|
echo -e "${RED}Failed to install kolors${NC}"
|
||||||
exit 1
|
exit 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
echo -e "${GREEN}Installing flash-attn...${NC}"
|
echo -e "${GREEN}Installing kaolin from GitHub...${NC}"
|
||||||
pip install flash-attn==2.7.0.post2 --no-build-isolation || {
|
pip install kaolin@git+https://github.com/NVIDIAGameWorks/kaolin.git@v0.16.0 || {
|
||||||
echo -e "${RED}Failed to install flash-attn${NC}"
|
echo -e "${RED}Failed to install kaolin${NC}"
|
||||||
exit 1
|
exit 1
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -39,7 +45,6 @@ rm -rf "$TMP_DIR" || {
|
|||||||
rm -rf "$TMP_DIR"
|
rm -rf "$TMP_DIR"
|
||||||
exit 1
|
exit 1
|
||||||
}
|
}
|
||||||
echo -e "${GREEN}Installation completed successfully!${NC}"
|
|
||||||
|
|
||||||
|
|
||||||
echo -e "${GREEN}Installing gsplat from GitHub...${NC}"
|
echo -e "${GREEN}Installing gsplat from GitHub...${NC}"
|
||||||
@ -50,8 +55,9 @@ pip install git+https://github.com/nerfstudio-project/gsplat.git@v1.5.0 || {
|
|||||||
|
|
||||||
|
|
||||||
echo -e "${GREEN}Installing EmbodiedGen...${NC}"
|
echo -e "${GREEN}Installing EmbodiedGen...${NC}"
|
||||||
|
pip install triton==2.1.0
|
||||||
pip install -e . || {
|
pip install -e . || {
|
||||||
echo -e "${RED}Failed to install local package${NC}"
|
echo -e "${RED}Failed to install EmbodiedGen pyproject.toml${NC}"
|
||||||
exit 1
|
exit 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -7,7 +7,7 @@ packages = ["embodied_gen"]
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "embodied_gen"
|
name = "embodied_gen"
|
||||||
version = "v0.1.0"
|
version = "v0.1.1"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
license-files = ["LICENSE", "NOTICE"]
|
license-files = ["LICENSE", "NOTICE"]
|
||||||
@ -17,16 +17,20 @@ requires-python = ">=3.10"
|
|||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
dev = [
|
dev = [
|
||||||
"cpplint==2.0.0",
|
"cpplint",
|
||||||
"pre-commit==2.13.0",
|
"pre-commit",
|
||||||
"pydocstyle",
|
"pydocstyle",
|
||||||
"black",
|
"black",
|
||||||
"isort",
|
"isort",
|
||||||
|
"pytest",
|
||||||
|
"pytest-mock",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
drender-cli = "embodied_gen.data.differentiable_render:entrypoint"
|
drender-cli = "embodied_gen.data.differentiable_render:entrypoint"
|
||||||
backproject-cli = "embodied_gen.data.backproject_v2:entrypoint"
|
backproject-cli = "embodied_gen.data.backproject_v2:entrypoint"
|
||||||
|
img3d-cli = "embodied_gen.scripts.imageto3d:entrypoint"
|
||||||
|
text3d-cli = "embodied_gen.scripts.textto3d:text_to_3d"
|
||||||
|
|
||||||
[tool.pydocstyle]
|
[tool.pydocstyle]
|
||||||
match = '(?!test_).*(?!_pb2)\.py'
|
match = '(?!test_).*(?!_pb2)\.py'
|
||||||
|
|||||||
12
pytest.ini
Normal file
12
pytest.ini
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
[pytest]
|
||||||
|
testpaths =
|
||||||
|
tests/test_unit
|
||||||
|
|
||||||
|
python_files = test_*.py
|
||||||
|
python_classes = Test*
|
||||||
|
python_functions = test_*
|
||||||
|
addopts = -ra --tb=short --disable-warnings --log-cli-level=INFO
|
||||||
|
filterwarnings =
|
||||||
|
ignore::DeprecationWarning
|
||||||
|
markers =
|
||||||
|
slow: marks tests as slow
|
||||||
@ -8,10 +8,10 @@ triton==2.1.0
|
|||||||
dataclasses_json
|
dataclasses_json
|
||||||
easydict
|
easydict
|
||||||
opencv-python>4.5
|
opencv-python>4.5
|
||||||
imageio==2.36.1
|
imageio
|
||||||
imageio-ffmpeg==0.5.1
|
imageio-ffmpeg
|
||||||
rembg==2.0.61
|
rembg==2.0.61
|
||||||
trimesh==4.4.4
|
trimesh
|
||||||
moviepy==1.0.3
|
moviepy==1.0.3
|
||||||
pymeshfix==0.17.0
|
pymeshfix==0.17.0
|
||||||
igraph==0.11.8
|
igraph==0.11.8
|
||||||
@ -20,21 +20,19 @@ openai==1.58.1
|
|||||||
transformers==4.42.4
|
transformers==4.42.4
|
||||||
gradio==5.12.0
|
gradio==5.12.0
|
||||||
sentencepiece==0.2.0
|
sentencepiece==0.2.0
|
||||||
diffusers==0.31.0
|
diffusers==0.34.0
|
||||||
xatlas==0.0.9
|
xatlas
|
||||||
onnxruntime==1.20.1
|
onnxruntime==1.20.1
|
||||||
tenacity==8.2.2
|
tenacity
|
||||||
accelerate==0.33.0
|
accelerate==0.33.0
|
||||||
basicsr==1.4.2
|
basicsr==1.4.2
|
||||||
realesrgan==0.3.0
|
realesrgan==0.3.0
|
||||||
pydantic==2.9.2
|
pydantic==2.9.2
|
||||||
vtk==9.3.1
|
vtk==9.3.1
|
||||||
spaces
|
spaces
|
||||||
|
colorlog
|
||||||
|
json-repair
|
||||||
utils3d@git+https://github.com/EasternJournalist/utils3d.git#egg=9a4eb15
|
utils3d@git+https://github.com/EasternJournalist/utils3d.git#egg=9a4eb15
|
||||||
clip@git+https://github.com/openai/CLIP.git
|
clip@git+https://github.com/openai/CLIP.git
|
||||||
kolors@git+https://github.com/Kwai-Kolors/Kolors.git#egg=038818d
|
|
||||||
segment-anything@git+https://github.com/facebookresearch/segment-anything.git#egg=dca509f
|
segment-anything@git+https://github.com/facebookresearch/segment-anything.git#egg=dca509f
|
||||||
nvdiffrast@git+https://github.com/NVlabs/nvdiffrast.git#egg=729261d
|
nvdiffrast@git+https://github.com/NVlabs/nvdiffrast.git#egg=729261d
|
||||||
# https://github.com/nerfstudio-project/gsplat/releases/download/v1.5.0/gsplat-1.5.0+pt24cu118-cp310-cp310-linux_x86_64.whl
|
|
||||||
# https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.0.post2/flash_attn-2.7.0.post2+cu11torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
|
||||||
# https://huggingface.co/xinjjj/RoboAssetGen/resolve/main/wheel_cu118/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl
|
|
||||||
|
|||||||
31
tests/test_examples/test_aesthetic_predictor.py
Normal file
31
tests/test_examples/test_aesthetic_predictor.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
# Project EmbodiedGen
|
||||||
|
#
|
||||||
|
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
# implied. See the License for the specific language governing
|
||||||
|
# permissions and limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from embodied_gen.validators.aesthetic_predictor import AestheticPredictor
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# @pytest.mark.manual
|
||||||
|
def test_aesthetic_predictor():
|
||||||
|
image_path = "apps/assets/example_image/sample_02.jpg"
|
||||||
|
predictor = AestheticPredictor(device="cpu")
|
||||||
|
score = predictor.predict(image_path)
|
||||||
|
|
||||||
|
assert isinstance(score, float)
|
||||||
|
logger.info(f"Aesthetic score: {score:.3f}")
|
||||||
119
tests/test_examples/test_quality_checkers.py
Normal file
119
tests/test_examples/test_quality_checkers.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
# Project EmbodiedGen
|
||||||
|
#
|
||||||
|
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
# implied. See the License for the specific language governing
|
||||||
|
# permissions and limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
||||||
|
from embodied_gen.utils.process_media import render_asset3d
|
||||||
|
from embodied_gen.validators.quality_checkers import (
|
||||||
|
ImageAestheticChecker,
|
||||||
|
ImageSegChecker,
|
||||||
|
MeshGeoChecker,
|
||||||
|
SemanticConsistChecker,
|
||||||
|
TextGenAlignChecker,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def geo_checker():
|
||||||
|
return MeshGeoChecker(GPT_CLIENT)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def seg_checker():
|
||||||
|
return ImageSegChecker(GPT_CLIENT)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def aesthetic_checker():
|
||||||
|
return ImageAestheticChecker()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def semantic_checker():
|
||||||
|
return SemanticConsistChecker(GPT_CLIENT)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def textalign_checker():
|
||||||
|
return TextGenAlignChecker(GPT_CLIENT)
|
||||||
|
|
||||||
|
|
||||||
|
def test_geo_checker(geo_checker):
|
||||||
|
flag, result = geo_checker(
|
||||||
|
[
|
||||||
|
"apps/assets/example_image/sample_02.jpg",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger.info(f"geo_checker: {flag}, {result}")
|
||||||
|
assert isinstance(flag, bool)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_aesthetic_checker(aesthetic_checker):
|
||||||
|
flag, result = aesthetic_checker("apps/assets/example_image/sample_02.jpg")
|
||||||
|
logger.info(f"aesthetic_checker: {flag}, {result}")
|
||||||
|
assert isinstance(flag, bool)
|
||||||
|
assert isinstance(result, float)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seg_checker(seg_checker):
|
||||||
|
flag, result = seg_checker(
|
||||||
|
[
|
||||||
|
"apps/assets/example_image/sample_02.jpg",
|
||||||
|
"apps/assets/example_image/sample_02.jpg",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger.info(f"seg_checker: {flag}, {result}")
|
||||||
|
assert isinstance(flag, bool)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_semantic_checker(semantic_checker):
|
||||||
|
flag, result = semantic_checker(
|
||||||
|
text="can",
|
||||||
|
image=["apps/assets/example_image/sample_02.jpg"],
|
||||||
|
)
|
||||||
|
logger.info(f"semantic_checker: {flag}, {result}")
|
||||||
|
assert isinstance(flag, bool)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"mesh_path, text_desc",
|
||||||
|
[
|
||||||
|
("apps/assets/example_texture/meshes/chair.obj", "chair"),
|
||||||
|
("apps/assets/example_texture/meshes/clock.obj", "clock"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_textgen_checker(textalign_checker, mesh_path, text_desc):
|
||||||
|
with tempfile.TemporaryDirectory() as output_root:
|
||||||
|
image_list = render_asset3d(
|
||||||
|
mesh_path,
|
||||||
|
output_root=output_root,
|
||||||
|
num_images=6,
|
||||||
|
elevation=(30, -30),
|
||||||
|
output_subdir="renders",
|
||||||
|
no_index_file=True,
|
||||||
|
with_mtl=False,
|
||||||
|
)
|
||||||
|
flag, result = textalign_checker(text_desc, image_list)
|
||||||
|
logger.info(f"textalign_checker: {flag}, {result}")
|
||||||
95
tests/test_unit/test_agents.py
Normal file
95
tests/test_unit/test_agents.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
# Project EmbodiedGen
|
||||||
|
#
|
||||||
|
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
# implied. See the License for the specific language governing
|
||||||
|
# permissions and limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from PIL import Image
|
||||||
|
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
||||||
|
from embodied_gen.validators.quality_checkers import (
|
||||||
|
ImageSegChecker,
|
||||||
|
MeshGeoChecker,
|
||||||
|
SemanticConsistChecker,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def gptclient_query():
|
||||||
|
with patch.object(
|
||||||
|
GPT_CLIENT, "query", return_value="mocked gpt response"
|
||||||
|
) as mock:
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def gptclient_query_case2():
|
||||||
|
with patch.object(GPT_CLIENT, "query", return_value=None) as mock:
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_images",
|
||||||
|
[
|
||||||
|
"dummy_path/color_grid_6view.png",
|
||||||
|
["dummy_path/color_grid_6view.jpg"],
|
||||||
|
[
|
||||||
|
"dummy_path/color_grid_6view.png",
|
||||||
|
"dummy_path/color_grid_6view2.png",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
Image.new("RGB", (64, 64), "red"),
|
||||||
|
Image.new("RGB", (64, 64), "blue"),
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_geo_checker_varied_inputs(input_images):
|
||||||
|
geo_checker = MeshGeoChecker(GPT_CLIENT)
|
||||||
|
flag, result = geo_checker(input_images)
|
||||||
|
assert isinstance(flag, (bool, type(None)))
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seg_checker():
|
||||||
|
seg_checker = ImageSegChecker(GPT_CLIENT)
|
||||||
|
flag, result = seg_checker(
|
||||||
|
[
|
||||||
|
"dummy_path/sample_0_raw.png", # raw image
|
||||||
|
"dummy_path/sample_0_cond.png", # segmented image
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert isinstance(flag, (bool, type(None)))
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_semantic_checker():
|
||||||
|
semantic_checker = SemanticConsistChecker(GPT_CLIENT)
|
||||||
|
flag, result = semantic_checker(
|
||||||
|
text="pen",
|
||||||
|
image=["dummy_path/pen.png"],
|
||||||
|
)
|
||||||
|
assert isinstance(flag, (bool, type(None)))
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_semantic_checker(gptclient_query_case2):
|
||||||
|
semantic_checker = SemanticConsistChecker(GPT_CLIENT)
|
||||||
|
flag, result = semantic_checker(
|
||||||
|
text="pen",
|
||||||
|
image=["dummy_path/pen.png"],
|
||||||
|
)
|
||||||
|
assert isinstance(flag, (bool, type(None)))
|
||||||
|
assert isinstance(result, str)
|
||||||
94
tests/test_unit/test_gpt_client.py
Normal file
94
tests/test_unit/test_gpt_client.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
# Project EmbodiedGen
|
||||||
|
#
|
||||||
|
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
# implied. See the License for the specific language governing
|
||||||
|
# permissions and limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
from PIL import Image
|
||||||
|
from embodied_gen.utils.gpt_clients import CONFIG_FILE, GPTclient
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def config():
|
||||||
|
with open(CONFIG_FILE, "r") as f:
|
||||||
|
return yaml.safe_load(f)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def env_vars(monkeypatch, config):
|
||||||
|
agent_type = config["agent_type"]
|
||||||
|
agent_config = config.get(agent_type, {})
|
||||||
|
monkeypatch.setenv(
|
||||||
|
"ENDPOINT", agent_config.get("endpoint", "fake_endpoint")
|
||||||
|
)
|
||||||
|
monkeypatch.setenv("API_KEY", agent_config.get("api_key", "fake_api_key"))
|
||||||
|
monkeypatch.setenv("API_VERSION", agent_config.get("api_version", "v1"))
|
||||||
|
monkeypatch.setenv(
|
||||||
|
"MODEL_NAME", agent_config.get("model_name", "test_model")
|
||||||
|
)
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def gpt_client(env_vars):
|
||||||
|
client = GPTclient(
|
||||||
|
endpoint=os.environ.get("ENDPOINT"),
|
||||||
|
api_key=os.environ.get("API_KEY"),
|
||||||
|
api_version=os.environ.get("API_VERSION"),
|
||||||
|
model_name=os.environ.get("MODEL_NAME"),
|
||||||
|
check_connection=False,
|
||||||
|
)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"text_prompt, image_base64",
|
||||||
|
[
|
||||||
|
("What is the capital of China?", None),
|
||||||
|
(
|
||||||
|
"What is the content in each image?",
|
||||||
|
"apps/assets/example_image/sample_02.jpg",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"What is the content in each image?",
|
||||||
|
[
|
||||||
|
"apps/assets/example_image/sample_02.jpg",
|
||||||
|
"apps/assets/example_image/sample_03.jpg",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"What is the content in each image?",
|
||||||
|
[
|
||||||
|
Image.new("RGB", (64, 64), "red"),
|
||||||
|
Image.new("RGB", (64, 64), "blue"),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_gptclient_query(gpt_client, text_prompt, image_base64):
|
||||||
|
# mock GPTclient.query
|
||||||
|
with patch.object(
|
||||||
|
GPTclient, "query", return_value="mocked response"
|
||||||
|
) as mock_query:
|
||||||
|
response = gpt_client.query(
|
||||||
|
text_prompt=text_prompt, image_base64=image_base64
|
||||||
|
)
|
||||||
|
assert response == "mocked response"
|
||||||
|
mock_query.assert_called_once_with(
|
||||||
|
text_prompt=text_prompt, image_base64=image_base64
|
||||||
|
)
|
||||||
Loading…
x
Reference in New Issue
Block a user