diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a47aa67 --- /dev/null +++ b/.gitignore @@ -0,0 +1,62 @@ +build/ +dummy/ +!scripts/build +builddir/ +conan-deps/ +distribute/ +lib/ +bin/ +dist/ +deps/* +docs/build +python/dist/ +docs/index.rst +python/MANIFEST.in +*.egg-info +*.pyc +*.pyi +*.json +*.bak +*.zip +wheels*/ + +# Compiled Object files +*.slo +*.lo +*.o + +# Compiled Dynamic libraries +*.so +*.dylib + +# Compiled Static libraries +*.lai +*.la +*.a + +.cproject +.project +.settings/ +*.db +*.bak +.arcconfig +.vscode/ + +# files +*.pack +*.pcd +*.html +*.ply +*.mp4 +# node +node_modules + +# local files +build.sh +__pycache__/ +output* +*.log +scripts/tools/ +weights/ +apps/assets/example_texture/ +apps/sessions/ \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..2ccd752 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,4 @@ +[submodule "thirdparty/TRELLIS"] + path = thirdparty/TRELLIS + url = https://github.com/microsoft/TRELLIS.git + branch = main diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..28d6ba6 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,78 @@ +repos: + - repo: git@gitlab.hobot.cc:ptd/3rd/pre-commit/pre-commit-hooks.git + rev: v2.3.0 # Use the ref you want to point at + hooks: + - id: trailing-whitespace + - id: check-added-large-files + name: Check for added large files + description: Prevent giant files from being committed + entry: check-added-large-files + language: python + args: ["--maxkb=1024"] + - repo: local + hooks: + - id: cpplint-cpp-source + name: cpplint + description: Check cpp code style. + entry: python3 scripts/lint_src/lint.py + language: system + exclude: (?x)(^tools/|^thirdparty/|^patch_files/) + files: \.(c|cc|cxx|cpp|cu|h|hpp)$ + args: [--project=asset_recons, --path] + - repo: local + hooks: + - id: pycodestyle-python + name: pep8-exclude-docs + description: Check python code style. + entry: pycodestyle + language: system + exclude: (?x)(^docs/|^thirdparty/|^scripts/build/) + files: \.(py)$ + types: [file, python] + args: [--config=setup.cfg] + + + # pre-commit install --hook-type commit-msg to enable it + # - repo: local + # hooks: + # - id: commit-check + # name: check for commit msg format + # language: pygrep + # entry: '\A(?!(feat|fix|docs|style|refactor|perf|test|chore)\(.*\): (\[[a-zA-Z][a-zA-Z0-9_]+-[1-9][0-9]*\]|\[cr_id_skip\]) [A-Z]+.*)' + # args: [--multiline] + # stages: [commit-msg] + + - repo: local + hooks: + - id: pydocstyle-python + name: pydocstyle-change-exclude-docs + description: Check python doc style. + entry: pydocstyle + language: system + exclude: (?x)(^docs/|^thirdparty/) + files: \.(py)$ + types: [file, python] + args: [--config=pyproject.toml] + - repo: local + hooks: + - id: black + name: black-exclude-docs + description: black format + entry: black + language: system + exclude: (?x)(^docs/|^thirdparty/) + files: \.(py)$ + types: [file, python] + args: [--config=pyproject.toml] + + - repo: local + hooks: + - id: isort + name: isort + description: isort format + entry: isort + language: system + exclude: (?x)(^thirdparty/) + files: \.(py)$ + types: [file, python] + args: [--settings-file=pyproject.toml] \ No newline at end of file diff --git a/LICENSE b/LICENSE index 261eeb9..2132bc3 100644 --- a/LICENSE +++ b/LICENSE @@ -1,3 +1,5 @@ +Copyright (c) 2024 Horizon Robotics and EmbodiedGen Contributors. All rights reserved. + Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ @@ -186,7 +188,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright [yyyy] [name of copyright owner] + Copyright 2024 Horizon Robotics and EmbodiedGen Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..35d708e --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +graft embodied_gen \ No newline at end of file diff --git a/README.md b/README.md index c4ae5c8..c86525d 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,178 @@ -# EmbodiedGen -Towards a Generative 3D World Engine for Embodied Intelligence +# EmbodiedGen: Towards a Generative 3D World Engine for Embodied Intelligence + +[![๐ŸŒ Project Page](https://img.shields.io/badge/๐ŸŒ-Project_Page-blue)](https://horizonrobotics.github.io/robot_lab/embodied_gen/index.html) +[![๐Ÿ“„ arXiv](https://img.shields.io/badge/๐Ÿ“„-arXiv-b31b1b)](#) +[![๐ŸŽฅ Video](https://img.shields.io/badge/๐ŸŽฅ-Video-red)](https://www.youtube.com/watch?v=SnHhzHeb_aI) +[![๐Ÿค— Hugging Face](https://img.shields.io/badge/๐Ÿค—-Image_to_3D_Demo-blue)](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Image-to-3D) +[![๐Ÿค— Hugging Face](https://img.shields.io/badge/๐Ÿค—-Text_to_3D_Demo-blue)](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Text-to-3D) +[![๐Ÿค— Hugging Face](https://img.shields.io/badge/๐Ÿค—-Texture_Gen_Demo-blue)](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Texture-Gen) + +Overall Framework + + +**EmbodiedGen** generates interactive 3D worlds with real-world scale and physical realism at low cost. + +--- + +## โœจ Table of Contents of EmbodiedGen +- [๐Ÿ–ผ๏ธ Image-to-3D](#image-to-3d) +- [๐Ÿ“ Text-to-3D](#text-to-3d) +- [๐ŸŽจ Texture Generation](#texture-generation) +- [๐ŸŒ 3D Scene Generation](#3d-scene-generation) +- [โš™๏ธ Articulated Object Generation](#articulated-object-generation) +- [๐Ÿž๏ธ Layout Generation](#layout-generation) + +## ๐Ÿš€ Quick Start + +```sh +git clone https://github.com/HorizonRobotics/EmbodiedGen +cd EmbodiedGen +conda create -n embodiedgen python=3.10.13 -y +conda activate embodiedgen +pip install -r requirements.txt --use-deprecated=legacy-resolver +pip install -e . +``` + +--- + +## ๐ŸŸข Setup GPT Agent + +Update the API key in file: `embodied_gen/utils/gpt_config.yaml`. + +You can choose between two backends for the GPT agent: + +- **`gpt-4o`** (Recommended) โ€“ Use this if you have access to **Azure OpenAI**. +- **`qwen2.5-vl`** โ€“ An open alternative with free usage via [OpenRouter](https://openrouter.ai/settings/keys) (50 free requests per day) + + +--- + +

๐Ÿ–ผ๏ธ Image-to-3D

+ +[![๐Ÿค— Hugging Face](https://img.shields.io/badge/๐Ÿค—-Image_to_3D_Demo-blue)](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Image-to-3D) Generate physically plausible 3D asset from input image. + +### Local Service +Run the image-to-3D generation service locally. The first run will download required models. + +```sh +# Run in foreground +python apps/image_to_3d.py +# Or run in the background +CUDA_VISIBLE_DEVICES=0 nohup python apps/image_to_3d.py > /dev/null 2>&1 & +``` + +### Local API +Generate a 3D model from an image using the command-line API. + +```sh +python3 embodied_gen/scripts/imageto3d.py \ + --image_path apps/assets/example_image/sample_04.jpg apps/assets/example_image/sample_19.jpg \ + --output_root outputs/imageto3d/ + +# See result(.urdf/mesh.obj/mesh.glb/gs.ply) in ${output_root}/sample_xx/result +``` + +--- + + +

๐Ÿ“ Text-to-3D

+ +[![๐Ÿค— Hugging Face](https://img.shields.io/badge/๐Ÿค—-Text_to_3D_Demo-blue)](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Text-to-3D) Create 3D assets from text descriptions for a wide range of geometry and styles. + +### Local Service +Run the text-to-3D generation service locally. + +```sh +python apps/text_to_3d.py +``` + +### Local API + +```sh +bash embodied_gen/scripts/textto3d.sh \ + --prompts "small bronze figurine of a lion" "ๅธฆๆœจ่ดจๅบ•ๅบง๏ผŒๅ…ทๆœ‰็ป็บฌ็บฟ็š„ๅœฐ็ƒไปช" "ๆฉ™่‰ฒ็”ตๅŠจๆ‰‹้’ป๏ผŒๆœ‰็ฃจๆŸ็ป†่Š‚" \ + --output_root outputs/textto3d/ +``` + +--- + + +

๐ŸŽจ Texture Generation

+ +[![๐Ÿค— Hugging Face](https://img.shields.io/badge/๐Ÿค—-Texture_Gen_Demo-blue)](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Texture-Gen) Generate visually rich textures for 3D mesh. + +### Local Service +Run the texture generation service locally. + +```sh +python apps/texture_edit.py +``` + +### Local API +Generate textures for a 3D mesh using a text prompt. + +```sh +bash embodied_gen/scripts/texture_gen.sh \ + --mesh_path "apps/assets/example_texture/meshes/robot_text.obj" \ + --prompt "ไธพ็€็‰Œๅญ็š„็บข่‰ฒๅ†™ๅฎž้ฃŽๆ ผๆœบๅ™จไบบ๏ผŒ็‰ŒๅญไธŠๅ†™็€โ€œHelloโ€" \ + --output_root "outputs/texture_gen/" \ + --uuid "robot_text" +``` + +--- + +

๐ŸŒ 3D Scene Generation

+ +๐Ÿšง *Coming Soon* + +--- + + +

โš™๏ธ Articulated Object Generation

+ +๐Ÿšง *Coming Soon* + +--- + + +

๐Ÿž๏ธ Layout Generation

+ +๐Ÿšง *Coming Soon* + +--- + +## ๐Ÿ“š Citation + +If you use EmbodiedGen in your research or projects, please cite: + +```bibtex +Coming Soon +``` + +--- + +## ๐Ÿ™Œ Acknowledgement + +EmbodiedGen builds upon the following amazing projects and models: + +- ๐ŸŒŸ [Trellis](https://github.com/microsoft/TRELLIS) +- ๐ŸŒŸ [Hunyuan-Delight](https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0) +- ๐ŸŒŸ [Segment Anything Model](https://github.com/facebookresearch/segment-anything) +- ๐ŸŒŸ [Rembg: a tool to remove images background](https://github.com/danielgatis/rembg) +- ๐ŸŒŸ [RMBG-1.4: BRIA Background Removal](https://huggingface.co/briaai/RMBG-1.4) +- ๐ŸŒŸ [stable-diffusion-x4-upscaler](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) +- ๐ŸŒŸ [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) +- ๐ŸŒŸ [Kolors](https://github.com/Kwai-Kolors/Kolors) +- ๐ŸŒŸ [ChatGLM3](https://github.com/THUDM/ChatGLM3) +- ๐ŸŒŸ [Aesthetic Score Model](http://captions.christoph-schuhmann.de/aesthetic_viz_laion_sac+logos+ava1-l14-linearMSE-en-2.37B.html) +- ๐ŸŒŸ [Pano2Room](https://github.com/TrickyGo/Pano2Room) +- ๐ŸŒŸ [Diffusion360](https://github.com/ArcherFMY/SD-T2I-360PanoImage) +- ๐ŸŒŸ [kaolin](https://github.com/NVIDIAGameWorks/kaolin) +- ๐ŸŒŸ [diffusers](https://github.com/huggingface/diffusers) +- ๐ŸŒŸ GPT: QWEN2.5VL, GPT4o + +--- + +## โš–๏ธ License + +This project is licensed under the [Apache License 2.0](LICENSE). See the `LICENSE` file for details. \ No newline at end of file diff --git a/apps/common.py b/apps/common.py new file mode 100644 index 0000000..11b7d6a --- /dev/null +++ b/apps/common.py @@ -0,0 +1,899 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +import gc +import logging +import os +import shutil +import subprocess +import sys +from glob import glob + +import cv2 +import gradio as gr +import numpy as np +import spaces +import torch +import torch.nn.functional as F +import trimesh +from easydict import EasyDict as edict +from gradio.themes import Soft +from gradio.themes.utils.colors import gray, neutral, slate, stone, teal, zinc +from PIL import Image +from embodied_gen.data.backproject_v2 import entrypoint as backproject_api +from embodied_gen.data.differentiable_render import entrypoint as render_api +from embodied_gen.data.utils import trellis_preprocess +from embodied_gen.models.delight_model import DelightingModel +from embodied_gen.models.gs_model import GaussianOperator +from embodied_gen.models.segment_model import ( + BMGG14Remover, + RembgRemover, + SAMPredictor, +) +from embodied_gen.models.sr_model import ImageRealESRGAN, ImageStableSR +from embodied_gen.scripts.render_gs import entrypoint as render_gs_api +from embodied_gen.scripts.render_mv import build_texture_gen_pipe, infer_pipe +from embodied_gen.scripts.text2image import ( + build_text2img_ip_pipeline, + build_text2img_pipeline, + text2img_gen, +) +from embodied_gen.utils.gpt_clients import GPT_CLIENT +from embodied_gen.utils.process_media import ( + filter_image_small_connected_components, + merge_images_video, + render_video, +) +from embodied_gen.utils.tags import VERSION +from embodied_gen.validators.quality_checkers import ( + BaseChecker, + ImageAestheticChecker, + ImageSegChecker, + MeshGeoChecker, +) +from embodied_gen.validators.urdf_convertor import URDFGenerator, zip_files + +current_file_path = os.path.abspath(__file__) +current_dir = os.path.dirname(current_file_path) +sys.path.append(os.path.join(current_dir, "..")) +from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline +from thirdparty.TRELLIS.trellis.representations import ( + Gaussian, + MeshExtractResult, +) +from thirdparty.TRELLIS.trellis.representations.gaussian.general_utils import ( + build_scaling_rotation, + inverse_sigmoid, + strip_symmetric, +) +from thirdparty.TRELLIS.trellis.utils import postprocessing_utils + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO +) +logger = logging.getLogger(__name__) + + +os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser( + "~/.cache/torch_extensions" +) +os.environ["GRADIO_ANALYTICS_ENABLED"] = "false" +os.environ["SPCONV_ALGO"] = "native" + +MAX_SEED = 100000 +DELIGHT = DelightingModel() +IMAGESR_MODEL = ImageRealESRGAN(outscale=4) +# IMAGESR_MODEL = ImageStableSR() + + +def patched_setup_functions(self): + def inverse_softplus(x): + return x + torch.log(-torch.expm1(-x)) + + def build_covariance_from_scaling_rotation( + scaling, scaling_modifier, rotation + ): + L = build_scaling_rotation(scaling_modifier * scaling, rotation) + actual_covariance = L @ L.transpose(1, 2) + symm = strip_symmetric(actual_covariance) + return symm + + if self.scaling_activation_type == "exp": + self.scaling_activation = torch.exp + self.inverse_scaling_activation = torch.log + elif self.scaling_activation_type == "softplus": + self.scaling_activation = F.softplus + self.inverse_scaling_activation = inverse_softplus + + self.covariance_activation = build_covariance_from_scaling_rotation + self.opacity_activation = torch.sigmoid + self.inverse_opacity_activation = inverse_sigmoid + self.rotation_activation = F.normalize + + self.scale_bias = self.inverse_scaling_activation( + torch.tensor(self.scaling_bias) + ).to(self.device) + self.rots_bias = torch.zeros((4)).to(self.device) + self.rots_bias[0] = 1 + self.opacity_bias = self.inverse_opacity_activation( + torch.tensor(self.opacity_bias) + ).to(self.device) + + +Gaussian.setup_functions = patched_setup_functions + + +def download_kolors_weights() -> None: + logger.info(f"Download kolors weights from huggingface...") + subprocess.run( + [ + "huggingface-cli", + "download", + "--resume-download", + "Kwai-Kolors/Kolors", + "--local-dir", + "weights/Kolors", + ], + check=True, + ) + subprocess.run( + [ + "huggingface-cli", + "download", + "--resume-download", + "Kwai-Kolors/Kolors-IP-Adapter-Plus", + "--local-dir", + "weights/Kolors-IP-Adapter-Plus", + ], + check=True, + ) + + +if os.getenv("GRADIO_APP") == "imageto3d": + RBG_REMOVER = RembgRemover() + RBG14_REMOVER = BMGG14Remover() + SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu") + PIPELINE = TrellisImageTo3DPipeline.from_pretrained( + "microsoft/TRELLIS-image-large" + ) + # PIPELINE.cuda() + SEG_CHECKER = ImageSegChecker(GPT_CLIENT) + GEO_CHECKER = MeshGeoChecker(GPT_CLIENT) + AESTHETIC_CHECKER = ImageAestheticChecker() + CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER] + TMP_DIR = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d" + ) +elif os.getenv("GRADIO_APP") == "textto3d": + RBG_REMOVER = RembgRemover() + RBG14_REMOVER = BMGG14Remover() + PIPELINE = TrellisImageTo3DPipeline.from_pretrained( + "microsoft/TRELLIS-image-large" + ) + # PIPELINE.cuda() + text_model_dir = "weights/Kolors" + if not os.path.exists(text_model_dir): + download_kolors_weights() + + PIPELINE_IMG_IP = build_text2img_ip_pipeline(text_model_dir, ref_scale=0.3) + PIPELINE_IMG = build_text2img_pipeline(text_model_dir) + SEG_CHECKER = ImageSegChecker(GPT_CLIENT) + GEO_CHECKER = MeshGeoChecker(GPT_CLIENT) + AESTHETIC_CHECKER = ImageAestheticChecker() + CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER] + TMP_DIR = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "sessions/textto3d" + ) +elif os.getenv("GRADIO_APP") == "texture_edit": + if not os.path.exists("weights/Kolors"): + download_kolors_weights() + + PIPELINE_IP = build_texture_gen_pipe( + base_ckpt_dir="./weights", + ip_adapt_scale=0.7, + device="cuda", + ) + PIPELINE = build_texture_gen_pipe( + base_ckpt_dir="./weights", + ip_adapt_scale=0, + device="cuda", + ) + TMP_DIR = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "sessions/texture_edit" + ) + +os.makedirs(TMP_DIR, exist_ok=True) + + +lighting_css = """ + +""" + +image_css = """ + +""" + +custom_theme = Soft( + primary_hue=stone, + secondary_hue=gray, + radius_size="md", + text_size="sm", + spacing_size="sm", +) + + +def start_session(req: gr.Request) -> None: + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + os.makedirs(user_dir, exist_ok=True) + + +def end_session(req: gr.Request) -> None: + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + if os.path.exists(user_dir): + shutil.rmtree(user_dir) + + +@spaces.GPU +def preprocess_image_fn( + image: str | np.ndarray | Image.Image, rmbg_tag: str = "rembg" +) -> tuple[Image.Image, Image.Image]: + if isinstance(image, str): + image = Image.open(image) + elif isinstance(image, np.ndarray): + image = Image.fromarray(image) + + image_cache = image.copy().resize((512, 512)) + + bg_remover = RBG_REMOVER if rmbg_tag == "rembg" else RBG14_REMOVER + image = bg_remover(image) + image = trellis_preprocess(image) + + return image, image_cache + + +def preprocess_sam_image_fn( + image: Image.Image, +) -> tuple[Image.Image, Image.Image]: + if isinstance(image, np.ndarray): + image = Image.fromarray(image) + + sam_image = SAM_PREDICTOR.preprocess_image(image) + image_cache = Image.fromarray(sam_image).resize((512, 512)) + SAM_PREDICTOR.predictor.set_image(sam_image) + + return sam_image, image_cache + + +def active_btn_by_content(content: gr.Image) -> gr.Button: + interactive = True if content is not None else False + + return gr.Button(interactive=interactive) + + +def active_btn_by_text_content(content: gr.Textbox) -> gr.Button: + if content is not None and len(content) > 0: + interactive = True + else: + interactive = False + + return gr.Button(interactive=interactive) + + +def get_selected_image( + choice: str, sample1: str, sample2: str, sample3: str +) -> str: + if choice == "sample1": + return sample1 + elif choice == "sample2": + return sample2 + elif choice == "sample3": + return sample3 + else: + raise ValueError(f"Invalid choice: {choice}") + + +def get_cached_image(image_path: str) -> Image.Image: + if isinstance(image_path, Image.Image): + return image_path + return Image.open(image_path).resize((512, 512)) + + +@spaces.GPU +def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict: + return { + "gaussian": { + **gs.init_params, + "_xyz": gs._xyz.cpu().numpy(), + "_features_dc": gs._features_dc.cpu().numpy(), + "_scaling": gs._scaling.cpu().numpy(), + "_rotation": gs._rotation.cpu().numpy(), + "_opacity": gs._opacity.cpu().numpy(), + }, + "mesh": { + "vertices": mesh.vertices.cpu().numpy(), + "faces": mesh.faces.cpu().numpy(), + }, + } + + +def unpack_state(state: dict, device: str = "cpu") -> tuple[Gaussian, dict]: + gs = Gaussian( + aabb=state["gaussian"]["aabb"], + sh_degree=state["gaussian"]["sh_degree"], + mininum_kernel_size=state["gaussian"]["mininum_kernel_size"], + scaling_bias=state["gaussian"]["scaling_bias"], + opacity_bias=state["gaussian"]["opacity_bias"], + scaling_activation=state["gaussian"]["scaling_activation"], + device=device, + ) + gs._xyz = torch.tensor(state["gaussian"]["_xyz"], device=device) + gs._features_dc = torch.tensor( + state["gaussian"]["_features_dc"], device=device + ) + gs._scaling = torch.tensor(state["gaussian"]["_scaling"], device=device) + gs._rotation = torch.tensor(state["gaussian"]["_rotation"], device=device) + gs._opacity = torch.tensor(state["gaussian"]["_opacity"], device=device) + + mesh = edict( + vertices=torch.tensor(state["mesh"]["vertices"], device=device), + faces=torch.tensor(state["mesh"]["faces"], device=device), + ) + + return gs, mesh + + +def get_seed(randomize_seed: bool, seed: int, max_seed: int = MAX_SEED) -> int: + return np.random.randint(0, max_seed) if randomize_seed else seed + + +def select_point( + image: np.ndarray, + sel_pix: list, + point_type: str, + evt: gr.SelectData, +): + if point_type == "foreground_point": + sel_pix.append((evt.index, 1)) # append the foreground_point + elif point_type == "background_point": + sel_pix.append((evt.index, 0)) # append the background_point + else: + sel_pix.append((evt.index, 1)) # default foreground_point + + masks = SAM_PREDICTOR.generate_masks(image, sel_pix) + seg_image = SAM_PREDICTOR.get_segmented_image(image, masks) + + for point, label in sel_pix: + color = (255, 0, 0) if label == 0 else (0, 255, 0) + marker_type = 1 if label == 0 else 5 + cv2.drawMarker( + image, + point, + color, + markerType=marker_type, + markerSize=15, + thickness=10, + ) + + torch.cuda.empty_cache() + + return (image, masks), seg_image + + +@spaces.GPU +def image_to_3d( + image: Image.Image, + seed: int, + ss_guidance_strength: float, + ss_sampling_steps: int, + slat_guidance_strength: float, + slat_sampling_steps: int, + raw_image_cache: Image.Image, + sam_image: Image.Image = None, + is_sam_image: bool = False, + req: gr.Request = None, +) -> tuple[dict, str]: + if is_sam_image: + seg_image = filter_image_small_connected_components(sam_image) + seg_image = Image.fromarray(seg_image, mode="RGBA") + seg_image = trellis_preprocess(seg_image) + else: + seg_image = image + + if isinstance(seg_image, np.ndarray): + seg_image = Image.fromarray(seg_image) + + output_root = os.path.join(TMP_DIR, str(req.session_hash)) + os.makedirs(output_root, exist_ok=True) + seg_image.save(f"{output_root}/seg_image.png") + raw_image_cache.save(f"{output_root}/raw_image.png") + PIPELINE.cuda() + outputs = PIPELINE.run( + seg_image, + seed=seed, + formats=["gaussian", "mesh"], + preprocess_image=False, + sparse_structure_sampler_params={ + "steps": ss_sampling_steps, + "cfg_strength": ss_guidance_strength, + }, + slat_sampler_params={ + "steps": slat_sampling_steps, + "cfg_strength": slat_guidance_strength, + }, + ) + # Set to cpu for memory saving. + PIPELINE.cpu() + + gs_model = outputs["gaussian"][0] + mesh_model = outputs["mesh"][0] + color_images = render_video(gs_model)["color"] + normal_images = render_video(mesh_model)["normal"] + + video_path = os.path.join(output_root, "gs_mesh.mp4") + merge_images_video(color_images, normal_images, video_path) + state = pack_state(gs_model, mesh_model) + + gc.collect() + torch.cuda.empty_cache() + + return state, video_path + + +@spaces.GPU +def extract_3d_representations( + state: dict, enable_delight: bool, texture_size: int, req: gr.Request +): + output_root = TMP_DIR + output_root = os.path.join(output_root, str(req.session_hash)) + gs_model, mesh_model = unpack_state(state, device="cuda") + + mesh = postprocessing_utils.to_glb( + gs_model, + mesh_model, + simplify=0.9, + texture_size=1024, + verbose=True, + ) + filename = "sample" + gs_path = os.path.join(output_root, f"{filename}_gs.ply") + gs_model.save_ply(gs_path) + + # Rotate mesh and GS by 90 degrees around Z-axis. + rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]] + # Addtional rotation for GS to align mesh. + gs_rot = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) @ np.array( + rot_matrix + ) + pose = GaussianOperator.trans_to_quatpose(gs_rot) + aligned_gs_path = gs_path.replace(".ply", "_aligned.ply") + GaussianOperator.resave_ply( + in_ply=gs_path, + out_ply=aligned_gs_path, + instance_pose=pose, + ) + + mesh.vertices = mesh.vertices @ np.array(rot_matrix) + mesh_obj_path = os.path.join(output_root, f"{filename}.obj") + mesh.export(mesh_obj_path) + mesh_glb_path = os.path.join(output_root, f"{filename}.glb") + mesh.export(mesh_glb_path) + + torch.cuda.empty_cache() + + return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path + + +def extract_3d_representations_v2( + state: dict, + enable_delight: bool, + texture_size: int, + req: gr.Request, +): + output_root = TMP_DIR + user_dir = os.path.join(output_root, str(req.session_hash)) + gs_model, mesh_model = unpack_state(state, device="cpu") + + filename = "sample" + gs_path = os.path.join(user_dir, f"{filename}_gs.ply") + gs_model.save_ply(gs_path) + + # Rotate mesh and GS by 90 degrees around Z-axis. + rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]] + gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]] + mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]] + + # Addtional rotation for GS to align mesh. + gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix) + pose = GaussianOperator.trans_to_quatpose(gs_rot) + aligned_gs_path = gs_path.replace(".ply", "_aligned.ply") + GaussianOperator.resave_ply( + in_ply=gs_path, + out_ply=aligned_gs_path, + instance_pose=pose, + device="cpu", + ) + color_path = os.path.join(user_dir, "color.png") + render_gs_api(aligned_gs_path, color_path) + + mesh = trimesh.Trimesh( + vertices=mesh_model.vertices.cpu().numpy(), + faces=mesh_model.faces.cpu().numpy(), + ) + mesh.vertices = mesh.vertices @ np.array(mesh_add_rot) + mesh.vertices = mesh.vertices @ np.array(rot_matrix) + + mesh_obj_path = os.path.join(user_dir, f"{filename}.obj") + mesh.export(mesh_obj_path) + + mesh = backproject_api( + delight_model=DELIGHT, + imagesr_model=IMAGESR_MODEL, + color_path=color_path, + mesh_path=mesh_obj_path, + output_path=mesh_obj_path, + skip_fix_mesh=False, + delight=enable_delight, + texture_wh=[texture_size, texture_size], + ) + + mesh_glb_path = os.path.join(user_dir, f"{filename}.glb") + mesh.export(mesh_glb_path) + + return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path + + +def extract_urdf( + gs_path: str, + mesh_obj_path: str, + asset_cat_text: str, + height_range_text: str, + mass_range_text: str, + asset_version_text: str, + req: gr.Request = None, +): + output_root = TMP_DIR + if req is not None: + output_root = os.path.join(output_root, str(req.session_hash)) + + # Convert to URDF and recover attrs by GPT. + filename = "sample" + urdf_convertor = URDFGenerator(GPT_CLIENT, render_view_num=4) + asset_attrs = { + "version": VERSION, + "gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply", + } + if asset_version_text: + asset_attrs["version"] = asset_version_text + if asset_cat_text: + asset_attrs["category"] = asset_cat_text.lower() + if height_range_text: + try: + min_height, max_height = map(float, height_range_text.split("-")) + asset_attrs["min_height"] = min_height + asset_attrs["max_height"] = max_height + except ValueError: + return "Invalid height input format. Use the format: min-max." + if mass_range_text: + try: + min_mass, max_mass = map(float, mass_range_text.split("-")) + asset_attrs["min_mass"] = min_mass + asset_attrs["max_mass"] = max_mass + except ValueError: + return "Invalid mass input format. Use the format: min-max." + + urdf_path = urdf_convertor( + mesh_path=mesh_obj_path, + output_root=f"{output_root}/URDF_{filename}", + **asset_attrs, + ) + + # Rescale GS and save to URDF/mesh folder. + real_height = urdf_convertor.get_attr_from_urdf( + urdf_path, attr_name="real_height" + ) + out_gs = f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}/{filename}_gs.ply" # noqa + GaussianOperator.resave_ply( + in_ply=gs_path, + out_ply=out_gs, + real_height=real_height, + device="cpu", + ) + + # Quality check and update .urdf file. + mesh_out = f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}/{filename}.obj" # noqa + trimesh.load(mesh_out).export(mesh_out.replace(".obj", ".glb")) + # image_paths = render_asset3d( + # mesh_path=mesh_out, + # output_root=f"{output_root}/URDF_{filename}", + # output_subdir="qa_renders", + # num_images=8, + # elevation=(30, -30), + # distance=5.5, + # ) + + image_dir = f"{output_root}/URDF_{filename}/{urdf_convertor.output_render_dir}/image_color" # noqa + image_paths = glob(f"{image_dir}/*.png") + images_list = [] + for checker in CHECKERS: + images = image_paths + if isinstance(checker, ImageSegChecker): + images = [ + f"{TMP_DIR}/{req.session_hash}/raw_image.png", + f"{TMP_DIR}/{req.session_hash}/seg_image.png", + ] + images_list.append(images) + + results = BaseChecker.validate(CHECKERS, images_list) + urdf_convertor.add_quality_tag(urdf_path, results) + + # Zip urdf files + urdf_zip = zip_files( + input_paths=[ + f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}", + f"{output_root}/URDF_{filename}/{filename}.urdf", + ], + output_zip=f"{output_root}/urdf_{filename}.zip", + ) + + estimated_type = urdf_convertor.estimated_attrs["category"] + estimated_height = urdf_convertor.estimated_attrs["height"] + estimated_mass = urdf_convertor.estimated_attrs["mass"] + estimated_mu = urdf_convertor.estimated_attrs["mu"] + + return ( + urdf_zip, + estimated_type, + estimated_height, + estimated_mass, + estimated_mu, + ) + + +@spaces.GPU +def text2image_fn( + prompt: str, + guidance_scale: float, + infer_step: int = 50, + ip_image: Image.Image | str = None, + ip_adapt_scale: float = 0.3, + image_wh: int | tuple[int, int] = [1024, 1024], + rmbg_tag: str = "rembg", + seed: int = None, + n_sample: int = 3, + req: gr.Request = None, +): + if isinstance(image_wh, int): + image_wh = (image_wh, image_wh) + output_root = TMP_DIR + if req is not None: + output_root = os.path.join(output_root, str(req.session_hash)) + os.makedirs(output_root, exist_ok=True) + + pipeline = PIPELINE_IMG if ip_image is None else PIPELINE_IMG_IP + if ip_image is not None: + pipeline.set_ip_adapter_scale([ip_adapt_scale]) + + images = text2img_gen( + prompt=prompt, + n_sample=n_sample, + guidance_scale=guidance_scale, + pipeline=pipeline, + ip_image=ip_image, + image_wh=image_wh, + infer_step=infer_step, + seed=seed, + ) + + for idx in range(len(images)): + image = images[idx] + images[idx], _ = preprocess_image_fn(image, rmbg_tag) + + save_paths = [] + for idx, image in enumerate(images): + save_path = f"{output_root}/sample_{idx}.png" + image.save(save_path) + save_paths.append(save_path) + + logger.info(f"Images saved to {output_root}") + + gc.collect() + torch.cuda.empty_cache() + + return save_paths + save_paths + + +@spaces.GPU +def generate_condition(mesh_path: str, req: gr.Request, uuid: str = "sample"): + output_root = os.path.join(TMP_DIR, str(req.session_hash)) + + _ = render_api( + mesh_path=mesh_path, + output_root=f"{output_root}/condition", + uuid=str(uuid), + ) + + gc.collect() + torch.cuda.empty_cache() + + return None, None, None + + +@spaces.GPU +def generate_texture_mvimages( + prompt: str, + controlnet_cond_scale: float = 0.55, + guidance_scale: float = 9, + strength: float = 0.9, + num_inference_steps: int = 50, + seed: int = 0, + ip_adapt_scale: float = 0, + ip_img_path: str = None, + uid: str = "sample", + sub_idxs: tuple[tuple[int]] = ((0, 1, 2), (3, 4, 5)), + req: gr.Request = None, +) -> list[str]: + output_root = os.path.join(TMP_DIR, str(req.session_hash)) + use_ip_adapter = True if ip_img_path and ip_adapt_scale > 0 else False + PIPELINE_IP.set_ip_adapter_scale([ip_adapt_scale]) + img_save_paths = infer_pipe( + index_file=f"{output_root}/condition/index.json", + controlnet_cond_scale=controlnet_cond_scale, + guidance_scale=guidance_scale, + strength=strength, + num_inference_steps=num_inference_steps, + ip_adapt_scale=ip_adapt_scale, + ip_img_path=ip_img_path, + uid=uid, + prompt=prompt, + save_dir=f"{output_root}/multi_view", + sub_idxs=sub_idxs, + pipeline=PIPELINE_IP if use_ip_adapter else PIPELINE, + seed=seed, + ) + + gc.collect() + torch.cuda.empty_cache() + + return img_save_paths + img_save_paths + + +def backproject_texture( + mesh_path: str, + input_image: str, + texture_size: int, + uuid: str = "sample", + req: gr.Request = None, +) -> str: + output_root = os.path.join(TMP_DIR, str(req.session_hash)) + output_dir = os.path.join(output_root, "texture_mesh") + os.makedirs(output_dir, exist_ok=True) + command = [ + "backproject-cli", + "--mesh_path", + mesh_path, + "--input_image", + input_image, + "--output_root", + output_dir, + "--uuid", + f"{uuid}", + "--texture_size", + str(texture_size), + "--skip_fix_mesh", + ] + + _ = subprocess.run( + command, capture_output=True, text=True, encoding="utf-8" + ) + output_obj_mesh = os.path.join(output_dir, f"{uuid}.obj") + output_glb_mesh = os.path.join(output_dir, f"{uuid}.glb") + _ = trimesh.load(output_obj_mesh).export(output_glb_mesh) + + zip_file = zip_files( + input_paths=[ + output_glb_mesh, + output_obj_mesh, + os.path.join(output_dir, "material.mtl"), + os.path.join(output_dir, "material_0.png"), + ], + output_zip=os.path.join(output_dir, f"{uuid}.zip"), + ) + + gc.collect() + torch.cuda.empty_cache() + + return output_glb_mesh, output_obj_mesh, zip_file + + +@spaces.GPU +def backproject_texture_v2( + mesh_path: str, + input_image: str, + texture_size: int, + enable_delight: bool = True, + fix_mesh: bool = False, + uuid: str = "sample", + req: gr.Request = None, +) -> str: + output_root = os.path.join(TMP_DIR, str(req.session_hash)) + output_dir = os.path.join(output_root, "texture_mesh") + os.makedirs(output_dir, exist_ok=True) + + textured_mesh = backproject_api( + delight_model=DELIGHT, + imagesr_model=IMAGESR_MODEL, + color_path=input_image, + mesh_path=mesh_path, + output_path=f"{output_dir}/{uuid}.obj", + skip_fix_mesh=not fix_mesh, + delight=enable_delight, + texture_wh=[texture_size, texture_size], + ) + + output_obj_mesh = os.path.join(output_dir, f"{uuid}.obj") + output_glb_mesh = os.path.join(output_dir, f"{uuid}.glb") + _ = textured_mesh.export(output_glb_mesh) + + zip_file = zip_files( + input_paths=[ + output_glb_mesh, + output_obj_mesh, + os.path.join(output_dir, "material.mtl"), + os.path.join(output_dir, "material_0.png"), + ], + output_zip=os.path.join(output_dir, f"{uuid}.zip"), + ) + + gc.collect() + torch.cuda.empty_cache() + + return output_glb_mesh, output_obj_mesh, zip_file + + +@spaces.GPU +def render_result_video( + mesh_path: str, video_size: int, req: gr.Request, uuid: str = "" +) -> str: + output_root = os.path.join(TMP_DIR, str(req.session_hash)) + output_dir = os.path.join(output_root, "texture_mesh") + + _ = render_api( + mesh_path=mesh_path, + output_root=output_dir, + num_images=90, + elevation=[20], + with_mtl=True, + pbr_light_factor=1, + uuid=str(uuid), + gen_color_mp4=True, + gen_glonormal_mp4=True, + distance=5.5, + resolution_hw=(video_size, video_size), + ) + + gc.collect() + torch.cuda.empty_cache() + + return f"{output_dir}/color.mp4" diff --git a/apps/image_to_3d.py b/apps/image_to_3d.py new file mode 100644 index 0000000..88883f4 --- /dev/null +++ b/apps/image_to_3d.py @@ -0,0 +1,501 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +import os + +os.environ["GRADIO_APP"] = "imageto3d" +from glob import glob + +import gradio as gr +from common import ( + MAX_SEED, + VERSION, + active_btn_by_content, + custom_theme, + end_session, + extract_3d_representations_v2, + extract_urdf, + get_seed, + image_css, + image_to_3d, + lighting_css, + preprocess_image_fn, + preprocess_sam_image_fn, + select_point, + start_session, +) + +with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo: + gr.Markdown( + """ + ## ***EmbodiedGen***: Image-to-3D Asset + **๐Ÿ”– Version**: {VERSION} +

+ + ๐ŸŒ Project Page + + + ๐Ÿ“„ arXiv + + + ๐Ÿ’ป GitHub + + + ๐ŸŽฅ Video + +

+ + ๐Ÿ–ผ๏ธ Generate physically plausible 3D asset from single input image. + + """.format( + VERSION=VERSION + ), + elem_classes=["header"], + ) + + gr.HTML(image_css) + gr.HTML(lighting_css) + with gr.Row(): + with gr.Column(scale=2): + with gr.Tabs() as input_tabs: + with gr.Tab( + label="Image(auto seg)", id=0 + ) as single_image_input_tab: + raw_image_cache = gr.Image( + format="png", + image_mode="RGB", + type="pil", + visible=False, + ) + image_prompt = gr.Image( + label="Input Image", + format="png", + image_mode="RGBA", + type="pil", + height=400, + elem_classes=["image_fit"], + ) + gr.Markdown( + """ + If you are not satisfied with the auto segmentation + result, please switch to the `Image(SAM seg)` tab.""" + ) + with gr.Tab( + label="Image(SAM seg)", id=1 + ) as samimage_input_tab: + with gr.Row(): + with gr.Column(scale=1): + image_prompt_sam = gr.Image( + label="Input Image", + type="numpy", + height=400, + elem_classes=["image_fit"], + ) + image_seg_sam = gr.Image( + label="SAM Seg Image", + image_mode="RGBA", + type="pil", + height=400, + visible=False, + ) + with gr.Column(scale=1): + image_mask_sam = gr.AnnotatedImage( + elem_classes=["image_fit"] + ) + + fg_bg_radio = gr.Radio( + ["foreground_point", "background_point"], + label="Select foreground(green) or background(red) points, by default foreground", # noqa + value="foreground_point", + ) + gr.Markdown( + """ Click the `Input Image` to select SAM points, + after get the satisified segmentation, click `Generate` + button to generate the 3D asset. \n + Note: If the segmented foreground is too small relative + to the entire image area, the generation will fail. + """ + ) + + with gr.Accordion(label="Generation Settings", open=False): + with gr.Row(): + seed = gr.Slider( + 0, MAX_SEED, label="Seed", value=0, step=1 + ) + texture_size = gr.Slider( + 1024, + 4096, + label="UV texture size", + value=2048, + step=256, + ) + rmbg_tag = gr.Radio( + choices=["rembg", "rmbg14"], + value="rembg", + label="Background Removal Model", + ) + with gr.Row(): + randomize_seed = gr.Checkbox( + label="Randomize Seed", value=False + ) + project_delight = gr.Checkbox( + label="Backproject delighting", + value=False, + ) + gr.Markdown("Geo Structure Generation") + with gr.Row(): + ss_guidance_strength = gr.Slider( + 0.0, + 10.0, + label="Guidance Strength", + value=7.5, + step=0.1, + ) + ss_sampling_steps = gr.Slider( + 1, 50, label="Sampling Steps", value=12, step=1 + ) + gr.Markdown("Visual Appearance Generation") + with gr.Row(): + slat_guidance_strength = gr.Slider( + 0.0, + 10.0, + label="Guidance Strength", + value=3.0, + step=0.1, + ) + slat_sampling_steps = gr.Slider( + 1, 50, label="Sampling Steps", value=12, step=1 + ) + + generate_btn = gr.Button( + "๐Ÿš€ 1. Generate(~0.5 mins)", + variant="primary", + interactive=False, + ) + model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False) + with gr.Row(): + extract_rep3d_btn = gr.Button( + "๐Ÿ” 2. Extract 3D Representation(~2 mins)", + variant="primary", + interactive=False, + ) + with gr.Accordion( + label="Enter Asset Attributes(optional)", open=False + ): + asset_cat_text = gr.Textbox( + label="Enter Asset Category (e.g., chair)" + ) + height_range_text = gr.Textbox( + label="Enter **Height Range** in meter (e.g., 0.5-0.6)" + ) + mass_range_text = gr.Textbox( + label="Enter **Mass Range** in kg (e.g., 1.1-1.2)" + ) + asset_version_text = gr.Textbox( + label=f"Enter version (e.g., {VERSION})" + ) + with gr.Row(): + extract_urdf_btn = gr.Button( + "๐Ÿงฉ 3. Extract URDF with physics(~1 mins)", + variant="primary", + interactive=False, + ) + with gr.Row(): + gr.Markdown( + "#### Estimated Asset 3D Attributes(No input required)" + ) + with gr.Row(): + est_type_text = gr.Textbox( + label="Asset category", interactive=False + ) + est_height_text = gr.Textbox( + label="Real height(.m)", interactive=False + ) + est_mass_text = gr.Textbox( + label="Mass(.kg)", interactive=False + ) + est_mu_text = gr.Textbox( + label="Friction coefficient", interactive=False + ) + with gr.Row(): + download_urdf = gr.DownloadButton( + label="โฌ‡๏ธ 4. Download URDF", + variant="primary", + interactive=False, + ) + + gr.Markdown( + """ NOTE: If `Asset Attributes` are provided, the provided + properties will be used; otherwise, the GPT-preset properties + will be applied. \n + The `Download URDF` file is restored to the real scale and + has quality inspection, open with an editor to view details. + """ + ) + + with gr.Row() as single_image_example: + examples = gr.Examples( + label="Image Gallery", + examples=[ + [image_path] + for image_path in sorted( + glob("apps/assets/example_image/*") + ) + ], + inputs=[image_prompt, rmbg_tag], + fn=preprocess_image_fn, + outputs=[image_prompt, raw_image_cache], + run_on_click=True, + examples_per_page=10, + ) + + with gr.Row(visible=False) as single_sam_image_example: + examples = gr.Examples( + label="Image Gallery", + examples=[ + [image_path] + for image_path in sorted( + glob("apps/assets/example_image/*") + ) + ], + inputs=[image_prompt_sam], + fn=preprocess_sam_image_fn, + outputs=[image_prompt_sam, raw_image_cache], + run_on_click=True, + examples_per_page=10, + ) + with gr.Column(scale=1): + video_output = gr.Video( + label="Generated 3D Asset", + autoplay=True, + loop=True, + height=300, + ) + model_output_gs = gr.Model3D( + label="Gaussian Representation", height=300, interactive=False + ) + aligned_gs = gr.Textbox(visible=False) + gr.Markdown( + """ The rendering of `Gaussian Representation` takes additional 10s. """ # noqa + ) + with gr.Row(): + model_output_mesh = gr.Model3D( + label="Mesh Representation", + height=300, + interactive=False, + clear_color=[0.8, 0.8, 0.8, 1], + elem_id="lighter_mesh", + ) + + is_samimage = gr.State(False) + output_buf = gr.State() + selected_points = gr.State(value=[]) + + demo.load(start_session) + demo.unload(end_session) + + single_image_input_tab.select( + lambda: tuple( + [False, gr.Row.update(visible=True), gr.Row.update(visible=False)] + ), + outputs=[is_samimage, single_image_example, single_sam_image_example], + ) + samimage_input_tab.select( + lambda: tuple( + [True, gr.Row.update(visible=True), gr.Row.update(visible=False)] + ), + outputs=[is_samimage, single_sam_image_example, single_image_example], + ) + + image_prompt.upload( + preprocess_image_fn, + inputs=[image_prompt, rmbg_tag], + outputs=[image_prompt, raw_image_cache], + ) + image_prompt.change( + lambda: tuple( + [ + gr.Button(interactive=False), + gr.Button(interactive=False), + gr.Button(interactive=False), + None, + "", + None, + None, + "", + "", + "", + "", + "", + "", + "", + "", + ] + ), + outputs=[ + extract_rep3d_btn, + extract_urdf_btn, + download_urdf, + model_output_gs, + aligned_gs, + model_output_mesh, + video_output, + asset_cat_text, + height_range_text, + mass_range_text, + asset_version_text, + est_type_text, + est_height_text, + est_mass_text, + est_mu_text, + ], + ) + image_prompt.change( + active_btn_by_content, + inputs=image_prompt, + outputs=generate_btn, + ) + + image_prompt_sam.upload( + preprocess_sam_image_fn, + inputs=[image_prompt_sam], + outputs=[image_prompt_sam, raw_image_cache], + ) + image_prompt_sam.change( + lambda: tuple( + [ + gr.Button(interactive=False), + gr.Button(interactive=False), + gr.Button(interactive=False), + None, + None, + None, + "", + "", + "", + "", + "", + "", + "", + "", + None, + [], + ] + ), + outputs=[ + extract_rep3d_btn, + extract_urdf_btn, + download_urdf, + model_output_gs, + model_output_mesh, + video_output, + asset_cat_text, + height_range_text, + mass_range_text, + asset_version_text, + est_type_text, + est_height_text, + est_mass_text, + est_mu_text, + image_mask_sam, + selected_points, + ], + ) + + image_prompt_sam.select( + select_point, + [ + image_prompt_sam, + selected_points, + fg_bg_radio, + ], + [image_mask_sam, image_seg_sam], + ) + image_seg_sam.change( + active_btn_by_content, + inputs=image_seg_sam, + outputs=generate_btn, + ) + + generate_btn.click( + get_seed, + inputs=[randomize_seed, seed], + outputs=[seed], + ).success( + image_to_3d, + inputs=[ + image_prompt, + seed, + ss_guidance_strength, + ss_sampling_steps, + slat_guidance_strength, + slat_sampling_steps, + raw_image_cache, + image_seg_sam, + is_samimage, + ], + outputs=[output_buf, video_output], + ).success( + lambda: gr.Button(interactive=True), + outputs=[extract_rep3d_btn], + ) + + extract_rep3d_btn.click( + extract_3d_representations_v2, + inputs=[ + output_buf, + project_delight, + texture_size, + ], + outputs=[ + model_output_mesh, + model_output_gs, + model_output_obj, + aligned_gs, + ], + ).success( + lambda: gr.Button(interactive=True), + outputs=[extract_urdf_btn], + ) + + extract_urdf_btn.click( + extract_urdf, + inputs=[ + aligned_gs, + model_output_obj, + asset_cat_text, + height_range_text, + mass_range_text, + asset_version_text, + ], + outputs=[ + download_urdf, + est_type_text, + est_height_text, + est_mass_text, + est_mu_text, + ], + queue=True, + show_progress="full", + ).success( + lambda: gr.Button(interactive=True), + outputs=[download_urdf], + ) + + +if __name__ == "__main__": + demo.launch(server_name="10.34.8.82", server_port=8081) diff --git a/apps/text_to_3d.py b/apps/text_to_3d.py new file mode 100644 index 0000000..a36a99f --- /dev/null +++ b/apps/text_to_3d.py @@ -0,0 +1,481 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +import os + +os.environ["GRADIO_APP"] = "textto3d" + + +import gradio as gr +from common import ( + MAX_SEED, + VERSION, + active_btn_by_text_content, + custom_theme, + end_session, + extract_3d_representations_v2, + extract_urdf, + get_cached_image, + get_seed, + get_selected_image, + image_css, + image_to_3d, + lighting_css, + start_session, + text2image_fn, +) + +with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo: + gr.Markdown( + """ + ## ***EmbodiedGen***: Text-to-3D Asset + **๐Ÿ”– Version**: {VERSION} +

+ + ๐ŸŒ Project Page + + + ๐Ÿ“„ arXiv + + + ๐Ÿ’ป GitHub + + + ๐ŸŽฅ Video + +

+ + ๐Ÿ“ Create 3D assets from text descriptions for a wide range of geometry and styles. + + """.format( + VERSION=VERSION + ), + elem_classes=["header"], + ) + gr.HTML(image_css) + gr.HTML(lighting_css) + with gr.Row(): + with gr.Column(scale=1): + raw_image_cache = gr.Image( + format="png", + image_mode="RGB", + type="pil", + visible=False, + ) + text_prompt = gr.Textbox( + label="Text Prompt (Chinese or English)", + placeholder="Input text prompt here", + ) + ip_image = gr.Image( + label="Reference Image(optional)", + format="png", + image_mode="RGB", + type="filepath", + height=250, + elem_classes=["image_fit"], + ) + gr.Markdown( + "Note: The `reference image` is optional, if use, " + "please provide image in nearly square resolution." + ) + + with gr.Accordion(label="Image Generation Settings", open=False): + with gr.Row(): + seed = gr.Slider( + 0, MAX_SEED, label="Seed", value=0, step=1 + ) + randomize_seed = gr.Checkbox( + label="Randomize Seed", value=False + ) + rmbg_tag = gr.Radio( + choices=["rembg", "rmbg14"], + value="rembg", + label="Background Removal Model", + ) + ip_adapt_scale = gr.Slider( + 0, 1, label="IP-adapter Scale", value=0.3, step=0.05 + ) + img_guidance_scale = gr.Slider( + 1, 30, label="Text Guidance Scale", value=12, step=0.2 + ) + img_inference_steps = gr.Slider( + 10, 100, label="Sampling Steps", value=50, step=5 + ) + img_resolution = gr.Slider( + 512, + 1536, + label="Image Resolution", + value=1024, + step=128, + ) + + generate_img_btn = gr.Button( + "๐ŸŽจ 1. Generate Images(~1min)", + variant="primary", + interactive=False, + ) + dropdown = gr.Radio( + choices=["sample1", "sample2", "sample3"], + value="sample1", + label="Choose your favorite sample style.", + ) + select_img = gr.Image( + visible=False, + format="png", + image_mode="RGBA", + type="pil", + height=300, + ) + + # text to 3d + with gr.Accordion(label="Generation Settings", open=False): + seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1) + texture_size = gr.Slider( + 1024, 4096, label="UV texture size", value=2048, step=256 + ) + with gr.Row(): + randomize_seed = gr.Checkbox( + label="Randomize Seed", value=False + ) + project_delight = gr.Checkbox( + label="backproject delight", value=False + ) + gr.Markdown("Geo Structure Generation") + with gr.Row(): + ss_guidance_strength = gr.Slider( + 0.0, + 10.0, + label="Guidance Strength", + value=7.5, + step=0.1, + ) + ss_sampling_steps = gr.Slider( + 1, 50, label="Sampling Steps", value=12, step=1 + ) + gr.Markdown("Visual Appearance Generation") + with gr.Row(): + slat_guidance_strength = gr.Slider( + 0.0, + 10.0, + label="Guidance Strength", + value=3.0, + step=0.1, + ) + slat_sampling_steps = gr.Slider( + 1, 50, label="Sampling Steps", value=12, step=1 + ) + + generate_btn = gr.Button( + "๐Ÿš€ 2. Generate 3D(~0.5 mins)", + variant="primary", + interactive=False, + ) + model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False) + with gr.Row(): + extract_rep3d_btn = gr.Button( + "๐Ÿ” 3. Extract 3D Representation(~1 mins)", + variant="primary", + interactive=False, + ) + with gr.Accordion( + label="Enter Asset Attributes(optional)", open=False + ): + asset_cat_text = gr.Textbox( + label="Enter Asset Category (e.g., chair)" + ) + height_range_text = gr.Textbox( + label="Enter Height Range in meter (e.g., 0.5-0.6)" + ) + mass_range_text = gr.Textbox( + label="Enter Mass Range in kg (e.g., 1.1-1.2)" + ) + asset_version_text = gr.Textbox( + label=f"Enter version (e.g., {VERSION})" + ) + with gr.Row(): + extract_urdf_btn = gr.Button( + "๐Ÿงฉ 4. Extract URDF with physics(~1 mins)", + variant="primary", + interactive=False, + ) + with gr.Row(): + download_urdf = gr.DownloadButton( + label="โฌ‡๏ธ 5. Download URDF", + variant="primary", + interactive=False, + ) + + with gr.Column(scale=3): + with gr.Row(): + image_sample1 = gr.Image( + label="sample1", + format="png", + image_mode="RGBA", + type="filepath", + height=300, + interactive=False, + elem_classes=["image_fit"], + ) + image_sample2 = gr.Image( + label="sample2", + format="png", + image_mode="RGBA", + type="filepath", + height=300, + interactive=False, + elem_classes=["image_fit"], + ) + image_sample3 = gr.Image( + label="sample3", + format="png", + image_mode="RGBA", + type="filepath", + height=300, + interactive=False, + elem_classes=["image_fit"], + ) + usample1 = gr.Image( + format="png", + image_mode="RGBA", + type="filepath", + visible=False, + ) + usample2 = gr.Image( + format="png", + image_mode="RGBA", + type="filepath", + visible=False, + ) + usample3 = gr.Image( + format="png", + image_mode="RGBA", + type="filepath", + visible=False, + ) + gr.Markdown( + "The generated image may be of poor quality due to auto " + "segmentation. Try adjusting the text prompt or seed." + ) + with gr.Row(): + video_output = gr.Video( + label="Generated 3D Asset", + autoplay=True, + loop=True, + height=300, + interactive=False, + ) + model_output_gs = gr.Model3D( + label="Gaussian Representation", + height=300, + interactive=False, + ) + aligned_gs = gr.Textbox(visible=False) + + model_output_mesh = gr.Model3D( + label="Mesh Representation", + clear_color=[0.8, 0.8, 0.8, 1], + height=300, + interactive=False, + elem_id="lighter_mesh", + ) + + gr.Markdown("Estimated Asset 3D Attributes(No input required)") + with gr.Row(): + est_type_text = gr.Textbox( + label="Asset category", interactive=False + ) + est_height_text = gr.Textbox( + label="Real height(.m)", interactive=False + ) + est_mass_text = gr.Textbox( + label="Mass(.kg)", interactive=False + ) + est_mu_text = gr.Textbox( + label="Friction coefficient", interactive=False + ) + + prompt_examples = [ + "satin gold tea cup with saucer", + "small bronze figurine of a lion", + "brown leather bag", + "Miniature cup with floral design", + "ๅธฆๆœจ่ดจๅบ•ๅบง, ๅ…ทๆœ‰็ป็บฌ็บฟ็š„ๅœฐ็ƒไปช", + "ๆฉ™่‰ฒ็”ตๅŠจๆ‰‹้’ป, ๆœ‰็ฃจๆŸ็ป†่Š‚", + "ๆ‰‹ๅทฅๅˆถไฝœ็š„็šฎ้ฉ็ฌ”่ฎฐๆœฌ", + ] + examples = gr.Examples( + label="Gallery", + examples=prompt_examples, + inputs=[text_prompt], + examples_per_page=10, + ) + + output_buf = gr.State() + + demo.load(start_session) + demo.unload(end_session) + + text_prompt.change( + active_btn_by_text_content, + inputs=[text_prompt], + outputs=[generate_img_btn], + ) + + generate_img_btn.click( + lambda: tuple( + [ + gr.Button(interactive=False), + gr.Button(interactive=False), + gr.Button(interactive=False), + gr.Button(interactive=False), + None, + "", + None, + None, + "", + "", + "", + "", + "", + "", + "", + "", + None, + None, + None, + ] + ), + outputs=[ + extract_rep3d_btn, + extract_urdf_btn, + download_urdf, + generate_btn, + model_output_gs, + aligned_gs, + model_output_mesh, + video_output, + asset_cat_text, + height_range_text, + mass_range_text, + asset_version_text, + est_type_text, + est_height_text, + est_mass_text, + est_mu_text, + image_sample1, + image_sample2, + image_sample3, + ], + ).success( + text2image_fn, + inputs=[ + text_prompt, + img_guidance_scale, + img_inference_steps, + ip_image, + ip_adapt_scale, + img_resolution, + rmbg_tag, + seed, + ], + outputs=[ + image_sample1, + image_sample2, + image_sample3, + usample1, + usample2, + usample3, + ], + ).success( + lambda: gr.Button(interactive=True), + outputs=[generate_btn], + ) + + generate_btn.click( + get_seed, + inputs=[randomize_seed, seed], + outputs=[seed], + ).success( + get_selected_image, + inputs=[dropdown, usample1, usample2, usample3], + outputs=select_img, + ).success( + get_cached_image, + inputs=[select_img], + outputs=[raw_image_cache], + ).success( + image_to_3d, + inputs=[ + select_img, + seed, + ss_guidance_strength, + ss_sampling_steps, + slat_guidance_strength, + slat_sampling_steps, + raw_image_cache, + ], + outputs=[output_buf, video_output], + ).success( + lambda: gr.Button(interactive=True), + outputs=[extract_rep3d_btn], + ) + + extract_rep3d_btn.click( + extract_3d_representations_v2, + inputs=[ + output_buf, + project_delight, + texture_size, + ], + outputs=[ + model_output_mesh, + model_output_gs, + model_output_obj, + aligned_gs, + ], + ).success( + lambda: gr.Button(interactive=True), + outputs=[extract_urdf_btn], + ) + + extract_urdf_btn.click( + extract_urdf, + inputs=[ + aligned_gs, + model_output_obj, + asset_cat_text, + height_range_text, + mass_range_text, + asset_version_text, + ], + outputs=[ + download_urdf, + est_type_text, + est_height_text, + est_mass_text, + est_mu_text, + ], + queue=True, + show_progress="full", + ).success( + lambda: gr.Button(interactive=True), + outputs=[download_urdf], + ) + + +if __name__ == "__main__": + demo.launch(server_name="10.34.8.82", server_port=8082) diff --git a/apps/texture_edit.py b/apps/texture_edit.py new file mode 100644 index 0000000..8d14bca --- /dev/null +++ b/apps/texture_edit.py @@ -0,0 +1,382 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +import os + +os.environ["GRADIO_APP"] = "texture_edit" +import gradio as gr +from common import ( + MAX_SEED, + VERSION, + backproject_texture_v2, + custom_theme, + end_session, + generate_condition, + generate_texture_mvimages, + get_seed, + get_selected_image, + image_css, + lighting_css, + render_result_video, + start_session, +) + + +def active_btn_by_content(mesh_content: gr.Model3D, text_content: gr.Textbox): + if ( + mesh_content is not None + and text_content is not None + and len(text_content) > 0 + ): + interactive = True + else: + interactive = False + + return gr.Button(interactive=interactive) + + +with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo: + gr.Markdown( + """ + ## ***EmbodiedGen***: Texture Generation + **๐Ÿ”– Version**: {VERSION} +

+ + ๐ŸŒ Project Page + + + ๐Ÿ“„ arXiv + + + ๐Ÿ’ป GitHub + + + ๐ŸŽฅ Video + +

+ + ๐ŸŽจ Generate visually rich textures for 3D mesh. + + """.format( + VERSION=VERSION + ), + elem_classes=["header"], + ) + gr.HTML(image_css) + gr.HTML(lighting_css) + with gr.Row(): + with gr.Column(scale=1): + mesh_input = gr.Model3D( + label="Upload Mesh File(.obj or .glb)", height=300 + ) + local_mesh = gr.Textbox(visible=False) + text_prompt = gr.Textbox( + label="Text Prompt (Chinese or English)", + placeholder="Input text prompt here", + ) + ip_image = gr.Image( + label="Reference Image(optional)", + format="png", + image_mode="RGB", + type="filepath", + height=250, + elem_classes=["image_fit"], + ) + gr.Markdown( + "Note: The `reference image` is optional. If provided, please " + "increase the `Condition Scale` in Generation Settings." + ) + + with gr.Accordion(label="Generation Settings", open=False): + with gr.Row(): + seed = gr.Slider( + 0, MAX_SEED, label="Seed", value=0, step=1 + ) + randomize_seed = gr.Checkbox( + label="Randomize Seed", value=False + ) + ip_adapt_scale = gr.Slider( + 0, 1, label="IP-adapter Scale", value=0.7, step=0.05 + ) + cond_scale = gr.Slider( + 0.0, + 1.0, + label="Geo Condition Scale", + value=0.60, + step=0.01, + ) + guidance_scale = gr.Slider( + 1, 30, label="Text Guidance Scale", value=9, step=0.2 + ) + guidance_strength = gr.Slider( + 0.0, + 1.0, + label="Strength", + value=0.9, + step=0.05, + ) + num_inference_steps = gr.Slider( + 10, 100, label="Sampling Steps", value=50, step=5 + ) + texture_size = gr.Slider( + 1024, 4096, label="UV texture size", value=2048, step=256 + ) + video_size = gr.Slider( + 512, 2048, label="Video Resolution", value=512, step=256 + ) + + generate_mv_btn = gr.Button( + "๐ŸŽจ 1. Generate MV Images(~1min)", + variant="primary", + interactive=False, + ) + + with gr.Column(scale=3): + with gr.Row(): + image_sample1 = gr.Image( + label="sample1", + format="png", + image_mode="RGBA", + type="filepath", + height=300, + interactive=False, + elem_classes=["image_fit"], + ) + image_sample2 = gr.Image( + label="sample2", + format="png", + image_mode="RGBA", + type="filepath", + height=300, + interactive=False, + elem_classes=["image_fit"], + ) + image_sample3 = gr.Image( + label="sample3", + format="png", + image_mode="RGBA", + type="filepath", + height=300, + interactive=False, + elem_classes=["image_fit"], + ) + + usample1 = gr.Image( + format="png", + image_mode="RGBA", + type="filepath", + visible=False, + ) + usample2 = gr.Image( + format="png", + image_mode="RGBA", + type="filepath", + visible=False, + ) + usample3 = gr.Image( + format="png", + image_mode="RGBA", + type="filepath", + visible=False, + ) + + gr.Markdown( + "Note: Select samples with consistent textures from various " + "perspectives and no obvious reflections." + ) + with gr.Row(): + with gr.Column(scale=1): + with gr.Row(): + dropdown = gr.Radio( + choices=["sample1", "sample2", "sample3"], + value="sample1", + label="Choose your favorite sample style.", + ) + select_img = gr.Image( + visible=False, + format="png", + image_mode="RGBA", + type="filepath", + height=300, + ) + with gr.Row(): + project_delight = gr.Checkbox( + label="delight", value=True + ) + fix_mesh = gr.Checkbox( + label="simplify mesh", value=False + ) + + with gr.Column(scale=1): + texture_bake_btn = gr.Button( + "๐Ÿ› ๏ธ 2. Texture Baking(~2min)", + variant="primary", + interactive=False, + ) + download_btn = gr.DownloadButton( + label="โฌ‡๏ธ 3. Download Mesh", + variant="primary", + interactive=False, + ) + + with gr.Row(): + mesh_output = gr.Model3D( + label="Mesh Edit Result", + clear_color=[0.8, 0.8, 0.8, 1], + height=380, + interactive=False, + elem_id="lighter_mesh", + ) + mesh_outpath = gr.Textbox(visible=False) + video_output = gr.Video( + label="Mesh Edit Video", + autoplay=True, + loop=True, + height=380, + ) + + with gr.Row(): + prompt_examples = [] + with open("apps/assets/example_texture/text_prompts.txt", "r") as f: + for line in f: + parts = line.strip().split("\\") + prompt_examples.append([parts[0].strip(), parts[1].strip()]) + + examples = gr.Examples( + label="Mesh Gallery", + examples=prompt_examples, + inputs=[mesh_input, text_prompt], + examples_per_page=10, + ) + + demo.load(start_session) + demo.unload(end_session) + + mesh_input.change( + lambda: tuple( + [ + None, + None, + None, + gr.Button(interactive=False), + gr.Button(interactive=False), + None, + None, + None, + ] + ), + outputs=[ + mesh_outpath, + mesh_output, + video_output, + texture_bake_btn, + download_btn, + image_sample1, + image_sample2, + image_sample3, + ], + ).success( + active_btn_by_content, + inputs=[mesh_input, text_prompt], + outputs=[generate_mv_btn], + ) + + text_prompt.change( + active_btn_by_content, + inputs=[mesh_input, text_prompt], + outputs=[generate_mv_btn], + ) + + generate_mv_btn.click( + get_seed, + inputs=[randomize_seed, seed], + outputs=[seed], + ).success( + lambda: tuple( + [ + None, + None, + None, + gr.Button(interactive=False), + gr.Button(interactive=False), + ] + ), + outputs=[ + mesh_outpath, + mesh_output, + video_output, + texture_bake_btn, + download_btn, + ], + ).success( + generate_condition, + inputs=[mesh_input], + outputs=[image_sample1, image_sample2, image_sample3], + ).success( + generate_texture_mvimages, + inputs=[ + text_prompt, + cond_scale, + guidance_scale, + guidance_strength, + num_inference_steps, + seed, + ip_adapt_scale, + ip_image, + ], + outputs=[ + image_sample1, + image_sample2, + image_sample3, + usample1, + usample2, + usample3, + ], + ).success( + lambda: gr.Button(interactive=True), + outputs=[texture_bake_btn], + ) + + texture_bake_btn.click( + lambda: tuple([None, None, None, gr.Button(interactive=False)]), + outputs=[mesh_outpath, mesh_output, video_output, download_btn], + ).success( + get_selected_image, + inputs=[dropdown, usample1, usample2, usample3], + outputs=select_img, + ).success( + backproject_texture_v2, + inputs=[ + mesh_input, + select_img, + texture_size, + project_delight, + fix_mesh, + ], + outputs=[mesh_output, mesh_outpath, download_btn], + ).success( + lambda: gr.DownloadButton(interactive=True), + outputs=[download_btn], + ).success( + render_result_video, + inputs=[mesh_outpath, video_size], + outputs=[video_output], + ) + + +if __name__ == "__main__": + demo.launch(server_name="10.34.8.82", server_port=8083) diff --git a/embodied_gen/data/backproject.py b/embodied_gen/data/backproject.py new file mode 100644 index 0000000..b02ae27 --- /dev/null +++ b/embodied_gen/data/backproject.py @@ -0,0 +1,518 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +import argparse +import logging +import math +import os +from typing import List, Literal, Union + +import cv2 +import numpy as np +import nvdiffrast.torch as dr +import torch +import trimesh +import utils3d +import xatlas +from tqdm import tqdm +from embodied_gen.data.mesh_operator import MeshFixer +from embodied_gen.data.utils import ( + CameraSetting, + get_images_from_grid, + init_kal_camera, + normalize_vertices_array, + post_process_texture, + save_mesh_with_mtl, +) +from embodied_gen.models.delight_model import DelightingModel + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO +) +logger = logging.getLogger(__name__) + + +class TextureBaker(object): + """Baking textures onto a mesh from multiple observations. + + This class take 3D mesh data, camera settings and texture baking parameters + to generate texture map by projecting images to the mesh from diff views. + It supports both a fast texture baking approach and a more optimized method + with total variation regularization. + + Attributes: + vertices (torch.Tensor): The vertices of the mesh. + faces (torch.Tensor): The faces of the mesh, defined by vertex indices. + uvs (torch.Tensor): The UV coordinates of the mesh. + camera_params (CameraSetting): Camera setting (intrinsics, extrinsics). + device (str): The device to run computations on ("cpu" or "cuda"). + w2cs (torch.Tensor): World-to-camera transformation matrices. + projections (torch.Tensor): Camera projection matrices. + + Example: + >>> vertices, faces, uvs = TextureBaker.parametrize_mesh(vertices, faces) # noqa + >>> texture_backer = TextureBaker(vertices, faces, uvs, camera_params) + >>> images = get_images_from_grid(args.color_path, image_size) + >>> texture = texture_backer.bake_texture( + ... images, texture_size=args.texture_size, mode=args.baker_mode + ... ) + >>> texture = post_process_texture(texture) + """ + + def __init__( + self, + vertices: np.ndarray, + faces: np.ndarray, + uvs: np.ndarray, + camera_params: CameraSetting, + device: str = "cuda", + ) -> None: + self.vertices = ( + torch.tensor(vertices, device=device) + if isinstance(vertices, np.ndarray) + else vertices.to(device) + ) + self.faces = ( + torch.tensor(faces.astype(np.int32), device=device) + if isinstance(faces, np.ndarray) + else faces.to(device) + ) + self.uvs = ( + torch.tensor(uvs, device=device) + if isinstance(uvs, np.ndarray) + else uvs.to(device) + ) + self.camera_params = camera_params + self.device = device + + camera = init_kal_camera(camera_params) + matrix_mv = camera.view_matrix() # (n_cam 4 4) world2cam + matrix_mv = kaolin_to_opencv_view(matrix_mv) + matrix_p = ( + camera.intrinsics.projection_matrix() + ) # (n_cam 4 4) cam2pixel + self.w2cs = matrix_mv.to(self.device) + self.projections = matrix_p.to(self.device) + + @staticmethod + def parametrize_mesh( + vertices: np.array, faces: np.array + ) -> Union[np.array, np.array, np.array]: + vmapping, indices, uvs = xatlas.parametrize(vertices, faces) + + vertices = vertices[vmapping] + faces = indices + + return vertices, faces, uvs + + def _bake_fast(self, observations, w2cs, projections, texture_size, masks): + texture = torch.zeros( + (texture_size * texture_size, 3), dtype=torch.float32 + ).cuda() + texture_weights = torch.zeros( + (texture_size * texture_size), dtype=torch.float32 + ).cuda() + rastctx = utils3d.torch.RastContext(backend="cuda") + for observation, w2c, projection in tqdm( + zip(observations, w2cs, projections), + total=len(observations), + desc="Texture baking (fast)", + ): + with torch.no_grad(): + rast = utils3d.torch.rasterize_triangle_faces( + rastctx, + self.vertices[None], + self.faces, + observation.shape[1], + observation.shape[0], + uv=self.uvs[None], + view=w2c, + projection=projection, + ) + uv_map = rast["uv"][0].detach().flip(0) + mask = rast["mask"][0].detach().bool() & masks[0] + + # nearest neighbor interpolation + uv_map = (uv_map * texture_size).floor().long() + obs = observation[mask] + uv_map = uv_map[mask] + idx = ( + uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size + ) + texture = texture.scatter_add( + 0, idx.view(-1, 1).expand(-1, 3), obs + ) + texture_weights = texture_weights.scatter_add( + 0, + idx, + torch.ones( + (obs.shape[0]), dtype=torch.float32, device=texture.device + ), + ) + + mask = texture_weights > 0 + texture[mask] /= texture_weights[mask][:, None] + texture = np.clip( + texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255, + 0, + 255, + ).astype(np.uint8) + + # inpaint + mask = ( + (texture_weights == 0) + .cpu() + .numpy() + .astype(np.uint8) + .reshape(texture_size, texture_size) + ) + texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA) + + return texture + + def _bake_opt( + self, + observations, + w2cs, + projections, + texture_size, + lambda_tv, + masks, + total_steps, + ): + rastctx = utils3d.torch.RastContext(backend="cuda") + observations = [observations.flip(0) for observations in observations] + masks = [m.flip(0) for m in masks] + _uv = [] + _uv_dr = [] + for observation, w2c, projection in tqdm( + zip(observations, w2cs, projections), + total=len(w2cs), + ): + with torch.no_grad(): + rast = utils3d.torch.rasterize_triangle_faces( + rastctx, + self.vertices[None], + self.faces, + observation.shape[1], + observation.shape[0], + uv=self.uvs[None], + view=w2c, + projection=projection, + ) + _uv.append(rast["uv"].detach()) + _uv_dr.append(rast["uv_dr"].detach()) + + texture = torch.nn.Parameter( + torch.zeros( + (1, texture_size, texture_size, 3), dtype=torch.float32 + ).cuda() + ) + optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2) + + def cosine_anealing(step, total_steps, start_lr, end_lr): + return end_lr + 0.5 * (start_lr - end_lr) * ( + 1 + np.cos(np.pi * step / total_steps) + ) + + def tv_loss(texture): + return torch.nn.functional.l1_loss( + texture[:, :-1, :, :], texture[:, 1:, :, :] + ) + torch.nn.functional.l1_loss( + texture[:, :, :-1, :], texture[:, :, 1:, :] + ) + + with tqdm(total=total_steps, desc="Texture baking") as pbar: + for step in range(total_steps): + optimizer.zero_grad() + selected = np.random.randint(0, len(w2cs)) + uv, uv_dr, observation, mask = ( + _uv[selected], + _uv_dr[selected], + observations[selected], + masks[selected], + ) + render = dr.texture(texture, uv, uv_dr)[0] + loss = torch.nn.functional.l1_loss( + render[mask], observation[mask] + ) + if lambda_tv > 0: + loss += lambda_tv * tv_loss(texture) + loss.backward() + optimizer.step() + + optimizer.param_groups[0]["lr"] = cosine_anealing( + step, total_steps, 1e-2, 1e-5 + ) + pbar.set_postfix({"loss": loss.item()}) + pbar.update() + texture = np.clip( + texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255 + ).astype(np.uint8) + mask = 1 - utils3d.torch.rasterize_triangle_faces( + rastctx, + (self.uvs * 2 - 1)[None], + self.faces, + texture_size, + texture_size, + )["mask"][0].detach().cpu().numpy().astype(np.uint8) + texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA) + + return texture + + def bake_texture( + self, + images: List[np.array], + texture_size: int = 1024, + mode: Literal["fast", "opt"] = "opt", + lambda_tv: float = 1e-2, + opt_step: int = 2000, + ): + masks = [np.any(img > 0, axis=-1) for img in images] + masks = [torch.tensor(m > 0).bool().to(self.device) for m in masks] + images = [ + torch.tensor(obs / 255.0).float().to(self.device) for obs in images + ] + + if mode == "fast": + return self._bake_fast( + images, self.w2cs, self.projections, texture_size, masks + ) + elif mode == "opt": + return self._bake_opt( + images, + self.w2cs, + self.projections, + texture_size, + lambda_tv, + masks, + opt_step, + ) + else: + raise ValueError(f"Unknown mode: {mode}") + + +def kaolin_to_opencv_view(raw_matrix): + R_orig = raw_matrix[:, :3, :3] + t_orig = raw_matrix[:, :3, 3] + + R_target = torch.zeros_like(R_orig) + R_target[:, :, 0] = R_orig[:, :, 2] + R_target[:, :, 1] = R_orig[:, :, 0] + R_target[:, :, 2] = R_orig[:, :, 1] + + t_target = t_orig + + target_matrix = ( + torch.eye(4, device=raw_matrix.device) + .unsqueeze(0) + .repeat(raw_matrix.size(0), 1, 1) + ) + target_matrix[:, :3, :3] = R_target + target_matrix[:, :3, 3] = t_target + + return target_matrix + + +def parse_args(): + parser = argparse.ArgumentParser(description="Render settings") + + parser.add_argument( + "--mesh_path", + type=str, + nargs="+", + required=True, + help="Paths to the mesh files for rendering.", + ) + parser.add_argument( + "--color_path", + type=str, + nargs="+", + required=True, + help="Paths to the mesh files for rendering.", + ) + parser.add_argument( + "--output_root", + type=str, + default="./outputs", + help="Root directory for output", + ) + parser.add_argument( + "--uuid", + type=str, + nargs="+", + default=None, + help="uuid for rendering saving.", + ) + parser.add_argument( + "--num_images", type=int, default=6, help="Number of images to render." + ) + parser.add_argument( + "--elevation", + type=float, + nargs="+", + default=[20.0, -10.0], + help="Elevation angles for the camera (default: [20.0, -10.0])", + ) + parser.add_argument( + "--distance", + type=float, + default=5, + help="Camera distance (default: 5)", + ) + parser.add_argument( + "--resolution_hw", + type=int, + nargs=2, + default=(512, 512), + help="Resolution of the output images (default: (512, 512))", + ) + parser.add_argument( + "--fov", + type=float, + default=30, + help="Field of view in degrees (default: 30)", + ) + parser.add_argument( + "--device", + type=str, + choices=["cpu", "cuda"], + default="cuda", + help="Device to run on (default: `cuda`)", + ) + parser.add_argument( + "--texture_size", + type=int, + default=1024, + help="Texture size for texture baking (default: 1024)", + ) + parser.add_argument( + "--baker_mode", + type=str, + default="opt", + help="Texture baking mode, `fast` or `opt` (default: opt)", + ) + parser.add_argument( + "--opt_step", + type=int, + default=2500, + help="Optimization steps for texture baking (default: 2500)", + ) + parser.add_argument( + "--mesh_sipmlify_ratio", + type=float, + default=0.9, + help="Mesh simplification ratio (default: 0.9)", + ) + parser.add_argument( + "--no_coor_trans", + action="store_true", + help="Do not transform the asset coordinate system.", + ) + parser.add_argument( + "--delight", action="store_true", help="Use delighting model." + ) + parser.add_argument( + "--skip_fix_mesh", action="store_true", help="Fix mesh geometry." + ) + + args = parser.parse_args() + + if args.uuid is None: + args.uuid = [] + for path in args.mesh_path: + uuid = os.path.basename(path).split(".")[0] + args.uuid.append(uuid) + + return args + + +def entrypoint() -> None: + args = parse_args() + camera_params = CameraSetting( + num_images=args.num_images, + elevation=args.elevation, + distance=args.distance, + resolution_hw=args.resolution_hw, + fov=math.radians(args.fov), + device=args.device, + ) + + for mesh_path, uuid, img_path in zip( + args.mesh_path, args.uuid, args.color_path + ): + mesh = trimesh.load(mesh_path) + if isinstance(mesh, trimesh.Scene): + mesh = mesh.dump(concatenate=True) + vertices, scale, center = normalize_vertices_array(mesh.vertices) + + if not args.no_coor_trans: + x_rot = np.array([[1, 0, 0], [0, 0, 1], [0, -1, 0]]) + z_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]]) + vertices = vertices @ x_rot + vertices = vertices @ z_rot + + faces = mesh.faces.astype(np.int32) + vertices = vertices.astype(np.float32) + + if not args.skip_fix_mesh: + mesh_fixer = MeshFixer(vertices, faces, args.device) + vertices, faces = mesh_fixer( + filter_ratio=args.mesh_sipmlify_ratio, + max_hole_size=0.04, + resolution=1024, + num_views=1000, + norm_mesh_ratio=0.5, + ) + + vertices, faces, uvs = TextureBaker.parametrize_mesh(vertices, faces) + texture_backer = TextureBaker( + vertices, + faces, + uvs, + camera_params, + ) + images = get_images_from_grid( + img_path, img_size=camera_params.resolution_hw[0] + ) + if args.delight: + delight_model = DelightingModel() + images = [delight_model(img) for img in images] + + images = [np.array(img) for img in images] + texture = texture_backer.bake_texture( + images=[img[..., :3] for img in images], + texture_size=args.texture_size, + mode=args.baker_mode, + opt_step=args.opt_step, + ) + texture = post_process_texture(texture) + + if not args.no_coor_trans: + vertices = vertices @ np.linalg.inv(z_rot) + vertices = vertices @ np.linalg.inv(x_rot) + vertices = vertices / scale + vertices = vertices + center + + output_path = os.path.join(args.output_root, f"{uuid}.obj") + mesh = save_mesh_with_mtl(vertices, faces, uvs, texture, output_path) + + return + + +if __name__ == "__main__": + entrypoint() diff --git a/embodied_gen/data/backproject_v2.py b/embodied_gen/data/backproject_v2.py new file mode 100644 index 0000000..efe1b34 --- /dev/null +++ b/embodied_gen/data/backproject_v2.py @@ -0,0 +1,702 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +import argparse +import logging +import math +import os + +import cv2 +import numpy as np +import nvdiffrast.torch as dr +import spaces +import torch +import torch.nn.functional as F +import trimesh +import xatlas +from PIL import Image +from embodied_gen.data.mesh_operator import MeshFixer +from embodied_gen.data.utils import ( + CameraSetting, + DiffrastRender, + get_images_from_grid, + init_kal_camera, + normalize_vertices_array, + post_process_texture, + save_mesh_with_mtl, +) +from embodied_gen.models.delight_model import DelightingModel +from embodied_gen.models.sr_model import ImageRealESRGAN + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO +) +logger = logging.getLogger(__name__) + + +__all__ = [ + "TextureBacker", +] + + +def _transform_vertices( + mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False +) -> torch.Tensor: + """Transform 3D vertices using a projection matrix.""" + t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype) + if pos.size(-1) == 3: + pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1) + + result = pos @ t_mtx.T + + return result if keepdim else result.unsqueeze(0) + + +def _bilinear_interpolation_scattering( + image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor +) -> torch.Tensor: + """Bilinear interpolation scattering for grid-based value accumulation.""" + device = values.device + dtype = values.dtype + C = values.shape[-1] + + indices = coords * torch.tensor( + [image_h - 1, image_w - 1], dtype=dtype, device=device + ) + i, j = indices.unbind(-1) + + i0, j0 = ( + indices.floor() + .long() + .clamp(0, image_h - 2) + .clamp(0, image_w - 2) + .unbind(-1) + ) + i1, j1 = i0 + 1, j0 + 1 + + w_i = i - i0.float() + w_j = j - j0.float() + weights = torch.stack( + [(1 - w_i) * (1 - w_j), (1 - w_i) * w_j, w_i * (1 - w_j), w_i * w_j], + dim=1, + ) + + indices_comb = torch.stack( + [ + torch.stack([i0, j0], dim=1), + torch.stack([i0, j1], dim=1), + torch.stack([i1, j0], dim=1), + torch.stack([i1, j1], dim=1), + ], + dim=1, + ) + + grid = torch.zeros(image_h, image_w, C, device=device, dtype=dtype) + cnt = torch.zeros(image_h, image_w, 1, device=device, dtype=dtype) + + for k in range(4): + idx = indices_comb[:, k] + w = weights[:, k].unsqueeze(-1) + + stride = torch.tensor([image_w, 1], device=device, dtype=torch.long) + flat_idx = (idx * stride).sum(-1) + + grid.view(-1, C).scatter_add_( + 0, flat_idx.unsqueeze(-1).expand(-1, C), values * w + ) + cnt.view(-1, 1).scatter_add_(0, flat_idx.unsqueeze(-1), w) + + mask = cnt.squeeze(-1) > 0 + grid[mask] = grid[mask] / cnt[mask].repeat(1, C) + + return grid + + +def _texture_inpaint_smooth( + texture: np.ndarray, + mask: np.ndarray, + vertices: np.ndarray, + faces: np.ndarray, + uv_map: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """Perform texture inpainting using vertex-based color propagation.""" + image_h, image_w, C = texture.shape + N = vertices.shape[0] + + # Initialize vertex data structures + vtx_mask = np.zeros(N, dtype=np.float32) + vtx_colors = np.zeros((N, C), dtype=np.float32) + unprocessed = [] + adjacency = [[] for _ in range(N)] + + # Build adjacency graph and initial color assignment + for face_idx in range(faces.shape[0]): + for k in range(3): + uv_idx_k = faces[face_idx, k] + v_idx = faces[face_idx, k] + + # Convert UV to pixel coordinates with boundary clamping + u = np.clip( + int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1 + ) + v = np.clip( + int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))), + 0, + image_h - 1, + ) + + if mask[v, u]: + vtx_mask[v_idx] = 1.0 + vtx_colors[v_idx] = texture[v, u] + elif v_idx not in unprocessed: + unprocessed.append(v_idx) + + # Build undirected adjacency graph + neighbor = faces[face_idx, (k + 1) % 3] + if neighbor not in adjacency[v_idx]: + adjacency[v_idx].append(neighbor) + if v_idx not in adjacency[neighbor]: + adjacency[neighbor].append(v_idx) + + # Color propagation with dynamic stopping + remaining_iters, prev_count = 2, 0 + while remaining_iters > 0: + current_unprocessed = [] + + for v_idx in unprocessed: + valid_neighbors = [n for n in adjacency[v_idx] if vtx_mask[n] > 0] + if not valid_neighbors: + current_unprocessed.append(v_idx) + continue + + # Calculate inverse square distance weights + neighbors_pos = vertices[valid_neighbors] + dist_sq = np.sum((vertices[v_idx] - neighbors_pos) ** 2, axis=1) + weights = 1 / np.maximum(dist_sq, 1e-8) + + vtx_colors[v_idx] = np.average( + vtx_colors[valid_neighbors], weights=weights, axis=0 + ) + vtx_mask[v_idx] = 1.0 + + # Update iteration control + if len(current_unprocessed) == prev_count: + remaining_iters -= 1 + else: + remaining_iters = min(remaining_iters + 1, 2) + prev_count = len(current_unprocessed) + unprocessed = current_unprocessed + + # Generate output texture + inpainted_texture, updated_mask = texture.copy(), mask.copy() + for face_idx in range(faces.shape[0]): + for k in range(3): + v_idx = faces[face_idx, k] + if not vtx_mask[v_idx]: + continue + + # UV coordinate conversion + uv_idx_k = faces[face_idx, k] + u = np.clip( + int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1 + ) + v = np.clip( + int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))), + 0, + image_h - 1, + ) + + inpainted_texture[v, u] = vtx_colors[v_idx] + updated_mask[v, u] = 255 + + return inpainted_texture, updated_mask + + +class TextureBacker: + """Texture baking pipeline for multi-view projection and fusion. + + This class performs UV-based texture generation for a 3D mesh using + multi-view color images, depth, and normal information. The pipeline + includes mesh normalization and UV unwrapping, visibility-aware + back-projection, confidence-weighted texture fusion, and inpainting + of missing texture regions. + + Args: + camera_params (CameraSetting): Camera intrinsics and extrinsics used + for rendering each view. + view_weights (list[float]): A list of weights for each view, used + to blend confidence maps during texture fusion. + render_wh (tuple[int, int], optional): Resolution (width, height) for + intermediate rendering passes. Defaults to (2048, 2048). + texture_wh (tuple[int, int], optional): Output texture resolution + (width, height). Defaults to (2048, 2048). + bake_angle_thresh (int, optional): Maximum angle (in degrees) between + view direction and surface normal for projection to be considered valid. + Defaults to 75. + mask_thresh (float, optional): Threshold applied to visibility masks + during rendering. Defaults to 0.5. + smooth_texture (bool, optional): If True, apply post-processing (e.g., + blurring) to the final texture. Defaults to True. + """ + + def __init__( + self, + camera_params: CameraSetting, + view_weights: list[float], + render_wh: tuple[int, int] = (2048, 2048), + texture_wh: tuple[int, int] = (2048, 2048), + bake_angle_thresh: int = 75, + mask_thresh: float = 0.5, + smooth_texture: bool = True, + ) -> None: + self.camera_params = camera_params + self.renderer = None + self.view_weights = view_weights + self.device = camera_params.device + self.render_wh = render_wh + self.texture_wh = texture_wh + self.mask_thresh = mask_thresh + self.smooth_texture = smooth_texture + + self.bake_angle_thresh = bake_angle_thresh + self.bake_unreliable_kernel_size = int( + (2 / 512) * max(self.render_wh[0], self.render_wh[1]) + ) + + def _lazy_init_render(self, camera_params, mask_thresh): + if self.renderer is None: + camera = init_kal_camera(camera_params) + mv = camera.view_matrix() # (n 4 4) world2cam + p = camera.intrinsics.projection_matrix() + # NOTE: add a negative sign at P[0, 2] as the y axis is flipped in `nvdiffrast` output. # noqa + p[:, 1, 1] = -p[:, 1, 1] + self.renderer = DiffrastRender( + p_matrix=p, + mv_matrix=mv, + resolution_hw=camera_params.resolution_hw, + context=dr.RasterizeCudaContext(), + mask_thresh=mask_thresh, + grad_db=False, + device=self.device, + antialias_mask=True, + ) + + def load_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh: + mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices) + self.scale, self.center = scale, center + + vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces) + uvs[:, 1] = 1 - uvs[:, 1] + mesh.vertices = mesh.vertices[vmapping] + mesh.faces = indices + mesh.visual.uv = uvs + + return mesh + + def get_mesh_np_attrs( + self, + mesh: trimesh.Trimesh, + scale: float = None, + center: np.ndarray = None, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + vertices = mesh.vertices.copy() + faces = mesh.faces.copy() + uv_map = mesh.visual.uv.copy() + uv_map[:, 1] = 1.0 - uv_map[:, 1] + + if scale is not None: + vertices = vertices / scale + if center is not None: + vertices = vertices + center + + return vertices, faces, uv_map + + def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor: + depth_image_np = depth_image.cpu().numpy() + depth_image_np = (depth_image_np * 255).astype(np.uint8) + depth_edges = cv2.Canny(depth_image_np, 30, 80) + sketch_image = ( + torch.from_numpy(depth_edges).to(depth_image.device).float() / 255 + ) + sketch_image = sketch_image.unsqueeze(-1) + + return sketch_image + + def compute_enhanced_viewnormal( + self, mv_mtx: torch.Tensor, vertices: torch.Tensor, faces: torch.Tensor + ) -> torch.Tensor: + rast, _ = self.renderer.compute_dr_raster(vertices, faces) + rendered_view_normals = [] + for idx in range(len(mv_mtx)): + pos_cam = _transform_vertices(mv_mtx[idx], vertices, keepdim=True) + pos_cam = pos_cam[:, :3] / pos_cam[:, 3:] + v0, v1, v2 = (pos_cam[faces[:, i]] for i in range(3)) + face_norm = F.normalize( + torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1 + ) + vertex_norm = ( + torch.from_numpy( + trimesh.geometry.mean_vertex_normals( + len(pos_cam), faces.cpu(), face_norm.cpu() + ) + ) + .to(vertices.device) + .contiguous() + ) + im_base_normals, _ = dr.interpolate( + vertex_norm[None, ...].float(), + rast[idx : idx + 1], + faces.to(torch.int32), + ) + rendered_view_normals.append(im_base_normals) + + rendered_view_normals = torch.cat(rendered_view_normals, dim=0) + + return rendered_view_normals + + def back_project( + self, image, vis_mask, depth, normal, uv + ) -> tuple[torch.Tensor, torch.Tensor]: + image = np.array(image) + image = torch.as_tensor(image, device=self.device, dtype=torch.float32) + if image.ndim == 2: + image = image.unsqueeze(-1) + image = image / 255 + + depth_inv = (1.0 - depth) * vis_mask + sketch_image = self._render_depth_edges(depth_inv) + + cos = F.cosine_similarity( + torch.tensor([[0, 0, 1]], device=self.device), + normal.view(-1, 3), + ).view_as(normal[..., :1]) + cos[cos < np.cos(np.radians(self.bake_angle_thresh))] = 0 + + k = self.bake_unreliable_kernel_size * 2 + 1 + kernel = torch.ones((1, 1, k, k), device=self.device) + + vis_mask = vis_mask.permute(2, 0, 1).unsqueeze(0).float() + vis_mask = F.conv2d( + 1.0 - vis_mask, + kernel, + padding=k // 2, + ) + vis_mask = 1.0 - (vis_mask > 0).float() + vis_mask = vis_mask.squeeze(0).permute(1, 2, 0) + + sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0) + sketch_image = F.conv2d(sketch_image, kernel, padding=k // 2) + sketch_image = (sketch_image > 0).float() + sketch_image = sketch_image.squeeze(0).permute(1, 2, 0) + vis_mask = vis_mask * (sketch_image < 0.5) + + cos[vis_mask == 0] = 0 + valid_pixels = (vis_mask != 0).view(-1) + + return ( + self._scatter_texture(uv, image, valid_pixels), + self._scatter_texture(uv, cos, valid_pixels), + ) + + def _scatter_texture(self, uv, data, mask): + def __filter_data(data, mask): + return data.view(-1, data.shape[-1])[mask] + + return _bilinear_interpolation_scattering( + self.texture_wh[1], + self.texture_wh[0], + __filter_data(uv, mask)[..., [1, 0]], + __filter_data(data, mask), + ) + + @torch.no_grad() + def fast_bake_texture( + self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor] + ) -> tuple[torch.Tensor, torch.Tensor]: + channel = textures[0].shape[-1] + texture_merge = torch.zeros(self.texture_wh + [channel]).to( + self.device + ) + trust_map_merge = torch.zeros(self.texture_wh + [1]).to(self.device) + for texture, cos_map in zip(textures, confidence_maps): + view_sum = (cos_map > 0).sum() + painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum() + if painted_sum / view_sum > 0.99: + continue + texture_merge += texture * cos_map + trust_map_merge += cos_map + texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1e-8) + + return texture_merge, trust_map_merge > 1e-8 + + def uv_inpaint( + self, mesh: trimesh.Trimesh, texture: np.ndarray, mask: np.ndarray + ) -> np.ndarray: + vertices, faces, uv_map = self.get_mesh_np_attrs(mesh) + + texture, mask = _texture_inpaint_smooth( + texture, mask, vertices, faces, uv_map + ) + texture = texture.clip(0, 1) + texture = cv2.inpaint( + (texture * 255).astype(np.uint8), + 255 - mask, + 3, + cv2.INPAINT_NS, + ) + + return texture + + @spaces.GPU + def compute_texture( + self, + colors: list[Image.Image], + mesh: trimesh.Trimesh, + ) -> trimesh.Trimesh: + self._lazy_init_render(self.camera_params, self.mask_thresh) + + vertices = torch.from_numpy(mesh.vertices).to(self.device).float() + faces = torch.from_numpy(mesh.faces).to(self.device).to(torch.int) + uv_map = torch.from_numpy(mesh.visual.uv).to(self.device).float() + + rendered_depth, masks = self.renderer.render_depth(vertices, faces) + norm_deps = self.renderer.normalize_map_by_mask(rendered_depth, masks) + render_uvs, _ = self.renderer.render_uv(vertices, faces, uv_map) + view_normals = self.compute_enhanced_viewnormal( + self.renderer.mv_mtx, vertices, faces + ) + + textures, weighted_cos_maps = [], [] + for color, mask, dep, normal, uv, weight in zip( + colors, + masks, + norm_deps, + view_normals, + render_uvs, + self.view_weights, + ): + texture, cos_map = self.back_project(color, mask, dep, normal, uv) + textures.append(texture) + weighted_cos_maps.append(weight * (cos_map**4)) + + texture, mask = self.fast_bake_texture(textures, weighted_cos_maps) + + texture_np = texture.cpu().numpy() + mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8) + + return texture_np, mask_np + + def __call__( + self, + colors: list[Image.Image], + mesh: trimesh.Trimesh, + output_path: str, + ) -> trimesh.Trimesh: + """Runs the texture baking and exports the textured mesh. + + Args: + colors (list[Image.Image]): List of input view images. + mesh (trimesh.Trimesh): Input mesh to be textured. + output_path (str): Path to save the output textured mesh (.obj or .glb). + + Returns: + trimesh.Trimesh: The textured mesh with UV and texture image. + """ + mesh = self.load_mesh(mesh) + texture_np, mask_np = self.compute_texture(colors, mesh) + + texture_np = self.uv_inpaint(mesh, texture_np, mask_np) + if self.smooth_texture: + texture_np = post_process_texture(texture_np) + + vertices, faces, uv_map = self.get_mesh_np_attrs( + mesh, self.scale, self.center + ) + textured_mesh = save_mesh_with_mtl( + vertices, faces, uv_map, texture_np, output_path + ) + + return textured_mesh + + +def parse_args(): + parser = argparse.ArgumentParser(description="Backproject texture") + parser.add_argument( + "--color_path", + type=str, + help="Multiview color image in 6x512x512 file path", + ) + parser.add_argument( + "--mesh_path", + type=str, + help="Mesh path, .obj, .glb or .ply", + ) + parser.add_argument( + "--output_path", + type=str, + help="Output mesh path with suffix", + ) + parser.add_argument( + "--num_images", type=int, default=6, help="Number of images to render." + ) + parser.add_argument( + "--elevation", + nargs=2, + type=float, + default=[20.0, -10.0], + help="Elevation angles for the camera (default: [20.0, -10.0])", + ) + parser.add_argument( + "--distance", + type=float, + default=5, + help="Camera distance (default: 5)", + ) + parser.add_argument( + "--resolution_hw", + type=int, + nargs=2, + default=(2048, 2048), + help="Resolution of the output images (default: (2048, 2048))", + ) + parser.add_argument( + "--fov", + type=float, + default=30, + help="Field of view in degrees (default: 30)", + ) + parser.add_argument( + "--device", + type=str, + choices=["cpu", "cuda"], + default="cuda", + help="Device to run on (default: `cuda`)", + ) + parser.add_argument( + "--skip_fix_mesh", action="store_true", help="Fix mesh geometry." + ) + parser.add_argument( + "--texture_wh", + nargs=2, + type=int, + default=[2048, 2048], + help="Texture resolution width and height", + ) + parser.add_argument( + "--mesh_sipmlify_ratio", + type=float, + default=0.9, + help="Mesh simplification ratio (default: 0.9)", + ) + parser.add_argument( + "--delight", action="store_true", help="Use delighting model." + ) + parser.add_argument( + "--no_smooth_texture", + action="store_true", + help="Do not smooth the texture.", + ) + parser.add_argument( + "--save_glb_path", type=str, default=None, help="Save glb path." + ) + parser.add_argument( + "--no_save_delight_img", + action="store_true", + help="Disable saving delight image", + ) + + args, unknown = parser.parse_known_args() + + return args + + +def entrypoint( + delight_model: DelightingModel = None, + imagesr_model: ImageRealESRGAN = None, + **kwargs, +) -> trimesh.Trimesh: + args = parse_args() + for k, v in kwargs.items(): + if hasattr(args, k) and v is not None: + setattr(args, k, v) + + # Setup camera parameters. + camera_params = CameraSetting( + num_images=args.num_images, + elevation=args.elevation, + distance=args.distance, + resolution_hw=args.resolution_hw, + fov=math.radians(args.fov), + device=args.device, + ) + view_weights = [1, 0.1, 0.02, 0.1, 1, 0.02] + + color_grid = Image.open(args.color_path) + if args.delight: + if delight_model is None: + delight_model = DelightingModel() + save_dir = os.path.dirname(args.output_path) + os.makedirs(save_dir, exist_ok=True) + color_grid = delight_model(color_grid) + if not args.no_save_delight_img: + color_grid.save(f"{save_dir}/color_grid_delight.png") + + multiviews = get_images_from_grid(color_grid, img_size=512) + + # Use RealESRGAN_x4plus for x4 (512->2048) image super resolution. + if imagesr_model is None: + imagesr_model = ImageRealESRGAN(outscale=4) + multiviews = [imagesr_model(img) for img in multiviews] + multiviews = [img.convert("RGB") for img in multiviews] + mesh = trimesh.load(args.mesh_path) + if isinstance(mesh, trimesh.Scene): + mesh = mesh.dump(concatenate=True) + + if not args.skip_fix_mesh: + mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices) + mesh_fixer = MeshFixer(mesh.vertices, mesh.faces, args.device) + mesh.vertices, mesh.faces = mesh_fixer( + filter_ratio=args.mesh_sipmlify_ratio, + max_hole_size=0.04, + resolution=1024, + num_views=1000, + norm_mesh_ratio=0.5, + ) + # Restore scale. + mesh.vertices = mesh.vertices / scale + mesh.vertices = mesh.vertices + center + + # Baking texture to mesh. + texture_backer = TextureBacker( + camera_params=camera_params, + view_weights=view_weights, + render_wh=camera_params.resolution_hw, + texture_wh=args.texture_wh, + smooth_texture=not args.no_smooth_texture, + ) + + textured_mesh = texture_backer(multiviews, mesh, args.output_path) + + if args.save_glb_path is not None: + os.makedirs(os.path.dirname(args.save_glb_path), exist_ok=True) + textured_mesh.export(args.save_glb_path) + + return textured_mesh + + +if __name__ == "__main__": + entrypoint() diff --git a/embodied_gen/data/datasets.py b/embodied_gen/data/datasets.py new file mode 100644 index 0000000..4a9563a --- /dev/null +++ b/embodied_gen/data/datasets.py @@ -0,0 +1,256 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +import json +import logging +import os +import random +from typing import Any, Callable, Dict, List, Tuple + +import torch +import torch.utils.checkpoint +from PIL import Image +from torch import nn +from torch.utils.data import Dataset +from torchvision import transforms + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO +) +logger = logging.getLogger(__name__) + + +__all__ = [ + "Asset3dGenDataset", +] + + +class Asset3dGenDataset(Dataset): + def __init__( + self, + index_file: str, + target_hw: Tuple[int, int], + transform: Callable = None, + control_transform: Callable = None, + max_train_samples: int = None, + sub_idxs: List[List[int]] = None, + seed: int = 79, + ) -> None: + if not os.path.exists(index_file): + raise FileNotFoundError(f"{index_file} index_file not found.") + + self.index_file = index_file + self.target_hw = target_hw + self.transform = transform + self.control_transform = control_transform + self.max_train_samples = max_train_samples + self.meta_info = self.prepare_data_index(index_file) + self.data_list = sorted(self.meta_info.keys()) + self.sub_idxs = sub_idxs # sub_idxs [[0,1,2], [3,4,5], [...], ...] + self.image_num = 6 # hardcode temp. + random.seed(seed) + logger.info(f"Trainset: {len(self)} asset3d instances.") + + def __len__(self) -> int: + return len(self.meta_info) + + def prepare_data_index(self, index_file: str) -> Dict[str, Any]: + with open(index_file, "r") as fin: + meta_info = json.load(fin) + + meta_info_filtered = dict() + for idx, uid in enumerate(meta_info): + if "status" not in meta_info[uid]: + continue + if meta_info[uid]["status"] != "success": + continue + if self.max_train_samples and idx >= self.max_train_samples: + break + + meta_info_filtered[uid] = meta_info[uid] + + logger.info( + f"Load {len(meta_info)} assets, keep {len(meta_info_filtered)} valids." # noqa + ) + + return meta_info_filtered + + def fetch_sample_images( + self, + uid: str, + attrs: List[str], + sub_index: int = None, + transform: Callable = None, + ) -> torch.Tensor: + sample = self.meta_info[uid] + images = [] + for attr in attrs: + item = sample[attr] + if sub_index is not None: + item = item[sub_index] + mode = "L" if attr == "image_mask" else "RGB" + image = Image.open(item).convert(mode) + if transform is not None: + image = transform(image) + if len(image.shape) == 2: + image = image[..., None] + images.append(image) + + images = torch.cat(images, dim=0) + + return images + + def fetch_sample_grid_images( + self, + uid: str, + attrs: List[str], + sub_idxs: List[List[int]], + transform: Callable = None, + ) -> torch.Tensor: + assert transform is not None + + grid_image = [] + for row_idxs in sub_idxs: + row_image = [] + for row_idx in row_idxs: + image = self.fetch_sample_images( + uid, attrs, row_idx, transform + ) + row_image.append(image) + row_image = torch.cat(row_image, dim=2) # (c h w) + grid_image.append(row_image) + + grid_image = torch.cat(grid_image, dim=1) + + return grid_image + + def compute_text_embeddings( + self, embed_path: str, original_size: Tuple[int, int] + ) -> Dict[str, nn.Module]: + data_dict = torch.load(embed_path) + prompt_embeds = data_dict["prompt_embeds"][0] + add_text_embeds = data_dict["pooled_prompt_embeds"][0] + + # Need changed if random crop, set as crop_top_left [y1, x1], center crop as [0, 0]. # noqa + crops_coords_top_left = (0, 0) + add_time_ids = list( + original_size + crops_coords_top_left + self.target_hw + ) + add_time_ids = torch.tensor([add_time_ids]) + # add_time_ids = add_time_ids.repeat((len(add_text_embeds), 1)) + + unet_added_cond_kwargs = { + "text_embeds": add_text_embeds, + "time_ids": add_time_ids, + } + + return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs} + + def visualize_item( + self, + control: torch.Tensor, + color: torch.Tensor, + save_dir: str = None, + ) -> List[Image.Image]: + to_pil = transforms.ToPILImage() + + color = (color + 1) / 2 + color_pil = to_pil(color) + normal_pil = to_pil(control[0:3]) + position_pil = to_pil(control[3:6]) + mask_pil = to_pil(control[6:]) + + if save_dir is not None: + os.makedirs(save_dir, exist_ok=True) + color_pil.save(f"{save_dir}/rgb.jpg") + normal_pil.save(f"{save_dir}/normal.jpg") + position_pil.save(f"{save_dir}/position.jpg") + mask_pil.save(f"{save_dir}/mask.jpg") + logger.info(f"Visualization in {save_dir}") + + return normal_pil, position_pil, mask_pil, color_pil + + def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: + uid = self.data_list[index] + + sub_idxs = self.sub_idxs + if sub_idxs is None: + sub_idxs = [[random.randint(0, self.image_num - 1)]] + + input_image = self.fetch_sample_grid_images( + uid, + attrs=["image_view_normal", "image_position", "image_mask"], + sub_idxs=sub_idxs, + transform=self.control_transform, + ) + assert input_image.shape[1:] == self.target_hw + + output_image = self.fetch_sample_grid_images( + uid, + attrs=["image_color"], + sub_idxs=sub_idxs, + transform=self.transform, + ) + + sample = self.meta_info[uid] + text_feats = self.compute_text_embeddings( + sample["text_feat"], tuple(sample["image_hw"]) + ) + + data = dict( + pixel_values=output_image, + conditioning_pixel_values=input_image, + prompt_embeds=text_feats["prompt_embeds"], + text_embeds=text_feats["text_embeds"], + time_ids=text_feats["time_ids"], + ) + + return data + + +if __name__ == "__main__": + index_file = "datasets/objaverse/v1.0/statistics_1.0_gobjaverse_filter/view6s_v4/meta_ac2e0ddea8909db26d102c8465b5bcb2.json" # noqa + target_hw = (512, 512) + transform_list = [ + transforms.Resize( + target_hw, interpolation=transforms.InterpolationMode.BILINEAR + ), + transforms.CenterCrop(target_hw), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + image_transform = transforms.Compose(transform_list) + control_transform = transforms.Compose(transform_list[:-1]) + + sub_idxs = [[0, 1, 2], [3, 4, 5]] # None + if sub_idxs is not None: + target_hw = ( + target_hw[0] * len(sub_idxs), + target_hw[1] * len(sub_idxs[0]), + ) + + dataset = Asset3dGenDataset( + index_file, + target_hw, + image_transform, + control_transform, + sub_idxs=sub_idxs, + ) + data = dataset[0] + dataset.visualize_item( + data["conditioning_pixel_values"], data["pixel_values"], save_dir="./" + ) diff --git a/embodied_gen/data/differentiable_render.py b/embodied_gen/data/differentiable_render.py new file mode 100644 index 0000000..18b0a86 --- /dev/null +++ b/embodied_gen/data/differentiable_render.py @@ -0,0 +1,526 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +import argparse +import json +import logging +import math +import os +from collections import defaultdict +from typing import List, Union + +import cv2 +import nvdiffrast.torch as dr +import torch +from tqdm import tqdm +from embodied_gen.data.utils import ( + CameraSetting, + DiffrastRender, + RenderItems, + as_list, + calc_vertex_normals, + import_kaolin_mesh, + init_kal_camera, + normalize_vertices_array, + render_pbr, + save_images, +) +from embodied_gen.utils.process_media import ( + create_gif_from_images, + create_mp4_from_images, +) + +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser( + "~/.cache/torch_extensions" +) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO +) +logger = logging.getLogger(__name__) + + +__all__ = ["ImageRender"] + + +class ImageRender(object): + """A differentiable mesh renderer supporting multi-view rendering. + + This class wraps a differentiable rasterization using `nvdiffrast` to + render mesh geometry to various maps (normal, depth, alpha, albedo, etc.). + + Args: + render_items (list[RenderItems]): A list of rendering targets to + generate (e.g., IMAGE, DEPTH, NORMAL, etc.). + camera_params (CameraSetting): The camera parameters for rendering, + including intrinsic and extrinsic matrices. + recompute_vtx_normal (bool, optional): If True, recomputes + vertex normals from the mesh geometry. Defaults to True. + with_mtl (bool, optional): Whether to load `.mtl` material files + for meshes. Defaults to False. + gen_color_gif (bool, optional): Generate a GIF of rendered + color images. Defaults to False. + gen_color_mp4 (bool, optional): Generate an MP4 video of rendered + color images. Defaults to False. + gen_viewnormal_mp4 (bool, optional): Generate an MP4 video of + view-space normals. Defaults to False. + gen_glonormal_mp4 (bool, optional): Generate an MP4 video of + global-space normals. Defaults to False. + no_index_file (bool, optional): If True, skip saving the `index.json` + summary file. Defaults to False. + light_factor (float, optional): A scalar multiplier for + PBR light intensity. Defaults to 1.0. + """ + + def __init__( + self, + render_items: list[RenderItems], + camera_params: CameraSetting, + recompute_vtx_normal: bool = True, + with_mtl: bool = False, + gen_color_gif: bool = False, + gen_color_mp4: bool = False, + gen_viewnormal_mp4: bool = False, + gen_glonormal_mp4: bool = False, + no_index_file: bool = False, + light_factor: float = 1.0, + ) -> None: + camera = init_kal_camera(camera_params) + self.camera = camera + + # Setup MVP matrix and renderer. + mv = camera.view_matrix() # (n 4 4) world2cam + p = camera.intrinsics.projection_matrix() + # NOTE: add a negative sign at P[0, 2] as the y axis is flipped in `nvdiffrast` output. # noqa + p[:, 1, 1] = -p[:, 1, 1] + # mvp = torch.bmm(p, mv) # camera.view_projection_matrix() + self.mv = mv + self.p = p + + renderer = DiffrastRender( + p_matrix=p, + mv_matrix=mv, + resolution_hw=camera_params.resolution_hw, + context=dr.RasterizeCudaContext(), + mask_thresh=0.5, + grad_db=False, + device=camera_params.device, + antialias_mask=True, + ) + self.renderer = renderer + self.recompute_vtx_normal = recompute_vtx_normal + self.render_items = render_items + self.device = camera_params.device + self.with_mtl = with_mtl + self.gen_color_gif = gen_color_gif + self.gen_color_mp4 = gen_color_mp4 + self.gen_viewnormal_mp4 = gen_viewnormal_mp4 + self.gen_glonormal_mp4 = gen_glonormal_mp4 + self.light_factor = light_factor + self.no_index_file = no_index_file + + def render_mesh( + self, + mesh_path: Union[str, List[str]], + output_root: str, + uuid: Union[str, List[str]] = None, + prompts: List[str] = None, + ) -> None: + mesh_path = as_list(mesh_path) + if uuid is None: + uuid = [os.path.basename(p).split(".")[0] for p in mesh_path] + uuid = as_list(uuid) + assert len(mesh_path) == len(uuid) + os.makedirs(output_root, exist_ok=True) + + meta_info = dict() + for idx, (path, uid) in tqdm( + enumerate(zip(mesh_path, uuid)), total=len(mesh_path) + ): + output_dir = os.path.join(output_root, uid) + os.makedirs(output_dir, exist_ok=True) + prompt = prompts[idx] if prompts else None + data_dict = self(path, output_dir, prompt) + meta_info[uid] = data_dict + + if self.no_index_file: + return + + index_file = os.path.join(output_root, "index.json") + with open(index_file, "w") as fout: + json.dump(meta_info, fout) + + logger.info(f"Rendering meta info logged in {index_file}") + + def __call__( + self, mesh_path: str, output_dir: str, prompt: str = None + ) -> dict[str, str]: + """Render a single mesh and return paths to the rendered outputs. + + Processes the input mesh, renders multiple modalities (e.g., normals, + depth, albedo), and optionally saves video or image sequences. + + Args: + mesh_path (str): Path to the mesh file (.obj/.glb). + output_dir (str): Directory to save rendered outputs. + prompt (str, optional): Optional caption prompt for MP4 metadata. + + Returns: + dict[str, str]: A mapping render types to the saved image paths. + """ + try: + mesh = import_kaolin_mesh(mesh_path, self.with_mtl) + except Exception as e: + logger.error(f"[ERROR MESH LOAD]: {e}, skip {mesh_path}") + return + + mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices) + if self.recompute_vtx_normal: + mesh.vertex_normals = calc_vertex_normals( + mesh.vertices, mesh.faces + ) + + mesh = mesh.to(self.device) + vertices, faces, vertex_normals = ( + mesh.vertices, + mesh.faces, + mesh.vertex_normals, + ) + + # Perform rendering. + data_dict = defaultdict(list) + if RenderItems.ALPHA.value in self.render_items: + masks, _ = self.renderer.render_rast_alpha(vertices, faces) + render_paths = save_images( + masks, f"{output_dir}/{RenderItems.ALPHA}" + ) + data_dict[RenderItems.ALPHA.value] = render_paths + + if RenderItems.GLOBAL_NORMAL.value in self.render_items: + rendered_normals, masks = self.renderer.render_global_normal( + vertices, faces, vertex_normals + ) + if self.gen_glonormal_mp4: + if isinstance(rendered_normals, torch.Tensor): + rendered_normals = rendered_normals.detach().cpu().numpy() + create_mp4_from_images( + rendered_normals, + output_path=f"{output_dir}/normal.mp4", + fps=15, + prompt=prompt, + ) + else: + render_paths = save_images( + rendered_normals, + f"{output_dir}/{RenderItems.GLOBAL_NORMAL}", + cvt_color=cv2.COLOR_BGR2RGB, + ) + data_dict[RenderItems.GLOBAL_NORMAL.value] = render_paths + + if RenderItems.VIEW_NORMAL.value in self.render_items: + assert ( + RenderItems.GLOBAL_NORMAL in self.render_items + ), f"Must render global normal firstly, got render_items: {self.render_items}." # noqa + rendered_view_normals = self.renderer.transform_normal( + rendered_normals, self.mv, masks, to_view=True + ) + + if self.gen_viewnormal_mp4: + create_mp4_from_images( + rendered_view_normals, + output_path=f"{output_dir}/view_normal.mp4", + fps=15, + prompt=prompt, + ) + else: + render_paths = save_images( + rendered_view_normals, + f"{output_dir}/{RenderItems.VIEW_NORMAL}", + cvt_color=cv2.COLOR_BGR2RGB, + ) + data_dict[RenderItems.VIEW_NORMAL.value] = render_paths + + if RenderItems.POSITION_MAP.value in self.render_items: + rendered_position, masks = self.renderer.render_position( + vertices, faces + ) + norm_position = self.renderer.normalize_map_by_mask( + rendered_position, masks + ) + render_paths = save_images( + norm_position, + f"{output_dir}/{RenderItems.POSITION_MAP}", + cvt_color=cv2.COLOR_BGR2RGB, + ) + data_dict[RenderItems.POSITION_MAP.value] = render_paths + + if RenderItems.DEPTH.value in self.render_items: + rendered_depth, masks = self.renderer.render_depth(vertices, faces) + norm_depth = self.renderer.normalize_map_by_mask( + rendered_depth, masks + ) + render_paths = save_images( + norm_depth, + f"{output_dir}/{RenderItems.DEPTH}", + ) + data_dict[RenderItems.DEPTH.value] = render_paths + + render_paths = save_images( + rendered_depth, + f"{output_dir}/{RenderItems.DEPTH}_exr", + to_uint8=False, + format=".exr", + ) + data_dict[f"{RenderItems.DEPTH.value}_exr"] = render_paths + + if RenderItems.IMAGE.value in self.render_items: + images = [] + albedos = [] + diffuses = [] + masks, _ = self.renderer.render_rast_alpha(vertices, faces) + try: + for idx, cam in enumerate(self.camera): + image, albedo, diffuse, _ = render_pbr( + mesh, cam, light_factor=self.light_factor + ) + image = torch.cat([image[0], masks[idx]], axis=-1) + images.append(image.detach().cpu().numpy()) + + if RenderItems.ALBEDO.value in self.render_items: + albedo = torch.cat([albedo[0], masks[idx]], axis=-1) + albedos.append(albedo.detach().cpu().numpy()) + + if RenderItems.DIFFUSE.value in self.render_items: + diffuse = torch.cat([diffuse[0], masks[idx]], axis=-1) + diffuses.append(diffuse.detach().cpu().numpy()) + + except Exception as e: + logger.error(f"[ERROR pbr render]: {e}, skip {mesh_path}") + return + + if self.gen_color_gif: + create_gif_from_images( + images, + output_path=f"{output_dir}/color.gif", + fps=15, + ) + + if self.gen_color_mp4: + create_mp4_from_images( + images, + output_path=f"{output_dir}/color.mp4", + fps=15, + prompt=prompt, + ) + + if self.gen_color_mp4 or self.gen_color_gif: + return data_dict + + render_paths = save_images( + images, + f"{output_dir}/{RenderItems.IMAGE}", + cvt_color=cv2.COLOR_BGRA2RGBA, + ) + data_dict[RenderItems.IMAGE.value] = render_paths + + render_paths = save_images( + albedos, + f"{output_dir}/{RenderItems.ALBEDO}", + cvt_color=cv2.COLOR_BGRA2RGBA, + ) + data_dict[RenderItems.ALBEDO.value] = render_paths + + render_paths = save_images( + diffuses, + f"{output_dir}/{RenderItems.DIFFUSE}", + cvt_color=cv2.COLOR_BGRA2RGBA, + ) + data_dict[RenderItems.DIFFUSE.value] = render_paths + + data_dict["status"] = "success" + + logger.info(f"Finish rendering in {output_dir}") + + return data_dict + + +def parse_args(): + parser = argparse.ArgumentParser(description="Render settings") + + parser.add_argument( + "--mesh_path", + type=str, + nargs="+", + help="Paths to the mesh files for rendering.", + ) + parser.add_argument( + "--output_root", + type=str, + help="Root directory for output", + ) + parser.add_argument( + "--uuid", + type=str, + nargs="+", + default=None, + help="uuid for rendering saving.", + ) + parser.add_argument( + "--num_images", type=int, default=6, help="Number of images to render." + ) + parser.add_argument( + "--elevation", + type=float, + nargs="+", + default=[20.0, -10.0], + help="Elevation angles for the camera (default: [20.0, -10.0])", + ) + parser.add_argument( + "--distance", + type=float, + default=5, + help="Camera distance (default: 5)", + ) + parser.add_argument( + "--resolution_hw", + type=int, + nargs=2, + default=(512, 512), + help="Resolution of the output images (default: (512, 512))", + ) + parser.add_argument( + "--fov", + type=float, + default=30, + help="Field of view in degrees (default: 30)", + ) + parser.add_argument( + "--pbr_light_factor", + type=float, + default=1.0, + help="Light factor for mesh PBR rendering (default: 2.)", + ) + parser.add_argument( + "--with_mtl", + action="store_true", + help="Whether to render with mesh material.", + ) + parser.add_argument( + "--gen_color_gif", + action="store_true", + help="Whether to generate color .gif rendering file.", + ) + parser.add_argument( + "--gen_color_mp4", + action="store_true", + help="Whether to generate color .mp4 rendering file.", + ) + parser.add_argument( + "--gen_viewnormal_mp4", + action="store_true", + help="Whether to generate view normal .mp4 rendering file.", + ) + parser.add_argument( + "--gen_glonormal_mp4", + action="store_true", + help="Whether to generate global normal .mp4 rendering file.", + ) + parser.add_argument( + "--prompts", + type=str, + nargs="+", + default=None, + help="Text prompts for the rendering.", + ) + + args = parser.parse_args() + + if args.uuid is None and args.mesh_path is not None: + args.uuid = [] + for path in args.mesh_path: + uuid = os.path.basename(path).split(".")[0] + args.uuid.append(uuid) + + return args + + +def entrypoint(**kwargs) -> None: + args = parse_args() + for k, v in kwargs.items(): + if hasattr(args, k) and v is not None: + setattr(args, k, v) + + camera_settings = CameraSetting( + num_images=args.num_images, + elevation=args.elevation, + distance=args.distance, + resolution_hw=args.resolution_hw, + fov=math.radians(args.fov), + device="cuda", + ) + + render_items = [ + RenderItems.ALPHA.value, + RenderItems.GLOBAL_NORMAL.value, + RenderItems.VIEW_NORMAL.value, + RenderItems.POSITION_MAP.value, + RenderItems.IMAGE.value, + RenderItems.DEPTH.value, + # RenderItems.ALBEDO.value, + # RenderItems.DIFFUSE.value, + ] + + gen_video = ( + args.gen_color_gif + or args.gen_color_mp4 + or args.gen_viewnormal_mp4 + or args.gen_glonormal_mp4 + ) + if gen_video: + render_items = [] + if args.gen_color_gif or args.gen_color_mp4: + render_items.append(RenderItems.IMAGE.value) + if args.gen_glonormal_mp4: + render_items.append(RenderItems.GLOBAL_NORMAL.value) + if args.gen_viewnormal_mp4: + render_items.append(RenderItems.VIEW_NORMAL.value) + if RenderItems.GLOBAL_NORMAL.value not in render_items: + render_items.append(RenderItems.GLOBAL_NORMAL.value) + + image_render = ImageRender( + render_items=render_items, + camera_params=camera_settings, + with_mtl=args.with_mtl, + gen_color_gif=args.gen_color_gif, + gen_color_mp4=args.gen_color_mp4, + gen_viewnormal_mp4=args.gen_viewnormal_mp4, + gen_glonormal_mp4=args.gen_glonormal_mp4, + light_factor=args.pbr_light_factor, + no_index_file=gen_video, + ) + image_render.render_mesh( + mesh_path=args.mesh_path, + output_root=args.output_root, + uuid=args.uuid, + prompts=args.prompts, + ) + + return + + +if __name__ == "__main__": + entrypoint() diff --git a/embodied_gen/data/mesh_operator.py b/embodied_gen/data/mesh_operator.py new file mode 100644 index 0000000..888b203 --- /dev/null +++ b/embodied_gen/data/mesh_operator.py @@ -0,0 +1,452 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +import logging +from typing import Tuple, Union + +import igraph +import numpy as np +import pyvista as pv +import spaces +import torch +import utils3d +from pymeshfix import _meshfix +from tqdm import tqdm + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO +) +logger = logging.getLogger(__name__) + + +__all__ = ["MeshFixer"] + + +def _radical_inverse(base, n): + val = 0 + inv_base = 1.0 / base + inv_base_n = inv_base + while n > 0: + digit = n % base + val += digit * inv_base_n + n //= base + inv_base_n *= inv_base + return val + + +def _halton_sequence(dim, n): + PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53] + return [_radical_inverse(PRIMES[dim], n) for dim in range(dim)] + + +def _hammersley_sequence(dim, n, num_samples): + return [n / num_samples] + _halton_sequence(dim - 1, n) + + +def _sphere_hammersley_seq(n, num_samples, offset=(0, 0), remap=False): + """Generate a point on a unit sphere using the Hammersley sequence. + + Args: + n (int): The index of the sample. + num_samples (int): The total number of samples. + offset (tuple, optional): Offset for the u and v coordinates. + remap (bool, optional): Whether to remap the u coordinate. + + Returns: + list: A list containing the spherical coordinates [phi, theta]. + """ + u, v = _hammersley_sequence(2, n, num_samples) + u += offset[0] / num_samples + v += offset[1] + + if remap: + u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3 + + theta = np.arccos(1 - 2 * u) - np.pi / 2 + phi = v * 2 * np.pi + return [phi, theta] + + +class MeshFixer(object): + """MeshFixer simplifies and repairs 3D triangle meshes by TSDF. + + Attributes: + vertices (torch.Tensor): A tensor of shape (V, 3) representing vertex positions. + faces (torch.Tensor): A tensor of shape (F, 3) representing face indices. + device (str): Device to run computations on, typically "cuda" or "cpu". + + Main logic reference: https://github.com/microsoft/TRELLIS/blob/main/trellis/utils/postprocessing_utils.py#L22 + """ + + def __init__( + self, + vertices: Union[torch.Tensor, np.ndarray], + faces: Union[torch.Tensor, np.ndarray], + device: str = "cuda", + ) -> None: + self.device = device + if isinstance(vertices, np.ndarray): + vertices = torch.tensor(vertices) + self.vertices = vertices + + if isinstance(faces, np.ndarray): + faces = torch.tensor(faces) + self.faces = faces + + @staticmethod + def log_mesh_changes(method): + def wrapper(self, *args, **kwargs): + logger.info( + f"Before {method.__name__}: {self.vertices.shape[0]} vertices, {self.faces.shape[0]} faces" # noqa + ) + result = method(self, *args, **kwargs) + logger.info( + f"After {method.__name__}: {self.vertices.shape[0]} vertices, {self.faces.shape[0]} faces" # noqa + ) + return result + + return wrapper + + @log_mesh_changes + def fill_holes( + self, + max_hole_size: float, + max_hole_nbe: int, + resolution: int, + num_views: int, + norm_mesh_ratio: float = 1.0, + ) -> None: + self.vertices = self.vertices * norm_mesh_ratio + vertices, self.faces = self._fill_holes( + self.vertices, + self.faces, + max_hole_size, + max_hole_nbe, + resolution, + num_views, + ) + self.vertices = vertices / norm_mesh_ratio + + @staticmethod + @torch.no_grad() + def _fill_holes( + vertices: torch.Tensor, + faces: torch.Tensor, + max_hole_size: float, + max_hole_nbe: int, + resolution: int, + num_views: int, + ) -> Union[torch.Tensor, torch.Tensor]: + yaws, pitchs = [], [] + for i in range(num_views): + y, p = _sphere_hammersley_seq(i, num_views) + yaws.append(y) + pitchs.append(p) + + yaws, pitchs = torch.tensor(yaws).to(vertices), torch.tensor( + pitchs + ).to(vertices) + radius, fov = 2.0, torch.deg2rad(torch.tensor(40)).to(vertices) + projection = utils3d.torch.perspective_from_fov_xy(fov, fov, 1, 3) + + views = [] + for yaw, pitch in zip(yaws, pitchs): + orig = ( + torch.tensor( + [ + torch.sin(yaw) * torch.cos(pitch), + torch.cos(yaw) * torch.cos(pitch), + torch.sin(pitch), + ] + ).to(vertices) + * radius + ) + view = utils3d.torch.view_look_at( + orig, + torch.tensor([0, 0, 0]).to(vertices), + torch.tensor([0, 0, 1]).to(vertices), + ) + views.append(view) + views = torch.stack(views, dim=0) + + # Rasterize the mesh + visibility = torch.zeros( + faces.shape[0], dtype=torch.int32, device=faces.device + ) + rastctx = utils3d.torch.RastContext(backend="cuda") + + for i in tqdm( + range(views.shape[0]), total=views.shape[0], desc="Rasterizing" + ): + view = views[i] + buffers = utils3d.torch.rasterize_triangle_faces( + rastctx, + vertices[None], + faces, + resolution, + resolution, + view=view, + projection=projection, + ) + face_id = buffers["face_id"][0][buffers["mask"][0] > 0.95] - 1 + face_id = torch.unique(face_id).long() + visibility[face_id] += 1 + + # Normalize visibility by the number of views + visibility = visibility.float() / num_views + + # Mincut: Identify outer and inner faces + edges, face2edge, edge_degrees = utils3d.torch.compute_edges(faces) + boundary_edge_indices = torch.nonzero(edge_degrees == 1).reshape(-1) + connected_components = utils3d.torch.compute_connected_components( + faces, edges, face2edge + ) + + outer_face_indices = torch.zeros( + faces.shape[0], dtype=torch.bool, device=faces.device + ) + for i in range(len(connected_components)): + outer_face_indices[connected_components[i]] = visibility[ + connected_components[i] + ] > min( + max( + visibility[connected_components[i]].quantile(0.75).item(), + 0.25, + ), + 0.5, + ) + + outer_face_indices = outer_face_indices.nonzero().reshape(-1) + inner_face_indices = torch.nonzero(visibility == 0).reshape(-1) + + if inner_face_indices.shape[0] == 0: + return vertices, faces + + # Construct dual graph (faces as nodes, edges as edges) + dual_edges, dual_edge2edge = utils3d.torch.compute_dual_graph( + face2edge + ) + dual_edge2edge = edges[dual_edge2edge] + dual_edges_weights = torch.norm( + vertices[dual_edge2edge[:, 0]] - vertices[dual_edge2edge[:, 1]], + dim=1, + ) + + # Mincut: Construct main graph and solve the mincut problem + g = igraph.Graph() + g.add_vertices(faces.shape[0]) + g.add_edges(dual_edges.cpu().numpy()) + g.es["weight"] = dual_edges_weights.cpu().numpy() + + g.add_vertex("s") # source + g.add_vertex("t") # target + + g.add_edges( + [(f, "s") for f in inner_face_indices], + attributes={ + "weight": torch.ones( + inner_face_indices.shape[0], dtype=torch.float32 + ) + .cpu() + .numpy() + }, + ) + g.add_edges( + [(f, "t") for f in outer_face_indices], + attributes={ + "weight": torch.ones( + outer_face_indices.shape[0], dtype=torch.float32 + ) + .cpu() + .numpy() + }, + ) + + cut = g.mincut("s", "t", (np.array(g.es["weight"]) * 1000).tolist()) + remove_face_indices = torch.tensor( + [v for v in cut.partition[0] if v < faces.shape[0]], + dtype=torch.long, + device=faces.device, + ) + + # Check if the cut is valid with each connected component + to_remove_cc = utils3d.torch.compute_connected_components( + faces[remove_face_indices] + ) + valid_remove_cc = [] + cutting_edges = [] + for cc in to_remove_cc: + # Check visibility median for connected component + visibility_median = visibility[remove_face_indices[cc]].median() + if visibility_median > 0.25: + continue + + # Check if the cutting loop is small enough + cc_edge_indices, cc_edges_degree = torch.unique( + face2edge[remove_face_indices[cc]], return_counts=True + ) + cc_boundary_edge_indices = cc_edge_indices[cc_edges_degree == 1] + cc_new_boundary_edge_indices = cc_boundary_edge_indices[ + ~torch.isin(cc_boundary_edge_indices, boundary_edge_indices) + ] + if len(cc_new_boundary_edge_indices) > 0: + cc_new_boundary_edge_cc = ( + utils3d.torch.compute_edge_connected_components( + edges[cc_new_boundary_edge_indices] + ) + ) + cc_new_boundary_edges_cc_center = [ + vertices[edges[cc_new_boundary_edge_indices[edge_cc]]] + .mean(dim=1) + .mean(dim=0) + for edge_cc in cc_new_boundary_edge_cc + ] + cc_new_boundary_edges_cc_area = [] + for i, edge_cc in enumerate(cc_new_boundary_edge_cc): + _e1 = ( + vertices[ + edges[cc_new_boundary_edge_indices[edge_cc]][:, 0] + ] + - cc_new_boundary_edges_cc_center[i] + ) + _e2 = ( + vertices[ + edges[cc_new_boundary_edge_indices[edge_cc]][:, 1] + ] + - cc_new_boundary_edges_cc_center[i] + ) + cc_new_boundary_edges_cc_area.append( + torch.norm(torch.cross(_e1, _e2, dim=-1), dim=1).sum() + * 0.5 + ) + cutting_edges.append(cc_new_boundary_edge_indices) + if any( + [ + _l > max_hole_size + for _l in cc_new_boundary_edges_cc_area + ] + ): + continue + + valid_remove_cc.append(cc) + + if len(valid_remove_cc) > 0: + remove_face_indices = remove_face_indices[ + torch.cat(valid_remove_cc) + ] + mask = torch.ones( + faces.shape[0], dtype=torch.bool, device=faces.device + ) + mask[remove_face_indices] = 0 + faces = faces[mask] + faces, vertices = utils3d.torch.remove_unreferenced_vertices( + faces, vertices + ) + + tqdm.write(f"Removed {(~mask).sum()} faces by mincut") + else: + tqdm.write(f"Removed 0 faces by mincut") + + # Fill small boundaries (holes) + mesh = _meshfix.PyTMesh() + mesh.load_array(vertices.cpu().numpy(), faces.cpu().numpy()) + mesh.fill_small_boundaries(nbe=max_hole_nbe, refine=True) + + _vertices, _faces = mesh.return_arrays() + vertices = torch.tensor(_vertices).to(vertices) + faces = torch.tensor(_faces).to(faces) + + return vertices, faces + + @property + def vertices_np(self) -> np.ndarray: + return self.vertices.cpu().numpy() + + @property + def faces_np(self) -> np.ndarray: + return self.faces.cpu().numpy() + + @log_mesh_changes + def simplify(self, ratio: float) -> None: + """Simplify the mesh using quadric edge collapse decimation. + + Args: + ratio (float): Ratio of faces to filter out. + """ + if ratio <= 0 or ratio >= 1: + raise ValueError("Simplify ratio must be between 0 and 1.") + + # Convert to PyVista format for simplification + mesh = pv.PolyData( + self.vertices_np, + np.hstack([np.full((self.faces.shape[0], 1), 3), self.faces_np]), + ) + mesh = mesh.decimate(ratio, progress_bar=True) + + # Update vertices and faces + self.vertices = torch.tensor( + mesh.points, device=self.device, dtype=torch.float32 + ) + self.faces = torch.tensor( + mesh.faces.reshape(-1, 4)[:, 1:], + device=self.device, + dtype=torch.int32, + ) + + @spaces.GPU + def __call__( + self, + filter_ratio: float, + max_hole_size: float, + resolution: int, + num_views: int, + norm_mesh_ratio: float = 1.0, + ) -> Tuple[np.ndarray, np.ndarray]: + """Post-process the mesh by simplifying and filling holes. + + This method performs a two-step process: + 1. Simplifies mesh by reducing faces using quadric edge decimation. + 2. Fills holes by removing invisible faces, repairing small boundaries. + + Args: + filter_ratio (float): Ratio of faces to simplify out. + Must be in the range (0, 1). + max_hole_size (float): Maximum area of a hole to fill. Connected + components of holes larger than this size will not be repaired. + resolution (int): Resolution of the rasterization buffer. + num_views (int): Number of viewpoints to sample for rasterization. + norm_mesh_ratio (float, optional): A scaling factor applied to the + vertices of the mesh during processing. + + Returns: + Tuple[np.ndarray, np.ndarray]: + - vertices: Simplified and repaired vertex array of (V, 3). + - faces: Simplified and repaired face array of (F, 3). + """ + self.vertices = self.vertices.to(self.device) + self.faces = self.faces.to(self.device) + + self.simplify(ratio=filter_ratio) + self.fill_holes( + max_hole_size=max_hole_size, + max_hole_nbe=int(250 * np.sqrt(1 - filter_ratio)), + resolution=resolution, + num_views=num_views, + norm_mesh_ratio=norm_mesh_ratio, + ) + + return self.vertices_np, self.faces_np diff --git a/embodied_gen/data/utils.py b/embodied_gen/data/utils.py new file mode 100644 index 0000000..31a65f8 --- /dev/null +++ b/embodied_gen/data/utils.py @@ -0,0 +1,1009 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +import math +import os +import random +import zipfile +from shutil import rmtree +from typing import List, Tuple, Union + +import cv2 +import kaolin as kal +import numpy as np +import nvdiffrast.torch as dr +import torch +import torch.nn.functional as F +from PIL import Image + +try: + from kolors.models.modeling_chatglm import ChatGLMModel + from kolors.models.tokenization_chatglm import ChatGLMTokenizer +except ImportError: + ChatGLMTokenizer = None + ChatGLMModel = None +import logging +from dataclasses import dataclass, field +from enum import Enum + +import trimesh +from kaolin.render.camera import Camera +from torch import nn + +logger = logging.getLogger(__name__) + + +__all__ = [ + "DiffrastRender", + "save_images", + "render_pbr", + "prelabel_text_feature", + "calc_vertex_normals", + "normalize_vertices_array", + "load_mesh_to_unit_cube", + "as_list", + "CameraSetting", + "RenderItems", + "import_kaolin_mesh", + "save_mesh_with_mtl", + "get_images_from_grid", + "post_process_texture", + "quat_mult", + "quat_to_rotmat", + "gamma_shs", + "resize_pil", + "trellis_preprocess", + "delete_dir", +] + + +class DiffrastRender(object): + """A class to handle differentiable rendering using nvdiffrast. + + This class provides methods to render position, depth, and normal maps + with optional anti-aliasing and gradient disabling for rasterization. + + Attributes: + p_mtx (torch.Tensor): Projection matrix. + mv_mtx (torch.Tensor): Model-view matrix. + mvp_mtx (torch.Tensor): Model-view-projection matrix, calculated as + p_mtx @ mv_mtx if not provided. + resolution_hw (Tuple[int, int]): Height and width of the rendering resolution. # noqa + _ctx (Union[dr.RasterizeCudaContext, dr.RasterizeGLContext]): Rasterization context. # noqa + mask_thresh (float): Threshold for mask creation. + grad_db (bool): Whether to disable gradients during rasterization. + antialias_mask (bool): Whether to apply anti-aliasing to the mask. + device (str): Device used for rendering ('cuda' or 'cpu'). + """ + + def __init__( + self, + p_matrix: torch.Tensor, + mv_matrix: torch.Tensor, + resolution_hw: Tuple[int, int], + context: Union[dr.RasterizeCudaContext, dr.RasterizeGLContext] = None, + mvp_matrix: torch.Tensor = None, + mask_thresh: float = 0.5, + grad_db: bool = False, + antialias_mask: bool = True, + align_coordinate: bool = True, + device: str = "cuda", + ) -> None: + self.p_mtx = p_matrix + self.mv_mtx = mv_matrix + if mvp_matrix is None: + self.mvp_mtx = torch.bmm(p_matrix, mv_matrix) + + self.resolution_hw = resolution_hw + if context is None: + context = dr.RasterizeCudaContext(device=device) + self._ctx = context + self.mask_thresh = mask_thresh + self.grad_db = grad_db + self.antialias_mask = antialias_mask + self.align_coordinate = align_coordinate + self.device = device + + def compute_dr_raster( + self, + vertices: torch.Tensor, + faces: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + vertices_clip = self.transform_vertices(vertices, matrix=self.mvp_mtx) + rast, _ = dr.rasterize( + self._ctx, + vertices_clip, + faces.int(), + resolution=self.resolution_hw, + grad_db=self.grad_db, + ) + + return rast, vertices_clip + + def transform_vertices( + self, + vertices: torch.Tensor, + matrix: torch.Tensor, + ) -> torch.Tensor: + verts_ones = torch.ones((len(vertices), 1)).to(vertices) + verts_homo = torch.cat([vertices, verts_ones], dim=-1) + trans_vertices = torch.matmul(verts_homo, matrix.permute(0, 2, 1)) + + return trans_vertices + + def normalize_map_by_mask_separately( + self, map: torch.Tensor, mask: torch.Tensor + ) -> torch.Tensor: + # Normalize each map separately by mask, normalized map in [0, 1]. + normalized_maps = [] + for map_item, mask_item in zip(map, mask): + normalized_map = self.normalize_map_by_mask(map_item, mask_item) + normalized_maps.append(normalized_map) + + normalized_maps = torch.stack(normalized_maps, dim=0) + + return normalized_maps + + def normalize_map_by_mask( + self, map: torch.Tensor, mask: torch.Tensor + ) -> torch.Tensor: + # Normalize all maps in total by mask, normalized map in [0, 1]. + foreground = (mask == 1).squeeze(dim=-1) + foreground_elements = map[foreground] + if len(foreground_elements) == 0: + return map + + min_val, _ = foreground_elements.min(dim=0) + max_val, _ = foreground_elements.max(dim=0) + val_range = (max_val - min_val).clip(min=1e-6) + + normalized_map = (map - min_val) / val_range + normalized_map = torch.lerp( + torch.zeros_like(normalized_map), normalized_map, mask + ) + normalized_map[normalized_map < 0] = 0 + + return normalized_map + + def _compute_mask( + self, + rast: torch.Tensor, + vertices_clip: torch.Tensor, + faces: torch.Tensor, + ) -> torch.Tensor: + mask = (rast[..., 3:] > 0).float() + mask = mask.clip(min=0, max=1) + + if self.antialias_mask is True: + mask = dr.antialias(mask, rast, vertices_clip, faces) + else: + foreground = mask > self.mask_thresh + mask[foreground] = 1 + mask[~foreground] = 0 + + return mask + + def render_rast_alpha( + self, + vertices: torch.Tensor, + faces: torch.Tensor, + ): + faces = faces.to(torch.int32) + rast, vertices_clip = self.compute_dr_raster(vertices, faces) + mask = self._compute_mask(rast, vertices_clip, faces) + + return mask, rast + + def render_position( + self, + vertices: torch.Tensor, + faces: torch.Tensor, + ) -> Union[torch.Tensor, torch.Tensor]: + # Vertices in model coordinate system, real position coordinate number. + faces = faces.to(torch.int32) + mask, rast = self.render_rast_alpha(vertices, faces) + + vertices_model = vertices[None, ...].contiguous().float() + position_map, _ = dr.interpolate(vertices_model, rast, faces) + # Align with blender. + if self.align_coordinate: + position_map = position_map[..., [0, 2, 1]] + position_map[..., 1] = -position_map[..., 1] + + position_map = torch.lerp( + torch.zeros_like(position_map), position_map, mask + ) + + return position_map, mask + + def render_uv( + self, + vertices: torch.Tensor, + faces: torch.Tensor, + vtx_uv: torch.Tensor, + ) -> Union[torch.Tensor, torch.Tensor]: + faces = faces.to(torch.int32) + mask, rast = self.render_rast_alpha(vertices, faces) + uv_map, _ = dr.interpolate(vtx_uv, rast, faces) + uv_map = torch.lerp(torch.zeros_like(uv_map), uv_map, mask) + + return uv_map, mask + + def render_depth( + self, + vertices: torch.Tensor, + faces: torch.Tensor, + ) -> Union[torch.Tensor, torch.Tensor]: + # Vertices in model coordinate system, real depth coordinate number. + faces = faces.to(torch.int32) + mask, rast = self.render_rast_alpha(vertices, faces) + + vertices_camera = self.transform_vertices(vertices, matrix=self.mv_mtx) + vertices_camera = vertices_camera[..., 2:3].contiguous().float() + depth_map, _ = dr.interpolate(vertices_camera, rast, faces) + # Change camera depth minus to positive. + if self.align_coordinate: + depth_map = -depth_map + depth_map = torch.lerp(torch.zeros_like(depth_map), depth_map, mask) + + return depth_map, mask + + def render_global_normal( + self, + vertices: torch.Tensor, + faces: torch.Tensor, + vertice_normals: torch.Tensor, + ) -> Union[torch.Tensor, torch.Tensor]: + # NOTE: vertice_normals in [-1, 1], return normal in [0, 1]. + # vertices / vertice_normals in model coordinate system. + faces = faces.to(torch.int32) + mask, rast = self.render_rast_alpha(vertices, faces) + im_base_normals, _ = dr.interpolate( + vertice_normals[None, ...].float(), rast, faces + ) + + if im_base_normals is not None: + faces = faces.to(torch.int64) + vertices_cam = self.transform_vertices( + vertices, matrix=self.mv_mtx + ) + face_vertices_ndc = kal.ops.mesh.index_vertices_by_faces( + vertices_cam[..., :3], faces + ) + face_normal_sign = kal.ops.mesh.face_normals(face_vertices_ndc)[ + ..., 2 + ] + for idx in range(len(im_base_normals)): + face_idx = (rast[idx, ..., -1].long() - 1).contiguous() + im_normal_sign = torch.sign(face_normal_sign[idx, face_idx]) + im_normal_sign[face_idx == -1] = 0 + im_base_normals[idx] *= im_normal_sign.unsqueeze(-1) + + normal = (im_base_normals + 1) / 2 + normal = normal.clip(min=0, max=1) + normal = torch.lerp(torch.zeros_like(normal), normal, mask) + + return normal, mask + + def transform_normal( + self, + normals: torch.Tensor, + trans_matrix: torch.Tensor, + masks: torch.Tensor, + to_view: bool, + ) -> torch.Tensor: + # NOTE: input normals in [0, 1], output normals in [0, 1]. + normals = normals.clone() + assert len(normals) == len(trans_matrix) + + if not to_view: + # Flip the sign on the x-axis to match inv bae system for global transformation. # noqa + normals[..., 0] = 1 - normals[..., 0] + + normals = 2 * normals - 1 + b, h, w, c = normals.shape + + transformed_normals = [] + for normal, matrix in zip(normals, trans_matrix): + # Transform normals using the transformation matrix (4x4). + reshaped_normals = normal.view(-1, c) # (h w 3) -> (hw 3) + padded_vectors = torch.nn.functional.pad( + reshaped_normals, pad=(0, 1), mode="constant", value=0.0 + ) + transformed_normal = torch.matmul( + padded_vectors, matrix.transpose(0, 1) + )[..., :3] + + # Normalize and clip the normals to [0, 1] range. + transformed_normal = F.normalize(transformed_normal, p=2, dim=-1) + transformed_normal = (transformed_normal + 1) / 2 + + if to_view: + # Flip the sign on the x-axis to match bae system for view transformation. # noqa + transformed_normal[..., 0] = 1 - transformed_normal[..., 0] + + transformed_normals.append(transformed_normal.view(h, w, c)) + + transformed_normals = torch.stack(transformed_normals, dim=0) + + if masks is not None: + transformed_normals = torch.lerp( + torch.zeros_like(transformed_normals), + transformed_normals, + masks, + ) + + return transformed_normals + + +def _az_el_to_points( + azimuths: np.ndarray, elevations: np.ndarray +) -> np.ndarray: + x = np.cos(azimuths) * np.cos(elevations) + y = np.sin(azimuths) * np.cos(elevations) + z = np.sin(elevations) + + return np.stack([x, y, z], axis=-1) + + +def _compute_az_el_by_views( + num_view: int, el: float +) -> Tuple[np.ndarray, np.ndarray]: + azimuths = np.arange(num_view) / num_view * np.pi * 2 + elevations = np.deg2rad(np.array([el] * num_view)) + + return azimuths, elevations + + +def _compute_cam_pts_by_az_el( + azs: np.ndarray, + els: np.ndarray, + distance: float, + extra_pts: np.ndarray = None, +) -> np.ndarray: + distances = np.array([distance for _ in range(len(azs))]) + cam_pts = _az_el_to_points(azs, els) * distances[:, None] + + if extra_pts is not None: + cam_pts = np.concatenate([cam_pts, extra_pts], axis=0) + + # Align coordinate system. + cam_pts = cam_pts[:, [0, 2, 1]] # xyz -> xzy + cam_pts[..., 2] = -cam_pts[..., 2] + + return cam_pts + + +def compute_cam_pts_by_views( + num_view: int, el: float, distance: float, extra_pts: np.ndarray = None +) -> torch.Tensor: + """Computes object-center camera points for a given number of views. + + Args: + num_view (int): The number of views (camera positions) to compute. + el (float): The elevation angle in degrees. + distance (float): The distance from the origin to the camera. + extra_pts (np.ndarray): Extra camera points postion. + + Returns: + torch.Tensor: A tensor containing the camera points for each view, with shape `(num_view, 3)`. # noqa + """ + azimuths, elevations = _compute_az_el_by_views(num_view, el) + cam_pts = _compute_cam_pts_by_az_el( + azimuths, elevations, distance, extra_pts + ) + + return cam_pts + + +def save_images( + images: Union[list[np.ndarray], list[torch.Tensor]], + output_dir: str, + cvt_color: str = None, + format: str = ".png", + to_uint8: bool = True, + verbose: bool = False, +) -> List[str]: + # NOTE: images in [0, 1] + os.makedirs(output_dir, exist_ok=True) + save_paths = [] + for idx, image in enumerate(images): + if isinstance(image, torch.Tensor): + image = image.detach().cpu().numpy() + if to_uint8: + image = image.clip(min=0, max=1) + image = (255.0 * image).astype(np.uint8) + if cvt_color is not None: + image = cv2.cvtColor(image, cvt_color) + save_path = os.path.join(output_dir, f"{idx:04d}{format}") + save_paths.append(save_path) + + cv2.imwrite(save_path, image) + + if verbose: + logger.info(f"Images saved in {output_dir}") + + return save_paths + + +def _current_lighting( + azimuths: List[float], + elevations: List[float], + light_factor: float = 1.0, + device: str = "cuda", +): + # azimuths, elevations in degress. + directions = [] + for az, el in zip(azimuths, elevations): + az, el = math.radians(az), math.radians(el) + direction = kal.render.lighting.sg_direction_from_azimuth_elevation( + az, el + ) + directions.append(direction) + directions = torch.cat(directions, dim=0) + + amplitude = torch.ones_like(directions) * light_factor + light_condition = kal.render.lighting.SgLightingParameters( + amplitude=amplitude, + direction=directions, + sharpness=3, + ).to(device) + + # light_condition = kal.render.lighting.SgLightingParameters.from_sun( + # directions, strength=1, angle=90, color=None + # ).to(device) + + return light_condition + + +def render_pbr( + mesh, + camera, + device="cuda", + cxt=None, + custom_materials=None, + light_factor=1.0, +): + if cxt is None: + cxt = dr.RasterizeCudaContext() + + light_condition = _current_lighting( + azimuths=[0, 90, 180, 270], + elevations=[90, 60, 30, 20], + light_factor=light_factor, + device=device, + ) + render_res = kal.render.easy_render.render_mesh( + camera, + mesh, + lighting=light_condition, + nvdiffrast_context=cxt, + custom_materials=custom_materials, + ) + + image = render_res[kal.render.easy_render.RenderPass.render] + image = image.clip(0, 1) + + albedo = render_res[kal.render.easy_render.RenderPass.albedo] + albedo = albedo.clip(0, 1) + + diffuse = render_res[kal.render.easy_render.RenderPass.diffuse] + diffuse = diffuse.clip(0, 1) + + normal = render_res[kal.render.easy_render.RenderPass.normals] + normal = normal.clip(-1, 1) + + return image, albedo, diffuse, normal + + +def _move_to_target_device(data, device: str): + if isinstance(data, dict): + for key, value in data.items(): + data[key] = _move_to_target_device(value, device) + elif isinstance(data, torch.Tensor): + return data.to(device) + + return data + + +def _encode_prompt( + prompt_batch, + text_encoders, + tokenizers, + proportion_empty_prompts=0, + is_train=True, +): + prompt_embeds_list = [] + + captions = [] + for caption in prompt_batch: + if random.random() < proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + captions.append(random.choice(caption) if is_train else caption[0]) + + with torch.no_grad(): + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + captions, + padding="max_length", + max_length=256, + truncation=True, + return_tensors="pt", + ).to(text_encoder.device) + + output = text_encoder( + input_ids=text_inputs.input_ids, + attention_mask=text_inputs.attention_mask, + position_ids=text_inputs.position_ids, + output_hidden_states=True, + ) + + # We are only interested in the pooled output of the text encoder. + prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() + pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + + return prompt_embeds, pooled_prompt_embeds + + +def load_llm_models(pretrained_model_name_or_path: str, device: str): + tokenizer = ChatGLMTokenizer.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + ) + text_encoder = ChatGLMModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + ).to(device) + + text_encoders = [ + text_encoder, + ] + tokenizers = [ + tokenizer, + ] + + logger.info(f"Load model from {pretrained_model_name_or_path} done.") + + return tokenizers, text_encoders + + +def prelabel_text_feature( + prompt_batch: List[str], + output_dir: str, + tokenizers: nn.Module, + text_encoders: nn.Module, +) -> List[str]: + os.makedirs(output_dir, exist_ok=True) + + # prompt_batch ["text..."] + prompt_embeds, pooled_prompt_embeds = _encode_prompt( + prompt_batch, text_encoders, tokenizers + ) + + prompt_embeds = _move_to_target_device(prompt_embeds, device="cpu") + pooled_prompt_embeds = _move_to_target_device( + pooled_prompt_embeds, device="cpu" + ) + + data_dict = dict( + prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds + ) + + save_path = os.path.join(output_dir, "text_feat.pth") + torch.save(data_dict, save_path) + + return save_path + + +def _calc_face_normals( + vertices: torch.Tensor, # V,3 first vertex may be unreferenced + faces: torch.Tensor, # F,3 long, first face may be all zero + normalize: bool = False, +) -> torch.Tensor: # F,3 + full_vertices = vertices[faces] # F,C=3,3 + v0, v1, v2 = full_vertices.unbind(dim=1) # F,3 + face_normals = torch.cross(v1 - v0, v2 - v0, dim=1) # F,3 + if normalize: + face_normals = F.normalize( + face_normals, eps=1e-6, dim=1 + ) # TODO inplace? + return face_normals # F,3 + + +def calc_vertex_normals( + vertices: torch.Tensor, # V,3 first vertex may be unreferenced + faces: torch.Tensor, # F,3 long, first face may be all zero + face_normals: torch.Tensor = None, # F,3, not normalized +) -> torch.Tensor: # F,3 + _F = faces.shape[0] + + if face_normals is None: + face_normals = _calc_face_normals(vertices, faces) + + vertex_normals = torch.zeros( + (vertices.shape[0], 3, 3), dtype=vertices.dtype, device=vertices.device + ) # V,C=3,3 + vertex_normals.scatter_add_( + dim=0, + index=faces[:, :, None].expand(_F, 3, 3), + src=face_normals[:, None, :].expand(_F, 3, 3), + ) + vertex_normals = vertex_normals.sum(dim=1) # V,3 + return F.normalize(vertex_normals, eps=1e-6, dim=1) + + +def normalize_vertices_array( + vertices: Union[torch.Tensor, np.ndarray], + mesh_scale: float = 1.0, + exec_norm: bool = True, +): + if isinstance(vertices, torch.Tensor): + bbmin, bbmax = vertices.min(0)[0], vertices.max(0)[0] + else: + bbmin, bbmax = vertices.min(0), vertices.max(0) # (3,) + center = (bbmin + bbmax) * 0.5 + bbsize = bbmax - bbmin + scale = 2 * mesh_scale / bbsize.max() + if exec_norm: + vertices = (vertices - center) * scale + + return vertices, scale, center + + +def load_mesh_to_unit_cube( + mesh_file: str, + mesh_scale: float = 1.0, +) -> tuple[trimesh.Trimesh, float, list[float]]: + if not os.path.exists(mesh_file): + raise FileNotFoundError(f"mesh_file path {mesh_file} not exists.") + + mesh = trimesh.load(mesh_file) + if isinstance(mesh, trimesh.Scene): + mesh = trimesh.utils.concatenate(mesh) + + vertices, scale, center = normalize_vertices_array( + mesh.vertices, mesh_scale + ) + mesh.vertices = vertices + + return mesh, scale, center + + +def as_list(obj): + if isinstance(obj, (list, tuple)): + return obj + elif isinstance(obj, set): + return list(obj) + else: + return [obj] + + +@dataclass +class CameraSetting: + """Camera settings for images rendering.""" + + num_images: int + elevation: list[float] + distance: float + resolution_hw: tuple[int, int] + fov: float + at: tuple[float, float, float] = field( + default_factory=lambda: (0.0, 0.0, 0.0) + ) + up: tuple[float, float, float] = field( + default_factory=lambda: (0.0, 1.0, 0.0) + ) + device: str = "cuda" + near: float = 1e-2 + far: float = 1e2 + + def __post_init__( + self, + ): + h = self.resolution_hw[0] + f = (h / 2) / math.tan(self.fov / 2) + cx = self.resolution_hw[1] / 2 + cy = self.resolution_hw[0] / 2 + Ks = [ + [f, 0, cx], + [0, f, cy], + [0, 0, 1], + ] + + self.Ks = Ks + + +@dataclass +class RenderItems(str, Enum): + IMAGE = "image_color" + ALPHA = "image_mask" + VIEW_NORMAL = "image_view_normal" + GLOBAL_NORMAL = "image_global_normal" + POSITION_MAP = "image_position" + DEPTH = "image_depth" + ALBEDO = "image_albedo" + DIFFUSE = "image_diffuse" + + +def _compute_az_el_by_camera_params( + camera_params: CameraSetting, flip_az: bool = False +): + num_view = camera_params.num_images // len(camera_params.elevation) + view_interval = 2 * np.pi / num_view / 2 + azimuths = [] + elevations = [] + for idx, el in enumerate(camera_params.elevation): + azs = np.arange(num_view) / num_view * np.pi * 2 + idx * view_interval + if flip_az: + azs *= -1 + els = np.deg2rad(np.array([el] * num_view)) + azimuths.append(azs) + elevations.append(els) + + azimuths = np.concatenate(azimuths, axis=0) + elevations = np.concatenate(elevations, axis=0) + + return azimuths, elevations + + +def init_kal_camera(camera_params: CameraSetting) -> Camera: + azimuths, elevations = _compute_az_el_by_camera_params(camera_params) + cam_pts = _compute_cam_pts_by_az_el( + azimuths, elevations, camera_params.distance + ) + + up = torch.cat( + [ + torch.tensor(camera_params.up).repeat(camera_params.num_images, 1), + ], + dim=0, + ) + + camera = Camera.from_args( + eye=torch.tensor(cam_pts), + at=torch.tensor(camera_params.at), + up=up, + fov=camera_params.fov, + height=camera_params.resolution_hw[0], + width=camera_params.resolution_hw[1], + near=camera_params.near, + far=camera_params.far, + device=camera_params.device, + ) + + return camera + + +def import_kaolin_mesh(mesh_path: str, with_mtl: bool = False): + if mesh_path.endswith(".glb"): + mesh = kal.io.gltf.import_mesh(mesh_path) + elif mesh_path.endswith(".obj"): + with_material = True if with_mtl else False + mesh = kal.io.obj.import_mesh(mesh_path, with_materials=with_material) + if with_mtl and mesh.materials and len(mesh.materials) > 0: + material = kal.render.materials.PBRMaterial() + assert ( + "map_Kd" in mesh.materials[0] + ), "'map_Kd' not found in materials." + material.diffuse_texture = mesh.materials[0]["map_Kd"] / 255.0 + mesh.materials = [material] + elif mesh_path.endswith(".ply"): + mesh = trimesh.load(mesh_path) + mesh_path = mesh_path.replace(".ply", ".obj") + mesh.export(mesh_path) + mesh = kal.io.obj.import_mesh(mesh_path) + elif mesh_path.endswith(".off"): + mesh = kal.io.off.import_mesh(mesh_path) + else: + raise RuntimeError( + f"{mesh_path} mesh type not supported, " + "supported mesh type `.glb`, `.obj`, `.ply`, `.off`." + ) + + return mesh + + +def save_mesh_with_mtl( + vertices: np.ndarray, + faces: np.ndarray, + uvs: np.ndarray, + texture: Union[Image.Image, np.ndarray], + output_path: str, + material_base=(250, 250, 250, 255), +) -> trimesh.Trimesh: + if isinstance(texture, np.ndarray): + texture = Image.fromarray(texture) + + mesh = trimesh.Trimesh( + vertices, + faces, + visual=trimesh.visual.TextureVisuals(uv=uvs, image=texture), + ) + mesh.visual.material = trimesh.visual.material.SimpleMaterial( + image=texture, + diffuse=material_base, + ambient=material_base, + specular=material_base, + ) + + dir_name = os.path.dirname(output_path) + os.makedirs(dir_name, exist_ok=True) + + _ = mesh.export(output_path) + # texture.save(os.path.join(dir_name, f"{file_name}_texture.png")) + + logger.info(f"Saved mesh with texture to {output_path}") + + return mesh + + +def get_images_from_grid( + image: Union[str, Image.Image], img_size: int +) -> list[Image.Image]: + if isinstance(image, str): + image = Image.open(image) + + view_images = np.array(image) + view_images = np.concatenate( + [view_images[:img_size, ...], view_images[img_size:, ...]], axis=1 + ) + images = np.split(view_images, view_images.shape[1] // img_size, axis=1) + images = [Image.fromarray(img) for img in images] + + return images + + +def post_process_texture(texture: np.ndarray, iter: int = 1) -> np.ndarray: + for _ in range(iter): + texture = cv2.fastNlMeansDenoisingColored(texture, None, 2, 2, 7, 15) + texture = cv2.bilateralFilter( + texture, d=5, sigmaColor=20, sigmaSpace=20 + ) + + return texture + + +def quat_mult(q1, q2): + # NOTE: + # Q1 is the quaternion that rotates the vector from the original position to the final position # noqa + # Q2 is the quaternion that been rotated + w1, x1, y1, z1 = q1.T + w2, x2, y2, z2 = q2.T + w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 + x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 + y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 + z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 + return torch.stack([w, x, y, z]).T + + +def quat_to_rotmat(quats: torch.Tensor, mode="wxyz") -> torch.Tensor: + """Convert quaternion to rotation matrix.""" + quats = F.normalize(quats, p=2, dim=-1) + + if mode == "xyzw": + x, y, z, w = torch.unbind(quats, dim=-1) + elif mode == "wxyz": + w, x, y, z = torch.unbind(quats, dim=-1) + else: + raise ValueError(f"Invalid mode: {mode}.") + + R = torch.stack( + [ + 1 - 2 * (y**2 + z**2), + 2 * (x * y - w * z), + 2 * (x * z + w * y), + 2 * (x * y + w * z), + 1 - 2 * (x**2 + z**2), + 2 * (y * z - w * x), + 2 * (x * z - w * y), + 2 * (y * z + w * x), + 1 - 2 * (x**2 + y**2), + ], + dim=-1, + ) + + return R.reshape(quats.shape[:-1] + (3, 3)) + + +def gamma_shs(shs: torch.Tensor, gamma: float) -> torch.Tensor: + C0 = 0.28209479177387814 # Constant for normalization in spherical harmonics # noqa + # Clip to the range [0.0, 1.0], apply gamma correction, and then un-clip back # noqa + new_shs = torch.clip(shs * C0 + 0.5, 0.0, 1.0) + new_shs = (torch.pow(new_shs, gamma) - 0.5) / C0 + return new_shs + + +def resize_pil(image: Image.Image, max_size: int = 1024) -> Image.Image: + max_size = max(image.size) + scale = min(1, 1024 / max_size) + if scale < 1: + new_size = (int(image.width * scale), int(image.height * scale)) + image = image.resize(new_size, Image.Resampling.LANCZOS) + + return image + + +def trellis_preprocess(image: Image.Image) -> Image.Image: + """Process the input image as trellis done.""" + image_np = np.array(image) + alpha = image_np[:, :, 3] + bbox = np.argwhere(alpha > 0.8 * 255) + bbox = ( + np.min(bbox[:, 1]), + np.min(bbox[:, 0]), + np.max(bbox[:, 1]), + np.max(bbox[:, 0]), + ) + center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 + size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) + size = int(size * 1.2) + bbox = ( + center[0] - size // 2, + center[1] - size // 2, + center[0] + size // 2, + center[1] + size // 2, + ) + image = image.crop(bbox) + image = image.resize((518, 518), Image.Resampling.LANCZOS) + image = np.array(image).astype(np.float32) / 255 + image = image[:, :, :3] * image[:, :, 3:4] + image = Image.fromarray((image * 255).astype(np.uint8)) + + return image + + +def zip_files(input_paths: list[str], output_zip: str) -> str: + with zipfile.ZipFile(output_zip, "w", zipfile.ZIP_DEFLATED) as zipf: + for input_path in input_paths: + if not os.path.exists(input_path): + raise FileNotFoundError(f"File not found: {input_path}") + + if os.path.isdir(input_path): + for root, _, files in os.walk(input_path): + for file in files: + file_path = os.path.join(root, file) + arcname = os.path.relpath( + file_path, start=os.path.commonpath(input_paths) + ) + zipf.write(file_path, arcname=arcname) + else: + arcname = os.path.relpath( + input_path, start=os.path.commonpath(input_paths) + ) + zipf.write(input_path, arcname=arcname) + + return output_zip + + +def delete_dir(folder_path: str, keep_subs: list[str] = None) -> None: + for item in os.listdir(folder_path): + if keep_subs is not None and item in keep_subs: + continue + item_path = os.path.join(folder_path, item) + if os.path.isdir(item_path): + rmtree(item_path) + else: + os.remove(item_path) diff --git a/embodied_gen/models/delight_model.py b/embodied_gen/models/delight_model.py new file mode 100644 index 0000000..645b4c5 --- /dev/null +++ b/embodied_gen/models/delight_model.py @@ -0,0 +1,200 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +import os +from typing import Union + +import cv2 +import numpy as np +import spaces +import torch +from diffusers import ( + EulerAncestralDiscreteScheduler, + StableDiffusionInstructPix2PixPipeline, +) +from huggingface_hub import snapshot_download +from PIL import Image +from embodied_gen.models.segment_model import RembgRemover + +__all__ = [ + "DelightingModel", +] + + +class DelightingModel(object): + """A model to remove the lighting in image space. + + This model is encapsulated based on the Hunyuan3D-Delight model + from https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0 # noqa + + Attributes: + image_guide_scale (float): Weight of image guidance in diffusion process. + text_guide_scale (float): Weight of text (prompt) guidance in diffusion process. + num_infer_step (int): Number of inference steps for diffusion model. + mask_erosion_size (int): Size of erosion kernel for alpha mask cleanup. + device (str): Device used for inference, e.g., 'cuda' or 'cpu'. + seed (int): Random seed for diffusion model reproducibility. + model_path (str): Filesystem path to pretrained model weights. + pipeline: Lazy-loaded diffusion pipeline instance. + """ + + def __init__( + self, + model_path: str = None, + num_infer_step: int = 50, + mask_erosion_size: int = 3, + image_guide_scale: float = 1.5, + text_guide_scale: float = 1.0, + device: str = "cuda", + seed: int = 0, + ) -> None: + self.image_guide_scale = image_guide_scale + self.text_guide_scale = text_guide_scale + self.num_infer_step = num_infer_step + self.mask_erosion_size = mask_erosion_size + self.kernel = np.ones( + (self.mask_erosion_size, self.mask_erosion_size), np.uint8 + ) + self.seed = seed + self.device = device + self.pipeline = None # lazy load model adapt to @spaces.GPU + + if model_path is None: + suffix = "hunyuan3d-delight-v2-0" + model_path = snapshot_download( + repo_id="tencent/Hunyuan3D-2", allow_patterns=f"{suffix}/*" + ) + model_path = os.path.join(model_path, suffix) + + self.model_path = model_path + + def _lazy_init_pipeline(self): + if self.pipeline is None: + pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( + self.model_path, + torch_dtype=torch.float16, + safety_checker=None, + ) + pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( + pipeline.scheduler.config + ) + pipeline.set_progress_bar_config(disable=True) + + pipeline.to(self.device, torch.float16) + self.pipeline = pipeline + + def recenter_image( + self, image: Image.Image, border_ratio: float = 0.2 + ) -> Image.Image: + if image.mode == "RGB": + return image + elif image.mode == "L": + image = image.convert("RGB") + return image + + alpha_channel = np.array(image)[:, :, 3] + non_zero_indices = np.argwhere(alpha_channel > 0) + if non_zero_indices.size == 0: + raise ValueError("Image is fully transparent") + + min_row, min_col = non_zero_indices.min(axis=0) + max_row, max_col = non_zero_indices.max(axis=0) + + cropped_image = image.crop( + (min_col, min_row, max_col + 1, max_row + 1) + ) + + width, height = cropped_image.size + border_width = int(width * border_ratio) + border_height = int(height * border_ratio) + + new_width = width + 2 * border_width + new_height = height + 2 * border_height + + square_size = max(new_width, new_height) + + new_image = Image.new( + "RGBA", (square_size, square_size), (255, 255, 255, 0) + ) + + paste_x = (square_size - new_width) // 2 + border_width + paste_y = (square_size - new_height) // 2 + border_height + + new_image.paste(cropped_image, (paste_x, paste_y)) + + return new_image + + @spaces.GPU + @torch.no_grad() + def __call__( + self, + image: Union[str, np.ndarray, Image.Image], + preprocess: bool = False, + target_wh: tuple[int, int] = None, + ) -> Image.Image: + self._lazy_init_pipeline() + + if isinstance(image, str): + image = Image.open(image) + elif isinstance(image, np.ndarray): + image = Image.fromarray(image) + + if preprocess: + bg_remover = RembgRemover() + image = bg_remover(image) + image = self.recenter_image(image) + + if target_wh is not None: + image = image.resize(target_wh) + else: + target_wh = image.size + + image_array = np.array(image) + assert image_array.shape[-1] == 4, "Image must have alpha channel" + + raw_alpha_channel = image_array[:, :, 3] + alpha_channel = cv2.erode(raw_alpha_channel, self.kernel, iterations=1) + image_array[alpha_channel == 0, :3] = 255 # must be white background + image_array[:, :, 3] = alpha_channel + + image = self.pipeline( + prompt="", + image=Image.fromarray(image_array).convert("RGB"), + generator=torch.manual_seed(self.seed), + num_inference_steps=self.num_infer_step, + image_guidance_scale=self.image_guide_scale, + guidance_scale=self.text_guide_scale, + ).images[0] + + alpha_channel = Image.fromarray(alpha_channel) + rgba_image = image.convert("RGBA").resize(target_wh) + rgba_image.putalpha(alpha_channel) + + return rgba_image + + +if __name__ == "__main__": + delighting_model = DelightingModel() + image_path = "apps/assets/example_image/sample_12.jpg" + image = delighting_model( + image_path, preprocess=True, target_wh=(512, 512) + ) # noqa + image.save("delight.png") + + # image_path = "embodied_gen/scripts/test_robot.png" + # image = delighting_model(image_path) + # image.save("delighting_image_a2.png") diff --git a/embodied_gen/models/gs_model.py b/embodied_gen/models/gs_model.py new file mode 100644 index 0000000..b866f46 --- /dev/null +++ b/embodied_gen/models/gs_model.py @@ -0,0 +1,526 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +import logging +import os +import struct +from dataclasses import dataclass +from typing import Optional + +import cv2 +import numpy as np +import torch +from gsplat.cuda._wrapper import spherical_harmonics +from gsplat.rendering import rasterization +from plyfile import PlyData +from scipy.spatial.transform import Rotation +from embodied_gen.data.utils import gamma_shs, quat_mult, quat_to_rotmat + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +__all__ = [ + "RenderResult", + "GaussianOperator", +] + + +@dataclass +class RenderResult: + rgb: np.ndarray + depth: np.ndarray + opacity: np.ndarray + mask_threshold: float = 10 + mask: Optional[np.ndarray] = None + rgba: Optional[np.ndarray] = None + + def __post_init__(self): + if isinstance(self.rgb, torch.Tensor): + rgb = self.rgb.detach().cpu().numpy() + rgb = (rgb * 255).astype(np.uint8) + self.rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB) + if isinstance(self.depth, torch.Tensor): + self.depth = self.depth.detach().cpu().numpy() + if isinstance(self.opacity, torch.Tensor): + opacity = self.opacity.detach().cpu().numpy() + opacity = (opacity * 255).astype(np.uint8) + self.opacity = cv2.cvtColor(opacity, cv2.COLOR_GRAY2RGB) + mask = np.where(self.opacity > self.mask_threshold, 255, 0) + self.mask = mask[..., 0:1].astype(np.uint8) + self.rgba = np.concatenate([self.rgb, self.mask], axis=-1) + + +@dataclass +class GaussianBase: + _opacities: torch.Tensor + _means: torch.Tensor + _scales: torch.Tensor + _quats: torch.Tensor + _rgbs: Optional[torch.Tensor] = None + _features_dc: Optional[torch.Tensor] = None + _features_rest: Optional[torch.Tensor] = None + sh_degree: Optional[int] = 0 + device: str = "cuda" + + def __post_init__(self): + self.active_sh_degree: int = self.sh_degree + self.to(self.device) + + def to(self, device: str) -> None: + for k, v in self.__dict__.items(): + if not isinstance(v, torch.Tensor): + continue + self.__dict__[k] = v.to(device) + + def get_numpy_data(self): + data = {} + for k, v in self.__dict__.items(): + if not isinstance(v, torch.Tensor): + continue + data[k] = v.detach().cpu().numpy() + + return data + + def quat_norm(self, x: torch.Tensor) -> torch.Tensor: + return x / x.norm(dim=-1, keepdim=True) + + @classmethod + def load_from_ply( + cls, + path: str, + gamma: float = 1.0, + device: str = "cuda", + ) -> "GaussianBase": + plydata = PlyData.read(path) + xyz = torch.stack( + ( + torch.tensor(plydata.elements[0]["x"], dtype=torch.float32), + torch.tensor(plydata.elements[0]["y"], dtype=torch.float32), + torch.tensor(plydata.elements[0]["z"], dtype=torch.float32), + ), + dim=1, + ) + + opacities = torch.tensor( + plydata.elements[0]["opacity"], dtype=torch.float32 + ).unsqueeze(-1) + features_dc = torch.zeros((xyz.shape[0], 3), dtype=torch.float32) + features_dc[:, 0] = torch.tensor( + plydata.elements[0]["f_dc_0"], dtype=torch.float32 + ) + features_dc[:, 1] = torch.tensor( + plydata.elements[0]["f_dc_1"], dtype=torch.float32 + ) + features_dc[:, 2] = torch.tensor( + plydata.elements[0]["f_dc_2"], dtype=torch.float32 + ) + + scale_names = [ + p.name + for p in plydata.elements[0].properties + if p.name.startswith("scale_") + ] + scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1])) + scales = torch.zeros( + (xyz.shape[0], len(scale_names)), dtype=torch.float32 + ) + for idx, attr_name in enumerate(scale_names): + scales[:, idx] = torch.tensor( + plydata.elements[0][attr_name], dtype=torch.float32 + ) + + rot_names = [ + p.name + for p in plydata.elements[0].properties + if p.name.startswith("rot_") + ] + rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1])) + rots = torch.zeros((xyz.shape[0], len(rot_names)), dtype=torch.float32) + for idx, attr_name in enumerate(rot_names): + rots[:, idx] = torch.tensor( + plydata.elements[0][attr_name], dtype=torch.float32 + ) + + rots = rots / torch.norm(rots, dim=-1, keepdim=True) + + # extra features + extra_f_names = [ + p.name + for p in plydata.elements[0].properties + if p.name.startswith("f_rest_") + ] + extra_f_names = sorted( + extra_f_names, key=lambda x: int(x.split("_")[-1]) + ) + + max_sh_degree = int(np.sqrt((len(extra_f_names) + 3) / 3) - 1) + if max_sh_degree != 0: + features_extra = torch.zeros( + (xyz.shape[0], len(extra_f_names)), dtype=torch.float32 + ) + for idx, attr_name in enumerate(extra_f_names): + features_extra[:, idx] = torch.tensor( + plydata.elements[0][attr_name], dtype=torch.float32 + ) + + features_extra = features_extra.view( + (features_extra.shape[0], 3, (max_sh_degree + 1) ** 2 - 1) + ) + features_extra = features_extra.permute(0, 2, 1) + + if abs(gamma - 1.0) > 1e-3: + features_dc = gamma_shs(features_dc, gamma) + features_extra[..., :] = 0.0 + opacities *= 0.8 + + shs = torch.cat( + [ + features_dc.reshape(-1, 3), + features_extra.reshape(len(features_dc), -1), + ], + dim=-1, + ) + else: + # sh_dim is 0, only dc features + shs = features_dc + features_extra = None + + return cls( + sh_degree=max_sh_degree, + _means=xyz, + _opacities=opacities, + _rgbs=shs, + _scales=scales, + _quats=rots, + _features_dc=features_dc, + _features_rest=features_extra, + device=device, + ) + + def save_to_ply( + self, path: str, colors: torch.Tensor = None, enable_mask: bool = False + ): + os.makedirs(os.path.dirname(path), exist_ok=True) + numpy_data = self.get_numpy_data() + means = numpy_data["_means"] + scales = numpy_data["_scales"] + quats = numpy_data["_quats"] + opacities = numpy_data["_opacities"] + sh0 = numpy_data["_features_dc"] + shN = numpy_data.get("_features_rest", np.zeros((means.shape[0], 0))) + shN = shN.reshape(means.shape[0], -1) + + # Create a mask to identify rows with NaN or Inf in any of the numpy_data arrays # noqa + if enable_mask: + invalid_mask = ( + np.isnan(means).any(axis=1) + | np.isinf(means).any(axis=1) + | np.isnan(scales).any(axis=1) + | np.isinf(scales).any(axis=1) + | np.isnan(quats).any(axis=1) + | np.isinf(quats).any(axis=1) + | np.isnan(opacities).any(axis=0) + | np.isinf(opacities).any(axis=0) + | np.isnan(sh0).any(axis=1) + | np.isinf(sh0).any(axis=1) + | np.isnan(shN).any(axis=1) + | np.isinf(shN).any(axis=1) + ) + + # Filter out rows with NaNs or Infs from all data arrays + means = means[~invalid_mask] + scales = scales[~invalid_mask] + quats = quats[~invalid_mask] + opacities = opacities[~invalid_mask] + sh0 = sh0[~invalid_mask] + shN = shN[~invalid_mask] + + num_points = means.shape[0] + + with open(path, "wb") as f: + # Write PLY header + f.write(b"ply\n") + f.write(b"format binary_little_endian 1.0\n") + f.write(f"element vertex {num_points}\n".encode()) + f.write(b"property float x\n") + f.write(b"property float y\n") + f.write(b"property float z\n") + f.write(b"property float nx\n") + f.write(b"property float ny\n") + f.write(b"property float nz\n") + + if colors is not None: + for j in range(colors.shape[1]): + f.write(f"property float f_dc_{j}\n".encode()) + else: + for i, data in enumerate([sh0, shN]): + prefix = "f_dc" if i == 0 else "f_rest" + for j in range(data.shape[1]): + f.write(f"property float {prefix}_{j}\n".encode()) + + f.write(b"property float opacity\n") + + for i in range(scales.shape[1]): + f.write(f"property float scale_{i}\n".encode()) + for i in range(quats.shape[1]): + f.write(f"property float rot_{i}\n".encode()) + + f.write(b"end_header\n") + + # Write vertex data + for i in range(num_points): + f.write(struct.pack(" (x y z qw qx qy qz) + instance_pose = instance_pose[[0, 1, 2, 6, 3, 4, 5]] + cur_instances_quats = self.quat_norm(instance_pose[3:]) + rot_cur = quat_to_rotmat(cur_instances_quats, mode="wxyz") + + # update the means + num_gs = means.shape[0] + trans_per_pts = torch.stack([instance_pose[:3]] * num_gs, dim=0) + quat_per_pts = torch.stack([instance_pose[3:]] * num_gs, dim=0) + rot_per_pts = torch.stack([rot_cur] * num_gs, dim=0) # (num_gs, 3, 3) + + # update the means + cur_means = ( + torch.bmm(rot_per_pts, means.unsqueeze(-1)).squeeze(-1) + + trans_per_pts + ) + + # update the quats + _quats = self.quat_norm(quats) + cur_quats = quat_mult(quat_per_pts, _quats) + + return cur_means, cur_quats + + def get_gaussians( + self, + c2w: torch.Tensor = None, + instance_pose: torch.Tensor = None, + apply_activate: bool = False, + ) -> "GaussianBase": + """Get Gaussian data under the given instance_pose.""" + if c2w is None: + c2w = torch.eye(4).to(self.device) + + if instance_pose is not None: + # compute the transformed gs means and quats + world_means, world_quats = self._compute_transform( + self._means, self._quats, instance_pose.float().to(self.device) + ) + else: + world_means, world_quats = self._means, self._quats + + # get colors of gaussians + if self._features_rest is not None: + colors = torch.cat( + (self._features_dc[:, None, :], self._features_rest), dim=1 + ) + else: + colors = self._features_dc[:, None, :] + + if self.sh_degree > 0: + viewdirs = world_means.detach() - c2w[..., :3, 3] # (N, 3) + viewdirs = viewdirs / viewdirs.norm(dim=-1, keepdim=True) + rgbs = spherical_harmonics(self.sh_degree, viewdirs, colors) + rgbs = torch.clamp(rgbs + 0.5, 0.0, 1.0) + else: + rgbs = torch.sigmoid(colors[:, 0, :]) + + gs_dict = dict( + _means=world_means, + _opacities=( + torch.sigmoid(self._opacities) + if apply_activate + else self._opacities + ), + _rgbs=rgbs, + _scales=( + torch.exp(self._scales) if apply_activate else self._scales + ), + _quats=self.quat_norm(world_quats), + _features_dc=self._features_dc, + _features_rest=self._features_rest, + sh_degree=self.sh_degree, + device=self.device, + ) + + return GaussianOperator(**gs_dict) + + def rescale(self, scale: float): + if scale != 1.0: + self._means *= scale + self._scales += torch.log(self._scales.new_tensor(scale)) + + def set_scale_by_height(self, real_height: float) -> None: + def _ptp(tensor, dim): + val = tensor.max(dim=dim).values - tensor.min(dim=dim).values + return val.tolist() + + xyz_scale = max(_ptp(self._means, dim=0)) + self.rescale(1 / (xyz_scale + 1e-6)) # Normalize to [-0.5, 0.5] + raw_height = _ptp(self._means, dim=0)[1] + scale = real_height / raw_height + + self.rescale(scale) + + return + + @staticmethod + def resave_ply( + in_ply: str, + out_ply: str, + real_height: float = None, + instance_pose: np.ndarray = None, + device: str = "cuda", + ) -> None: + gs_model = GaussianOperator.load_from_ply(in_ply, device=device) + + if instance_pose is not None: + gs_model = gs_model.get_gaussians(instance_pose=instance_pose) + + if real_height is not None: + gs_model.set_scale_by_height(real_height) + + gs_model.save_to_ply(out_ply) + + return + + @staticmethod + def trans_to_quatpose( + rot_matrix: list[list[float]], + trans_matrix: list[float] = [0, 0, 0], + ) -> torch.Tensor: + if isinstance(rot_matrix, list): + rot_matrix = np.array(rot_matrix) + + rot = Rotation.from_matrix(rot_matrix) + qx, qy, qz, qw = rot.as_quat() + instance_pose = torch.tensor([*trans_matrix, qx, qy, qz, qw]) + + return instance_pose + + def render( + self, + c2w: torch.Tensor, + Ks: torch.Tensor, + image_width: int, + image_height: int, + ) -> RenderResult: + gs = self.get_gaussians(c2w, apply_activate=True) + renders, alphas, _ = rasterization( + means=gs._means, + quats=gs._quats, + scales=gs._scales, + opacities=gs._opacities.squeeze(), + colors=gs._rgbs, + viewmats=torch.linalg.inv(c2w)[None, ...], + Ks=Ks[None, ...], + width=image_width, + height=image_height, + packed=False, + absgrad=True, + sparse_grad=False, + # rasterize_mode="classic", + rasterize_mode="antialiased", + **{ + "near_plane": 0.01, + "far_plane": 1000000000, + "radius_clip": 0.0, + "render_mode": "RGB+ED", + }, + ) + renders = renders[0] + alphas = alphas[0].squeeze(-1) + + assert renders.shape[-1] == 4, f"Must render rgb, depth and alpha" + rendered_rgb, rendered_depth = torch.split(renders, [3, 1], dim=-1) + + return RenderResult( + torch.clamp(rendered_rgb, min=0, max=1), + rendered_depth, + alphas[..., None], + ) + + +if __name__ == "__main__": + input_gs = "outputs/test/debug.ply" + output_gs = "./debug_v3.ply" + gs_model: GaussianOperator = GaussianOperator.load_from_ply(input_gs) + + # ็ป• x ่ฝดๆ—‹่ฝฌ 180ยฐ + R_x = [[1, 0, 0], [0, -1, 0], [0, 0, -1]] + instance_pose = gs_model.trans_to_quatpose(R_x) + gs_model = gs_model.get_gaussians(instance_pose=instance_pose) + + gs_model.rescale(2) + + gs_model.set_scale_by_height(1.3) + + gs_model.save_to_ply(output_gs) diff --git a/embodied_gen/models/segment_model.py b/embodied_gen/models/segment_model.py new file mode 100644 index 0000000..ab92c0c --- /dev/null +++ b/embodied_gen/models/segment_model.py @@ -0,0 +1,379 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +import logging +import os +from typing import Literal, Union + +import cv2 +import numpy as np +import rembg +import torch +from huggingface_hub import snapshot_download +from PIL import Image +from segment_anything import ( + SamAutomaticMaskGenerator, + SamPredictor, + sam_model_registry, +) +from transformers import pipeline +from embodied_gen.data.utils import resize_pil, trellis_preprocess +from embodied_gen.utils.process_media import filter_small_connected_components +from embodied_gen.validators.quality_checkers import ImageSegChecker + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +__all__ = [ + "SAMRemover", + "SAMPredictor", + "RembgRemover", + "get_segmented_image_by_agent", +] + + +class SAMRemover(object): + """Loading SAM models and performing background removal on images. + + Attributes: + checkpoint (str): Path to the model checkpoint. + model_type (str): Type of the SAM model to load (default: "vit_h"). + area_ratio (float): Area ratio filtering small connected components. + """ + + def __init__( + self, + checkpoint: str = None, + model_type: str = "vit_h", + area_ratio: float = 15, + ): + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.model_type = model_type + self.area_ratio = area_ratio + + if checkpoint is None: + suffix = "sam" + model_path = snapshot_download( + repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*" + ) + checkpoint = os.path.join( + model_path, suffix, "sam_vit_h_4b8939.pth" + ) + + self.mask_generator = self._load_sam_model(checkpoint) + + def _load_sam_model(self, checkpoint: str) -> SamAutomaticMaskGenerator: + sam = sam_model_registry[self.model_type](checkpoint=checkpoint) + sam.to(device=self.device) + + return SamAutomaticMaskGenerator(sam) + + def __call__( + self, image: Union[str, Image.Image, np.ndarray], save_path: str = None + ) -> Image.Image: + """Removes the background from an image using the SAM model. + + Args: + image (Union[str, Image.Image, np.ndarray]): Input image, + can be a file path, PIL Image, or numpy array. + save_path (str): Path to save the output image (default: None). + + Returns: + Image.Image: The image with background removed, + including an alpha channel. + """ + # Convert input to numpy array + if isinstance(image, str): + image = Image.open(image) + elif isinstance(image, np.ndarray): + image = Image.fromarray(image).convert("RGB") + image = resize_pil(image) + image = np.array(image.convert("RGB")) + + # Generate masks + masks = self.mask_generator.generate(image) + masks = sorted(masks, key=lambda x: x["area"], reverse=True) + + if not masks: + logger.warning( + "Segmentation failed: No mask generated, return raw image." + ) + output_image = Image.fromarray(image, mode="RGB") + else: + # Use the largest mask + best_mask = masks[0]["segmentation"] + mask = (best_mask * 255).astype(np.uint8) + mask = filter_small_connected_components( + mask, area_ratio=self.area_ratio + ) + # Apply the mask to remove the background + background_removed = cv2.bitwise_and(image, image, mask=mask) + output_image = np.dstack((background_removed, mask)) + output_image = Image.fromarray(output_image, mode="RGBA") + + if save_path is not None: + os.makedirs(os.path.dirname(save_path), exist_ok=True) + output_image.save(save_path) + + return output_image + + +class SAMPredictor(object): + def __init__( + self, + checkpoint: str = None, + model_type: str = "vit_h", + binary_thresh: float = 0.1, + device: str = "cuda", + ): + self.device = device + self.model_type = model_type + + if checkpoint is None: + suffix = "sam" + model_path = snapshot_download( + repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*" + ) + checkpoint = os.path.join( + model_path, suffix, "sam_vit_h_4b8939.pth" + ) + + self.predictor = self._load_sam_model(checkpoint) + self.binary_thresh = binary_thresh + + def _load_sam_model(self, checkpoint: str) -> SamPredictor: + sam = sam_model_registry[self.model_type](checkpoint=checkpoint) + sam.to(device=self.device) + + return SamPredictor(sam) + + def preprocess_image(self, image: Image.Image) -> np.ndarray: + if isinstance(image, str): + image = Image.open(image) + elif isinstance(image, np.ndarray): + image = Image.fromarray(image).convert("RGB") + + image = resize_pil(image) + image = np.array(image.convert("RGB")) + + return image + + def generate_masks( + self, + image: np.ndarray, + selected_points: list[list[int]], + ) -> np.ndarray: + if len(selected_points) == 0: + return [] + + points = ( + torch.Tensor([p for p, _ in selected_points]) + .to(self.predictor.device) + .unsqueeze(1) + ) + + labels = ( + torch.Tensor([int(l) for _, l in selected_points]) + .to(self.predictor.device) + .unsqueeze(1) + ) + + transformed_points = self.predictor.transform.apply_coords_torch( + points, image.shape[:2] + ) + + masks, scores, _ = self.predictor.predict_torch( + point_coords=transformed_points, + point_labels=labels, + multimask_output=True, + ) + valid_mask = masks[:, torch.argmax(scores, dim=1)] + masks_pos = valid_mask[labels[:, 0] == 1, 0].cpu().detach().numpy() + masks_neg = valid_mask[labels[:, 0] == 0, 0].cpu().detach().numpy() + if len(masks_neg) == 0: + masks_neg = np.zeros_like(masks_pos) + if len(masks_pos) == 0: + masks_pos = np.zeros_like(masks_neg) + masks_neg = masks_neg.max(axis=0, keepdims=True) + masks_pos = masks_pos.max(axis=0, keepdims=True) + valid_mask = (masks_pos.astype(int) - masks_neg.astype(int)).clip(0, 1) + + binary_mask = (valid_mask > self.binary_thresh).astype(np.int32) + + return [(mask, f"mask_{i}") for i, mask in enumerate(binary_mask)] + + def get_segmented_image( + self, image: np.ndarray, masks: list[tuple[np.ndarray, str]] + ) -> Image.Image: + seg_image = Image.fromarray(image, mode="RGB") + alpha_channel = np.zeros( + (seg_image.height, seg_image.width), dtype=np.uint8 + ) + for mask, _ in masks: + # Use the maximum to combine multiple masks + alpha_channel = np.maximum(alpha_channel, mask) + + alpha_channel = np.clip(alpha_channel, 0, 1) + alpha_channel = (alpha_channel * 255).astype(np.uint8) + alpha_image = Image.fromarray(alpha_channel, mode="L") + r, g, b = seg_image.split() + seg_image = Image.merge("RGBA", (r, g, b, alpha_image)) + + return seg_image + + def __call__( + self, + image: Union[str, Image.Image, np.ndarray], + selected_points: list[list[int]], + ) -> Image.Image: + image = self.preprocess_image(image) + self.predictor.set_image(image) + masks = self.generate_masks(image, selected_points) + + return self.get_segmented_image(image, masks) + + +class RembgRemover(object): + def __init__(self): + self.rembg_session = rembg.new_session("u2net") + + def __call__( + self, image: Union[str, Image.Image, np.ndarray], save_path: str = None + ) -> Image.Image: + if isinstance(image, str): + image = Image.open(image) + elif isinstance(image, np.ndarray): + image = Image.fromarray(image) + + image = resize_pil(image) + output_image = rembg.remove(image, session=self.rembg_session) + + if save_path is not None: + os.makedirs(os.path.dirname(save_path), exist_ok=True) + output_image.save(save_path) + + return output_image + + +class BMGG14Remover(object): + def __init__(self) -> None: + self.model = pipeline( + "image-segmentation", + model="briaai/RMBG-1.4", + trust_remote_code=True, + ) + + def __call__( + self, image: Union[str, Image.Image, np.ndarray], save_path: str = None + ): + if isinstance(image, str): + image = Image.open(image) + elif isinstance(image, np.ndarray): + image = Image.fromarray(image) + + image = resize_pil(image) + output_image = self.model(image) + + if save_path is not None: + os.makedirs(os.path.dirname(save_path), exist_ok=True) + output_image.save(save_path) + + return output_image + + +def invert_rgba_pil( + image: Image.Image, mask: Image.Image, save_path: str = None +) -> Image.Image: + mask = (255 - np.array(mask))[..., None] + image_array = np.concatenate([np.array(image), mask], axis=-1) + inverted_image = Image.fromarray(image_array, "RGBA") + + if save_path is not None: + os.makedirs(os.path.dirname(save_path), exist_ok=True) + inverted_image.save(save_path) + + return inverted_image + + +def get_segmented_image_by_agent( + image: Image.Image, + sam_remover: SAMRemover, + rbg_remover: RembgRemover, + seg_checker: ImageSegChecker = None, + save_path: str = None, + mode: Literal["loose", "strict"] = "loose", +) -> Image.Image: + def _is_valid_seg(raw_img: Image.Image, seg_img: Image.Image) -> bool: + if seg_checker is None: + return True + return raw_img.mode == "RGBA" and seg_checker([raw_img, seg_img])[0] + + out_sam = f"{save_path}_sam.png" if save_path else None + out_sam_inv = f"{save_path}_sam_inv.png" if save_path else None + out_rbg = f"{save_path}_rbg.png" if save_path else None + + seg_image = sam_remover(image, out_sam) + seg_image = seg_image.convert("RGBA") + _, _, _, alpha = seg_image.split() + seg_image_inv = invert_rgba_pil(image.convert("RGB"), alpha, out_sam_inv) + seg_image_rbg = rbg_remover(image, out_rbg) + + final_image = None + if _is_valid_seg(image, seg_image): + final_image = seg_image + elif _is_valid_seg(image, seg_image_inv): + final_image = seg_image_inv + elif _is_valid_seg(image, seg_image_rbg): + logger.warning(f"Failed to segment by `SAM`, retry with `rembg`.") + final_image = seg_image_rbg + else: + if mode == "strict": + raise RuntimeError( + f"Failed to segment by `SAM` or `rembg`, abort." + ) + logger.warning("Failed to segment by SAM or rembg, use raw image.") + final_image = image.convert("RGBA") + + if save_path: + final_image.save(save_path) + + final_image = trellis_preprocess(final_image) + + return final_image + + +if __name__ == "__main__": + input_image = "outputs/text2image/demo_objects/electrical/sample_0.jpg" + output_image = "sample_0_seg2.png" + + # input_image = "outputs/text2image/tmp/coffee_machine.jpeg" + # output_image = "outputs/text2image/tmp/coffee_machine_seg.png" + + # input_image = "outputs/text2image/tmp/bucket.jpeg" + # output_image = "outputs/text2image/tmp/bucket_seg.png" + + remover = SAMRemover(model_type="vit_h") + remover = RembgRemover() + clean_image = remover(input_image) + clean_image.save(output_image) + get_segmented_image_by_agent( + Image.open(input_image), remover, remover, None, "./test_seg.png" + ) + + remover = BMGG14Remover() + remover("embodied_gen/models/test_seg.jpg", "./seg.png") diff --git a/embodied_gen/models/sr_model.py b/embodied_gen/models/sr_model.py new file mode 100644 index 0000000..40310bb --- /dev/null +++ b/embodied_gen/models/sr_model.py @@ -0,0 +1,174 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +import logging +import os +from typing import Union + +import numpy as np +import spaces +import torch +from huggingface_hub import snapshot_download +from PIL import Image +from embodied_gen.data.utils import get_images_from_grid + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO +) +logger = logging.getLogger(__name__) + + +__all__ = [ + "ImageStableSR", + "ImageRealESRGAN", +] + + +class ImageStableSR: + """Super-resolution image upscaler using Stable Diffusion x4 upscaling model from StabilityAI.""" + + def __init__( + self, + model_path: str = "stabilityai/stable-diffusion-x4-upscaler", + device="cuda", + ) -> None: + from diffusers import StableDiffusionUpscalePipeline + + self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained( + model_path, + torch_dtype=torch.float16, + ).to(device) + self.up_pipeline_x4.set_progress_bar_config(disable=True) + self.up_pipeline_x4.enable_model_cpu_offload() + + @spaces.GPU + def __call__( + self, + image: Union[Image.Image, np.ndarray], + prompt: str = "", + infer_step: int = 20, + ) -> Image.Image: + if isinstance(image, np.ndarray): + image = Image.fromarray(image) + + image = image.convert("RGB") + + with torch.no_grad(): + upscaled_image = self.up_pipeline_x4( + image=image, + prompt=[prompt], + num_inference_steps=infer_step, + ).images[0] + + return upscaled_image + + +class ImageRealESRGAN: + """A wrapper for Real-ESRGAN-based image super-resolution. + + This class uses the RealESRGAN model to perform image upscaling, + typically by a factor of 4. + + Attributes: + outscale (int): The output image scale factor (e.g., 2, 4). + model_path (str): Path to the pre-trained model weights. + """ + + def __init__(self, outscale: int, model_path: str = None) -> None: + # monkey patch to support torchvision>=0.16 + import torchvision + from packaging import version + + if version.parse(torchvision.__version__) > version.parse("0.16"): + import sys + import types + + import torchvision.transforms.functional as TF + + functional_tensor = types.ModuleType( + "torchvision.transforms.functional_tensor" + ) + functional_tensor.rgb_to_grayscale = TF.rgb_to_grayscale + sys.modules["torchvision.transforms.functional_tensor"] = ( + functional_tensor + ) + + self.outscale = outscale + self.upsampler = None + + if model_path is None: + suffix = "super_resolution" + model_path = snapshot_download( + repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*" + ) + model_path = os.path.join( + model_path, suffix, "RealESRGAN_x4plus.pth" + ) + + self.model_path = model_path + + def _lazy_init(self): + if self.upsampler is None: + from basicsr.archs.rrdbnet_arch import RRDBNet + from realesrgan import RealESRGANer + + model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=4, + ) + + self.upsampler = RealESRGANer( + scale=4, + model_path=self.model_path, + model=model, + pre_pad=0, + half=True, + ) + + @spaces.GPU + def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image: + self._lazy_init() + + if isinstance(image, Image.Image): + image = np.array(image) + + with torch.no_grad(): + output, _ = self.upsampler.enhance(image, outscale=self.outscale) + + return Image.fromarray(output) + + +if __name__ == "__main__": + color_path = "outputs/texture_mesh_gen/multi_view/color_sample0.png" + + # Use RealESRGAN_x4plus for x4 (512->2048) image super resolution. + super_model = ImageRealESRGAN(outscale=4) + multiviews = get_images_from_grid(color_path, img_size=512) + multiviews = [super_model(img.convert("RGB")) for img in multiviews] + for idx, img in enumerate(multiviews): + img.save(f"sr{idx}.png") + + # # Use stable diffusion for x4 (512->2048) image super resolution. + # super_model = ImageStableSR() + # multiviews = get_images_from_grid(color_path, img_size=512) + # multiviews = [super_model(img) for img in multiviews] + # for idx, img in enumerate(multiviews): + # img.save(f"sr_stable{idx}.png") diff --git a/embodied_gen/models/text_model.py b/embodied_gen/models/text_model.py new file mode 100644 index 0000000..109762b --- /dev/null +++ b/embodied_gen/models/text_model.py @@ -0,0 +1,171 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +import logging +import random + +import numpy as np +import torch +from diffusers import ( + AutoencoderKL, + EulerDiscreteScheduler, + UNet2DConditionModel, +) +from kolors.models.modeling_chatglm import ChatGLMModel +from kolors.models.tokenization_chatglm import ChatGLMTokenizer +from kolors.models.unet_2d_condition import ( + UNet2DConditionModel as UNet2DConditionModelIP, +) +from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import ( + StableDiffusionXLPipeline, +) +from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import ( # noqa + StableDiffusionXLPipeline as StableDiffusionXLPipelineIP, +) +from PIL import Image +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +__all__ = [ + "build_text2img_ip_pipeline", + "build_text2img_pipeline", + "text2img_gen", +] + + +def build_text2img_ip_pipeline( + ckpt_dir: str, + ref_scale: float, + device: str = "cuda", +) -> StableDiffusionXLPipelineIP: + text_encoder = ChatGLMModel.from_pretrained( + f"{ckpt_dir}/text_encoder", torch_dtype=torch.float16 + ).half() + tokenizer = ChatGLMTokenizer.from_pretrained(f"{ckpt_dir}/text_encoder") + vae = AutoencoderKL.from_pretrained( + f"{ckpt_dir}/vae", revision=None + ).half() + scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler") + unet = UNet2DConditionModelIP.from_pretrained( + f"{ckpt_dir}/unet", revision=None + ).half() + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + f"{ckpt_dir}/../Kolors-IP-Adapter-Plus/image_encoder", + ignore_mismatched_sizes=True, + ).to(dtype=torch.float16) + clip_image_processor = CLIPImageProcessor(size=336, crop_size=336) + + pipe = StableDiffusionXLPipelineIP( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=clip_image_processor, + force_zeros_for_empty_prompt=False, + ) + + if hasattr(pipe.unet, "encoder_hid_proj"): + pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj + + pipe.load_ip_adapter( + f"{ckpt_dir}/../Kolors-IP-Adapter-Plus", + subfolder="", + weight_name=["ip_adapter_plus_general.bin"], + ) + pipe.set_ip_adapter_scale([ref_scale]) + + pipe = pipe.to(device) + pipe.enable_model_cpu_offload() + # pipe.enable_xformers_memory_efficient_attention() + # pipe.enable_vae_slicing() + + return pipe + + +def build_text2img_pipeline( + ckpt_dir: str, + device: str = "cuda", +) -> StableDiffusionXLPipeline: + text_encoder = ChatGLMModel.from_pretrained( + f"{ckpt_dir}/text_encoder", torch_dtype=torch.float16 + ).half() + tokenizer = ChatGLMTokenizer.from_pretrained(f"{ckpt_dir}/text_encoder") + vae = AutoencoderKL.from_pretrained( + f"{ckpt_dir}/vae", revision=None + ).half() + scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler") + unet = UNet2DConditionModel.from_pretrained( + f"{ckpt_dir}/unet", revision=None + ).half() + pipe = StableDiffusionXLPipeline( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + force_zeros_for_empty_prompt=False, + ) + pipe = pipe.to(device) + pipe.enable_model_cpu_offload() + pipe.enable_xformers_memory_efficient_attention() + + return pipe + + +def text2img_gen( + prompt: str, + n_sample: int, + guidance_scale: float, + pipeline: StableDiffusionXLPipeline | StableDiffusionXLPipelineIP, + ip_image: Image.Image | str = None, + image_wh: tuple[int, int] = [1024, 1024], + infer_step: int = 50, + ip_image_size: int = 512, + seed: int = None, +) -> list[Image.Image]: + prompt = "Single " + prompt + ", in the center of the image" + prompt += ", high quality, high resolution, best quality, white background, 3D style" # noqa + logger.info(f"Processing prompt: {prompt}") + + generator = None + if seed is not None: + generator = torch.Generator(pipeline.device).manual_seed(seed) + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + kwargs = dict( + prompt=prompt, + height=image_wh[1], + width=image_wh[0], + num_inference_steps=infer_step, + guidance_scale=guidance_scale, + num_images_per_prompt=n_sample, + generator=generator, + ) + if ip_image is not None: + if isinstance(ip_image, str): + ip_image = Image.open(ip_image) + ip_image = ip_image.resize((ip_image_size, ip_image_size)) + kwargs.update(ip_adapter_image=[ip_image]) + + return pipeline(**kwargs).images diff --git a/embodied_gen/models/texture_model.py b/embodied_gen/models/texture_model.py new file mode 100644 index 0000000..dd2b5d5 --- /dev/null +++ b/embodied_gen/models/texture_model.py @@ -0,0 +1,108 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +import os + +import torch +from diffusers import AutoencoderKL, DiffusionPipeline, EulerDiscreteScheduler +from huggingface_hub import snapshot_download +from kolors.models.controlnet import ControlNetModel +from kolors.models.modeling_chatglm import ChatGLMModel +from kolors.models.tokenization_chatglm import ChatGLMTokenizer +from kolors.models.unet_2d_condition import UNet2DConditionModel +from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import ( + StableDiffusionXLControlNetImg2ImgPipeline, +) +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +__all__ = [ + "build_texture_gen_pipe", +] + + +def build_texture_gen_pipe( + base_ckpt_dir: str, + controlnet_ckpt: str = None, + ip_adapt_scale: float = 0, + device: str = "cuda", +) -> DiffusionPipeline: + tokenizer = ChatGLMTokenizer.from_pretrained( + f"{base_ckpt_dir}/Kolors/text_encoder" + ) + text_encoder = ChatGLMModel.from_pretrained( + f"{base_ckpt_dir}/Kolors/text_encoder", torch_dtype=torch.float16 + ).half() + vae = AutoencoderKL.from_pretrained( + f"{base_ckpt_dir}/Kolors/vae", revision=None + ).half() + unet = UNet2DConditionModel.from_pretrained( + f"{base_ckpt_dir}/Kolors/unet", revision=None + ).half() + scheduler = EulerDiscreteScheduler.from_pretrained( + f"{base_ckpt_dir}/Kolors/scheduler" + ) + + if controlnet_ckpt is None: + suffix = "geo_cond_mv" + model_path = snapshot_download( + repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*" + ) + controlnet_ckpt = os.path.join(model_path, suffix) + + controlnet = ControlNetModel.from_pretrained( + controlnet_ckpt, use_safetensors=True + ).half() + + # IP-Adapter model + image_encoder = None + clip_image_processor = None + if ip_adapt_scale > 0: + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + f"{base_ckpt_dir}/Kolors-IP-Adapter-Plus/image_encoder", + # ignore_mismatched_sizes=True, + ).to(dtype=torch.float16) + ip_img_size = 336 + clip_image_processor = CLIPImageProcessor( + size=ip_img_size, crop_size=ip_img_size + ) + + pipe = StableDiffusionXLControlNetImg2ImgPipeline( + vae=vae, + controlnet=controlnet, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=clip_image_processor, + force_zeros_for_empty_prompt=False, + ) + + if ip_adapt_scale > 0: + if hasattr(pipe.unet, "encoder_hid_proj"): + pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj + pipe.load_ip_adapter( + f"{base_ckpt_dir}/Kolors-IP-Adapter-Plus", + subfolder="", + weight_name=["ip_adapter_plus_general.bin"], + ) + pipe.set_ip_adapter_scale([ip_adapt_scale]) + + pipe = pipe.to(device) + pipe.enable_model_cpu_offload() + + return pipe diff --git a/embodied_gen/scripts/imageto3d.py b/embodied_gen/scripts/imageto3d.py new file mode 100644 index 0000000..4ebcca8 --- /dev/null +++ b/embodied_gen/scripts/imageto3d.py @@ -0,0 +1,311 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +import argparse +import logging +import os +import sys +from glob import glob +from shutil import copy, copytree + +import numpy as np +import trimesh +from PIL import Image +from embodied_gen.data.backproject_v2 import entrypoint as backproject_api +from embodied_gen.data.utils import delete_dir, trellis_preprocess +from embodied_gen.models.delight_model import DelightingModel +from embodied_gen.models.gs_model import GaussianOperator +from embodied_gen.models.segment_model import ( + BMGG14Remover, + RembgRemover, + SAMPredictor, +) +from embodied_gen.models.sr_model import ImageRealESRGAN +from embodied_gen.scripts.render_gs import entrypoint as render_gs_api +from embodied_gen.utils.gpt_clients import GPT_CLIENT +from embodied_gen.utils.process_media import merge_images_video, render_video +from embodied_gen.utils.tags import VERSION +from embodied_gen.validators.quality_checkers import ( + BaseChecker, + ImageAestheticChecker, + ImageSegChecker, + MeshGeoChecker, +) +from embodied_gen.validators.urdf_convertor import URDFGenerator + +current_file_path = os.path.abspath(__file__) +current_dir = os.path.dirname(current_file_path) +sys.path.append(os.path.join(current_dir, "../..")) +from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO +) +logger = logging.getLogger(__name__) + + +os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser( + "~/.cache/torch_extensions" +) +os.environ["GRADIO_ANALYTICS_ENABLED"] = "false" +os.environ["SPCONV_ALGO"] = "native" + + +DELIGHT = DelightingModel() +IMAGESR_MODEL = ImageRealESRGAN(outscale=4) + +RBG_REMOVER = RembgRemover() +RBG14_REMOVER = BMGG14Remover() +SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu") +PIPELINE = TrellisImageTo3DPipeline.from_pretrained( + "microsoft/TRELLIS-image-large" +) +PIPELINE.cuda() +SEG_CHECKER = ImageSegChecker(GPT_CLIENT) +GEO_CHECKER = MeshGeoChecker(GPT_CLIENT) +AESTHETIC_CHECKER = ImageAestheticChecker() +CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER] +TMP_DIR = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d" +) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Image to 3D pipeline args.") + parser.add_argument( + "--image_path", type=str, nargs="+", help="Path to the input images." + ) + parser.add_argument( + "--image_root", type=str, help="Path to the input images folder." + ) + parser.add_argument( + "--output_root", + type=str, + required=True, + help="Root directory for saving outputs.", + ) + parser.add_argument( + "--height_range", + type=str, + default=None, + help="The hight in meter to restore the mesh real size.", + ) + parser.add_argument( + "--mass_range", + type=str, + default=None, + help="The mass in kg to restore the mesh real weight.", + ) + parser.add_argument("--asset_type", type=str, default=None) + parser.add_argument("--skip_exists", action="store_true") + parser.add_argument("--strict_seg", action="store_true") + parser.add_argument("--version", type=str, default=VERSION) + parser.add_argument("--remove_intermediate", type=bool, default=True) + args = parser.parse_args() + + assert ( + args.image_path or args.image_root + ), "Please provide either --image_path or --image_root." + if not args.image_path: + args.image_path = glob(os.path.join(args.image_root, "*.png")) + args.image_path += glob(os.path.join(args.image_root, "*.jpg")) + args.image_path += glob(os.path.join(args.image_root, "*.jpeg")) + + return args + + +if __name__ == "__main__": + args = parse_args() + + for image_path in args.image_path: + try: + filename = os.path.basename(image_path).split(".")[0] + output_root = args.output_root + if args.image_root is not None or len(args.image_path) > 1: + output_root = os.path.join(output_root, filename) + os.makedirs(output_root, exist_ok=True) + + mesh_out = f"{output_root}/{filename}.obj" + if args.skip_exists and os.path.exists(mesh_out): + logger.info( + f"Skip {image_path}, already processed in {mesh_out}" + ) + continue + + image = Image.open(image_path) + image.save(f"{output_root}/{filename}_raw.png") + + # Segmentation: Get segmented image using SAM or Rembg. + seg_path = f"{output_root}/{filename}_cond.png" + if image.mode != "RGBA": + seg_image = RBG_REMOVER(image, save_path=seg_path) + seg_image = trellis_preprocess(seg_image) + else: + seg_image = image + seg_image.save(seg_path) + + # Run the pipeline + try: + outputs = PIPELINE.run( + seg_image, + preprocess_image=False, + # Optional parameters + # seed=1, + # sparse_structure_sampler_params={ + # "steps": 12, + # "cfg_strength": 7.5, + # }, + # slat_sampler_params={ + # "steps": 12, + # "cfg_strength": 3, + # }, + ) + except Exception as e: + logger.error( + f"[Pipeline Failed] process {image_path}: {e}, skip." + ) + continue + + # Render and save color and mesh videos + gs_model = outputs["gaussian"][0] + mesh_model = outputs["mesh"][0] + color_images = render_video(gs_model)["color"] + normal_images = render_video(mesh_model)["normal"] + video_path = os.path.join(output_root, "gs_mesh.mp4") + merge_images_video(color_images, normal_images, video_path) + + # Save the raw Gaussian model + gs_path = mesh_out.replace(".obj", "_gs.ply") + gs_model.save_ply(gs_path) + + # Rotate mesh and GS by 90 degrees around Z-axis. + rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]] + gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]] + mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]] + + # Addtional rotation for GS to align mesh. + gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix) + pose = GaussianOperator.trans_to_quatpose(gs_rot) + aligned_gs_path = gs_path.replace(".ply", "_aligned.ply") + GaussianOperator.resave_ply( + in_ply=gs_path, + out_ply=aligned_gs_path, + instance_pose=pose, + device="cpu", + ) + color_path = os.path.join(output_root, "color.png") + render_gs_api(aligned_gs_path, color_path) + + mesh = trimesh.Trimesh( + vertices=mesh_model.vertices.cpu().numpy(), + faces=mesh_model.faces.cpu().numpy(), + ) + mesh.vertices = mesh.vertices @ np.array(mesh_add_rot) + mesh.vertices = mesh.vertices @ np.array(rot_matrix) + + mesh_obj_path = os.path.join(output_root, f"{filename}.obj") + mesh.export(mesh_obj_path) + + mesh = backproject_api( + delight_model=DELIGHT, + imagesr_model=IMAGESR_MODEL, + color_path=color_path, + mesh_path=mesh_obj_path, + output_path=mesh_obj_path, + skip_fix_mesh=False, + delight=True, + texture_wh=[2048, 2048], + ) + + mesh_glb_path = os.path.join(output_root, f"{filename}.glb") + mesh.export(mesh_glb_path) + + urdf_convertor = URDFGenerator(GPT_CLIENT, render_view_num=4) + asset_attrs = { + "version": VERSION, + "gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply", + } + if args.height_range: + min_height, max_height = map( + float, args.height_range.split("-") + ) + asset_attrs["min_height"] = min_height + asset_attrs["max_height"] = max_height + if args.mass_range: + min_mass, max_mass = map(float, args.mass_range.split("-")) + asset_attrs["min_mass"] = min_mass + asset_attrs["max_mass"] = max_mass + if args.asset_type: + asset_attrs["category"] = args.asset_type + if args.version: + asset_attrs["version"] = args.version + + urdf_root = f"{output_root}/URDF_{filename}" + urdf_path = urdf_convertor( + mesh_path=mesh_obj_path, + output_root=urdf_root, + **asset_attrs, + ) + + # Rescale GS and save to URDF/mesh folder. + real_height = urdf_convertor.get_attr_from_urdf( + urdf_path, attr_name="real_height" + ) + out_gs = f"{urdf_root}/{urdf_convertor.output_mesh_dir}/{filename}_gs.ply" # noqa + GaussianOperator.resave_ply( + in_ply=aligned_gs_path, + out_ply=out_gs, + real_height=real_height, + device="cpu", + ) + + # Quality check and update .urdf file. + mesh_out = f"{urdf_root}/{urdf_convertor.output_mesh_dir}/{filename}.obj" # noqa + trimesh.load(mesh_out).export(mesh_out.replace(".obj", ".glb")) + + image_dir = f"{urdf_root}/{urdf_convertor.output_render_dir}/image_color" # noqa + image_paths = glob(f"{image_dir}/*.png") + images_list = [] + for checker in CHECKERS: + images = image_paths + if isinstance(checker, ImageSegChecker): + images = [ + f"{output_root}/{filename}_raw.png", + f"{output_root}/{filename}_cond.png", + ] + images_list.append(images) + + results = BaseChecker.validate(CHECKERS, images_list) + urdf_convertor.add_quality_tag(urdf_path, results) + + # Organize the final result files + result_dir = f"{output_root}/result" + os.makedirs(result_dir, exist_ok=True) + copy(urdf_path, f"{result_dir}/{os.path.basename(urdf_path)}") + copytree( + f"{urdf_root}/{urdf_convertor.output_mesh_dir}", + f"{result_dir}/{urdf_convertor.output_mesh_dir}", + ) + copy(video_path, f"{result_dir}/video.mp4") + if args.remove_intermediate: + delete_dir(output_root, keep_subs=["result"]) + + except Exception as e: + logger.error(f"Failed to process {image_path}: {e}, skip.") + continue + + logger.info(f"Processing complete. Outputs saved to {args.output_root}") diff --git a/embodied_gen/scripts/render_gs.py b/embodied_gen/scripts/render_gs.py new file mode 100644 index 0000000..16e7f37 --- /dev/null +++ b/embodied_gen/scripts/render_gs.py @@ -0,0 +1,175 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +import argparse +import logging +import math +import os + +import cv2 +import numpy as np +import spaces +import torch +from tqdm import tqdm +from embodied_gen.data.utils import ( + CameraSetting, + init_kal_camera, + normalize_vertices_array, +) +from embodied_gen.models.gs_model import GaussianOperator + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO +) +logger = logging.getLogger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Render GS color images") + + parser.add_argument( + "--input_gs", type=str, help="Input render GS.ply path." + ) + parser.add_argument( + "--output_path", + type=str, + help="Output grid image path for rendered GS color images.", + ) + parser.add_argument( + "--num_images", type=int, default=6, help="Number of images to render." + ) + parser.add_argument( + "--elevation", + type=float, + nargs="+", + default=[20.0, -10.0], + help="Elevation angles for the camera (default: [20.0, -10.0])", + ) + parser.add_argument( + "--distance", + type=float, + default=5, + help="Camera distance (default: 5)", + ) + parser.add_argument( + "--resolution_hw", + type=int, + nargs=2, + default=(512, 512), + help="Resolution of the output images (default: (512, 512))", + ) + parser.add_argument( + "--fov", + type=float, + default=30, + help="Field of view in degrees (default: 30)", + ) + parser.add_argument( + "--device", + type=str, + choices=["cpu", "cuda"], + default="cuda", + help="Device to run on (default: `cuda`)", + ) + parser.add_argument( + "--image_size", + type=int, + default=512, + help="Output image size for single view in color grid (default: 512)", + ) + + args, unknown = parser.parse_known_args() + + return args + + +def load_gs_model( + input_gs: str, pre_quat: list[float] = [0.0, 0.7071, 0.0, -0.7071] +) -> GaussianOperator: + gs_model = GaussianOperator.load_from_ply(input_gs) + # Normalize vertices to [-1, 1], center to (0, 0, 0). + _, scale, center = normalize_vertices_array(gs_model._means) + scale, center = float(scale), center.tolist() + transpose = [*[-v for v in center], *pre_quat] + instance_pose = torch.tensor(transpose).to(gs_model.device) + gs_model = gs_model.get_gaussians(instance_pose=instance_pose) + gs_model.rescale(scale) + + return gs_model + + +@spaces.GPU +def entrypoint(input_gs: str = None, output_path: str = None) -> None: + args = parse_args() + if isinstance(input_gs, str): + args.input_gs = input_gs + if isinstance(output_path, str): + args.output_path = output_path + + # Setup camera parameters + camera_params = CameraSetting( + num_images=args.num_images, + elevation=args.elevation, + distance=args.distance, + resolution_hw=args.resolution_hw, + fov=math.radians(args.fov), + device=args.device, + ) + camera = init_kal_camera(camera_params) + matrix_mv = camera.view_matrix() # (n_cam 4 4) world2cam + matrix_mv[:, :3, 3] = -matrix_mv[:, :3, 3] + w2cs = matrix_mv.to(camera_params.device) + c2ws = [torch.linalg.inv(matrix) for matrix in w2cs] + Ks = torch.tensor(camera_params.Ks).to(camera_params.device) + + # Load GS model and normalize. + gs_model = load_gs_model(args.input_gs, pre_quat=[0.0, 0.0, 1.0, 0.0]) + + # Render GS color images. + images = [] + for idx in tqdm(range(len(c2ws)), desc="Rendering GS"): + result = gs_model.render( + c2ws[idx], + Ks=Ks, + image_width=camera_params.resolution_hw[1], + image_height=camera_params.resolution_hw[0], + ) + color = cv2.resize( + result.rgba, + (args.image_size, args.image_size), + interpolation=cv2.INTER_AREA, + ) + images.append(color) + + # Cat color images into grid image and save. + select_idxs = [[0, 2, 1], [5, 4, 3]] # fix order for 6 views + grid_image = [] + for row_idxs in select_idxs: + row_image = [] + for row_idx in row_idxs: + row_image.append(images[row_idx]) + row_image = np.concatenate(row_image, axis=1) + grid_image.append(row_image) + + grid_image = np.concatenate(grid_image, axis=0) + os.makedirs(os.path.dirname(args.output_path), exist_ok=True) + cv2.imwrite(args.output_path, grid_image) + logger.info(f"Saved grid image to {args.output_path}") + + +if __name__ == "__main__": + entrypoint() diff --git a/embodied_gen/scripts/render_mv.py b/embodied_gen/scripts/render_mv.py new file mode 100644 index 0000000..61c8013 --- /dev/null +++ b/embodied_gen/scripts/render_mv.py @@ -0,0 +1,198 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +import logging +import os +import random +from typing import List, Tuple + +import fire +import numpy as np +import torch +from diffusers.utils import make_image_grid +from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import ( + StableDiffusionXLControlNetImg2ImgPipeline, +) +from PIL import Image, ImageEnhance, ImageFilter +from torchvision import transforms +from embodied_gen.data.datasets import Asset3dGenDataset +from embodied_gen.models.texture_model import build_texture_gen_pipe + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def get_init_noise_image(image: Image.Image) -> Image.Image: + blurred_image = image.convert("L").filter( + ImageFilter.GaussianBlur(radius=3) + ) + + enhancer = ImageEnhance.Contrast(blurred_image) + image_decreased_contrast = enhancer.enhance(factor=0.5) + + return image_decreased_contrast + + +def infer_pipe( + index_file: str, + controlnet_ckpt: str = None, + uid: str = None, + prompt: str = None, + controlnet_cond_scale: float = 0.4, + control_guidance_end: float = 0.9, + strength: float = 1.0, + num_inference_steps: int = 50, + guidance_scale: float = 10, + ip_adapt_scale: float = 0, + ip_img_path: str = None, + sub_idxs: List[List[int]] = None, + num_images_per_prompt: int = 3, # increase if want similar images. + device: str = "cuda", + save_dir: str = "infer_vis", + seed: int = None, + target_hw: tuple[int, int] = (512, 512), + pipeline: StableDiffusionXLControlNetImg2ImgPipeline = None, +) -> str: + # sub_idxs = [[0, 1, 2], [3, 4, 5]] # None for single image. + if sub_idxs is None: + sub_idxs = [[random.randint(0, 5)]] # 6 views. + target_hw = [2 * size for size in target_hw] + + transform_list = [ + transforms.Resize( + target_hw, interpolation=transforms.InterpolationMode.BILINEAR + ), + transforms.CenterCrop(target_hw), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + image_transform = transforms.Compose(transform_list) + control_transform = transforms.Compose(transform_list[:-1]) + + grid_hw = (target_hw[0] * len(sub_idxs), target_hw[1] * len(sub_idxs[0])) + dataset = Asset3dGenDataset( + index_file, target_hw=grid_hw, sub_idxs=sub_idxs + ) + + if uid is None: + uid = random.choice(list(dataset.meta_info.keys())) + if prompt is None: + prompt = dataset.meta_info[uid]["capture"] + if isinstance(prompt, List) or isinstance(prompt, Tuple): + prompt = ", ".join(map(str, prompt)) + # prompt += "high quality, ultra-clear, high resolution, best quality, 4k" + # prompt += "้ซ˜ๅ“่ดจ,ๆธ…ๆ™ฐ,็ป†่Š‚" + prompt += ", high quality, high resolution, best quality" + # prompt += ", with diffuse lighting, showing no reflections." + logger.info(f"Inference with prompt: {prompt}") + + negative_prompt = "nsfw,้˜ดๅฝฑ,ไฝŽๅˆ†่พจ็އ,ไผชๅฝฑใ€ๆจก็ณŠ,้œ“่™น็ฏ,้ซ˜ๅ…‰,้•œ้ขๅๅฐ„" + + control_image = dataset.fetch_sample_grid_images( + uid, + attrs=["image_view_normal", "image_position", "image_mask"], + sub_idxs=sub_idxs, + transform=control_transform, + ) + + color_image = dataset.fetch_sample_grid_images( + uid, + attrs=["image_color"], + sub_idxs=sub_idxs, + transform=image_transform, + ) + + normal_pil, position_pil, mask_pil, color_pil = dataset.visualize_item( + control_image, + color_image, + save_dir=save_dir, + ) + + if pipeline is None: + pipeline = build_texture_gen_pipe( + base_ckpt_dir="./weights", + controlnet_ckpt=controlnet_ckpt, + ip_adapt_scale=ip_adapt_scale, + device=device, + ) + + if ip_adapt_scale > 0 and ip_img_path is not None and len(ip_img_path) > 0: + ip_image = Image.open(ip_img_path).convert("RGB") + ip_image = ip_image.resize(target_hw[::-1]) + ip_image = [ip_image] + pipeline.set_ip_adapter_scale([ip_adapt_scale]) + else: + ip_image = None + + generator = None + if seed is not None: + generator = torch.Generator(device).manual_seed(seed) + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + init_image = get_init_noise_image(normal_pil) + # init_image = get_init_noise_image(color_pil) + + images = [] + row_num, col_num = 2, 3 + img_save_paths = [] + while len(images) < col_num: + image = pipeline( + prompt=prompt, + image=init_image, + controlnet_conditioning_scale=controlnet_cond_scale, + control_guidance_end=control_guidance_end, + strength=strength, + control_image=control_image[None, ...], + negative_prompt=negative_prompt, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + ip_adapter_image=ip_image, + generator=generator, + ).images + images.extend(image) + + grid_image = [normal_pil, position_pil, color_pil] + images[:col_num] + # save_dir = os.path.join(save_dir, uid) + os.makedirs(save_dir, exist_ok=True) + + for idx in range(col_num): + rgba_image = Image.merge("RGBA", (*images[idx].split(), mask_pil)) + img_save_path = os.path.join(save_dir, f"color_sample{idx}.png") + rgba_image.save(img_save_path) + img_save_paths.append(img_save_path) + + sub_idxs = "_".join( + [str(item) for sublist in sub_idxs for item in sublist] + ) + save_path = os.path.join( + save_dir, f"sample_idx{str(sub_idxs)}_ip{ip_adapt_scale}.jpg" + ) + make_image_grid(grid_image, row_num, col_num).save(save_path) + logger.info(f"Visualize in {save_path}") + + return img_save_paths + + +def entrypoint() -> None: + fire.Fire(infer_pipe) + + +if __name__ == "__main__": + entrypoint() diff --git a/embodied_gen/scripts/text2image.py b/embodied_gen/scripts/text2image.py new file mode 100644 index 0000000..3b375a3 --- /dev/null +++ b/embodied_gen/scripts/text2image.py @@ -0,0 +1,168 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +import argparse +import logging +import os + +from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import ( + StableDiffusionXLPipeline, +) +from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import ( # noqa + StableDiffusionXLPipeline as StableDiffusionXLPipelineIP, +) +from tqdm import tqdm +from embodied_gen.models.text_model import ( + build_text2img_ip_pipeline, + build_text2img_pipeline, + text2img_gen, +) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Text to Image.") + parser.add_argument( + "--prompts", + type=str, + nargs="+", + help="List of prompts (space-separated).", + ) + parser.add_argument( + "--ref_image", + type=str, + nargs="+", + help="List of ref_image paths (space-separated).", + ) + parser.add_argument( + "--output_root", + type=str, + help="Root directory for saving outputs.", + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=12.0, + help="Guidance scale for the diffusion model.", + ) + parser.add_argument( + "--ref_scale", + type=float, + default=0.3, + help="Reference image scale for the IP adapter.", + ) + parser.add_argument( + "--n_sample", + type=int, + default=1, + ) + parser.add_argument( + "--resolution", + type=int, + default=1024, + ) + parser.add_argument( + "--infer_step", + type=int, + default=50, + ) + parser.add_argument( + "--seed", + type=int, + default=0, + ) + args = parser.parse_args() + + return args + + +def entrypoint( + pipeline: StableDiffusionXLPipeline | StableDiffusionXLPipelineIP = None, + **kwargs, +) -> list[str]: + args = parse_args() + for k, v in kwargs.items(): + if hasattr(args, k) and v is not None: + setattr(args, k, v) + + prompts = args.prompts + if len(prompts) == 1 and prompts[0].endswith(".txt"): + with open(prompts[0], "r") as f: + prompts = f.readlines() + prompts = [ + prompt.strip() for prompt in prompts if prompt.strip() != "" + ] + + os.makedirs(args.output_root, exist_ok=True) + + ip_img_paths = args.ref_image + if ip_img_paths is None or len(ip_img_paths) == 0: + args.ref_scale = 0 + ip_img_paths = [None] * len(prompts) + elif isinstance(ip_img_paths, str): + ip_img_paths = [ip_img_paths] * len(prompts) + elif isinstance(ip_img_paths, list): + if len(ip_img_paths) == 1: + ip_img_paths = ip_img_paths * len(prompts) + else: + raise ValueError("Invalid ref_image paths.") + assert len(ip_img_paths) == len( + prompts + ), f"Number of ref images does not match prompts, {len(ip_img_paths)} != {len(prompts)}" # noqa + + if pipeline is None: + if args.ref_scale > 0: + pipeline = build_text2img_ip_pipeline( + "weights/Kolors", + ref_scale=args.ref_scale, + ) + else: + pipeline = build_text2img_pipeline("weights/Kolors") + + for idx, (prompt, ip_img_path) in tqdm( + enumerate(zip(prompts, ip_img_paths)), + desc="Generating images", + total=len(prompts), + ): + images = text2img_gen( + prompt=prompt, + n_sample=args.n_sample, + guidance_scale=args.guidance_scale, + pipeline=pipeline, + ip_image=ip_img_path, + image_wh=[args.resolution, args.resolution], + infer_step=args.infer_step, + seed=args.seed, + ) + + save_paths = [] + for sub_idx, image in enumerate(images): + save_path = ( + f"{args.output_root}/sample_{idx*args.n_sample+sub_idx}.png" + ) + image.save(save_path) + save_paths.append(save_path) + + logger.info(f"Images saved to {args.output_root}") + + return save_paths + + +if __name__ == "__main__": + entrypoint() diff --git a/embodied_gen/scripts/textto3d.sh b/embodied_gen/scripts/textto3d.sh new file mode 100644 index 0000000..d9648c7 --- /dev/null +++ b/embodied_gen/scripts/textto3d.sh @@ -0,0 +1,56 @@ +#!/bin/bash + +# Initialize variables +prompts=() +output_root="" + +# Parse arguments +while [[ $# -gt 0 ]]; do + case "$1" in + --prompts) + shift + while [[ $# -gt 0 && ! "$1" =~ ^-- ]]; do + prompts+=("$1") + shift + done + ;; + --output_root) + output_root="$2" + shift 2 + ;; + *) + echo "Unknown argument: $1" + exit 1 + ;; + esac +done + +# Validate required arguments +if [[ ${#prompts[@]} -eq 0 || -z "$output_root" ]]; then + echo "Missing required arguments." + echo "Usage: bash run_text2asset3d.sh --prompts \"Prompt1\" \"Prompt2\" --output_root " + exit 1 +fi + +# Print arguments (for debugging) +echo "Prompts:" +for p in "${prompts[@]}"; do + echo " - $p" +done +echo "Output root: ${output_root}" + +# Concatenate prompts for Python command +prompt_args="" +for p in "${prompts[@]}"; do + prompt_args+="\"$p\" " +done + +# Step 1: Text-to-Image +eval python3 embodied_gen/scripts/text2image.py \ + --prompts ${prompt_args} \ + --output_root "${output_root}/images" + +# Step 2: Image-to-3D +python3 embodied_gen/scripts/imageto3d.py \ + --image_root "${output_root}/images" \ + --output_root "${output_root}/asset3d" diff --git a/embodied_gen/scripts/texture_gen.sh b/embodied_gen/scripts/texture_gen.sh new file mode 100644 index 0000000..747a5fe --- /dev/null +++ b/embodied_gen/scripts/texture_gen.sh @@ -0,0 +1,80 @@ +#!/bin/bash + +while [[ $# -gt 0 ]]; do + case $1 in + --mesh_path) + mesh_path="$2" + shift 2 + ;; + --prompt) + prompt="$2" + shift 2 + ;; + --uuid) + uuid="$2" + shift 2 + ;; + --output_root) + output_root="$2" + shift 2 + ;; + *) + echo "unknown: $1" + exit 1 + ;; + esac +done + + +if [[ -z "$mesh_path" || -z "$prompt" || -z "$uuid" || -z "$output_root" ]]; then + echo "params missing" + echo "usage: bash run.sh --mesh_path --prompt --uuid --output_root " + exit 1 +fi + +# Step 1: drender-cli for condition rendering +drender-cli --mesh_path ${mesh_path} \ + --output_root ${output_root}/condition \ + --uuid ${uuid} + +# Step 2: multi-view rendering +python embodied_gen/scripts/render_mv.py \ + --index_file "${output_root}/condition/index.json" \ + --controlnet_cond_scale 0.75 \ + --guidance_scale 9 \ + --strength 0.9 \ + --num_inference_steps 40 \ + --ip_adapt_scale 0 \ + --ip_img_path None \ + --uid ${uuid} \ + --prompt "${prompt}" \ + --save_dir "${output_root}/multi_view" \ + --sub_idxs "[[0,1,2],[3,4,5]]" \ + --seed 0 + +# Step 3: backprojection +backproject-cli --mesh_path ${mesh_path} \ + --color_path ${output_root}/multi_view/color_sample0.png \ + --output_path "${output_root}/texture_mesh/${uuid}.obj" \ + --save_glb_path "${output_root}/texture_mesh/${uuid}.glb" \ + --skip_fix_mesh \ + --delight \ + --no_save_delight_img + +# Step 4: final rendering of textured mesh +drender-cli --mesh_path "${output_root}/texture_mesh/${uuid}.obj" \ + --output_root ${output_root}/texture_mesh \ + --num_images 90 \ + --elevation 20 \ + --with_mtl \ + --gen_color_mp4 \ + --pbr_light_factor 1.2 + +# Organize folders +rm -rf ${output_root}/condition +video_path="${output_root}/texture_mesh/${uuid}/color.mp4" +if [ -f "${video_path}" ]; then + cp "${video_path}" "${output_root}/texture_mesh/color.mp4" + echo "Resave video to ${output_root}/texture_mesh/color.mp4" +fi +rm -rf ${output_root}/texture_mesh/${uuid} \ No newline at end of file diff --git a/embodied_gen/utils/gpt_clients.py b/embodied_gen/utils/gpt_clients.py new file mode 100644 index 0000000..7b3f72b --- /dev/null +++ b/embodied_gen/utils/gpt_clients.py @@ -0,0 +1,211 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +import base64 +import logging +import os +from io import BytesIO +from typing import Optional + +import yaml +from openai import AzureOpenAI, OpenAI # pip install openai +from PIL import Image +from tenacity import ( + retry, + stop_after_attempt, + stop_after_delay, + wait_random_exponential, +) +from embodied_gen.utils.process_media import combine_images_to_base64 + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class GPTclient: + """A client to interact with the GPT model via OpenAI or Azure API.""" + + def __init__( + self, + endpoint: str, + api_key: str, + model_name: str = "yfb-gpt-4o", + api_version: str = None, + verbose: bool = False, + ): + if api_version is not None: + self.client = AzureOpenAI( + azure_endpoint=endpoint, + api_key=api_key, + api_version=api_version, + ) + else: + self.client = OpenAI( + base_url=endpoint, + api_key=api_key, + ) + + self.endpoint = endpoint + self.model_name = model_name + self.image_formats = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"} + self.verbose = verbose + logger.info(f"Using GPT model: {self.model_name}.") + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=(stop_after_attempt(10) | stop_after_delay(30)), + ) + def completion_with_backoff(self, **kwargs): + return self.client.chat.completions.create(**kwargs) + + def query( + self, + text_prompt: str, + image_base64: Optional[list[str | Image.Image]] = None, + system_role: Optional[str] = None, + ) -> Optional[str]: + """Queries the GPT model with a text and optional image prompts. + + Args: + text_prompt (str): The main text input that the model responds to. + image_base64 (Optional[List[str]]): A list of image base64 strings + or local image paths or PIL.Image to accompany the text prompt. + system_role (Optional[str]): Optional system-level instructions + that specify the behavior of the assistant. + + Returns: + Optional[str]: The response content generated by the model based on + the prompt. Returns `None` if an error occurs. + """ + if system_role is None: + system_role = "You are a highly knowledgeable assistant specializing in physics, engineering, and object properties." # noqa + + content_user = [ + { + "type": "text", + "text": text_prompt, + }, + ] + + # Process images if provided + if image_base64 is not None: + image_base64 = ( + image_base64 + if isinstance(image_base64, list) + else [image_base64] + ) + for img in image_base64: + if isinstance(img, Image.Image): + buffer = BytesIO() + img.save(buffer, format=img.format or "PNG") + buffer.seek(0) + image_binary = buffer.read() + img = base64.b64encode(image_binary).decode("utf-8") + elif ( + len(os.path.splitext(img)) > 1 + and os.path.splitext(img)[-1].lower() in self.image_formats + ): + if not os.path.exists(img): + raise FileNotFoundError(f"Image file not found: {img}") + with open(img, "rb") as f: + img = base64.b64encode(f.read()).decode("utf-8") + + content_user.append( + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{img}"}, + } + ) + + payload = { + "messages": [ + {"role": "system", "content": system_role}, + {"role": "user", "content": content_user}, + ], + "temperature": 0.1, + "max_tokens": 500, + "top_p": 0.1, + "frequency_penalty": 0, + "presence_penalty": 0, + "stop": None, + } + payload.update({"model": self.model_name}) + + response = None + try: + response = self.completion_with_backoff(**payload) + response = response.choices[0].message.content + except Exception as e: + logger.error(f"Error GPTclint {self.endpoint} API call: {e}") + response = None + + if self.verbose: + logger.info(f"Prompt: {text_prompt}") + logger.info(f"Response: {response}") + + return response + + +with open("embodied_gen/utils/gpt_config.yaml", "r") as f: + config = yaml.safe_load(f) + +agent_type = config["agent_type"] +agent_config = config.get(agent_type, {}) + +# Prefer environment variables, fallback to YAML config +endpoint = os.environ.get("ENDPOINT", agent_config.get("endpoint")) +api_key = os.environ.get("API_KEY", agent_config.get("api_key")) +api_version = os.environ.get("API_VERSION", agent_config.get("api_version")) +model_name = os.environ.get("MODEL_NAME", agent_config.get("model_name")) + +GPT_CLIENT = GPTclient( + endpoint=endpoint, + api_key=api_key, + api_version=api_version, + model_name=model_name, +) + +if __name__ == "__main__": + if "openrouter" in GPT_CLIENT.endpoint: + response = GPT_CLIENT.query( + text_prompt="What is the content in each image?", + image_base64=combine_images_to_base64( + [ + "outputs/text2image/demo_objects/bed/sample_0.jpg", + "outputs/imageto3d/v2/cups/sample_69/URDF_sample_69/qa_renders/image_color/003.png", # noqa + "outputs/text2image/demo_objects/cardboard/sample_1.jpg", + ] + ), # input raw image_path if only one image + ) + print(response) + else: + response = GPT_CLIENT.query( + text_prompt="What is the content in the images?", + image_base64=[ + Image.open("outputs/text2image/demo_objects/bed/sample_0.jpg"), + Image.open( + "outputs/imageto3d/v2/cups/sample_69/URDF_sample_69/qa_renders/image_color/003.png" # noqa + ), + ], + ) + print(response) + + # test2: text prompt + response = GPT_CLIENT.query( + text_prompt="What is the capital of China?" + ) + print(response) diff --git a/embodied_gen/utils/gpt_config.yaml b/embodied_gen/utils/gpt_config.yaml new file mode 100644 index 0000000..2e966bf --- /dev/null +++ b/embodied_gen/utils/gpt_config.yaml @@ -0,0 +1,14 @@ +# config.yaml +agent_type: "qwen2.5-vl" # gpt-4o or qwen2.5-vl + +gpt-4o: + endpoint: https://xxx.openai.azure.com + api_key: xxx + api_version: 2025-xx-xx + model_name: yfb-gpt-4o + +qwen2.5-vl: + endpoint: https://openrouter.ai/api/v1 + api_key: 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 + api_version: null + model_name: qwen/qwen2.5-vl-72b-instruct:free diff --git a/embodied_gen/utils/process_media.py b/embodied_gen/utils/process_media.py new file mode 100644 index 0000000..c5708e6 --- /dev/null +++ b/embodied_gen/utils/process_media.py @@ -0,0 +1,328 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +import base64 +import logging +import math +import os +import subprocess +import sys +from glob import glob +from io import BytesIO +from typing import Union + +import cv2 +import imageio +import numpy as np +import PIL.Image as Image +import spaces +import torch +from moviepy.editor import VideoFileClip, clips_array +from tqdm import tqdm + +current_file_path = os.path.abspath(__file__) +current_dir = os.path.dirname(current_file_path) +sys.path.append(os.path.join(current_dir, "../..")) +from thirdparty.TRELLIS.trellis.renderers.mesh_renderer import MeshRenderer +from thirdparty.TRELLIS.trellis.representations import MeshExtractResult +from thirdparty.TRELLIS.trellis.utils.render_utils import ( + render_frames, + yaw_pitch_r_fov_to_extrinsics_intrinsics, +) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +__all__ = [ + "render_asset3d", + "merge_images_video", + "filter_small_connected_components", + "filter_image_small_connected_components", + "combine_images_to_base64", + "render_mesh", + "render_video", + "create_mp4_from_images", + "create_gif_from_images", +] + + +@spaces.GPU +def render_asset3d( + mesh_path: str, + output_root: str, + distance: float = 5.0, + num_images: int = 1, + elevation: list[float] = (0.0,), + pbr_light_factor: float = 1.5, + return_key: str = "image_color/*", + output_subdir: str = "renders", + gen_color_mp4: bool = False, + gen_viewnormal_mp4: bool = False, + gen_glonormal_mp4: bool = False, +) -> list[str]: + command = [ + "python3", + "embodied_gen/data/differentiable_render.py", + "--mesh_path", + mesh_path, + "--output_root", + output_root, + "--uuid", + output_subdir, + "--distance", + str(distance), + "--num_images", + str(num_images), + "--elevation", + *map(str, elevation), + "--pbr_light_factor", + str(pbr_light_factor), + "--with_mtl", + ] + if gen_color_mp4: + command.append("--gen_color_mp4") + if gen_viewnormal_mp4: + command.append("--gen_viewnormal_mp4") + if gen_glonormal_mp4: + command.append("--gen_glonormal_mp4") + try: + subprocess.run(command, check=True) + except subprocess.CalledProcessError as e: + logger.error(f"Error occurred during rendering: {e}.") + + dst_paths = glob(os.path.join(output_root, output_subdir, return_key)) + + return dst_paths + + +def merge_images_video(color_images, normal_images, output_path) -> None: + width = color_images[0].shape[1] + combined_video = [ + np.hstack([rgb_img[:, : width // 2], normal_img[:, width // 2 :]]) + for rgb_img, normal_img in zip(color_images, normal_images) + ] + imageio.mimsave(output_path, combined_video, fps=50) + + return + + +def merge_video_video( + video_path1: str, video_path2: str, output_path: str +) -> None: + """Merge two videos by the left half and the right half of the videos.""" + clip1 = VideoFileClip(video_path1) + clip2 = VideoFileClip(video_path2) + + if clip1.size != clip2.size: + raise ValueError("The resolutions of the two videos do not match.") + + width, height = clip1.size + clip1_half = clip1.crop(x1=0, y1=0, x2=width // 2, y2=height) + clip2_half = clip2.crop(x1=width // 2, y1=0, x2=width, y2=height) + final_clip = clips_array([[clip1_half, clip2_half]]) + final_clip.write_videofile(output_path, codec="libx264") + + +def filter_small_connected_components( + mask: Union[Image.Image, np.ndarray], + area_ratio: float, + connectivity: int = 8, +) -> np.ndarray: + if isinstance(mask, Image.Image): + mask = np.array(mask) + num_labels, labels, stats, _ = cv2.connectedComponentsWithStats( + mask, + connectivity=connectivity, + ) + + small_components = np.zeros_like(mask, dtype=np.uint8) + mask_area = (mask != 0).sum() + min_area = mask_area // area_ratio + for label in range(1, num_labels): + area = stats[label, cv2.CC_STAT_AREA] + if area < min_area: + small_components[labels == label] = 255 + + mask = cv2.bitwise_and(mask, cv2.bitwise_not(small_components)) + + return mask + + +def filter_image_small_connected_components( + image: Union[Image.Image, np.ndarray], + area_ratio: float = 10, + connectivity: int = 8, +) -> np.ndarray: + if isinstance(image, Image.Image): + image = image.convert("RGBA") + image = np.array(image) + + mask = image[..., 3] + mask = filter_small_connected_components(mask, area_ratio, connectivity) + image[..., 3] = mask + + return image + + +def combine_images_to_base64( + images: list[str | Image.Image], + cat_row_col: tuple[int, int] = None, + target_wh: tuple[int, int] = (512, 512), +) -> str: + n_images = len(images) + if cat_row_col is None: + n_col = math.ceil(math.sqrt(n_images)) + n_row = math.ceil(n_images / n_col) + else: + n_row, n_col = cat_row_col + + images = [ + Image.open(p).convert("RGB") if isinstance(p, str) else p + for p in images[: n_row * n_col] + ] + images = [img.resize(target_wh) for img in images] + + grid_w, grid_h = n_col * target_wh[0], n_row * target_wh[1] + grid = Image.new("RGB", (grid_w, grid_h), (255, 255, 255)) + + for idx, img in enumerate(images): + row, col = divmod(idx, n_col) + grid.paste(img, (col * target_wh[0], row * target_wh[1])) + + buffer = BytesIO() + grid.save(buffer, format="PNG") + + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + +@spaces.GPU +def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs): + renderer = MeshRenderer() + renderer.rendering_options.resolution = options.get("resolution", 512) + renderer.rendering_options.near = options.get("near", 1) + renderer.rendering_options.far = options.get("far", 100) + renderer.rendering_options.ssaa = options.get("ssaa", 4) + rets = {} + for extr, intr in tqdm(zip(extrinsics, intrinsics), desc="Rendering"): + res = renderer.render(sample, extr, intr) + if "normal" not in rets: + rets["normal"] = [] + normal = torch.lerp( + torch.zeros_like(res["normal"]), res["normal"], res["mask"] + ) + normal = np.clip( + normal.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255 + ).astype(np.uint8) + rets["normal"].append(normal) + + return rets + + +@spaces.GPU +def render_video( + sample, + resolution=512, + bg_color=(0, 0, 0), + num_frames=300, + r=2, + fov=40, + **kwargs, +): + yaws = torch.linspace(0, 2 * 3.1415, num_frames) + yaws = yaws.tolist() + pitch = [0.5] * num_frames + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics( + yaws, pitch, r, fov + ) + render_fn = ( + render_mesh if isinstance(sample, MeshExtractResult) else render_frames + ) + result = render_fn( + sample, + extrinsics, + intrinsics, + {"resolution": resolution, "bg_color": bg_color}, + **kwargs, + ) + + return result + + +def create_mp4_from_images(images, output_path, fps=10, prompt=None): + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = 0.5 + font_thickness = 1 + color = (255, 255, 255) + position = (20, 25) + + with imageio.get_writer(output_path, fps=fps) as writer: + for image in images: + image = image.clip(min=0, max=1) + image = (255.0 * image).astype(np.uint8) + image = image[..., :3] + if prompt is not None: + cv2.putText( + image, + prompt, + position, + font, + font_scale, + color, + font_thickness, + ) + + writer.append_data(image) + + logger.info(f"MP4 video saved to {output_path}") + + +def create_gif_from_images(images, output_path, fps=10): + pil_images = [] + for image in images: + image = image.clip(min=0, max=1) + image = (255.0 * image).astype(np.uint8) + image = Image.fromarray(image, mode="RGBA") + pil_images.append(image.convert("RGB")) + + duration = 1000 // fps + pil_images[0].save( + output_path, + save_all=True, + append_images=pil_images[1:], + duration=duration, + loop=0, + ) + + logger.info(f"GIF saved to {output_path}") + + +if __name__ == "__main__": + # Example usage: + merge_video_video( + "outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh_glo_normal.mp4", # noqa + "outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh.mp4", # noqa + "merge.mp4", + ) + + image_base64 = combine_images_to_base64( + [ + "apps/assets/example_image/sample_00.jpg", + "apps/assets/example_image/sample_01.jpg", + "apps/assets/example_image/sample_02.jpg", + ] + ) diff --git a/embodied_gen/utils/tags.py b/embodied_gen/utils/tags.py new file mode 100644 index 0000000..07a9a63 --- /dev/null +++ b/embodied_gen/utils/tags.py @@ -0,0 +1 @@ +VERSION = "v0.1.0" diff --git a/embodied_gen/validators/aesthetic_predictor.py b/embodied_gen/validators/aesthetic_predictor.py new file mode 100644 index 0000000..5b9c557 --- /dev/null +++ b/embodied_gen/validators/aesthetic_predictor.py @@ -0,0 +1,149 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +import os + +import clip +import numpy as np +import pytorch_lightning as pl +import torch +import torch.nn as nn +from huggingface_hub import snapshot_download +from PIL import Image + + +class AestheticPredictor: + """Aesthetic Score Predictor. + + Checkpoints from https://github.com/christophschuhmann/improved-aesthetic-predictor/tree/main + + Args: + clip_model_dir (str): Path to the directory of the CLIP model. + sac_model_path (str): Path to the pre-trained SAC model. + device (str): Device to use for computation ("cuda" or "cpu"). + """ + + def __init__(self, clip_model_dir=None, sac_model_path=None, device="cpu"): + + self.device = device + + if clip_model_dir is None: + model_path = snapshot_download( + repo_id="xinjjj/RoboAssetGen", allow_patterns="aesthetic/*" + ) + suffix = "aesthetic" + model_path = snapshot_download( + repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*" + ) + clip_model_dir = os.path.join(model_path, suffix) + + if sac_model_path is None: + model_path = snapshot_download( + repo_id="xinjjj/RoboAssetGen", allow_patterns="aesthetic/*" + ) + suffix = "aesthetic" + model_path = snapshot_download( + repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*" + ) + sac_model_path = os.path.join( + model_path, suffix, "sac+logos+ava1-l14-linearMSE.pth" + ) + + self.clip_model, self.preprocess = self._load_clip_model( + clip_model_dir + ) + self.sac_model = self._load_sac_model(sac_model_path, input_size=768) + + class MLP(pl.LightningModule): # noqa + def __init__(self, input_size): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(input_size, 1024), + nn.Dropout(0.2), + nn.Linear(1024, 128), + nn.Dropout(0.2), + nn.Linear(128, 64), + nn.Dropout(0.1), + nn.Linear(64, 16), + nn.Linear(16, 1), + ) + + def forward(self, x): + return self.layers(x) + + @staticmethod + def normalized(a, axis=-1, order=2): + """Normalize the array to unit norm.""" + l2 = np.atleast_1d(np.linalg.norm(a, order, axis)) + l2[l2 == 0] = 1 + return a / np.expand_dims(l2, axis) + + def _load_clip_model(self, model_dir: str, model_name: str = "ViT-L/14"): + """Load the CLIP model.""" + model, preprocess = clip.load( + model_name, download_root=model_dir, device=self.device + ) + return model, preprocess + + def _load_sac_model(self, model_path, input_size): + """Load the SAC model.""" + model = self.MLP(input_size) + ckpt = torch.load(model_path) + model.load_state_dict(ckpt) + model.to(self.device) + model.eval() + return model + + def predict(self, image_path): + """Predict the aesthetic score for a given image. + + Args: + image_path (str): Path to the image file. + + Returns: + float: Predicted aesthetic score. + """ + pil_image = Image.open(image_path) + image = self.preprocess(pil_image).unsqueeze(0).to(self.device) + + with torch.no_grad(): + # Extract CLIP features + image_features = self.clip_model.encode_image(image) + # Normalize features + normalized_features = self.normalized( + image_features.cpu().detach().numpy() + ) + # Predict score + prediction = self.sac_model( + torch.from_numpy(normalized_features) + .type(torch.FloatTensor) + .to(self.device) + ) + + return prediction.item() + + +if __name__ == "__main__": + # Configuration + img_path = "apps/assets/example_image/sample_00.jpg" + + # Initialize the predictor + predictor = AestheticPredictor() + + # Predict the aesthetic score + score = predictor.predict(img_path) + print("Aesthetic score predicted by the model:", score) diff --git a/embodied_gen/validators/quality_checkers.py b/embodied_gen/validators/quality_checkers.py new file mode 100644 index 0000000..4608c6b --- /dev/null +++ b/embodied_gen/validators/quality_checkers.py @@ -0,0 +1,242 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +import logging +import os + +from tqdm import tqdm +from embodied_gen.utils.gpt_clients import GPT_CLIENT, GPTclient +from embodied_gen.utils.process_media import render_asset3d +from embodied_gen.validators.aesthetic_predictor import AestheticPredictor + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class BaseChecker: + def __init__(self, prompt: str = None, verbose: bool = False) -> None: + self.prompt = prompt + self.verbose = verbose + + def query(self, *args, **kwargs): + raise NotImplementedError( + "Subclasses must implement the query method." + ) + + def __call__(self, *args, **kwargs) -> bool: + response = self.query(*args, **kwargs) + if response is None: + response = "Error when calling gpt api." + + if self.verbose and response != "YES": + logger.info(response) + + flag = "YES" in response + response = "YES" if flag else response + + return flag, response + + @staticmethod + def validate( + checkers: list["BaseChecker"], images_list: list[list[str]] + ) -> list: + assert len(checkers) == len(images_list) + results = [] + overall_result = True + for checker, images in zip(checkers, images_list): + qa_flag, qa_info = checker(images) + if isinstance(qa_info, str): + qa_info = qa_info.replace("\n", ".") + results.append([checker.__class__.__name__, qa_info]) + if qa_flag is False: + overall_result = False + + results.append(["overall", "YES" if overall_result else "NO"]) + + return results + + +class MeshGeoChecker(BaseChecker): + """A geometry quality checker for 3D mesh assets using GPT-based reasoning. + + This class leverages a multi-modal GPT client to analyze rendered images + of a 3D object and determine if its geometry is complete. + + Attributes: + gpt_client (GPTclient): The GPT client used for multi-modal querying. + prompt (str): The prompt sent to the GPT model. If not provided, a default one is used. + verbose (bool): Whether to print debug information during evaluation. + """ + + def __init__( + self, + gpt_client: GPTclient, + prompt: str = None, + verbose: bool = False, + ) -> None: + super().__init__(prompt, verbose) + self.gpt_client = gpt_client + if self.prompt is None: + self.prompt = """ + Refer to the provided multi-view rendering images to evaluate + whether the geometry of the 3D object asset is complete and + whether the asset can be placed stably on the ground. + Return "YES" only if reach the requirments, + otherwise "NO" and explain the reason very briefly. + """ + + def query(self, image_paths: str) -> str: + # Hardcode tmp because of the openrouter can't input multi images. + if "openrouter" in self.gpt_client.endpoint: + from embodied_gen.utils.process_media import ( + combine_images_to_base64, + ) + + image_paths = combine_images_to_base64(image_paths) + + return self.gpt_client.query( + text_prompt=self.prompt, + image_base64=image_paths, + ) + + +class ImageSegChecker(BaseChecker): + """A segmentation quality checker for 3D assets using GPT-based reasoning. + + This class compares an original image with its segmented version to + evaluate whether the segmentation successfully isolates the main object + with minimal truncation and correct foreground extraction. + + Attributes: + gpt_client (GPTclient): GPT client used for multi-modal image analysis. + prompt (str): The prompt used to guide the GPT model for evaluation. + verbose (bool): Whether to enable verbose logging. + """ + + def __init__( + self, + gpt_client: GPTclient, + prompt: str = None, + verbose: bool = False, + ) -> None: + super().__init__(prompt, verbose) + self.gpt_client = gpt_client + if self.prompt is None: + self.prompt = """ + The first image is the original, and the second image is the + result after segmenting the main object. Evaluate the segmentation + quality to ensure the main object is clearly segmented without + significant truncation. Note that the foreground of the object + needs to be extracted instead of the background. + Minor imperfections can be ignored. If segmentation is acceptable, + return "YES" only; otherwise, return "NO" with + very brief explanation. + """ + + def query(self, image_paths: list[str]) -> str: + if len(image_paths) != 2: + raise ValueError( + "ImageSegChecker requires exactly two images: [raw_image, seg_image]." # noqa + ) + # Hardcode tmp because of the openrouter can't input multi images. + if "openrouter" in self.gpt_client.endpoint: + from embodied_gen.utils.process_media import ( + combine_images_to_base64, + ) + + image_paths = combine_images_to_base64(image_paths) + + return self.gpt_client.query( + text_prompt=self.prompt, + image_base64=image_paths, + ) + + +class ImageAestheticChecker(BaseChecker): + """A class for evaluating the aesthetic quality of images. + + Attributes: + clip_model_dir (str): Path to the CLIP model directory. + sac_model_path (str): Path to the aesthetic predictor model weights. + thresh (float): Threshold above which images are considered aesthetically acceptable. + verbose (bool): Whether to print detailed log messages. + predictor (AestheticPredictor): The model used to predict aesthetic scores. + """ + + def __init__( + self, + clip_model_dir: str = None, + sac_model_path: str = None, + thresh: float = 4.50, + verbose: bool = False, + ) -> None: + super().__init__(verbose=verbose) + self.clip_model_dir = clip_model_dir + self.sac_model_path = sac_model_path + self.thresh = thresh + self.predictor = AestheticPredictor(clip_model_dir, sac_model_path) + + def query(self, image_paths: list[str]) -> float: + scores = [self.predictor.predict(img_path) for img_path in image_paths] + return sum(scores) / len(scores) + + def __call__(self, image_paths: list[str], **kwargs) -> bool: + avg_score = self.query(image_paths) + if self.verbose: + logger.info(f"Average aesthetic score: {avg_score}") + return avg_score > self.thresh, avg_score + + +if __name__ == "__main__": + geo_checker = MeshGeoChecker(GPT_CLIENT) + seg_checker = ImageSegChecker(GPT_CLIENT) + aesthetic_checker = ImageAestheticChecker() + + checkers = [geo_checker, seg_checker, aesthetic_checker] + + output_root = "outputs/test_gpt" + + fails = [] + for idx in tqdm(range(150)): + mesh_path = f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}.obj" # noqa + if not os.path.exists(mesh_path): + continue + image_paths = render_asset3d( + mesh_path, + f"{output_root}/{idx}", + num_images=8, + elevation=(30, -30), + distance=5.5, + ) + + for cid, checker in enumerate(checkers): + if isinstance(checker, ImageSegChecker): + images = [ + f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}_raw.png", # noqa + f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}_cond.png", # noqa + ] + else: + images = image_paths + result, info = checker(images) + logger.info( + f"Checker {checker.__class__.__name__}: {result}, {info}, mesh {mesh_path}" # noqa + ) + + if result is False: + fails.append((idx, cid, info)) + + break diff --git a/embodied_gen/validators/urdf_convertor.py b/embodied_gen/validators/urdf_convertor.py new file mode 100644 index 0000000..eed9e07 --- /dev/null +++ b/embodied_gen/validators/urdf_convertor.py @@ -0,0 +1,419 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +import logging +import os +import shutil +import xml.etree.ElementTree as ET +from datetime import datetime +from xml.dom.minidom import parseString + +import numpy as np +import trimesh +from embodied_gen.data.utils import zip_files +from embodied_gen.utils.gpt_clients import GPT_CLIENT, GPTclient +from embodied_gen.utils.process_media import render_asset3d +from embodied_gen.utils.tags import VERSION + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +__all__ = ["URDFGenerator"] + + +URDF_TEMPLATE = """ + + + + + + + + + + + + + 0.8 + 0.6 + + + + + + + + + 1.0 + "0.0.0" + "unknown" + "unknown" + 0.0 + 0.0 + 0.0 + 0.0 + 0.0 + "-1" + "" + + + +""" + + +class URDFGenerator(object): + def __init__( + self, + gpt_client: GPTclient, + mesh_file_list: list[str] = ["material_0.png", "material.mtl"], + prompt_template: str = None, + attrs_name: list[str] = None, + render_dir: str = "urdf_renders", + render_view_num: int = 4, + ) -> None: + if mesh_file_list is None: + mesh_file_list = [] + self.mesh_file_list = mesh_file_list + self.output_mesh_dir = "mesh" + self.output_render_dir = render_dir + self.gpt_client = gpt_client + self.render_view_num = render_view_num + if render_view_num == 4: + view_desc = "This is orthographic projection showing the front, left, right and back views " # noqa + else: + view_desc = "This is the rendered views " + + if prompt_template is None: + prompt_template = ( + view_desc + + """of the 3D object asset, + category: {category}. + Give the category of this object asset (within 3 words), + (if category is already provided, use it directly), + accurately describe this 3D object asset (within 15 words), + and give the recommended geometric height range (unit: meter), + weight range (unit: kilogram), the average static friction + coefficient of the object relative to rubber and the average + dynamic friction coefficient of the object relative to rubber. + Return response format as shown in Example. + + Example: + Category: cup + Description: shiny golden cup with floral design + Height: 0.1-0.15 m + Weight: 0.3-0.6 kg + Static friction coefficient: 1.1 + Dynamic friction coefficient: 0.9 + """ + ) + + self.prompt_template = prompt_template + if attrs_name is None: + attrs_name = [ + "category", + "description", + "min_height", + "max_height", + "real_height", + "min_mass", + "max_mass", + "version", + "generate_time", + "gs_model", + ] + self.attrs_name = attrs_name + + def parse_response(self, response: str) -> dict[str, any]: + lines = response.split("\n") + lines = [line.strip() for line in lines if line] + category = lines[0].split(": ")[1] + description = lines[1].split(": ")[1] + min_height, max_height = map( + lambda x: float(x.strip().replace(",", "").split()[0]), + lines[2].split(": ")[1].split("-"), + ) + min_mass, max_mass = map( + lambda x: float(x.strip().replace(",", "").split()[0]), + lines[3].split(": ")[1].split("-"), + ) + mu1 = float(lines[4].split(": ")[1].replace(",", "")) + mu2 = float(lines[5].split(": ")[1].replace(",", "")) + + return { + "category": category.lower(), + "description": description.lower(), + "min_height": round(min_height, 4), + "max_height": round(max_height, 4), + "min_mass": round(min_mass, 4), + "max_mass": round(max_mass, 4), + "mu1": round(mu1, 2), + "mu2": round(mu2, 2), + "version": VERSION, + "generate_time": datetime.now().strftime("%Y%m%d%H%M%S"), + } + + def generate_urdf( + self, + input_mesh: str, + output_dir: str, + attr_dict: dict, + output_name: str = None, + ) -> str: + """Generate a URDF file for a given mesh with specified attributes. + + Args: + input_mesh (str): Path to the input mesh file. + output_dir (str): Directory to store the generated URDF + and processed mesh. + attr_dict (dict): Dictionary containing attributes like height, + mass, and friction coefficients. + output_name (str, optional): Name for the generated URDF and robot. + + Returns: + str: Path to the generated URDF file. + """ + + # 1. Load and normalize the mesh + mesh = trimesh.load(input_mesh) + mesh_scale = np.ptp(mesh.vertices, axis=0).max() + mesh.vertices /= mesh_scale # Normalize to [-0.5, 0.5] + raw_height = np.ptp(mesh.vertices, axis=0)[1] + + # 2. Scale the mesh to real height + real_height = attr_dict["real_height"] + scale = round(real_height / raw_height, 6) + mesh = mesh.apply_scale(scale) + + # 3. Prepare output directories and save scaled mesh + mesh_folder = os.path.join(output_dir, self.output_mesh_dir) + os.makedirs(mesh_folder, exist_ok=True) + + obj_name = os.path.basename(input_mesh) + mesh_output_path = os.path.join(mesh_folder, obj_name) + mesh.export(mesh_output_path) + + # 4. Copy additional mesh files, if any + input_dir = os.path.dirname(input_mesh) + for file in self.mesh_file_list: + src_file = os.path.join(input_dir, file) + dest_file = os.path.join(mesh_folder, file) + if os.path.isfile(src_file): + shutil.copy(src_file, dest_file) + + # 5. Determine output name + if output_name is None: + output_name = os.path.splitext(obj_name)[0] + + # 6. Load URDF template and update attributes + robot = ET.fromstring(URDF_TEMPLATE) + robot.set("name", output_name) + + link = robot.find("link") + if link is None: + raise ValueError("URDF template is missing 'link' element.") + link.set("name", output_name) + + # Update visual geometry + visual = link.find("visual/geometry/mesh") + if visual is not None: + visual.set( + "filename", os.path.join(self.output_mesh_dir, obj_name) + ) + visual.set("scale", "1.0 1.0 1.0") + + # Update collision geometry + collision = link.find("collision/geometry/mesh") + if collision is not None: + collision.set( + "filename", os.path.join(self.output_mesh_dir, obj_name) + ) + collision.set("scale", "1.0 1.0 1.0") + + # Update friction coefficients + gazebo = link.find("collision/gazebo") + if gazebo is not None: + for param, key in zip(["mu1", "mu2"], ["mu1", "mu2"]): + element = gazebo.find(param) + if element is not None: + element.text = f"{attr_dict[key]:.2f}" + + # Update mass + inertial = link.find("inertial/mass") + if inertial is not None: + mass_value = (attr_dict["min_mass"] + attr_dict["max_mass"]) / 2 + inertial.set("value", f"{mass_value:.4f}") + + # Add extra_info element to the link + extra_info = link.find("extra_info/scale") + if extra_info is not None: + extra_info.text = f"{scale:.6f}" + + for key in self.attrs_name: + extra_info = link.find(f"extra_info/{key}") + if extra_info is not None and key in attr_dict: + extra_info.text = f"{attr_dict[key]}" + + # 7. Write URDF to file + os.makedirs(output_dir, exist_ok=True) + urdf_path = os.path.join(output_dir, f"{output_name}.urdf") + tree = ET.ElementTree(robot) + tree.write(urdf_path, encoding="utf-8", xml_declaration=True) + + logger.info(f"URDF file saved to {urdf_path}") + + return urdf_path + + @staticmethod + def get_attr_from_urdf( + urdf_path: str, + attr_root: str = ".//link/extra_info", + attr_name: str = "scale", + ) -> float: + if not os.path.exists(urdf_path): + raise FileNotFoundError(f"URDF file not found: {urdf_path}") + + mesh_scale = 1.0 + tree = ET.parse(urdf_path) + root = tree.getroot() + extra_info = root.find(attr_root) + if extra_info is not None: + scale_element = extra_info.find(attr_name) + if scale_element is not None: + mesh_scale = float(scale_element.text) + + return mesh_scale + + @staticmethod + def add_quality_tag( + urdf_path: str, results, output_path: str = None + ) -> None: + if output_path is None: + output_path = urdf_path + + tree = ET.parse(urdf_path) + root = tree.getroot() + custom_data = ET.SubElement(root, "custom_data") + quality = ET.SubElement(custom_data, "quality") + for key, value in results: + checker_tag = ET.SubElement(quality, key) + checker_tag.text = str(value) + + rough_string = ET.tostring(root, encoding="utf-8") + formatted_string = parseString(rough_string).toprettyxml(indent=" ") + cleaned_string = "\n".join( + [line for line in formatted_string.splitlines() if line.strip()] + ) + + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, "w", encoding="utf-8") as f: + f.write(cleaned_string) + + logger.info(f"URDF files saved to {output_path}") + + def get_estimated_attributes(self, asset_attrs: dict): + estimated_attrs = { + "height": round( + (asset_attrs["min_height"] + asset_attrs["max_height"]) / 2, 4 + ), + "mass": round( + (asset_attrs["min_mass"] + asset_attrs["max_mass"]) / 2, 4 + ), + "mu": round((asset_attrs["mu1"] + asset_attrs["mu2"]) / 2, 4), + "category": asset_attrs["category"], + } + + return estimated_attrs + + def __call__( + self, + mesh_path: str, + output_root: str, + text_prompt: str = None, + category: str = "unknown", + **kwargs, + ): + if text_prompt is None or len(text_prompt) == 0: + text_prompt = self.prompt_template + text_prompt = text_prompt.format(category=category.lower()) + + image_path = render_asset3d( + mesh_path, + output_root, + num_images=self.render_view_num, + output_subdir=self.output_render_dir, + ) + + # Hardcode tmp because of the openrouter can't input multi images. + if "openrouter" in self.gpt_client.endpoint: + from embodied_gen.utils.process_media import ( + combine_images_to_base64, + ) + + image_path = combine_images_to_base64(image_path) + + response = self.gpt_client.query(text_prompt, image_path) + if response is None: + asset_attrs = { + "category": category.lower(), + "description": category.lower(), + "min_height": 1, + "max_height": 1, + "min_mass": 1, + "max_mass": 1, + "mu1": 0.8, + "mu2": 0.6, + "version": VERSION, + "generate_time": datetime.now().strftime("%Y%m%d%H%M%S"), + } + else: + asset_attrs = self.parse_response(response) + for key in self.attrs_name: + if key in kwargs: + asset_attrs[key] = kwargs[key] + + asset_attrs["real_height"] = round( + (asset_attrs["min_height"] + asset_attrs["max_height"]) / 2, 4 + ) + + self.estimated_attrs = self.get_estimated_attributes(asset_attrs) + + urdf_path = self.generate_urdf(mesh_path, output_root, asset_attrs) + + logger.info(f"response: {response}") + + return urdf_path + + +if __name__ == "__main__": + urdf_gen = URDFGenerator(GPT_CLIENT, render_view_num=4) + urdf_path = urdf_gen( + mesh_path="outputs/imageto3d/cma/o5/URDF_o5/mesh/o5.obj", + output_root="outputs/test_urdf", + # category="coffee machine", + # min_height=1.0, + # max_height=1.2, + version=VERSION, + ) + + # zip_files( + # input_paths=[ + # "scripts/apps/tmp/2umpdum3e5n/URDF_sample/mesh", + # "scripts/apps/tmp/2umpdum3e5n/URDF_sample/sample.urdf" + # ], + # output_zip="zip.zip" + # ) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..c261905 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,51 @@ +[build-system] +requires = ["setuptools", "wheel", "build"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +packages = ["embodied_gen"] + +[project] +name = "embodied_gen" +version = "v0.1.0" +readme = "README.md" +license = "Apache-2.0" +license-files = ["LICENSE", "NOTICE"] + +dependencies = [] +requires-python = ">=3.10" + +[project.optional-dependencies] +dev = [ + "cpplint==2.0.0", + "pre-commit==2.13.0", + "pydocstyle", + "black", + "isort", +] + +[project.scripts] +drender-cli = "embodied_gen.data.differentiable_render:entrypoint" +backproject-cli = "embodied_gen.data.backproject_v2:entrypoint" + +[tool.pydocstyle] +match = '(?!test_).*(?!_pb2)\.py' +match-dir = '^(?!(raw|projects|tools|k8s_submit|thirdparty)$)[\w.-]+$' +convention = "google" +add-ignore = 'D104,D107,D202,D105,D100,D102,D103,D101,E203' + +[tool.pycodestyle] +max-line-length = 79 +ignore = "E203" + +[tool.black] +line-length = 79 +exclude = "thirdparty" +skip-string-normalization = true + +[tool.isort] +line_length = 79 +profile = 'black' +no_lines_before = 'FIRSTPARTY' +known_first_party = ['embodied_gen'] +skip = "thirdparty/" \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..c1812df --- /dev/null +++ b/requirements.txt @@ -0,0 +1,41 @@ +torch==2.4.0+cu118 +torchvision==0.19.0+cu118 +xformers==0.0.27.post2 +pytorch-lightning==2.4.0 +spconv-cu120==2.3.6 +numpy==1.26.4 +triton==2.1.0 +dataclasses_json +easydict +opencv-python>4.5 +imageio==2.36.1 +imageio-ffmpeg==0.5.1 +rembg==2.0.61 +trimesh==4.4.4 +moviepy==1.0.3 +pymeshfix==0.17.0 +igraph==0.11.8 +pyvista==0.36.1 +openai==1.58.1 +transformers==4.42.4 +gradio==5.12.0 +sentencepiece==0.2.0 +diffusers==0.31.0 +xatlas==0.0.9 +onnxruntime==1.20.1 +tenacity==8.2.2 +accelerate==0.33.0 +basicsr==1.4.2 +realesrgan==0.3.0 +pydantic==2.9.2 +vtk==9.3.1 +spaces +utils3d@git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8 +clip@git+https://github.com/openai/CLIP.git +kolors@git+https://github.com/Kwai-Kolors/Kolors.git#egg=038818d +segment-anything@git+https://github.com/facebookresearch/segment-anything.git#egg=dca509f +nvdiffrast@git+https://github.com/NVlabs/nvdiffrast.git#egg=729261d +kaolin@git+https://github.com/NVIDIAGameWorks/kaolin.git@v0.16.0 +https://github.com/nerfstudio-project/gsplat/releases/download/v1.5.0/gsplat-1.5.0+pt24cu118-cp310-cp310-linux_x86_64.whl +https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.0.post2/flash_attn-2.7.0.post2+cu11torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl +https://huggingface.co/xinjjj/RoboAssetGen/resolve/main/wheel_cu118/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl diff --git a/scripts/autoformat.sh b/scripts/autoformat.sh new file mode 100644 index 0000000..bf12f78 --- /dev/null +++ b/scripts/autoformat.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +ROOT_DIR=${1} + +set -e + +black --config=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./ +isort --settings-file=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./ +pycodestyle --show-source --config=${ROOT_DIR}setup.cfg ${ROOT_DIR}./ +pydocstyle --config=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./ diff --git a/scripts/check_lint.py b/scripts/check_lint.py new file mode 100644 index 0000000..991ad1b --- /dev/null +++ b/scripts/check_lint.py @@ -0,0 +1,61 @@ +import argparse +import os +import subprocess +import sys + + +def get_root(): + current_file_path = os.path.abspath(__file__) + root_path = os.path.dirname(current_file_path) + for _ in range(2): + root_path = os.path.dirname(root_path) + return root_path + + +def cpp_lint(root_path: str): + # run external python file to lint cpp + subprocess.check_call( + " ".join( + [ + "python3", + f"{root_path}/scripts/lint_src/lint.py", + "--project=asset_recons", + "--path", + f"{root_path}/src/", + f"{root_path}/include/", + f"{root_path}/module/", + "--exclude_path", + f"{root_path}/module/web_viz/front_end/", + ] + ), + shell=True, + ) + + +def python_lint(root_path: str, auto_format: bool = False): + # run external python file to lint python + subprocess.check_call( + " ".join( + [ + "bash", + ( + f"{root_path}/scripts/lint/check_pylint.sh" + if not auto_format + else f"{root_path}/scripts/lint/autoformat.sh" + ), + f"{root_path}/", + ] + ), + shell=True, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="check format.") + parser.add_argument( + "--auto_format", action="store_true", help="auto format python" + ) + parser = parser.parse_args() + root_path = get_root() + cpp_lint(root_path) + python_lint(root_path, parser.auto_format) diff --git a/scripts/check_pylint.sh b/scripts/check_pylint.sh new file mode 100644 index 0000000..8938332 --- /dev/null +++ b/scripts/check_pylint.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +ROOT_DIR=${1} + +set -e + + +pycodestyle --show-source --config=${ROOT_DIR}setup.cfg ${ROOT_DIR}./ +pydocstyle --config=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./ +black --check --diff --config=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./ +isort --diff --settings-file=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./ diff --git a/scripts/lint/autoformat.sh b/scripts/lint/autoformat.sh new file mode 100644 index 0000000..bf12f78 --- /dev/null +++ b/scripts/lint/autoformat.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +ROOT_DIR=${1} + +set -e + +black --config=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./ +isort --settings-file=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./ +pycodestyle --show-source --config=${ROOT_DIR}setup.cfg ${ROOT_DIR}./ +pydocstyle --config=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./ diff --git a/scripts/lint/check_lint.py b/scripts/lint/check_lint.py new file mode 100644 index 0000000..991ad1b --- /dev/null +++ b/scripts/lint/check_lint.py @@ -0,0 +1,61 @@ +import argparse +import os +import subprocess +import sys + + +def get_root(): + current_file_path = os.path.abspath(__file__) + root_path = os.path.dirname(current_file_path) + for _ in range(2): + root_path = os.path.dirname(root_path) + return root_path + + +def cpp_lint(root_path: str): + # run external python file to lint cpp + subprocess.check_call( + " ".join( + [ + "python3", + f"{root_path}/scripts/lint_src/lint.py", + "--project=asset_recons", + "--path", + f"{root_path}/src/", + f"{root_path}/include/", + f"{root_path}/module/", + "--exclude_path", + f"{root_path}/module/web_viz/front_end/", + ] + ), + shell=True, + ) + + +def python_lint(root_path: str, auto_format: bool = False): + # run external python file to lint python + subprocess.check_call( + " ".join( + [ + "bash", + ( + f"{root_path}/scripts/lint/check_pylint.sh" + if not auto_format + else f"{root_path}/scripts/lint/autoformat.sh" + ), + f"{root_path}/", + ] + ), + shell=True, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="check format.") + parser.add_argument( + "--auto_format", action="store_true", help="auto format python" + ) + parser = parser.parse_args() + root_path = get_root() + cpp_lint(root_path) + python_lint(root_path, parser.auto_format) diff --git a/scripts/lint/check_pylint.sh b/scripts/lint/check_pylint.sh new file mode 100644 index 0000000..8938332 --- /dev/null +++ b/scripts/lint/check_pylint.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +ROOT_DIR=${1} + +set -e + + +pycodestyle --show-source --config=${ROOT_DIR}setup.cfg ${ROOT_DIR}./ +pydocstyle --config=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./ +black --check --diff --config=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./ +isort --diff --settings-file=${ROOT_DIR}pyproject.toml ${ROOT_DIR}./ diff --git a/scripts/lint_src/cpplint.hook b/scripts/lint_src/cpplint.hook new file mode 100755 index 0000000..a9d87ee --- /dev/null +++ b/scripts/lint_src/cpplint.hook @@ -0,0 +1,10 @@ +#!/bin/bash + +TOTAL_ERRORS=0 +if [[ ! $(which cpplint) ]]; then + pip install cpplint +fi +# diff files on local machine. +files=$(git diff --cached --name-status | awk '$1 != "D" {print $2}') +python3 scripts/lint_src/lint.py --project=asset_recons --path $files --exclude_path thirdparty patch_files; + diff --git a/scripts/lint_src/lint.py b/scripts/lint_src/lint.py new file mode 100644 index 0000000..5003569 --- /dev/null +++ b/scripts/lint_src/lint.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- + +import argparse +import codecs +import os +import re +import sys + +import cpplint +import pycodestyle +from cpplint import _cpplint_state + +CXX_SUFFIX = set(["cc", "c", "cpp", "h", "cu", "hpp"]) + + +def filepath_enumerate(paths): + """Enumerate the file paths of all subfiles of the list of paths.""" + out = [] + for path in paths: + if os.path.isfile(path): + out.append(path) + else: + for root, dirs, files in os.walk(path): + for name in files: + out.append(os.path.normpath(os.path.join(root, name))) + return out + + +class LintHelper(object): + @staticmethod + def _print_summary_map(strm, result_map, ftype): + """Print summary of certain result map.""" + if len(result_map) == 0: + return 0 + npass = len([x for k, x in result_map.items() if len(x) == 0]) + strm.write( + "=====%d/%d %s files passed check=====\n" + % (npass, len(result_map), ftype) + ) + for fname, emap in result_map.items(): + if len(emap) == 0: + continue + strm.write( + "%s: %d Errors of %d Categories map=%s\n" + % (fname, sum(emap.values()), len(emap), str(emap)) + ) + return len(result_map) - npass + + def __init__(self) -> None: + self.project_name = None + self.cpp_header_map = {} + self.cpp_src_map = {} + super().__init__() + cpplint_args = [".", "--extensions=" + (",".join(CXX_SUFFIX))] + _ = cpplint.ParseArguments(cpplint_args) + cpplint._SetFilters( + ",".join( + [ + "-build/c++11", + "-build/namespaces", + "-build/include,", + "+build/include_what_you_use", + "+build/include_order", + ] + ) + ) + cpplint._SetCountingStyle("toplevel") + cpplint._line_length = 80 + + def process_cpp(self, path, suffix): + """Process a cpp file.""" + _cpplint_state.ResetErrorCounts() + cpplint.ProcessFile(str(path), _cpplint_state.verbose_level) + _cpplint_state.PrintErrorCounts() + errors = _cpplint_state.errors_by_category.copy() + + if suffix == "h": + self.cpp_header_map[str(path)] = errors + else: + self.cpp_src_map[str(path)] = errors + + def print_summary(self, strm): + """Print summary of lint.""" + nerr = 0 + nerr += LintHelper._print_summary_map( + strm, self.cpp_header_map, "cpp-header" + ) + nerr += LintHelper._print_summary_map( + strm, self.cpp_src_map, "cpp-source" + ) + if nerr == 0: + strm.write("All passed!\n") + else: + strm.write("%d files failed lint\n" % nerr) + return nerr + + +# singleton helper for lint check +_HELPER = LintHelper() + + +def process(fname, allow_type): + """Process a file.""" + fname = str(fname) + arr = fname.rsplit(".", 1) + if fname.find("#") != -1 or arr[-1] not in allow_type: + return + if arr[-1] in CXX_SUFFIX: + _HELPER.process_cpp(fname, arr[-1]) + + +def main(): + """Main entry function.""" + parser = argparse.ArgumentParser(description="lint source codes") + parser.add_argument("--project", help="project name") + parser.add_argument( + "--path", + nargs="+", + default=[], + help="path to traverse", + required=False, + ) + parser.add_argument( + "--exclude_path", + nargs="+", + default=[], + help="exclude this path, and all subfolders " + "if path is a folder", + ) + + args = parser.parse_args() + _HELPER.project_name = args.project + allow_type = [] + allow_type += [x for x in CXX_SUFFIX] + allow_type = set(allow_type) + + # get excluded files + excluded_paths = filepath_enumerate(args.exclude_path) + for path in args.path: + if os.path.isfile(path): + normpath = os.path.normpath(path) + if normpath not in excluded_paths: + process(path, allow_type) + else: + for root, dirs, files in os.walk(path): + for name in files: + file_path = os.path.normpath(os.path.join(root, name)) + if file_path not in excluded_paths: + process(file_path, allow_type) + nerr = _HELPER.print_summary(sys.stderr) + sys.exit(nerr > 0) + + +if __name__ == "__main__": + main() diff --git a/scripts/lint_src/pep8.hook b/scripts/lint_src/pep8.hook new file mode 100755 index 0000000..6926a05 --- /dev/null +++ b/scripts/lint_src/pep8.hook @@ -0,0 +1,16 @@ +#!/bin/bash + +TOTAL_ERRORS=0 +if [[ ! $(which pycodestyle) ]]; then + pip install pycodestyle +fi +# diff files on local machine. +files=$(git diff --cached --name-status | awk '$1 != "D" {print $2}') +for file in $files; do + if [ "${file##*.}" == "py" & -f "${file}"] ; then + pycodestyle --show-source $file --config=setup.cfg; + TOTAL_ERRORS=$(expr $TOTAL_ERRORS + $?); + fi +done + +exit $TOTAL_ERRORS diff --git a/scripts/lint_src/pydocstyle.hook b/scripts/lint_src/pydocstyle.hook new file mode 100755 index 0000000..8b689d6 --- /dev/null +++ b/scripts/lint_src/pydocstyle.hook @@ -0,0 +1,16 @@ +#!/bin/bash + +TOTAL_ERRORS=0 +if [[ ! $(which pydocstyle) ]]; then + pip install pydocstyle +fi +# diff files on local machine. +files=$(git diff --cached --name-status | awk '$1 != "D" {print $2}') +for file in $files; do + if [ "${file##*.}" == "py" ] ; then + pydocstyle $file; + TOTAL_ERRORS=$(expr $TOTAL_ERRORS + $?); + fi +done + +exit $TOTAL_ERRORS diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..7ef6850 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[pycodestyle] +ignore = E203,W503,E402,E501 diff --git a/thirdparty/TRELLIS b/thirdparty/TRELLIS new file mode 160000 index 0000000..55a8e81 --- /dev/null +++ b/thirdparty/TRELLIS @@ -0,0 +1 @@ +Subproject commit 55a8e8164b195bbf927e0978f00e76c835e6011f