Compare commits
No commits in common. "eae67dc538b79ab8d4655a88c73941bd8b80cfa2" and "031ba17742325830f356b0431baa8e3ad1698544" have entirely different histories.
eae67dc538
...
031ba17742
63
.github/workflows/deploy_docs.yml
vendored
@ -1,63 +0,0 @@
|
|||||||
name: Deploy MkDocs Documentation
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- master
|
|
||||||
pull_request:
|
|
||||||
branches:
|
|
||||||
- master
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: write
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
build:
|
|
||||||
runs-on: ubuntu-22.04
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
fetch-depth: 0
|
|
||||||
|
|
||||||
- name: Setup Python
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: '3.10'
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
pip install --upgrade pip
|
|
||||||
pip install mkdocs-material mkdocstrings[python] mkdocs-git-revision-date-localized-plugin
|
|
||||||
|
|
||||||
- name: Set PYTHONPATH
|
|
||||||
run: echo "PYTHONPATH=$(pwd)" >> $GITHUB_ENV
|
|
||||||
|
|
||||||
- name: Build docs
|
|
||||||
run: mkdocs build
|
|
||||||
|
|
||||||
deploy:
|
|
||||||
needs: build
|
|
||||||
if: github.event_name == 'push' && github.ref == 'refs/heads/master'
|
|
||||||
runs-on: ubuntu-22.04
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
fetch-depth: 0
|
|
||||||
|
|
||||||
- name: Setup Python
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: '3.10'
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
pip install --upgrade pip
|
|
||||||
pip install mkdocs-material mkdocstrings[python] mkdocs-git-revision-date-localized-plugin
|
|
||||||
|
|
||||||
- name: Set PYTHONPATH
|
|
||||||
run: echo "PYTHONPATH=$(pwd)" >> $GITHUB_ENV
|
|
||||||
|
|
||||||
- name: Deploy docs
|
|
||||||
run: mkdocs gh-deploy --force
|
|
||||||
5
.gitignore
vendored
@ -1,6 +1,5 @@
|
|||||||
build/
|
build/
|
||||||
dummy/
|
dummy/
|
||||||
thirdparty/
|
|
||||||
!scripts/build
|
!scripts/build
|
||||||
builddir/
|
builddir/
|
||||||
conan-deps/
|
conan-deps/
|
||||||
@ -61,7 +60,3 @@ scripts/tools/
|
|||||||
weights
|
weights
|
||||||
apps/sessions/
|
apps/sessions/
|
||||||
apps/assets/
|
apps/assets/
|
||||||
|
|
||||||
# Larger than 1MB
|
|
||||||
docs/assets/real2sim_mujoco.gif
|
|
||||||
docs/assets/scene3d.gif
|
|
||||||
@ -1,4 +0,0 @@
|
|||||||
FROM embodiedgen:v0.1.2
|
|
||||||
WORKDIR /EmbodiedGen
|
|
||||||
COPY . .
|
|
||||||
|
|
||||||
70
README.md
@ -1,26 +1,25 @@
|
|||||||
# *EmbodiedGen*: Towards a Generative 3D World Engine for Embodied Intelligence
|
# *EmbodiedGen*: Towards a Generative 3D World Engine for Embodied Intelligence
|
||||||
|
|
||||||
[](https://horizonrobotics.github.io/EmbodiedGen/)
|
[](https://horizonrobotics.github.io/robot_lab/embodied_gen/index.html)
|
||||||
[](https://github.com/HorizonRobotics/EmbodiedGen)
|
|
||||||
[](https://arxiv.org/abs/2506.10600)
|
[](https://arxiv.org/abs/2506.10600)
|
||||||
[](https://www.youtube.com/watch?v=rG4odybuJRk)
|
[](https://www.youtube.com/watch?v=rG4odybuJRk)
|
||||||
[](https://mp.weixin.qq.com/s/HH1cPBhK2xcDbyCK4BBTbw)
|
|
||||||
<!-- [](https://horizonrobotics.github.io/robot_lab/embodied_gen/index.html) -->
|
|
||||||
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Gallery-Explorer)
|
|
||||||
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Image-to-3D)
|
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Image-to-3D)
|
||||||
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Text-to-3D)
|
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Text-to-3D)
|
||||||
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Texture-Gen)
|
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Texture-Gen)
|
||||||
|
[](https://mp.weixin.qq.com/s/HH1cPBhK2xcDbyCK4BBTbw)
|
||||||
|
|
||||||
|
|
||||||
|
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Gallery-Explorer)
|
||||||
|
|
||||||
|
|
||||||
> ***EmbodiedGen*** is a generative engine to create diverse and interactive 3D worlds composed of high-quality 3D assets(mesh & 3DGS) with plausible physics, leveraging generative AI to address the challenges of generalization in embodied intelligence related research.
|
> ***EmbodiedGen*** is a generative engine to create diverse and interactive 3D worlds composed of high-quality 3D assets(mesh & 3DGS) with plausible physics, leveraging generative AI to address the challenges of generalization in embodied intelligence related research.
|
||||||
> It composed of six key modules: `Image-to-3D`, `Text-to-3D`, `Texture Generation`, `Articulated Object Generation`, `Scene Generation` and `Layout Generation`.
|
> It composed of six key modules: `Image-to-3D`, `Text-to-3D`, `Texture Generation`, `Articulated Object Generation`, `Scene Generation` and `Layout Generation`.
|
||||||
|
|
||||||
<img src="docs/assets/overall.jpg" alt="Overall Framework" width="700"/>
|
<img src="apps/assets/overall.jpg" alt="Overall Framework" width="700"/>
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## ✨ Table of Contents of EmbodiedGen
|
## ✨ Table of Contents of EmbodiedGen
|
||||||
[](https://horizonrobotics.github.io/EmbodiedGen/) Follow the documentation to get started!
|
|
||||||
|
|
||||||
- [🖼️ Image-to-3D](#image-to-3d)
|
- [🖼️ Image-to-3D](#image-to-3d)
|
||||||
- [📝 Text-to-3D](#text-to-3d)
|
- [📝 Text-to-3D](#text-to-3d)
|
||||||
- [🎨 Texture Generation](#texture-generation)
|
- [🎨 Texture Generation](#texture-generation)
|
||||||
@ -31,13 +30,11 @@
|
|||||||
|
|
||||||
## 🚀 Quick Start
|
## 🚀 Quick Start
|
||||||
|
|
||||||
[](https://horizonrobotics.github.io/EmbodiedGen/)
|
|
||||||
|
|
||||||
### ✅ Setup Environment
|
### ✅ Setup Environment
|
||||||
```sh
|
```sh
|
||||||
git clone https://github.com/HorizonRobotics/EmbodiedGen.git
|
git clone https://github.com/HorizonRobotics/EmbodiedGen.git
|
||||||
cd EmbodiedGen
|
cd EmbodiedGen
|
||||||
git checkout v0.1.6
|
git checkout v0.1.5
|
||||||
git submodule update --init --recursive --progress
|
git submodule update --init --recursive --progress
|
||||||
conda create -n embodiedgen python=3.10.13 -y # recommended to use a new env.
|
conda create -n embodiedgen python=3.10.13 -y # recommended to use a new env.
|
||||||
conda activate embodiedgen
|
conda activate embodiedgen
|
||||||
@ -70,7 +67,7 @@ You can choose between two backends for the GPT agent:
|
|||||||
|
|
||||||
### 📸 Directly use EmbodiedGen All-Simulators-Ready Assets
|
### 📸 Directly use EmbodiedGen All-Simulators-Ready Assets
|
||||||
|
|
||||||
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Gallery-Explorer) Explore EmbodiedGen generated assets that are ready for simulation across any simulators (SAPIEN, Isaac Sim, MuJoCo, PyBullet, Genesis, Isaac Gym etc.). Details in chapter [any-simulators](#any-simulators).
|
Explore EmbodiedGen generated assets in [](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Gallery-Explorer) that are ready for simulation across any simulators (SAPIEN, Isaac Sim, MuJoCo, PyBullet, Genesis, Isaac Gym etc.). Details in chapter [any-simulators](#any-simulators).
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@ -79,7 +76,7 @@ You can choose between two backends for the GPT agent:
|
|||||||
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Image-to-3D) Generate physically plausible 3D asset URDF from single input image, offering high-quality support for digital twin systems.
|
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Image-to-3D) Generate physically plausible 3D asset URDF from single input image, offering high-quality support for digital twin systems.
|
||||||
(HF space is a simplified demonstration. For the full functionality, please refer to `img3d-cli`.)
|
(HF space is a simplified demonstration. For the full functionality, please refer to `img3d-cli`.)
|
||||||
|
|
||||||
<img src="docs/assets/image_to_3d.jpg" alt="Image to 3D" width="700">
|
<img src="apps/assets/image_to_3d.jpg" alt="Image to 3D" width="700">
|
||||||
|
|
||||||
### ☁️ Service
|
### ☁️ Service
|
||||||
Run the image-to-3D generation service locally.
|
Run the image-to-3D generation service locally.
|
||||||
@ -94,8 +91,8 @@ CUDA_VISIBLE_DEVICES=0 nohup python apps/image_to_3d.py > /dev/null 2>&1 &
|
|||||||
### ⚡ API
|
### ⚡ API
|
||||||
Generate physically plausible 3D assets from image input via the command-line API.
|
Generate physically plausible 3D assets from image input via the command-line API.
|
||||||
```sh
|
```sh
|
||||||
img3d-cli --image_path apps/assets/example_image/sample_00.jpg apps/assets/example_image/sample_01.jpg apps/assets/example_image/sample_19.jpg \
|
img3d-cli --image_path apps/assets/example_image/sample_04.jpg apps/assets/example_image/sample_19.jpg \
|
||||||
--n_retry 1 --output_root outputs/imageto3d
|
--n_retry 2 --output_root outputs/imageto3d
|
||||||
|
|
||||||
# See result(.urdf/mesh.obj/mesh.glb/gs.ply) in ${output_root}/sample_xx/result
|
# See result(.urdf/mesh.obj/mesh.glb/gs.ply) in ${output_root}/sample_xx/result
|
||||||
```
|
```
|
||||||
@ -107,7 +104,7 @@ img3d-cli --image_path apps/assets/example_image/sample_00.jpg apps/assets/examp
|
|||||||
|
|
||||||
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Text-to-3D) Create 3D assets from text descriptions for a wide range of geometry and styles. (HF space is a simplified demonstration. For the full functionality, please refer to `text3d-cli`.)
|
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Text-to-3D) Create 3D assets from text descriptions for a wide range of geometry and styles. (HF space is a simplified demonstration. For the full functionality, please refer to `text3d-cli`.)
|
||||||
|
|
||||||
<img src="docs/assets/text_to_3d.jpg" alt="Text to 3D" width="700">
|
<img src="apps/assets/text_to_3d.jpg" alt="Text to 3D" width="700">
|
||||||
|
|
||||||
### ☁️ Service
|
### ☁️ Service
|
||||||
Deploy the text-to-3D generation service locally.
|
Deploy the text-to-3D generation service locally.
|
||||||
@ -122,11 +119,11 @@ python apps/text_to_3d.py
|
|||||||
Text-to-image model based on SD3.5 Medium, English prompts only.
|
Text-to-image model based on SD3.5 Medium, English prompts only.
|
||||||
Usage requires agreement to the [model license(click accept)](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium), models downloaded automatically.
|
Usage requires agreement to the [model license(click accept)](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium), models downloaded automatically.
|
||||||
|
|
||||||
For large-scale 3D asset generation, set `--n_image_retry=4` `--n_asset_retry=3` `--n_pipe_retry=2`, slower but better, via automatic checking and retries. For more diverse results, omit `--seed_img`.
|
For large-scale 3D assets generation, set `--n_pipe_retry=2` to ensure high end-to-end 3D asset usability through automatic quality check and retries. For more diverse results, do not set `--seed_img`.
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
text3d-cli --prompts "small bronze figurine of a lion" "A globe with wooden base" "wooden table with embroidery" \
|
text3d-cli --prompts "small bronze figurine of a lion" "A globe with wooden base" "wooden table with embroidery" \
|
||||||
--n_image_retry 1 --n_asset_retry 1 --n_pipe_retry 1 --seed_img 0 \
|
--n_image_retry 2 --n_asset_retry 2 --n_pipe_retry 1 --seed_img 0 \
|
||||||
--output_root outputs/textto3d
|
--output_root outputs/textto3d
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -145,7 +142,7 @@ ps: models with more permissive licenses found in `embodied_gen/models/image_com
|
|||||||
|
|
||||||
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Texture-Gen) Generate visually rich textures for 3D mesh.
|
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Texture-Gen) Generate visually rich textures for 3D mesh.
|
||||||
|
|
||||||
<img src="docs/assets/texture_gen.jpg" alt="Texture Gen" width="700">
|
<img src="apps/assets/texture_gen.jpg" alt="Texture Gen" width="700">
|
||||||
|
|
||||||
|
|
||||||
### ☁️ Service
|
### ☁️ Service
|
||||||
@ -170,7 +167,7 @@ texture-cli --mesh_path "apps/assets/example_texture/meshes/robot_text.obj" \
|
|||||||
|
|
||||||
<h2 id="3d-scene-generation">🌍 3D Scene Generation</h2>
|
<h2 id="3d-scene-generation">🌍 3D Scene Generation</h2>
|
||||||
|
|
||||||
<img src="docs/assets/scene3d.gif" alt="scene3d" style="width: 600px;">
|
<img src="apps/assets/scene3d.gif" alt="scene3d" style="width: 600px;">
|
||||||
|
|
||||||
### ⚡ API
|
### ⚡ API
|
||||||
> Run `bash install.sh extra` to install additional requirements if you need to use `scene3d-cli`.
|
> Run `bash install.sh extra` to install additional requirements if you need to use `scene3d-cli`.
|
||||||
@ -193,7 +190,7 @@ CUDA_VISIBLE_DEVICES=0 scene3d-cli \
|
|||||||
|
|
||||||
🚧 *Coming Soon*
|
🚧 *Coming Soon*
|
||||||
|
|
||||||
<img src="docs/assets/articulate.gif" alt="articulate" style="width: 500px;">
|
<img src="apps/assets/articulate.gif" alt="articulate" style="width: 500px;">
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
@ -205,12 +202,12 @@ CUDA_VISIBLE_DEVICES=0 scene3d-cli \
|
|||||||
|
|
||||||
<table>
|
<table>
|
||||||
<tr>
|
<tr>
|
||||||
<td><img src="docs/assets/layout1.gif" alt="layout1" width="320"/></td>
|
<td><img src="apps/assets/layout1.gif" alt="layout1" width="320"/></td>
|
||||||
<td><img src="docs/assets/layout2.gif" alt="layout2" width="320"/></td>
|
<td><img src="apps/assets/layout2.gif" alt="layout2" width="320"/></td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
<td><img src="docs/assets/layout3.gif" alt="layout3" width="320"/></td>
|
<td><img src="apps/assets/layout3.gif" alt="layout3" width="320"/></td>
|
||||||
<td><img src="docs/assets/layout4.gif" alt="layout4" width="320"/></td>
|
<td><img src="apps/assets/layout4.gif" alt="layout4" width="320"/></td>
|
||||||
</tr>
|
</tr>
|
||||||
</table>
|
</table>
|
||||||
|
|
||||||
@ -228,8 +225,8 @@ layout-cli --task_descs "Place the pen in the mug on the desk" "Put the fruit on
|
|||||||
|
|
||||||
<table>
|
<table>
|
||||||
<tr>
|
<tr>
|
||||||
<td><img src="docs/assets/Iscene_demo1.gif" alt="Iscene_demo1" width="234"/></td>
|
<td><img src="apps/assets/Iscene_demo1.gif" alt="Iscene_demo1" width="234"/></td>
|
||||||
<td><img src="docs/assets/Iscene_demo2.gif" alt="Iscene_demo2" width="350"/></td>
|
<td><img src="apps/assets/Iscene_demo2.gif" alt="Iscene_demo2" width="350"/></td>
|
||||||
</tr>
|
</tr>
|
||||||
</table>
|
</table>
|
||||||
|
|
||||||
@ -246,8 +243,7 @@ Using `compose_layout.py`, you can recompose the layout of the generated interac
|
|||||||
```sh
|
```sh
|
||||||
python embodied_gen/scripts/compose_layout.py \
|
python embodied_gen/scripts/compose_layout.py \
|
||||||
--layout_path "outputs/layouts_gens/task_0000/layout.json" \
|
--layout_path "outputs/layouts_gens/task_0000/layout.json" \
|
||||||
--output_dir "outputs/layouts_gens/task_0000/recompose" \
|
--output_dir "outputs/layouts_gens/task_0000/recompose" --insert_robot
|
||||||
--insert_robot
|
|
||||||
```
|
```
|
||||||
|
|
||||||
We provide `sim-cli`, that allows users to easily load generated layouts into an interactive 3D simulation using the SAPIEN engine (will support for more simulators in future updates).
|
We provide `sim-cli`, that allows users to easily load generated layouts into an interactive 3D simulation using the SAPIEN engine (will support for more simulators in future updates).
|
||||||
@ -261,8 +257,8 @@ Example: generate multiple parallel simulation envs with `gym.make` and record s
|
|||||||
|
|
||||||
<table>
|
<table>
|
||||||
<tr>
|
<tr>
|
||||||
<td><img src="docs/assets/parallel_sim.gif" alt="parallel_sim1" width="290"/></td>
|
<td><img src="apps/assets/parallel_sim.gif" alt="parallel_sim1" width="290"/></td>
|
||||||
<td><img src="docs/assets/parallel_sim2.gif" alt="parallel_sim2" width="290"/></td>
|
<td><img src="apps/assets/parallel_sim2.gif" alt="parallel_sim2" width="290"/></td>
|
||||||
</tr>
|
</tr>
|
||||||
</table>
|
</table>
|
||||||
|
|
||||||
@ -275,7 +271,7 @@ python embodied_gen/scripts/parallel_sim.py \
|
|||||||
|
|
||||||
### 🖼️ Real-to-Sim Digital Twin
|
### 🖼️ Real-to-Sim Digital Twin
|
||||||
|
|
||||||
<img src="docs/assets/real2sim_mujoco.gif" alt="real2sim_mujoco" width="400">
|
<img src="apps/assets/real2sim_mujoco.gif" alt="real2sim_mujoco" width="400">
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@ -288,11 +284,11 @@ Example in `tests/test_examples/test_asset_converter.py`.
|
|||||||
| Simulator | Conversion Class |
|
| Simulator | Conversion Class |
|
||||||
|-----------|------------------|
|
|-----------|------------------|
|
||||||
| [isaacsim](https://github.com/isaac-sim/IsaacSim) | MeshtoUSDConverter |
|
| [isaacsim](https://github.com/isaac-sim/IsaacSim) | MeshtoUSDConverter |
|
||||||
| [mujoco](https://github.com/google-deepmind/mujoco) / [genesis](https://github.com/Genesis-Embodied-AI/Genesis) | MeshtoMJCFConverter |
|
| [mujoco](https://github.com/google-deepmind/mujoco) | MeshtoMJCFConverter |
|
||||||
| [sapien](https://github.com/haosulab/SAPIEN) / [isaacgym](https://github.com/isaac-sim/IsaacGymEnvs) / [pybullet](https://github.com/bulletphysics/bullet3) | EmbodiedGen generated .urdf can be used directly |
|
| [genesis](https://github.com/Genesis-Embodied-AI/Genesis) / [sapien](https://github.com/haosulab/SAPIEN) / [isaacgym](https://github.com/isaac-sim/IsaacGymEnvs) / [pybullet](https://github.com/bulletphysics/bullet3) | EmbodiedGen generated .urdf can be used directly |
|
||||||
|
|
||||||
|
|
||||||
<img src="docs/assets/simulators_collision.jpg" alt="simulators_collision" width="500">
|
<img src="apps/assets/simulators_collision.jpg" alt="simulators_collision" width="500">
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@ -300,8 +296,6 @@ Example in `tests/test_examples/test_asset_converter.py`.
|
|||||||
```sh
|
```sh
|
||||||
pip install -e .[dev] && pre-commit install
|
pip install -e .[dev] && pre-commit install
|
||||||
python -m pytest # Pass all unit-test are required.
|
python -m pytest # Pass all unit-test are required.
|
||||||
# mkdocs serve --dev-addr 0.0.0.0:8000
|
|
||||||
# mkdocs gh-deploy --force
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## 📚 Citation
|
## 📚 Citation
|
||||||
@ -331,4 +325,4 @@ EmbodiedGen builds upon the following amazing projects and models:
|
|||||||
|
|
||||||
## ⚖️ License
|
## ⚖️ License
|
||||||
|
|
||||||
This project is licensed under the [Apache License 2.0](docs/LICENSE). See the `LICENSE` file for details.
|
This project is licensed under the [Apache License 2.0](LICENSE). See the `LICENSE` file for details.
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from gradio.themes.utils.colors import gray, neutral, slate, stone, teal, zinc
|
|||||||
lighting_css = """
|
lighting_css = """
|
||||||
<style>
|
<style>
|
||||||
#lighter_mesh canvas {
|
#lighter_mesh canvas {
|
||||||
filter: brightness(2.0) !important;
|
filter: brightness(1.9) !important;
|
||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
"""
|
"""
|
||||||
|
|||||||
|
Before Width: | Height: | Size: 471 KiB After Width: | Height: | Size: 471 KiB |
|
Before Width: | Height: | Size: 875 KiB After Width: | Height: | Size: 875 KiB |
|
Before Width: | Height: | Size: 809 KiB After Width: | Height: | Size: 809 KiB |
|
Before Width: | Height: | Size: 82 KiB |
|
Before Width: | Height: | Size: 71 KiB |
@ -15,7 +15,4 @@ Pick up the marker from the table and put it in the bowl
|
|||||||
Pick up the charger and move it slightly to the left
|
Pick up the charger and move it slightly to the left
|
||||||
Move the jar to the left side of the desk
|
Move the jar to the left side of the desk
|
||||||
Pick the rubik's cube on the top of the desk
|
Pick the rubik's cube on the top of the desk
|
||||||
Move the mug to the right
|
Move the mug to the right
|
||||||
Put the apples from table to the basket
|
|
||||||
Put the oranges from table to the bowl
|
|
||||||
Put the red cup on the tray on the table
|
|
||||||
|
Before Width: | Height: | Size: 36 KiB After Width: | Height: | Size: 36 KiB |
|
Before Width: | Height: | Size: 684 KiB After Width: | Height: | Size: 684 KiB |
|
Before Width: | Height: | Size: 713 KiB After Width: | Height: | Size: 713 KiB |
|
Before Width: | Height: | Size: 642 KiB After Width: | Height: | Size: 642 KiB |
|
Before Width: | Height: | Size: 699 KiB After Width: | Height: | Size: 699 KiB |
|
Before Width: | Height: | Size: 236 KiB After Width: | Height: | Size: 236 KiB |
|
Before Width: | Height: | Size: 771 KiB After Width: | Height: | Size: 771 KiB |
|
Before Width: | Height: | Size: 653 KiB After Width: | Height: | Size: 653 KiB |
|
Before Width: | Height: | Size: 712 KiB After Width: | Height: | Size: 712 KiB |
|
Before Width: | Height: | Size: 2.8 MiB After Width: | Height: | Size: 2.8 MiB |
|
Before Width: | Height: | Size: 3.3 MiB After Width: | Height: | Size: 3.3 MiB |
|
Before Width: | Height: | Size: 926 KiB After Width: | Height: | Size: 926 KiB |
|
Before Width: | Height: | Size: 47 KiB After Width: | Height: | Size: 47 KiB |
|
Before Width: | Height: | Size: 76 KiB After Width: | Height: | Size: 76 KiB |
@ -32,9 +32,8 @@ import trimesh
|
|||||||
from easydict import EasyDict as edict
|
from easydict import EasyDict as edict
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
|
from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
|
||||||
from embodied_gen.data.backproject_v3 import entrypoint as backproject_api_v3
|
|
||||||
from embodied_gen.data.differentiable_render import entrypoint as render_api
|
from embodied_gen.data.differentiable_render import entrypoint as render_api
|
||||||
from embodied_gen.data.utils import resize_pil, trellis_preprocess, zip_files
|
from embodied_gen.data.utils import trellis_preprocess, zip_files
|
||||||
from embodied_gen.models.delight_model import DelightingModel
|
from embodied_gen.models.delight_model import DelightingModel
|
||||||
from embodied_gen.models.gs_model import GaussianOperator
|
from embodied_gen.models.gs_model import GaussianOperator
|
||||||
from embodied_gen.models.segment_model import (
|
from embodied_gen.models.segment_model import (
|
||||||
@ -132,8 +131,8 @@ def patched_setup_functions(self):
|
|||||||
Gaussian.setup_functions = patched_setup_functions
|
Gaussian.setup_functions = patched_setup_functions
|
||||||
|
|
||||||
|
|
||||||
# DELIGHT = DelightingModel()
|
DELIGHT = DelightingModel()
|
||||||
# IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
||||||
# IMAGESR_MODEL = ImageStableSR()
|
# IMAGESR_MODEL = ImageStableSR()
|
||||||
if os.getenv("GRADIO_APP") == "imageto3d":
|
if os.getenv("GRADIO_APP") == "imageto3d":
|
||||||
RBG_REMOVER = RembgRemover()
|
RBG_REMOVER = RembgRemover()
|
||||||
@ -170,8 +169,6 @@ elif os.getenv("GRADIO_APP") == "textto3d":
|
|||||||
)
|
)
|
||||||
os.makedirs(TMP_DIR, exist_ok=True)
|
os.makedirs(TMP_DIR, exist_ok=True)
|
||||||
elif os.getenv("GRADIO_APP") == "texture_edit":
|
elif os.getenv("GRADIO_APP") == "texture_edit":
|
||||||
DELIGHT = DelightingModel()
|
|
||||||
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
|
||||||
PIPELINE_IP = build_texture_gen_pipe(
|
PIPELINE_IP = build_texture_gen_pipe(
|
||||||
base_ckpt_dir="./weights",
|
base_ckpt_dir="./weights",
|
||||||
ip_adapt_scale=0.7,
|
ip_adapt_scale=0.7,
|
||||||
@ -208,7 +205,7 @@ def preprocess_image_fn(
|
|||||||
elif isinstance(image, np.ndarray):
|
elif isinstance(image, np.ndarray):
|
||||||
image = Image.fromarray(image)
|
image = Image.fromarray(image)
|
||||||
|
|
||||||
image_cache = resize_pil(image.copy(), 1024)
|
image_cache = image.copy().resize((512, 512))
|
||||||
|
|
||||||
bg_remover = RBG_REMOVER if rmbg_tag == "rembg" else RBG14_REMOVER
|
bg_remover = RBG_REMOVER if rmbg_tag == "rembg" else RBG14_REMOVER
|
||||||
image = bg_remover(image)
|
image = bg_remover(image)
|
||||||
@ -224,7 +221,7 @@ def preprocess_sam_image_fn(
|
|||||||
image = Image.fromarray(image)
|
image = Image.fromarray(image)
|
||||||
|
|
||||||
sam_image = SAM_PREDICTOR.preprocess_image(image)
|
sam_image = SAM_PREDICTOR.preprocess_image(image)
|
||||||
image_cache = sam_image.copy()
|
image_cache = Image.fromarray(sam_image).resize((512, 512))
|
||||||
SAM_PREDICTOR.predictor.set_image(sam_image)
|
SAM_PREDICTOR.predictor.set_image(sam_image)
|
||||||
|
|
||||||
return sam_image, image_cache
|
return sam_image, image_cache
|
||||||
@ -515,60 +512,6 @@ def extract_3d_representations_v2(
|
|||||||
return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
|
return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
|
||||||
|
|
||||||
|
|
||||||
def extract_3d_representations_v3(
|
|
||||||
state: dict,
|
|
||||||
enable_delight: bool,
|
|
||||||
texture_size: int,
|
|
||||||
req: gr.Request,
|
|
||||||
):
|
|
||||||
output_root = TMP_DIR
|
|
||||||
user_dir = os.path.join(output_root, str(req.session_hash))
|
|
||||||
gs_model, mesh_model = unpack_state(state, device="cpu")
|
|
||||||
|
|
||||||
filename = "sample"
|
|
||||||
gs_path = os.path.join(user_dir, f"{filename}_gs.ply")
|
|
||||||
gs_model.save_ply(gs_path)
|
|
||||||
|
|
||||||
# Rotate mesh and GS by 90 degrees around Z-axis.
|
|
||||||
rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
|
|
||||||
gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
|
|
||||||
mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
|
|
||||||
|
|
||||||
# Addtional rotation for GS to align mesh.
|
|
||||||
gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix)
|
|
||||||
pose = GaussianOperator.trans_to_quatpose(gs_rot)
|
|
||||||
aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
|
|
||||||
GaussianOperator.resave_ply(
|
|
||||||
in_ply=gs_path,
|
|
||||||
out_ply=aligned_gs_path,
|
|
||||||
instance_pose=pose,
|
|
||||||
device="cpu",
|
|
||||||
)
|
|
||||||
|
|
||||||
mesh = trimesh.Trimesh(
|
|
||||||
vertices=mesh_model.vertices.cpu().numpy(),
|
|
||||||
faces=mesh_model.faces.cpu().numpy(),
|
|
||||||
)
|
|
||||||
mesh.vertices = mesh.vertices @ np.array(mesh_add_rot)
|
|
||||||
mesh.vertices = mesh.vertices @ np.array(rot_matrix)
|
|
||||||
|
|
||||||
mesh_obj_path = os.path.join(user_dir, f"{filename}.obj")
|
|
||||||
mesh.export(mesh_obj_path)
|
|
||||||
|
|
||||||
mesh = backproject_api_v3(
|
|
||||||
gs_path=aligned_gs_path,
|
|
||||||
mesh_path=mesh_obj_path,
|
|
||||||
output_path=mesh_obj_path,
|
|
||||||
skip_fix_mesh=False,
|
|
||||||
texture_size=texture_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
mesh_glb_path = os.path.join(user_dir, f"{filename}.glb")
|
|
||||||
mesh.export(mesh_glb_path)
|
|
||||||
|
|
||||||
return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
|
|
||||||
|
|
||||||
|
|
||||||
def extract_urdf(
|
def extract_urdf(
|
||||||
gs_path: str,
|
gs_path: str,
|
||||||
mesh_obj_path: str,
|
mesh_obj_path: str,
|
||||||
|
|||||||
@ -27,7 +27,7 @@ from common import (
|
|||||||
VERSION,
|
VERSION,
|
||||||
active_btn_by_content,
|
active_btn_by_content,
|
||||||
end_session,
|
end_session,
|
||||||
extract_3d_representations_v3,
|
extract_3d_representations_v2,
|
||||||
extract_urdf,
|
extract_urdf,
|
||||||
get_seed,
|
get_seed,
|
||||||
image_to_3d,
|
image_to_3d,
|
||||||
@ -45,8 +45,8 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|||||||
## ***EmbodiedGen***: Image-to-3D Asset
|
## ***EmbodiedGen***: Image-to-3D Asset
|
||||||
**🔖 Version**: {VERSION}
|
**🔖 Version**: {VERSION}
|
||||||
<p style="display: flex; gap: 10px; flex-wrap: nowrap;">
|
<p style="display: flex; gap: 10px; flex-wrap: nowrap;">
|
||||||
<a href="https://horizonrobotics.github.io/EmbodiedGen">
|
<a href="https://horizonrobotics.github.io/robot_lab/embodied_gen/index.html">
|
||||||
<img alt="📖 Documentation" src="https://img.shields.io/badge/📖-Documentation-blue">
|
<img alt="🌐 Project Page" src="https://img.shields.io/badge/🌐-Project_Page-blue">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://arxiv.org/abs/2506.10600">
|
<a href="https://arxiv.org/abs/2506.10600">
|
||||||
<img alt="📄 arXiv" src="https://img.shields.io/badge/📄-arXiv-b31b1b">
|
<img alt="📄 arXiv" src="https://img.shields.io/badge/📄-arXiv-b31b1b">
|
||||||
@ -179,17 +179,17 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|||||||
)
|
)
|
||||||
|
|
||||||
generate_btn = gr.Button(
|
generate_btn = gr.Button(
|
||||||
"🚀 1. Generate(~2 mins)",
|
"🚀 1. Generate(~0.5 mins)",
|
||||||
variant="primary",
|
variant="primary",
|
||||||
interactive=False,
|
interactive=False,
|
||||||
)
|
)
|
||||||
model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False)
|
model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False)
|
||||||
# with gr.Row():
|
with gr.Row():
|
||||||
# extract_rep3d_btn = gr.Button(
|
extract_rep3d_btn = gr.Button(
|
||||||
# "🔍 2. Extract 3D Representation(~2 mins)",
|
"🔍 2. Extract 3D Representation(~2 mins)",
|
||||||
# variant="primary",
|
variant="primary",
|
||||||
# interactive=False,
|
interactive=False,
|
||||||
# )
|
)
|
||||||
with gr.Accordion(
|
with gr.Accordion(
|
||||||
label="Enter Asset Attributes(optional)", open=False
|
label="Enter Asset Attributes(optional)", open=False
|
||||||
):
|
):
|
||||||
@ -207,7 +207,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
extract_urdf_btn = gr.Button(
|
extract_urdf_btn = gr.Button(
|
||||||
"🧩 2. Extract URDF with physics(~1 mins)",
|
"🧩 3. Extract URDF with physics(~1 mins)",
|
||||||
variant="primary",
|
variant="primary",
|
||||||
interactive=False,
|
interactive=False,
|
||||||
)
|
)
|
||||||
@ -230,7 +230,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
download_urdf = gr.DownloadButton(
|
download_urdf = gr.DownloadButton(
|
||||||
label="⬇️ 3. Download URDF",
|
label="⬇️ 4. Download URDF",
|
||||||
variant="primary",
|
variant="primary",
|
||||||
interactive=False,
|
interactive=False,
|
||||||
)
|
)
|
||||||
@ -326,7 +326,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|||||||
image_prompt.change(
|
image_prompt.change(
|
||||||
lambda: tuple(
|
lambda: tuple(
|
||||||
[
|
[
|
||||||
# gr.Button(interactive=False),
|
gr.Button(interactive=False),
|
||||||
gr.Button(interactive=False),
|
gr.Button(interactive=False),
|
||||||
gr.Button(interactive=False),
|
gr.Button(interactive=False),
|
||||||
None,
|
None,
|
||||||
@ -344,7 +344,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
outputs=[
|
outputs=[
|
||||||
# extract_rep3d_btn,
|
extract_rep3d_btn,
|
||||||
extract_urdf_btn,
|
extract_urdf_btn,
|
||||||
download_urdf,
|
download_urdf,
|
||||||
model_output_gs,
|
model_output_gs,
|
||||||
@ -375,7 +375,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|||||||
image_prompt_sam.change(
|
image_prompt_sam.change(
|
||||||
lambda: tuple(
|
lambda: tuple(
|
||||||
[
|
[
|
||||||
# gr.Button(interactive=False),
|
gr.Button(interactive=False),
|
||||||
gr.Button(interactive=False),
|
gr.Button(interactive=False),
|
||||||
gr.Button(interactive=False),
|
gr.Button(interactive=False),
|
||||||
None,
|
None,
|
||||||
@ -394,7 +394,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
outputs=[
|
outputs=[
|
||||||
# extract_rep3d_btn,
|
extract_rep3d_btn,
|
||||||
extract_urdf_btn,
|
extract_urdf_btn,
|
||||||
download_urdf,
|
download_urdf,
|
||||||
model_output_gs,
|
model_output_gs,
|
||||||
@ -447,7 +447,12 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|||||||
],
|
],
|
||||||
outputs=[output_buf, video_output],
|
outputs=[output_buf, video_output],
|
||||||
).success(
|
).success(
|
||||||
extract_3d_representations_v3,
|
lambda: gr.Button(interactive=True),
|
||||||
|
outputs=[extract_rep3d_btn],
|
||||||
|
)
|
||||||
|
|
||||||
|
extract_rep3d_btn.click(
|
||||||
|
extract_3d_representations_v2,
|
||||||
inputs=[
|
inputs=[
|
||||||
output_buf,
|
output_buf,
|
||||||
project_delight,
|
project_delight,
|
||||||
@ -490,4 +495,4 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
demo.launch(server_port=8081)
|
demo.launch()
|
||||||
|
|||||||
@ -27,7 +27,7 @@ from common import (
|
|||||||
VERSION,
|
VERSION,
|
||||||
active_btn_by_text_content,
|
active_btn_by_text_content,
|
||||||
end_session,
|
end_session,
|
||||||
extract_3d_representations_v3,
|
extract_3d_representations_v2,
|
||||||
extract_urdf,
|
extract_urdf,
|
||||||
get_cached_image,
|
get_cached_image,
|
||||||
get_seed,
|
get_seed,
|
||||||
@ -45,8 +45,8 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|||||||
## ***EmbodiedGen***: Text-to-3D Asset
|
## ***EmbodiedGen***: Text-to-3D Asset
|
||||||
**🔖 Version**: {VERSION}
|
**🔖 Version**: {VERSION}
|
||||||
<p style="display: flex; gap: 10px; flex-wrap: nowrap;">
|
<p style="display: flex; gap: 10px; flex-wrap: nowrap;">
|
||||||
<a href="https://horizonrobotics.github.io/EmbodiedGen">
|
<a href="https://horizonrobotics.github.io/robot_lab/embodied_gen/index.html">
|
||||||
<img alt="📖 Documentation" src="https://img.shields.io/badge/📖-Documentation-blue">
|
<img alt="🌐 Project Page" src="https://img.shields.io/badge/🌐-Project_Page-blue">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://arxiv.org/abs/2506.10600">
|
<a href="https://arxiv.org/abs/2506.10600">
|
||||||
<img alt="📄 arXiv" src="https://img.shields.io/badge/📄-arXiv-b31b1b">
|
<img alt="📄 arXiv" src="https://img.shields.io/badge/📄-arXiv-b31b1b">
|
||||||
@ -178,17 +178,17 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|||||||
)
|
)
|
||||||
|
|
||||||
generate_btn = gr.Button(
|
generate_btn = gr.Button(
|
||||||
"🚀 2. Generate 3D(~2 mins)",
|
"🚀 2. Generate 3D(~0.5 mins)",
|
||||||
variant="primary",
|
variant="primary",
|
||||||
interactive=False,
|
interactive=False,
|
||||||
)
|
)
|
||||||
model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False)
|
model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False)
|
||||||
# with gr.Row():
|
with gr.Row():
|
||||||
# extract_rep3d_btn = gr.Button(
|
extract_rep3d_btn = gr.Button(
|
||||||
# "🔍 3. Extract 3D Representation(~1 mins)",
|
"🔍 3. Extract 3D Representation(~1 mins)",
|
||||||
# variant="primary",
|
variant="primary",
|
||||||
# interactive=False,
|
interactive=False,
|
||||||
# )
|
)
|
||||||
with gr.Accordion(
|
with gr.Accordion(
|
||||||
label="Enter Asset Attributes(optional)", open=False
|
label="Enter Asset Attributes(optional)", open=False
|
||||||
):
|
):
|
||||||
@ -206,13 +206,13 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
extract_urdf_btn = gr.Button(
|
extract_urdf_btn = gr.Button(
|
||||||
"🧩 3. Extract URDF with physics(~1 mins)",
|
"🧩 4. Extract URDF with physics(~1 mins)",
|
||||||
variant="primary",
|
variant="primary",
|
||||||
interactive=False,
|
interactive=False,
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
download_urdf = gr.DownloadButton(
|
download_urdf = gr.DownloadButton(
|
||||||
label="⬇️ 4. Download URDF",
|
label="⬇️ 5. Download URDF",
|
||||||
variant="primary",
|
variant="primary",
|
||||||
interactive=False,
|
interactive=False,
|
||||||
)
|
)
|
||||||
@ -336,7 +336,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|||||||
generate_img_btn.click(
|
generate_img_btn.click(
|
||||||
lambda: tuple(
|
lambda: tuple(
|
||||||
[
|
[
|
||||||
# gr.Button(interactive=False),
|
gr.Button(interactive=False),
|
||||||
gr.Button(interactive=False),
|
gr.Button(interactive=False),
|
||||||
gr.Button(interactive=False),
|
gr.Button(interactive=False),
|
||||||
gr.Button(interactive=False),
|
gr.Button(interactive=False),
|
||||||
@ -358,7 +358,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
outputs=[
|
outputs=[
|
||||||
# extract_rep3d_btn,
|
extract_rep3d_btn,
|
||||||
extract_urdf_btn,
|
extract_urdf_btn,
|
||||||
download_urdf,
|
download_urdf,
|
||||||
generate_btn,
|
generate_btn,
|
||||||
@ -428,7 +428,12 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|||||||
],
|
],
|
||||||
outputs=[output_buf, video_output],
|
outputs=[output_buf, video_output],
|
||||||
).success(
|
).success(
|
||||||
extract_3d_representations_v3,
|
lambda: gr.Button(interactive=True),
|
||||||
|
outputs=[extract_rep3d_btn],
|
||||||
|
)
|
||||||
|
|
||||||
|
extract_rep3d_btn.click(
|
||||||
|
extract_3d_representations_v2,
|
||||||
inputs=[
|
inputs=[
|
||||||
output_buf,
|
output_buf,
|
||||||
project_delight,
|
project_delight,
|
||||||
@ -471,4 +476,4 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
demo.launch(server_name="0.0.0.0", server_port=8082)
|
demo.launch()
|
||||||
|
|||||||
@ -55,8 +55,8 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|||||||
## ***EmbodiedGen***: Texture Generation
|
## ***EmbodiedGen***: Texture Generation
|
||||||
**🔖 Version**: {VERSION}
|
**🔖 Version**: {VERSION}
|
||||||
<p style="display: flex; gap: 10px; flex-wrap: nowrap;">
|
<p style="display: flex; gap: 10px; flex-wrap: nowrap;">
|
||||||
<a href="https://horizonrobotics.github.io/EmbodiedGen">
|
<a href="https://horizonrobotics.github.io/robot_lab/embodied_gen/index.html">
|
||||||
<img alt="📖 Documentation" src="https://img.shields.io/badge/📖-Documentation-blue">
|
<img alt="🌐 Project Page" src="https://img.shields.io/badge/🌐-Project_Page-blue">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://arxiv.org/abs/2506.10600">
|
<a href="https://arxiv.org/abs/2506.10600">
|
||||||
<img alt="📄 arXiv" src="https://img.shields.io/badge/📄-arXiv-b31b1b">
|
<img alt="📄 arXiv" src="https://img.shields.io/badge/📄-arXiv-b31b1b">
|
||||||
@ -381,4 +381,4 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
demo.launch(server_port=8083)
|
demo.launch()
|
||||||
|
|||||||
@ -1,53 +1,13 @@
|
|||||||
# Project EmbodiedGen
|
|
||||||
#
|
|
||||||
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
# implied. See the License for the specific language governing
|
|
||||||
# permissions and limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
gradio_tmp_dir = os.path.join(
|
|
||||||
os.path.dirname(os.path.abspath(__file__)), "gradio_cache"
|
|
||||||
)
|
|
||||||
os.makedirs(gradio_tmp_dir, exist_ok=True)
|
|
||||||
os.environ["GRADIO_TEMP_DIR"] = gradio_tmp_dir
|
|
||||||
|
|
||||||
import shutil
|
import shutil
|
||||||
import uuid
|
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, Tuple
|
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from app_style import custom_theme, lighting_css
|
from app_style import custom_theme, lighting_css
|
||||||
from embodied_gen.utils.tags import VERSION
|
|
||||||
|
|
||||||
try:
|
|
||||||
from embodied_gen.utils.gpt_clients import GPT_CLIENT as gpt_client
|
|
||||||
|
|
||||||
gpt_client.check_connection()
|
|
||||||
GPT_AVAILABLE = True
|
|
||||||
except Exception as e:
|
|
||||||
gpt_client = None
|
|
||||||
GPT_AVAILABLE = False
|
|
||||||
print(
|
|
||||||
f"Warning: GPT client could not be initialized. Search will be disabled. Error: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# --- Configuration & Data Loading ---
|
# --- Configuration & Data Loading ---
|
||||||
|
VERSION = "v0.1.5"
|
||||||
RUNNING_MODE = "local" # local or hf_remote
|
RUNNING_MODE = "local" # local or hf_remote
|
||||||
CSV_FILE = "dataset_index.csv"
|
CSV_FILE = "dataset_index.csv"
|
||||||
|
|
||||||
@ -76,7 +36,6 @@ TMP_DIR = os.path.join(
|
|||||||
)
|
)
|
||||||
os.makedirs(TMP_DIR, exist_ok=True)
|
os.makedirs(TMP_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
# --- Custom CSS for Styling ---
|
# --- Custom CSS for Styling ---
|
||||||
css = """
|
css = """
|
||||||
.gradio-container .gradio-group { box-shadow: 0 2px 4px rgba(0,0,0,0.05) !important; }
|
.gradio-container .gradio-group { box-shadow: 0 2px 4px rgba(0,0,0,0.05) !important; }
|
||||||
@ -85,43 +44,14 @@ css = """
|
|||||||
|
|
||||||
lighting_css = """
|
lighting_css = """
|
||||||
<style>
|
<style>
|
||||||
#visual_mesh canvas { filter: brightness(2.2) !important; }
|
#lighter_mesh canvas {
|
||||||
#collision_mesh_a canvas, #collision_mesh_b canvas { filter: brightness(1.0) !important; }
|
filter: brightness(2.2) !important;
|
||||||
|
}
|
||||||
</style>
|
</style>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_prev_temp = {}
|
|
||||||
|
|
||||||
|
# --- Helper Functions ---
|
||||||
def _unique_path(
|
|
||||||
src_path: str | None, session_hash: str, kind: str
|
|
||||||
) -> str | None:
|
|
||||||
"""Link/copy src to GRADIO_TEMP_DIR/session_hash with random filename. Always return a fresh URL."""
|
|
||||||
if not src_path:
|
|
||||||
return None
|
|
||||||
tmp_root = (
|
|
||||||
Path(os.environ.get("GRADIO_TEMP_DIR", "/tmp"))
|
|
||||||
/ "model3d-cache"
|
|
||||||
/ session_hash
|
|
||||||
)
|
|
||||||
tmp_root.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# rolling cleanup for same kind
|
|
||||||
prev = _prev_temp.get(session_hash, {})
|
|
||||||
old = prev.get(kind)
|
|
||||||
if old and old.exists():
|
|
||||||
old.unlink()
|
|
||||||
|
|
||||||
ext = Path(src_path).suffix or ".glb"
|
|
||||||
dst = tmp_root / f"{kind}-{uuid.uuid4().hex}{ext}"
|
|
||||||
shutil.copy2(src_path, dst)
|
|
||||||
|
|
||||||
prev[kind] = dst
|
|
||||||
_prev_temp[session_hash] = prev
|
|
||||||
return str(dst)
|
|
||||||
|
|
||||||
|
|
||||||
# --- Helper Functions (data filtering) ---
|
|
||||||
def get_primary_categories():
|
def get_primary_categories():
|
||||||
return sorted(df["primary_category"].dropna().unique())
|
return sorted(df["primary_category"].dropna().unique())
|
||||||
|
|
||||||
@ -151,7 +81,7 @@ def get_categories(primary, secondary):
|
|||||||
|
|
||||||
def get_assets(primary, secondary, category):
|
def get_assets(primary, secondary, category):
|
||||||
if not primary or not secondary:
|
if not primary or not secondary:
|
||||||
return [], gr.update(interactive=False), pd.DataFrame()
|
return [], gr.update(interactive=False)
|
||||||
|
|
||||||
subset = df[
|
subset = df[
|
||||||
(df["primary_category"] == primary)
|
(df["primary_category"] == primary)
|
||||||
@ -175,211 +105,79 @@ def get_assets(primary, secondary, category):
|
|||||||
else "https://dummyimage.com/512x512/cccccc/000000&text=No+Preview"
|
else "https://dummyimage.com/512x512/cccccc/000000&text=No+Preview"
|
||||||
)
|
)
|
||||||
|
|
||||||
return items, gr.update(interactive=True), subset
|
return items, gr.update(interactive=True)
|
||||||
|
|
||||||
|
|
||||||
def search_assets(query: str, top_k: int):
|
def show_asset_from_gallery(
|
||||||
if not GPT_AVAILABLE or not query:
|
evt: gr.SelectData, primary: str, secondary: str, category: str
|
||||||
gr.Warning(
|
):
|
||||||
"GPT client is not available or query is empty. Cannot perform search."
|
index = evt.index
|
||||||
)
|
subset = df[
|
||||||
return [], gr.update(interactive=False), pd.DataFrame()
|
(df["primary_category"] == primary)
|
||||||
|
& (df["secondary_category"] == secondary)
|
||||||
|
]
|
||||||
|
if category:
|
||||||
|
subset = subset[subset["category"] == category]
|
||||||
|
|
||||||
gr.Info(f"Searching for assets matching: '{query}'...")
|
est_type_text = "N/A"
|
||||||
|
est_height_text = "N/A"
|
||||||
|
est_mass_text = "N/A"
|
||||||
|
est_mu_text = "N/A"
|
||||||
|
|
||||||
keywords = query.split()
|
if index >= len(subset):
|
||||||
keyword_filter = pd.Series([False] * len(df), index=df.index)
|
return (
|
||||||
for keyword in keywords:
|
None,
|
||||||
keyword_filter |= df['description'].str.contains(
|
"Error: Selection index is out of bounds.",
|
||||||
keyword, case=False, na=False
|
None,
|
||||||
|
None,
|
||||||
|
est_type_text,
|
||||||
|
est_height_text,
|
||||||
|
est_mass_text,
|
||||||
|
est_mu_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
candidates = df[keyword_filter]
|
row = subset.iloc[index]
|
||||||
|
|
||||||
if len(candidates) > 100:
|
|
||||||
candidates = candidates.head(100)
|
|
||||||
|
|
||||||
if candidates.empty:
|
|
||||||
gr.Warning("No assets found matching the keywords.")
|
|
||||||
return [], gr.update(interactive=True), pd.DataFrame()
|
|
||||||
|
|
||||||
try:
|
|
||||||
descriptions = [
|
|
||||||
f"{idx}: {desc}" for idx, desc in candidates['description'].items()
|
|
||||||
]
|
|
||||||
descriptions_text = "\n".join(descriptions)
|
|
||||||
|
|
||||||
prompt = f"""
|
|
||||||
A user is searching for 3D assets with the query: "{query}".
|
|
||||||
Below is a list of available assets, each with an ID and a description.
|
|
||||||
Please evaluate how well each asset description matches the user's query and rate them on a scale from 0 to 10, where 10 is a perfect match.
|
|
||||||
|
|
||||||
Your task is to return a list of the top {top_k} asset IDs, sorted from the most relevant to the least relevant.
|
|
||||||
The output format must be a simple comma-separated list of IDs, for example: "123,45,678". Do not add any other text.
|
|
||||||
|
|
||||||
Asset Descriptions:
|
|
||||||
{descriptions_text}
|
|
||||||
|
|
||||||
User Query: "{query}"
|
|
||||||
|
|
||||||
Top {top_k} sorted asset IDs:
|
|
||||||
"""
|
|
||||||
response = gpt_client.query(prompt)
|
|
||||||
sorted_ids_str = response.strip().split(',')
|
|
||||||
sorted_ids = [
|
|
||||||
int(id_str.strip())
|
|
||||||
for id_str in sorted_ids_str
|
|
||||||
if id_str.strip().isdigit()
|
|
||||||
]
|
|
||||||
top_assets = df.loc[sorted_ids].head(top_k)
|
|
||||||
except Exception as e:
|
|
||||||
gr.Error(f"An error occurred while using GPT for ranking: {e}")
|
|
||||||
top_assets = candidates.head(top_k)
|
|
||||||
|
|
||||||
items = []
|
|
||||||
for row in top_assets.itertuples():
|
|
||||||
asset_dir = os.path.join(DATA_ROOT, row.asset_dir)
|
|
||||||
video_path = None
|
|
||||||
if pd.notna(row.asset_dir) and os.path.exists(asset_dir):
|
|
||||||
for f in os.listdir(asset_dir):
|
|
||||||
if f.lower().endswith(".mp4"):
|
|
||||||
video_path = os.path.join(asset_dir, f)
|
|
||||||
break
|
|
||||||
items.append(
|
|
||||||
video_path
|
|
||||||
if video_path
|
|
||||||
else "https://dummyimage.com/512x512/cccccc/000000&text=No+Preview"
|
|
||||||
)
|
|
||||||
|
|
||||||
gr.Info(f"Found {len(items)} assets.")
|
|
||||||
return items, gr.update(interactive=True), top_assets
|
|
||||||
|
|
||||||
|
|
||||||
# --- Mesh extraction ---
|
|
||||||
def _extract_mesh_paths(row) -> Tuple[str | None, str | None, str]:
|
|
||||||
desc = row["description"]
|
desc = row["description"]
|
||||||
urdf_path = os.path.join(DATA_ROOT, row["urdf_path"])
|
urdf_path = os.path.join(DATA_ROOT, row["urdf_path"])
|
||||||
asset_dir = os.path.join(DATA_ROOT, row["asset_dir"])
|
asset_dir = os.path.join(DATA_ROOT, row["asset_dir"])
|
||||||
visual_mesh_path = None
|
mesh_to_display = None
|
||||||
collision_mesh_path = None
|
|
||||||
|
|
||||||
if pd.notna(urdf_path) and os.path.exists(urdf_path):
|
if pd.notna(urdf_path) and os.path.exists(urdf_path):
|
||||||
try:
|
try:
|
||||||
tree = ET.parse(urdf_path)
|
tree = ET.parse(urdf_path)
|
||||||
root = tree.getroot()
|
root = tree.getroot()
|
||||||
|
|
||||||
visual_mesh_element = root.find('.//visual/geometry/mesh')
|
mesh_element = root.find('.//visual/geometry/mesh')
|
||||||
if visual_mesh_element is not None:
|
if mesh_element is not None:
|
||||||
visual_mesh_filename = visual_mesh_element.get('filename')
|
mesh_filename = mesh_element.get('filename')
|
||||||
if visual_mesh_filename:
|
if mesh_filename:
|
||||||
glb_filename = (
|
glb_filename = os.path.splitext(mesh_filename)[0] + ".glb"
|
||||||
os.path.splitext(visual_mesh_filename)[0] + ".glb"
|
|
||||||
)
|
|
||||||
potential_path = os.path.join(asset_dir, glb_filename)
|
potential_path = os.path.join(asset_dir, glb_filename)
|
||||||
if os.path.exists(potential_path):
|
if os.path.exists(potential_path):
|
||||||
visual_mesh_path = potential_path
|
mesh_to_display = potential_path
|
||||||
|
|
||||||
collision_mesh_element = root.find('.//collision/geometry/mesh')
|
category_elem = root.find('.//extra_info/category')
|
||||||
if collision_mesh_element is not None:
|
if category_elem is not None and category_elem.text:
|
||||||
collision_mesh_filename = collision_mesh_element.get(
|
est_type_text = category_elem.text.strip()
|
||||||
'filename'
|
|
||||||
)
|
height_elem = root.find('.//extra_info/real_height')
|
||||||
if collision_mesh_filename:
|
if height_elem is not None and height_elem.text:
|
||||||
potential_collision_path = os.path.join(
|
est_height_text = height_elem.text.strip()
|
||||||
asset_dir, collision_mesh_filename
|
|
||||||
)
|
mass_elem = root.find('.//extra_info/min_mass')
|
||||||
if os.path.exists(potential_collision_path):
|
if mass_elem is not None and mass_elem.text:
|
||||||
collision_mesh_path = potential_collision_path
|
est_mass_text = mass_elem.text.strip()
|
||||||
|
|
||||||
|
mu_elem = root.find('.//collision/gazebo/mu2')
|
||||||
|
if mu_elem is not None and mu_elem.text:
|
||||||
|
est_mu_text = mu_elem.text.strip()
|
||||||
|
|
||||||
except ET.ParseError:
|
except ET.ParseError:
|
||||||
desc = f"Error: Failed to parse URDF at {urdf_path}. {desc}"
|
desc = f"Error: Failed to parse URDF at {urdf_path}. {desc}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
desc = f"An error occurred while processing URDF: {str(e)}. {desc}"
|
desc = f"An error occurred while processing URDF: {str(e)}. {desc}"
|
||||||
|
|
||||||
return visual_mesh_path, collision_mesh_path, desc
|
|
||||||
|
|
||||||
|
|
||||||
def show_asset_from_gallery(
|
|
||||||
evt: gr.SelectData,
|
|
||||||
primary: str,
|
|
||||||
secondary: str,
|
|
||||||
category: str,
|
|
||||||
search_query: str,
|
|
||||||
gallery_df: pd.DataFrame,
|
|
||||||
):
|
|
||||||
"""Parse the selected asset and return raw paths + metadata."""
|
|
||||||
index = evt.index
|
|
||||||
|
|
||||||
if search_query and gallery_df is not None and not gallery_df.empty:
|
|
||||||
subset = gallery_df
|
|
||||||
else:
|
|
||||||
if not primary or not secondary:
|
|
||||||
return (
|
|
||||||
None, # visual_path
|
|
||||||
None, # collision_path
|
|
||||||
"Error: Primary or secondary category not selected.",
|
|
||||||
None, # asset_dir
|
|
||||||
None, # urdf_path
|
|
||||||
"N/A",
|
|
||||||
"N/A",
|
|
||||||
"N/A",
|
|
||||||
"N/A",
|
|
||||||
)
|
|
||||||
|
|
||||||
subset = df[
|
|
||||||
(df["primary_category"] == primary)
|
|
||||||
& (df["secondary_category"] == secondary)
|
|
||||||
]
|
|
||||||
if category:
|
|
||||||
subset = subset[subset["category"] == category]
|
|
||||||
|
|
||||||
if subset.empty or index >= len(subset):
|
|
||||||
return (
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
"Error: Selection index is out of bounds or data is missing.",
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
"N/A",
|
|
||||||
"N/A",
|
|
||||||
"N/A",
|
|
||||||
"N/A",
|
|
||||||
)
|
|
||||||
|
|
||||||
row = subset.iloc[index]
|
|
||||||
visual_path, collision_path, desc = _extract_mesh_paths(row)
|
|
||||||
|
|
||||||
urdf_path = os.path.join(DATA_ROOT, row["urdf_path"])
|
|
||||||
asset_dir = os.path.join(DATA_ROOT, row["asset_dir"])
|
|
||||||
|
|
||||||
# read extra info
|
|
||||||
est_type_text = "N/A"
|
|
||||||
est_height_text = "N/A"
|
|
||||||
est_mass_text = "N/A"
|
|
||||||
est_mu_text = "N/A"
|
|
||||||
|
|
||||||
if pd.notna(urdf_path) and os.path.exists(urdf_path):
|
|
||||||
try:
|
|
||||||
tree = ET.parse(urdf_path)
|
|
||||||
root = tree.getroot()
|
|
||||||
category_elem = root.find('.//extra_info/category')
|
|
||||||
if category_elem is not None and category_elem.text:
|
|
||||||
est_type_text = category_elem.text.strip()
|
|
||||||
height_elem = root.find('.//extra_info/real_height')
|
|
||||||
if height_elem is not None and height_elem.text:
|
|
||||||
est_height_text = height_elem.text.strip()
|
|
||||||
mass_elem = root.find('.//extra_info/min_mass')
|
|
||||||
if mass_elem is not None and mass_elem.text:
|
|
||||||
est_mass_text = mass_elem.text.strip()
|
|
||||||
mu_elem = root.find('.//collision/gazebo/mu2')
|
|
||||||
if mu_elem is not None and mu_elem.text:
|
|
||||||
est_mu_text = mu_elem.text.strip()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
visual_path,
|
gr.update(value=mesh_to_display),
|
||||||
collision_path,
|
|
||||||
desc,
|
desc,
|
||||||
asset_dir,
|
asset_dir,
|
||||||
urdf_path,
|
urdf_path,
|
||||||
@ -390,56 +188,6 @@ def show_asset_from_gallery(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def render_meshes(
|
|
||||||
visual_path: str | None,
|
|
||||||
collision_path: str | None,
|
|
||||||
switch_viewer: bool,
|
|
||||||
req: gr.Request,
|
|
||||||
):
|
|
||||||
session_hash = getattr(req, "session_hash", "default")
|
|
||||||
|
|
||||||
if switch_viewer:
|
|
||||||
yield (
|
|
||||||
gr.update(value=None),
|
|
||||||
gr.update(value=None, visible=False),
|
|
||||||
gr.update(value=None, visible=True),
|
|
||||||
True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
yield (
|
|
||||||
gr.update(value=None),
|
|
||||||
gr.update(value=None, visible=True),
|
|
||||||
gr.update(value=None, visible=False),
|
|
||||||
True,
|
|
||||||
)
|
|
||||||
|
|
||||||
visual_unique = (
|
|
||||||
_unique_path(visual_path, session_hash, "visual")
|
|
||||||
if visual_path
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
collision_unique = (
|
|
||||||
_unique_path(collision_path, session_hash, "collision")
|
|
||||||
if collision_path
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
if switch_viewer:
|
|
||||||
yield (
|
|
||||||
gr.update(value=visual_unique),
|
|
||||||
gr.update(value=None, visible=False),
|
|
||||||
gr.update(value=collision_unique, visible=True),
|
|
||||||
False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
yield (
|
|
||||||
gr.update(value=visual_unique),
|
|
||||||
gr.update(value=collision_unique, visible=True),
|
|
||||||
gr.update(value=None, visible=False),
|
|
||||||
True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_asset_zip(asset_dir: str, req: gr.Request):
|
def create_asset_zip(asset_dir: str, req: gr.Request):
|
||||||
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
||||||
os.makedirs(user_dir, exist_ok=True)
|
os.makedirs(user_dir, exist_ok=True)
|
||||||
@ -466,7 +214,7 @@ def end_session(req: gr.Request) -> None:
|
|||||||
shutil.rmtree(user_dir)
|
shutil.rmtree(user_dir)
|
||||||
|
|
||||||
|
|
||||||
# --- UI ---
|
# --- Gradio UI Definition ---
|
||||||
with gr.Blocks(
|
with gr.Blocks(
|
||||||
theme=custom_theme,
|
theme=custom_theme,
|
||||||
css=css,
|
css=css,
|
||||||
@ -479,8 +227,8 @@ with gr.Blocks(
|
|||||||
|
|
||||||
**🔖 Version**: {VERSION}
|
**🔖 Version**: {VERSION}
|
||||||
<p style="display: flex; gap: 10px; flex-wrap: nowrap;">
|
<p style="display: flex; gap: 10px; flex-wrap: nowrap;">
|
||||||
<a href="https://horizonrobotics.github.io/EmbodiedGen">
|
<a href="https://horizonrobotics.github.io/robot_lab/embodied_gen/index.html">
|
||||||
<img alt="📖 Documentation" src="https://img.shields.io/badge/📖-Documentation-blue">
|
<img alt="🌐 Project Page" src="https://img.shields.io/badge/🌐-Project_Page-blue">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://arxiv.org/abs/2506.10600">
|
<a href="https://arxiv.org/abs/2506.10600">
|
||||||
<img alt="📄 arXiv" src="https://img.shields.io/badge/📄-arXiv-b31b1b">
|
<img alt="📄 arXiv" src="https://img.shields.io/badge/📄-arXiv-b31b1b">
|
||||||
@ -508,35 +256,9 @@ with gr.Blocks(
|
|||||||
category_list = get_categories(primary_val, secondary_val)
|
category_list = get_categories(primary_val, secondary_val)
|
||||||
category_val = category_list[0] if category_list else None
|
category_val = category_list[0] if category_list else None
|
||||||
asset_folder = gr.State(value=None)
|
asset_folder = gr.State(value=None)
|
||||||
gallery_df_state = gr.State()
|
|
||||||
|
|
||||||
switch_viewer_state = gr.State(value=False)
|
|
||||||
|
|
||||||
with gr.Row(equal_height=False):
|
with gr.Row(equal_height=False):
|
||||||
with gr.Column(scale=1, min_width=350):
|
with gr.Column(scale=1, min_width=350):
|
||||||
with gr.Group():
|
|
||||||
gr.Markdown("### Search Asset with Descriptions")
|
|
||||||
search_box = gr.Textbox(
|
|
||||||
label="🔎 Enter your search query",
|
|
||||||
placeholder="e.g., 'a red chair with four legs'",
|
|
||||||
interactive=GPT_AVAILABLE,
|
|
||||||
)
|
|
||||||
top_k_slider = gr.Slider(
|
|
||||||
minimum=1,
|
|
||||||
maximum=50,
|
|
||||||
value=10,
|
|
||||||
step=1,
|
|
||||||
label="Number of results",
|
|
||||||
interactive=GPT_AVAILABLE,
|
|
||||||
)
|
|
||||||
search_button = gr.Button(
|
|
||||||
"Search", variant="primary", interactive=GPT_AVAILABLE
|
|
||||||
)
|
|
||||||
if not GPT_AVAILABLE:
|
|
||||||
gr.Markdown(
|
|
||||||
"<p style='color: #ff4b4b;'>⚠️ GPT client not available. Search is disabled.</p>"
|
|
||||||
)
|
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
gr.Markdown("### Select Asset Category")
|
gr.Markdown("### Select Asset Category")
|
||||||
primary = gr.Dropdown(
|
primary = gr.Dropdown(
|
||||||
@ -556,11 +278,10 @@ with gr.Blocks(
|
|||||||
)
|
)
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
initial_assets, _, initial_df = get_assets(
|
|
||||||
primary_val, secondary_val, category_val
|
|
||||||
)
|
|
||||||
gallery = gr.Gallery(
|
gallery = gr.Gallery(
|
||||||
value=initial_assets,
|
value=get_assets(primary_val, secondary_val, category_val)[
|
||||||
|
0
|
||||||
|
],
|
||||||
label="🖼️ Asset Previews",
|
label="🖼️ Asset Previews",
|
||||||
columns=3,
|
columns=3,
|
||||||
height="auto",
|
height="auto",
|
||||||
@ -571,40 +292,14 @@ with gr.Blocks(
|
|||||||
|
|
||||||
with gr.Column(scale=2, min_width=500):
|
with gr.Column(scale=2, min_width=500):
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
with gr.Tabs():
|
viewer = gr.Model3D(
|
||||||
with gr.TabItem("Visual Mesh") as t1:
|
label="🧊 3D Model Viewer",
|
||||||
viewer = gr.Model3D(
|
height=500,
|
||||||
label="🧊 3D Model Viewer",
|
clear_color=[0.95, 0.95, 0.95],
|
||||||
height=500,
|
elem_id="lighter_mesh",
|
||||||
clear_color=[0.95, 0.95, 0.95],
|
|
||||||
elem_id="visual_mesh",
|
|
||||||
)
|
|
||||||
with gr.TabItem("Collision Mesh") as t2:
|
|
||||||
collision_viewer_a = gr.Model3D(
|
|
||||||
label="🧊 Collision Mesh",
|
|
||||||
height=500,
|
|
||||||
clear_color=[0.95, 0.95, 0.95],
|
|
||||||
elem_id="collision_mesh_a",
|
|
||||||
visible=True,
|
|
||||||
)
|
|
||||||
collision_viewer_b = gr.Model3D(
|
|
||||||
label="🧊 Collision Mesh",
|
|
||||||
height=500,
|
|
||||||
clear_color=[0.95, 0.95, 0.95],
|
|
||||||
elem_id="collision_mesh_b",
|
|
||||||
visible=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
t1.select(
|
|
||||||
fn=lambda: None,
|
|
||||||
js="() => { window.dispatchEvent(new Event('resize')); }",
|
|
||||||
)
|
)
|
||||||
t2.select(
|
|
||||||
fn=lambda: None,
|
|
||||||
js="() => { window.dispatchEvent(new Event('resize')); }",
|
|
||||||
)
|
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
# TODO: Add more asset details if needed
|
||||||
est_type_text = gr.Textbox(
|
est_type_text = gr.Textbox(
|
||||||
label="Asset category", interactive=False
|
label="Asset category", interactive=False
|
||||||
)
|
)
|
||||||
@ -617,11 +312,10 @@ with gr.Blocks(
|
|||||||
est_mu_text = gr.Textbox(
|
est_mu_text = gr.Textbox(
|
||||||
label="Friction coefficient", interactive=False
|
label="Friction coefficient", interactive=False
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Accordion(label="Asset Details", open=False):
|
||||||
desc_box = gr.Textbox(
|
desc_box = gr.Textbox(
|
||||||
label="📝 Asset Description", interactive=False
|
label="📝 Asset Description", interactive=False
|
||||||
)
|
)
|
||||||
with gr.Accordion(label="Asset Details", open=False):
|
|
||||||
urdf_file = gr.Textbox(
|
urdf_file = gr.Textbox(
|
||||||
label="URDF File Path", interactive=False, lines=2
|
label="URDF File Path", interactive=False, lines=2
|
||||||
)
|
)
|
||||||
@ -637,64 +331,55 @@ with gr.Blocks(
|
|||||||
interactive=False,
|
interactive=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
search_button.click(
|
|
||||||
fn=search_assets,
|
|
||||||
inputs=[search_box, top_k_slider],
|
|
||||||
outputs=[gallery, gallery, gallery_df_state],
|
|
||||||
)
|
|
||||||
search_box.submit(
|
|
||||||
fn=search_assets,
|
|
||||||
inputs=[search_box, top_k_slider],
|
|
||||||
outputs=[gallery, gallery, gallery_df_state],
|
|
||||||
)
|
|
||||||
|
|
||||||
def update_on_primary_change(p):
|
def update_on_primary_change(p):
|
||||||
s_choices = get_secondary_categories(p)
|
s_choices = get_secondary_categories(p)
|
||||||
initial_assets, gallery_update, initial_df = get_assets(p, None, None)
|
|
||||||
return (
|
return (
|
||||||
gr.update(choices=s_choices, value=None),
|
gr.update(choices=s_choices, value=None),
|
||||||
gr.update(choices=[], value=None),
|
gr.update(choices=[], value=None),
|
||||||
initial_assets,
|
[],
|
||||||
gallery_update,
|
gr.update(interactive=False),
|
||||||
initial_df,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_on_secondary_change(p, s):
|
def update_on_secondary_change(p, s):
|
||||||
c_choices = get_categories(p, s)
|
c_choices = get_categories(p, s)
|
||||||
asset_previews, gallery_update, gallery_df = get_assets(p, s, None)
|
return (
|
||||||
|
gr.update(choices=c_choices, value=None),
|
||||||
|
[],
|
||||||
|
gr.update(interactive=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_on_secondary_change(p, s):
|
||||||
|
c_choices = get_categories(p, s)
|
||||||
|
asset_previews, gallery_update = get_assets(p, s, None)
|
||||||
return (
|
return (
|
||||||
gr.update(choices=c_choices, value=None),
|
gr.update(choices=c_choices, value=None),
|
||||||
asset_previews,
|
asset_previews,
|
||||||
gallery_update,
|
gallery_update,
|
||||||
gallery_df,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_assets(p, s, c):
|
|
||||||
asset_previews, gallery_update, gallery_df = get_assets(p, s, c)
|
|
||||||
return asset_previews, gallery_update, gallery_df
|
|
||||||
|
|
||||||
primary.change(
|
primary.change(
|
||||||
fn=update_on_primary_change,
|
fn=update_on_primary_change,
|
||||||
inputs=[primary],
|
inputs=[primary],
|
||||||
outputs=[secondary, category, gallery, gallery, gallery_df_state],
|
outputs=[secondary, category, gallery, gallery],
|
||||||
)
|
)
|
||||||
|
|
||||||
secondary.change(
|
secondary.change(
|
||||||
fn=update_on_secondary_change,
|
fn=update_on_secondary_change,
|
||||||
inputs=[primary, secondary],
|
inputs=[primary, secondary],
|
||||||
outputs=[category, gallery, gallery, gallery_df_state],
|
outputs=[category, gallery, gallery],
|
||||||
)
|
)
|
||||||
|
|
||||||
category.change(
|
category.change(
|
||||||
fn=update_assets,
|
fn=get_assets,
|
||||||
inputs=[primary, secondary, category],
|
inputs=[primary, secondary, category],
|
||||||
outputs=[gallery, gallery, gallery_df_state],
|
outputs=[gallery, gallery],
|
||||||
)
|
)
|
||||||
|
|
||||||
gallery.select(
|
gallery.select(
|
||||||
fn=show_asset_from_gallery,
|
fn=show_asset_from_gallery,
|
||||||
inputs=[primary, secondary, category, search_box, gallery_df_state],
|
inputs=[primary, secondary, category],
|
||||||
outputs=[
|
outputs=[
|
||||||
(visual_path_state := gr.State()),
|
viewer,
|
||||||
(collision_path_state := gr.State()),
|
|
||||||
desc_box,
|
desc_box,
|
||||||
asset_folder,
|
asset_folder,
|
||||||
urdf_file,
|
urdf_file,
|
||||||
@ -703,23 +388,22 @@ with gr.Blocks(
|
|||||||
est_mass_text,
|
est_mass_text,
|
||||||
est_mu_text,
|
est_mu_text,
|
||||||
],
|
],
|
||||||
).then(
|
|
||||||
fn=render_meshes,
|
|
||||||
inputs=[visual_path_state, collision_path_state, switch_viewer_state],
|
|
||||||
outputs=[
|
|
||||||
viewer,
|
|
||||||
collision_viewer_a,
|
|
||||||
collision_viewer_b,
|
|
||||||
switch_viewer_state,
|
|
||||||
],
|
|
||||||
).success(
|
).success(
|
||||||
lambda: (gr.Button(interactive=True), gr.Button(interactive=False)),
|
lambda: tuple(
|
||||||
|
[
|
||||||
|
gr.Button(interactive=True),
|
||||||
|
gr.Button(interactive=False),
|
||||||
|
]
|
||||||
|
),
|
||||||
outputs=[extract_btn, download_btn],
|
outputs=[extract_btn, download_btn],
|
||||||
)
|
)
|
||||||
|
|
||||||
extract_btn.click(
|
extract_btn.click(
|
||||||
fn=create_asset_zip, inputs=[asset_folder], outputs=[download_btn]
|
fn=create_asset_zip, inputs=[asset_folder], outputs=[download_btn]
|
||||||
).success(fn=lambda: gr.update(interactive=True), outputs=download_btn)
|
).success(
|
||||||
|
fn=lambda: gr.update(interactive=True),
|
||||||
|
outputs=download_btn,
|
||||||
|
)
|
||||||
|
|
||||||
demo.load(start_session)
|
demo.load(start_session)
|
||||||
demo.unload(end_session)
|
demo.unload(end_session)
|
||||||
@ -727,6 +411,7 @@ with gr.Blocks(
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
demo.launch(
|
demo.launch(
|
||||||
|
server_name="10.34.8.82",
|
||||||
server_port=8088,
|
server_port=8088,
|
||||||
allowed_paths=[
|
allowed_paths=[
|
||||||
"/horizon-bucket/robot_lab/datasets/embodiedgen/assets"
|
"/horizon-bucket/robot_lab/datasets/embodiedgen/assets"
|
||||||
|
|||||||
26
docker.sh
@ -1,26 +0,0 @@
|
|||||||
IMAGE=wangxinjie/embodiedgen:v0.1.x
|
|
||||||
CONTAINER=EmbodiedGen-docker-${USER}
|
|
||||||
docker pull ${IMAGE}
|
|
||||||
docker run -itd --shm-size="64g" --gpus all --cap-add=SYS_PTRACE \
|
|
||||||
--security-opt seccomp=unconfined --privileged --net=host \
|
|
||||||
--name ${CONTAINER} ${IMAGE}
|
|
||||||
|
|
||||||
docker exec -it ${CONTAINER} bash
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
export no_proxy="localhost,127.0.0.1,192.168.48.210,120.48.161.22"
|
|
||||||
export ENDPOINT="https://llmproxy.d-robotics.cc/v1"
|
|
||||||
export API_KEY="sk-B8urDShf4TLeruwI3dB8286485Aa4984A722E945F566EfF4"
|
|
||||||
export MODEL_NAME="azure/gpt-4.1"
|
|
||||||
|
|
||||||
|
|
||||||
# start a tmux run in backend
|
|
||||||
CUDA_VISIBLE_DEVICES=0 nohup python apps/image_to_3d.py > /dev/null 2>&1 &
|
|
||||||
CUDA_VISIBLE_DEVICES=1 nohup python apps/text_to_3d.py > /dev/null 2>&1 &
|
|
||||||
CUDA_VISIBLE_DEVICES=2 nohup python apps/texture_edit.py > /dev/null 2>&1 &
|
|
||||||
|
|
||||||
# server_name="0.0.0.0", server_port=10001
|
|
||||||
# 120.48.161.22:10001
|
|
||||||
# unset http_proxy https_proxy no_proxy ENDPOINT API_KEY MODEL_NAME
|
|
||||||
# export http_proxy=http://192.168.16.76:18000 https_proxy=http://192.168.16.76:18000
|
|
||||||
@ -25,10 +25,12 @@ ENV CUDA_HOME=/usr/local/cuda-11.8 \
|
|||||||
|
|
||||||
RUN useradd -m -s /bin/bash e_user
|
RUN useradd -m -s /bin/bash e_user
|
||||||
WORKDIR /EmbodiedGen
|
WORKDIR /EmbodiedGen
|
||||||
|
COPY . .
|
||||||
RUN chown -R e_user:e_user /EmbodiedGen
|
RUN chown -R e_user:e_user /EmbodiedGen
|
||||||
USER e_user
|
USER e_user
|
||||||
|
|
||||||
RUN conda create -n embodiedgen python=3.10.13 -y
|
RUN conda create -n embodiedgen python=3.10.13 -y && \
|
||||||
|
conda run -n embodiedgen bash install.sh
|
||||||
|
|
||||||
RUN /opt/conda/bin/conda init bash && \
|
RUN /opt/conda/bin/conda init bash && \
|
||||||
echo "conda activate embodiedgen" >> /home/e_user/.bashrc
|
echo "conda activate embodiedgen" >> /home/e_user/.bashrc
|
||||||
@ -1,28 +0,0 @@
|
|||||||
# 🙌 Acknowledgement
|
|
||||||
|
|
||||||
EmbodiedGen builds upon the following amazing projects and models:
|
|
||||||
🌟 [Trellis](https://github.com/microsoft/TRELLIS) | 🌟 [Hunyuan-Delight](https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0) | 🌟 [Segment Anything](https://github.com/facebookresearch/segment-anything) | 🌟 [Rembg](https://github.com/danielgatis/rembg) | 🌟 [RMBG-1.4](https://huggingface.co/briaai/RMBG-1.4) | 🌟 [Stable Diffusion x4](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) | 🌟 [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) | 🌟 [Kolors](https://github.com/Kwai-Kolors/Kolors) | 🌟 [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 🌟 [Aesthetic Score](http://captions.christoph-schuhmann.de/aesthetic_viz_laion_sac+logos+ava1-l14-linearMSE-en-2.37B.html) | 🌟 [Pano2Room](https://github.com/TrickyGo/Pano2Room) | 🌟 [Diffusion360](https://github.com/ArcherFMY/SD-T2I-360PanoImage) | 🌟 [Kaolin](https://github.com/NVIDIAGameWorks/kaolin) | 🌟 [diffusers](https://github.com/huggingface/diffusers) | 🌟 [gsplat](https://github.com/nerfstudio-project/gsplat) | 🌟 [QWEN-2.5VL](https://github.com/QwenLM/Qwen2.5-VL) | 🌟 [GPT4o](https://platform.openai.com/docs/models/gpt-4o) | 🌟 [SD3.5](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium) | 🌟 [ManiSkill](https://github.com/haosulab/ManiSkill)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 📚 Citation
|
|
||||||
|
|
||||||
If you use EmbodiedGen in your research or projects, please cite:
|
|
||||||
|
|
||||||
```bibtex
|
|
||||||
@misc{wang2025embodiedgengenerative3dworld,
|
|
||||||
title={EmbodiedGen: Towards a Generative 3D World Engine for Embodied Intelligence},
|
|
||||||
author={Xinjie Wang and Liu Liu and Yu Cao and Ruiqi Wu and Wenkang Qin and Dehui Wang and Wei Sui and Zhizhong Su},
|
|
||||||
year={2025},
|
|
||||||
eprint={2506.10600},
|
|
||||||
archivePrefix={arXiv},
|
|
||||||
primaryClass={cs.RO},
|
|
||||||
url={https://arxiv.org/abs/2506.10600},
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## ⚖️ License
|
|
||||||
|
|
||||||
This project is licensed under the [Apache License 2.0](LICENSE). See the `LICENSE` file for details.
|
|
||||||
@ -1,25 +0,0 @@
|
|||||||
# Data API
|
|
||||||
|
|
||||||
::: embodied_gen.data.asset_converter
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
|
|
||||||
::: embodied_gen.data.datasets
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
|
|
||||||
::: embodied_gen.data.differentiable_render
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
|
|
||||||
::: embodied_gen.data.mesh_operator
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
|
|
||||||
::: embodied_gen.data.backproject_v2
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
|
|
||||||
::: embodied_gen.data.convex_decomposer
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
@ -1,7 +0,0 @@
|
|||||||
# Envs API
|
|
||||||
|
|
||||||
Documentation for simulation environments and task definitions.
|
|
||||||
|
|
||||||
::: embodied_gen.envs.pick_embodiedgen
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
@ -1,14 +0,0 @@
|
|||||||
# API Reference
|
|
||||||
|
|
||||||
Welcome to the API reference for EmbodiedGen.
|
|
||||||
|
|
||||||
This section contains detailed documentation for all public modules, classes,
|
|
||||||
and functions. Use the navigation on the left (or the list below) to
|
|
||||||
browse the different components.
|
|
||||||
|
|
||||||
* [**Data API**](data.md): Tools for data processing, conversion, and rendering.
|
|
||||||
* [**Envs API**](envs.md): Simulation environment definitions.
|
|
||||||
* [**Models API**](models.md): The core generative models (Texture, 3DGS, Layout, etc.).
|
|
||||||
* [**Trainer API**](trainer.md): PyTorch-Lightning style trainers for models.
|
|
||||||
* [**Utilities API**](utils.md): Helper functions and configuration.
|
|
||||||
* [**Validators API**](validators.md): Tools for checking and validating assets.
|
|
||||||
@ -1,33 +0,0 @@
|
|||||||
# Models API
|
|
||||||
|
|
||||||
::: embodied_gen.models.texture_model
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
|
|
||||||
::: embodied_gen.models.gs_model
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
|
|
||||||
::: embodied_gen.models.layout
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
|
|
||||||
::: embodied_gen.models.text_model
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
|
|
||||||
::: embodied_gen.models.sr_model
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
|
|
||||||
::: embodied_gen.models.segment_model
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
|
|
||||||
::: embodied_gen.models.image_comm_model
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
|
|
||||||
::: embodied_gen.models.delight_model
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
@ -1,11 +0,0 @@
|
|||||||
# Trainer API
|
|
||||||
|
|
||||||
This section covers the training pipelines for various models.
|
|
||||||
|
|
||||||
::: embodied_gen.trainer.gsplat_trainer
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
|
|
||||||
::: embodied_gen.trainer.pono2mesh_trainer
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
@ -1,47 +0,0 @@
|
|||||||
# Utilities API
|
|
||||||
|
|
||||||
General-purpose utility functions, configuration, and helper classes.
|
|
||||||
|
|
||||||
::: embodied_gen.utils.config
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
|
|
||||||
::: embodied_gen.utils.log
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
|
|
||||||
::: embodied_gen.utils.enum
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
|
|
||||||
::: embodied_gen.utils.geometry
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
|
|
||||||
::: embodied_gen.utils.gaussian
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
|
|
||||||
::: embodied_gen.utils.gpt_clients
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
|
|
||||||
::: embodied_gen.utils.process_media
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
|
|
||||||
::: embodied_gen.utils.simulation
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
|
|
||||||
::: embodied_gen.utils.tags
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
|
|
||||||
::: embodied_gen.utils.trender
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
|
|
||||||
::: embodied_gen.utils.monkey_patches
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
@ -1,15 +0,0 @@
|
|||||||
# Validators API
|
|
||||||
|
|
||||||
Tools for asset validation, quality control, and conversion.
|
|
||||||
|
|
||||||
::: embodied_gen.validators.aesthetic_predictor
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
|
|
||||||
::: embodied_gen.validators.quality_checkers
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
|
|
||||||
::: embodied_gen.validators.urdf_convertor
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
@ -1,25 +0,0 @@
|
|||||||
---
|
|
||||||
hide:
|
|
||||||
- navigation
|
|
||||||
---
|
|
||||||
|
|
||||||
# 👋 Welcome to EmbodiedGen
|
|
||||||
|
|
||||||
[](https://horizonrobotics.github.io/EmbodiedGen/)
|
|
||||||
[](https://github.com/HorizonRobotics/EmbodiedGen)
|
|
||||||
[](https://arxiv.org/abs/2506.10600)
|
|
||||||
[](https://www.youtube.com/watch?v=rG4odybuJRk)
|
|
||||||
[](https://mp.weixin.qq.com/s/HH1cPBhK2xcDbyCK4BBTbw)
|
|
||||||
<!-- [](https://horizonrobotics.github.io/robot_lab/embodied_gen/index.html) -->
|
|
||||||
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Gallery-Explorer)
|
|
||||||
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Image-to-3D)
|
|
||||||
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Text-to-3D)
|
|
||||||
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Texture-Gen)
|
|
||||||
|
|
||||||
*EmbodiedGen*: Towards a Generative 3D World Engine for Embodied Intelligence.
|
|
||||||
|
|
||||||
<img src="assets/overall.jpg" alt="Overall Framework" width="700"/>
|
|
||||||
|
|
||||||
> ***EmbodiedGen*** is a generative engine to create diverse and interactive 3D worlds composed of high-quality 3D assets(mesh & 3DGS) with plausible physics, leveraging generative AI to address the challenges of generalization in embodied intelligence related research.
|
|
||||||
|
|
||||||
---
|
|
||||||
@ -1,40 +0,0 @@
|
|||||||
---
|
|
||||||
hide:
|
|
||||||
- navigation
|
|
||||||
---
|
|
||||||
|
|
||||||
## ✅ Setup Environment
|
|
||||||
```sh
|
|
||||||
git clone https://github.com/HorizonRobotics/EmbodiedGen.git
|
|
||||||
cd EmbodiedGen
|
|
||||||
git checkout v0.1.6
|
|
||||||
git submodule update --init --recursive --progress
|
|
||||||
conda create -n embodiedgen python=3.10.13 -y # recommended to use a new env.
|
|
||||||
conda activate embodiedgen
|
|
||||||
bash install.sh basic
|
|
||||||
```
|
|
||||||
|
|
||||||
Please `huggingface-cli login` to ensure that the ckpts can be downloaded automatically afterwards.
|
|
||||||
|
|
||||||
## ✅ Starting from Docker
|
|
||||||
|
|
||||||
We provide a pre-built Docker image on [Docker Hub](https://hub.docker.com/repository/docker/wangxinjie/embodiedgen) with a configured environment for your convenience. For more details, please refer to [Docker documentation](https://github.com/HorizonRobotics/EmbodiedGen/tree/master/docker).
|
|
||||||
|
|
||||||
> **Note:** Model checkpoints are not included in the image, they will be automatically downloaded on first run. You still need to set up the GPT Agent manually.
|
|
||||||
|
|
||||||
```sh
|
|
||||||
IMAGE=wangxinjie/embodiedgen:env_v0.1.x
|
|
||||||
CONTAINER=EmbodiedGen-docker-${USER}
|
|
||||||
docker pull ${IMAGE}
|
|
||||||
docker run -itd --shm-size="64g" --gpus all --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --privileged --net=host --name ${CONTAINER} ${IMAGE}
|
|
||||||
docker exec -it ${CONTAINER} bash
|
|
||||||
```
|
|
||||||
|
|
||||||
## ✅ Setup GPT Agent
|
|
||||||
|
|
||||||
Update the API key in file: `embodied_gen/utils/gpt_config.yaml`.
|
|
||||||
|
|
||||||
You can choose between two backends for the GPT agent:
|
|
||||||
|
|
||||||
- **`gpt-4o`** (Recommended) – Use this if you have access to **Azure OpenAI**.
|
|
||||||
- **`qwen2.5-vl`** – An alternative with free usage via OpenRouter, apply a free key [here](https://openrouter.ai/settings/keys) and update `api_key` in `embodied_gen/utils/gpt_config.yaml` (50 free requests per day)
|
|
||||||
@ -1,34 +0,0 @@
|
|||||||
document.addEventListener('DOMContentLoaded', function () {
|
|
||||||
|
|
||||||
const swiperElement = document.querySelector('.swiper1');
|
|
||||||
|
|
||||||
if (swiperElement) {
|
|
||||||
const swiper = new Swiper('.swiper1', {
|
|
||||||
loop: true,
|
|
||||||
slidesPerView: 3,
|
|
||||||
spaceBetween: 20,
|
|
||||||
navigation: {
|
|
||||||
nextEl: '.swiper-button-next',
|
|
||||||
prevEl: '.swiper-button-prev',
|
|
||||||
},
|
|
||||||
centeredSlides: false,
|
|
||||||
noSwiping: true,
|
|
||||||
noSwipingClass: 'swiper-no-swiping',
|
|
||||||
watchSlidesProgress: true,
|
|
||||||
});
|
|
||||||
|
|
||||||
const modelViewers = swiperElement.querySelectorAll('model-viewer');
|
|
||||||
|
|
||||||
if (modelViewers.length > 0) {
|
|
||||||
let loadedCount = 0;
|
|
||||||
modelViewers.forEach(mv => {
|
|
||||||
mv.addEventListener('load', () => {
|
|
||||||
loadedCount++;
|
|
||||||
if (loadedCount === modelViewers.length) {
|
|
||||||
swiper.update();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
@ -1,75 +0,0 @@
|
|||||||
# 🖼️ Image-to-3D Service
|
|
||||||
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Image-to-3D)
|
|
||||||
|
|
||||||
This service launches a web application to generate physically plausible 3D asset URDF from single input image, offering high-quality support for digital twin systems.
|
|
||||||
|
|
||||||
<div class="swiper swiper1" style="max-width: 1000px; margin: 20px auto; border-radius: 12px;">
|
|
||||||
<div class="swiper-wrapper">
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/image/astronaut.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
style="display:block; width:100%; height:250px; background-color: #f8f8f8;">
|
|
||||||
</model-viewer>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/image/robot_i.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
style="display:block; width:100%; height:250px; background-color: #f8f8f8;">
|
|
||||||
</model-viewer>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/image/desk.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
style="display:block; width:100%; height:250px; background-color: #f8f8f8;">
|
|
||||||
</model-viewer>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/image/chair.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
style="display:block; width:100%; height:250px; background-color: #f8f8f8;">
|
|
||||||
</model-viewer>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/image/desk2.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
style="display:block; width:100%; height:250px; background-color: #f8f8f8;">
|
|
||||||
</model-viewer>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-button-prev"></div>
|
|
||||||
<div class="swiper-button-next"></div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## ☁️ Run the App Service
|
|
||||||
|
|
||||||
!!! note "Note"
|
|
||||||
Gradio servive is a simplified demonstration. For the full functionality, please refer to [img3d-cli](../tutorials/image_to_3d.md).
|
|
||||||
|
|
||||||
Run the image-to-3D generation service locally. Models are automatically downloaded on first run, please be patient.
|
|
||||||
|
|
||||||
```sh
|
|
||||||
# Run in foreground
|
|
||||||
python apps/image_to_3d.py
|
|
||||||
|
|
||||||
# Or run in the background
|
|
||||||
CUDA_VISIBLE_DEVICES=0 nohup python apps/image_to_3d.py > /dev/null 2>&1 &
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
!!! tip "Getting Started"
|
|
||||||
- Try it directly online via our [Hugging Face Space](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Image-to-3D) — no installation required.
|
|
||||||
- Explore EmbodiedGen generated sim-ready [Assets Gallery](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Gallery-Explorer).
|
|
||||||
- For instructions on using the generated asset in any simulator, see [Any Simulators Tutorial](tutorials/any_simulators.md).
|
|
||||||
@ -1,46 +0,0 @@
|
|||||||
# Interactive 3D Generation & Visualization Services
|
|
||||||
|
|
||||||
EmbodiedGen provides a suite of **interactive services** that transform images and text into **physically plausible, simulator-ready 3D assets**.
|
|
||||||
Each service is optimized for visual quality, simulation compatibility, and scalability — making it easy to create, edit, and explore assets for **digital twin**, **robotic simulation**, and **AI embodiment** scenarios.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## ⚙️ Prerequisites
|
|
||||||
|
|
||||||
!!! tip "Prerequisites"
|
|
||||||
Make sure to finish the [Installation Guide](../install.md) before launching any service. Missing dependencies will cause initialization errors. Model weights are automatically downloaded on first run.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 🧩 Overview of Available Services
|
|
||||||
|
|
||||||
| Service | Description |
|
|
||||||
|----------|--------------|
|
|
||||||
| [🖼️ **Image to 3D**](image_to_3d.md) | Generate physically plausible 3D asset URDF from single input image, offering high-quality support for digital twin systems. |
|
|
||||||
| [📝 **Text to 3D**](text_to_3d.md) | Generate physically plausible 3D assets from text descriptions for a wide range of geometry and styles. |
|
|
||||||
| [🎨 **Texture Edit**](texture_edit.md) | Generate visually rich textures for existing 3D meshes. |
|
|
||||||
| [📸 **Asset Gallery**](visualize_asset.md) | Explore and download EmbodiedGen All-Simulators-Ready Assets. |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## ⚙️ How to Run Locally
|
|
||||||
|
|
||||||
!!! tip "Quick Start"
|
|
||||||
Each service can be launched directly as a local Gradio app:
|
|
||||||
```bash
|
|
||||||
# Example: Run the Image-to-3D service
|
|
||||||
python apps/image_to_3d.py
|
|
||||||
```
|
|
||||||
|
|
||||||
Models are automatically downloaded on first run. For full CLI usage, please check the corresponding [tutorials](../tutorials/index.md).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 🧭 Next Steps
|
|
||||||
|
|
||||||
- [📘 Tutorials](../tutorials/index.md) – Learn how to use EmbodiedGen in generating interactive 3D scenes for embodied intelligence.
|
|
||||||
- [🧱 API Reference](../api/index.md) – Integrate EmbodiedGen code programmatically.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
> 💡 *EmbodiedGen bridges the gap between AI-driven 3D generation and physically grounded simulation, enabling true embodiment for intelligent agents.*
|
|
||||||
@ -1,89 +0,0 @@
|
|||||||
# 📝 Text-to-3D Service
|
|
||||||
|
|
||||||
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Text-to-3D)
|
|
||||||
|
|
||||||
This service launches a web application to generate physically plausible 3D assets from text descriptions for a wide range of geometry and styles.
|
|
||||||
|
|
||||||
<div class="swiper swiper1" style="max-width: 1000px; margin: 20px auto; border-radius: 12px;">
|
|
||||||
<div class="swiper-wrapper">
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/text/c2.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
background-color="#ffffff"
|
|
||||||
style="display:block; width: 100%; height: 160px; border-radius: 12px;"
|
|
||||||
>
|
|
||||||
</model-viewer>
|
|
||||||
<p style="text-align: center; margin-top: 8px; font-size: 14px;">"Antique brass key, intricate filigree"</p>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/text/c3.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
background-color="#ffffff"
|
|
||||||
style="display:block; width: 100%; height: 160px;">
|
|
||||||
</model-viewer>
|
|
||||||
<p style="text-align: center; margin-top: 8px; font-size: 14px;">"Rusty old wrench, peeling paint"</p>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/text/c4.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
background-color="#ffffff"
|
|
||||||
style="display:block; width: 100%; height: 160px;">
|
|
||||||
</model-viewer>
|
|
||||||
<p style="text-align: center; margin-top: 8px; font-size: 14px;">"Sleek black drone, red sensors"</p>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/text/c7.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
background-color="#ffffff"
|
|
||||||
style="display:block; width: 100%; height: 160px;">
|
|
||||||
</model-viewer>
|
|
||||||
<p style="text-align: center; margin-top: 8px; font-size: 14px;">"Miniature screwdriver with bright orange handle"</p>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/text/c9.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
background-color="#ffffff"
|
|
||||||
style="display:block; width: 100%; height: 160px;">
|
|
||||||
</model-viewer>
|
|
||||||
<p style="text-align: center; margin-top: 8px; font-size: 14px;">"European style wooden dressing table"</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-button-prev swiper1-prev"></div>
|
|
||||||
<div class="swiper-button-next swiper1-next"></div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## ☁️ Run the App Service
|
|
||||||
|
|
||||||
Create 3D assets from text descriptions for a wide range of geometry and styles.
|
|
||||||
!!! note "Note"
|
|
||||||
Gradio servive is a simplified demonstration. For the full functionality, please refer to [text3d-cli](../tutorials/text_to_3d.md).
|
|
||||||
|
|
||||||
|
|
||||||
Text-to-image model based on the Kolors model, supporting Chinese and English prompts. Models downloaded automatically on first run, please be patient.
|
|
||||||
|
|
||||||
```sh
|
|
||||||
# Run in foreground
|
|
||||||
python apps/text_to_3d.py
|
|
||||||
|
|
||||||
# Or run in the background
|
|
||||||
CUDA_VISIBLE_DEVICES=0 nohup python apps/text_to_3d.py > /dev/null 2>&1 &
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
!!! tip "Getting Started"
|
|
||||||
- You can also try Text-to-3D instantly online via our [Hugging Face Space](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Text-to-3D) — no installation required.
|
|
||||||
- Explore EmbodiedGen generated sim-ready [Assets Gallery](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Gallery-Explorer).
|
|
||||||
- For instructions on using the generated asset in any simulator, see [Any Simulators Tutorial](tutorials/any_simulators.md).
|
|
||||||
@ -1,141 +0,0 @@
|
|||||||
# 🎨 Texture Generation Service
|
|
||||||
|
|
||||||
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Texture-Gen)
|
|
||||||
|
|
||||||
This service launches a web application to generate visually rich textures for 3D mesh.
|
|
||||||
|
|
||||||
<div class="swiper swiper1" style="max-width: 1000px; margin: 20px auto; border-radius: 12px;">
|
|
||||||
<div class="swiper-wrapper">
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/edit/hello2.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
style="display:block; width:100%; height:250px; background-color: #f8f8f8;">
|
|
||||||
</model-viewer>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/edit/love4.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
style="display:block; width:100%; height:250px; background-color: #f8f8f8;">
|
|
||||||
</model-viewer>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/edit/robot_china.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
style="display:block; width:100%; height:250px; background-color: #f8f8f8;">
|
|
||||||
</model-viewer>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/edit/horse1.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
style="display:block; width:100%; height:250px; background-color: #f8f8f8;">
|
|
||||||
</model-viewer>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/edit/horse2.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
style="display:block; width:100%; height:250px; background-color: #f8f8f8;">
|
|
||||||
</model-viewer>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/edit/shoe_0_0.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
style="display:block; width:100%; height:250px; background-color: #f8f8f8;">
|
|
||||||
</model-viewer>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/edit/shoe_0_3.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
style="display:block; width:100%; height:250px; background-color: #f8f8f8;">
|
|
||||||
</model-viewer>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/edit/clock_num.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
style="display:block; width:100%; height:250px; background-color: #f8f8f8;">
|
|
||||||
</model-viewer>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/edit/clock5.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
style="display:block; width:100%; height:250px; background-color: #f8f8f8;">
|
|
||||||
</model-viewer>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/edit/vase1.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
style="display:block; width:100%; height:250px; background-color: #f8f8f8;">
|
|
||||||
</model-viewer>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/edit/vase2.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
style="display:block; width:100%; height:250px; background-color: #f8f8f8;">
|
|
||||||
</model-viewer>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/edit/drill1.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
style="display:block; width:100%; height:250px; background-color: #f8f8f8;">
|
|
||||||
</model-viewer>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/edit/drill4.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
style="display:block; width:100%; height:250px; background-color: #f8f8f8;">
|
|
||||||
</model-viewer>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-button-prev"></div>
|
|
||||||
<div class="swiper-button-next"></div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## ☁️ Run the App Service
|
|
||||||
|
|
||||||
!!! note "Note"
|
|
||||||
Gradio servive is a simplified demonstration. For the full functionality, please refer to [texture-cli](../tutorials/texture_edit.md).
|
|
||||||
|
|
||||||
Run the texture generation service locally. Models downloaded automatically on first run, see `download_kolors_weights`, `geo_cond_mv`.
|
|
||||||
|
|
||||||
|
|
||||||
```sh
|
|
||||||
# Run in foreground
|
|
||||||
python apps/texture_edit.py
|
|
||||||
|
|
||||||
# Or run in the background
|
|
||||||
CUDA_VISIBLE_DEVICES=0 nohup python apps/texture_edit.py > /dev/null 2>&1 &
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
!!! tip "Getting Started"
|
|
||||||
- Try it directly online via our [Hugging Face Space](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Texture-Gen) — no installation required.
|
|
||||||
- Explore EmbodiedGen generated sim-ready [Assets Gallery](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Gallery-Explorer).
|
|
||||||
- For instructions on using the generated asset in any simulator, see [Any Simulators Tutorial](tutorials/any_simulators.md).
|
|
||||||
@ -1,8 +0,0 @@
|
|||||||
# 📸 EmbodiedGen All-Simulators-Ready Assets Gallery
|
|
||||||
|
|
||||||
[](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Gallery-Explorer)
|
|
||||||
|
|
||||||
|
|
||||||
!!! tip "Getting Started"
|
|
||||||
- Explore EmbodiedGen generated sim-ready [Assets Gallery](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Gallery-Explorer).
|
|
||||||
- For instructions on using the generated asset in any simulator, see [Any Simulators Tutorial](tutorials/any_simulators.md).
|
|
||||||
@ -1,22 +0,0 @@
|
|||||||
/* Adjust the logo size */
|
|
||||||
.md-header__button.md-logo {
|
|
||||||
height: 4rem;
|
|
||||||
padding: 0;
|
|
||||||
display: inline-flex;
|
|
||||||
align-items: center;
|
|
||||||
}
|
|
||||||
.md-header__button.md-logo img {
|
|
||||||
height: 4rem;
|
|
||||||
width: auto;
|
|
||||||
}
|
|
||||||
|
|
||||||
.md-typeset pre code {
|
|
||||||
font-size: 0.7rem;
|
|
||||||
/* line-height: 1.5; */
|
|
||||||
font-family: "Fira Code", "JetBrains Mono", monospace;
|
|
||||||
}
|
|
||||||
|
|
||||||
.md-typeset .admonition,
|
|
||||||
.md-typeset details {
|
|
||||||
font-size: 0.77rem;
|
|
||||||
}
|
|
||||||
@ -1,68 +0,0 @@
|
|||||||
# 🎮 Use EmbodiedGen in Any Simulator
|
|
||||||
|
|
||||||
Leverage **EmbodiedGen-generated assets** with *accurate physical collisions* and *consistent visual appearance* across major simulation engines — **IsaacSim**, **MuJoCo**, **Genesis**, **PyBullet**, **IsaacGym**, and **SAPIEN**.
|
|
||||||
|
|
||||||
!!! tip "Universal Compatibility"
|
|
||||||
EmbodiedGen assets follow **standardized URDF semantics** with **physically consistent collision meshes**,
|
|
||||||
enabling seamless loading across multiple simulation frameworks — no manual editing needed.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 🧩 Supported Simulators
|
|
||||||
|
|
||||||
| Simulator | Conversion Class |
|
|
||||||
|------------|------------------|
|
|
||||||
| [IsaacSim](https://github.com/isaac-sim/IsaacSim) | `MeshtoUSDConverter` |
|
|
||||||
| [MuJoCo](https://github.com/google-deepmind/mujoco) / [Genesis](https://github.com/Genesis-Embodied-AI/Genesis) | `MeshtoMJCFConverter` |
|
|
||||||
| [SAPIEN](https://github.com/haosulab/SAPIEN) / [IsaacGym](https://github.com/isaac-sim/IsaacGymEnvs) / [PyBullet](https://github.com/bulletphysics/bullet3) | `.urdf` generated by EmbodiedGen can be used **directly** |
|
|
||||||
|
|
||||||
!!! note "Simulator Integration Overview"
|
|
||||||
|
|
||||||
This table summarizes the compatibility of EmbodiedGen assets with various simulators:
|
|
||||||
|
|
||||||
| Simulator | Supported Format | Notes |
|
|
||||||
|-----------|-----------------|-------|
|
|
||||||
| IsaacSim | USD / .usda | Use `MeshtoUSDConverter` to convert mesh to USD format. |
|
|
||||||
| MuJoCo | MJCF (.xml) | Use `MeshtoMJCFConverter` for physics-ready assets. |
|
|
||||||
| Genesis | MJCF (.xml) | Same as MuJoCo; fully compatible with Genesis scenes. |
|
|
||||||
| SAPIEN | URDF (.urdf) | Can directly load EmbodiedGen `.urdf` assets. |
|
|
||||||
| IsaacGym | URDF (.urdf) | Directly usable. |
|
|
||||||
| PyBullet | URDF (.urdf) | Directly usable. |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
|
|
||||||
## 🧱 Example: Conversion to Target Simulator
|
|
||||||
|
|
||||||
```python
|
|
||||||
from embodied_gen.data.asset_converter import cvt_embodiedgen_asset_to_anysim
|
|
||||||
from embodied_gen.utils.enum import AssetType, SimAssetMapper
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
simulator_name: Literal[
|
|
||||||
"isaacsim",
|
|
||||||
"isaacgym",
|
|
||||||
"genesis",
|
|
||||||
"pybullet",
|
|
||||||
"sapien3",
|
|
||||||
"mujoco",
|
|
||||||
] = "mujoco"
|
|
||||||
|
|
||||||
dst_asset_path = cvt_embodiedgen_asset_to_anysim(
|
|
||||||
urdf_files=[
|
|
||||||
"path1_to_embodiedgen_asset/asset.urdf",
|
|
||||||
"path2_to_embodiedgen_asset/asset.urdf",
|
|
||||||
],
|
|
||||||
target_dirs=[
|
|
||||||
"path1_to_target_dir/asset.usd",
|
|
||||||
"path2_to_target_dir/asset.usd",
|
|
||||||
],
|
|
||||||
target_type=SimAssetMapper[simulator_name],
|
|
||||||
source_type=AssetType.MESH,
|
|
||||||
overwrite=True,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
<img src="../assets/simulators_collision.jpg" alt="simulators_collision" width="800">
|
|
||||||
|
|
||||||
Collision and visualization mesh across simulators, showing consistent geometry and material fidelity.
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
# Real-to-Sim Digital Twin Creation
|
|
||||||
|
|
||||||
<img src="../assets/real2sim_mujoco.gif" alt="real2sim_mujoco" width="600">
|
|
||||||
@ -1,22 +0,0 @@
|
|||||||
# Simulation in Parallel Envs
|
|
||||||
|
|
||||||
Generate multiple parallel simulation environments with `gym.make` and record sensor and trajectory data.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## ⚡ Command-Line Usage
|
|
||||||
|
|
||||||
```sh
|
|
||||||
python embodied_gen/scripts/parallel_sim.py \
|
|
||||||
--layout_file "outputs/layouts_gen/task_0000/layout.json" \
|
|
||||||
--output_dir "outputs/parallel_sim/task_0000" \
|
|
||||||
--num_envs 16
|
|
||||||
```
|
|
||||||
|
|
||||||
<div style="display: flex; justify-content: center; align-items: center; gap: 16px; margin: 16px 0;">
|
|
||||||
<img src="../assets/parallel_sim.gif" alt="parallel_sim1"
|
|
||||||
style="width: 330px; max-width: 100%; border-radius: 12px; display: block;">
|
|
||||||
<img src="../assets/parallel_sim2.gif" alt="parallel_sim2"
|
|
||||||
style="width: 330px; max-width: 100%; border-radius: 12px; display: block;">
|
|
||||||
</div>
|
|
||||||
|
|
||||||
@ -1,71 +0,0 @@
|
|||||||
# 🖼️ Image-to-3D: Physically Plausible 3D Asset Generation
|
|
||||||
|
|
||||||
Generate **physically plausible 3D assets** from a single input image, supporting **digital twin** and **simulation environments**.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## ⚡ Command-Line Usage
|
|
||||||
|
|
||||||
```bash
|
|
||||||
img3d-cli --image_path apps/assets/example_image/sample_00.jpg \
|
|
||||||
apps/assets/example_image/sample_01.jpg apps/assets/example_image/sample_19.jpg \
|
|
||||||
--n_retry 1 --output_root outputs/imageto3d
|
|
||||||
```
|
|
||||||
|
|
||||||
You will get the following results:
|
|
||||||
|
|
||||||
<div class="swiper swiper1" style="max-width: 1000px; margin: 20px auto; border-radius: 12px;">
|
|
||||||
<div class="swiper-wrapper">
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/image2/sample_00.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
style="display:block; width:100%; height:250px; background-color: #f8f8f8;">
|
|
||||||
</model-viewer>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/image2/sample_01.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
style="display:block; width:100%; height:250px; background-color: #f8f8f8;">
|
|
||||||
</model-viewer>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/image2/sample_19.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
style="display:block; width:100%; height:250px; background-color: #f8f8f8;">
|
|
||||||
</model-viewer>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-button-prev"></div>
|
|
||||||
<div class="swiper-button-next"></div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
|
|
||||||
The generated results are organized as follows:
|
|
||||||
```sh
|
|
||||||
outputs/imageto3d/sample_xx/result
|
|
||||||
├── mesh
|
|
||||||
│ ├── material_0.png
|
|
||||||
│ ├── material.mtl
|
|
||||||
│ ├── sample_xx_collision.ply
|
|
||||||
│ ├── sample_xx.glb
|
|
||||||
│ ├── sample_xx_gs.ply
|
|
||||||
│ └── sample_xx.obj
|
|
||||||
├── sample_xx.urdf
|
|
||||||
└── video.mp4
|
|
||||||
```
|
|
||||||
|
|
||||||
- `mesh/` → Geometry and texture files, including visual mesh, collision mesh and 3DGS.
|
|
||||||
- `*.urdf` → Simulator-ready URDF with collision and visual meshes
|
|
||||||
- `video.mp4` → Preview of the generated 3D asset
|
|
||||||
|
|
||||||
|
|
||||||
!!! tip "Getting Started"
|
|
||||||
- Try it directly online via our [Hugging Face Space](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Image-to-3D) — no installation required.
|
|
||||||
- Explore EmbodiedGen generated sim-ready [Assets Gallery](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Gallery-Explorer).
|
|
||||||
- For instructions on using the generated asset in any simulator, see [Any Simulators Tutorial](any_simulators.md).
|
|
||||||
@ -1,176 +0,0 @@
|
|||||||
# Tutorials & Interface Usage
|
|
||||||
|
|
||||||
Welcome to the tutorials for `EmbodiedGen`. `EmbodiedGen` is a powerful toolset for generating 3D assets, textures, scenes, and interactive layouts ready for simulators and digital twin environments.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## ⚙️ Prerequisites
|
|
||||||
|
|
||||||
!!! tip "Prerequisites"
|
|
||||||
Make sure to finish the [Installation Guide](../install.md) before starting tutorial. Missing dependencies will cause initialization errors. Model weights are automatically downloaded on first run.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## [🖼️ Image-to-3D](image_to_3d.md)
|
|
||||||
|
|
||||||
Generate **physically plausible 3D assets** from a single input image, supporting digital twin and simulation environments.
|
|
||||||
|
|
||||||
|
|
||||||
<div class="swiper swiper1" style="max-width: 1000px; margin: 20px auto; border-radius: 12px;">
|
|
||||||
<div class="swiper-wrapper">
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/image2/sample_00.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
style="display:block; width:100%; height:250px; background-color: #f8f8f8;">
|
|
||||||
</model-viewer>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/image2/sample_01.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
style="display:block; width:100%; height:250px; background-color: #f8f8f8;">
|
|
||||||
</model-viewer>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/image2/sample_19.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
style="display:block; width:100%; height:250px; background-color: #f8f8f8;">
|
|
||||||
</model-viewer>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-button-prev"></div>
|
|
||||||
<div class="swiper-button-next"></div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## [📝 Text-to-3D](text_to_3d.md)
|
|
||||||
|
|
||||||
Create **physically plausible 3D assets** from **text descriptions**, supporting a wide range of geometry, style, and material details.
|
|
||||||
|
|
||||||
|
|
||||||
<div class="swiper swiper1" style="max-width: 1000px; margin: 20px auto; border-radius: 12px;">
|
|
||||||
<div class="swiper-wrapper">
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/text2/sample3d_0.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
background-color="#ffffff"
|
|
||||||
style="display:block; width: 100%; height: 160px; border-radius: 12px;"
|
|
||||||
>
|
|
||||||
</model-viewer>
|
|
||||||
<p style="text-align: center; margin-top: 8px; font-size: 14px;">"small bronze figurine of a lion"</p>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/text2/sample3d_1.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
background-color="#ffffff"
|
|
||||||
style="display:block; width: 100%; height: 160px;">
|
|
||||||
</model-viewer>
|
|
||||||
<p style="text-align: center; margin-top: 8px; font-size: 14px;">"A globe with wooden base"</p>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/text2/sample3d_2.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
background-color="#ffffff"
|
|
||||||
style="display:block; width: 100%; height: 160px;">
|
|
||||||
</model-viewer>
|
|
||||||
<p style="text-align: center; margin-top: 8px; font-size: 14px;">"wooden table with embroidery"</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-button-prev"></div>
|
|
||||||
<div class="swiper-button-next"></div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## [🎨 Texture Generation](texture_gen.md)
|
|
||||||
|
|
||||||
Generate **high-quality textures** for 3D meshes using **text prompts**, supporting both Chinese and English, to enhance the visual appearance of existing 3D assets.
|
|
||||||
|
|
||||||
<div class="swiper swiper1" style="max-width: 1000px; margin: 20px auto; border-radius: 12px;">
|
|
||||||
<div class="swiper-wrapper">
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/edit2/robot_text.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
camera-orbit="180deg auto auto"
|
|
||||||
style="display:block; width:100%; height:250px; background-color: #f8f8f8;">
|
|
||||||
</model-viewer>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/edit2/horse.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
camera-orbit="90deg auto auto"
|
|
||||||
style="display:block; width:100%; height:250px; background-color: #f8f8f8;">
|
|
||||||
</model-viewer>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-button-prev"></div>
|
|
||||||
<div class="swiper-button-next"></div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## [🌍 3D Scene Generation](scene_gen.md)
|
|
||||||
|
|
||||||
Generate **physically consistent and visually coherent 3D environments** from text prompts. Typically used as **background** 3DGS scenes in simulators for efficient and photo-realistic rendering.
|
|
||||||
|
|
||||||
<img src="../assets/scene3d.gif" style="width: 600px; max-width: 100%; border-radius: 12px; display: block; margin: 16px auto;">
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## [🏞️ Layout Generation](layout_gen.md)
|
|
||||||
|
|
||||||
Generate diverse, physically realistic, and scalable **interactive 3D scenes** from natural language task descriptions, while also modeling the robot and manipulable objects.
|
|
||||||
|
|
||||||
<div align="center" style="display: grid; grid-template-columns: repeat(2, 1fr); gap: 16px; justify-items: center; margin: 20px 0;">
|
|
||||||
<img src="../assets/layout1.gif" alt="layout1" style="width: 400px; border-radius: 12px; display: block;">
|
|
||||||
<img src="../assets/layout2.gif" alt="layout2" style="width: 400px; border-radius: 12px; display: block;">
|
|
||||||
<img src="../assets/layout3.gif" alt="layout3" style="width: 400px; border-radius: 12px; display: block;">
|
|
||||||
<img src="../assets/Iscene_demo2.gif" alt="layout4" style="width: 400px; border-radius: 12px; display: block;">
|
|
||||||
</div>
|
|
||||||
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## [🏎️ Parallel Simulation](gym_env.md)
|
|
||||||
|
|
||||||
Generate multiple **parallel simulation environments** with `gym.make` and record sensor and trajectory data.
|
|
||||||
|
|
||||||
<div style="display: flex; justify-content: center; align-items: center; gap: 16px; margin: 16px 0;">
|
|
||||||
<img src="../assets/parallel_sim.gif" alt="parallel_sim1"
|
|
||||||
style="width: 330px; max-width: 100%; border-radius: 12px; display: block;">
|
|
||||||
<img src="../assets/parallel_sim2.gif" alt="parallel_sim2"
|
|
||||||
style="width: 330px; max-width: 100%; border-radius: 12px; display: block;">
|
|
||||||
</div>
|
|
||||||
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## [🎮 Use in Any Simulator](any_simulators.md)
|
|
||||||
|
|
||||||
Seamlessly use EmbodiedGen-generated assets in major simulators like **IsaacSim**, **MuJoCo**, **Genesis**, **PyBullet**, **IsaacGym**, and **SAPIEN**, featuring **accurate physical collisions** and **consistent visual appearance**.
|
|
||||||
|
|
||||||
<div align="center">
|
|
||||||
<img src="../assets/simulators_collision.jpg" alt="simulators_collision" style="width: 600px; max-width: 100%; border-radius: 12px; display: block; margin: 16px 0;">
|
|
||||||
</div>
|
|
||||||
|
|
||||||
## [🔧 Real-to-Sim Digital Twin Creation](digital_twin.md)
|
|
||||||
|
|
||||||
<div align="center">
|
|
||||||
<img src="../assets/real2sim_mujoco.gif" alt="real2sim_mujoco" style="width: 400px; max-width: 100%; border-radius: 12px; display: block; margin: 16px 0;">
|
|
||||||
</div>
|
|
||||||
@ -1,93 +0,0 @@
|
|||||||
# 🏞️ Layout Generation — Interactive 3D Scenes
|
|
||||||
|
|
||||||
Layout Generation enables the generation of diverse, physically realistic, and scalable **interactive 3D scenes** directly from natural language task descriptions, while also modeling the robot's pose and relationships with manipulable objects. Target objects are randomly placed within the robot's reachable range, making the scenes readily usable for downstream simulation and reinforcement learning tasks in any mainstream simulator.
|
|
||||||
|
|
||||||
<div align="center" style="display: grid; grid-template-columns: repeat(2, 1fr); gap: 16px; justify-items: center; margin: 20px 0;">
|
|
||||||
<img src="../assets/layout1.gif" alt="layout1" style="width: 400px; border-radius: 12px; display: block;">
|
|
||||||
<img src="../assets/layout2.gif" alt="layout2" style="width: 400px; border-radius: 12px; display: block;">
|
|
||||||
<img src="../assets/layout3.gif" alt="layout3" style="width: 400px; border-radius: 12px; display: block;">
|
|
||||||
<img src="../assets/Iscene_demo2.gif" alt="layout4" style="width: 400px; border-radius: 12px; display: block;">
|
|
||||||
</div>
|
|
||||||
|
|
||||||
!!! note "Model Requirement"
|
|
||||||
The text-to-image model is based on `SD3.5 Medium`. Usage requires agreement to the [model license](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Prerequisites — Prepare Background 3D Scenes
|
|
||||||
|
|
||||||
Before running `layout-cli`, you need to prepare background 3D scenes.
|
|
||||||
You can either **generate your own** using the [`scene3d-cli`](scene_gen.md), or **download pre-generated backgrounds** for convenience.
|
|
||||||
|
|
||||||
Each scene takes approximately **30 minutes** to generate. For efficiency, we recommend pre-generating and listing them in `outputs/bg_scenes/scene_list.txt`.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Option 1: Download pre-generated backgrounds (~4 GB)
|
|
||||||
hf download xinjjj/scene3d-bg --repo-type dataset --local-dir outputs
|
|
||||||
|
|
||||||
# Option 2: Download a larger background set (~14 GB)
|
|
||||||
hf download xinjjj..RLv2-BG --repo-type dataset --local-dir outputs
|
|
||||||
```
|
|
||||||
|
|
||||||
## Generate Interactive Layout Scenes
|
|
||||||
|
|
||||||
Use the `layout-cli` to create interactive 3D scenes based on task descriptions. Each layout generation takes approximately 30 minutes.
|
|
||||||
|
|
||||||
```sh
|
|
||||||
layout-cli \
|
|
||||||
--task_descs "Place the pen in the mug on the desk" \
|
|
||||||
"Put the fruit on the table on the plate" \
|
|
||||||
--bg_list "outputs/bg_scenes/scene_list.txt" \
|
|
||||||
--output_root "outputs/layouts_gen" \
|
|
||||||
--insert_robot
|
|
||||||
```
|
|
||||||
|
|
||||||
You will get the following results:
|
|
||||||
<div align="center" style="display: flex; justify-content: center; align-items: flex-start; gap: 24px; margin: 20px auto; flex-wrap: wrap;">
|
|
||||||
<img src="../assets/Iscene_demo1.gif" alt="Iscene_demo1"
|
|
||||||
style="height: 200px; border-radius: 12px; display: block; width: auto;">
|
|
||||||
<img src="../assets/Iscene_demo2.gif" alt="Iscene_demo2"
|
|
||||||
style="height: 200px; border-radius: 12px; display: block; width: auto;">
|
|
||||||
</div>
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Batch Generation
|
|
||||||
|
|
||||||
You can also run multiple tasks via a task list file in the backend.
|
|
||||||
|
|
||||||
```sh
|
|
||||||
CUDA_VISIBLE_DEVICES=0 nohup layout-cli \
|
|
||||||
--task_descs "apps/assets/example_layout/task_list.txt" \
|
|
||||||
--bg_list "outputs/bg_scenes/scene_list.txt" \
|
|
||||||
--output_root "outputs/layouts_gens" \
|
|
||||||
--insert_robot > layouts_gens.log &
|
|
||||||
```
|
|
||||||
|
|
||||||
> 💡 Remove `--insert_robot` if you don’t need robot pose consideration in layout generation.
|
|
||||||
|
|
||||||
### Layout Randomization
|
|
||||||
|
|
||||||
Using `compose_layout.py`, you can **recompose the layout** of the generated interactive 3D scenes.
|
|
||||||
|
|
||||||
```sh
|
|
||||||
python embodied_gen/scripts/compose_layout.py \
|
|
||||||
--layout_path "outputs/layouts_gens/task_0000/layout.json" \
|
|
||||||
--output_dir "outputs/layouts_gens/task_0000/recompose" \
|
|
||||||
--insert_robot
|
|
||||||
```
|
|
||||||
|
|
||||||
### Load Interactive 3D Scenes in Simulators
|
|
||||||
|
|
||||||
We provide `sim-cli`, that allows users to easily load generated layouts into an interactive 3D simulation using the SAPIEN engine.
|
|
||||||
|
|
||||||
```sh
|
|
||||||
sim-cli --layout_path "outputs/layouts_gen/task_0000/layout.json" \
|
|
||||||
--output_dir "outputs/layouts_gen/task_0000/sapien_render" --insert_robot
|
|
||||||
```
|
|
||||||
|
|
||||||
!!! tip "Recommended Workflow"
|
|
||||||
1. Generate or download background scenes using `scene3d-cli`.
|
|
||||||
2. Create interactive layouts from task descriptions using `layout-cli`.
|
|
||||||
3. Optionally recompose them using `compose_layout.py`.
|
|
||||||
4. Load the final layouts into simulators with `sim-cli`.
|
|
||||||
@ -1,47 +0,0 @@
|
|||||||
# 🌍 3D Scene Generation
|
|
||||||
|
|
||||||
Generate **physically consistent and visually coherent 3D environments** from text prompts. Typically used as **background** 3DGS scenes in simulators for efficient and photo-realistic rendering.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
<img src="../assets/scene3d.gif" style="width: 600px; border-radius: 12px; display: block; margin: 16px auto;">
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## ⚡ Command-Line Usage
|
|
||||||
|
|
||||||
> 💡 Run `bash install.sh extra` to install additional dependencies if you plan to use `scene3d-cli`.
|
|
||||||
|
|
||||||
It typically takes ~30 minutes per scene to generate both the colored mesh and 3D Gaussian Splat(3DGS) representation.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 scene3d-cli \
|
|
||||||
--prompts "Art studio with easel and canvas" \
|
|
||||||
--output_dir outputs/bg_scenes/ \
|
|
||||||
--seed 0 \
|
|
||||||
--gs3d.max_steps 4000 \
|
|
||||||
--disable_pano_check
|
|
||||||
```
|
|
||||||
|
|
||||||
The generated results are organized as follows:
|
|
||||||
```sh
|
|
||||||
outputs/bg_scenes/scene_000
|
|
||||||
├── gs_model.ply
|
|
||||||
├── gsplat_cfg.yml
|
|
||||||
├── mesh_model.ply
|
|
||||||
├── pano_image.png
|
|
||||||
├── prompt.txt
|
|
||||||
└── video.mp4
|
|
||||||
```
|
|
||||||
|
|
||||||
- `gs_model.ply` → Generated 3D scene in 3D Gaussian Splat representation.
|
|
||||||
- `mesh_model.ply` → Color mesh representation of the generated scene.
|
|
||||||
- `gsplat_cfg.yml` → Configuration file for 3DGS training and rendering parameters.
|
|
||||||
- `pano_image.png` → Generated panoramic view image.
|
|
||||||
- `prompt.txt` → Original scene generation prompt for traceability.
|
|
||||||
- `video.mp4` → Preview RGB and depth preview of the generated 3D scene.
|
|
||||||
|
|
||||||
!!! note "Usage Notes"
|
|
||||||
- `3D Scene Generation` produces background 3DGS scenes optimized for efficient rendering in simulation environments. We also provide hybrid rendering examples combining background 3DGS with foreground interactive assets, see the [example]()
|
|
||||||
for details.
|
|
||||||
- In Layout Generation, we further demonstrate task-desc-driven interactive 3D scene generation, building complete 3D scenes based on natural language task descriptions. See the [Layout Generation Guide](layout_gen.md).
|
|
||||||
@ -1,115 +0,0 @@
|
|||||||
# 📝 Text-to-3D: Generate 3D Assets from Text
|
|
||||||
|
|
||||||
Create **physically plausible 3D assets** from **text descriptions**, supporting a wide range of geometry, style, and material details.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## ⚡ Command-Line Usage
|
|
||||||
|
|
||||||
**Basic CLI(recommend)**
|
|
||||||
|
|
||||||
Text-to-image model based on Stable Diffusion 3.5 Medium, English prompts only. Usage requires agreement to the [model license (click “Accept”)](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium).
|
|
||||||
|
|
||||||
```bash
|
|
||||||
text3d-cli \
|
|
||||||
--prompts "small bronze figurine of a lion" "A globe with wooden base" "wooden table with embroidery" \
|
|
||||||
--n_image_retry 1 \
|
|
||||||
--n_asset_retry 1 \
|
|
||||||
--n_pipe_retry 1 \
|
|
||||||
--seed_img 0 \
|
|
||||||
--output_root outputs/textto3d
|
|
||||||
```
|
|
||||||
|
|
||||||
- `--n_image_retry`: Number of retries per prompt for text-to-image generation
|
|
||||||
- `--n_asset_retry`: Retry attempts for image-to-3D assets generation
|
|
||||||
- `--n_pipe_retry`: Pipeline retry for end-to-end 3D asset quality check
|
|
||||||
- `--seed_img`: Optional initial seed image for style guidance
|
|
||||||
- `--output_root`: Directory to save generated assets
|
|
||||||
|
|
||||||
For large-scale 3D asset generation, set `--n_image_retry=4` `--n_asset_retry=3` `--n_pipe_retry=2`, slower but better, via automatic checking and retries. For more diverse results, omit `--seed_img`.
|
|
||||||
|
|
||||||
You will get the following results:
|
|
||||||
|
|
||||||
<div class="swiper swiper1" style="max-width: 1000px; margin: 20px auto; border-radius: 12px;">
|
|
||||||
<div class="swiper-wrapper">
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/text2/sample3d_0.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
background-color="#ffffff"
|
|
||||||
style="display:block; width: 100%; height: 160px; border-radius: 12px;"
|
|
||||||
>
|
|
||||||
</model-viewer>
|
|
||||||
<p style="text-align: center; margin-top: 8px; font-size: 14px;">"small bronze figurine of a lion"</p>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/text2/sample3d_1.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
background-color="#ffffff"
|
|
||||||
style="display:block; width: 100%; height: 160px;">
|
|
||||||
</model-viewer>
|
|
||||||
<p style="text-align: center; margin-top: 8px; font-size: 14px;">"A globe with wooden base"</p>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/text2/sample3d_2.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
background-color="#ffffff"
|
|
||||||
style="display:block; width: 100%; height: 160px;">
|
|
||||||
</model-viewer>
|
|
||||||
<p style="text-align: center; margin-top: 8px; font-size: 14px;">"wooden table with embroidery"</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-button-prev swiper1-prev"></div>
|
|
||||||
<div class="swiper-button-next swiper1-next"></div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
|
|
||||||
Kolors Model CLI (Supports Chinese & English Prompts):
|
|
||||||
```bash
|
|
||||||
bash embodied_gen/scripts/textto3d.sh \
|
|
||||||
--prompts "small bronze figurine of a lion" "A globe with wooden base and latitude and longitude lines" "橙色电动手钻,有磨损细节" \
|
|
||||||
--output_root outputs/textto3d_k
|
|
||||||
```
|
|
||||||
|
|
||||||
> Models with more permissive licenses can be found in `embodied_gen/models/image_comm_model.py`.
|
|
||||||
|
|
||||||
|
|
||||||
The generated results are organized as follows:
|
|
||||||
```sh
|
|
||||||
outputs/textto3d
|
|
||||||
├── asset3d
|
|
||||||
│ ├── sample3d_xx
|
|
||||||
│ │ └── result
|
|
||||||
│ │ ├── mesh
|
|
||||||
│ │ │ ├── material_0.png
|
|
||||||
│ │ │ ├── material.mtl
|
|
||||||
│ │ │ ├── sample3d_xx_collision.obj
|
|
||||||
│ │ │ ├── sample3d_xx.glb
|
|
||||||
│ │ │ ├── sample3d_xx_gs.ply
|
|
||||||
│ │ │ └── sample3d_xx.obj
|
|
||||||
│ │ ├── sample3d_xx.urdf
|
|
||||||
│ │ └── video.mp4
|
|
||||||
└── images
|
|
||||||
├── sample3d_xx.png
|
|
||||||
├── sample3d_xx_raw.png
|
|
||||||
```
|
|
||||||
|
|
||||||
- `mesh/` → 3D geometry and texture files for the asset, including visual mesh, collision mesh and 3DGS
|
|
||||||
- `*.urdf` → Simulator-ready URDF including collision and visual meshes
|
|
||||||
- `video.mp4` → Preview video of the generated 3D asset
|
|
||||||
- `images/sample3d_xx.png` → Foreground-extracted image used for image-to-3D step
|
|
||||||
- `images/sample3d_xx_raw.png` → Original generated image from the text-to-image step
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
!!! tip "Getting Started"
|
|
||||||
- You can also try Text-to-3D instantly online via our [Hugging Face Space](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Text-to-3D) — no installation required.
|
|
||||||
- Explore EmbodiedGen generated sim-ready [Assets Gallery](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Gallery-Explorer).
|
|
||||||
- For instructions on using the generated asset in any simulator, see [Any Simulators Tutorial](any_simulators.md).
|
|
||||||
@ -1,57 +0,0 @@
|
|||||||
# 🎨 Texture Generation: Create Visually Rich Textures for 3D Meshes
|
|
||||||
|
|
||||||
Generate **high-quality textures** for 3D meshes using **text prompts**, supporting both **Chinese and English**. This allows you to enhance the visual appearance of existing 3D assets for simulation, visualization, or digital twin applications.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## ⚡ Command-Line Usage
|
|
||||||
|
|
||||||
```bash
|
|
||||||
texture-cli \
|
|
||||||
--mesh_path "apps/assets/example_texture/meshes/robot_text.obj" \
|
|
||||||
"apps/assets/example_texture/meshes/horse.obj" \
|
|
||||||
--prompt "举着牌子的写实风格机器人,大眼睛,牌子上写着“Hello”的文字" \
|
|
||||||
"A gray horse head with flying mane and brown eyes" \
|
|
||||||
--output_root "outputs/texture_gen" \
|
|
||||||
--seed 0
|
|
||||||
```
|
|
||||||
|
|
||||||
- `--mesh_path` → Path(s) to input 3D mesh files
|
|
||||||
- `--prompt` → Text prompt(s) describing desired texture/style for each mesh
|
|
||||||
- `--output_root` → Directory to save textured meshes and related outputs
|
|
||||||
- `--seed` → Random seed for reproducible texture generation
|
|
||||||
|
|
||||||
|
|
||||||
You will get the following results:
|
|
||||||
|
|
||||||
<div class="swiper swiper1" style="max-width: 1000px; margin: 20px auto; border-radius: 12px;">
|
|
||||||
<div class="swiper-wrapper">
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/edit2/robot_text.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
camera-orbit="180deg auto auto"
|
|
||||||
style="display:block; width:100%; height:250px; background-color: #f8f8f8;">
|
|
||||||
</model-viewer>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-slide model-card">
|
|
||||||
<model-viewer
|
|
||||||
src="https://raw.githubusercontent.com/HochCC/ShowCase/main/edit2/horse.glb"
|
|
||||||
auto-rotate
|
|
||||||
camera-controls
|
|
||||||
camera-orbit="90deg auto auto"
|
|
||||||
style="display:block; width:100%; height:250px; background-color: #f8f8f8;">
|
|
||||||
</model-viewer>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div class="swiper-button-prev"></div>
|
|
||||||
<div class="swiper-button-next"></div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
!!! tip "Getting Started"
|
|
||||||
- Try it directly online via our [Hugging Face Space](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Texture-Gen) — no installation required.
|
|
||||||
- Explore EmbodiedGen generated sim-ready [Assets Gallery](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Gallery-Explorer).
|
|
||||||
- For instructions on using the generated asset in any simulator, see [Any Simulators Tutorial](any_simulators.md).
|
|
||||||
@ -4,12 +4,12 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from shutil import copy, copytree, rmtree
|
from shutil import copy, copytree, rmtree
|
||||||
|
|
||||||
import trimesh
|
import trimesh
|
||||||
from scipy.spatial.transform import Rotation
|
from scipy.spatial.transform import Rotation
|
||||||
from embodied_gen.utils.enum import AssetType
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -17,62 +17,72 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AssetConverterFactory",
|
"AssetConverterFactory",
|
||||||
|
"AssetType",
|
||||||
"MeshtoMJCFConverter",
|
"MeshtoMJCFConverter",
|
||||||
"MeshtoUSDConverter",
|
"MeshtoUSDConverter",
|
||||||
"URDFtoUSDConverter",
|
"URDFtoUSDConverter",
|
||||||
"cvt_embodiedgen_asset_to_anysim",
|
"cvt_embodiedgen_asset_to_anysim",
|
||||||
"PhysicsUSDAdder",
|
"PhysicsUSDAdder",
|
||||||
|
"SimAssetMapper",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AssetType(str):
|
||||||
|
"""Asset type enumeration."""
|
||||||
|
|
||||||
|
MJCF = "mjcf"
|
||||||
|
USD = "usd"
|
||||||
|
URDF = "urdf"
|
||||||
|
MESH = "mesh"
|
||||||
|
|
||||||
|
|
||||||
|
class SimAssetMapper:
|
||||||
|
_mapping = dict(
|
||||||
|
ISAACSIM=AssetType.USD,
|
||||||
|
ISAACGYM=AssetType.URDF,
|
||||||
|
MUJOCO=AssetType.MJCF,
|
||||||
|
GENESIS=AssetType.MJCF,
|
||||||
|
SAPIEN=AssetType.URDF,
|
||||||
|
PYBULLET=AssetType.URDF,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __class_getitem__(cls, key: str):
|
||||||
|
key = key.upper()
|
||||||
|
if key.startswith("SAPIEN"):
|
||||||
|
key = "SAPIEN"
|
||||||
|
return cls._mapping[key]
|
||||||
|
|
||||||
|
|
||||||
def cvt_embodiedgen_asset_to_anysim(
|
def cvt_embodiedgen_asset_to_anysim(
|
||||||
urdf_files: list[str],
|
urdf_files: list[str],
|
||||||
target_dirs: list[str],
|
|
||||||
target_type: AssetType,
|
target_type: AssetType,
|
||||||
source_type: AssetType,
|
source_type: AssetType,
|
||||||
overwrite: bool = False,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
"""Convert URDF files generated by EmbodiedGen into formats required by simulators.
|
"""Convert URDF files generated by EmbodiedGen into the format required by all simulators.
|
||||||
|
|
||||||
Supported simulators include SAPIEN, Isaac Sim, MuJoCo, Isaac Gym, Genesis, and Pybullet.
|
Supported simulators include SAPIEN, Isaac Sim, MuJoCo, Isaac Gym, Genesis, and Pybullet.
|
||||||
Converting to the `USD` format requires `isaacsim` to be installed.
|
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```py
|
|
||||||
from embodied_gen.data.asset_converter import cvt_embodiedgen_asset_to_anysim
|
|
||||||
from embodied_gen.utils.enum import AssetType
|
|
||||||
|
|
||||||
dst_asset_path = cvt_embodiedgen_asset_to_anysim(
|
dst_asset_path = cvt_embodiedgen_asset_to_anysim(
|
||||||
urdf_files=[
|
urdf_files,
|
||||||
"path1_to_embodiedgen_asset/asset.urdf",
|
target_type=SimAssetMapper[simulator_name],
|
||||||
"path2_to_embodiedgen_asset/asset.urdf",
|
|
||||||
],
|
|
||||||
target_dirs=[
|
|
||||||
"path1_to_target_dir/asset.usd",
|
|
||||||
"path2_to_target_dir/asset.usd",
|
|
||||||
],
|
|
||||||
target_type=AssetType.USD,
|
|
||||||
source_type=AssetType.MESH,
|
source_type=AssetType.MESH,
|
||||||
)
|
)
|
||||||
```
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
urdf_files (list[str]): List of URDF file paths.
|
urdf_files (List[str]): List of URDF file paths to be converted.
|
||||||
target_dirs (list[str]): List of target directories.
|
target_type (AssetType): The target asset type.
|
||||||
target_type (AssetType): Target asset type.
|
source_type (AssetType): The source asset type.
|
||||||
source_type (AssetType): Source asset type.
|
|
||||||
overwrite (bool, optional): Overwrite existing files.
|
|
||||||
**kwargs: Additional converter arguments.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, str]: Mapping from URDF file to converted asset file.
|
Dict[str, str]: A dictionary mapping the original URDF file path to the converted asset file path.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(urdf_files, str):
|
if isinstance(urdf_files, str):
|
||||||
urdf_files = [urdf_files]
|
urdf_files = [urdf_files]
|
||||||
if isinstance(target_dirs, str):
|
|
||||||
urdf_files = [target_dirs]
|
|
||||||
|
|
||||||
# If the target type is URDF, no conversion is needed.
|
# If the target type is URDF, no conversion is needed.
|
||||||
if target_type == AssetType.URDF:
|
if target_type == AssetType.URDF:
|
||||||
@ -86,17 +96,18 @@ def cvt_embodiedgen_asset_to_anysim(
|
|||||||
asset_paths = dict()
|
asset_paths = dict()
|
||||||
|
|
||||||
with asset_converter:
|
with asset_converter:
|
||||||
for urdf_file, target_dir in zip(urdf_files, target_dirs):
|
for urdf_file in urdf_files:
|
||||||
filename = os.path.basename(urdf_file).replace(".urdf", "")
|
filename = os.path.basename(urdf_file).replace(".urdf", "")
|
||||||
|
asset_dir = os.path.dirname(urdf_file)
|
||||||
if target_type == AssetType.MJCF:
|
if target_type == AssetType.MJCF:
|
||||||
target_file = f"{target_dir}/{filename}.xml"
|
target_file = f"{asset_dir}/../mjcf/{filename}.xml"
|
||||||
elif target_type == AssetType.USD:
|
elif target_type == AssetType.USD:
|
||||||
target_file = f"{target_dir}/{filename}.usd"
|
target_file = f"{asset_dir}/../usd/{filename}.usd"
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Target type {target_type} not supported."
|
f"Target type {target_type} not supported."
|
||||||
)
|
)
|
||||||
if not os.path.exists(target_file) or overwrite:
|
if not os.path.exists(target_file):
|
||||||
asset_converter.convert(urdf_file, target_file)
|
asset_converter.convert(urdf_file, target_file)
|
||||||
|
|
||||||
asset_paths[urdf_file] = target_file
|
asset_paths[urdf_file] = target_file
|
||||||
@ -105,35 +116,16 @@ def cvt_embodiedgen_asset_to_anysim(
|
|||||||
|
|
||||||
|
|
||||||
class AssetConverterBase(ABC):
|
class AssetConverterBase(ABC):
|
||||||
"""Abstract base class for asset converters.
|
"""Converter abstract base class."""
|
||||||
|
|
||||||
Provides context management and mesh transformation utilities.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def convert(self, urdf_path: str, output_path: str, **kwargs) -> str:
|
def convert(self, urdf_path: str, output_path: str, **kwargs) -> str:
|
||||||
"""Convert an asset file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
urdf_path (str): Path to input URDF file.
|
|
||||||
output_path (str): Path to output file.
|
|
||||||
**kwargs: Additional arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Path to converted asset.
|
|
||||||
"""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def transform_mesh(
|
def transform_mesh(
|
||||||
self, input_mesh: str, output_mesh: str, mesh_origin: ET.Element
|
self, input_mesh: str, output_mesh: str, mesh_origin: ET.Element
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Apply transform to mesh based on URDF origin element.
|
"""Apply transform to the mesh based on the origin element in URDF."""
|
||||||
|
|
||||||
Args:
|
|
||||||
input_mesh (str): Path to input mesh.
|
|
||||||
output_mesh (str): Path to output mesh.
|
|
||||||
mesh_origin (ET.Element): Origin element from URDF.
|
|
||||||
"""
|
|
||||||
mesh = trimesh.load(input_mesh, group_material=False)
|
mesh = trimesh.load(input_mesh, group_material=False)
|
||||||
rpy = list(map(float, mesh_origin.get("rpy").split(" ")))
|
rpy = list(map(float, mesh_origin.get("rpy").split(" ")))
|
||||||
rotation = Rotation.from_euler("xyz", rpy, degrees=False)
|
rotation = Rotation.from_euler("xyz", rpy, degrees=False)
|
||||||
@ -155,19 +147,14 @@ class AssetConverterBase(ABC):
|
|||||||
return
|
return
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
"""Context manager entry."""
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
"""Context manager exit."""
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class MeshtoMJCFConverter(AssetConverterBase):
|
class MeshtoMJCFConverter(AssetConverterBase):
|
||||||
"""Converts mesh-based URDF files to MJCF format.
|
"""Convert URDF files into MJCF format."""
|
||||||
|
|
||||||
Handles geometry, materials, and asset copying.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -176,12 +163,6 @@ class MeshtoMJCFConverter(AssetConverterBase):
|
|||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
def _copy_asset_file(self, src: str, dst: str) -> None:
|
def _copy_asset_file(self, src: str, dst: str) -> None:
|
||||||
"""Copies asset file if not already present.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
src (str): Source file path.
|
|
||||||
dst (str): Destination file path.
|
|
||||||
"""
|
|
||||||
if os.path.exists(dst):
|
if os.path.exists(dst):
|
||||||
return
|
return
|
||||||
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
||||||
@ -199,19 +180,7 @@ class MeshtoMJCFConverter(AssetConverterBase):
|
|||||||
material: ET.Element | None = None,
|
material: ET.Element | None = None,
|
||||||
is_collision: bool = False,
|
is_collision: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Adds geometry to MJCF body from URDF link.
|
"""Add geometry to the MJCF body from the URDF link."""
|
||||||
|
|
||||||
Args:
|
|
||||||
mujoco_element (ET.Element): MJCF asset element.
|
|
||||||
link (ET.Element): URDF link element.
|
|
||||||
body (ET.Element): MJCF body element.
|
|
||||||
tag (str): Tag name ("visual" or "collision").
|
|
||||||
input_dir (str): Input directory.
|
|
||||||
output_dir (str): Output directory.
|
|
||||||
mesh_name (str): Mesh name.
|
|
||||||
material (ET.Element, optional): Material element.
|
|
||||||
is_collision (bool, optional): If True, treat as collision geometry.
|
|
||||||
"""
|
|
||||||
element = link.find(tag)
|
element = link.find(tag)
|
||||||
geometry = element.find("geometry")
|
geometry = element.find("geometry")
|
||||||
mesh = geometry.find("mesh")
|
mesh = geometry.find("mesh")
|
||||||
@ -270,20 +239,7 @@ class MeshtoMJCFConverter(AssetConverterBase):
|
|||||||
name: str,
|
name: str,
|
||||||
reflectance: float = 0.2,
|
reflectance: float = 0.2,
|
||||||
) -> ET.Element:
|
) -> ET.Element:
|
||||||
"""Adds materials to MJCF asset from URDF link.
|
"""Add materials to the MJCF asset from the URDF link."""
|
||||||
|
|
||||||
Args:
|
|
||||||
mujoco_element (ET.Element): MJCF asset element.
|
|
||||||
link (ET.Element): URDF link element.
|
|
||||||
tag (str): Tag name.
|
|
||||||
input_dir (str): Input directory.
|
|
||||||
output_dir (str): Output directory.
|
|
||||||
name (str): Material name.
|
|
||||||
reflectance (float, optional): Reflectance value.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ET.Element: Material element.
|
|
||||||
"""
|
|
||||||
element = link.find(tag)
|
element = link.find(tag)
|
||||||
geometry = element.find("geometry")
|
geometry = element.find("geometry")
|
||||||
mesh = geometry.find("mesh")
|
mesh = geometry.find("mesh")
|
||||||
@ -323,12 +279,7 @@ class MeshtoMJCFConverter(AssetConverterBase):
|
|||||||
return material
|
return material
|
||||||
|
|
||||||
def convert(self, urdf_path: str, mjcf_path: str):
|
def convert(self, urdf_path: str, mjcf_path: str):
|
||||||
"""Converts a URDF file to MJCF format.
|
"""Convert a URDF file to MJCF format."""
|
||||||
|
|
||||||
Args:
|
|
||||||
urdf_path (str): Path to URDF file.
|
|
||||||
mjcf_path (str): Path to output MJCF file.
|
|
||||||
"""
|
|
||||||
tree = ET.parse(urdf_path)
|
tree = ET.parse(urdf_path)
|
||||||
root = tree.getroot()
|
root = tree.getroot()
|
||||||
|
|
||||||
@ -382,22 +333,10 @@ class MeshtoMJCFConverter(AssetConverterBase):
|
|||||||
|
|
||||||
|
|
||||||
class URDFtoMJCFConverter(MeshtoMJCFConverter):
|
class URDFtoMJCFConverter(MeshtoMJCFConverter):
|
||||||
"""Converts URDF files with joints to MJCF format, handling joint transformations.
|
"""Convert URDF files with joints to MJCF format, handling transformations from joints."""
|
||||||
|
|
||||||
Handles fixed joints and hierarchical body structure.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def convert(self, urdf_path: str, mjcf_path: str, **kwargs) -> str:
|
def convert(self, urdf_path: str, mjcf_path: str, **kwargs) -> str:
|
||||||
"""Converts a URDF file with joints to MJCF format.
|
"""Convert a URDF file with joints to MJCF format."""
|
||||||
|
|
||||||
Args:
|
|
||||||
urdf_path (str): Path to URDF file.
|
|
||||||
mjcf_path (str): Path to output MJCF file.
|
|
||||||
**kwargs: Additional arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Path to converted MJCF file.
|
|
||||||
"""
|
|
||||||
tree = ET.parse(urdf_path)
|
tree = ET.parse(urdf_path)
|
||||||
root = tree.getroot()
|
root = tree.getroot()
|
||||||
|
|
||||||
@ -481,15 +420,11 @@ class URDFtoMJCFConverter(MeshtoMJCFConverter):
|
|||||||
|
|
||||||
|
|
||||||
class MeshtoUSDConverter(AssetConverterBase):
|
class MeshtoUSDConverter(AssetConverterBase):
|
||||||
"""Converts mesh-based URDF files to USD format.
|
"""Convert Mesh file from URDF into USD format."""
|
||||||
|
|
||||||
Adds physics APIs and post-processes collision meshes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
DEFAULT_BIND_APIS = [
|
DEFAULT_BIND_APIS = [
|
||||||
"MaterialBindingAPI",
|
"MaterialBindingAPI",
|
||||||
"PhysicsMeshCollisionAPI",
|
"PhysicsMeshCollisionAPI",
|
||||||
"PhysxConvexDecompositionCollisionAPI",
|
|
||||||
"PhysicsCollisionAPI",
|
"PhysicsCollisionAPI",
|
||||||
"PhysxCollisionAPI",
|
"PhysxCollisionAPI",
|
||||||
"PhysicsMassAPI",
|
"PhysicsMassAPI",
|
||||||
@ -504,21 +439,13 @@ class MeshtoUSDConverter(AssetConverterBase):
|
|||||||
simulation_app=None,
|
simulation_app=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Initializes the converter.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
force_usd_conversion (bool, optional): Force USD conversion.
|
|
||||||
make_instanceable (bool, optional): Make prims instanceable.
|
|
||||||
simulation_app (optional): Simulation app instance.
|
|
||||||
**kwargs: Additional arguments.
|
|
||||||
"""
|
|
||||||
if simulation_app is not None:
|
if simulation_app is not None:
|
||||||
self.simulation_app = simulation_app
|
self.simulation_app = simulation_app
|
||||||
|
|
||||||
self.exit_close = kwargs.pop("exit_close", True)
|
if "exit_close" in kwargs:
|
||||||
self.physx_max_convex_hulls = kwargs.pop("physx_max_convex_hulls", 32)
|
self.exit_close = kwargs.pop("exit_close")
|
||||||
self.physx_max_vertices = kwargs.pop("physx_max_vertices", 16)
|
else:
|
||||||
self.physx_max_voxel_res = kwargs.pop("physx_max_voxel_res", 10000)
|
self.exit_close = True
|
||||||
|
|
||||||
self.usd_parms = dict(
|
self.usd_parms = dict(
|
||||||
force_usd_conversion=force_usd_conversion,
|
force_usd_conversion=force_usd_conversion,
|
||||||
@ -527,7 +454,6 @@ class MeshtoUSDConverter(AssetConverterBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
"""Context manager entry, launches simulation app if needed."""
|
|
||||||
from isaaclab.app import AppLauncher
|
from isaaclab.app import AppLauncher
|
||||||
|
|
||||||
if not hasattr(self, "simulation_app"):
|
if not hasattr(self, "simulation_app"):
|
||||||
@ -546,23 +472,17 @@ class MeshtoUSDConverter(AssetConverterBase):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
"""Context manager exit, closes simulation app if created."""
|
|
||||||
# Close the simulation app if it was created here
|
# Close the simulation app if it was created here
|
||||||
if exc_val is not None:
|
|
||||||
logger.error(f"Exception occurred: {exc_val}.")
|
|
||||||
|
|
||||||
if hasattr(self, "app_launcher") and self.exit_close:
|
if hasattr(self, "app_launcher") and self.exit_close:
|
||||||
self.simulation_app.close()
|
self.simulation_app.close()
|
||||||
|
|
||||||
|
if exc_val is not None:
|
||||||
|
logger.error(f"Exception occurred: {exc_val}.")
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def convert(self, urdf_path: str, output_file: str):
|
def convert(self, urdf_path: str, output_file: str):
|
||||||
"""Converts a URDF file to USD and post-processes collision meshes.
|
"""Convert a URDF file to USD and post-process collision meshes."""
|
||||||
|
|
||||||
Args:
|
|
||||||
urdf_path (str): Path to URDF file.
|
|
||||||
output_file (str): Path to output USD file.
|
|
||||||
"""
|
|
||||||
from isaaclab.sim.converters import MeshConverter, MeshConverterCfg
|
from isaaclab.sim.converters import MeshConverter, MeshConverterCfg
|
||||||
from pxr import PhysxSchema, Sdf, Usd, UsdShade
|
from pxr import PhysxSchema, Sdf, Usd, UsdShade
|
||||||
|
|
||||||
@ -589,8 +509,6 @@ class MeshtoUSDConverter(AssetConverterBase):
|
|||||||
stage = Usd.Stage.Open(usd_path)
|
stage = Usd.Stage.Open(usd_path)
|
||||||
layer = stage.GetRootLayer()
|
layer = stage.GetRootLayer()
|
||||||
with Usd.EditContext(stage, layer):
|
with Usd.EditContext(stage, layer):
|
||||||
base_prim = stage.GetPseudoRoot().GetChildren()[0]
|
|
||||||
base_prim.SetMetadata("kind", "component")
|
|
||||||
for prim in stage.Traverse():
|
for prim in stage.Traverse():
|
||||||
# Change texture path to relative path.
|
# Change texture path to relative path.
|
||||||
if prim.GetName() == "material_0":
|
if prim.GetName() == "material_0":
|
||||||
@ -603,9 +521,11 @@ class MeshtoUSDConverter(AssetConverterBase):
|
|||||||
|
|
||||||
# Add convex decomposition collision and set ShrinkWrap.
|
# Add convex decomposition collision and set ShrinkWrap.
|
||||||
elif prim.GetName() == "mesh":
|
elif prim.GetName() == "mesh":
|
||||||
approx_attr = prim.CreateAttribute(
|
approx_attr = prim.GetAttribute("physics:approximation")
|
||||||
"physics:approximation", Sdf.ValueTypeNames.Token
|
if not approx_attr:
|
||||||
)
|
approx_attr = prim.CreateAttribute(
|
||||||
|
"physics:approximation", Sdf.ValueTypeNames.Token
|
||||||
|
)
|
||||||
approx_attr.Set("convexDecomposition")
|
approx_attr.Set("convexDecomposition")
|
||||||
|
|
||||||
physx_conv_api = (
|
physx_conv_api = (
|
||||||
@ -613,15 +533,6 @@ class MeshtoUSDConverter(AssetConverterBase):
|
|||||||
prim
|
prim
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
physx_conv_api.GetMaxConvexHullsAttr().Set(
|
|
||||||
self.physx_max_convex_hulls
|
|
||||||
)
|
|
||||||
physx_conv_api.GetHullVertexLimitAttr().Set(
|
|
||||||
self.physx_max_vertices
|
|
||||||
)
|
|
||||||
physx_conv_api.GetVoxelResolutionAttr().Set(
|
|
||||||
self.physx_max_voxel_res
|
|
||||||
)
|
|
||||||
physx_conv_api.GetShrinkWrapAttr().Set(True)
|
physx_conv_api.GetShrinkWrapAttr().Set(True)
|
||||||
|
|
||||||
api_schemas = prim.GetMetadata("apiSchemas")
|
api_schemas = prim.GetMetadata("apiSchemas")
|
||||||
@ -641,27 +552,15 @@ class MeshtoUSDConverter(AssetConverterBase):
|
|||||||
|
|
||||||
|
|
||||||
class PhysicsUSDAdder(MeshtoUSDConverter):
|
class PhysicsUSDAdder(MeshtoUSDConverter):
|
||||||
"""Adds physics APIs and collision properties to USD assets.
|
|
||||||
|
|
||||||
Useful for post-processing USD files for simulation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
DEFAULT_BIND_APIS = [
|
DEFAULT_BIND_APIS = [
|
||||||
"MaterialBindingAPI",
|
"MaterialBindingAPI",
|
||||||
"PhysicsMeshCollisionAPI",
|
"PhysicsMeshCollisionAPI",
|
||||||
"PhysxConvexDecompositionCollisionAPI",
|
|
||||||
"PhysicsCollisionAPI",
|
"PhysicsCollisionAPI",
|
||||||
"PhysxCollisionAPI",
|
"PhysxCollisionAPI",
|
||||||
"PhysicsRigidBodyAPI",
|
"PhysicsRigidBodyAPI",
|
||||||
]
|
]
|
||||||
|
|
||||||
def convert(self, usd_path: str, output_file: str = None):
|
def convert(self, usd_path: str, output_file: str = None):
|
||||||
"""Adds physics APIs and collision properties to a USD file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
usd_path (str): Path to input USD file.
|
|
||||||
output_file (str, optional): Path to output USD file.
|
|
||||||
"""
|
|
||||||
from pxr import PhysxSchema, Sdf, Usd, UsdGeom, UsdPhysics
|
from pxr import PhysxSchema, Sdf, Usd, UsdGeom, UsdPhysics
|
||||||
|
|
||||||
if output_file is None:
|
if output_file is None:
|
||||||
@ -684,23 +583,18 @@ class PhysicsUSDAdder(MeshtoUSDConverter):
|
|||||||
if "lightfactory" in prim.GetName().lower():
|
if "lightfactory" in prim.GetName().lower():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
approx_attr = prim.CreateAttribute(
|
approx_attr = prim.GetAttribute(
|
||||||
"physics:approximation", Sdf.ValueTypeNames.Token
|
"physics:approximation"
|
||||||
)
|
)
|
||||||
|
if not approx_attr:
|
||||||
|
approx_attr = prim.CreateAttribute(
|
||||||
|
"physics:approximation",
|
||||||
|
Sdf.ValueTypeNames.Token,
|
||||||
|
)
|
||||||
approx_attr.Set("convexDecomposition")
|
approx_attr.Set("convexDecomposition")
|
||||||
|
|
||||||
physx_conv_api = PhysxSchema.PhysxConvexDecompositionCollisionAPI.Apply(
|
physx_conv_api = PhysxSchema.PhysxConvexDecompositionCollisionAPI.Apply(
|
||||||
prim
|
prim
|
||||||
)
|
)
|
||||||
physx_conv_api.GetMaxConvexHullsAttr().Set(
|
|
||||||
self.physx_max_convex_hulls
|
|
||||||
)
|
|
||||||
physx_conv_api.GetHullVertexLimitAttr().Set(
|
|
||||||
self.physx_max_vertices
|
|
||||||
)
|
|
||||||
physx_conv_api.GetVoxelResolutionAttr().Set(
|
|
||||||
self.physx_max_voxel_res
|
|
||||||
)
|
|
||||||
physx_conv_api.GetShrinkWrapAttr().Set(True)
|
physx_conv_api.GetShrinkWrapAttr().Set(True)
|
||||||
|
|
||||||
rigid_body_api = UsdPhysics.RigidBodyAPI.Apply(prim)
|
rigid_body_api = UsdPhysics.RigidBodyAPI.Apply(prim)
|
||||||
@ -727,18 +621,14 @@ class PhysicsUSDAdder(MeshtoUSDConverter):
|
|||||||
|
|
||||||
|
|
||||||
class URDFtoUSDConverter(MeshtoUSDConverter):
|
class URDFtoUSDConverter(MeshtoUSDConverter):
|
||||||
"""Converts URDF files to USD format.
|
"""Convert URDF files into USD format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fix_base (bool, optional): Fix the base link.
|
fix_base (bool): Whether to fix the base link.
|
||||||
merge_fixed_joints (bool, optional): Merge fixed joints.
|
merge_fixed_joints (bool): Whether to merge fixed joints.
|
||||||
make_instanceable (bool, optional): Make prims instanceable.
|
make_instanceable (bool): Whether to make prims instanceable.
|
||||||
force_usd_conversion (bool, optional): Force conversion to USD.
|
force_usd_conversion (bool): Force conversion to USD.
|
||||||
collision_from_visuals (bool, optional): Generate collisions from visuals.
|
collision_from_visuals (bool): Generate collisions from visuals if not provided.
|
||||||
joint_drive (optional): Joint drive configuration.
|
|
||||||
rotate_wxyz (tuple[float], optional): Quaternion for rotation.
|
|
||||||
simulation_app (optional): Simulation app instance.
|
|
||||||
**kwargs: Additional arguments.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -753,19 +643,6 @@ class URDFtoUSDConverter(MeshtoUSDConverter):
|
|||||||
simulation_app=None,
|
simulation_app=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Initializes the converter.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
fix_base (bool, optional): Fix the base link.
|
|
||||||
merge_fixed_joints (bool, optional): Merge fixed joints.
|
|
||||||
make_instanceable (bool, optional): Make prims instanceable.
|
|
||||||
force_usd_conversion (bool, optional): Force conversion to USD.
|
|
||||||
collision_from_visuals (bool, optional): Generate collisions from visuals.
|
|
||||||
joint_drive (optional): Joint drive configuration.
|
|
||||||
rotate_wxyz (tuple[float], optional): Quaternion for rotation.
|
|
||||||
simulation_app (optional): Simulation app instance.
|
|
||||||
**kwargs: Additional arguments.
|
|
||||||
"""
|
|
||||||
self.usd_parms = dict(
|
self.usd_parms = dict(
|
||||||
fix_base=fix_base,
|
fix_base=fix_base,
|
||||||
merge_fixed_joints=merge_fixed_joints,
|
merge_fixed_joints=merge_fixed_joints,
|
||||||
@ -780,12 +657,7 @@ class URDFtoUSDConverter(MeshtoUSDConverter):
|
|||||||
self.simulation_app = simulation_app
|
self.simulation_app = simulation_app
|
||||||
|
|
||||||
def convert(self, urdf_path: str, output_file: str):
|
def convert(self, urdf_path: str, output_file: str):
|
||||||
"""Converts a URDF file to USD and post-processes collision meshes.
|
"""Convert a URDF file to USD and post-process collision meshes."""
|
||||||
|
|
||||||
Args:
|
|
||||||
urdf_path (str): Path to URDF file.
|
|
||||||
output_file (str): Path to output USD file.
|
|
||||||
"""
|
|
||||||
from isaaclab.sim.converters import UrdfConverter, UrdfConverterCfg
|
from isaaclab.sim.converters import UrdfConverter, UrdfConverterCfg
|
||||||
from pxr import Gf, PhysxSchema, Sdf, Usd, UsdGeom
|
from pxr import Gf, PhysxSchema, Sdf, Usd, UsdGeom
|
||||||
|
|
||||||
@ -804,9 +676,11 @@ class URDFtoUSDConverter(MeshtoUSDConverter):
|
|||||||
with Usd.EditContext(stage, layer):
|
with Usd.EditContext(stage, layer):
|
||||||
for prim in stage.Traverse():
|
for prim in stage.Traverse():
|
||||||
if prim.GetName() == "collisions":
|
if prim.GetName() == "collisions":
|
||||||
approx_attr = prim.CreateAttribute(
|
approx_attr = prim.GetAttribute("physics:approximation")
|
||||||
"physics:approximation", Sdf.ValueTypeNames.Token
|
if not approx_attr:
|
||||||
)
|
approx_attr = prim.CreateAttribute(
|
||||||
|
"physics:approximation", Sdf.ValueTypeNames.Token
|
||||||
|
)
|
||||||
approx_attr.Set("convexDecomposition")
|
approx_attr.Set("convexDecomposition")
|
||||||
|
|
||||||
physx_conv_api = (
|
physx_conv_api = (
|
||||||
@ -814,9 +688,6 @@ class URDFtoUSDConverter(MeshtoUSDConverter):
|
|||||||
prim
|
prim
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
physx_conv_api.GetMaxConvexHullsAttr().Set(32)
|
|
||||||
physx_conv_api.GetHullVertexLimitAttr().Set(16)
|
|
||||||
physx_conv_api.GetVoxelResolutionAttr().Set(10000)
|
|
||||||
physx_conv_api.GetShrinkWrapAttr().Set(True)
|
physx_conv_api.GetShrinkWrapAttr().Set(True)
|
||||||
|
|
||||||
api_schemas = prim.GetMetadata("apiSchemas")
|
api_schemas = prim.GetMetadata("apiSchemas")
|
||||||
@ -847,36 +718,13 @@ class URDFtoUSDConverter(MeshtoUSDConverter):
|
|||||||
|
|
||||||
|
|
||||||
class AssetConverterFactory:
|
class AssetConverterFactory:
|
||||||
"""Factory for creating asset converters based on target and source types.
|
"""Factory class for creating asset converters based on target and source types."""
|
||||||
|
|
||||||
Example:
|
|
||||||
```py
|
|
||||||
from embodied_gen.data.asset_converter import AssetConverterFactory
|
|
||||||
from embodied_gen.utils.enum import AssetType
|
|
||||||
|
|
||||||
converter = AssetConverterFactory.create(
|
|
||||||
target_type=AssetType.USD, source_type=AssetType.MESH
|
|
||||||
)
|
|
||||||
with converter:
|
|
||||||
for urdf_path, output_file in zip(urdf_paths, output_files):
|
|
||||||
converter.convert(urdf_path, output_file)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(
|
def create(
|
||||||
target_type: AssetType, source_type: AssetType = "urdf", **kwargs
|
target_type: AssetType, source_type: AssetType = "urdf", **kwargs
|
||||||
) -> AssetConverterBase:
|
) -> AssetConverterBase:
|
||||||
"""Creates an asset converter instance.
|
"""Create an asset converter instance based on target and source types."""
|
||||||
|
|
||||||
Args:
|
|
||||||
target_type (AssetType): Target asset type.
|
|
||||||
source_type (AssetType, optional): Source asset type.
|
|
||||||
**kwargs: Additional arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AssetConverterBase: Converter instance.
|
|
||||||
"""
|
|
||||||
if target_type == AssetType.MJCF and source_type == AssetType.MESH:
|
if target_type == AssetType.MJCF and source_type == AssetType.MESH:
|
||||||
converter = MeshtoMJCFConverter(**kwargs)
|
converter = MeshtoMJCFConverter(**kwargs)
|
||||||
elif target_type == AssetType.MJCF and source_type == AssetType.URDF:
|
elif target_type == AssetType.MJCF and source_type == AssetType.URDF:
|
||||||
@ -898,14 +746,7 @@ if __name__ == "__main__":
|
|||||||
# target_asset_type = AssetType.USD
|
# target_asset_type = AssetType.USD
|
||||||
|
|
||||||
urdf_paths = [
|
urdf_paths = [
|
||||||
'outputs/EmbodiedGenData/demo_assets/banana/result/banana.urdf',
|
"outputs/embodiedgen_assets/demo_assets/remote_control/result/remote_control.urdf",
|
||||||
'outputs/EmbodiedGenData/demo_assets/book/result/book.urdf',
|
|
||||||
'outputs/EmbodiedGenData/demo_assets/lamp/result/lamp.urdf',
|
|
||||||
'outputs/EmbodiedGenData/demo_assets/mug/result/mug.urdf',
|
|
||||||
'outputs/EmbodiedGenData/demo_assets/remote_control/result/remote_control.urdf',
|
|
||||||
"outputs/EmbodiedGenData/demo_assets/rubik's_cube/result/rubik's_cube.urdf",
|
|
||||||
'outputs/EmbodiedGenData/demo_assets/table/result/table.urdf',
|
|
||||||
'outputs/EmbodiedGenData/demo_assets/vase/result/vase.urdf',
|
|
||||||
]
|
]
|
||||||
|
|
||||||
if target_asset_type == AssetType.MJCF:
|
if target_asset_type == AssetType.MJCF:
|
||||||
@ -919,14 +760,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
elif target_asset_type == AssetType.USD:
|
elif target_asset_type == AssetType.USD:
|
||||||
output_files = [
|
output_files = [
|
||||||
'outputs/EmbodiedGenData/demo_assets/banana/usd/banana.usd',
|
"outputs/embodiedgen_assets/demo_assets/remote_control/usd/remote_control.usd",
|
||||||
'outputs/EmbodiedGenData/demo_assets/book/usd/book.usd',
|
|
||||||
'outputs/EmbodiedGenData/demo_assets/lamp/usd/lamp.usd',
|
|
||||||
'outputs/EmbodiedGenData/demo_assets/mug/usd/mug.usd',
|
|
||||||
'outputs/EmbodiedGenData/demo_assets/remote_control/usd/remote_control.usd',
|
|
||||||
"outputs/EmbodiedGenData/demo_assets/rubik's_cube/usd/rubik's_cube.usd",
|
|
||||||
'outputs/EmbodiedGenData/demo_assets/table/usd/table.usd',
|
|
||||||
'outputs/EmbodiedGenData/demo_assets/vase/usd/vase.usd',
|
|
||||||
]
|
]
|
||||||
asset_converter = AssetConverterFactory.create(
|
asset_converter = AssetConverterFactory.create(
|
||||||
target_type=AssetType.USD,
|
target_type=AssetType.USD,
|
||||||
|
|||||||
@ -34,7 +34,6 @@ from embodied_gen.data.utils import (
|
|||||||
CameraSetting,
|
CameraSetting,
|
||||||
get_images_from_grid,
|
get_images_from_grid,
|
||||||
init_kal_camera,
|
init_kal_camera,
|
||||||
kaolin_to_opencv_view,
|
|
||||||
normalize_vertices_array,
|
normalize_vertices_array,
|
||||||
post_process_texture,
|
post_process_texture,
|
||||||
save_mesh_with_mtl,
|
save_mesh_with_mtl,
|
||||||
@ -307,6 +306,28 @@ class TextureBaker(object):
|
|||||||
raise ValueError(f"Unknown mode: {mode}")
|
raise ValueError(f"Unknown mode: {mode}")
|
||||||
|
|
||||||
|
|
||||||
|
def kaolin_to_opencv_view(raw_matrix):
|
||||||
|
R_orig = raw_matrix[:, :3, :3]
|
||||||
|
t_orig = raw_matrix[:, :3, 3]
|
||||||
|
|
||||||
|
R_target = torch.zeros_like(R_orig)
|
||||||
|
R_target[:, :, 0] = R_orig[:, :, 2]
|
||||||
|
R_target[:, :, 1] = R_orig[:, :, 0]
|
||||||
|
R_target[:, :, 2] = R_orig[:, :, 1]
|
||||||
|
|
||||||
|
t_target = t_orig
|
||||||
|
|
||||||
|
target_matrix = (
|
||||||
|
torch.eye(4, device=raw_matrix.device)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.repeat(raw_matrix.size(0), 1, 1)
|
||||||
|
)
|
||||||
|
target_matrix[:, :3, :3] = R_target
|
||||||
|
target_matrix[:, :3, 3] = t_target
|
||||||
|
|
||||||
|
return target_matrix
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description="Render settings")
|
parser = argparse.ArgumentParser(description="Render settings")
|
||||||
|
|
||||||
|
|||||||
@ -58,16 +58,7 @@ __all__ = [
|
|||||||
def _transform_vertices(
|
def _transform_vertices(
|
||||||
mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False
|
mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Transforms 3D vertices using a projection matrix.
|
"""Transform 3D vertices using a projection matrix."""
|
||||||
|
|
||||||
Args:
|
|
||||||
mtx (torch.Tensor): Projection matrix.
|
|
||||||
pos (torch.Tensor): Vertex positions.
|
|
||||||
keepdim (bool, optional): If True, keeps the batch dimension.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Transformed vertices.
|
|
||||||
"""
|
|
||||||
t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype)
|
t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype)
|
||||||
if pos.size(-1) == 3:
|
if pos.size(-1) == 3:
|
||||||
pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1)
|
pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1)
|
||||||
@ -80,17 +71,7 @@ def _transform_vertices(
|
|||||||
def _bilinear_interpolation_scattering(
|
def _bilinear_interpolation_scattering(
|
||||||
image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor
|
image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Performs bilinear interpolation scattering for grid-based value accumulation.
|
"""Bilinear interpolation scattering for grid-based value accumulation."""
|
||||||
|
|
||||||
Args:
|
|
||||||
image_h (int): Image height.
|
|
||||||
image_w (int): Image width.
|
|
||||||
coords (torch.Tensor): Normalized coordinates.
|
|
||||||
values (torch.Tensor): Values to scatter.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Interpolated grid.
|
|
||||||
"""
|
|
||||||
device = values.device
|
device = values.device
|
||||||
dtype = values.dtype
|
dtype = values.dtype
|
||||||
C = values.shape[-1]
|
C = values.shape[-1]
|
||||||
@ -154,18 +135,7 @@ def _texture_inpaint_smooth(
|
|||||||
faces: np.ndarray,
|
faces: np.ndarray,
|
||||||
uv_map: np.ndarray,
|
uv_map: np.ndarray,
|
||||||
) -> tuple[np.ndarray, np.ndarray]:
|
) -> tuple[np.ndarray, np.ndarray]:
|
||||||
"""Performs texture inpainting using vertex-based color propagation.
|
"""Perform texture inpainting using vertex-based color propagation."""
|
||||||
|
|
||||||
Args:
|
|
||||||
texture (np.ndarray): Texture image.
|
|
||||||
mask (np.ndarray): Mask image.
|
|
||||||
vertices (np.ndarray): Mesh vertices.
|
|
||||||
faces (np.ndarray): Mesh faces.
|
|
||||||
uv_map (np.ndarray): UV coordinates.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[np.ndarray, np.ndarray]: Inpainted texture and updated mask.
|
|
||||||
"""
|
|
||||||
image_h, image_w, C = texture.shape
|
image_h, image_w, C = texture.shape
|
||||||
N = vertices.shape[0]
|
N = vertices.shape[0]
|
||||||
|
|
||||||
@ -261,41 +231,29 @@ def _texture_inpaint_smooth(
|
|||||||
class TextureBacker:
|
class TextureBacker:
|
||||||
"""Texture baking pipeline for multi-view projection and fusion.
|
"""Texture baking pipeline for multi-view projection and fusion.
|
||||||
|
|
||||||
This class generates UV-based textures for a 3D mesh using multi-view images,
|
This class performs UV-based texture generation for a 3D mesh using
|
||||||
depth, and normal information. It includes mesh normalization, UV unwrapping,
|
multi-view color images, depth, and normal information. The pipeline
|
||||||
visibility-aware back-projection, confidence-weighted fusion, and inpainting.
|
includes mesh normalization and UV unwrapping, visibility-aware
|
||||||
|
back-projection, confidence-weighted texture fusion, and inpainting
|
||||||
|
of missing texture regions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
camera_params (CameraSetting): Camera intrinsics and extrinsics.
|
camera_params (CameraSetting): Camera intrinsics and extrinsics used
|
||||||
view_weights (list[float]): Weights for each view in texture fusion.
|
for rendering each view.
|
||||||
render_wh (tuple[int, int], optional): Intermediate rendering resolution.
|
view_weights (list[float]): A list of weights for each view, used
|
||||||
texture_wh (tuple[int, int], optional): Output texture resolution.
|
to blend confidence maps during texture fusion.
|
||||||
bake_angle_thresh (int, optional): Max angle for valid projection.
|
render_wh (tuple[int, int], optional): Resolution (width, height) for
|
||||||
mask_thresh (float, optional): Threshold for visibility masks.
|
intermediate rendering passes. Defaults to (2048, 2048).
|
||||||
smooth_texture (bool, optional): Apply post-processing to texture.
|
texture_wh (tuple[int, int], optional): Output texture resolution
|
||||||
inpaint_smooth (bool, optional): Apply inpainting smoothing.
|
(width, height). Defaults to (2048, 2048).
|
||||||
|
bake_angle_thresh (int, optional): Maximum angle (in degrees) between
|
||||||
Example:
|
view direction and surface normal for projection to be considered valid.
|
||||||
```py
|
Defaults to 75.
|
||||||
from embodied_gen.data.backproject_v2 import TextureBacker
|
mask_thresh (float, optional): Threshold applied to visibility masks
|
||||||
from embodied_gen.data.utils import CameraSetting
|
during rendering. Defaults to 0.5.
|
||||||
import trimesh
|
smooth_texture (bool, optional): If True, apply post-processing (e.g.,
|
||||||
from PIL import Image
|
blurring) to the final texture. Defaults to True.
|
||||||
|
inpaint_smooth (bool, optional): If True, apply inpainting to smooth.
|
||||||
camera_params = CameraSetting(
|
|
||||||
num_images=6,
|
|
||||||
elevation=[20, -10],
|
|
||||||
distance=5,
|
|
||||||
resolution_hw=(2048,2048),
|
|
||||||
fov=math.radians(30),
|
|
||||||
device='cuda',
|
|
||||||
)
|
|
||||||
view_weights = [1, 0.1, 0.02, 0.1, 1, 0.02]
|
|
||||||
mesh = trimesh.load('mesh.obj')
|
|
||||||
images = [Image.open(f'view_{i}.png') for i in range(6)]
|
|
||||||
texture_backer = TextureBacker(camera_params, view_weights)
|
|
||||||
textured_mesh = texture_backer(images, mesh, 'output.obj')
|
|
||||||
```
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -325,12 +283,6 @@ class TextureBacker:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _lazy_init_render(self, camera_params, mask_thresh):
|
def _lazy_init_render(self, camera_params, mask_thresh):
|
||||||
"""Lazily initializes the renderer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
camera_params (CameraSetting): Camera settings.
|
|
||||||
mask_thresh (float): Mask threshold.
|
|
||||||
"""
|
|
||||||
if self.renderer is None:
|
if self.renderer is None:
|
||||||
camera = init_kal_camera(camera_params)
|
camera = init_kal_camera(camera_params)
|
||||||
mv = camera.view_matrix() # (n 4 4) world2cam
|
mv = camera.view_matrix() # (n 4 4) world2cam
|
||||||
@ -349,14 +301,6 @@ class TextureBacker:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def load_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh:
|
def load_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh:
|
||||||
"""Normalizes mesh and unwraps UVs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mesh (trimesh.Trimesh): Input mesh.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
trimesh.Trimesh: Mesh with normalized vertices and UVs.
|
|
||||||
"""
|
|
||||||
mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
|
mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
|
||||||
self.scale, self.center = scale, center
|
self.scale, self.center = scale, center
|
||||||
|
|
||||||
@ -374,16 +318,6 @@ class TextureBacker:
|
|||||||
scale: float = None,
|
scale: float = None,
|
||||||
center: np.ndarray = None,
|
center: np.ndarray = None,
|
||||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
"""Gets mesh attributes as numpy arrays.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mesh (trimesh.Trimesh): Input mesh.
|
|
||||||
scale (float, optional): Scale factor.
|
|
||||||
center (np.ndarray, optional): Center offset.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (vertices, faces, uv_map)
|
|
||||||
"""
|
|
||||||
vertices = mesh.vertices.copy()
|
vertices = mesh.vertices.copy()
|
||||||
faces = mesh.faces.copy()
|
faces = mesh.faces.copy()
|
||||||
uv_map = mesh.visual.uv.copy()
|
uv_map = mesh.visual.uv.copy()
|
||||||
@ -397,14 +331,6 @@ class TextureBacker:
|
|||||||
return vertices, faces, uv_map
|
return vertices, faces, uv_map
|
||||||
|
|
||||||
def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor:
|
def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor:
|
||||||
"""Computes edge image from depth map.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
depth_image (torch.Tensor): Depth map.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Edge image.
|
|
||||||
"""
|
|
||||||
depth_image_np = depth_image.cpu().numpy()
|
depth_image_np = depth_image.cpu().numpy()
|
||||||
depth_image_np = (depth_image_np * 255).astype(np.uint8)
|
depth_image_np = (depth_image_np * 255).astype(np.uint8)
|
||||||
depth_edges = cv2.Canny(depth_image_np, 30, 80)
|
depth_edges = cv2.Canny(depth_image_np, 30, 80)
|
||||||
@ -418,16 +344,6 @@ class TextureBacker:
|
|||||||
def compute_enhanced_viewnormal(
|
def compute_enhanced_viewnormal(
|
||||||
self, mv_mtx: torch.Tensor, vertices: torch.Tensor, faces: torch.Tensor
|
self, mv_mtx: torch.Tensor, vertices: torch.Tensor, faces: torch.Tensor
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Computes enhanced view normals for mesh faces.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mv_mtx (torch.Tensor): View matrices.
|
|
||||||
vertices (torch.Tensor): Mesh vertices.
|
|
||||||
faces (torch.Tensor): Mesh faces.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: View normals.
|
|
||||||
"""
|
|
||||||
rast, _ = self.renderer.compute_dr_raster(vertices, faces)
|
rast, _ = self.renderer.compute_dr_raster(vertices, faces)
|
||||||
rendered_view_normals = []
|
rendered_view_normals = []
|
||||||
for idx in range(len(mv_mtx)):
|
for idx in range(len(mv_mtx)):
|
||||||
@ -460,18 +376,6 @@ class TextureBacker:
|
|||||||
def back_project(
|
def back_project(
|
||||||
self, image, vis_mask, depth, normal, uv
|
self, image, vis_mask, depth, normal, uv
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Back-projects image and confidence to UV texture space.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image (PIL.Image or np.ndarray): Input image.
|
|
||||||
vis_mask (torch.Tensor): Visibility mask.
|
|
||||||
depth (torch.Tensor): Depth map.
|
|
||||||
normal (torch.Tensor): Normal map.
|
|
||||||
uv (torch.Tensor): UV coordinates.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[torch.Tensor, torch.Tensor]: Texture and confidence map.
|
|
||||||
"""
|
|
||||||
image = np.array(image)
|
image = np.array(image)
|
||||||
image = torch.as_tensor(image, device=self.device, dtype=torch.float32)
|
image = torch.as_tensor(image, device=self.device, dtype=torch.float32)
|
||||||
if image.ndim == 2:
|
if image.ndim == 2:
|
||||||
@ -514,17 +418,6 @@ class TextureBacker:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _scatter_texture(self, uv, data, mask):
|
def _scatter_texture(self, uv, data, mask):
|
||||||
"""Scatters data to texture using UV coordinates and mask.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
uv (torch.Tensor): UV coordinates.
|
|
||||||
data (torch.Tensor): Data to scatter.
|
|
||||||
mask (torch.Tensor): Mask for valid pixels.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Scattered texture.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __filter_data(data, mask):
|
def __filter_data(data, mask):
|
||||||
return data.view(-1, data.shape[-1])[mask]
|
return data.view(-1, data.shape[-1])[mask]
|
||||||
|
|
||||||
@ -539,15 +432,6 @@ class TextureBacker:
|
|||||||
def fast_bake_texture(
|
def fast_bake_texture(
|
||||||
self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor]
|
self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor]
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Fuses multiple textures and confidence maps.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
textures (list[torch.Tensor]): List of textures.
|
|
||||||
confidence_maps (list[torch.Tensor]): List of confidence maps.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[torch.Tensor, torch.Tensor]: Fused texture and mask.
|
|
||||||
"""
|
|
||||||
channel = textures[0].shape[-1]
|
channel = textures[0].shape[-1]
|
||||||
texture_merge = torch.zeros(self.texture_wh + [channel]).to(
|
texture_merge = torch.zeros(self.texture_wh + [channel]).to(
|
||||||
self.device
|
self.device
|
||||||
@ -567,16 +451,6 @@ class TextureBacker:
|
|||||||
def uv_inpaint(
|
def uv_inpaint(
|
||||||
self, mesh: trimesh.Trimesh, texture: np.ndarray, mask: np.ndarray
|
self, mesh: trimesh.Trimesh, texture: np.ndarray, mask: np.ndarray
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Inpaints missing regions in the UV texture.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mesh (trimesh.Trimesh): Mesh.
|
|
||||||
texture (np.ndarray): Texture image.
|
|
||||||
mask (np.ndarray): Mask image.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Inpainted texture.
|
|
||||||
"""
|
|
||||||
if self.inpaint_smooth:
|
if self.inpaint_smooth:
|
||||||
vertices, faces, uv_map = self.get_mesh_np_attrs(mesh)
|
vertices, faces, uv_map = self.get_mesh_np_attrs(mesh)
|
||||||
texture, mask = _texture_inpaint_smooth(
|
texture, mask = _texture_inpaint_smooth(
|
||||||
@ -599,15 +473,6 @@ class TextureBacker:
|
|||||||
colors: list[Image.Image],
|
colors: list[Image.Image],
|
||||||
mesh: trimesh.Trimesh,
|
mesh: trimesh.Trimesh,
|
||||||
) -> trimesh.Trimesh:
|
) -> trimesh.Trimesh:
|
||||||
"""Computes the fused texture for the mesh from multi-view images.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
colors (list[Image.Image]): List of view images.
|
|
||||||
mesh (trimesh.Trimesh): Mesh to texture.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[np.ndarray, np.ndarray]: Texture and mask.
|
|
||||||
"""
|
|
||||||
self._lazy_init_render(self.camera_params, self.mask_thresh)
|
self._lazy_init_render(self.camera_params, self.mask_thresh)
|
||||||
|
|
||||||
vertices = torch.from_numpy(mesh.vertices).to(self.device).float()
|
vertices = torch.from_numpy(mesh.vertices).to(self.device).float()
|
||||||
@ -652,7 +517,7 @@ class TextureBacker:
|
|||||||
Args:
|
Args:
|
||||||
colors (list[Image.Image]): List of input view images.
|
colors (list[Image.Image]): List of input view images.
|
||||||
mesh (trimesh.Trimesh): Input mesh to be textured.
|
mesh (trimesh.Trimesh): Input mesh to be textured.
|
||||||
output_path (str): Path to save the output textured mesh.
|
output_path (str): Path to save the output textured mesh (.obj or .glb).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
trimesh.Trimesh: The textured mesh with UV and texture image.
|
trimesh.Trimesh: The textured mesh with UV and texture image.
|
||||||
@ -675,11 +540,6 @@ class TextureBacker:
|
|||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
"""Parses command-line arguments for texture backprojection.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
argparse.Namespace: Parsed arguments.
|
|
||||||
"""
|
|
||||||
parser = argparse.ArgumentParser(description="Backproject texture")
|
parser = argparse.ArgumentParser(description="Backproject texture")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--color_path",
|
"--color_path",
|
||||||
@ -776,16 +636,6 @@ def entrypoint(
|
|||||||
imagesr_model: ImageRealESRGAN = None,
|
imagesr_model: ImageRealESRGAN = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> trimesh.Trimesh:
|
) -> trimesh.Trimesh:
|
||||||
"""Entrypoint for texture backprojection from multi-view images.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
delight_model (DelightingModel, optional): Delighting model.
|
|
||||||
imagesr_model (ImageRealESRGAN, optional): Super-resolution model.
|
|
||||||
**kwargs: Additional arguments to override CLI.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
trimesh.Trimesh: Textured mesh.
|
|
||||||
"""
|
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
if hasattr(args, k) and v is not None:
|
if hasattr(args, k) and v is not None:
|
||||||
|
|||||||
@ -1,558 +0,0 @@
|
|||||||
# Project EmbodiedGen
|
|
||||||
#
|
|
||||||
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
# implied. See the License for the specific language governing
|
|
||||||
# permissions and limitations under the License.
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
from typing import Literal, Union
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import nvdiffrast.torch as dr
|
|
||||||
import spaces
|
|
||||||
import torch
|
|
||||||
import trimesh
|
|
||||||
import utils3d
|
|
||||||
import xatlas
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm
|
|
||||||
from embodied_gen.data.mesh_operator import MeshFixer
|
|
||||||
from embodied_gen.data.utils import (
|
|
||||||
CameraSetting,
|
|
||||||
init_kal_camera,
|
|
||||||
kaolin_to_opencv_view,
|
|
||||||
normalize_vertices_array,
|
|
||||||
post_process_texture,
|
|
||||||
save_mesh_with_mtl,
|
|
||||||
)
|
|
||||||
from embodied_gen.models.delight_model import DelightingModel
|
|
||||||
from embodied_gen.models.gs_model import load_gs_model
|
|
||||||
from embodied_gen.models.sr_model import ImageRealESRGAN
|
|
||||||
|
|
||||||
logging.basicConfig(
|
|
||||||
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
|
|
||||||
)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"TextureBaker",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class TextureBaker(object):
|
|
||||||
"""Baking textures onto a mesh from multiple observations.
|
|
||||||
|
|
||||||
This class take 3D mesh data, camera settings and texture baking parameters
|
|
||||||
to generate texture map by projecting images to the mesh from diff views.
|
|
||||||
It supports both a fast texture baking approach and a more optimized method
|
|
||||||
with total variation regularization.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
vertices (torch.Tensor): The vertices of the mesh.
|
|
||||||
faces (torch.Tensor): The faces of the mesh, defined by vertex indices.
|
|
||||||
uvs (torch.Tensor): The UV coordinates of the mesh.
|
|
||||||
camera_params (CameraSetting): Camera setting (intrinsics, extrinsics).
|
|
||||||
device (str): The device to run computations on ("cpu" or "cuda").
|
|
||||||
w2cs (torch.Tensor): World-to-camera transformation matrices.
|
|
||||||
projections (torch.Tensor): Camera projection matrices.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> vertices, faces, uvs = TextureBaker.parametrize_mesh(vertices, faces) # noqa
|
|
||||||
>>> texture_backer = TextureBaker(vertices, faces, uvs, camera_params)
|
|
||||||
>>> images = get_images_from_grid(args.color_path, image_size)
|
|
||||||
>>> texture = texture_backer.bake_texture(
|
|
||||||
... images, texture_size=args.texture_size, mode=args.baker_mode
|
|
||||||
... )
|
|
||||||
>>> texture = post_process_texture(texture)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vertices: np.ndarray,
|
|
||||||
faces: np.ndarray,
|
|
||||||
uvs: np.ndarray,
|
|
||||||
camera_params: CameraSetting,
|
|
||||||
device: str = "cuda",
|
|
||||||
) -> None:
|
|
||||||
self.vertices = (
|
|
||||||
torch.tensor(vertices, device=device)
|
|
||||||
if isinstance(vertices, np.ndarray)
|
|
||||||
else vertices.to(device)
|
|
||||||
)
|
|
||||||
self.faces = (
|
|
||||||
torch.tensor(faces.astype(np.int32), device=device)
|
|
||||||
if isinstance(faces, np.ndarray)
|
|
||||||
else faces.to(device)
|
|
||||||
)
|
|
||||||
self.uvs = (
|
|
||||||
torch.tensor(uvs, device=device)
|
|
||||||
if isinstance(uvs, np.ndarray)
|
|
||||||
else uvs.to(device)
|
|
||||||
)
|
|
||||||
self.camera_params = camera_params
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
camera = init_kal_camera(camera_params)
|
|
||||||
matrix_mv = camera.view_matrix() # (n_cam 4 4) world2cam
|
|
||||||
matrix_mv = kaolin_to_opencv_view(matrix_mv)
|
|
||||||
matrix_p = (
|
|
||||||
camera.intrinsics.projection_matrix()
|
|
||||||
) # (n_cam 4 4) cam2pixel
|
|
||||||
self.w2cs = matrix_mv.to(self.device)
|
|
||||||
self.projections = matrix_p.to(self.device)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def parametrize_mesh(
|
|
||||||
vertices: np.array, faces: np.array
|
|
||||||
) -> Union[np.array, np.array, np.array]:
|
|
||||||
vmapping, indices, uvs = xatlas.parametrize(vertices, faces)
|
|
||||||
|
|
||||||
vertices = vertices[vmapping]
|
|
||||||
faces = indices
|
|
||||||
|
|
||||||
return vertices, faces, uvs
|
|
||||||
|
|
||||||
def _bake_fast(self, observations, w2cs, projections, texture_size, masks):
|
|
||||||
texture = torch.zeros(
|
|
||||||
(texture_size * texture_size, 3), dtype=torch.float32
|
|
||||||
).cuda()
|
|
||||||
texture_weights = torch.zeros(
|
|
||||||
(texture_size * texture_size), dtype=torch.float32
|
|
||||||
).cuda()
|
|
||||||
rastctx = utils3d.torch.RastContext(backend="cuda")
|
|
||||||
for observation, w2c, projection in tqdm(
|
|
||||||
zip(observations, w2cs, projections),
|
|
||||||
total=len(observations),
|
|
||||||
desc="Texture baking (fast)",
|
|
||||||
):
|
|
||||||
with torch.no_grad():
|
|
||||||
rast = utils3d.torch.rasterize_triangle_faces(
|
|
||||||
rastctx,
|
|
||||||
self.vertices[None],
|
|
||||||
self.faces,
|
|
||||||
observation.shape[1],
|
|
||||||
observation.shape[0],
|
|
||||||
uv=self.uvs[None],
|
|
||||||
view=w2c,
|
|
||||||
projection=projection,
|
|
||||||
)
|
|
||||||
uv_map = rast["uv"][0].detach().flip(0)
|
|
||||||
mask = rast["mask"][0].detach().bool() & masks[0]
|
|
||||||
|
|
||||||
# nearest neighbor interpolation
|
|
||||||
uv_map = (uv_map * texture_size).floor().long()
|
|
||||||
obs = observation[mask]
|
|
||||||
uv_map = uv_map[mask]
|
|
||||||
idx = (
|
|
||||||
uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size
|
|
||||||
)
|
|
||||||
texture = texture.scatter_add(
|
|
||||||
0, idx.view(-1, 1).expand(-1, 3), obs
|
|
||||||
)
|
|
||||||
texture_weights = texture_weights.scatter_add(
|
|
||||||
0,
|
|
||||||
idx,
|
|
||||||
torch.ones(
|
|
||||||
(obs.shape[0]), dtype=torch.float32, device=texture.device
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
mask = texture_weights > 0
|
|
||||||
texture[mask] /= texture_weights[mask][:, None]
|
|
||||||
texture = np.clip(
|
|
||||||
texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255,
|
|
||||||
0,
|
|
||||||
255,
|
|
||||||
).astype(np.uint8)
|
|
||||||
|
|
||||||
# inpaint
|
|
||||||
mask = (
|
|
||||||
(texture_weights == 0)
|
|
||||||
.cpu()
|
|
||||||
.numpy()
|
|
||||||
.astype(np.uint8)
|
|
||||||
.reshape(texture_size, texture_size)
|
|
||||||
)
|
|
||||||
texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
|
|
||||||
|
|
||||||
return texture
|
|
||||||
|
|
||||||
def _bake_opt(
|
|
||||||
self,
|
|
||||||
observations,
|
|
||||||
w2cs,
|
|
||||||
projections,
|
|
||||||
texture_size,
|
|
||||||
lambda_tv,
|
|
||||||
masks,
|
|
||||||
total_steps,
|
|
||||||
):
|
|
||||||
rastctx = utils3d.torch.RastContext(backend="cuda")
|
|
||||||
observations = [observations.flip(0) for observations in observations]
|
|
||||||
masks = [m.flip(0) for m in masks]
|
|
||||||
_uv = []
|
|
||||||
_uv_dr = []
|
|
||||||
for observation, w2c, projection in tqdm(
|
|
||||||
zip(observations, w2cs, projections),
|
|
||||||
total=len(w2cs),
|
|
||||||
):
|
|
||||||
with torch.no_grad():
|
|
||||||
rast = utils3d.torch.rasterize_triangle_faces(
|
|
||||||
rastctx,
|
|
||||||
self.vertices[None],
|
|
||||||
self.faces,
|
|
||||||
observation.shape[1],
|
|
||||||
observation.shape[0],
|
|
||||||
uv=self.uvs[None],
|
|
||||||
view=w2c,
|
|
||||||
projection=projection,
|
|
||||||
)
|
|
||||||
_uv.append(rast["uv"].detach())
|
|
||||||
_uv_dr.append(rast["uv_dr"].detach())
|
|
||||||
|
|
||||||
texture = torch.nn.Parameter(
|
|
||||||
torch.zeros(
|
|
||||||
(1, texture_size, texture_size, 3), dtype=torch.float32
|
|
||||||
).cuda()
|
|
||||||
)
|
|
||||||
optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2)
|
|
||||||
|
|
||||||
def cosine_anealing(step, total_steps, start_lr, end_lr):
|
|
||||||
return end_lr + 0.5 * (start_lr - end_lr) * (
|
|
||||||
1 + np.cos(np.pi * step / total_steps)
|
|
||||||
)
|
|
||||||
|
|
||||||
def tv_loss(texture):
|
|
||||||
return torch.nn.functional.l1_loss(
|
|
||||||
texture[:, :-1, :, :], texture[:, 1:, :, :]
|
|
||||||
) + torch.nn.functional.l1_loss(
|
|
||||||
texture[:, :, :-1, :], texture[:, :, 1:, :]
|
|
||||||
)
|
|
||||||
|
|
||||||
with tqdm(total=total_steps, desc="Texture baking") as pbar:
|
|
||||||
for step in range(total_steps):
|
|
||||||
optimizer.zero_grad()
|
|
||||||
selected = np.random.randint(0, len(w2cs))
|
|
||||||
uv, uv_dr, observation, mask = (
|
|
||||||
_uv[selected],
|
|
||||||
_uv_dr[selected],
|
|
||||||
observations[selected],
|
|
||||||
masks[selected],
|
|
||||||
)
|
|
||||||
render = dr.texture(texture, uv, uv_dr)[0]
|
|
||||||
loss = torch.nn.functional.l1_loss(
|
|
||||||
render[mask], observation[mask]
|
|
||||||
)
|
|
||||||
if lambda_tv > 0:
|
|
||||||
loss += lambda_tv * tv_loss(texture)
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
optimizer.param_groups[0]["lr"] = cosine_anealing(
|
|
||||||
step, total_steps, 1e-2, 1e-5
|
|
||||||
)
|
|
||||||
pbar.set_postfix({"loss": loss.item()})
|
|
||||||
pbar.update()
|
|
||||||
|
|
||||||
texture = np.clip(
|
|
||||||
texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255
|
|
||||||
).astype(np.uint8)
|
|
||||||
mask = 1 - utils3d.torch.rasterize_triangle_faces(
|
|
||||||
rastctx,
|
|
||||||
(self.uvs * 2 - 1)[None],
|
|
||||||
self.faces,
|
|
||||||
texture_size,
|
|
||||||
texture_size,
|
|
||||||
)["mask"][0].detach().cpu().numpy().astype(np.uint8)
|
|
||||||
texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
|
|
||||||
|
|
||||||
return texture
|
|
||||||
|
|
||||||
def bake_texture(
|
|
||||||
self,
|
|
||||||
images: list[np.array],
|
|
||||||
texture_size: int = 1024,
|
|
||||||
mode: Literal["fast", "opt"] = "opt",
|
|
||||||
lambda_tv: float = 1e-2,
|
|
||||||
opt_step: int = 2000,
|
|
||||||
):
|
|
||||||
masks = [np.any(img > 0, axis=-1) for img in images]
|
|
||||||
masks = [torch.tensor(m > 0).bool().to(self.device) for m in masks]
|
|
||||||
images = [
|
|
||||||
torch.tensor(obs / 255.0).float().to(self.device) for obs in images
|
|
||||||
]
|
|
||||||
|
|
||||||
if mode == "fast":
|
|
||||||
return self._bake_fast(
|
|
||||||
images, self.w2cs, self.projections, texture_size, masks
|
|
||||||
)
|
|
||||||
elif mode == "opt":
|
|
||||||
return self._bake_opt(
|
|
||||||
images,
|
|
||||||
self.w2cs,
|
|
||||||
self.projections,
|
|
||||||
texture_size,
|
|
||||||
lambda_tv,
|
|
||||||
masks,
|
|
||||||
opt_step,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown mode: {mode}")
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
"""Parses command-line arguments for texture backprojection.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
argparse.Namespace: Parsed arguments.
|
|
||||||
"""
|
|
||||||
parser = argparse.ArgumentParser(description="Backproject texture")
|
|
||||||
parser.add_argument(
|
|
||||||
"--gs_path",
|
|
||||||
type=str,
|
|
||||||
help="Path to the GS.ply gaussian splatting model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--mesh_path",
|
|
||||||
type=str,
|
|
||||||
help="Mesh path, .obj, .glb or .ply",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output_path",
|
|
||||||
type=str,
|
|
||||||
help="Output mesh path with suffix",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_images",
|
|
||||||
type=int,
|
|
||||||
default=180,
|
|
||||||
help="Number of images to render.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--elevation",
|
|
||||||
nargs="+",
|
|
||||||
type=float,
|
|
||||||
default=list(range(85, -90, -10)),
|
|
||||||
help="Elevation angles for the camera",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--distance",
|
|
||||||
type=float,
|
|
||||||
default=5,
|
|
||||||
help="Camera distance (default: 5)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--resolution_hw",
|
|
||||||
type=int,
|
|
||||||
nargs=2,
|
|
||||||
default=(512, 512),
|
|
||||||
help="Resolution of the render images (default: (512, 512))",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--fov",
|
|
||||||
type=float,
|
|
||||||
default=30,
|
|
||||||
help="Field of view in degrees (default: 30)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--device",
|
|
||||||
type=str,
|
|
||||||
choices=["cpu", "cuda"],
|
|
||||||
default="cuda",
|
|
||||||
help="Device to run on (default: `cuda`)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--skip_fix_mesh", action="store_true", help="Fix mesh geometry."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--texture_size",
|
|
||||||
type=int,
|
|
||||||
default=2048,
|
|
||||||
help="Texture size for texture baking (default: 1024)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--baker_mode",
|
|
||||||
type=str,
|
|
||||||
default="opt",
|
|
||||||
help="Texture baking mode, `fast` or `opt` (default: opt)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--opt_step",
|
|
||||||
type=int,
|
|
||||||
default=3000,
|
|
||||||
help="Optimization steps for texture baking (default: 3000)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--mesh_sipmlify_ratio",
|
|
||||||
type=float,
|
|
||||||
default=0.9,
|
|
||||||
help="Mesh simplification ratio (default: 0.9)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--delight", action="store_true", help="Use delighting model."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--no_smooth_texture",
|
|
||||||
action="store_true",
|
|
||||||
help="Do not smooth the texture.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--no_coor_trans",
|
|
||||||
action="store_true",
|
|
||||||
help="Do not transform the asset coordinate system.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--save_glb_path", type=str, default=None, help="Save glb path."
|
|
||||||
)
|
|
||||||
parser.add_argument("--n_max_faces", type=int, default=30000)
|
|
||||||
args, unknown = parser.parse_known_args()
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
@spaces.GPU
|
|
||||||
def entrypoint(
|
|
||||||
delight_model: DelightingModel = None,
|
|
||||||
imagesr_model: ImageRealESRGAN = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> trimesh.Trimesh:
|
|
||||||
"""Entrypoint for texture backprojection from multi-view images.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
delight_model (DelightingModel, optional): Delighting model.
|
|
||||||
imagesr_model (ImageRealESRGAN, optional): Super-resolution model.
|
|
||||||
**kwargs: Additional arguments to override CLI.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
trimesh.Trimesh: Textured mesh.
|
|
||||||
"""
|
|
||||||
args = parse_args()
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
if hasattr(args, k) and v is not None:
|
|
||||||
setattr(args, k, v)
|
|
||||||
|
|
||||||
# Setup camera parameters.
|
|
||||||
camera_params = CameraSetting(
|
|
||||||
num_images=args.num_images,
|
|
||||||
elevation=args.elevation,
|
|
||||||
distance=args.distance,
|
|
||||||
resolution_hw=args.resolution_hw,
|
|
||||||
fov=math.radians(args.fov),
|
|
||||||
device=args.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# GS render.
|
|
||||||
camera = init_kal_camera(camera_params, flip_az=True)
|
|
||||||
matrix_mv = camera.view_matrix() # (n_cam 4 4) world2cam
|
|
||||||
matrix_mv[:, :3, 3] = -matrix_mv[:, :3, 3]
|
|
||||||
w2cs = matrix_mv.to(camera_params.device)
|
|
||||||
c2ws = [torch.linalg.inv(matrix) for matrix in w2cs]
|
|
||||||
Ks = torch.tensor(camera_params.Ks).to(camera_params.device)
|
|
||||||
gs_model = load_gs_model(args.gs_path, pre_quat=[0.0, 0.0, 1.0, 0.0])
|
|
||||||
multiviews = []
|
|
||||||
for idx in tqdm(range(len(c2ws)), desc="Rendering GS"):
|
|
||||||
result = gs_model.render(
|
|
||||||
c2ws[idx],
|
|
||||||
Ks=Ks,
|
|
||||||
image_width=camera_params.resolution_hw[1],
|
|
||||||
image_height=camera_params.resolution_hw[0],
|
|
||||||
)
|
|
||||||
color = cv2.cvtColor(result.rgba, cv2.COLOR_BGRA2RGBA)
|
|
||||||
multiviews.append(Image.fromarray(color))
|
|
||||||
|
|
||||||
if args.delight and delight_model is None:
|
|
||||||
delight_model = DelightingModel()
|
|
||||||
|
|
||||||
if args.delight:
|
|
||||||
for idx in range(len(multiviews)):
|
|
||||||
multiviews[idx] = delight_model(multiviews[idx])
|
|
||||||
|
|
||||||
multiviews = [img.convert("RGB") for img in multiviews]
|
|
||||||
|
|
||||||
mesh = trimesh.load(args.mesh_path)
|
|
||||||
if isinstance(mesh, trimesh.Scene):
|
|
||||||
mesh = mesh.dump(concatenate=True)
|
|
||||||
|
|
||||||
vertices, scale, center = normalize_vertices_array(mesh.vertices)
|
|
||||||
|
|
||||||
# Transform mesh coordinate system by default.
|
|
||||||
if not args.no_coor_trans:
|
|
||||||
x_rot = np.array([[1, 0, 0], [0, 0, 1], [0, -1, 0]])
|
|
||||||
z_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])
|
|
||||||
vertices = vertices @ x_rot
|
|
||||||
vertices = vertices @ z_rot
|
|
||||||
|
|
||||||
faces = mesh.faces.astype(np.int32)
|
|
||||||
vertices = vertices.astype(np.float32)
|
|
||||||
|
|
||||||
if not args.skip_fix_mesh and len(faces) > 10 * args.n_max_faces:
|
|
||||||
mesh_fixer = MeshFixer(vertices, faces, args.device)
|
|
||||||
vertices, faces = mesh_fixer(
|
|
||||||
filter_ratio=args.mesh_sipmlify_ratio,
|
|
||||||
max_hole_size=0.04,
|
|
||||||
resolution=1024,
|
|
||||||
num_views=1000,
|
|
||||||
norm_mesh_ratio=0.5,
|
|
||||||
)
|
|
||||||
if len(faces) > args.n_max_faces:
|
|
||||||
mesh_fixer = MeshFixer(vertices, faces, args.device)
|
|
||||||
vertices, faces = mesh_fixer(
|
|
||||||
filter_ratio=max(0.05, args.mesh_sipmlify_ratio - 0.2),
|
|
||||||
max_hole_size=0.04,
|
|
||||||
resolution=1024,
|
|
||||||
num_views=1000,
|
|
||||||
norm_mesh_ratio=0.5,
|
|
||||||
)
|
|
||||||
|
|
||||||
vertices, faces, uvs = TextureBaker.parametrize_mesh(vertices, faces)
|
|
||||||
texture_backer = TextureBaker(
|
|
||||||
vertices,
|
|
||||||
faces,
|
|
||||||
uvs,
|
|
||||||
camera_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
multiviews = [np.array(img) for img in multiviews]
|
|
||||||
texture = texture_backer.bake_texture(
|
|
||||||
images=[img[..., :3] for img in multiviews],
|
|
||||||
texture_size=args.texture_size,
|
|
||||||
mode=args.baker_mode,
|
|
||||||
opt_step=args.opt_step,
|
|
||||||
)
|
|
||||||
if not args.no_smooth_texture:
|
|
||||||
texture = post_process_texture(texture)
|
|
||||||
|
|
||||||
# Recover mesh original orientation, scale and center.
|
|
||||||
if not args.no_coor_trans:
|
|
||||||
vertices = vertices @ np.linalg.inv(z_rot)
|
|
||||||
vertices = vertices @ np.linalg.inv(x_rot)
|
|
||||||
vertices = vertices / scale
|
|
||||||
vertices = vertices + center
|
|
||||||
|
|
||||||
textured_mesh = save_mesh_with_mtl(
|
|
||||||
vertices, faces, uvs, texture, args.output_path
|
|
||||||
)
|
|
||||||
if args.save_glb_path is not None:
|
|
||||||
os.makedirs(os.path.dirname(args.save_glb_path), exist_ok=True)
|
|
||||||
textured_mesh.export(args.save_glb_path)
|
|
||||||
|
|
||||||
return textured_mesh
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
entrypoint()
|
|
||||||
@ -39,22 +39,6 @@ def decompose_convex_coacd(
|
|||||||
auto_scale: bool = True,
|
auto_scale: bool = True,
|
||||||
scale_factor: float = 1.0,
|
scale_factor: float = 1.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Decomposes a mesh using CoACD and saves the result.
|
|
||||||
|
|
||||||
This function loads a mesh from a file, runs the CoACD algorithm with the
|
|
||||||
given parameters, optionally scales the resulting convex hulls to match the
|
|
||||||
original mesh's bounding box, and exports the combined result to a file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filename: Path to the input mesh file.
|
|
||||||
outfile: Path to save the decomposed output mesh.
|
|
||||||
params: A dictionary of parameters for the CoACD algorithm.
|
|
||||||
verbose: If True, sets the CoACD log level to 'info'.
|
|
||||||
auto_scale: If True, automatically computes a scale factor to match the
|
|
||||||
decomposed mesh's bounding box to the visual mesh's bounding box.
|
|
||||||
scale_factor: An additional scaling factor applied to the vertices of
|
|
||||||
the decomposed mesh parts.
|
|
||||||
"""
|
|
||||||
coacd.set_log_level("info" if verbose else "warn")
|
coacd.set_log_level("info" if verbose else "warn")
|
||||||
|
|
||||||
mesh = trimesh.load(filename, force="mesh")
|
mesh = trimesh.load(filename, force="mesh")
|
||||||
@ -99,38 +83,7 @@ def decompose_convex_mesh(
|
|||||||
scale_factor: float = 1.005,
|
scale_factor: float = 1.005,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Decomposes a mesh into convex parts with retry logic.
|
"""Decompose a mesh into convex parts using the CoACD algorithm."""
|
||||||
|
|
||||||
This function serves as a wrapper for `decompose_convex_coacd`, providing
|
|
||||||
explicit parameters for the CoACD algorithm and implementing a retry
|
|
||||||
mechanism. If the initial decomposition fails, it attempts again with
|
|
||||||
`preprocess_mode` set to 'on'.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filename: Path to the input mesh file.
|
|
||||||
outfile: Path to save the decomposed output mesh.
|
|
||||||
threshold: CoACD parameter. See CoACD documentation for details.
|
|
||||||
max_convex_hull: CoACD parameter. See CoACD documentation for details.
|
|
||||||
preprocess_mode: CoACD parameter. See CoACD documentation for details.
|
|
||||||
preprocess_resolution: CoACD parameter. See CoACD documentation for details.
|
|
||||||
resolution: CoACD parameter. See CoACD documentation for details.
|
|
||||||
mcts_nodes: CoACD parameter. See CoACD documentation for details.
|
|
||||||
mcts_iterations: CoACD parameter. See CoACD documentation for details.
|
|
||||||
mcts_max_depth: CoACD parameter. See CoACD documentation for details.
|
|
||||||
pca: CoACD parameter. See CoACD documentation for details.
|
|
||||||
merge: CoACD parameter. See CoACD documentation for details.
|
|
||||||
seed: CoACD parameter. See CoACD documentation for details.
|
|
||||||
auto_scale: If True, automatically scale the output to match the input
|
|
||||||
bounding box.
|
|
||||||
scale_factor: Additional scaling factor to apply.
|
|
||||||
verbose: If True, enables detailed logging.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The path to the output file if decomposition is successful.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If convex decomposition fails after all attempts.
|
|
||||||
"""
|
|
||||||
coacd.set_log_level("info" if verbose else "warn")
|
coacd.set_log_level("info" if verbose else "warn")
|
||||||
|
|
||||||
if os.path.exists(outfile):
|
if os.path.exists(outfile):
|
||||||
@ -195,37 +148,9 @@ def decompose_convex_mp(
|
|||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
auto_scale: bool = True,
|
auto_scale: bool = True,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Decomposes a mesh into convex parts in a separate process.
|
"""Decompose a mesh into convex parts using the CoACD algorithm in a separate process.
|
||||||
|
|
||||||
This function uses the `multiprocessing` module to run the CoACD algorithm
|
|
||||||
in a spawned subprocess. This is useful for isolating the decomposition
|
|
||||||
process to prevent potential memory leaks or crashes in the main process.
|
|
||||||
It includes a retry mechanism similar to `decompose_convex_mesh`.
|
|
||||||
|
|
||||||
See https://simulately.wiki/docs/toolkits/ConvexDecomp for details.
|
See https://simulately.wiki/docs/toolkits/ConvexDecomp for details.
|
||||||
|
|
||||||
Args:
|
|
||||||
filename: Path to the input mesh file.
|
|
||||||
outfile: Path to save the decomposed output mesh.
|
|
||||||
threshold: CoACD parameter.
|
|
||||||
max_convex_hull: CoACD parameter.
|
|
||||||
preprocess_mode: CoACD parameter.
|
|
||||||
preprocess_resolution: CoACD parameter.
|
|
||||||
resolution: CoACD parameter.
|
|
||||||
mcts_nodes: CoACD parameter.
|
|
||||||
mcts_iterations: CoACD parameter.
|
|
||||||
mcts_max_depth: CoACD parameter.
|
|
||||||
pca: CoACD parameter.
|
|
||||||
merge: CoACD parameter.
|
|
||||||
seed: CoACD parameter.
|
|
||||||
verbose: If True, enables detailed logging in the subprocess.
|
|
||||||
auto_scale: If True, automatically scale the output.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The path to the output file if decomposition is successful.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If convex decomposition fails after all attempts.
|
|
||||||
"""
|
"""
|
||||||
params = dict(
|
params = dict(
|
||||||
threshold=threshold,
|
threshold=threshold,
|
||||||
|
|||||||
@ -66,14 +66,6 @@ def create_mp4_from_images(
|
|||||||
fps: int = 10,
|
fps: int = 10,
|
||||||
prompt: str = None,
|
prompt: str = None,
|
||||||
):
|
):
|
||||||
"""Creates an MP4 video from a list of images.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (list[np.ndarray]): List of images as numpy arrays.
|
|
||||||
output_path (str): Path to save the MP4 file.
|
|
||||||
fps (int, optional): Frames per second. Defaults to 10.
|
|
||||||
prompt (str, optional): Optional text prompt overlay.
|
|
||||||
"""
|
|
||||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||||
font_scale = 0.5
|
font_scale = 0.5
|
||||||
font_thickness = 1
|
font_thickness = 1
|
||||||
@ -104,13 +96,6 @@ def create_mp4_from_images(
|
|||||||
def create_gif_from_images(
|
def create_gif_from_images(
|
||||||
images: list[np.ndarray], output_path: str, fps: int = 10
|
images: list[np.ndarray], output_path: str, fps: int = 10
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Creates a GIF animation from a list of images.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (list[np.ndarray]): List of images as numpy arrays.
|
|
||||||
output_path (str): Path to save the GIF file.
|
|
||||||
fps (int, optional): Frames per second. Defaults to 10.
|
|
||||||
"""
|
|
||||||
pil_images = []
|
pil_images = []
|
||||||
for image in images:
|
for image in images:
|
||||||
image = image.clip(min=0, max=1)
|
image = image.clip(min=0, max=1)
|
||||||
@ -131,47 +116,32 @@ def create_gif_from_images(
|
|||||||
|
|
||||||
|
|
||||||
class ImageRender(object):
|
class ImageRender(object):
|
||||||
"""Differentiable mesh renderer supporting multi-view rendering.
|
"""A differentiable mesh renderer supporting multi-view rendering.
|
||||||
|
|
||||||
This class wraps differentiable rasterization using `nvdiffrast` to render mesh
|
This class wraps a differentiable rasterization using `nvdiffrast` to
|
||||||
geometry to various maps (normal, depth, alpha, albedo, etc.) and supports
|
render mesh geometry to various maps (normal, depth, alpha, albedo, etc.).
|
||||||
saving images and videos.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
render_items (list[RenderItems]): List of rendering targets.
|
render_items (list[RenderItems]): A list of rendering targets to
|
||||||
camera_params (CameraSetting): Camera parameters for rendering.
|
generate (e.g., IMAGE, DEPTH, NORMAL, etc.).
|
||||||
recompute_vtx_normal (bool, optional): Recompute vertex normals. Defaults to True.
|
camera_params (CameraSetting): The camera parameters for rendering,
|
||||||
with_mtl (bool, optional): Load mesh material files. Defaults to False.
|
including intrinsic and extrinsic matrices.
|
||||||
gen_color_gif (bool, optional): Generate GIF of color images. Defaults to False.
|
recompute_vtx_normal (bool, optional): If True, recomputes
|
||||||
gen_color_mp4 (bool, optional): Generate MP4 of color images. Defaults to False.
|
vertex normals from the mesh geometry. Defaults to True.
|
||||||
gen_viewnormal_mp4 (bool, optional): Generate MP4 of view-space normals. Defaults to False.
|
with_mtl (bool, optional): Whether to load `.mtl` material files
|
||||||
gen_glonormal_mp4 (bool, optional): Generate MP4 of global-space normals. Defaults to False.
|
for meshes. Defaults to False.
|
||||||
no_index_file (bool, optional): Skip saving index file. Defaults to False.
|
gen_color_gif (bool, optional): Generate a GIF of rendered
|
||||||
light_factor (float, optional): PBR light intensity multiplier. Defaults to 1.0.
|
color images. Defaults to False.
|
||||||
|
gen_color_mp4 (bool, optional): Generate an MP4 video of rendered
|
||||||
Example:
|
color images. Defaults to False.
|
||||||
```py
|
gen_viewnormal_mp4 (bool, optional): Generate an MP4 video of
|
||||||
from embodied_gen.data.differentiable_render import ImageRender
|
view-space normals. Defaults to False.
|
||||||
from embodied_gen.data.utils import CameraSetting
|
gen_glonormal_mp4 (bool, optional): Generate an MP4 video of
|
||||||
from embodied_gen.utils.enum import RenderItems
|
global-space normals. Defaults to False.
|
||||||
|
no_index_file (bool, optional): If True, skip saving the `index.json`
|
||||||
camera_params = CameraSetting(
|
summary file. Defaults to False.
|
||||||
num_images=6,
|
light_factor (float, optional): A scalar multiplier for
|
||||||
elevation=[20, -10],
|
PBR light intensity. Defaults to 1.0.
|
||||||
distance=5,
|
|
||||||
resolution_hw=(512,512),
|
|
||||||
fov=math.radians(30),
|
|
||||||
device='cuda',
|
|
||||||
)
|
|
||||||
render_items = [RenderItems.IMAGE.value, RenderItems.DEPTH.value]
|
|
||||||
renderer = ImageRender(
|
|
||||||
render_items,
|
|
||||||
camera_params,
|
|
||||||
with_mtl=args.with_mtl,
|
|
||||||
gen_color_mp4=True,
|
|
||||||
)
|
|
||||||
renderer.render_mesh(mesh_path='mesh.obj', output_root='./renders')
|
|
||||||
```
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -228,14 +198,6 @@ class ImageRender(object):
|
|||||||
uuid: Union[str, List[str]] = None,
|
uuid: Union[str, List[str]] = None,
|
||||||
prompts: List[str] = None,
|
prompts: List[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Renders one or more meshes and saves outputs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mesh_path (Union[str, List[str]]): Path(s) to mesh files.
|
|
||||||
output_root (str): Directory to save outputs.
|
|
||||||
uuid (Union[str, List[str]], optional): Unique IDs for outputs.
|
|
||||||
prompts (List[str], optional): Text prompts for videos.
|
|
||||||
"""
|
|
||||||
mesh_path = as_list(mesh_path)
|
mesh_path = as_list(mesh_path)
|
||||||
if uuid is None:
|
if uuid is None:
|
||||||
uuid = [os.path.basename(p).split(".")[0] for p in mesh_path]
|
uuid = [os.path.basename(p).split(".")[0] for p in mesh_path]
|
||||||
@ -265,15 +227,18 @@ class ImageRender(object):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self, mesh_path: str, output_dir: str, prompt: str = None
|
self, mesh_path: str, output_dir: str, prompt: str = None
|
||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
"""Renders a single mesh and returns output paths.
|
"""Render a single mesh and return paths to the rendered outputs.
|
||||||
|
|
||||||
|
Processes the input mesh, renders multiple modalities (e.g., normals,
|
||||||
|
depth, albedo), and optionally saves video or image sequences.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mesh_path (str): Path to mesh file.
|
mesh_path (str): Path to the mesh file (.obj/.glb).
|
||||||
output_dir (str): Directory to save outputs.
|
output_dir (str): Directory to save rendered outputs.
|
||||||
prompt (str, optional): Caption prompt for MP4 metadata.
|
prompt (str, optional): Optional caption prompt for MP4 metadata.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, str]: Mapping of render types to saved image paths.
|
dict[str, str]: A mapping render types to the saved image paths.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
mesh = import_kaolin_mesh(mesh_path, self.with_mtl)
|
mesh = import_kaolin_mesh(mesh_path, self.with_mtl)
|
||||||
|
|||||||
@ -16,13 +16,17 @@
|
|||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import multiprocessing as mp
|
||||||
|
import os
|
||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
|
import coacd
|
||||||
import igraph
|
import igraph
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pyvista as pv
|
import pyvista as pv
|
||||||
import spaces
|
import spaces
|
||||||
import torch
|
import torch
|
||||||
|
import trimesh
|
||||||
import utils3d
|
import utils3d
|
||||||
from pymeshfix import _meshfix
|
from pymeshfix import _meshfix
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|||||||
@ -66,7 +66,6 @@ __all__ = [
|
|||||||
"resize_pil",
|
"resize_pil",
|
||||||
"trellis_preprocess",
|
"trellis_preprocess",
|
||||||
"delete_dir",
|
"delete_dir",
|
||||||
"kaolin_to_opencv_view",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -374,18 +373,10 @@ def _compute_az_el_by_views(
|
|||||||
def _compute_cam_pts_by_az_el(
|
def _compute_cam_pts_by_az_el(
|
||||||
azs: np.ndarray,
|
azs: np.ndarray,
|
||||||
els: np.ndarray,
|
els: np.ndarray,
|
||||||
distance: float | list[float] | np.ndarray,
|
distance: float,
|
||||||
extra_pts: np.ndarray = None,
|
extra_pts: np.ndarray = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
if np.isscalar(distance) or isinstance(distance, (float, int)):
|
distances = np.array([distance for _ in range(len(azs))])
|
||||||
distances = np.full(len(azs), distance)
|
|
||||||
else:
|
|
||||||
distances = np.array(distance)
|
|
||||||
if len(distances) != len(azs):
|
|
||||||
raise ValueError(
|
|
||||||
f"Length of distances ({len(distances)}) must match length of azs ({len(azs)})"
|
|
||||||
)
|
|
||||||
|
|
||||||
cam_pts = _az_el_to_points(azs, els) * distances[:, None]
|
cam_pts = _az_el_to_points(azs, els) * distances[:, None]
|
||||||
|
|
||||||
if extra_pts is not None:
|
if extra_pts is not None:
|
||||||
@ -719,7 +710,7 @@ class CameraSetting:
|
|||||||
|
|
||||||
num_images: int
|
num_images: int
|
||||||
elevation: list[float]
|
elevation: list[float]
|
||||||
distance: float | list[float]
|
distance: float
|
||||||
resolution_hw: tuple[int, int]
|
resolution_hw: tuple[int, int]
|
||||||
fov: float
|
fov: float
|
||||||
at: tuple[float, float, float] = field(
|
at: tuple[float, float, float] = field(
|
||||||
@ -833,28 +824,6 @@ def import_kaolin_mesh(mesh_path: str, with_mtl: bool = False):
|
|||||||
return mesh
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
def kaolin_to_opencv_view(raw_matrix):
|
|
||||||
R_orig = raw_matrix[:, :3, :3]
|
|
||||||
t_orig = raw_matrix[:, :3, 3]
|
|
||||||
|
|
||||||
R_target = torch.zeros_like(R_orig)
|
|
||||||
R_target[:, :, 0] = R_orig[:, :, 2]
|
|
||||||
R_target[:, :, 1] = R_orig[:, :, 0]
|
|
||||||
R_target[:, :, 2] = R_orig[:, :, 1]
|
|
||||||
|
|
||||||
t_target = t_orig
|
|
||||||
|
|
||||||
target_matrix = (
|
|
||||||
torch.eye(4, device=raw_matrix.device)
|
|
||||||
.unsqueeze(0)
|
|
||||||
.repeat(raw_matrix.size(0), 1, 1)
|
|
||||||
)
|
|
||||||
target_matrix[:, :3, :3] = R_target
|
|
||||||
target_matrix[:, :3, 3] = t_target
|
|
||||||
|
|
||||||
return target_matrix
|
|
||||||
|
|
||||||
|
|
||||||
def save_mesh_with_mtl(
|
def save_mesh_with_mtl(
|
||||||
vertices: np.ndarray,
|
vertices: np.ndarray,
|
||||||
faces: np.ndarray,
|
faces: np.ndarray,
|
||||||
|
|||||||
@ -51,33 +51,6 @@ __all__ = ["PickEmbodiedGen"]
|
|||||||
|
|
||||||
@register_env("PickEmbodiedGen-v1", max_episode_steps=100)
|
@register_env("PickEmbodiedGen-v1", max_episode_steps=100)
|
||||||
class PickEmbodiedGen(BaseEnv):
|
class PickEmbodiedGen(BaseEnv):
|
||||||
"""PickEmbodiedGen as gym env example for object pick-and-place tasks.
|
|
||||||
|
|
||||||
This environment simulates a robot interacting with 3D assets in the
|
|
||||||
embodiedgen generated scene in SAPIEN. It supports multi-environment setups,
|
|
||||||
dynamic reconfiguration, and hybrid rendering with 3D Gaussian Splatting.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
Use `gym.make` to create the `PickEmbodiedGen-v1` parallel environment.
|
|
||||||
```python
|
|
||||||
import gymnasium as gym
|
|
||||||
env = gym.make(
|
|
||||||
"PickEmbodiedGen-v1",
|
|
||||||
num_envs=cfg.num_envs,
|
|
||||||
render_mode=cfg.render_mode,
|
|
||||||
enable_shadow=cfg.enable_shadow,
|
|
||||||
layout_file=cfg.layout_file,
|
|
||||||
control_mode=cfg.control_mode,
|
|
||||||
camera_cfg=dict(
|
|
||||||
camera_eye=cfg.camera_eye,
|
|
||||||
camera_target_pt=cfg.camera_target_pt,
|
|
||||||
image_hw=cfg.image_hw,
|
|
||||||
fovy_deg=cfg.fovy_deg,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
SUPPORTED_ROBOTS = ["panda", "panda_wristcam", "fetch"]
|
SUPPORTED_ROBOTS = ["panda", "panda_wristcam", "fetch"]
|
||||||
goal_thresh = 0.0
|
goal_thresh = 0.0
|
||||||
|
|
||||||
@ -90,19 +63,6 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
reconfiguration_freq: int = None,
|
reconfiguration_freq: int = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Initializes the PickEmbodiedGen environment.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
*args: Variable length argument list for the base class.
|
|
||||||
robot_uids: The robot(s) to use in the environment.
|
|
||||||
robot_init_qpos_noise: Noise added to the robot's initial joint
|
|
||||||
positions.
|
|
||||||
num_envs: The number of parallel environments to create.
|
|
||||||
reconfiguration_freq: How often to reconfigure the scene. If None,
|
|
||||||
it is set based on num_envs.
|
|
||||||
**kwargs: Additional keyword arguments for environment setup,
|
|
||||||
including layout_file, replace_objs, enable_grasp, etc.
|
|
||||||
"""
|
|
||||||
self.robot_init_qpos_noise = robot_init_qpos_noise
|
self.robot_init_qpos_noise = robot_init_qpos_noise
|
||||||
if reconfiguration_freq is None:
|
if reconfiguration_freq is None:
|
||||||
if num_envs == 1:
|
if num_envs == 1:
|
||||||
@ -156,22 +116,6 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
def init_env_layouts(
|
def init_env_layouts(
|
||||||
layout_file: str, num_envs: int, replace_objs: bool
|
layout_file: str, num_envs: int, replace_objs: bool
|
||||||
) -> list[LayoutInfo]:
|
) -> list[LayoutInfo]:
|
||||||
"""Initializes and saves layout files for each environment instance.
|
|
||||||
|
|
||||||
For each environment, this method creates a layout configuration. If
|
|
||||||
`replace_objs` is True, it generates new object placements for each
|
|
||||||
subsequent environment. The generated layouts are saved as new JSON
|
|
||||||
files.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layout_file: Path to the base layout JSON file.
|
|
||||||
num_envs: The number of environments to create layouts for.
|
|
||||||
replace_objs: If True, generates new object placements for each
|
|
||||||
environment after the first one using BFS placement.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of file paths to the generated layout for each environment.
|
|
||||||
"""
|
|
||||||
layouts = []
|
layouts = []
|
||||||
for env_idx in range(num_envs):
|
for env_idx in range(num_envs):
|
||||||
if replace_objs and env_idx > 0:
|
if replace_objs and env_idx > 0:
|
||||||
@ -192,18 +136,6 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
def compute_robot_init_pose(
|
def compute_robot_init_pose(
|
||||||
layouts: list[str], num_envs: int, z_offset: float = 0.0
|
layouts: list[str], num_envs: int, z_offset: float = 0.0
|
||||||
) -> list[list[float]]:
|
) -> list[list[float]]:
|
||||||
"""Computes the initial pose for the robot in each environment.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layouts: A list of file paths to the environment layouts.
|
|
||||||
num_envs: The number of environments.
|
|
||||||
z_offset: An optional vertical offset to apply to the robot's
|
|
||||||
position to prevent collisions.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of initial poses ([x, y, z, qw, qx, qy, qz]) for the robot
|
|
||||||
in each environment.
|
|
||||||
"""
|
|
||||||
robot_pose = []
|
robot_pose = []
|
||||||
for env_idx in range(num_envs):
|
for env_idx in range(num_envs):
|
||||||
layout = json.load(open(layouts[env_idx], "r"))
|
layout = json.load(open(layouts[env_idx], "r"))
|
||||||
@ -216,11 +148,6 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _default_sim_config(self):
|
def _default_sim_config(self):
|
||||||
"""Returns the default simulation configuration.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The default simulation configuration object.
|
|
||||||
"""
|
|
||||||
return SimConfig(
|
return SimConfig(
|
||||||
scene_config=SceneConfig(
|
scene_config=SceneConfig(
|
||||||
solver_position_iterations=30,
|
solver_position_iterations=30,
|
||||||
@ -236,11 +163,6 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _default_sensor_configs(self):
|
def _default_sensor_configs(self):
|
||||||
"""Returns the default sensor configurations for the agent.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list containing the default camera configuration.
|
|
||||||
"""
|
|
||||||
pose = sapien_utils.look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1])
|
pose = sapien_utils.look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1])
|
||||||
|
|
||||||
return [
|
return [
|
||||||
@ -249,11 +171,6 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _default_human_render_camera_configs(self):
|
def _default_human_render_camera_configs(self):
|
||||||
"""Returns the default camera configuration for human-friendly rendering.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The default camera configuration for the renderer.
|
|
||||||
"""
|
|
||||||
pose = sapien_utils.look_at(
|
pose = sapien_utils.look_at(
|
||||||
eye=self.camera_cfg["camera_eye"],
|
eye=self.camera_cfg["camera_eye"],
|
||||||
target=self.camera_cfg["camera_target_pt"],
|
target=self.camera_cfg["camera_target_pt"],
|
||||||
@ -270,24 +187,10 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _load_agent(self, options: dict):
|
def _load_agent(self, options: dict):
|
||||||
"""Loads the agent (robot) and a ground plane into the scene.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
options: A dictionary of options for loading the agent.
|
|
||||||
"""
|
|
||||||
self.ground = build_ground(self.scene)
|
self.ground = build_ground(self.scene)
|
||||||
super()._load_agent(options, sapien.Pose(p=[-10, 0, 10]))
|
super()._load_agent(options, sapien.Pose(p=[-10, 0, 10]))
|
||||||
|
|
||||||
def _load_scene(self, options: dict):
|
def _load_scene(self, options: dict):
|
||||||
"""Loads all assets, objects, and the goal site into the scene.
|
|
||||||
|
|
||||||
This method iterates through the layouts for each environment, loads the
|
|
||||||
specified assets, and adds them to the simulation. It also creates a
|
|
||||||
kinematic sphere to represent the goal site.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
options: A dictionary of options for loading the scene.
|
|
||||||
"""
|
|
||||||
all_objects = []
|
all_objects = []
|
||||||
logger.info(f"Loading EmbodiedGen assets...")
|
logger.info(f"Loading EmbodiedGen assets...")
|
||||||
for env_idx in range(self.num_envs):
|
for env_idx in range(self.num_envs):
|
||||||
@ -319,15 +222,6 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
self._hidden_objects.append(self.goal_site)
|
self._hidden_objects.append(self.goal_site)
|
||||||
|
|
||||||
def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
|
def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
|
||||||
"""Initializes an episode for a given set of environments.
|
|
||||||
|
|
||||||
This method sets the goal position, resets the robot's joint positions
|
|
||||||
with optional noise, and sets its root pose.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
env_idx: A tensor of environment indices to initialize.
|
|
||||||
options: A dictionary of options for initialization.
|
|
||||||
"""
|
|
||||||
with torch.device(self.device):
|
with torch.device(self.device):
|
||||||
b = len(env_idx)
|
b = len(env_idx)
|
||||||
goal_xyz = torch.zeros((b, 3))
|
goal_xyz = torch.zeros((b, 3))
|
||||||
@ -362,21 +256,6 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
def render_gs3d_images(
|
def render_gs3d_images(
|
||||||
self, layouts: list[str], num_envs: int, init_quat: list[float]
|
self, layouts: list[str], num_envs: int, init_quat: list[float]
|
||||||
) -> dict[str, np.ndarray]:
|
) -> dict[str, np.ndarray]:
|
||||||
"""Renders background images using a pre-trained Gaussian Splatting model.
|
|
||||||
|
|
||||||
This method pre-renders the static background for each environment from
|
|
||||||
the perspective of all cameras to be used for hybrid rendering.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layouts: A list of file paths to the environment layouts.
|
|
||||||
num_envs: The number of environments.
|
|
||||||
init_quat: An initial quaternion to orient the Gaussian Splatting
|
|
||||||
model.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary mapping a unique key (e.g., 'camera-env_idx') to the
|
|
||||||
rendered background image as a numpy array.
|
|
||||||
"""
|
|
||||||
sim_coord_align = (
|
sim_coord_align = (
|
||||||
torch.tensor(SIM_COORD_ALIGN).to(torch.float32).to(self.device)
|
torch.tensor(SIM_COORD_ALIGN).to(torch.float32).to(self.device)
|
||||||
)
|
)
|
||||||
@ -414,15 +293,6 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
return bg_images
|
return bg_images
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
"""Renders the environment based on the configured render_mode.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If `render_mode` is not set.
|
|
||||||
NotImplementedError: If the `render_mode` is not supported.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The rendered output, which varies depending on the render mode.
|
|
||||||
"""
|
|
||||||
if self.render_mode is None:
|
if self.render_mode is None:
|
||||||
raise RuntimeError("render_mode is not set.")
|
raise RuntimeError("render_mode is not set.")
|
||||||
if self.render_mode == "human":
|
if self.render_mode == "human":
|
||||||
@ -445,17 +315,6 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
def render_rgb_array(
|
def render_rgb_array(
|
||||||
self, camera_name: str = None, return_alpha: bool = False
|
self, camera_name: str = None, return_alpha: bool = False
|
||||||
):
|
):
|
||||||
"""Renders an RGB image from the human-facing render camera.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
camera_name: The name of the camera to render from. If None, uses
|
|
||||||
all human render cameras.
|
|
||||||
return_alpha: Whether to include the alpha channel in the output.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A numpy array representing the rendered image(s). If multiple
|
|
||||||
cameras are used, the images are tiled.
|
|
||||||
"""
|
|
||||||
for obj in self._hidden_objects:
|
for obj in self._hidden_objects:
|
||||||
obj.show_visual()
|
obj.show_visual()
|
||||||
self.scene.update_render(
|
self.scene.update_render(
|
||||||
@ -476,11 +335,6 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
return tile_images(images)
|
return tile_images(images)
|
||||||
|
|
||||||
def render_sensors(self):
|
def render_sensors(self):
|
||||||
"""Renders images from all on-board sensor cameras.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tiled image of all sensor outputs as a numpy array.
|
|
||||||
"""
|
|
||||||
images = []
|
images = []
|
||||||
sensor_images = self.get_sensor_images()
|
sensor_images = self.get_sensor_images()
|
||||||
for image in sensor_images.values():
|
for image in sensor_images.values():
|
||||||
@ -489,14 +343,6 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
return tile_images(images)
|
return tile_images(images)
|
||||||
|
|
||||||
def hybrid_render(self):
|
def hybrid_render(self):
|
||||||
"""Renders a hybrid image by blending simulated foreground with a background.
|
|
||||||
|
|
||||||
The foreground is rendered with an alpha channel and then blended with
|
|
||||||
the pre-rendered Gaussian Splatting background image.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A torch tensor of the final blended RGB images.
|
|
||||||
"""
|
|
||||||
fg_images = self.render_rgb_array(
|
fg_images = self.render_rgb_array(
|
||||||
return_alpha=True
|
return_alpha=True
|
||||||
) # (n_env, h, w, 3)
|
) # (n_env, h, w, 3)
|
||||||
@ -516,16 +362,6 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
return images[..., :3]
|
return images[..., :3]
|
||||||
|
|
||||||
def evaluate(self):
|
def evaluate(self):
|
||||||
"""Evaluates the current state of the environment.
|
|
||||||
|
|
||||||
Checks for task success criteria such as whether the object is grasped,
|
|
||||||
placed at the goal, and if the robot is static.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary containing boolean tensors for various success
|
|
||||||
metrics, including 'is_grasped', 'is_obj_placed', and overall
|
|
||||||
'success'.
|
|
||||||
"""
|
|
||||||
obj_to_goal_pos = (
|
obj_to_goal_pos = (
|
||||||
self.obj.pose.p
|
self.obj.pose.p
|
||||||
) # self.goal_site.pose.p - self.obj.pose.p
|
) # self.goal_site.pose.p - self.obj.pose.p
|
||||||
@ -545,31 +381,10 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _get_obs_extra(self, info: dict):
|
def _get_obs_extra(self, info: dict):
|
||||||
"""Gets extra information for the observation dictionary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
info: A dictionary containing evaluation information.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An empty dictionary, as no extra observations are added.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return dict()
|
return dict()
|
||||||
|
|
||||||
def compute_dense_reward(self, obs: any, action: torch.Tensor, info: dict):
|
def compute_dense_reward(self, obs: any, action: torch.Tensor, info: dict):
|
||||||
"""Computes a dense reward for the current step.
|
|
||||||
|
|
||||||
The reward is a composite of reaching, grasping, placing, and
|
|
||||||
maintaining a static final pose.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
obs: The current observation.
|
|
||||||
action: The action taken in the current step.
|
|
||||||
info: A dictionary containing evaluation information from `evaluate()`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tensor containing the dense reward for each environment.
|
|
||||||
"""
|
|
||||||
tcp_to_obj_dist = torch.linalg.norm(
|
tcp_to_obj_dist = torch.linalg.norm(
|
||||||
self.obj.pose.p - self.agent.tcp.pose.p, axis=1
|
self.obj.pose.p - self.agent.tcp.pose.p, axis=1
|
||||||
)
|
)
|
||||||
@ -602,14 +417,4 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
def compute_normalized_dense_reward(
|
def compute_normalized_dense_reward(
|
||||||
self, obs: any, action: torch.Tensor, info: dict
|
self, obs: any, action: torch.Tensor, info: dict
|
||||||
):
|
):
|
||||||
"""Computes a dense reward normalized to be between 0 and 1.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
obs: The current observation.
|
|
||||||
action: The action taken in the current step.
|
|
||||||
info: A dictionary containing evaluation information from `evaluate()`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tensor containing the normalized dense reward for each environment.
|
|
||||||
"""
|
|
||||||
return self.compute_dense_reward(obs=obs, action=action, info=info) / 6
|
return self.compute_dense_reward(obs=obs, action=action, info=info) / 6
|
||||||
|
|||||||
@ -40,7 +40,7 @@ class DelightingModel(object):
|
|||||||
"""A model to remove the lighting in image space.
|
"""A model to remove the lighting in image space.
|
||||||
|
|
||||||
This model is encapsulated based on the Hunyuan3D-Delight model
|
This model is encapsulated based on the Hunyuan3D-Delight model
|
||||||
from `https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0` # noqa
|
from https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0 # noqa
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
image_guide_scale (float): Weight of image guidance in diffusion process.
|
image_guide_scale (float): Weight of image guidance in diffusion process.
|
||||||
|
|||||||
@ -21,18 +21,14 @@ import struct
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from gsplat.cuda._wrapper import spherical_harmonics
|
from gsplat.cuda._wrapper import spherical_harmonics
|
||||||
from gsplat.rendering import rasterization
|
from gsplat.rendering import rasterization
|
||||||
from plyfile import PlyData
|
from plyfile import PlyData
|
||||||
from scipy.spatial.transform import Rotation
|
from scipy.spatial.transform import Rotation
|
||||||
from embodied_gen.data.utils import (
|
from embodied_gen.data.utils import gamma_shs, quat_mult, quat_to_rotmat
|
||||||
gamma_shs,
|
|
||||||
normalize_vertices_array,
|
|
||||||
quat_mult,
|
|
||||||
quat_to_rotmat,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -498,21 +494,6 @@ class GaussianOperator(GaussianBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_gs_model(
|
|
||||||
input_gs: str, pre_quat: list[float] = [0.0, 0.7071, 0.0, -0.7071]
|
|
||||||
) -> GaussianOperator:
|
|
||||||
gs_model = GaussianOperator.load_from_ply(input_gs)
|
|
||||||
# Normalize vertices to [-1, 1], center to (0, 0, 0).
|
|
||||||
_, scale, center = normalize_vertices_array(gs_model._means)
|
|
||||||
scale, center = float(scale), center.tolist()
|
|
||||||
transpose = [*[v for v in center], *pre_quat]
|
|
||||||
instance_pose = torch.tensor(transpose).to(gs_model.device)
|
|
||||||
gs_model = gs_model.get_gaussians(instance_pose=instance_pose)
|
|
||||||
gs_model.rescale(scale)
|
|
||||||
|
|
||||||
return gs_model
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
input_gs = "outputs/layouts_gens_demo/task_0000/background/gs_model.ply"
|
input_gs = "outputs/layouts_gens_demo/task_0000/background/gs_model.ply"
|
||||||
output_gs = "./gs_model.ply"
|
output_gs = "./gs_model.ply"
|
||||||
|
|||||||
@ -38,61 +38,26 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class BasePipelineLoader(ABC):
|
class BasePipelineLoader(ABC):
|
||||||
"""Abstract base class for loading Hugging Face image generation pipelines.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
device (str): Device to load the pipeline on.
|
|
||||||
|
|
||||||
Methods:
|
|
||||||
load(): Loads and returns the pipeline.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, device="cuda"):
|
def __init__(self, device="cuda"):
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load(self):
|
def load(self):
|
||||||
"""Load and return the pipeline instance."""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BasePipelineRunner(ABC):
|
class BasePipelineRunner(ABC):
|
||||||
"""Abstract base class for running image generation pipelines.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
pipe: The loaded pipeline.
|
|
||||||
|
|
||||||
Methods:
|
|
||||||
run(prompt, **kwargs): Runs the pipeline with a prompt.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, pipe):
|
def __init__(self, pipe):
|
||||||
self.pipe = pipe
|
self.pipe = pipe
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def run(self, prompt: str, **kwargs) -> Image.Image:
|
def run(self, prompt: str, **kwargs) -> Image.Image:
|
||||||
"""Run the pipeline with the given prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt (str): Text prompt for image generation.
|
|
||||||
**kwargs: Additional pipeline arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Image.Image: Generated image(s).
|
|
||||||
"""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# ===== SD3.5-medium =====
|
# ===== SD3.5-medium =====
|
||||||
class SD35Loader(BasePipelineLoader):
|
class SD35Loader(BasePipelineLoader):
|
||||||
"""Loader for Stable Diffusion 3.5 medium pipeline."""
|
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
"""Load the Stable Diffusion 3.5 medium pipeline.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
StableDiffusion3Pipeline: Loaded pipeline.
|
|
||||||
"""
|
|
||||||
pipe = StableDiffusion3Pipeline.from_pretrained(
|
pipe = StableDiffusion3Pipeline.from_pretrained(
|
||||||
"stabilityai/stable-diffusion-3.5-medium",
|
"stabilityai/stable-diffusion-3.5-medium",
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
@ -105,25 +70,12 @@ class SD35Loader(BasePipelineLoader):
|
|||||||
|
|
||||||
|
|
||||||
class SD35Runner(BasePipelineRunner):
|
class SD35Runner(BasePipelineRunner):
|
||||||
"""Runner for Stable Diffusion 3.5 medium pipeline."""
|
|
||||||
|
|
||||||
def run(self, prompt: str, **kwargs) -> Image.Image:
|
def run(self, prompt: str, **kwargs) -> Image.Image:
|
||||||
"""Generate images using Stable Diffusion 3.5 medium.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt (str): Text prompt.
|
|
||||||
**kwargs: Additional arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Image.Image: Generated image(s).
|
|
||||||
"""
|
|
||||||
return self.pipe(prompt=prompt, **kwargs).images
|
return self.pipe(prompt=prompt, **kwargs).images
|
||||||
|
|
||||||
|
|
||||||
# ===== Cosmos2 =====
|
# ===== Cosmos2 =====
|
||||||
class CosmosLoader(BasePipelineLoader):
|
class CosmosLoader(BasePipelineLoader):
|
||||||
"""Loader for Cosmos2 text-to-image pipeline."""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_id="nvidia/Cosmos-Predict2-2B-Text2Image",
|
model_id="nvidia/Cosmos-Predict2-2B-Text2Image",
|
||||||
@ -135,8 +87,6 @@ class CosmosLoader(BasePipelineLoader):
|
|||||||
self.local_dir = local_dir
|
self.local_dir = local_dir
|
||||||
|
|
||||||
def _patch(self):
|
def _patch(self):
|
||||||
"""Patch model and processor for optimized loading."""
|
|
||||||
|
|
||||||
def patch_model(cls):
|
def patch_model(cls):
|
||||||
orig = cls.from_pretrained
|
orig = cls.from_pretrained
|
||||||
|
|
||||||
@ -160,11 +110,6 @@ class CosmosLoader(BasePipelineLoader):
|
|||||||
patch_processor(SiglipProcessor)
|
patch_processor(SiglipProcessor)
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
"""Load the Cosmos2 text-to-image pipeline.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Cosmos2TextToImagePipeline: Loaded pipeline.
|
|
||||||
"""
|
|
||||||
self._patch()
|
self._patch()
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
repo_id=self.model_id,
|
repo_id=self.model_id,
|
||||||
@ -196,19 +141,7 @@ class CosmosLoader(BasePipelineLoader):
|
|||||||
|
|
||||||
|
|
||||||
class CosmosRunner(BasePipelineRunner):
|
class CosmosRunner(BasePipelineRunner):
|
||||||
"""Runner for Cosmos2 text-to-image pipeline."""
|
|
||||||
|
|
||||||
def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
|
def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
|
||||||
"""Generate images using Cosmos2 pipeline.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt (str): Text prompt.
|
|
||||||
negative_prompt (str, optional): Negative prompt.
|
|
||||||
**kwargs: Additional arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Image.Image: Generated image(s).
|
|
||||||
"""
|
|
||||||
return self.pipe(
|
return self.pipe(
|
||||||
prompt=prompt, negative_prompt=negative_prompt, **kwargs
|
prompt=prompt, negative_prompt=negative_prompt, **kwargs
|
||||||
).images
|
).images
|
||||||
@ -216,14 +149,7 @@ class CosmosRunner(BasePipelineRunner):
|
|||||||
|
|
||||||
# ===== Kolors =====
|
# ===== Kolors =====
|
||||||
class KolorsLoader(BasePipelineLoader):
|
class KolorsLoader(BasePipelineLoader):
|
||||||
"""Loader for Kolors pipeline."""
|
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
"""Load the Kolors pipeline.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
KolorsPipeline: Loaded pipeline.
|
|
||||||
"""
|
|
||||||
pipe = KolorsPipeline.from_pretrained(
|
pipe = KolorsPipeline.from_pretrained(
|
||||||
"Kwai-Kolors/Kolors-diffusers",
|
"Kwai-Kolors/Kolors-diffusers",
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
@ -238,31 +164,13 @@ class KolorsLoader(BasePipelineLoader):
|
|||||||
|
|
||||||
|
|
||||||
class KolorsRunner(BasePipelineRunner):
|
class KolorsRunner(BasePipelineRunner):
|
||||||
"""Runner for Kolors pipeline."""
|
|
||||||
|
|
||||||
def run(self, prompt: str, **kwargs) -> Image.Image:
|
def run(self, prompt: str, **kwargs) -> Image.Image:
|
||||||
"""Generate images using Kolors pipeline.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt (str): Text prompt.
|
|
||||||
**kwargs: Additional arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Image.Image: Generated image(s).
|
|
||||||
"""
|
|
||||||
return self.pipe(prompt=prompt, **kwargs).images
|
return self.pipe(prompt=prompt, **kwargs).images
|
||||||
|
|
||||||
|
|
||||||
# ===== Flux =====
|
# ===== Flux =====
|
||||||
class FluxLoader(BasePipelineLoader):
|
class FluxLoader(BasePipelineLoader):
|
||||||
"""Loader for Flux pipeline."""
|
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
"""Load the Flux pipeline.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
FluxPipeline: Loaded pipeline.
|
|
||||||
"""
|
|
||||||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
||||||
pipe = FluxPipeline.from_pretrained(
|
pipe = FluxPipeline.from_pretrained(
|
||||||
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
|
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
|
||||||
@ -274,50 +182,20 @@ class FluxLoader(BasePipelineLoader):
|
|||||||
|
|
||||||
|
|
||||||
class FluxRunner(BasePipelineRunner):
|
class FluxRunner(BasePipelineRunner):
|
||||||
"""Runner for Flux pipeline."""
|
|
||||||
|
|
||||||
def run(self, prompt: str, **kwargs) -> Image.Image:
|
def run(self, prompt: str, **kwargs) -> Image.Image:
|
||||||
"""Generate images using Flux pipeline.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt (str): Text prompt.
|
|
||||||
**kwargs: Additional arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Image.Image: Generated image(s).
|
|
||||||
"""
|
|
||||||
return self.pipe(prompt=prompt, **kwargs).images
|
return self.pipe(prompt=prompt, **kwargs).images
|
||||||
|
|
||||||
|
|
||||||
# ===== Chroma =====
|
# ===== Chroma =====
|
||||||
class ChromaLoader(BasePipelineLoader):
|
class ChromaLoader(BasePipelineLoader):
|
||||||
"""Loader for Chroma pipeline."""
|
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
"""Load the Chroma pipeline.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ChromaPipeline: Loaded pipeline.
|
|
||||||
"""
|
|
||||||
return ChromaPipeline.from_pretrained(
|
return ChromaPipeline.from_pretrained(
|
||||||
"lodestones/Chroma", torch_dtype=torch.bfloat16
|
"lodestones/Chroma", torch_dtype=torch.bfloat16
|
||||||
).to(self.device)
|
).to(self.device)
|
||||||
|
|
||||||
|
|
||||||
class ChromaRunner(BasePipelineRunner):
|
class ChromaRunner(BasePipelineRunner):
|
||||||
"""Runner for Chroma pipeline."""
|
|
||||||
|
|
||||||
def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
|
def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
|
||||||
"""Generate images using Chroma pipeline.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt (str): Text prompt.
|
|
||||||
negative_prompt (str, optional): Negative prompt.
|
|
||||||
**kwargs: Additional arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Image.Image: Generated image(s).
|
|
||||||
"""
|
|
||||||
return self.pipe(
|
return self.pipe(
|
||||||
prompt=prompt, negative_prompt=negative_prompt, **kwargs
|
prompt=prompt, negative_prompt=negative_prompt, **kwargs
|
||||||
).images
|
).images
|
||||||
@ -333,22 +211,6 @@ PIPELINE_REGISTRY = {
|
|||||||
|
|
||||||
|
|
||||||
def build_hf_image_pipeline(name: str, device="cuda") -> BasePipelineRunner:
|
def build_hf_image_pipeline(name: str, device="cuda") -> BasePipelineRunner:
|
||||||
"""Build a Hugging Face image generation pipeline runner by name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name (str): Name of the pipeline (e.g., "sd35", "cosmos").
|
|
||||||
device (str): Device to load the pipeline on.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BasePipelineRunner: Pipeline runner instance.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```py
|
|
||||||
from embodied_gen.models.image_comm_model import build_hf_image_pipeline
|
|
||||||
runner = build_hf_image_pipeline("sd35")
|
|
||||||
images = runner.run(prompt="A robot holding a sign that says 'Hello'")
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
if name not in PIPELINE_REGISTRY:
|
if name not in PIPELINE_REGISTRY:
|
||||||
raise ValueError(f"Unsupported model: {name}")
|
raise ValueError(f"Unsupported model: {name}")
|
||||||
loader_cls, runner_cls = PIPELINE_REGISTRY[name]
|
loader_cls, runner_cls = PIPELINE_REGISTRY[name]
|
||||||
|
|||||||
@ -376,21 +376,6 @@ LAYOUT_DESCRIBER_PROMPT = """
|
|||||||
|
|
||||||
|
|
||||||
class LayoutDesigner(object):
|
class LayoutDesigner(object):
|
||||||
"""A class for querying GPT-based scene layout reasoning and formatting responses.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
prompt (str): The system prompt for GPT.
|
|
||||||
verbose (bool): Whether to log responses.
|
|
||||||
gpt_client (GPTclient): The GPT client instance.
|
|
||||||
|
|
||||||
Methods:
|
|
||||||
query(prompt, params): Query GPT with a prompt and parameters.
|
|
||||||
format_response(response): Parse and clean JSON response.
|
|
||||||
format_response_repair(response): Repair and parse JSON response.
|
|
||||||
save_output(output, save_path): Save output to file.
|
|
||||||
__call__(prompt, save_path, params): Query and process output.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
gpt_client: GPTclient,
|
gpt_client: GPTclient,
|
||||||
@ -402,15 +387,6 @@ class LayoutDesigner(object):
|
|||||||
self.gpt_client = gpt_client
|
self.gpt_client = gpt_client
|
||||||
|
|
||||||
def query(self, prompt: str, params: dict = None) -> str:
|
def query(self, prompt: str, params: dict = None) -> str:
|
||||||
"""Query GPT with the system prompt and user prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt (str): User prompt.
|
|
||||||
params (dict, optional): GPT parameters.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: GPT response.
|
|
||||||
"""
|
|
||||||
full_prompt = self.prompt + f"\n\nInput:\n\"{prompt}\""
|
full_prompt = self.prompt + f"\n\nInput:\n\"{prompt}\""
|
||||||
|
|
||||||
response = self.gpt_client.query(
|
response = self.gpt_client.query(
|
||||||
@ -424,17 +400,6 @@ class LayoutDesigner(object):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
def format_response(self, response: str) -> dict:
|
def format_response(self, response: str) -> dict:
|
||||||
"""Format and parse GPT response as JSON.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response (str): Raw GPT response.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Parsed JSON output.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
json.JSONDecodeError: If parsing fails.
|
|
||||||
"""
|
|
||||||
cleaned = re.sub(r"^```json\s*|\s*```$", "", response.strip())
|
cleaned = re.sub(r"^```json\s*|\s*```$", "", response.strip())
|
||||||
try:
|
try:
|
||||||
output = json.loads(cleaned)
|
output = json.loads(cleaned)
|
||||||
@ -446,23 +411,9 @@ class LayoutDesigner(object):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
def format_response_repair(self, response: str) -> dict:
|
def format_response_repair(self, response: str) -> dict:
|
||||||
"""Repair and parse possibly broken JSON response.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response (str): Raw GPT response.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Parsed JSON output.
|
|
||||||
"""
|
|
||||||
return json_repair.loads(response)
|
return json_repair.loads(response)
|
||||||
|
|
||||||
def save_output(self, output: dict, save_path: str) -> None:
|
def save_output(self, output: dict, save_path: str) -> None:
|
||||||
"""Save output dictionary to a file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
output (dict): Output data.
|
|
||||||
save_path (str): Path to save the file.
|
|
||||||
"""
|
|
||||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||||
with open(save_path, 'w') as f:
|
with open(save_path, 'w') as f:
|
||||||
json.dump(output, f, indent=4)
|
json.dump(output, f, indent=4)
|
||||||
@ -470,16 +421,6 @@ class LayoutDesigner(object):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self, prompt: str, save_path: str = None, params: dict = None
|
self, prompt: str, save_path: str = None, params: dict = None
|
||||||
) -> dict | str:
|
) -> dict | str:
|
||||||
"""Query GPT and process the output.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt (str): User prompt.
|
|
||||||
save_path (str, optional): Path to save output.
|
|
||||||
params (dict, optional): GPT parameters.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict | str: Output data.
|
|
||||||
"""
|
|
||||||
response = self.query(prompt, params=params)
|
response = self.query(prompt, params=params)
|
||||||
output = self.format_response_repair(response)
|
output = self.format_response_repair(response)
|
||||||
self.save_output(output, save_path) if save_path else None
|
self.save_output(output, save_path) if save_path else None
|
||||||
@ -501,29 +442,6 @@ LAYOUT_DESCRIBER = LayoutDesigner(
|
|||||||
def build_scene_layout(
|
def build_scene_layout(
|
||||||
task_desc: str, output_path: str = None, gpt_params: dict = None
|
task_desc: str, output_path: str = None, gpt_params: dict = None
|
||||||
) -> LayoutInfo:
|
) -> LayoutInfo:
|
||||||
"""Build a 3D scene layout from a natural language task description.
|
|
||||||
|
|
||||||
This function uses GPT-based reasoning to generate a structured scene layout,
|
|
||||||
including object hierarchy, spatial relations, and style descriptions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_desc (str): Natural language description of the robotic task.
|
|
||||||
output_path (str, optional): Path to save the visualized scene tree.
|
|
||||||
gpt_params (dict, optional): Parameters for GPT queries.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
LayoutInfo: Structured layout information for the scene.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```py
|
|
||||||
from embodied_gen.models.layout import build_scene_layout
|
|
||||||
layout_info = build_scene_layout(
|
|
||||||
task_desc="Put the apples on the table on the plate",
|
|
||||||
output_path="outputs/scene_tree.jpg",
|
|
||||||
)
|
|
||||||
print(layout_info)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
layout_relation = LAYOUT_DISASSEMBLER(task_desc, params=gpt_params)
|
layout_relation = LAYOUT_DISASSEMBLER(task_desc, params=gpt_params)
|
||||||
layout_tree = LAYOUT_GRAPHER(layout_relation, params=gpt_params)
|
layout_tree = LAYOUT_GRAPHER(layout_relation, params=gpt_params)
|
||||||
object_mapping = Scene3DItemEnum.object_mapping(layout_relation)
|
object_mapping = Scene3DItemEnum.object_mapping(layout_relation)
|
||||||
|
|||||||
@ -48,19 +48,12 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class SAMRemover(object):
|
class SAMRemover(object):
|
||||||
"""Loads SAM models and performs background removal on images.
|
"""Loading SAM models and performing background removal on images.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
checkpoint (str): Path to the model checkpoint.
|
checkpoint (str): Path to the model checkpoint.
|
||||||
model_type (str): Type of the SAM model to load.
|
model_type (str): Type of the SAM model to load (default: "vit_h").
|
||||||
area_ratio (float): Area ratio for filtering small connected components.
|
area_ratio (float): Area ratio filtering small connected components.
|
||||||
|
|
||||||
Example:
|
|
||||||
```py
|
|
||||||
from embodied_gen.models.segment_model import SAMRemover
|
|
||||||
remover = SAMRemover(model_type="vit_h")
|
|
||||||
result = remover("input.jpg", "output.png")
|
|
||||||
```
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -85,14 +78,6 @@ class SAMRemover(object):
|
|||||||
self.mask_generator = self._load_sam_model(checkpoint)
|
self.mask_generator = self._load_sam_model(checkpoint)
|
||||||
|
|
||||||
def _load_sam_model(self, checkpoint: str) -> SamAutomaticMaskGenerator:
|
def _load_sam_model(self, checkpoint: str) -> SamAutomaticMaskGenerator:
|
||||||
"""Loads the SAM model and returns a mask generator.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
checkpoint (str): Path to model checkpoint.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SamAutomaticMaskGenerator: Mask generator instance.
|
|
||||||
"""
|
|
||||||
sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
|
sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
|
||||||
sam.to(device=self.device)
|
sam.to(device=self.device)
|
||||||
|
|
||||||
@ -104,11 +89,13 @@ class SAMRemover(object):
|
|||||||
"""Removes the background from an image using the SAM model.
|
"""Removes the background from an image using the SAM model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image (Union[str, Image.Image, np.ndarray]): Input image.
|
image (Union[str, Image.Image, np.ndarray]): Input image,
|
||||||
save_path (str, optional): Path to save the output image.
|
can be a file path, PIL Image, or numpy array.
|
||||||
|
save_path (str): Path to save the output image (default: None).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Image.Image: Image with background removed (RGBA).
|
Image.Image: The image with background removed,
|
||||||
|
including an alpha channel.
|
||||||
"""
|
"""
|
||||||
# Convert input to numpy array
|
# Convert input to numpy array
|
||||||
if isinstance(image, str):
|
if isinstance(image, str):
|
||||||
@ -147,15 +134,6 @@ class SAMRemover(object):
|
|||||||
|
|
||||||
|
|
||||||
class SAMPredictor(object):
|
class SAMPredictor(object):
|
||||||
"""Loads SAM models and predicts segmentation masks from user points.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
checkpoint (str, optional): Path to model checkpoint.
|
|
||||||
model_type (str, optional): SAM model type.
|
|
||||||
binary_thresh (float, optional): Threshold for binary mask.
|
|
||||||
device (str, optional): Device for inference.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
checkpoint: str = None,
|
checkpoint: str = None,
|
||||||
@ -179,28 +157,12 @@ class SAMPredictor(object):
|
|||||||
self.binary_thresh = binary_thresh
|
self.binary_thresh = binary_thresh
|
||||||
|
|
||||||
def _load_sam_model(self, checkpoint: str) -> SamPredictor:
|
def _load_sam_model(self, checkpoint: str) -> SamPredictor:
|
||||||
"""Loads the SAM model and returns a predictor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
checkpoint (str): Path to model checkpoint.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SamPredictor: Predictor instance.
|
|
||||||
"""
|
|
||||||
sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
|
sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
|
||||||
sam.to(device=self.device)
|
sam.to(device=self.device)
|
||||||
|
|
||||||
return SamPredictor(sam)
|
return SamPredictor(sam)
|
||||||
|
|
||||||
def preprocess_image(self, image: Image.Image) -> np.ndarray:
|
def preprocess_image(self, image: Image.Image) -> np.ndarray:
|
||||||
"""Preprocesses input image for SAM prediction.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image (Image.Image): Input image.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Preprocessed image array.
|
|
||||||
"""
|
|
||||||
if isinstance(image, str):
|
if isinstance(image, str):
|
||||||
image = Image.open(image)
|
image = Image.open(image)
|
||||||
elif isinstance(image, np.ndarray):
|
elif isinstance(image, np.ndarray):
|
||||||
@ -216,15 +178,6 @@ class SAMPredictor(object):
|
|||||||
image: np.ndarray,
|
image: np.ndarray,
|
||||||
selected_points: list[list[int]],
|
selected_points: list[list[int]],
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Generates segmentation masks from selected points.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image (np.ndarray): Input image array.
|
|
||||||
selected_points (list[list[int]]): List of points and labels.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[tuple[np.ndarray, str]]: List of masks and names.
|
|
||||||
"""
|
|
||||||
if len(selected_points) == 0:
|
if len(selected_points) == 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@ -267,15 +220,6 @@ class SAMPredictor(object):
|
|||||||
def get_segmented_image(
|
def get_segmented_image(
|
||||||
self, image: np.ndarray, masks: list[tuple[np.ndarray, str]]
|
self, image: np.ndarray, masks: list[tuple[np.ndarray, str]]
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
"""Combines masks and returns segmented image with alpha channel.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image (np.ndarray): Input image array.
|
|
||||||
masks (list[tuple[np.ndarray, str]]): List of masks.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Image.Image: Segmented RGBA image.
|
|
||||||
"""
|
|
||||||
seg_image = Image.fromarray(image, mode="RGB")
|
seg_image = Image.fromarray(image, mode="RGB")
|
||||||
alpha_channel = np.zeros(
|
alpha_channel = np.zeros(
|
||||||
(seg_image.height, seg_image.width), dtype=np.uint8
|
(seg_image.height, seg_image.width), dtype=np.uint8
|
||||||
@ -297,15 +241,6 @@ class SAMPredictor(object):
|
|||||||
image: Union[str, Image.Image, np.ndarray],
|
image: Union[str, Image.Image, np.ndarray],
|
||||||
selected_points: list[list[int]],
|
selected_points: list[list[int]],
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
"""Segments image using selected points.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image (Union[str, Image.Image, np.ndarray]): Input image.
|
|
||||||
selected_points (list[list[int]]): List of points and labels.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Image.Image: Segmented RGBA image.
|
|
||||||
"""
|
|
||||||
image = self.preprocess_image(image)
|
image = self.preprocess_image(image)
|
||||||
self.predictor.set_image(image)
|
self.predictor.set_image(image)
|
||||||
masks = self.generate_masks(image, selected_points)
|
masks = self.generate_masks(image, selected_points)
|
||||||
@ -314,32 +249,12 @@ class SAMPredictor(object):
|
|||||||
|
|
||||||
|
|
||||||
class RembgRemover(object):
|
class RembgRemover(object):
|
||||||
"""Removes background from images using the rembg library.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```py
|
|
||||||
from embodied_gen.models.segment_model import RembgRemover
|
|
||||||
remover = RembgRemover()
|
|
||||||
result = remover("input.jpg", "output.png")
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initializes the RembgRemover."""
|
|
||||||
self.rembg_session = rembg.new_session("u2net")
|
self.rembg_session = rembg.new_session("u2net")
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
|
self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
"""Removes background from an image.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image (Union[str, Image.Image, np.ndarray]): Input image.
|
|
||||||
save_path (str, optional): Path to save the output image.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Image.Image: Image with background removed (RGBA).
|
|
||||||
"""
|
|
||||||
if isinstance(image, str):
|
if isinstance(image, str):
|
||||||
image = Image.open(image)
|
image = Image.open(image)
|
||||||
elif isinstance(image, np.ndarray):
|
elif isinstance(image, np.ndarray):
|
||||||
@ -356,18 +271,7 @@ class RembgRemover(object):
|
|||||||
|
|
||||||
|
|
||||||
class BMGG14Remover(object):
|
class BMGG14Remover(object):
|
||||||
"""Removes background using the RMBG-1.4 segmentation model.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```py
|
|
||||||
from embodied_gen.models.segment_model import BMGG14Remover
|
|
||||||
remover = BMGG14Remover()
|
|
||||||
result = remover("input.jpg", "output.png")
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Initializes the BMGG14Remover."""
|
|
||||||
self.model = pipeline(
|
self.model = pipeline(
|
||||||
"image-segmentation",
|
"image-segmentation",
|
||||||
model="briaai/RMBG-1.4",
|
model="briaai/RMBG-1.4",
|
||||||
@ -377,15 +281,6 @@ class BMGG14Remover(object):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
|
self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
|
||||||
):
|
):
|
||||||
"""Removes background from an image.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image (Union[str, Image.Image, np.ndarray]): Input image.
|
|
||||||
save_path (str, optional): Path to save the output image.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Image.Image: Image with background removed.
|
|
||||||
"""
|
|
||||||
if isinstance(image, str):
|
if isinstance(image, str):
|
||||||
image = Image.open(image)
|
image = Image.open(image)
|
||||||
elif isinstance(image, np.ndarray):
|
elif isinstance(image, np.ndarray):
|
||||||
@ -404,16 +299,6 @@ class BMGG14Remover(object):
|
|||||||
def invert_rgba_pil(
|
def invert_rgba_pil(
|
||||||
image: Image.Image, mask: Image.Image, save_path: str = None
|
image: Image.Image, mask: Image.Image, save_path: str = None
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
"""Inverts the alpha channel of an RGBA image using a mask.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image (Image.Image): Input RGB image.
|
|
||||||
mask (Image.Image): Mask image for alpha inversion.
|
|
||||||
save_path (str, optional): Path to save the output image.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Image.Image: RGBA image with inverted alpha.
|
|
||||||
"""
|
|
||||||
mask = (255 - np.array(mask))[..., None]
|
mask = (255 - np.array(mask))[..., None]
|
||||||
image_array = np.concatenate([np.array(image), mask], axis=-1)
|
image_array = np.concatenate([np.array(image), mask], axis=-1)
|
||||||
inverted_image = Image.fromarray(image_array, "RGBA")
|
inverted_image = Image.fromarray(image_array, "RGBA")
|
||||||
@ -433,20 +318,6 @@ def get_segmented_image_by_agent(
|
|||||||
save_path: str = None,
|
save_path: str = None,
|
||||||
mode: Literal["loose", "strict"] = "loose",
|
mode: Literal["loose", "strict"] = "loose",
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
"""Segments an image using SAM and rembg, with quality checking.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image (Image.Image): Input image.
|
|
||||||
sam_remover (SAMRemover): SAM-based remover.
|
|
||||||
rbg_remover (RembgRemover): rembg-based remover.
|
|
||||||
seg_checker (ImageSegChecker, optional): Quality checker.
|
|
||||||
save_path (str, optional): Path to save the output image.
|
|
||||||
mode (Literal["loose", "strict"], optional): Segmentation mode.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Image.Image: Segmented RGBA image.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _is_valid_seg(raw_img: Image.Image, seg_img: Image.Image) -> bool:
|
def _is_valid_seg(raw_img: Image.Image, seg_img: Image.Image) -> bool:
|
||||||
if seg_checker is None:
|
if seg_checker is None:
|
||||||
return True
|
return True
|
||||||
|
|||||||
@ -39,38 +39,13 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class ImageStableSR:
|
class ImageStableSR:
|
||||||
"""Super-resolution image upscaler using Stable Diffusion x4 upscaling model.
|
"""Super-resolution image upscaler using Stable Diffusion x4 upscaling model from StabilityAI."""
|
||||||
|
|
||||||
This class wraps the StabilityAI Stable Diffusion x4 upscaler for high-quality
|
|
||||||
image super-resolution.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path (str, optional): Path or HuggingFace repo for the model.
|
|
||||||
device (str, optional): Device for inference.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```py
|
|
||||||
from embodied_gen.models.sr_model import ImageStableSR
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
sr_model = ImageStableSR()
|
|
||||||
img = Image.open("input.png")
|
|
||||||
upscaled = sr_model(img)
|
|
||||||
upscaled.save("output.png")
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_path: str = "stabilityai/stable-diffusion-x4-upscaler",
|
model_path: str = "stabilityai/stable-diffusion-x4-upscaler",
|
||||||
device="cuda",
|
device="cuda",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initializes the Stable Diffusion x4 upscaler.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path (str, optional): Model path or repo.
|
|
||||||
device (str, optional): Device for inference.
|
|
||||||
"""
|
|
||||||
from diffusers import StableDiffusionUpscalePipeline
|
from diffusers import StableDiffusionUpscalePipeline
|
||||||
|
|
||||||
self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained(
|
self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained(
|
||||||
@ -87,16 +62,6 @@ class ImageStableSR:
|
|||||||
prompt: str = "",
|
prompt: str = "",
|
||||||
infer_step: int = 20,
|
infer_step: int = 20,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
"""Performs super-resolution on the input image.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image (Union[Image.Image, np.ndarray]): Input image.
|
|
||||||
prompt (str, optional): Text prompt for upscaling.
|
|
||||||
infer_step (int, optional): Number of inference steps.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Image.Image: Upscaled image.
|
|
||||||
"""
|
|
||||||
if isinstance(image, np.ndarray):
|
if isinstance(image, np.ndarray):
|
||||||
image = Image.fromarray(image)
|
image = Image.fromarray(image)
|
||||||
|
|
||||||
@ -121,26 +86,9 @@ class ImageRealESRGAN:
|
|||||||
Attributes:
|
Attributes:
|
||||||
outscale (int): The output image scale factor (e.g., 2, 4).
|
outscale (int): The output image scale factor (e.g., 2, 4).
|
||||||
model_path (str): Path to the pre-trained model weights.
|
model_path (str): Path to the pre-trained model weights.
|
||||||
|
|
||||||
Example:
|
|
||||||
```py
|
|
||||||
from embodied_gen.models.sr_model import ImageRealESRGAN
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
sr_model = ImageRealESRGAN(outscale=4)
|
|
||||||
img = Image.open("input.png")
|
|
||||||
upscaled = sr_model(img)
|
|
||||||
upscaled.save("output.png")
|
|
||||||
```
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, outscale: int, model_path: str = None) -> None:
|
def __init__(self, outscale: int, model_path: str = None) -> None:
|
||||||
"""Initializes the RealESRGAN upscaler.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
outscale (int): Output scale factor.
|
|
||||||
model_path (str, optional): Path to model weights.
|
|
||||||
"""
|
|
||||||
# monkey patch to support torchvision>=0.16
|
# monkey patch to support torchvision>=0.16
|
||||||
import torchvision
|
import torchvision
|
||||||
from packaging import version
|
from packaging import version
|
||||||
@ -174,7 +122,6 @@ class ImageRealESRGAN:
|
|||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
|
|
||||||
def _lazy_init(self):
|
def _lazy_init(self):
|
||||||
"""Lazily initializes the RealESRGAN model."""
|
|
||||||
if self.upsampler is None:
|
if self.upsampler is None:
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
@ -198,14 +145,6 @@ class ImageRealESRGAN:
|
|||||||
|
|
||||||
@spaces.GPU
|
@spaces.GPU
|
||||||
def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
|
def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
|
||||||
"""Performs super-resolution on the input image.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image (Union[Image.Image, np.ndarray]): Input image.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Image.Image: Upscaled image.
|
|
||||||
"""
|
|
||||||
self._lazy_init()
|
self._lazy_init()
|
||||||
|
|
||||||
if isinstance(image, Image.Image):
|
if isinstance(image, Image.Image):
|
||||||
|
|||||||
@ -60,11 +60,6 @@ PROMPT_KAPPEND = "Single {object}, in the center of the image, white background,
|
|||||||
|
|
||||||
|
|
||||||
def download_kolors_weights(local_dir: str = "weights/Kolors") -> None:
|
def download_kolors_weights(local_dir: str = "weights/Kolors") -> None:
|
||||||
"""Downloads Kolors model weights from HuggingFace.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
local_dir (str, optional): Local directory to store weights.
|
|
||||||
"""
|
|
||||||
logger.info(f"Download kolors weights from huggingface...")
|
logger.info(f"Download kolors weights from huggingface...")
|
||||||
os.makedirs(local_dir, exist_ok=True)
|
os.makedirs(local_dir, exist_ok=True)
|
||||||
subprocess.run(
|
subprocess.run(
|
||||||
@ -98,22 +93,6 @@ def build_text2img_ip_pipeline(
|
|||||||
ref_scale: float,
|
ref_scale: float,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
) -> StableDiffusionXLPipelineIP:
|
) -> StableDiffusionXLPipelineIP:
|
||||||
"""Builds a Stable Diffusion XL pipeline with IP-Adapter for text-to-image generation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ckpt_dir (str): Directory containing model checkpoints.
|
|
||||||
ref_scale (float): Reference scale for IP-Adapter.
|
|
||||||
device (str, optional): Device for inference.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
StableDiffusionXLPipelineIP: Configured pipeline.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```py
|
|
||||||
from embodied_gen.models.text_model import build_text2img_ip_pipeline
|
|
||||||
pipe = build_text2img_ip_pipeline("weights/Kolors", ref_scale=0.3)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
download_kolors_weights(ckpt_dir)
|
download_kolors_weights(ckpt_dir)
|
||||||
|
|
||||||
text_encoder = ChatGLMModel.from_pretrained(
|
text_encoder = ChatGLMModel.from_pretrained(
|
||||||
@ -167,21 +146,6 @@ def build_text2img_pipeline(
|
|||||||
ckpt_dir: str,
|
ckpt_dir: str,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
) -> StableDiffusionXLPipeline:
|
) -> StableDiffusionXLPipeline:
|
||||||
"""Builds a Stable Diffusion XL pipeline for text-to-image generation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ckpt_dir (str): Directory containing model checkpoints.
|
|
||||||
device (str, optional): Device for inference.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
StableDiffusionXLPipeline: Configured pipeline.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```py
|
|
||||||
from embodied_gen.models.text_model import build_text2img_pipeline
|
|
||||||
pipe = build_text2img_pipeline("weights/Kolors")
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
download_kolors_weights(ckpt_dir)
|
download_kolors_weights(ckpt_dir)
|
||||||
|
|
||||||
text_encoder = ChatGLMModel.from_pretrained(
|
text_encoder = ChatGLMModel.from_pretrained(
|
||||||
@ -221,29 +185,6 @@ def text2img_gen(
|
|||||||
ip_image_size: int = 512,
|
ip_image_size: int = 512,
|
||||||
seed: int = None,
|
seed: int = None,
|
||||||
) -> list[Image.Image]:
|
) -> list[Image.Image]:
|
||||||
"""Generates images from text prompts using a Stable Diffusion XL pipeline.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt (str): Text prompt for image generation.
|
|
||||||
n_sample (int): Number of images to generate.
|
|
||||||
guidance_scale (float): Guidance scale for diffusion.
|
|
||||||
pipeline (StableDiffusionXLPipeline | StableDiffusionXLPipelineIP): Pipeline instance.
|
|
||||||
ip_image (Image.Image | str, optional): Reference image for IP-Adapter.
|
|
||||||
image_wh (tuple[int, int], optional): Output image size (width, height).
|
|
||||||
infer_step (int, optional): Number of inference steps.
|
|
||||||
ip_image_size (int, optional): Size for IP-Adapter image.
|
|
||||||
seed (int, optional): Random seed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[Image.Image]: List of generated images.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```py
|
|
||||||
from embodied_gen.models.text_model import text2img_gen
|
|
||||||
images = text2img_gen(prompt="banana", n_sample=3, guidance_scale=7.5)
|
|
||||||
images[0].save("banana.png")
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
prompt = PROMPT_KAPPEND.format(object=prompt.strip())
|
prompt = PROMPT_KAPPEND.format(object=prompt.strip())
|
||||||
logger.info(f"Processing prompt: {prompt}")
|
logger.info(f"Processing prompt: {prompt}")
|
||||||
|
|
||||||
|
|||||||
@ -42,56 +42,6 @@ def build_texture_gen_pipe(
|
|||||||
ip_adapt_scale: float = 0,
|
ip_adapt_scale: float = 0,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
) -> DiffusionPipeline:
|
) -> DiffusionPipeline:
|
||||||
"""Build and initialize the Kolors + ControlNet (optional IP-Adapter) texture generation pipeline.
|
|
||||||
|
|
||||||
Loads Kolors tokenizer, text encoder (ChatGLM), VAE, UNet, scheduler and (optionally)
|
|
||||||
a ControlNet checkpoint plus IP-Adapter vision encoder. If ``controlnet_ckpt`` is
|
|
||||||
not provided, the default multi-view texture ControlNet weights are downloaded
|
|
||||||
automatically from the hub. When ``ip_adapt_scale > 0`` an IP-Adapter vision
|
|
||||||
encoder and its weights are also loaded and activated.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
base_ckpt_dir (str):
|
|
||||||
Root directory where Kolors (and optionally Kolors-IP-Adapter-Plus) weights
|
|
||||||
are or will be stored. Required subfolders: ``Kolors/{text_encoder,vae,unet,scheduler}``.
|
|
||||||
controlnet_ckpt (str, optional):
|
|
||||||
Directory containing a ControlNet checkpoint (safetensors). If ``None``,
|
|
||||||
downloads the default ``texture_gen_mv_v1`` snapshot.
|
|
||||||
ip_adapt_scale (float, optional):
|
|
||||||
Strength (>=0) of IP-Adapter conditioning. Set >0 to enable IP-Adapter;
|
|
||||||
typical values: 0.4-0.8. Default: 0 (disabled).
|
|
||||||
device (str, optional):
|
|
||||||
Target device to move the pipeline to (e.g. ``"cuda"``, ``"cuda:0"``, ``"cpu"``).
|
|
||||||
Default: ``"cuda"``.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
DiffusionPipeline: A configured
|
|
||||||
``StableDiffusionXLControlNetImg2ImgPipeline`` ready for multi-view texture
|
|
||||||
generation (with optional IP-Adapter support).
|
|
||||||
|
|
||||||
Example:
|
|
||||||
Initialize pipeline with IP-Adapter enabled.
|
|
||||||
```python
|
|
||||||
from embodied_gen.models.texture_model import build_texture_gen_pipe
|
|
||||||
ip_adapt_scale = 0.7
|
|
||||||
PIPELINE = build_texture_gen_pipe(
|
|
||||||
base_ckpt_dir="./weights",
|
|
||||||
ip_adapt_scale=ip_adapt_scale,
|
|
||||||
device="cuda",
|
|
||||||
)
|
|
||||||
PIPELINE.set_ip_adapter_scale([ip_adapt_scale])
|
|
||||||
```
|
|
||||||
Initialize pipeline without IP-Adapter.
|
|
||||||
```python
|
|
||||||
from embodied_gen.models.texture_model import build_texture_gen_pipe
|
|
||||||
PIPELINE = build_texture_gen_pipe(
|
|
||||||
base_ckpt_dir="./weights",
|
|
||||||
ip_adapt_scale=0,
|
|
||||||
device="cuda",
|
|
||||||
)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
download_kolors_weights(f"{base_ckpt_dir}/Kolors")
|
download_kolors_weights(f"{base_ckpt_dir}/Kolors")
|
||||||
logger.info(f"Load Kolors weights...")
|
logger.info(f"Load Kolors weights...")
|
||||||
tokenizer = ChatGLMTokenizer.from_pretrained(
|
tokenizer = ChatGLMTokenizer.from_pretrained(
|
||||||
|
|||||||
@ -26,14 +26,12 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import trimesh
|
import trimesh
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from embodied_gen.data.backproject_v3 import entrypoint as backproject_api
|
from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
|
||||||
from embodied_gen.data.utils import delete_dir, trellis_preprocess
|
from embodied_gen.data.utils import delete_dir, trellis_preprocess
|
||||||
|
from embodied_gen.models.delight_model import DelightingModel
|
||||||
# from embodied_gen.models.delight_model import DelightingModel
|
|
||||||
from embodied_gen.models.gs_model import GaussianOperator
|
from embodied_gen.models.gs_model import GaussianOperator
|
||||||
from embodied_gen.models.segment_model import RembgRemover
|
from embodied_gen.models.segment_model import RembgRemover
|
||||||
|
from embodied_gen.models.sr_model import ImageRealESRGAN
|
||||||
# from embodied_gen.models.sr_model import ImageRealESRGAN
|
|
||||||
from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
|
from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
|
||||||
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
||||||
from embodied_gen.utils.log import logger
|
from embodied_gen.utils.log import logger
|
||||||
@ -61,8 +59,8 @@ os.environ["SPCONV_ALGO"] = "native"
|
|||||||
random.seed(0)
|
random.seed(0)
|
||||||
|
|
||||||
logger.info("Loading Image3D Models...")
|
logger.info("Loading Image3D Models...")
|
||||||
# DELIGHT = DelightingModel()
|
DELIGHT = DelightingModel()
|
||||||
# IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
||||||
RBG_REMOVER = RembgRemover()
|
RBG_REMOVER = RembgRemover()
|
||||||
PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
|
PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
|
||||||
"microsoft/TRELLIS-image-large"
|
"microsoft/TRELLIS-image-large"
|
||||||
@ -110,7 +108,9 @@ def parse_args():
|
|||||||
default=2,
|
default=2,
|
||||||
)
|
)
|
||||||
parser.add_argument("--disable_decompose_convex", action="store_true")
|
parser.add_argument("--disable_decompose_convex", action="store_true")
|
||||||
parser.add_argument("--texture_size", type=int, default=2048)
|
parser.add_argument(
|
||||||
|
"--texture_wh", type=int, nargs=2, default=[2048, 2048]
|
||||||
|
)
|
||||||
args, unknown = parser.parse_known_args()
|
args, unknown = parser.parse_known_args()
|
||||||
|
|
||||||
return args
|
return args
|
||||||
@ -248,14 +248,16 @@ def entrypoint(**kwargs):
|
|||||||
mesh.export(mesh_obj_path)
|
mesh.export(mesh_obj_path)
|
||||||
|
|
||||||
mesh = backproject_api(
|
mesh = backproject_api(
|
||||||
# delight_model=DELIGHT,
|
delight_model=DELIGHT,
|
||||||
# imagesr_model=IMAGESR_MODEL,
|
imagesr_model=IMAGESR_MODEL,
|
||||||
gs_path=aligned_gs_path,
|
color_path=color_path,
|
||||||
mesh_path=mesh_obj_path,
|
mesh_path=mesh_obj_path,
|
||||||
output_path=mesh_obj_path,
|
output_path=mesh_obj_path,
|
||||||
skip_fix_mesh=False,
|
skip_fix_mesh=False,
|
||||||
texture_size=args.texture_size,
|
delight=True,
|
||||||
delight=False,
|
texture_wh=args.texture_wh,
|
||||||
|
elevation=[20, -10, 60, -50],
|
||||||
|
num_images=12,
|
||||||
)
|
)
|
||||||
|
|
||||||
mesh_glb_path = os.path.join(output_root, f"{filename}.glb")
|
mesh_glb_path = os.path.join(output_root, f"{filename}.glb")
|
||||||
|
|||||||
@ -29,7 +29,7 @@ from embodied_gen.data.utils import (
|
|||||||
init_kal_camera,
|
init_kal_camera,
|
||||||
normalize_vertices_array,
|
normalize_vertices_array,
|
||||||
)
|
)
|
||||||
from embodied_gen.models.gs_model import load_gs_model
|
from embodied_gen.models.gs_model import GaussianOperator
|
||||||
from embodied_gen.utils.process_media import combine_images_to_grid
|
from embodied_gen.utils.process_media import combine_images_to_grid
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@ -97,6 +97,21 @@ def parse_args():
|
|||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def load_gs_model(
|
||||||
|
input_gs: str, pre_quat: list[float] = [0.0, 0.7071, 0.0, -0.7071]
|
||||||
|
) -> GaussianOperator:
|
||||||
|
gs_model = GaussianOperator.load_from_ply(input_gs)
|
||||||
|
# Normalize vertices to [-1, 1], center to (0, 0, 0).
|
||||||
|
_, scale, center = normalize_vertices_array(gs_model._means)
|
||||||
|
scale, center = float(scale), center.tolist()
|
||||||
|
transpose = [*[v for v in center], *pre_quat]
|
||||||
|
instance_pose = torch.tensor(transpose).to(gs_model.device)
|
||||||
|
gs_model = gs_model.get_gaussians(instance_pose=instance_pose)
|
||||||
|
gs_model.rescale(scale)
|
||||||
|
|
||||||
|
return gs_model
|
||||||
|
|
||||||
|
|
||||||
@spaces.GPU
|
@spaces.GPU
|
||||||
def entrypoint(**kwargs) -> None:
|
def entrypoint(**kwargs) -> None:
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|||||||
@ -53,31 +53,26 @@ from thirdparty.pano2room.utils.functions import (
|
|||||||
|
|
||||||
|
|
||||||
class Pano2MeshSRPipeline:
|
class Pano2MeshSRPipeline:
|
||||||
"""Pipeline for converting panoramic RGB images into 3D mesh representations.
|
"""Converting panoramic RGB image into 3D mesh representations, followed by inpainting and mesh refinement.
|
||||||
|
|
||||||
This class integrates depth estimation, inpainting, mesh conversion, multi-view mesh repair,
|
This class integrates several key components including:
|
||||||
and 3D Gaussian Splatting (3DGS) dataset generation.
|
- Depth estimation from RGB panorama
|
||||||
|
- Inpainting of missing regions under offsets
|
||||||
|
- RGB-D to mesh conversion
|
||||||
|
- Multi-view mesh repair
|
||||||
|
- 3D Gaussian Splatting (3DGS) dataset generation
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (Pano2MeshSRConfig): Configuration object containing model and pipeline parameters.
|
config (Pano2MeshSRConfig): Configuration object containing model and pipeline parameters.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```py
|
```python
|
||||||
from embodied_gen.trainer.pono2mesh_trainer import Pano2MeshSRPipeline
|
|
||||||
from embodied_gen.utils.config import Pano2MeshSRConfig
|
|
||||||
|
|
||||||
config = Pano2MeshSRConfig()
|
|
||||||
pipeline = Pano2MeshSRPipeline(config)
|
pipeline = Pano2MeshSRPipeline(config)
|
||||||
pipeline(pano_image='example.png', output_dir='./output')
|
pipeline(pano_image='example.png', output_dir='./output')
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: Pano2MeshSRConfig) -> None:
|
def __init__(self, config: Pano2MeshSRConfig) -> None:
|
||||||
"""Initializes the pipeline with models and camera poses.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (Pano2MeshSRConfig): Configuration object.
|
|
||||||
"""
|
|
||||||
self.cfg = config
|
self.cfg = config
|
||||||
self.device = config.device
|
self.device = config.device
|
||||||
|
|
||||||
@ -98,7 +93,6 @@ class Pano2MeshSRPipeline:
|
|||||||
self.kernel = torch.from_numpy(kernel).float().to(self.device)
|
self.kernel = torch.from_numpy(kernel).float().to(self.device)
|
||||||
|
|
||||||
def init_mesh_params(self) -> None:
|
def init_mesh_params(self) -> None:
|
||||||
"""Initializes mesh parameters and inpaint mask."""
|
|
||||||
torch.set_default_device(self.device)
|
torch.set_default_device(self.device)
|
||||||
self.inpaint_mask = torch.ones(
|
self.inpaint_mask = torch.ones(
|
||||||
(self.cfg.cubemap_h, self.cfg.cubemap_w), dtype=torch.bool
|
(self.cfg.cubemap_h, self.cfg.cubemap_w), dtype=torch.bool
|
||||||
@ -109,14 +103,6 @@ class Pano2MeshSRPipeline:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def read_camera_pose_file(filepath: str) -> np.ndarray:
|
def read_camera_pose_file(filepath: str) -> np.ndarray:
|
||||||
"""Reads a camera pose file and returns the pose matrix.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filepath (str): Path to the camera pose file.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: 4x4 camera pose matrix.
|
|
||||||
"""
|
|
||||||
with open(filepath, "r") as f:
|
with open(filepath, "r") as f:
|
||||||
values = [float(num) for line in f for num in line.split()]
|
values = [float(num) for line in f for num in line.split()]
|
||||||
|
|
||||||
@ -125,14 +111,6 @@ class Pano2MeshSRPipeline:
|
|||||||
def load_camera_poses(
|
def load_camera_poses(
|
||||||
self, trajectory_dir: str
|
self, trajectory_dir: str
|
||||||
) -> tuple[np.ndarray, list[torch.Tensor]]:
|
) -> tuple[np.ndarray, list[torch.Tensor]]:
|
||||||
"""Loads camera poses from a directory.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
trajectory_dir (str): Directory containing camera pose files.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[np.ndarray, list[torch.Tensor]]: List of relative camera poses.
|
|
||||||
"""
|
|
||||||
pose_filenames = sorted(
|
pose_filenames = sorted(
|
||||||
[
|
[
|
||||||
fname
|
fname
|
||||||
@ -170,14 +148,6 @@ class Pano2MeshSRPipeline:
|
|||||||
def load_inpaint_poses(
|
def load_inpaint_poses(
|
||||||
self, poses: torch.Tensor
|
self, poses: torch.Tensor
|
||||||
) -> dict[int, torch.Tensor]:
|
) -> dict[int, torch.Tensor]:
|
||||||
"""Samples and loads poses for inpainting.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
poses (torch.Tensor): Tensor of camera poses.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict[int, torch.Tensor]: Dictionary mapping indices to pose tensors.
|
|
||||||
"""
|
|
||||||
inpaint_poses = dict()
|
inpaint_poses = dict()
|
||||||
sampled_views = poses[:: self.cfg.inpaint_frame_stride]
|
sampled_views = poses[:: self.cfg.inpaint_frame_stride]
|
||||||
init_pose = torch.eye(4)
|
init_pose = torch.eye(4)
|
||||||
@ -192,14 +162,6 @@ class Pano2MeshSRPipeline:
|
|||||||
return inpaint_poses
|
return inpaint_poses
|
||||||
|
|
||||||
def project(self, world_to_cam: torch.Tensor):
|
def project(self, world_to_cam: torch.Tensor):
|
||||||
"""Projects the mesh to an image using the given camera pose.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
world_to_cam (torch.Tensor): World-to-camera transformation matrix.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Projected RGB image, inpaint mask, and depth map.
|
|
||||||
"""
|
|
||||||
(
|
(
|
||||||
project_image,
|
project_image,
|
||||||
project_depth,
|
project_depth,
|
||||||
@ -223,14 +185,6 @@ class Pano2MeshSRPipeline:
|
|||||||
return project_image[:3, ...], inpaint_mask, project_depth
|
return project_image[:3, ...], inpaint_mask, project_depth
|
||||||
|
|
||||||
def render_pano(self, pose: torch.Tensor):
|
def render_pano(self, pose: torch.Tensor):
|
||||||
"""Renders a panorama from the mesh using the given pose.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pose (torch.Tensor): Camera pose.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: RGB panorama, depth map, and mask.
|
|
||||||
"""
|
|
||||||
cubemap_list = []
|
cubemap_list = []
|
||||||
for cubemap_pose in self.cubemap_w2cs:
|
for cubemap_pose in self.cubemap_w2cs:
|
||||||
project_pose = cubemap_pose @ pose
|
project_pose = cubemap_pose @ pose
|
||||||
@ -259,15 +213,6 @@ class Pano2MeshSRPipeline:
|
|||||||
world_to_cam: torch.Tensor = None,
|
world_to_cam: torch.Tensor = None,
|
||||||
using_distance_map: bool = True,
|
using_distance_map: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Converts RGB-D images to mesh and updates mesh parameters.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
rgb (torch.Tensor): RGB image tensor.
|
|
||||||
depth (torch.Tensor): Depth map tensor.
|
|
||||||
inpaint_mask (torch.Tensor): Inpaint mask tensor.
|
|
||||||
world_to_cam (torch.Tensor, optional): Camera pose.
|
|
||||||
using_distance_map (bool, optional): Whether to use distance map.
|
|
||||||
"""
|
|
||||||
if world_to_cam is None:
|
if world_to_cam is None:
|
||||||
world_to_cam = torch.eye(4, dtype=torch.float32).to(self.device)
|
world_to_cam = torch.eye(4, dtype=torch.float32).to(self.device)
|
||||||
|
|
||||||
@ -294,15 +239,6 @@ class Pano2MeshSRPipeline:
|
|||||||
def get_edge_image_by_depth(
|
def get_edge_image_by_depth(
|
||||||
self, depth: torch.Tensor, dilate_iter: int = 1
|
self, depth: torch.Tensor, dilate_iter: int = 1
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Computes edge image from depth map.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
depth (torch.Tensor): Depth map tensor.
|
|
||||||
dilate_iter (int, optional): Number of dilation iterations.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Edge image.
|
|
||||||
"""
|
|
||||||
if isinstance(depth, torch.Tensor):
|
if isinstance(depth, torch.Tensor):
|
||||||
depth = depth.cpu().detach().numpy()
|
depth = depth.cpu().detach().numpy()
|
||||||
|
|
||||||
@ -317,15 +253,6 @@ class Pano2MeshSRPipeline:
|
|||||||
def mesh_repair_by_greedy_view_selection(
|
def mesh_repair_by_greedy_view_selection(
|
||||||
self, pose_dict: dict[str, torch.Tensor], output_dir: str
|
self, pose_dict: dict[str, torch.Tensor], output_dir: str
|
||||||
) -> list:
|
) -> list:
|
||||||
"""Repairs mesh by selecting views greedily and inpainting missing regions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pose_dict (dict[str, torch.Tensor]): Dictionary of poses for inpainting.
|
|
||||||
output_dir (str): Directory to save visualizations.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: List of inpainted panoramas with poses.
|
|
||||||
"""
|
|
||||||
inpainted_panos_w_pose = []
|
inpainted_panos_w_pose = []
|
||||||
while len(pose_dict) > 0:
|
while len(pose_dict) > 0:
|
||||||
logger.info(f"Repairing mesh left rounds {len(pose_dict)}")
|
logger.info(f"Repairing mesh left rounds {len(pose_dict)}")
|
||||||
@ -416,17 +343,6 @@ class Pano2MeshSRPipeline:
|
|||||||
distances: torch.Tensor,
|
distances: torch.Tensor,
|
||||||
pano_mask: torch.Tensor,
|
pano_mask: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor]:
|
) -> tuple[torch.Tensor]:
|
||||||
"""Inpaints missing regions in a panorama.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
idx (int): Index of the panorama.
|
|
||||||
colors (torch.Tensor): RGB image tensor.
|
|
||||||
distances (torch.Tensor): Distance map tensor.
|
|
||||||
pano_mask (torch.Tensor): Mask tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[torch.Tensor]: Inpainted RGB image, distances, and normals.
|
|
||||||
"""
|
|
||||||
mask = (pano_mask[None, ..., None] > 0.5).float()
|
mask = (pano_mask[None, ..., None] > 0.5).float()
|
||||||
mask = mask.permute(0, 3, 1, 2)
|
mask = mask.permute(0, 3, 1, 2)
|
||||||
mask = dilation(mask, kernel=self.kernel)
|
mask = dilation(mask, kernel=self.kernel)
|
||||||
@ -448,14 +364,6 @@ class Pano2MeshSRPipeline:
|
|||||||
def preprocess_pano(
|
def preprocess_pano(
|
||||||
self, image: Image.Image | str
|
self, image: Image.Image | str
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Preprocesses a panoramic image for mesh generation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image (Image.Image | str): Input image or path.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[torch.Tensor, torch.Tensor]: Preprocessed RGB and depth tensors.
|
|
||||||
"""
|
|
||||||
if isinstance(image, str):
|
if isinstance(image, str):
|
||||||
image = Image.open(image)
|
image = Image.open(image)
|
||||||
|
|
||||||
@ -479,17 +387,6 @@ class Pano2MeshSRPipeline:
|
|||||||
def pano_to_perpective(
|
def pano_to_perpective(
|
||||||
self, pano_image: torch.Tensor, pitch: float, yaw: float, fov: float
|
self, pano_image: torch.Tensor, pitch: float, yaw: float, fov: float
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Converts a panoramic image to a perspective view.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pano_image (torch.Tensor): Panoramic image tensor.
|
|
||||||
pitch (float): Pitch angle.
|
|
||||||
yaw (float): Yaw angle.
|
|
||||||
fov (float): Field of view.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Perspective image tensor.
|
|
||||||
"""
|
|
||||||
rots = dict(
|
rots = dict(
|
||||||
roll=0,
|
roll=0,
|
||||||
pitch=pitch,
|
pitch=pitch,
|
||||||
@ -507,14 +404,6 @@ class Pano2MeshSRPipeline:
|
|||||||
return perspective
|
return perspective
|
||||||
|
|
||||||
def pano_to_cubemap(self, pano_rgb: torch.Tensor):
|
def pano_to_cubemap(self, pano_rgb: torch.Tensor):
|
||||||
"""Converts a panoramic RGB image to six cubemap views.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pano_rgb (torch.Tensor): Panoramic RGB image tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: List of cubemap RGB tensors.
|
|
||||||
"""
|
|
||||||
# Define six canonical cube directions in (pitch, yaw)
|
# Define six canonical cube directions in (pitch, yaw)
|
||||||
directions = [
|
directions = [
|
||||||
(0, 0),
|
(0, 0),
|
||||||
@ -535,11 +424,6 @@ class Pano2MeshSRPipeline:
|
|||||||
return cubemaps_rgb
|
return cubemaps_rgb
|
||||||
|
|
||||||
def save_mesh(self, output_path: str) -> None:
|
def save_mesh(self, output_path: str) -> None:
|
||||||
"""Saves the mesh to a file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
output_path (str): Path to save the mesh file.
|
|
||||||
"""
|
|
||||||
vertices_np = self.vertices.T.cpu().numpy()
|
vertices_np = self.vertices.T.cpu().numpy()
|
||||||
colors_np = self.colors.T.cpu().numpy()
|
colors_np = self.colors.T.cpu().numpy()
|
||||||
faces_np = self.faces.T.cpu().numpy()
|
faces_np = self.faces.T.cpu().numpy()
|
||||||
@ -550,14 +434,6 @@ class Pano2MeshSRPipeline:
|
|||||||
mesh.export(output_path)
|
mesh.export(output_path)
|
||||||
|
|
||||||
def mesh_pose_to_gs_pose(self, mesh_pose: torch.Tensor) -> np.ndarray:
|
def mesh_pose_to_gs_pose(self, mesh_pose: torch.Tensor) -> np.ndarray:
|
||||||
"""Converts mesh pose to 3D Gaussian Splatting pose.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mesh_pose (torch.Tensor): Mesh pose tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Converted pose matrix.
|
|
||||||
"""
|
|
||||||
pose = mesh_pose.clone()
|
pose = mesh_pose.clone()
|
||||||
pose[0, :] *= -1
|
pose[0, :] *= -1
|
||||||
pose[1, :] *= -1
|
pose[1, :] *= -1
|
||||||
@ -574,15 +450,6 @@ class Pano2MeshSRPipeline:
|
|||||||
return c2w
|
return c2w
|
||||||
|
|
||||||
def __call__(self, pano_image: Image.Image | str, output_dir: str):
|
def __call__(self, pano_image: Image.Image | str, output_dir: str):
|
||||||
"""Runs the pipeline to generate mesh and 3DGS data from a panoramic image.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pano_image (Image.Image | str): Input panoramic image or path.
|
|
||||||
output_dir (str): Directory to save outputs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
self.init_mesh_params()
|
self.init_mesh_params()
|
||||||
pano_rgb, pano_depth = self.preprocess_pano(pano_image)
|
pano_rgb, pano_depth = self.preprocess_pano(pano_image)
|
||||||
self.sup_pool = SupInfoPool()
|
self.sup_pool = SupInfoPool()
|
||||||
|
|||||||
@ -24,27 +24,11 @@ __all__ = [
|
|||||||
"Scene3DItemEnum",
|
"Scene3DItemEnum",
|
||||||
"SpatialRelationEnum",
|
"SpatialRelationEnum",
|
||||||
"RobotItemEnum",
|
"RobotItemEnum",
|
||||||
"LayoutInfo",
|
|
||||||
"AssetType",
|
|
||||||
"SimAssetMapper",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RenderItems(str, Enum):
|
class RenderItems(str, Enum):
|
||||||
"""Enumeration of render item types for 3D scenes.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
IMAGE: Color image.
|
|
||||||
ALPHA: Mask image.
|
|
||||||
VIEW_NORMAL: View-space normal image.
|
|
||||||
GLOBAL_NORMAL: World-space normal image.
|
|
||||||
POSITION_MAP: Position map image.
|
|
||||||
DEPTH: Depth image.
|
|
||||||
ALBEDO: Albedo image.
|
|
||||||
DIFFUSE: Diffuse image.
|
|
||||||
"""
|
|
||||||
|
|
||||||
IMAGE = "image_color"
|
IMAGE = "image_color"
|
||||||
ALPHA = "image_mask"
|
ALPHA = "image_mask"
|
||||||
VIEW_NORMAL = "image_view_normal"
|
VIEW_NORMAL = "image_view_normal"
|
||||||
@ -57,21 +41,6 @@ class RenderItems(str, Enum):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Scene3DItemEnum(str, Enum):
|
class Scene3DItemEnum(str, Enum):
|
||||||
"""Enumeration of 3D scene item categories.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
BACKGROUND: Background objects.
|
|
||||||
CONTEXT: Contextual objects.
|
|
||||||
ROBOT: Robot entity.
|
|
||||||
MANIPULATED_OBJS: Objects manipulated by the robot.
|
|
||||||
DISTRACTOR_OBJS: Distractor objects.
|
|
||||||
OTHERS: Other objects.
|
|
||||||
|
|
||||||
Methods:
|
|
||||||
object_list(layout_relation): Returns a list of objects in the scene.
|
|
||||||
object_mapping(layout_relation): Returns a mapping from object to category.
|
|
||||||
"""
|
|
||||||
|
|
||||||
BACKGROUND = "background"
|
BACKGROUND = "background"
|
||||||
CONTEXT = "context"
|
CONTEXT = "context"
|
||||||
ROBOT = "robot"
|
ROBOT = "robot"
|
||||||
@ -81,14 +50,6 @@ class Scene3DItemEnum(str, Enum):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def object_list(cls, layout_relation: dict) -> list:
|
def object_list(cls, layout_relation: dict) -> list:
|
||||||
"""Returns a list of objects in the scene.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layout_relation: Dictionary mapping categories to objects.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of objects in the scene.
|
|
||||||
"""
|
|
||||||
return (
|
return (
|
||||||
[
|
[
|
||||||
layout_relation[cls.BACKGROUND.value],
|
layout_relation[cls.BACKGROUND.value],
|
||||||
@ -100,14 +61,6 @@ class Scene3DItemEnum(str, Enum):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def object_mapping(cls, layout_relation):
|
def object_mapping(cls, layout_relation):
|
||||||
"""Returns a mapping from object to category.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layout_relation: Dictionary mapping categories to objects.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary mapping object names to their category.
|
|
||||||
"""
|
|
||||||
relation_mapping = {
|
relation_mapping = {
|
||||||
# layout_relation[cls.ROBOT.value]: cls.ROBOT.value,
|
# layout_relation[cls.ROBOT.value]: cls.ROBOT.value,
|
||||||
layout_relation[cls.BACKGROUND.value]: cls.BACKGROUND.value,
|
layout_relation[cls.BACKGROUND.value]: cls.BACKGROUND.value,
|
||||||
@ -131,15 +84,6 @@ class Scene3DItemEnum(str, Enum):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SpatialRelationEnum(str, Enum):
|
class SpatialRelationEnum(str, Enum):
|
||||||
"""Enumeration of spatial relations for objects in a scene.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
ON: Objects on a surface (e.g., table).
|
|
||||||
IN: Objects in a container or room.
|
|
||||||
INSIDE: Objects inside a shelf or rack.
|
|
||||||
FLOOR: Objects on the floor.
|
|
||||||
"""
|
|
||||||
|
|
||||||
ON = "ON" # objects on the table
|
ON = "ON" # objects on the table
|
||||||
IN = "IN" # objects in the room
|
IN = "IN" # objects in the room
|
||||||
INSIDE = "INSIDE" # objects inside the shelf/rack
|
INSIDE = "INSIDE" # objects inside the shelf/rack
|
||||||
@ -148,14 +92,6 @@ class SpatialRelationEnum(str, Enum):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RobotItemEnum(str, Enum):
|
class RobotItemEnum(str, Enum):
|
||||||
"""Enumeration of supported robot types.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
FRANKA: Franka robot.
|
|
||||||
UR5: UR5 robot.
|
|
||||||
PIPER: Piper robot.
|
|
||||||
"""
|
|
||||||
|
|
||||||
FRANKA = "franka"
|
FRANKA = "franka"
|
||||||
UR5 = "ur5"
|
UR5 = "ur5"
|
||||||
PIPER = "piper"
|
PIPER = "piper"
|
||||||
@ -163,18 +99,6 @@ class RobotItemEnum(str, Enum):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LayoutInfo(DataClassJsonMixin):
|
class LayoutInfo(DataClassJsonMixin):
|
||||||
"""Data structure for layout information in a 3D scene.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
tree: Hierarchical structure of scene objects.
|
|
||||||
relation: Spatial relations between objects.
|
|
||||||
objs_desc: Descriptions of objects.
|
|
||||||
objs_mapping: Mapping from object names to categories.
|
|
||||||
assets: Asset file paths for objects.
|
|
||||||
quality: Quality information for assets.
|
|
||||||
position: Position coordinates for objects.
|
|
||||||
"""
|
|
||||||
|
|
||||||
tree: dict[str, list]
|
tree: dict[str, list]
|
||||||
relation: dict[str, str | list[str]]
|
relation: dict[str, str | list[str]]
|
||||||
objs_desc: dict[str, str] = field(default_factory=dict)
|
objs_desc: dict[str, str] = field(default_factory=dict)
|
||||||
@ -182,64 +106,3 @@ class LayoutInfo(DataClassJsonMixin):
|
|||||||
assets: dict[str, str] = field(default_factory=dict)
|
assets: dict[str, str] = field(default_factory=dict)
|
||||||
quality: dict[str, str] = field(default_factory=dict)
|
quality: dict[str, str] = field(default_factory=dict)
|
||||||
position: dict[str, list[float]] = field(default_factory=dict)
|
position: dict[str, list[float]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AssetType(str):
|
|
||||||
"""Enumeration for asset types.
|
|
||||||
|
|
||||||
Supported types:
|
|
||||||
MJCF: MuJoCo XML format.
|
|
||||||
USD: Universal Scene Description format.
|
|
||||||
URDF: Unified Robot Description Format.
|
|
||||||
MESH: Mesh file format.
|
|
||||||
"""
|
|
||||||
|
|
||||||
MJCF = "mjcf"
|
|
||||||
USD = "usd"
|
|
||||||
URDF = "urdf"
|
|
||||||
MESH = "mesh"
|
|
||||||
|
|
||||||
|
|
||||||
class SimAssetMapper:
|
|
||||||
"""Maps simulator names to asset types.
|
|
||||||
|
|
||||||
Provides a mapping from simulator names to their corresponding asset type.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```py
|
|
||||||
from embodied_gen.utils.enum import SimAssetMapper
|
|
||||||
asset_type = SimAssetMapper["isaacsim"]
|
|
||||||
print(asset_type) # Output: 'usd'
|
|
||||||
```
|
|
||||||
|
|
||||||
Methods:
|
|
||||||
__class_getitem__(key): Returns the asset type for a given simulator name.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_mapping = dict(
|
|
||||||
ISAACSIM=AssetType.USD,
|
|
||||||
ISAACGYM=AssetType.URDF,
|
|
||||||
MUJOCO=AssetType.MJCF,
|
|
||||||
GENESIS=AssetType.MJCF,
|
|
||||||
SAPIEN=AssetType.URDF,
|
|
||||||
PYBULLET=AssetType.URDF,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __class_getitem__(cls, key: str):
|
|
||||||
"""Returns the asset type for a given simulator name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key: Name of the simulator.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AssetType corresponding to the simulator.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
KeyError: If the simulator name is not recognized.
|
|
||||||
"""
|
|
||||||
key = key.upper()
|
|
||||||
if key.startswith("SAPIEN"):
|
|
||||||
key = "SAPIEN"
|
|
||||||
return cls._mapping[key]
|
|
||||||
|
|||||||
@ -45,13 +45,13 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
def matrix_to_pose(matrix: np.ndarray) -> list[float]:
|
def matrix_to_pose(matrix: np.ndarray) -> list[float]:
|
||||||
"""Converts a 4x4 transformation matrix to a pose (x, y, z, qx, qy, qz, qw).
|
"""Convert a 4x4 transformation matrix to a pose (x, y, z, qx, qy, qz, qw).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
matrix (np.ndarray): 4x4 transformation matrix.
|
matrix (np.ndarray): 4x4 transformation matrix.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[float]: Pose as [x, y, z, qx, qy, qz, qw].
|
List[float]: Pose as [x, y, z, qx, qy, qz, qw].
|
||||||
"""
|
"""
|
||||||
x, y, z = matrix[:3, 3]
|
x, y, z = matrix[:3, 3]
|
||||||
rot_mat = matrix[:3, :3]
|
rot_mat = matrix[:3, :3]
|
||||||
@ -62,13 +62,13 @@ def matrix_to_pose(matrix: np.ndarray) -> list[float]:
|
|||||||
|
|
||||||
|
|
||||||
def pose_to_matrix(pose: list[float]) -> np.ndarray:
|
def pose_to_matrix(pose: list[float]) -> np.ndarray:
|
||||||
"""Converts pose (x, y, z, qx, qy, qz, qw) to a 4x4 transformation matrix.
|
"""Convert pose (x, y, z, qx, qy, qz, qw) to a 4x4 transformation matrix.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pose (list[float]): Pose as [x, y, z, qx, qy, qz, qw].
|
List[float]: Pose as [x, y, z, qx, qy, qz, qw].
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
np.ndarray: 4x4 transformation matrix.
|
matrix (np.ndarray): 4x4 transformation matrix.
|
||||||
"""
|
"""
|
||||||
x, y, z, qx, qy, qz, qw = pose
|
x, y, z, qx, qy, qz, qw = pose
|
||||||
r = R.from_quat([qx, qy, qz, qw])
|
r = R.from_quat([qx, qy, qz, qw])
|
||||||
@ -82,16 +82,6 @@ def pose_to_matrix(pose: list[float]) -> np.ndarray:
|
|||||||
def compute_xy_bbox(
|
def compute_xy_bbox(
|
||||||
vertices: np.ndarray, col_x: int = 0, col_y: int = 1
|
vertices: np.ndarray, col_x: int = 0, col_y: int = 1
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
"""Computes the bounding box in XY plane for given vertices.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
vertices (np.ndarray): Vertex coordinates.
|
|
||||||
col_x (int, optional): Column index for X.
|
|
||||||
col_y (int, optional): Column index for Y.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[float]: [min_x, max_x, min_y, max_y]
|
|
||||||
"""
|
|
||||||
x_vals = vertices[:, col_x]
|
x_vals = vertices[:, col_x]
|
||||||
y_vals = vertices[:, col_y]
|
y_vals = vertices[:, col_y]
|
||||||
return x_vals.min(), x_vals.max(), y_vals.min(), y_vals.max()
|
return x_vals.min(), x_vals.max(), y_vals.min(), y_vals.max()
|
||||||
@ -102,16 +92,6 @@ def has_iou_conflict(
|
|||||||
placed_boxes: list[list[float]],
|
placed_boxes: list[list[float]],
|
||||||
iou_threshold: float = 0.0,
|
iou_threshold: float = 0.0,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Checks for intersection-over-union conflict between boxes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
new_box (list[float]): New box coordinates.
|
|
||||||
placed_boxes (list[list[float]]): List of placed box coordinates.
|
|
||||||
iou_threshold (float, optional): IOU threshold.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if conflict exists, False otherwise.
|
|
||||||
"""
|
|
||||||
new_min_x, new_max_x, new_min_y, new_max_y = new_box
|
new_min_x, new_max_x, new_min_y, new_max_y = new_box
|
||||||
for min_x, max_x, min_y, max_y in placed_boxes:
|
for min_x, max_x, min_y, max_y in placed_boxes:
|
||||||
ix1 = max(new_min_x, min_x)
|
ix1 = max(new_min_x, min_x)
|
||||||
@ -125,14 +105,7 @@ def has_iou_conflict(
|
|||||||
|
|
||||||
|
|
||||||
def with_seed(seed_attr_name: str = "seed"):
|
def with_seed(seed_attr_name: str = "seed"):
|
||||||
"""Decorator to temporarily set the random seed for reproducibility.
|
"""A parameterized decorator that temporarily sets the random seed."""
|
||||||
|
|
||||||
Args:
|
|
||||||
seed_attr_name (str, optional): Name of the seed argument.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
function: Decorator function.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
@ -170,20 +143,6 @@ def compute_convex_hull_path(
|
|||||||
y_axis: int = 1,
|
y_axis: int = 1,
|
||||||
z_axis: int = 2,
|
z_axis: int = 2,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
"""Computes a dense convex hull path for the top surface of a mesh.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
vertices (np.ndarray): Mesh vertices.
|
|
||||||
z_threshold (float, optional): Z threshold for top surface.
|
|
||||||
interp_per_edge (int, optional): Interpolation points per edge.
|
|
||||||
margin (float, optional): Margin for polygon buffer.
|
|
||||||
x_axis (int, optional): X axis index.
|
|
||||||
y_axis (int, optional): Y axis index.
|
|
||||||
z_axis (int, optional): Z axis index.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path: Matplotlib path object for the convex hull.
|
|
||||||
"""
|
|
||||||
top_vertices = vertices[
|
top_vertices = vertices[
|
||||||
vertices[:, z_axis] > vertices[:, z_axis].max() - z_threshold
|
vertices[:, z_axis] > vertices[:, z_axis].max() - z_threshold
|
||||||
]
|
]
|
||||||
@ -211,15 +170,6 @@ def compute_convex_hull_path(
|
|||||||
|
|
||||||
|
|
||||||
def find_parent_node(node: str, tree: dict) -> str | None:
|
def find_parent_node(node: str, tree: dict) -> str | None:
|
||||||
"""Finds the parent node of a given node in a tree.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node (str): Node name.
|
|
||||||
tree (dict): Tree structure.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str | None: Parent node name or None.
|
|
||||||
"""
|
|
||||||
for parent, children in tree.items():
|
for parent, children in tree.items():
|
||||||
if any(child[0] == node for child in children):
|
if any(child[0] == node for child in children):
|
||||||
return parent
|
return parent
|
||||||
@ -227,16 +177,6 @@ def find_parent_node(node: str, tree: dict) -> str | None:
|
|||||||
|
|
||||||
|
|
||||||
def all_corners_inside(hull: Path, box: list, threshold: int = 3) -> bool:
|
def all_corners_inside(hull: Path, box: list, threshold: int = 3) -> bool:
|
||||||
"""Checks if at least `threshold` corners of a box are inside a hull.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hull (Path): Convex hull path.
|
|
||||||
box (list): Box coordinates [x1, x2, y1, y2].
|
|
||||||
threshold (int, optional): Minimum corners inside.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if enough corners are inside.
|
|
||||||
"""
|
|
||||||
x1, x2, y1, y2 = box
|
x1, x2, y1, y2 = box
|
||||||
corners = [[x1, y1], [x2, y1], [x1, y2], [x2, y2]]
|
corners = [[x1, y1], [x2, y1], [x1, y2], [x2, y2]]
|
||||||
|
|
||||||
@ -247,15 +187,6 @@ def all_corners_inside(hull: Path, box: list, threshold: int = 3) -> bool:
|
|||||||
def compute_axis_rotation_quat(
|
def compute_axis_rotation_quat(
|
||||||
axis: Literal["x", "y", "z"], angle_rad: float
|
axis: Literal["x", "y", "z"], angle_rad: float
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
"""Computes quaternion for rotation around a given axis.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
axis (Literal["x", "y", "z"]): Axis of rotation.
|
|
||||||
angle_rad (float): Rotation angle in radians.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[float]: Quaternion [x, y, z, w].
|
|
||||||
"""
|
|
||||||
if axis.lower() == "x":
|
if axis.lower() == "x":
|
||||||
q = Quaternion(axis=[1, 0, 0], angle=angle_rad)
|
q = Quaternion(axis=[1, 0, 0], angle=angle_rad)
|
||||||
elif axis.lower() == "y":
|
elif axis.lower() == "y":
|
||||||
@ -271,15 +202,6 @@ def compute_axis_rotation_quat(
|
|||||||
def quaternion_multiply(
|
def quaternion_multiply(
|
||||||
init_quat: list[float], rotate_quat: list[float]
|
init_quat: list[float], rotate_quat: list[float]
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
"""Multiplies two quaternions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
init_quat (list[float]): Initial quaternion [x, y, z, w].
|
|
||||||
rotate_quat (list[float]): Rotation quaternion [x, y, z, w].
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[float]: Resulting quaternion [x, y, z, w].
|
|
||||||
"""
|
|
||||||
qx, qy, qz, qw = init_quat
|
qx, qy, qz, qw = init_quat
|
||||||
q1 = Quaternion(w=qw, x=qx, y=qy, z=qz)
|
q1 = Quaternion(w=qw, x=qx, y=qy, z=qz)
|
||||||
qx, qy, qz, qw = rotate_quat
|
qx, qy, qz, qw = rotate_quat
|
||||||
@ -295,17 +217,7 @@ def check_reachable(
|
|||||||
min_reach: float = 0.25,
|
min_reach: float = 0.25,
|
||||||
max_reach: float = 0.85,
|
max_reach: float = 0.85,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Checks if the target point is within the reachable range.
|
"""Check if the target point is within the reachable range."""
|
||||||
|
|
||||||
Args:
|
|
||||||
base_xyz (np.ndarray): Base position.
|
|
||||||
reach_xyz (np.ndarray): Target position.
|
|
||||||
min_reach (float, optional): Minimum reach distance.
|
|
||||||
max_reach (float, optional): Maximum reach distance.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if reachable, False otherwise.
|
|
||||||
"""
|
|
||||||
distance = np.linalg.norm(reach_xyz - base_xyz)
|
distance = np.linalg.norm(reach_xyz - base_xyz)
|
||||||
|
|
||||||
return min_reach < distance < max_reach
|
return min_reach < distance < max_reach
|
||||||
@ -326,31 +238,26 @@ def bfs_placement(
|
|||||||
robot_dim: float = 0.12,
|
robot_dim: float = 0.12,
|
||||||
seed: int = None,
|
seed: int = None,
|
||||||
) -> LayoutInfo:
|
) -> LayoutInfo:
|
||||||
"""Places objects in a scene layout using BFS traversal.
|
"""Place objects in the layout using BFS traversal.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
layout_file (str): Path to layout JSON file generated from `layout-cli`.
|
layout_file: Path to the JSON file defining the layout structure and assets.
|
||||||
floor_margin (float, optional): Z-offset for objects placed on the floor.
|
floor_margin: Z-offset for the background object, typically for objects placed on the floor.
|
||||||
beside_margin (float, optional): Minimum margin for objects placed 'beside' their parent, used when 'on' placement fails.
|
beside_margin: Minimum margin for objects placed 'beside' their parent, used when 'on' placement fails.
|
||||||
max_attempts (int, optional): Max attempts for a non-overlapping placement.
|
max_attempts: Maximum number of attempts to find a non-overlapping position for an object.
|
||||||
init_rpy (tuple, optional): Initial rotation (rpy).
|
init_rpy: Initial Roll-Pitch-Yaw rotation rad applied to all object meshes to align the mesh's
|
||||||
rotate_objs (bool, optional): Whether to random rotate objects.
|
coordinate system with the world's (e.g., Z-up).
|
||||||
rotate_bg (bool, optional): Whether to random rotate background.
|
rotate_objs: If True, apply a random rotation around the Z-axis for manipulated and distractor objects.
|
||||||
rotate_context (bool, optional): Whether to random rotate context asset.
|
rotate_bg: If True, apply a random rotation around the Y-axis for the background object.
|
||||||
limit_reach_range (tuple[float, float] | None, optional): If set, enforce a check that manipulated objects are within the robot's reach range, in meter.
|
rotate_context: If True, apply a random rotation around the Z-axis for the context object.
|
||||||
max_orient_diff (float | None, optional): If set, enforce a check that manipulated objects are within the robot's orientation range, in degree.
|
limit_reach_range: If set, enforce a check that manipulated objects are within the robot's reach range, in meter.
|
||||||
robot_dim (float, optional): The approximate robot size.
|
max_orient_diff: If set, enforce a check that manipulated objects are within the robot's orientation range, in degree.
|
||||||
seed (int, optional): Random seed for reproducible placement.
|
robot_dim: The approximate dimension (e.g., diameter) of the robot for box representation.
|
||||||
|
seed: Random seed for reproducible placement.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
LayoutInfo: Layout information with object poses.
|
A :class:`LayoutInfo` object containing the objects and their final computed 7D poses
|
||||||
|
([x, y, z, qx, qy, qz, qw]).
|
||||||
Example:
|
|
||||||
```py
|
|
||||||
from embodied_gen.utils.geometry import bfs_placement
|
|
||||||
layout = bfs_placement("scene_layout.json", seed=42)
|
|
||||||
print(layout.position)
|
|
||||||
```
|
|
||||||
"""
|
"""
|
||||||
layout_info = LayoutInfo.from_dict(json.load(open(layout_file, "r")))
|
layout_info = LayoutInfo.from_dict(json.load(open(layout_file, "r")))
|
||||||
asset_dir = os.path.dirname(layout_file)
|
asset_dir = os.path.dirname(layout_file)
|
||||||
@ -571,13 +478,6 @@ def bfs_placement(
|
|||||||
def compose_mesh_scene(
|
def compose_mesh_scene(
|
||||||
layout_info: LayoutInfo, out_scene_path: str, with_bg: bool = False
|
layout_info: LayoutInfo, out_scene_path: str, with_bg: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Composes a mesh scene from layout information and saves to file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layout_info (LayoutInfo): Layout information.
|
|
||||||
out_scene_path (str): Output scene file path.
|
|
||||||
with_bg (bool, optional): Include background mesh.
|
|
||||||
"""
|
|
||||||
object_mapping = Scene3DItemEnum.object_mapping(layout_info.relation)
|
object_mapping = Scene3DItemEnum.object_mapping(layout_info.relation)
|
||||||
scene = trimesh.Scene()
|
scene = trimesh.Scene()
|
||||||
for node in layout_info.assets:
|
for node in layout_info.assets:
|
||||||
@ -605,16 +505,6 @@ def compose_mesh_scene(
|
|||||||
def compute_pinhole_intrinsics(
|
def compute_pinhole_intrinsics(
|
||||||
image_w: int, image_h: int, fov_deg: float
|
image_w: int, image_h: int, fov_deg: float
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Computes pinhole camera intrinsic matrix from image size and FOV.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_w (int): Image width.
|
|
||||||
image_h (int): Image height.
|
|
||||||
fov_deg (float): Field of view in degrees.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Intrinsic matrix K.
|
|
||||||
"""
|
|
||||||
fov_rad = np.deg2rad(fov_deg)
|
fov_rad = np.deg2rad(fov_deg)
|
||||||
fx = image_w / (2 * np.tan(fov_rad / 2))
|
fx = image_w / (2 * np.tan(fov_rad / 2))
|
||||||
fy = fx # assuming square pixels
|
fy = fx # assuming square pixels
|
||||||
|
|||||||
@ -45,35 +45,7 @@ CONFIG_FILE = "embodied_gen/utils/gpt_config.yaml"
|
|||||||
|
|
||||||
|
|
||||||
class GPTclient:
|
class GPTclient:
|
||||||
"""A client to interact with GPT models via OpenAI or Azure API.
|
"""A client to interact with the GPT model via OpenAI or Azure API."""
|
||||||
|
|
||||||
Supports text and image prompts, connection checking, and configurable parameters.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
endpoint (str): API endpoint URL.
|
|
||||||
api_key (str): API key for authentication.
|
|
||||||
model_name (str, optional): Model name to use.
|
|
||||||
api_version (str, optional): API version (for Azure).
|
|
||||||
check_connection (bool, optional): Whether to check API connection.
|
|
||||||
verbose (bool, optional): Enable verbose logging.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```sh
|
|
||||||
export ENDPOINT="https://yfb-openai-sweden.openai.azure.com"
|
|
||||||
export API_KEY="xxxxxx"
|
|
||||||
export API_VERSION="2025-03-01-preview"
|
|
||||||
export MODEL_NAME="yfb-gpt-4o-sweden"
|
|
||||||
```
|
|
||||||
```py
|
|
||||||
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
|
||||||
|
|
||||||
response = GPT_CLIENT.query("Describe the physics of a falling apple.")
|
|
||||||
response = GPT_CLIENT.query(
|
|
||||||
text_prompt="Describe the content in each image."
|
|
||||||
image_base64=["path/to/image1.png", "path/to/image2.jpg"],
|
|
||||||
)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -110,7 +82,6 @@ class GPTclient:
|
|||||||
stop=(stop_after_attempt(10) | stop_after_delay(30)),
|
stop=(stop_after_attempt(10) | stop_after_delay(30)),
|
||||||
)
|
)
|
||||||
def completion_with_backoff(self, **kwargs):
|
def completion_with_backoff(self, **kwargs):
|
||||||
"""Performs a chat completion request with retry/backoff."""
|
|
||||||
return self.client.chat.completions.create(**kwargs)
|
return self.client.chat.completions.create(**kwargs)
|
||||||
|
|
||||||
def query(
|
def query(
|
||||||
@ -120,16 +91,19 @@ class GPTclient:
|
|||||||
system_role: Optional[str] = None,
|
system_role: Optional[str] = None,
|
||||||
params: Optional[dict] = None,
|
params: Optional[dict] = None,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""Queries the GPT model with text and optional image prompts.
|
"""Queries the GPT model with a text and optional image prompts.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text_prompt (str): Main text input.
|
text_prompt (str): The main text input that the model responds to.
|
||||||
image_base64 (Optional[list[str | Image.Image]], optional): List of image base64 strings, file paths, or PIL Images.
|
image_base64 (Optional[List[str]]): A list of image base64 strings
|
||||||
system_role (Optional[str], optional): System-level instructions.
|
or local image paths or PIL.Image to accompany the text prompt.
|
||||||
params (Optional[dict], optional): Additional GPT parameters.
|
system_role (Optional[str]): Optional system-level instructions
|
||||||
|
that specify the behavior of the assistant.
|
||||||
|
params (Optional[dict]): Additional parameters for GPT setting.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[str]: Model response content, or None if error.
|
Optional[str]: The response content generated by the model based on
|
||||||
|
the prompt. Returns `None` if an error occurs.
|
||||||
"""
|
"""
|
||||||
if system_role is None:
|
if system_role is None:
|
||||||
system_role = "You are a highly knowledgeable assistant specializing in physics, engineering, and object properties." # noqa
|
system_role = "You are a highly knowledgeable assistant specializing in physics, engineering, and object properties." # noqa
|
||||||
@ -203,11 +177,7 @@ class GPTclient:
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
def check_connection(self) -> None:
|
def check_connection(self) -> None:
|
||||||
"""Checks whether the GPT API connection is working.
|
"""Check whether the GPT API connection is working."""
|
||||||
|
|
||||||
Raises:
|
|
||||||
ConnectionError: If connection fails.
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
response = self.completion_with_backoff(
|
response = self.completion_with_backoff(
|
||||||
messages=[
|
messages=[
|
||||||
|
|||||||
@ -69,40 +69,6 @@ def render_asset3d(
|
|||||||
no_index_file: bool = False,
|
no_index_file: bool = False,
|
||||||
with_mtl: bool = True,
|
with_mtl: bool = True,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""Renders a 3D mesh asset and returns output image paths.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mesh_path (str): Path to the mesh file.
|
|
||||||
output_root (str): Directory to save outputs.
|
|
||||||
distance (float, optional): Camera distance.
|
|
||||||
num_images (int, optional): Number of views to render.
|
|
||||||
elevation (list[float], optional): Camera elevation angles.
|
|
||||||
pbr_light_factor (float, optional): PBR lighting factor.
|
|
||||||
return_key (str, optional): Glob pattern for output images.
|
|
||||||
output_subdir (str, optional): Subdirectory for outputs.
|
|
||||||
gen_color_mp4 (bool, optional): Generate color MP4 video.
|
|
||||||
gen_viewnormal_mp4 (bool, optional): Generate view normal MP4.
|
|
||||||
gen_glonormal_mp4 (bool, optional): Generate global normal MP4.
|
|
||||||
no_index_file (bool, optional): Skip index file saving.
|
|
||||||
with_mtl (bool, optional): Use mesh material.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[str]: List of output image file paths.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```py
|
|
||||||
from embodied_gen.utils.process_media import render_asset3d
|
|
||||||
|
|
||||||
image_paths = render_asset3d(
|
|
||||||
mesh_path="path_to_mesh.obj",
|
|
||||||
output_root="path_to_save_dir",
|
|
||||||
num_images=6,
|
|
||||||
elevation=(30, -30),
|
|
||||||
output_subdir="renders",
|
|
||||||
no_index_file=True,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
input_args = dict(
|
input_args = dict(
|
||||||
mesh_path=mesh_path,
|
mesh_path=mesh_path,
|
||||||
output_root=output_root,
|
output_root=output_root,
|
||||||
@ -129,13 +95,6 @@ def render_asset3d(
|
|||||||
|
|
||||||
|
|
||||||
def merge_images_video(color_images, normal_images, output_path) -> None:
|
def merge_images_video(color_images, normal_images, output_path) -> None:
|
||||||
"""Merges color and normal images into a video.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
color_images (list[np.ndarray]): List of color images.
|
|
||||||
normal_images (list[np.ndarray]): List of normal images.
|
|
||||||
output_path (str): Path to save the output video.
|
|
||||||
"""
|
|
||||||
width = color_images[0].shape[1]
|
width = color_images[0].shape[1]
|
||||||
combined_video = [
|
combined_video = [
|
||||||
np.hstack([rgb_img[:, : width // 2], normal_img[:, width // 2 :]])
|
np.hstack([rgb_img[:, : width // 2], normal_img[:, width // 2 :]])
|
||||||
@ -149,13 +108,7 @@ def merge_images_video(color_images, normal_images, output_path) -> None:
|
|||||||
def merge_video_video(
|
def merge_video_video(
|
||||||
video_path1: str, video_path2: str, output_path: str
|
video_path1: str, video_path2: str, output_path: str
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Merges two videos by combining their left and right halves.
|
"""Merge two videos by the left half and the right half of the videos."""
|
||||||
|
|
||||||
Args:
|
|
||||||
video_path1 (str): Path to first video.
|
|
||||||
video_path2 (str): Path to second video.
|
|
||||||
output_path (str): Path to save the merged video.
|
|
||||||
"""
|
|
||||||
clip1 = VideoFileClip(video_path1)
|
clip1 = VideoFileClip(video_path1)
|
||||||
clip2 = VideoFileClip(video_path2)
|
clip2 = VideoFileClip(video_path2)
|
||||||
|
|
||||||
@ -174,16 +127,6 @@ def filter_small_connected_components(
|
|||||||
area_ratio: float,
|
area_ratio: float,
|
||||||
connectivity: int = 8,
|
connectivity: int = 8,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Removes small connected components from a binary mask.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mask (Union[Image.Image, np.ndarray]): Input mask.
|
|
||||||
area_ratio (float): Minimum area ratio for components.
|
|
||||||
connectivity (int, optional): Connectivity for labeling.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Mask with small components removed.
|
|
||||||
"""
|
|
||||||
if isinstance(mask, Image.Image):
|
if isinstance(mask, Image.Image):
|
||||||
mask = np.array(mask)
|
mask = np.array(mask)
|
||||||
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(
|
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(
|
||||||
@ -209,16 +152,6 @@ def filter_image_small_connected_components(
|
|||||||
area_ratio: float = 10,
|
area_ratio: float = 10,
|
||||||
connectivity: int = 8,
|
connectivity: int = 8,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Removes small connected components from the alpha channel of an image.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image (Union[Image.Image, np.ndarray]): Input image.
|
|
||||||
area_ratio (float, optional): Minimum area ratio.
|
|
||||||
connectivity (int, optional): Connectivity for labeling.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Image with filtered alpha channel.
|
|
||||||
"""
|
|
||||||
if isinstance(image, Image.Image):
|
if isinstance(image, Image.Image):
|
||||||
image = image.convert("RGBA")
|
image = image.convert("RGBA")
|
||||||
image = np.array(image)
|
image = np.array(image)
|
||||||
@ -236,24 +169,6 @@ def combine_images_to_grid(
|
|||||||
target_wh: tuple[int, int] = (512, 512),
|
target_wh: tuple[int, int] = (512, 512),
|
||||||
image_mode: str = "RGB",
|
image_mode: str = "RGB",
|
||||||
) -> list[Image.Image]:
|
) -> list[Image.Image]:
|
||||||
"""Combines multiple images into a grid.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (list[str | Image.Image]): List of image paths or PIL Images.
|
|
||||||
cat_row_col (tuple[int, int], optional): Grid rows and columns.
|
|
||||||
target_wh (tuple[int, int], optional): Target image size.
|
|
||||||
image_mode (str, optional): Image mode.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[Image.Image]: List containing the grid image.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```py
|
|
||||||
from embodied_gen.utils.process_media import combine_images_to_grid
|
|
||||||
grid = combine_images_to_grid(["img1.png", "img2.png"])
|
|
||||||
grid[0].save("grid.png")
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
n_images = len(images)
|
n_images = len(images)
|
||||||
if n_images == 1:
|
if n_images == 1:
|
||||||
return images
|
return images
|
||||||
@ -281,19 +196,6 @@ def combine_images_to_grid(
|
|||||||
|
|
||||||
|
|
||||||
class SceneTreeVisualizer:
|
class SceneTreeVisualizer:
|
||||||
"""Visualizes a scene tree layout using networkx and matplotlib.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layout_info (LayoutInfo): Layout information for the scene.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```py
|
|
||||||
from embodied_gen.utils.process_media import SceneTreeVisualizer
|
|
||||||
visualizer = SceneTreeVisualizer(layout_info)
|
|
||||||
visualizer.render(save_path="tree.png")
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, layout_info: LayoutInfo) -> None:
|
def __init__(self, layout_info: LayoutInfo) -> None:
|
||||||
self.tree = layout_info.tree
|
self.tree = layout_info.tree
|
||||||
self.relation = layout_info.relation
|
self.relation = layout_info.relation
|
||||||
@ -372,14 +274,6 @@ class SceneTreeVisualizer:
|
|||||||
dpi=300,
|
dpi=300,
|
||||||
title: str = "Scene 3D Hierarchy Tree",
|
title: str = "Scene 3D Hierarchy Tree",
|
||||||
):
|
):
|
||||||
"""Renders the scene tree and saves to file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
save_path (str): Path to save the rendered image.
|
|
||||||
figsize (tuple, optional): Figure size.
|
|
||||||
dpi (int, optional): Image DPI.
|
|
||||||
title (str, optional): Plot image title.
|
|
||||||
"""
|
|
||||||
node_colors = [
|
node_colors = [
|
||||||
self.role_colors[self._get_node_role(n)] for n in self.G.nodes
|
self.role_colors[self._get_node_role(n)] for n in self.G.nodes
|
||||||
]
|
]
|
||||||
@ -456,14 +350,6 @@ class SceneTreeVisualizer:
|
|||||||
|
|
||||||
|
|
||||||
def load_scene_dict(file_path: str) -> dict:
|
def load_scene_dict(file_path: str) -> dict:
|
||||||
"""Loads a scene description dictionary from a file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path (str): Path to the scene description file.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Mapping from scene ID to description.
|
|
||||||
"""
|
|
||||||
scene_dict = {}
|
scene_dict = {}
|
||||||
with open(file_path, "r", encoding='utf-8') as f:
|
with open(file_path, "r", encoding='utf-8') as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
@ -477,28 +363,12 @@ def load_scene_dict(file_path: str) -> dict:
|
|||||||
|
|
||||||
|
|
||||||
def is_image_file(filename: str) -> bool:
|
def is_image_file(filename: str) -> bool:
|
||||||
"""Checks if a filename is an image file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filename (str): Filename to check.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if image file, False otherwise.
|
|
||||||
"""
|
|
||||||
mime_type, _ = mimetypes.guess_type(filename)
|
mime_type, _ = mimetypes.guess_type(filename)
|
||||||
|
|
||||||
return mime_type is not None and mime_type.startswith('image')
|
return mime_type is not None and mime_type.startswith('image')
|
||||||
|
|
||||||
|
|
||||||
def parse_text_prompts(prompts: list[str]) -> list[str]:
|
def parse_text_prompts(prompts: list[str]) -> list[str]:
|
||||||
"""Parses text prompts from a list or file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompts (list[str]): List of prompts or a file path.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[str]: List of parsed prompts.
|
|
||||||
"""
|
|
||||||
if len(prompts) == 1 and prompts[0].endswith(".txt"):
|
if len(prompts) == 1 and prompts[0].endswith(".txt"):
|
||||||
with open(prompts[0], "r") as f:
|
with open(prompts[0], "r") as f:
|
||||||
prompts = [
|
prompts = [
|
||||||
@ -516,18 +386,13 @@ def alpha_blend_rgba(
|
|||||||
"""Alpha blends a foreground RGBA image over a background RGBA image.
|
"""Alpha blends a foreground RGBA image over a background RGBA image.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fg_image: Foreground image (str, PIL Image, or ndarray).
|
fg_image: Foreground image. Can be a file path (str), a PIL Image,
|
||||||
bg_image: Background image (str, PIL Image, or ndarray).
|
or a NumPy ndarray.
|
||||||
|
bg_image: Background image. Can be a file path (str), a PIL Image,
|
||||||
|
or a NumPy ndarray.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Image.Image: Alpha-blended RGBA image.
|
A PIL Image representing the alpha-blended result in RGBA mode.
|
||||||
|
|
||||||
Example:
|
|
||||||
```py
|
|
||||||
from embodied_gen.utils.process_media import alpha_blend_rgba
|
|
||||||
result = alpha_blend_rgba("fg.png", "bg.png")
|
|
||||||
result.save("blended.png")
|
|
||||||
```
|
|
||||||
"""
|
"""
|
||||||
if isinstance(fg_image, str):
|
if isinstance(fg_image, str):
|
||||||
fg_image = Image.open(fg_image)
|
fg_image = Image.open(fg_image)
|
||||||
@ -556,11 +421,13 @@ def check_object_edge_truncated(
|
|||||||
"""Checks if a binary object mask is truncated at the image edges.
|
"""Checks if a binary object mask is truncated at the image edges.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mask (np.ndarray): 2D binary mask.
|
mask: A 2D binary NumPy array where nonzero values indicate the object region.
|
||||||
edge_threshold (int, optional): Edge pixel threshold.
|
edge_threshold: Number of pixels from each image edge to consider for truncation.
|
||||||
|
Defaults to 5.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if object is fully enclosed, False if truncated.
|
True if the object is fully enclosed (not truncated).
|
||||||
|
False if the object touches or crosses any image boundary.
|
||||||
"""
|
"""
|
||||||
top = mask[:edge_threshold, :].any()
|
top = mask[:edge_threshold, :].any()
|
||||||
bottom = mask[-edge_threshold:, :].any()
|
bottom = mask[-edge_threshold:, :].any()
|
||||||
@ -573,22 +440,6 @@ def check_object_edge_truncated(
|
|||||||
def vcat_pil_images(
|
def vcat_pil_images(
|
||||||
images: list[Image.Image], image_mode: str = "RGB"
|
images: list[Image.Image], image_mode: str = "RGB"
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
"""Vertically concatenates a list of PIL images.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (list[Image.Image]): List of images.
|
|
||||||
image_mode (str, optional): Image mode.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Image.Image: Vertically concatenated image.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```py
|
|
||||||
from embodied_gen.utils.process_media import vcat_pil_images
|
|
||||||
img = vcat_pil_images([Image.open("a.png"), Image.open("b.png")])
|
|
||||||
img.save("vcat.png")
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
widths, heights = zip(*(img.size for img in images))
|
widths, heights = zip(*(img.size for img in images))
|
||||||
total_height = sum(heights)
|
total_height = sum(heights)
|
||||||
max_width = max(widths)
|
max_width = max(widths)
|
||||||
|
|||||||
@ -69,21 +69,6 @@ def load_actor_from_urdf(
|
|||||||
update_mass: bool = False,
|
update_mass: bool = False,
|
||||||
scale: float | np.ndarray = 1.0,
|
scale: float | np.ndarray = 1.0,
|
||||||
) -> sapien.pysapien.Entity:
|
) -> sapien.pysapien.Entity:
|
||||||
"""Load an sapien actor from a URDF file and add it to the scene.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scene (sapien.Scene | ManiSkillScene): The simulation scene.
|
|
||||||
file_path (str): Path to the URDF file.
|
|
||||||
pose (sapien.Pose | None): Initial pose of the actor.
|
|
||||||
env_idx (int): Environment index for multi-env setup.
|
|
||||||
use_static (bool): Whether the actor is static.
|
|
||||||
update_mass (bool): Whether to update the actor's mass from URDF.
|
|
||||||
scale (float | np.ndarray): Scale factor for the actor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
sapien.pysapien.Entity: The created actor entity.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _get_local_pose(origin_tag: ET.Element | None) -> sapien.Pose:
|
def _get_local_pose(origin_tag: ET.Element | None) -> sapien.Pose:
|
||||||
local_pose = sapien.Pose(p=[0, 0, 0], q=[1, 0, 0, 0])
|
local_pose = sapien.Pose(p=[0, 0, 0], q=[1, 0, 0, 0])
|
||||||
if origin_tag is not None:
|
if origin_tag is not None:
|
||||||
@ -169,17 +154,14 @@ def load_assets_from_layout_file(
|
|||||||
init_quat: list[float] = [0, 0, 0, 1],
|
init_quat: list[float] = [0, 0, 0, 1],
|
||||||
env_idx: int = None,
|
env_idx: int = None,
|
||||||
) -> dict[str, sapien.pysapien.Entity]:
|
) -> dict[str, sapien.pysapien.Entity]:
|
||||||
"""Load assets from an EmbodiedGen layout file and create sapien actors in the scene.
|
"""Load assets from `EmbodiedGen` layout-gen output and create actors in the scene.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
scene (ManiSkillScene | sapien.Scene): The sapien simulation scene.
|
scene (sapien.Scene | ManiSkillScene): The SAPIEN or ManiSkill scene to load assets into.
|
||||||
layout (str): Path to the embodiedgen layout file.
|
layout (str): The layout file path.
|
||||||
z_offset (float): Z offset for non-context objects.
|
z_offset (float): Offset to apply to the Z-coordinate of non-context objects.
|
||||||
init_quat (list[float]): Initial quaternion for orientation.
|
init_quat (List[float]): Initial quaternion (x, y, z, w) for orientation adjustment.
|
||||||
env_idx (int): Environment index.
|
env_idx (int): Environment index for multi-environment setup.
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict[str, sapien.pysapien.Entity]: Mapping from object names to actor entities.
|
|
||||||
"""
|
"""
|
||||||
asset_root = os.path.dirname(layout)
|
asset_root = os.path.dirname(layout)
|
||||||
layout = LayoutInfo.from_dict(json.load(open(layout, "r")))
|
layout = LayoutInfo.from_dict(json.load(open(layout, "r")))
|
||||||
@ -224,19 +206,6 @@ def load_mani_skill_robot(
|
|||||||
control_mode: str = "pd_joint_pos",
|
control_mode: str = "pd_joint_pos",
|
||||||
backend_str: tuple[str, str] = ("cpu", "gpu"),
|
backend_str: tuple[str, str] = ("cpu", "gpu"),
|
||||||
) -> BaseAgent:
|
) -> BaseAgent:
|
||||||
"""Load a ManiSkill robot agent into the scene.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scene (sapien.Scene | ManiSkillScene): The simulation scene.
|
|
||||||
layout (LayoutInfo | str): Layout info or path to layout file.
|
|
||||||
control_freq (int): Control frequency.
|
|
||||||
robot_init_qpos_noise (float): Noise for initial joint positions.
|
|
||||||
control_mode (str): Robot control mode.
|
|
||||||
backend_str (tuple[str, str]): Simulation/render backend.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BaseAgent: The loaded robot agent.
|
|
||||||
"""
|
|
||||||
from mani_skill.agents import REGISTERED_AGENTS
|
from mani_skill.agents import REGISTERED_AGENTS
|
||||||
from mani_skill.envs.scene import ManiSkillScene
|
from mani_skill.envs.scene import ManiSkillScene
|
||||||
from mani_skill.envs.utils.system.backend import (
|
from mani_skill.envs.utils.system.backend import (
|
||||||
@ -309,14 +278,14 @@ def render_images(
|
|||||||
]
|
]
|
||||||
] = None,
|
] = None,
|
||||||
) -> dict[str, Image.Image]:
|
) -> dict[str, Image.Image]:
|
||||||
"""Render images from a given SAPIEN camera.
|
"""Render images from a given sapien camera.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
camera (sapien.render.RenderCameraComponent): Camera to render from.
|
camera (sapien.render.RenderCameraComponent): The camera to render from.
|
||||||
render_keys (list[str], optional): Types of images to render.
|
render_keys (List[str]): Types of images to render (e.g., Color, Segmentation).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, Image.Image]: Dictionary of rendered images.
|
Dict[str, Image.Image]: Dictionary of rendered images.
|
||||||
"""
|
"""
|
||||||
if render_keys is None:
|
if render_keys is None:
|
||||||
render_keys = [
|
render_keys = [
|
||||||
@ -372,33 +341,11 @@ def render_images(
|
|||||||
|
|
||||||
|
|
||||||
class SapienSceneManager:
|
class SapienSceneManager:
|
||||||
"""Manages SAPIEN simulation scenes, cameras, and rendering.
|
"""A class to manage SAPIEN simulator."""
|
||||||
|
|
||||||
This class provides utilities for setting up scenes, adding cameras,
|
|
||||||
stepping simulation, and rendering images.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
sim_freq (int): Simulation frequency.
|
|
||||||
ray_tracing (bool): Whether to use ray tracing.
|
|
||||||
device (str): Device for simulation.
|
|
||||||
renderer (sapien.SapienRenderer): SAPIEN renderer.
|
|
||||||
scene (sapien.Scene): Simulation scene.
|
|
||||||
cameras (list): List of camera components.
|
|
||||||
actors (dict): Mapping of actor names to entities.
|
|
||||||
|
|
||||||
Example see `embodied_gen/scripts/simulate_sapien.py`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, sim_freq: int, ray_tracing: bool, device: str = "cuda"
|
self, sim_freq: int, ray_tracing: bool, device: str = "cuda"
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the scene manager.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sim_freq (int): Simulation frequency.
|
|
||||||
ray_tracing (bool): Enable ray tracing.
|
|
||||||
device (str): Device for simulation.
|
|
||||||
"""
|
|
||||||
self.sim_freq = sim_freq
|
self.sim_freq = sim_freq
|
||||||
self.ray_tracing = ray_tracing
|
self.ray_tracing = ray_tracing
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -408,11 +355,7 @@ class SapienSceneManager:
|
|||||||
self.actors: dict[str, sapien.pysapien.Entity] = {}
|
self.actors: dict[str, sapien.pysapien.Entity] = {}
|
||||||
|
|
||||||
def _setup_scene(self) -> sapien.Scene:
|
def _setup_scene(self) -> sapien.Scene:
|
||||||
"""Set up the SAPIEN scene with lighting and ground.
|
"""Set up the SAPIEN scene with lighting and ground."""
|
||||||
|
|
||||||
Returns:
|
|
||||||
sapien.Scene: The initialized scene.
|
|
||||||
"""
|
|
||||||
# Ray tracing settings
|
# Ray tracing settings
|
||||||
if self.ray_tracing:
|
if self.ray_tracing:
|
||||||
sapien.render.set_camera_shader_dir("rt")
|
sapien.render.set_camera_shader_dir("rt")
|
||||||
@ -454,18 +397,6 @@ class SapienSceneManager:
|
|||||||
render_keys: list[str],
|
render_keys: list[str],
|
||||||
sim_steps_per_control: int = 1,
|
sim_steps_per_control: int = 1,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Step the simulation and render images from cameras.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent (BaseAgent): The robot agent.
|
|
||||||
action (torch.Tensor): Action to apply.
|
|
||||||
cameras (list): List of camera components.
|
|
||||||
render_keys (list[str]): Types of images to render.
|
|
||||||
sim_steps_per_control (int): Simulation steps per control.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Dictionary of rendered frames per camera.
|
|
||||||
"""
|
|
||||||
agent.set_action(action)
|
agent.set_action(action)
|
||||||
frames = defaultdict(list)
|
frames = defaultdict(list)
|
||||||
for _ in range(sim_steps_per_control):
|
for _ in range(sim_steps_per_control):
|
||||||
@ -486,13 +417,13 @@ class SapienSceneManager:
|
|||||||
image_hw: tuple[int, int],
|
image_hw: tuple[int, int],
|
||||||
fovy_deg: float,
|
fovy_deg: float,
|
||||||
) -> sapien.render.RenderCameraComponent:
|
) -> sapien.render.RenderCameraComponent:
|
||||||
"""Create a camera in the scene.
|
"""Create a single camera in the scene.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cam_name (str): Camera name.
|
cam_name (str): Name of the camera.
|
||||||
pose (sapien.Pose): Camera pose.
|
pose (sapien.Pose): Camera pose p=(x, y, z), q=(w, x, y, z)
|
||||||
image_hw (tuple[int, int]): Image resolution (height, width).
|
image_hw (Tuple[int, int]): Image resolution (height, width) for cameras.
|
||||||
fovy_deg (float): Field of view in degrees.
|
fovy_deg (float): Field of view in degrees for cameras.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
sapien.render.RenderCameraComponent: The created camera.
|
sapien.render.RenderCameraComponent: The created camera.
|
||||||
@ -525,15 +456,15 @@ class SapienSceneManager:
|
|||||||
"""Initialize multiple cameras arranged in a circle.
|
"""Initialize multiple cameras arranged in a circle.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_cameras (int): Number of cameras.
|
num_cameras (int): Number of cameras to create.
|
||||||
radius (float): Circle radius.
|
radius (float): Radius of the camera circle.
|
||||||
height (float): Camera height.
|
height (float): Fixed Z-coordinate of the cameras.
|
||||||
target_pt (list[float]): Target point to look at.
|
target_pt (list[float]): 3D point (x, y, z) that cameras look at.
|
||||||
image_hw (tuple[int, int]): Image resolution.
|
image_hw (Tuple[int, int]): Image resolution (height, width) for cameras.
|
||||||
fovy_deg (float): Field of view in degrees.
|
fovy_deg (float): Field of view in degrees for cameras.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[sapien.render.RenderCameraComponent]: List of cameras.
|
List[sapien.render.RenderCameraComponent]: List of created cameras.
|
||||||
"""
|
"""
|
||||||
angle_step = 2 * np.pi / num_cameras
|
angle_step = 2 * np.pi / num_cameras
|
||||||
world_up_vec = np.array([0.0, 0.0, 1.0])
|
world_up_vec = np.array([0.0, 0.0, 1.0])
|
||||||
@ -579,19 +510,6 @@ class SapienSceneManager:
|
|||||||
|
|
||||||
|
|
||||||
class FrankaPandaGrasper(object):
|
class FrankaPandaGrasper(object):
|
||||||
"""Provides grasp planning and control for Franka Panda robot.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
agent (BaseAgent): The robot agent.
|
|
||||||
robot: The robot instance.
|
|
||||||
control_freq (float): Control frequency.
|
|
||||||
control_timestep (float): Control timestep.
|
|
||||||
joint_vel_limits (float): Joint velocity limits.
|
|
||||||
joint_acc_limits (float): Joint acceleration limits.
|
|
||||||
finger_length (float): Length of gripper fingers.
|
|
||||||
planners: Motion planners for each environment.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
agent: BaseAgent,
|
agent: BaseAgent,
|
||||||
@ -600,7 +518,6 @@ class FrankaPandaGrasper(object):
|
|||||||
joint_acc_limits: float = 1.0,
|
joint_acc_limits: float = 1.0,
|
||||||
finger_length: float = 0.025,
|
finger_length: float = 0.025,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the grasper."""
|
|
||||||
self.agent = agent
|
self.agent = agent
|
||||||
self.robot = agent.robot
|
self.robot = agent.robot
|
||||||
self.control_freq = control_freq
|
self.control_freq = control_freq
|
||||||
@ -636,15 +553,6 @@ class FrankaPandaGrasper(object):
|
|||||||
gripper_state: Literal[-1, 1],
|
gripper_state: Literal[-1, 1],
|
||||||
n_step: int = 10,
|
n_step: int = 10,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Generate gripper control actions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
gripper_state (Literal[-1, 1]): Desired gripper state.
|
|
||||||
n_step (int): Number of steps.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Array of gripper actions.
|
|
||||||
"""
|
|
||||||
qpos = self.robot.get_qpos()[0, :-2].cpu().numpy()
|
qpos = self.robot.get_qpos()[0, :-2].cpu().numpy()
|
||||||
actions = []
|
actions = []
|
||||||
for _ in range(n_step):
|
for _ in range(n_step):
|
||||||
@ -663,20 +571,6 @@ class FrankaPandaGrasper(object):
|
|||||||
action_key: str = "position",
|
action_key: str = "position",
|
||||||
env_idx: int = 0,
|
env_idx: int = 0,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Plan and execute motion to a target pose.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pose (sapien.Pose): Target pose.
|
|
||||||
control_timestep (float): Control timestep.
|
|
||||||
gripper_state (Literal[-1, 1]): Desired gripper state.
|
|
||||||
use_point_cloud (bool): Use point cloud for planning.
|
|
||||||
n_max_step (int): Max number of steps.
|
|
||||||
action_key (str): Key for action in result.
|
|
||||||
env_idx (int): Environment index.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Array of actions to reach the pose.
|
|
||||||
"""
|
|
||||||
result = self.planners[env_idx].plan_qpos_to_pose(
|
result = self.planners[env_idx].plan_qpos_to_pose(
|
||||||
np.concatenate([pose.p, pose.q]),
|
np.concatenate([pose.p, pose.q]),
|
||||||
self.robot.get_qpos().cpu().numpy()[0],
|
self.robot.get_qpos().cpu().numpy()[0],
|
||||||
@ -714,17 +608,6 @@ class FrankaPandaGrasper(object):
|
|||||||
offset: tuple[float, float, float] = [0, 0, -0.05],
|
offset: tuple[float, float, float] = [0, 0, -0.05],
|
||||||
env_idx: int = 0,
|
env_idx: int = 0,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Compute grasp actions for a target actor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
actor (sapien.pysapien.Entity): Target actor to grasp.
|
|
||||||
reach_target_only (bool): Only reach the target pose if True.
|
|
||||||
offset (tuple[float, float, float]): Offset for reach pose.
|
|
||||||
env_idx (int): Environment index.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Array of grasp actions.
|
|
||||||
"""
|
|
||||||
physx_rigid = actor.components[1]
|
physx_rigid = actor.components[1]
|
||||||
mesh = get_component_mesh(physx_rigid, to_world_frame=True)
|
mesh = get_component_mesh(physx_rigid, to_world_frame=True)
|
||||||
obb = mesh.bounding_box_oriented
|
obb = mesh.bounding_box_oriented
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
VERSION = "v0.1.6"
|
VERSION = "v0.1.5"
|
||||||
|
|||||||
@ -27,22 +27,14 @@ from PIL import Image
|
|||||||
|
|
||||||
|
|
||||||
class AestheticPredictor:
|
class AestheticPredictor:
|
||||||
"""Aesthetic Score Predictor using CLIP and a pre-trained MLP.
|
"""Aesthetic Score Predictor.
|
||||||
|
|
||||||
Checkpoints from `https://github.com/christophschuhmann/improved-aesthetic-predictor/tree/main`.
|
Checkpoints from https://github.com/christophschuhmann/improved-aesthetic-predictor/tree/main
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
clip_model_dir (str, optional): Path to CLIP model directory.
|
clip_model_dir (str): Path to the directory of the CLIP model.
|
||||||
sac_model_path (str, optional): Path to SAC model weights.
|
sac_model_path (str): Path to the pre-trained SAC model.
|
||||||
device (str, optional): Device for computation ("cuda" or "cpu").
|
device (str): Device to use for computation ("cuda" or "cpu").
|
||||||
|
|
||||||
Example:
|
|
||||||
```py
|
|
||||||
from embodied_gen.validators.aesthetic_predictor import AestheticPredictor
|
|
||||||
predictor = AestheticPredictor(device="cuda")
|
|
||||||
score = predictor.predict("image.png")
|
|
||||||
print("Aesthetic score:", score)
|
|
||||||
```
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, clip_model_dir=None, sac_model_path=None, device="cpu"):
|
def __init__(self, clip_model_dir=None, sac_model_path=None, device="cpu"):
|
||||||
@ -117,7 +109,7 @@ class AestheticPredictor:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
def predict(self, image_path):
|
def predict(self, image_path):
|
||||||
"""Predicts the aesthetic score for a given image.
|
"""Predict the aesthetic score for a given image.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image_path (str): Path to the image file.
|
image_path (str): Path to the image file.
|
||||||
|
|||||||
@ -40,16 +40,6 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class BaseChecker:
|
class BaseChecker:
|
||||||
"""Base class for quality checkers using GPT clients.
|
|
||||||
|
|
||||||
Provides a common interface for querying and validating responses.
|
|
||||||
Subclasses must implement the `query` method.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
prompt (str): The prompt used for queries.
|
|
||||||
verbose (bool): Whether to enable verbose logging.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, prompt: str = None, verbose: bool = False) -> None:
|
def __init__(self, prompt: str = None, verbose: bool = False) -> None:
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
@ -80,15 +70,6 @@ class BaseChecker:
|
|||||||
def validate(
|
def validate(
|
||||||
checkers: list["BaseChecker"], images_list: list[list[str]]
|
checkers: list["BaseChecker"], images_list: list[list[str]]
|
||||||
) -> list:
|
) -> list:
|
||||||
"""Validates a list of checkers against corresponding image lists.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
checkers (list[BaseChecker]): List of checker instances.
|
|
||||||
images_list (list[list[str]]): List of image path lists.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: Validation results with overall outcome.
|
|
||||||
"""
|
|
||||||
assert len(checkers) == len(images_list)
|
assert len(checkers) == len(images_list)
|
||||||
results = []
|
results = []
|
||||||
overall_result = True
|
overall_result = True
|
||||||
@ -211,7 +192,7 @@ class ImageSegChecker(BaseChecker):
|
|||||||
|
|
||||||
|
|
||||||
class ImageAestheticChecker(BaseChecker):
|
class ImageAestheticChecker(BaseChecker):
|
||||||
"""Evaluates the aesthetic quality of images using a CLIP-based predictor.
|
"""A class for evaluating the aesthetic quality of images.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
clip_model_dir (str): Path to the CLIP model directory.
|
clip_model_dir (str): Path to the CLIP model directory.
|
||||||
@ -219,14 +200,6 @@ class ImageAestheticChecker(BaseChecker):
|
|||||||
thresh (float): Threshold above which images are considered aesthetically acceptable.
|
thresh (float): Threshold above which images are considered aesthetically acceptable.
|
||||||
verbose (bool): Whether to print detailed log messages.
|
verbose (bool): Whether to print detailed log messages.
|
||||||
predictor (AestheticPredictor): The model used to predict aesthetic scores.
|
predictor (AestheticPredictor): The model used to predict aesthetic scores.
|
||||||
|
|
||||||
Example:
|
|
||||||
```py
|
|
||||||
from embodied_gen.validators.quality_checkers import ImageAestheticChecker
|
|
||||||
checker = ImageAestheticChecker(thresh=4.5)
|
|
||||||
flag, score = checker(["image1.png", "image2.png"])
|
|
||||||
print("Aesthetic OK:", flag, "Score:", score)
|
|
||||||
```
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -254,16 +227,6 @@ class ImageAestheticChecker(BaseChecker):
|
|||||||
|
|
||||||
|
|
||||||
class SemanticConsistChecker(BaseChecker):
|
class SemanticConsistChecker(BaseChecker):
|
||||||
"""Checks semantic consistency between text descriptions and segmented images.
|
|
||||||
|
|
||||||
Uses GPT to evaluate if the image matches the text in object type, geometry, and color.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
gpt_client (GPTclient): GPT client for queries.
|
|
||||||
prompt (str): Prompt for consistency evaluation.
|
|
||||||
verbose (bool): Whether to enable verbose logging.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
gpt_client: GPTclient,
|
gpt_client: GPTclient,
|
||||||
@ -313,16 +276,6 @@ class SemanticConsistChecker(BaseChecker):
|
|||||||
|
|
||||||
|
|
||||||
class TextGenAlignChecker(BaseChecker):
|
class TextGenAlignChecker(BaseChecker):
|
||||||
"""Evaluates alignment between text prompts and generated 3D asset images.
|
|
||||||
|
|
||||||
Assesses if the rendered images match the text description in category and geometry.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
gpt_client (GPTclient): GPT client for queries.
|
|
||||||
prompt (str): Prompt for alignment evaluation.
|
|
||||||
verbose (bool): Whether to enable verbose logging.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
gpt_client: GPTclient,
|
gpt_client: GPTclient,
|
||||||
@ -536,17 +489,6 @@ class PanoHeightEstimator(object):
|
|||||||
|
|
||||||
|
|
||||||
class SemanticMatcher(BaseChecker):
|
class SemanticMatcher(BaseChecker):
|
||||||
"""Matches query text to semantically similar scene descriptions.
|
|
||||||
|
|
||||||
Uses GPT to find the most similar scene IDs from a dictionary.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
gpt_client (GPTclient): GPT client for queries.
|
|
||||||
prompt (str): Prompt for semantic matching.
|
|
||||||
verbose (bool): Whether to enable verbose logging.
|
|
||||||
seed (int): Random seed for selection.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
gpt_client: GPTclient,
|
gpt_client: GPTclient,
|
||||||
@ -601,17 +543,6 @@ class SemanticMatcher(BaseChecker):
|
|||||||
def query(
|
def query(
|
||||||
self, text: str, context: dict, rand: bool = True, params: dict = None
|
self, text: str, context: dict, rand: bool = True, params: dict = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Queries for semantically similar scene IDs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text (str): Query text.
|
|
||||||
context (dict): Dictionary of scene descriptions.
|
|
||||||
rand (bool, optional): Whether to randomly select from top matches.
|
|
||||||
params (dict, optional): Additional GPT parameters.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Matched scene ID.
|
|
||||||
"""
|
|
||||||
match_list = self.gpt_client.query(
|
match_list = self.gpt_client.query(
|
||||||
self.prompt.format(context=context, text=text),
|
self.prompt.format(context=context, text=text),
|
||||||
params=params,
|
params=params,
|
||||||
|
|||||||
@ -80,31 +80,6 @@ URDF_TEMPLATE = """
|
|||||||
|
|
||||||
|
|
||||||
class URDFGenerator(object):
|
class URDFGenerator(object):
|
||||||
"""Generates URDF files for 3D assets with physical and semantic attributes.
|
|
||||||
|
|
||||||
Uses GPT to estimate object properties and generates a URDF file with mesh, friction, mass, and metadata.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
gpt_client (GPTclient): GPT client for attribute estimation.
|
|
||||||
mesh_file_list (list[str], optional): Additional mesh files to copy.
|
|
||||||
prompt_template (str, optional): Prompt template for GPT queries.
|
|
||||||
attrs_name (list[str], optional): List of attribute names to include.
|
|
||||||
render_dir (str, optional): Directory for rendered images.
|
|
||||||
render_view_num (int, optional): Number of views to render.
|
|
||||||
decompose_convex (bool, optional): Whether to decompose mesh for collision.
|
|
||||||
rotate_xyzw (list[float], optional): Quaternion for mesh rotation.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```py
|
|
||||||
from embodied_gen.validators.urdf_convertor import URDFGenerator
|
|
||||||
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
|
||||||
|
|
||||||
urdf_gen = URDFGenerator(GPT_CLIENT, render_view_num=4)
|
|
||||||
urdf_path = urdf_gen(mesh_path="mesh.obj", output_root="output_dir")
|
|
||||||
print("Generated URDF:", urdf_path)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
gpt_client: GPTclient,
|
gpt_client: GPTclient,
|
||||||
@ -193,14 +168,6 @@ class URDFGenerator(object):
|
|||||||
self.rotate_xyzw = rotate_xyzw
|
self.rotate_xyzw = rotate_xyzw
|
||||||
|
|
||||||
def parse_response(self, response: str) -> dict[str, any]:
|
def parse_response(self, response: str) -> dict[str, any]:
|
||||||
"""Parses GPT response to extract asset attributes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response (str): GPT response string.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict[str, any]: Parsed attributes.
|
|
||||||
"""
|
|
||||||
lines = response.split("\n")
|
lines = response.split("\n")
|
||||||
lines = [line.strip() for line in lines if line]
|
lines = [line.strip() for line in lines if line]
|
||||||
category = lines[0].split(": ")[1]
|
category = lines[0].split(": ")[1]
|
||||||
@ -240,9 +207,11 @@ class URDFGenerator(object):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_mesh (str): Path to the input mesh file.
|
input_mesh (str): Path to the input mesh file.
|
||||||
output_dir (str): Directory to store the generated URDF and mesh.
|
output_dir (str): Directory to store the generated URDF
|
||||||
attr_dict (dict): Dictionary of asset attributes.
|
and processed mesh.
|
||||||
output_name (str, optional): Name for the URDF and robot.
|
attr_dict (dict): Dictionary containing attributes like height,
|
||||||
|
mass, and friction coefficients.
|
||||||
|
output_name (str, optional): Name for the generated URDF and robot.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: Path to the generated URDF file.
|
str: Path to the generated URDF file.
|
||||||
@ -367,16 +336,6 @@ class URDFGenerator(object):
|
|||||||
attr_root: str = ".//link/extra_info",
|
attr_root: str = ".//link/extra_info",
|
||||||
attr_name: str = "scale",
|
attr_name: str = "scale",
|
||||||
) -> float:
|
) -> float:
|
||||||
"""Extracts an attribute value from a URDF file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
urdf_path (str): Path to the URDF file.
|
|
||||||
attr_root (str, optional): XML path to attribute root.
|
|
||||||
attr_name (str, optional): Attribute name.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: Attribute value, or None if not found.
|
|
||||||
"""
|
|
||||||
if not os.path.exists(urdf_path):
|
if not os.path.exists(urdf_path):
|
||||||
raise FileNotFoundError(f"URDF file not found: {urdf_path}")
|
raise FileNotFoundError(f"URDF file not found: {urdf_path}")
|
||||||
|
|
||||||
@ -399,13 +358,6 @@ class URDFGenerator(object):
|
|||||||
def add_quality_tag(
|
def add_quality_tag(
|
||||||
urdf_path: str, results: list, output_path: str = None
|
urdf_path: str, results: list, output_path: str = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Adds a quality tag to a URDF file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
urdf_path (str): Path to the URDF file.
|
|
||||||
results (list): List of [checker_name, result] pairs.
|
|
||||||
output_path (str, optional): Output file path.
|
|
||||||
"""
|
|
||||||
if output_path is None:
|
if output_path is None:
|
||||||
output_path = urdf_path
|
output_path = urdf_path
|
||||||
|
|
||||||
@ -430,14 +382,6 @@ class URDFGenerator(object):
|
|||||||
logger.info(f"URDF files saved to {output_path}")
|
logger.info(f"URDF files saved to {output_path}")
|
||||||
|
|
||||||
def get_estimated_attributes(self, asset_attrs: dict):
|
def get_estimated_attributes(self, asset_attrs: dict):
|
||||||
"""Calculates estimated attributes from asset properties.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
asset_attrs (dict): Asset attributes.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Estimated attributes (height, mass, mu, category).
|
|
||||||
"""
|
|
||||||
estimated_attrs = {
|
estimated_attrs = {
|
||||||
"height": round(
|
"height": round(
|
||||||
(asset_attrs["min_height"] + asset_attrs["max_height"]) / 2, 4
|
(asset_attrs["min_height"] + asset_attrs["max_height"]) / 2, 4
|
||||||
@ -459,18 +403,6 @@ class URDFGenerator(object):
|
|||||||
category: str = "unknown",
|
category: str = "unknown",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Generates a URDF file for a mesh asset.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mesh_path (str): Path to mesh file.
|
|
||||||
output_root (str): Directory for outputs.
|
|
||||||
text_prompt (str, optional): Prompt for GPT.
|
|
||||||
category (str, optional): Asset category.
|
|
||||||
**kwargs: Additional attributes.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Path to generated URDF file.
|
|
||||||
"""
|
|
||||||
if text_prompt is None or len(text_prompt) == 0:
|
if text_prompt is None or len(text_prompt) == 0:
|
||||||
text_prompt = self.prompt_template
|
text_prompt = self.prompt_template
|
||||||
text_prompt = text_prompt.format(category=category.lower())
|
text_prompt = text_prompt.format(category=category.lower())
|
||||||
|
|||||||
@ -5,9 +5,8 @@ source "$SCRIPT_DIR/_utils.sh"
|
|||||||
|
|
||||||
PIP_INSTALL_PACKAGES=(
|
PIP_INSTALL_PACKAGES=(
|
||||||
"pip==22.3.1"
|
"pip==22.3.1"
|
||||||
"torch==2.4.0+cu121 torchvision==0.19.0+cu121 --index-url https://download.pytorch.org/whl/cu121"
|
"torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu118"
|
||||||
"xformers==0.0.27.post2 --index-url https://download.pytorch.org/whl/cu121"
|
"xformers==0.0.27.post2 --index-url https://download.pytorch.org/whl/cu118"
|
||||||
"flash-attn==2.7.0.post2 --no-build-isolation"
|
|
||||||
"-r requirements.txt --use-deprecated=legacy-resolver"
|
"-r requirements.txt --use-deprecated=legacy-resolver"
|
||||||
"flash-attn==2.7.0.post2"
|
"flash-attn==2.7.0.post2"
|
||||||
"utils3d@git+https://github.com/EasternJournalist/utils3d.git@9a4eb15"
|
"utils3d@git+https://github.com/EasternJournalist/utils3d.git@9a4eb15"
|
||||||
|
|||||||
117
mkdocs.yml
@ -1,117 +0,0 @@
|
|||||||
site_name: Documentation
|
|
||||||
site_url: https://horizonrobotics.github.io/EmbodiedGen/
|
|
||||||
repo_name: "EmbodiedGen"
|
|
||||||
repo_url: https://github.com/HorizonRobotics/EmbodiedGen
|
|
||||||
copyright: "Copyright (c) 2025 Horizon Robotics"
|
|
||||||
use_directory_urls: false
|
|
||||||
|
|
||||||
nav:
|
|
||||||
- 🏠 Home: index.md
|
|
||||||
- 🚀 Installation: install.md
|
|
||||||
- 🧩 Services:
|
|
||||||
- Overview: services/index.md
|
|
||||||
- Image-to-3D: services/image_to_3d.md
|
|
||||||
- Text-to-3D: services/text_to_3d.md
|
|
||||||
- Texture Generation: services/texture_edit.md
|
|
||||||
- Asset Visualizer: services/visualize_asset.md
|
|
||||||
- 📘 Tutorials:
|
|
||||||
- Overview: tutorials/index.md
|
|
||||||
- Image-to-3D: tutorials/image_to_3d.md
|
|
||||||
- Text-to-3D: tutorials/text_to_3d.md
|
|
||||||
- Texture Generation: tutorials/texture_edit.md
|
|
||||||
# - Articulated Object Generation: tutorials/articulated_gen.md
|
|
||||||
- 3D Scene Generation: tutorials/scene_gen.md
|
|
||||||
- Interactive 3D Scenes: tutorials/layout_gen.md
|
|
||||||
- Gym Parallel Envs: tutorials/gym_env.md
|
|
||||||
- Any Simulators: tutorials/any_simulators.md
|
|
||||||
- Digital Twin Creation: tutorials/digital_twin.md
|
|
||||||
- 📚 API Reference:
|
|
||||||
- Overview: api/index.md
|
|
||||||
- Data: api/data.md
|
|
||||||
- Envs: api/envs.md
|
|
||||||
- Models: api/models.md
|
|
||||||
- Trainer: api/trainer.md
|
|
||||||
- Utilities: api/utils.md
|
|
||||||
- Validators: api/validators.md
|
|
||||||
- ✨ Acknowledgement: acknowledgement.md
|
|
||||||
|
|
||||||
extra:
|
|
||||||
social:
|
|
||||||
- icon: simple/huggingface
|
|
||||||
link: https://huggingface.co/collections/HorizonRobotics/embodiedgen
|
|
||||||
- icon: fontawesome/brands/github
|
|
||||||
link: https://github.com/HorizonRobotics/EmbodiedGen
|
|
||||||
- icon: simple/arxiv
|
|
||||||
link: https://arxiv.org/abs/2506.10600
|
|
||||||
- icon: fontawesome/solid/globe
|
|
||||||
link: https://horizonrobotics.github.io/robot_lab/embodied_gen/index.html
|
|
||||||
- icon: fontawesome/brands/youtube
|
|
||||||
link: https://www.youtube.com/watch?v=rG4odybuJRk
|
|
||||||
|
|
||||||
theme:
|
|
||||||
name: material
|
|
||||||
language: en
|
|
||||||
logo: assets/logo.png
|
|
||||||
favicon: assets/logo.png
|
|
||||||
icon:
|
|
||||||
repo: fontawesome/brands/github
|
|
||||||
features:
|
|
||||||
- navigation.instant
|
|
||||||
- navigation.instant.prefetch
|
|
||||||
- navigation.instant.progress
|
|
||||||
- navigation.path
|
|
||||||
- navigation.tabs
|
|
||||||
- navigation.top
|
|
||||||
- search.highlight
|
|
||||||
- content.code.copy
|
|
||||||
- content.action.edit
|
|
||||||
palette:
|
|
||||||
- media: "(prefers-color-scheme: light)"
|
|
||||||
scheme: default
|
|
||||||
primary: brown
|
|
||||||
accent: red
|
|
||||||
toggle:
|
|
||||||
icon: material/weather-sunny
|
|
||||||
name: Switch to dark mode
|
|
||||||
- media: "(prefers-color-scheme: dark)"
|
|
||||||
scheme: slate
|
|
||||||
primary: brown
|
|
||||||
accent: red
|
|
||||||
toggle:
|
|
||||||
icon: material/weather-night
|
|
||||||
name: Switch to light mode
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- search
|
|
||||||
- mkdocstrings:
|
|
||||||
handlers:
|
|
||||||
python:
|
|
||||||
paths: [embodied_gen]
|
|
||||||
options:
|
|
||||||
show_signature_annotations: true
|
|
||||||
separate_signature: true
|
|
||||||
show_root_toc_entry: false
|
|
||||||
docstring_style: google
|
|
||||||
show_source: true
|
|
||||||
merge_init_into_class: true
|
|
||||||
show_root_heading: true
|
|
||||||
show_root_full_path: true
|
|
||||||
|
|
||||||
- git-revision-date-localized:
|
|
||||||
enable_creation_date: true
|
|
||||||
|
|
||||||
extra_css:
|
|
||||||
- stylesheets/extra.css
|
|
||||||
- https://cdn.jsdelivr.net/npm/swiper/swiper-bundle.min.css
|
|
||||||
|
|
||||||
extra_javascript:
|
|
||||||
- https://cdn.jsdelivr.net/npm/swiper/swiper-bundle.min.js
|
|
||||||
- path: https://unpkg.com/@google/model-viewer/dist/model-viewer.min.js
|
|
||||||
type: module
|
|
||||||
- js/model_viewer.js
|
|
||||||
|
|
||||||
markdown_extensions:
|
|
||||||
- pymdownx.highlight
|
|
||||||
- pymdownx.superfences
|
|
||||||
- admonition
|
|
||||||
|
|
||||||
@ -7,7 +7,7 @@ packages = ["embodied_gen"]
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "embodied_gen"
|
name = "embodied_gen"
|
||||||
version = "v0.1.6"
|
version = "v0.1.5"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
license-files = ["LICENSE", "NOTICE"]
|
license-files = ["LICENSE", "NOTICE"]
|
||||||
@ -24,10 +24,6 @@ dev = [
|
|||||||
"isort",
|
"isort",
|
||||||
"pytest",
|
"pytest",
|
||||||
"pytest-mock",
|
"pytest-mock",
|
||||||
"mkdocs",
|
|
||||||
"mkdocs-material",
|
|
||||||
"mkdocstrings[python]",
|
|
||||||
"mkdocs-git-revision-date-localized-plugin",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|||||||
@ -1,16 +0,0 @@
|
|||||||
IMAGE=embodiedgen:v0.1.1
|
|
||||||
CONTAINER=EmbodiedGen-docker-${USER}
|
|
||||||
docker pull ${IMAGE}
|
|
||||||
docker run -itd --shm-size="64g" --gpus all --cap-add=SYS_PTRACE \
|
|
||||||
--security-opt seccomp=unconfined --privileged --net=host \
|
|
||||||
-v /data1/liy/projects/EmbodiedGen:/EmbodiedGen \
|
|
||||||
--name ${CONTAINER} ${IMAGE}
|
|
||||||
|
|
||||||
docker exec -it ${CONTAINER} bash
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
export no_proxy="localhost,127.0.0.1,192.168.48.210,120.48.161.22"
|
|
||||||
export ENDPOINT="https://llmproxy.d-robotics.cc/v1"
|
|
||||||
export API_KEY="sk-B8urDShf4TLeruwI3dB8286485Aa4984A722E945F566EfF4"
|
|
||||||
export MODEL_NAME="azure/gpt-4.1"
|
|
||||||
@ -1,2 +1,2 @@
|
|||||||
[pycodestyle]
|
[pycodestyle]
|
||||||
ignore = E203,W503,E402,E501,E251
|
ignore = E203,W503,E402,E501
|
||||||
|
|||||||
@ -1,12 +1,6 @@
|
|||||||
from typing import Literal
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from embodied_gen.data.asset_converter import (
|
from embodied_gen.data.asset_converter import AssetConverterFactory, AssetType
|
||||||
AssetConverterFactory,
|
|
||||||
cvt_embodiedgen_asset_to_anysim,
|
|
||||||
)
|
|
||||||
from embodied_gen.utils.enum import AssetType, SimAssetMapper
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
@ -62,26 +56,3 @@ def test_MeshtoUSDConverter(data_dir):
|
|||||||
|
|
||||||
assert output_file.exists(), f"Output not generated: {output_file}"
|
assert output_file.exists(), f"Output not generated: {output_file}"
|
||||||
assert output_file.stat().st_size > 0
|
assert output_file.stat().st_size > 0
|
||||||
|
|
||||||
|
|
||||||
def test_cvt_embodiedgen_asset_to_anysim(
|
|
||||||
simulator_name: Literal[
|
|
||||||
"isaacsim",
|
|
||||||
"isaacgym",
|
|
||||||
"genesis",
|
|
||||||
"pybullet",
|
|
||||||
"sapien3",
|
|
||||||
"mujoco",
|
|
||||||
] = "mujoco",
|
|
||||||
):
|
|
||||||
dst_asset_path = cvt_embodiedgen_asset_to_anysim(
|
|
||||||
urdf_files=[
|
|
||||||
"outputs/embodiedgen_assets/demo_assets/remote_control/result/remote_control.urdf",
|
|
||||||
],
|
|
||||||
target_dirs=[
|
|
||||||
"outputs/embodiedgen_assets/demo_assets/remote_control/usd/remote_control.usd",
|
|
||||||
],
|
|
||||||
target_type=SimAssetMapper[simulator_name],
|
|
||||||
source_type=AssetType.MESH,
|
|
||||||
overwrite=True,
|
|
||||||
)
|
|
||||||
|
|||||||
1
thirdparty/TRELLIS
vendored
Submodule
@ -0,0 +1 @@
|
|||||||
|
Subproject commit 55a8e8164b195bbf927e0978f00e76c835e6011f
|
||||||
1
thirdparty/pano2room
vendored
Submodule
@ -0,0 +1 @@
|
|||||||
|
Subproject commit bbf93ae57086ed700edc6ee445852d4457a9d704
|
||||||