feat(docs): Improve docstrings across the codebase and docs. (#56)
This commit is contained in:
parent
cd94669770
commit
a256674bf2
@ -37,7 +37,7 @@
|
|||||||
```sh
|
```sh
|
||||||
git clone https://github.com/HorizonRobotics/EmbodiedGen.git
|
git clone https://github.com/HorizonRobotics/EmbodiedGen.git
|
||||||
cd EmbodiedGen
|
cd EmbodiedGen
|
||||||
git checkout v0.1.5
|
git checkout v0.1.6
|
||||||
git submodule update --init --recursive --progress
|
git submodule update --init --recursive --progress
|
||||||
conda create -n embodiedgen python=3.10.13 -y # recommended to use a new env.
|
conda create -n embodiedgen python=3.10.13 -y # recommended to use a new env.
|
||||||
conda activate embodiedgen
|
conda activate embodiedgen
|
||||||
|
|||||||
@ -31,8 +31,8 @@ from typing import Any, Dict, Tuple
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import yaml
|
|
||||||
from app_style import custom_theme, lighting_css
|
from app_style import custom_theme, lighting_css
|
||||||
|
from embodied_gen.utils.tags import VERSION
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from embodied_gen.utils.gpt_clients import GPT_CLIENT as gpt_client
|
from embodied_gen.utils.gpt_clients import GPT_CLIENT as gpt_client
|
||||||
@ -48,7 +48,6 @@ except Exception as e:
|
|||||||
|
|
||||||
|
|
||||||
# --- Configuration & Data Loading ---
|
# --- Configuration & Data Loading ---
|
||||||
VERSION = "v0.1.5"
|
|
||||||
RUNNING_MODE = "local" # local or hf_remote
|
RUNNING_MODE = "local" # local or hf_remote
|
||||||
CSV_FILE = "dataset_index.csv"
|
CSV_FILE = "dataset_index.csv"
|
||||||
|
|
||||||
|
|||||||
@ -7,7 +7,7 @@ hide:
|
|||||||
```sh
|
```sh
|
||||||
git clone https://github.com/HorizonRobotics/EmbodiedGen.git
|
git clone https://github.com/HorizonRobotics/EmbodiedGen.git
|
||||||
cd EmbodiedGen
|
cd EmbodiedGen
|
||||||
git checkout v0.1.5
|
git checkout v0.1.6
|
||||||
git submodule update --init --recursive --progress
|
git submodule update --init --recursive --progress
|
||||||
conda create -n embodiedgen python=3.10.13 -y # recommended to use a new env.
|
conda create -n embodiedgen python=3.10.13 -y # recommended to use a new env.
|
||||||
conda activate embodiedgen
|
conda activate embodiedgen
|
||||||
|
|||||||
@ -35,7 +35,8 @@ Leverage **EmbodiedGen-generated assets** with *accurate physical collisions* an
|
|||||||
## 🧱 Example: Conversion to Target Simulator
|
## 🧱 Example: Conversion to Target Simulator
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from embodied_gen.data.asset_converter import SimAssetMapper, cvt_embodiedgen_asset_to_anysim
|
from embodied_gen.data.asset_converter import cvt_embodiedgen_asset_to_anysim
|
||||||
|
from embodied_gen.utils.enum import AssetType, SimAssetMapper
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
simulator_name: Literal[
|
simulator_name: Literal[
|
||||||
@ -52,6 +53,10 @@ dst_asset_path = cvt_embodiedgen_asset_to_anysim(
|
|||||||
"path1_to_embodiedgen_asset/asset.urdf",
|
"path1_to_embodiedgen_asset/asset.urdf",
|
||||||
"path2_to_embodiedgen_asset/asset.urdf",
|
"path2_to_embodiedgen_asset/asset.urdf",
|
||||||
],
|
],
|
||||||
|
target_dirs=[
|
||||||
|
"path1_to_target_dir/asset.usd",
|
||||||
|
"path2_to_target_dir/asset.usd",
|
||||||
|
],
|
||||||
target_type=SimAssetMapper[simulator_name],
|
target_type=SimAssetMapper[simulator_name],
|
||||||
source_type=AssetType.MESH,
|
source_type=AssetType.MESH,
|
||||||
overwrite=True,
|
overwrite=True,
|
||||||
|
|||||||
@ -4,12 +4,12 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from shutil import copy, copytree, rmtree
|
from shutil import copy, copytree, rmtree
|
||||||
|
|
||||||
import trimesh
|
import trimesh
|
||||||
from scipy.spatial.transform import Rotation
|
from scipy.spatial.transform import Rotation
|
||||||
|
from embodied_gen.utils.enum import AssetType
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -17,75 +17,62 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AssetConverterFactory",
|
"AssetConverterFactory",
|
||||||
"AssetType",
|
|
||||||
"MeshtoMJCFConverter",
|
"MeshtoMJCFConverter",
|
||||||
"MeshtoUSDConverter",
|
"MeshtoUSDConverter",
|
||||||
"URDFtoUSDConverter",
|
"URDFtoUSDConverter",
|
||||||
"cvt_embodiedgen_asset_to_anysim",
|
"cvt_embodiedgen_asset_to_anysim",
|
||||||
"PhysicsUSDAdder",
|
"PhysicsUSDAdder",
|
||||||
"SimAssetMapper",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AssetType(str):
|
|
||||||
"""Asset type enumeration."""
|
|
||||||
|
|
||||||
MJCF = "mjcf"
|
|
||||||
USD = "usd"
|
|
||||||
URDF = "urdf"
|
|
||||||
MESH = "mesh"
|
|
||||||
|
|
||||||
|
|
||||||
class SimAssetMapper:
|
|
||||||
_mapping = dict(
|
|
||||||
ISAACSIM=AssetType.USD,
|
|
||||||
ISAACGYM=AssetType.URDF,
|
|
||||||
MUJOCO=AssetType.MJCF,
|
|
||||||
GENESIS=AssetType.MJCF,
|
|
||||||
SAPIEN=AssetType.URDF,
|
|
||||||
PYBULLET=AssetType.URDF,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __class_getitem__(cls, key: str):
|
|
||||||
key = key.upper()
|
|
||||||
if key.startswith("SAPIEN"):
|
|
||||||
key = "SAPIEN"
|
|
||||||
return cls._mapping[key]
|
|
||||||
|
|
||||||
|
|
||||||
def cvt_embodiedgen_asset_to_anysim(
|
def cvt_embodiedgen_asset_to_anysim(
|
||||||
urdf_files: list[str],
|
urdf_files: list[str],
|
||||||
|
target_dirs: list[str],
|
||||||
target_type: AssetType,
|
target_type: AssetType,
|
||||||
source_type: AssetType,
|
source_type: AssetType,
|
||||||
overwrite: bool = False,
|
overwrite: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
"""Convert URDF files generated by EmbodiedGen into the format required by all simulators.
|
"""Convert URDF files generated by EmbodiedGen into formats required by simulators.
|
||||||
|
|
||||||
Supported simulators include SAPIEN, Isaac Sim, MuJoCo, Isaac Gym, Genesis, and Pybullet.
|
Supported simulators include SAPIEN, Isaac Sim, MuJoCo, Isaac Gym, Genesis, and Pybullet.
|
||||||
|
Converting to the `USD` format requires `isaacsim` to be installed.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```py
|
||||||
|
from embodied_gen.data.asset_converter import cvt_embodiedgen_asset_to_anysim
|
||||||
|
from embodied_gen.utils.enum import AssetType
|
||||||
|
|
||||||
dst_asset_path = cvt_embodiedgen_asset_to_anysim(
|
dst_asset_path = cvt_embodiedgen_asset_to_anysim(
|
||||||
urdf_files,
|
urdf_files=[
|
||||||
target_type=SimAssetMapper[simulator_name],
|
"path1_to_embodiedgen_asset/asset.urdf",
|
||||||
|
"path2_to_embodiedgen_asset/asset.urdf",
|
||||||
|
],
|
||||||
|
target_dirs=[
|
||||||
|
"path1_to_target_dir/asset.usd",
|
||||||
|
"path2_to_target_dir/asset.usd",
|
||||||
|
],
|
||||||
|
target_type=AssetType.USD,
|
||||||
source_type=AssetType.MESH,
|
source_type=AssetType.MESH,
|
||||||
)
|
)
|
||||||
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
urdf_files (List[str]): List of URDF file paths to be converted.
|
urdf_files (list[str]): List of URDF file paths.
|
||||||
target_type (AssetType): The target asset type.
|
target_dirs (list[str]): List of target directories.
|
||||||
source_type (AssetType): The source asset type.
|
target_type (AssetType): Target asset type.
|
||||||
overwrite (bool): Whether to overwrite existing converted files.
|
source_type (AssetType): Source asset type.
|
||||||
**kwargs: Additional keyword arguments for the converter.
|
overwrite (bool, optional): Overwrite existing files.
|
||||||
|
**kwargs: Additional converter arguments.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, str]: A dictionary mapping the original URDF file path to the converted asset file path.
|
dict[str, str]: Mapping from URDF file to converted asset file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(urdf_files, str):
|
if isinstance(urdf_files, str):
|
||||||
urdf_files = [urdf_files]
|
urdf_files = [urdf_files]
|
||||||
|
if isinstance(target_dirs, str):
|
||||||
|
urdf_files = [target_dirs]
|
||||||
|
|
||||||
# If the target type is URDF, no conversion is needed.
|
# If the target type is URDF, no conversion is needed.
|
||||||
if target_type == AssetType.URDF:
|
if target_type == AssetType.URDF:
|
||||||
@ -99,18 +86,17 @@ def cvt_embodiedgen_asset_to_anysim(
|
|||||||
asset_paths = dict()
|
asset_paths = dict()
|
||||||
|
|
||||||
with asset_converter:
|
with asset_converter:
|
||||||
for urdf_file in urdf_files:
|
for urdf_file, target_dir in zip(urdf_files, target_dirs):
|
||||||
filename = os.path.basename(urdf_file).replace(".urdf", "")
|
filename = os.path.basename(urdf_file).replace(".urdf", "")
|
||||||
asset_dir = os.path.dirname(urdf_file)
|
|
||||||
if target_type == AssetType.MJCF:
|
if target_type == AssetType.MJCF:
|
||||||
target_file = f"{asset_dir}/../mjcf/{filename}.xml"
|
target_file = f"{target_dir}/{filename}.xml"
|
||||||
elif target_type == AssetType.USD:
|
elif target_type == AssetType.USD:
|
||||||
target_file = f"{asset_dir}/../usd/{filename}.usd"
|
target_file = f"{target_dir}/{filename}.usd"
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Target type {target_type} not supported."
|
f"Target type {target_type} not supported."
|
||||||
)
|
)
|
||||||
if not os.path.exists(target_file):
|
if not os.path.exists(target_file) or overwrite:
|
||||||
asset_converter.convert(urdf_file, target_file)
|
asset_converter.convert(urdf_file, target_file)
|
||||||
|
|
||||||
asset_paths[urdf_file] = target_file
|
asset_paths[urdf_file] = target_file
|
||||||
@ -119,16 +105,35 @@ def cvt_embodiedgen_asset_to_anysim(
|
|||||||
|
|
||||||
|
|
||||||
class AssetConverterBase(ABC):
|
class AssetConverterBase(ABC):
|
||||||
"""Converter abstract base class."""
|
"""Abstract base class for asset converters.
|
||||||
|
|
||||||
|
Provides context management and mesh transformation utilities.
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def convert(self, urdf_path: str, output_path: str, **kwargs) -> str:
|
def convert(self, urdf_path: str, output_path: str, **kwargs) -> str:
|
||||||
|
"""Convert an asset file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
urdf_path (str): Path to input URDF file.
|
||||||
|
output_path (str): Path to output file.
|
||||||
|
**kwargs: Additional arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Path to converted asset.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def transform_mesh(
|
def transform_mesh(
|
||||||
self, input_mesh: str, output_mesh: str, mesh_origin: ET.Element
|
self, input_mesh: str, output_mesh: str, mesh_origin: ET.Element
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Apply transform to the mesh based on the origin element in URDF."""
|
"""Apply transform to mesh based on URDF origin element.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_mesh (str): Path to input mesh.
|
||||||
|
output_mesh (str): Path to output mesh.
|
||||||
|
mesh_origin (ET.Element): Origin element from URDF.
|
||||||
|
"""
|
||||||
mesh = trimesh.load(input_mesh, group_material=False)
|
mesh = trimesh.load(input_mesh, group_material=False)
|
||||||
rpy = list(map(float, mesh_origin.get("rpy").split(" ")))
|
rpy = list(map(float, mesh_origin.get("rpy").split(" ")))
|
||||||
rotation = Rotation.from_euler("xyz", rpy, degrees=False)
|
rotation = Rotation.from_euler("xyz", rpy, degrees=False)
|
||||||
@ -150,14 +155,19 @@ class AssetConverterBase(ABC):
|
|||||||
return
|
return
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
|
"""Context manager entry."""
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Context manager exit."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class MeshtoMJCFConverter(AssetConverterBase):
|
class MeshtoMJCFConverter(AssetConverterBase):
|
||||||
"""Convert URDF files into MJCF format."""
|
"""Converts mesh-based URDF files to MJCF format.
|
||||||
|
|
||||||
|
Handles geometry, materials, and asset copying.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -166,6 +176,12 @@ class MeshtoMJCFConverter(AssetConverterBase):
|
|||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
def _copy_asset_file(self, src: str, dst: str) -> None:
|
def _copy_asset_file(self, src: str, dst: str) -> None:
|
||||||
|
"""Copies asset file if not already present.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src (str): Source file path.
|
||||||
|
dst (str): Destination file path.
|
||||||
|
"""
|
||||||
if os.path.exists(dst):
|
if os.path.exists(dst):
|
||||||
return
|
return
|
||||||
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
||||||
@ -183,7 +199,19 @@ class MeshtoMJCFConverter(AssetConverterBase):
|
|||||||
material: ET.Element | None = None,
|
material: ET.Element | None = None,
|
||||||
is_collision: bool = False,
|
is_collision: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add geometry to the MJCF body from the URDF link."""
|
"""Adds geometry to MJCF body from URDF link.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mujoco_element (ET.Element): MJCF asset element.
|
||||||
|
link (ET.Element): URDF link element.
|
||||||
|
body (ET.Element): MJCF body element.
|
||||||
|
tag (str): Tag name ("visual" or "collision").
|
||||||
|
input_dir (str): Input directory.
|
||||||
|
output_dir (str): Output directory.
|
||||||
|
mesh_name (str): Mesh name.
|
||||||
|
material (ET.Element, optional): Material element.
|
||||||
|
is_collision (bool, optional): If True, treat as collision geometry.
|
||||||
|
"""
|
||||||
element = link.find(tag)
|
element = link.find(tag)
|
||||||
geometry = element.find("geometry")
|
geometry = element.find("geometry")
|
||||||
mesh = geometry.find("mesh")
|
mesh = geometry.find("mesh")
|
||||||
@ -242,7 +270,20 @@ class MeshtoMJCFConverter(AssetConverterBase):
|
|||||||
name: str,
|
name: str,
|
||||||
reflectance: float = 0.2,
|
reflectance: float = 0.2,
|
||||||
) -> ET.Element:
|
) -> ET.Element:
|
||||||
"""Add materials to the MJCF asset from the URDF link."""
|
"""Adds materials to MJCF asset from URDF link.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mujoco_element (ET.Element): MJCF asset element.
|
||||||
|
link (ET.Element): URDF link element.
|
||||||
|
tag (str): Tag name.
|
||||||
|
input_dir (str): Input directory.
|
||||||
|
output_dir (str): Output directory.
|
||||||
|
name (str): Material name.
|
||||||
|
reflectance (float, optional): Reflectance value.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ET.Element: Material element.
|
||||||
|
"""
|
||||||
element = link.find(tag)
|
element = link.find(tag)
|
||||||
geometry = element.find("geometry")
|
geometry = element.find("geometry")
|
||||||
mesh = geometry.find("mesh")
|
mesh = geometry.find("mesh")
|
||||||
@ -282,7 +323,12 @@ class MeshtoMJCFConverter(AssetConverterBase):
|
|||||||
return material
|
return material
|
||||||
|
|
||||||
def convert(self, urdf_path: str, mjcf_path: str):
|
def convert(self, urdf_path: str, mjcf_path: str):
|
||||||
"""Convert a URDF file to MJCF format."""
|
"""Converts a URDF file to MJCF format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
urdf_path (str): Path to URDF file.
|
||||||
|
mjcf_path (str): Path to output MJCF file.
|
||||||
|
"""
|
||||||
tree = ET.parse(urdf_path)
|
tree = ET.parse(urdf_path)
|
||||||
root = tree.getroot()
|
root = tree.getroot()
|
||||||
|
|
||||||
@ -336,10 +382,22 @@ class MeshtoMJCFConverter(AssetConverterBase):
|
|||||||
|
|
||||||
|
|
||||||
class URDFtoMJCFConverter(MeshtoMJCFConverter):
|
class URDFtoMJCFConverter(MeshtoMJCFConverter):
|
||||||
"""Convert URDF files with joints to MJCF format, handling transformations from joints."""
|
"""Converts URDF files with joints to MJCF format, handling joint transformations.
|
||||||
|
|
||||||
|
Handles fixed joints and hierarchical body structure.
|
||||||
|
"""
|
||||||
|
|
||||||
def convert(self, urdf_path: str, mjcf_path: str, **kwargs) -> str:
|
def convert(self, urdf_path: str, mjcf_path: str, **kwargs) -> str:
|
||||||
"""Convert a URDF file with joints to MJCF format."""
|
"""Converts a URDF file with joints to MJCF format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
urdf_path (str): Path to URDF file.
|
||||||
|
mjcf_path (str): Path to output MJCF file.
|
||||||
|
**kwargs: Additional arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Path to converted MJCF file.
|
||||||
|
"""
|
||||||
tree = ET.parse(urdf_path)
|
tree = ET.parse(urdf_path)
|
||||||
root = tree.getroot()
|
root = tree.getroot()
|
||||||
|
|
||||||
@ -423,7 +481,10 @@ class URDFtoMJCFConverter(MeshtoMJCFConverter):
|
|||||||
|
|
||||||
|
|
||||||
class MeshtoUSDConverter(AssetConverterBase):
|
class MeshtoUSDConverter(AssetConverterBase):
|
||||||
"""Convert Mesh file from URDF into USD format."""
|
"""Converts mesh-based URDF files to USD format.
|
||||||
|
|
||||||
|
Adds physics APIs and post-processes collision meshes.
|
||||||
|
"""
|
||||||
|
|
||||||
DEFAULT_BIND_APIS = [
|
DEFAULT_BIND_APIS = [
|
||||||
"MaterialBindingAPI",
|
"MaterialBindingAPI",
|
||||||
@ -443,6 +504,14 @@ class MeshtoUSDConverter(AssetConverterBase):
|
|||||||
simulation_app=None,
|
simulation_app=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
"""Initializes the converter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
force_usd_conversion (bool, optional): Force USD conversion.
|
||||||
|
make_instanceable (bool, optional): Make prims instanceable.
|
||||||
|
simulation_app (optional): Simulation app instance.
|
||||||
|
**kwargs: Additional arguments.
|
||||||
|
"""
|
||||||
if simulation_app is not None:
|
if simulation_app is not None:
|
||||||
self.simulation_app = simulation_app
|
self.simulation_app = simulation_app
|
||||||
|
|
||||||
@ -458,6 +527,7 @@ class MeshtoUSDConverter(AssetConverterBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
|
"""Context manager entry, launches simulation app if needed."""
|
||||||
from isaaclab.app import AppLauncher
|
from isaaclab.app import AppLauncher
|
||||||
|
|
||||||
if not hasattr(self, "simulation_app"):
|
if not hasattr(self, "simulation_app"):
|
||||||
@ -476,6 +546,7 @@ class MeshtoUSDConverter(AssetConverterBase):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Context manager exit, closes simulation app if created."""
|
||||||
# Close the simulation app if it was created here
|
# Close the simulation app if it was created here
|
||||||
if hasattr(self, "app_launcher") and self.exit_close:
|
if hasattr(self, "app_launcher") and self.exit_close:
|
||||||
self.simulation_app.close()
|
self.simulation_app.close()
|
||||||
@ -486,7 +557,12 @@ class MeshtoUSDConverter(AssetConverterBase):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def convert(self, urdf_path: str, output_file: str):
|
def convert(self, urdf_path: str, output_file: str):
|
||||||
"""Convert a URDF file to USD and post-process collision meshes."""
|
"""Converts a URDF file to USD and post-processes collision meshes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
urdf_path (str): Path to URDF file.
|
||||||
|
output_file (str): Path to output USD file.
|
||||||
|
"""
|
||||||
from isaaclab.sim.converters import MeshConverter, MeshConverterCfg
|
from isaaclab.sim.converters import MeshConverter, MeshConverterCfg
|
||||||
from pxr import PhysxSchema, Sdf, Usd, UsdShade
|
from pxr import PhysxSchema, Sdf, Usd, UsdShade
|
||||||
|
|
||||||
@ -556,6 +632,11 @@ class MeshtoUSDConverter(AssetConverterBase):
|
|||||||
|
|
||||||
|
|
||||||
class PhysicsUSDAdder(MeshtoUSDConverter):
|
class PhysicsUSDAdder(MeshtoUSDConverter):
|
||||||
|
"""Adds physics APIs and collision properties to USD assets.
|
||||||
|
|
||||||
|
Useful for post-processing USD files for simulation.
|
||||||
|
"""
|
||||||
|
|
||||||
DEFAULT_BIND_APIS = [
|
DEFAULT_BIND_APIS = [
|
||||||
"MaterialBindingAPI",
|
"MaterialBindingAPI",
|
||||||
# "PhysicsMeshCollisionAPI",
|
# "PhysicsMeshCollisionAPI",
|
||||||
@ -566,6 +647,12 @@ class PhysicsUSDAdder(MeshtoUSDConverter):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def convert(self, usd_path: str, output_file: str = None):
|
def convert(self, usd_path: str, output_file: str = None):
|
||||||
|
"""Adds physics APIs and collision properties to a USD file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
usd_path (str): Path to input USD file.
|
||||||
|
output_file (str, optional): Path to output USD file.
|
||||||
|
"""
|
||||||
from pxr import PhysxSchema, Sdf, Usd, UsdGeom, UsdPhysics
|
from pxr import PhysxSchema, Sdf, Usd, UsdGeom, UsdPhysics
|
||||||
|
|
||||||
if output_file is None:
|
if output_file is None:
|
||||||
@ -626,14 +713,18 @@ class PhysicsUSDAdder(MeshtoUSDConverter):
|
|||||||
|
|
||||||
|
|
||||||
class URDFtoUSDConverter(MeshtoUSDConverter):
|
class URDFtoUSDConverter(MeshtoUSDConverter):
|
||||||
"""Convert URDF files into USD format.
|
"""Converts URDF files to USD format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fix_base (bool): Whether to fix the base link.
|
fix_base (bool, optional): Fix the base link.
|
||||||
merge_fixed_joints (bool): Whether to merge fixed joints.
|
merge_fixed_joints (bool, optional): Merge fixed joints.
|
||||||
make_instanceable (bool): Whether to make prims instanceable.
|
make_instanceable (bool, optional): Make prims instanceable.
|
||||||
force_usd_conversion (bool): Force conversion to USD.
|
force_usd_conversion (bool, optional): Force conversion to USD.
|
||||||
collision_from_visuals (bool): Generate collisions from visuals if not provided.
|
collision_from_visuals (bool, optional): Generate collisions from visuals.
|
||||||
|
joint_drive (optional): Joint drive configuration.
|
||||||
|
rotate_wxyz (tuple[float], optional): Quaternion for rotation.
|
||||||
|
simulation_app (optional): Simulation app instance.
|
||||||
|
**kwargs: Additional arguments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -648,6 +739,19 @@ class URDFtoUSDConverter(MeshtoUSDConverter):
|
|||||||
simulation_app=None,
|
simulation_app=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
"""Initializes the converter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fix_base (bool, optional): Fix the base link.
|
||||||
|
merge_fixed_joints (bool, optional): Merge fixed joints.
|
||||||
|
make_instanceable (bool, optional): Make prims instanceable.
|
||||||
|
force_usd_conversion (bool, optional): Force conversion to USD.
|
||||||
|
collision_from_visuals (bool, optional): Generate collisions from visuals.
|
||||||
|
joint_drive (optional): Joint drive configuration.
|
||||||
|
rotate_wxyz (tuple[float], optional): Quaternion for rotation.
|
||||||
|
simulation_app (optional): Simulation app instance.
|
||||||
|
**kwargs: Additional arguments.
|
||||||
|
"""
|
||||||
self.usd_parms = dict(
|
self.usd_parms = dict(
|
||||||
fix_base=fix_base,
|
fix_base=fix_base,
|
||||||
merge_fixed_joints=merge_fixed_joints,
|
merge_fixed_joints=merge_fixed_joints,
|
||||||
@ -662,7 +766,12 @@ class URDFtoUSDConverter(MeshtoUSDConverter):
|
|||||||
self.simulation_app = simulation_app
|
self.simulation_app = simulation_app
|
||||||
|
|
||||||
def convert(self, urdf_path: str, output_file: str):
|
def convert(self, urdf_path: str, output_file: str):
|
||||||
"""Convert a URDF file to USD and post-process collision meshes."""
|
"""Converts a URDF file to USD and post-processes collision meshes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
urdf_path (str): Path to URDF file.
|
||||||
|
output_file (str): Path to output USD file.
|
||||||
|
"""
|
||||||
from isaaclab.sim.converters import UrdfConverter, UrdfConverterCfg
|
from isaaclab.sim.converters import UrdfConverter, UrdfConverterCfg
|
||||||
from pxr import Gf, PhysxSchema, Sdf, Usd, UsdGeom
|
from pxr import Gf, PhysxSchema, Sdf, Usd, UsdGeom
|
||||||
|
|
||||||
@ -723,13 +832,36 @@ class URDFtoUSDConverter(MeshtoUSDConverter):
|
|||||||
|
|
||||||
|
|
||||||
class AssetConverterFactory:
|
class AssetConverterFactory:
|
||||||
"""Factory class for creating asset converters based on target and source types."""
|
"""Factory for creating asset converters based on target and source types.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
from embodied_gen.data.asset_converter import AssetConverterFactory
|
||||||
|
from embodied_gen.utils.enum import AssetType
|
||||||
|
|
||||||
|
converter = AssetConverterFactory.create(
|
||||||
|
target_type=AssetType.USD, source_type=AssetType.MESH
|
||||||
|
)
|
||||||
|
with converter:
|
||||||
|
for urdf_path, output_file in zip(urdf_paths, output_files):
|
||||||
|
converter.convert(urdf_path, output_file)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(
|
def create(
|
||||||
target_type: AssetType, source_type: AssetType = "urdf", **kwargs
|
target_type: AssetType, source_type: AssetType = "urdf", **kwargs
|
||||||
) -> AssetConverterBase:
|
) -> AssetConverterBase:
|
||||||
"""Create an asset converter instance based on target and source types."""
|
"""Creates an asset converter instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_type (AssetType): Target asset type.
|
||||||
|
source_type (AssetType, optional): Source asset type.
|
||||||
|
**kwargs: Additional arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AssetConverterBase: Converter instance.
|
||||||
|
"""
|
||||||
if target_type == AssetType.MJCF and source_type == AssetType.MESH:
|
if target_type == AssetType.MJCF and source_type == AssetType.MESH:
|
||||||
converter = MeshtoMJCFConverter(**kwargs)
|
converter = MeshtoMJCFConverter(**kwargs)
|
||||||
elif target_type == AssetType.MJCF and source_type == AssetType.URDF:
|
elif target_type == AssetType.MJCF and source_type == AssetType.URDF:
|
||||||
@ -751,7 +883,14 @@ if __name__ == "__main__":
|
|||||||
# target_asset_type = AssetType.USD
|
# target_asset_type = AssetType.USD
|
||||||
|
|
||||||
urdf_paths = [
|
urdf_paths = [
|
||||||
"outputs/embodiedgen_assets/demo_assets/remote_control/result/remote_control.urdf",
|
'outputs/EmbodiedGenData/demo_assets/banana/result/banana.urdf',
|
||||||
|
'outputs/EmbodiedGenData/demo_assets/book/result/book.urdf',
|
||||||
|
'outputs/EmbodiedGenData/demo_assets/lamp/result/lamp.urdf',
|
||||||
|
'outputs/EmbodiedGenData/demo_assets/mug/result/mug.urdf',
|
||||||
|
'outputs/EmbodiedGenData/demo_assets/remote_control/result/remote_control.urdf',
|
||||||
|
"outputs/EmbodiedGenData/demo_assets/rubik's_cube/result/rubik's_cube.urdf",
|
||||||
|
'outputs/EmbodiedGenData/demo_assets/table/result/table.urdf',
|
||||||
|
'outputs/EmbodiedGenData/demo_assets/vase/result/vase.urdf',
|
||||||
]
|
]
|
||||||
|
|
||||||
if target_asset_type == AssetType.MJCF:
|
if target_asset_type == AssetType.MJCF:
|
||||||
@ -765,7 +904,14 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
elif target_asset_type == AssetType.USD:
|
elif target_asset_type == AssetType.USD:
|
||||||
output_files = [
|
output_files = [
|
||||||
"outputs/embodiedgen_assets/demo_assets/remote_control/usd/remote_control.usd",
|
'outputs/EmbodiedGenData/demo_assets/banana/usd/banana.usd',
|
||||||
|
'outputs/EmbodiedGenData/demo_assets/book/usd/book.usd',
|
||||||
|
'outputs/EmbodiedGenData/demo_assets/lamp/usd/lamp.usd',
|
||||||
|
'outputs/EmbodiedGenData/demo_assets/mug/usd/mug.usd',
|
||||||
|
'outputs/EmbodiedGenData/demo_assets/remote_control/usd/remote_control.usd',
|
||||||
|
"outputs/EmbodiedGenData/demo_assets/rubik's_cube/usd/rubik's_cube.usd",
|
||||||
|
'outputs/EmbodiedGenData/demo_assets/table/usd/table.usd',
|
||||||
|
'outputs/EmbodiedGenData/demo_assets/vase/usd/vase.usd',
|
||||||
]
|
]
|
||||||
asset_converter = AssetConverterFactory.create(
|
asset_converter = AssetConverterFactory.create(
|
||||||
target_type=AssetType.USD,
|
target_type=AssetType.USD,
|
||||||
@ -776,33 +922,33 @@ if __name__ == "__main__":
|
|||||||
for urdf_path, output_file in zip(urdf_paths, output_files):
|
for urdf_path, output_file in zip(urdf_paths, output_files):
|
||||||
asset_converter.convert(urdf_path, output_file)
|
asset_converter.convert(urdf_path, output_file)
|
||||||
|
|
||||||
urdf_path = "outputs/embodiedgen_assets/demo_assets/remote_control/result/remote_control.urdf"
|
# urdf_path = "outputs/embodiedgen_assets/demo_assets/remote_control/result/remote_control.urdf"
|
||||||
output_file = "outputs/embodiedgen_assets/demo_assets/remote_control/usd/remote_control.usd"
|
# output_file = "outputs/embodiedgen_assets/demo_assets/remote_control/usd/remote_control.usd"
|
||||||
|
|
||||||
asset_converter = AssetConverterFactory.create(
|
# asset_converter = AssetConverterFactory.create(
|
||||||
target_type=AssetType.USD,
|
# target_type=AssetType.USD,
|
||||||
source_type=AssetType.URDF,
|
# source_type=AssetType.URDF,
|
||||||
rotate_wxyz=(0.7071, 0.7071, 0, 0), # rotate 90 deg around the X-axis
|
# rotate_wxyz=(0.7071, 0.7071, 0, 0), # rotate 90 deg around the X-axis
|
||||||
)
|
# )
|
||||||
|
|
||||||
with asset_converter:
|
# with asset_converter:
|
||||||
asset_converter.convert(urdf_path, output_file)
|
# asset_converter.convert(urdf_path, output_file)
|
||||||
|
|
||||||
# Convert infinigen urdf to mjcf
|
# # Convert infinigen urdf to mjcf
|
||||||
urdf_path = "/home/users/xinjie.wang/xinjie/infinigen/outputs/exports/kitchen_i_urdf/export_scene/scene.urdf"
|
# urdf_path = "/home/users/xinjie.wang/xinjie/infinigen/outputs/exports/kitchen_i_urdf/export_scene/scene.urdf"
|
||||||
output_file = "/home/users/xinjie.wang/xinjie/infinigen/outputs/exports/kitchen_i_urdf/mjcf/scene.xml"
|
# output_file = "/home/users/xinjie.wang/xinjie/infinigen/outputs/exports/kitchen_i_urdf/mjcf/scene.xml"
|
||||||
asset_converter = AssetConverterFactory.create(
|
# asset_converter = AssetConverterFactory.create(
|
||||||
target_type=AssetType.MJCF,
|
# target_type=AssetType.MJCF,
|
||||||
source_type=AssetType.URDF,
|
# source_type=AssetType.URDF,
|
||||||
keep_materials=["diffuse"],
|
# keep_materials=["diffuse"],
|
||||||
)
|
# )
|
||||||
with asset_converter:
|
# with asset_converter:
|
||||||
asset_converter.convert(urdf_path, output_file)
|
# asset_converter.convert(urdf_path, output_file)
|
||||||
|
|
||||||
# Convert infinigen usdc to physics usdc
|
# # Convert infinigen usdc to physics usdc
|
||||||
converter = PhysicsUSDAdder()
|
# converter = PhysicsUSDAdder()
|
||||||
with converter:
|
# with converter:
|
||||||
converter.convert(
|
# converter.convert(
|
||||||
usd_path="/home/users/xinjie.wang/xinjie/infinigen/outputs/usdc/export_scene/export_scene.usdc",
|
# usd_path="/home/users/xinjie.wang/xinjie/infinigen/outputs/usdc/export_scene/export_scene.usdc",
|
||||||
output_file="/home/users/xinjie.wang/xinjie/infinigen/outputs/usdc_p3/export_scene/export_scene.usdc",
|
# output_file="/home/users/xinjie.wang/xinjie/infinigen/outputs/usdc_p3/export_scene/export_scene.usdc",
|
||||||
)
|
# )
|
||||||
|
|||||||
@ -58,7 +58,16 @@ __all__ = [
|
|||||||
def _transform_vertices(
|
def _transform_vertices(
|
||||||
mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False
|
mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Transform 3D vertices using a projection matrix."""
|
"""Transforms 3D vertices using a projection matrix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mtx (torch.Tensor): Projection matrix.
|
||||||
|
pos (torch.Tensor): Vertex positions.
|
||||||
|
keepdim (bool, optional): If True, keeps the batch dimension.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Transformed vertices.
|
||||||
|
"""
|
||||||
t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype)
|
t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype)
|
||||||
if pos.size(-1) == 3:
|
if pos.size(-1) == 3:
|
||||||
pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1)
|
pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1)
|
||||||
@ -71,7 +80,17 @@ def _transform_vertices(
|
|||||||
def _bilinear_interpolation_scattering(
|
def _bilinear_interpolation_scattering(
|
||||||
image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor
|
image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Bilinear interpolation scattering for grid-based value accumulation."""
|
"""Performs bilinear interpolation scattering for grid-based value accumulation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_h (int): Image height.
|
||||||
|
image_w (int): Image width.
|
||||||
|
coords (torch.Tensor): Normalized coordinates.
|
||||||
|
values (torch.Tensor): Values to scatter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Interpolated grid.
|
||||||
|
"""
|
||||||
device = values.device
|
device = values.device
|
||||||
dtype = values.dtype
|
dtype = values.dtype
|
||||||
C = values.shape[-1]
|
C = values.shape[-1]
|
||||||
@ -135,7 +154,18 @@ def _texture_inpaint_smooth(
|
|||||||
faces: np.ndarray,
|
faces: np.ndarray,
|
||||||
uv_map: np.ndarray,
|
uv_map: np.ndarray,
|
||||||
) -> tuple[np.ndarray, np.ndarray]:
|
) -> tuple[np.ndarray, np.ndarray]:
|
||||||
"""Perform texture inpainting using vertex-based color propagation."""
|
"""Performs texture inpainting using vertex-based color propagation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texture (np.ndarray): Texture image.
|
||||||
|
mask (np.ndarray): Mask image.
|
||||||
|
vertices (np.ndarray): Mesh vertices.
|
||||||
|
faces (np.ndarray): Mesh faces.
|
||||||
|
uv_map (np.ndarray): UV coordinates.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[np.ndarray, np.ndarray]: Inpainted texture and updated mask.
|
||||||
|
"""
|
||||||
image_h, image_w, C = texture.shape
|
image_h, image_w, C = texture.shape
|
||||||
N = vertices.shape[0]
|
N = vertices.shape[0]
|
||||||
|
|
||||||
@ -231,29 +261,41 @@ def _texture_inpaint_smooth(
|
|||||||
class TextureBacker:
|
class TextureBacker:
|
||||||
"""Texture baking pipeline for multi-view projection and fusion.
|
"""Texture baking pipeline for multi-view projection and fusion.
|
||||||
|
|
||||||
This class performs UV-based texture generation for a 3D mesh using
|
This class generates UV-based textures for a 3D mesh using multi-view images,
|
||||||
multi-view color images, depth, and normal information. The pipeline
|
depth, and normal information. It includes mesh normalization, UV unwrapping,
|
||||||
includes mesh normalization and UV unwrapping, visibility-aware
|
visibility-aware back-projection, confidence-weighted fusion, and inpainting.
|
||||||
back-projection, confidence-weighted texture fusion, and inpainting
|
|
||||||
of missing texture regions.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
camera_params (CameraSetting): Camera intrinsics and extrinsics used
|
camera_params (CameraSetting): Camera intrinsics and extrinsics.
|
||||||
for rendering each view.
|
view_weights (list[float]): Weights for each view in texture fusion.
|
||||||
view_weights (list[float]): A list of weights for each view, used
|
render_wh (tuple[int, int], optional): Intermediate rendering resolution.
|
||||||
to blend confidence maps during texture fusion.
|
texture_wh (tuple[int, int], optional): Output texture resolution.
|
||||||
render_wh (tuple[int, int], optional): Resolution (width, height) for
|
bake_angle_thresh (int, optional): Max angle for valid projection.
|
||||||
intermediate rendering passes. Defaults to (2048, 2048).
|
mask_thresh (float, optional): Threshold for visibility masks.
|
||||||
texture_wh (tuple[int, int], optional): Output texture resolution
|
smooth_texture (bool, optional): Apply post-processing to texture.
|
||||||
(width, height). Defaults to (2048, 2048).
|
inpaint_smooth (bool, optional): Apply inpainting smoothing.
|
||||||
bake_angle_thresh (int, optional): Maximum angle (in degrees) between
|
|
||||||
view direction and surface normal for projection to be considered valid.
|
Example:
|
||||||
Defaults to 75.
|
```py
|
||||||
mask_thresh (float, optional): Threshold applied to visibility masks
|
from embodied_gen.data.backproject_v2 import TextureBacker
|
||||||
during rendering. Defaults to 0.5.
|
from embodied_gen.data.utils import CameraSetting
|
||||||
smooth_texture (bool, optional): If True, apply post-processing (e.g.,
|
import trimesh
|
||||||
blurring) to the final texture. Defaults to True.
|
from PIL import Image
|
||||||
inpaint_smooth (bool, optional): If True, apply inpainting to smooth.
|
|
||||||
|
camera_params = CameraSetting(
|
||||||
|
num_images=6,
|
||||||
|
elevation=[20, -10],
|
||||||
|
distance=5,
|
||||||
|
resolution_hw=(2048,2048),
|
||||||
|
fov=math.radians(30),
|
||||||
|
device='cuda',
|
||||||
|
)
|
||||||
|
view_weights = [1, 0.1, 0.02, 0.1, 1, 0.02]
|
||||||
|
mesh = trimesh.load('mesh.obj')
|
||||||
|
images = [Image.open(f'view_{i}.png') for i in range(6)]
|
||||||
|
texture_backer = TextureBacker(camera_params, view_weights)
|
||||||
|
textured_mesh = texture_backer(images, mesh, 'output.obj')
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -283,6 +325,12 @@ class TextureBacker:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _lazy_init_render(self, camera_params, mask_thresh):
|
def _lazy_init_render(self, camera_params, mask_thresh):
|
||||||
|
"""Lazily initializes the renderer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_params (CameraSetting): Camera settings.
|
||||||
|
mask_thresh (float): Mask threshold.
|
||||||
|
"""
|
||||||
if self.renderer is None:
|
if self.renderer is None:
|
||||||
camera = init_kal_camera(camera_params)
|
camera = init_kal_camera(camera_params)
|
||||||
mv = camera.view_matrix() # (n 4 4) world2cam
|
mv = camera.view_matrix() # (n 4 4) world2cam
|
||||||
@ -301,6 +349,14 @@ class TextureBacker:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def load_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh:
|
def load_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh:
|
||||||
|
"""Normalizes mesh and unwraps UVs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mesh (trimesh.Trimesh): Input mesh.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
trimesh.Trimesh: Mesh with normalized vertices and UVs.
|
||||||
|
"""
|
||||||
mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
|
mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
|
||||||
self.scale, self.center = scale, center
|
self.scale, self.center = scale, center
|
||||||
|
|
||||||
@ -318,6 +374,16 @@ class TextureBacker:
|
|||||||
scale: float = None,
|
scale: float = None,
|
||||||
center: np.ndarray = None,
|
center: np.ndarray = None,
|
||||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
"""Gets mesh attributes as numpy arrays.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mesh (trimesh.Trimesh): Input mesh.
|
||||||
|
scale (float, optional): Scale factor.
|
||||||
|
center (np.ndarray, optional): Center offset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (vertices, faces, uv_map)
|
||||||
|
"""
|
||||||
vertices = mesh.vertices.copy()
|
vertices = mesh.vertices.copy()
|
||||||
faces = mesh.faces.copy()
|
faces = mesh.faces.copy()
|
||||||
uv_map = mesh.visual.uv.copy()
|
uv_map = mesh.visual.uv.copy()
|
||||||
@ -331,6 +397,14 @@ class TextureBacker:
|
|||||||
return vertices, faces, uv_map
|
return vertices, faces, uv_map
|
||||||
|
|
||||||
def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor:
|
def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Computes edge image from depth map.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
depth_image (torch.Tensor): Depth map.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Edge image.
|
||||||
|
"""
|
||||||
depth_image_np = depth_image.cpu().numpy()
|
depth_image_np = depth_image.cpu().numpy()
|
||||||
depth_image_np = (depth_image_np * 255).astype(np.uint8)
|
depth_image_np = (depth_image_np * 255).astype(np.uint8)
|
||||||
depth_edges = cv2.Canny(depth_image_np, 30, 80)
|
depth_edges = cv2.Canny(depth_image_np, 30, 80)
|
||||||
@ -344,6 +418,16 @@ class TextureBacker:
|
|||||||
def compute_enhanced_viewnormal(
|
def compute_enhanced_viewnormal(
|
||||||
self, mv_mtx: torch.Tensor, vertices: torch.Tensor, faces: torch.Tensor
|
self, mv_mtx: torch.Tensor, vertices: torch.Tensor, faces: torch.Tensor
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
"""Computes enhanced view normals for mesh faces.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mv_mtx (torch.Tensor): View matrices.
|
||||||
|
vertices (torch.Tensor): Mesh vertices.
|
||||||
|
faces (torch.Tensor): Mesh faces.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: View normals.
|
||||||
|
"""
|
||||||
rast, _ = self.renderer.compute_dr_raster(vertices, faces)
|
rast, _ = self.renderer.compute_dr_raster(vertices, faces)
|
||||||
rendered_view_normals = []
|
rendered_view_normals = []
|
||||||
for idx in range(len(mv_mtx)):
|
for idx in range(len(mv_mtx)):
|
||||||
@ -376,6 +460,18 @@ class TextureBacker:
|
|||||||
def back_project(
|
def back_project(
|
||||||
self, image, vis_mask, depth, normal, uv
|
self, image, vis_mask, depth, normal, uv
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Back-projects image and confidence to UV texture space.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (PIL.Image or np.ndarray): Input image.
|
||||||
|
vis_mask (torch.Tensor): Visibility mask.
|
||||||
|
depth (torch.Tensor): Depth map.
|
||||||
|
normal (torch.Tensor): Normal map.
|
||||||
|
uv (torch.Tensor): UV coordinates.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[torch.Tensor, torch.Tensor]: Texture and confidence map.
|
||||||
|
"""
|
||||||
image = np.array(image)
|
image = np.array(image)
|
||||||
image = torch.as_tensor(image, device=self.device, dtype=torch.float32)
|
image = torch.as_tensor(image, device=self.device, dtype=torch.float32)
|
||||||
if image.ndim == 2:
|
if image.ndim == 2:
|
||||||
@ -418,6 +514,17 @@ class TextureBacker:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _scatter_texture(self, uv, data, mask):
|
def _scatter_texture(self, uv, data, mask):
|
||||||
|
"""Scatters data to texture using UV coordinates and mask.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
uv (torch.Tensor): UV coordinates.
|
||||||
|
data (torch.Tensor): Data to scatter.
|
||||||
|
mask (torch.Tensor): Mask for valid pixels.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Scattered texture.
|
||||||
|
"""
|
||||||
|
|
||||||
def __filter_data(data, mask):
|
def __filter_data(data, mask):
|
||||||
return data.view(-1, data.shape[-1])[mask]
|
return data.view(-1, data.shape[-1])[mask]
|
||||||
|
|
||||||
@ -432,6 +539,15 @@ class TextureBacker:
|
|||||||
def fast_bake_texture(
|
def fast_bake_texture(
|
||||||
self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor]
|
self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor]
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Fuses multiple textures and confidence maps.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
textures (list[torch.Tensor]): List of textures.
|
||||||
|
confidence_maps (list[torch.Tensor]): List of confidence maps.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[torch.Tensor, torch.Tensor]: Fused texture and mask.
|
||||||
|
"""
|
||||||
channel = textures[0].shape[-1]
|
channel = textures[0].shape[-1]
|
||||||
texture_merge = torch.zeros(self.texture_wh + [channel]).to(
|
texture_merge = torch.zeros(self.texture_wh + [channel]).to(
|
||||||
self.device
|
self.device
|
||||||
@ -451,6 +567,16 @@ class TextureBacker:
|
|||||||
def uv_inpaint(
|
def uv_inpaint(
|
||||||
self, mesh: trimesh.Trimesh, texture: np.ndarray, mask: np.ndarray
|
self, mesh: trimesh.Trimesh, texture: np.ndarray, mask: np.ndarray
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
|
"""Inpaints missing regions in the UV texture.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mesh (trimesh.Trimesh): Mesh.
|
||||||
|
texture (np.ndarray): Texture image.
|
||||||
|
mask (np.ndarray): Mask image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Inpainted texture.
|
||||||
|
"""
|
||||||
if self.inpaint_smooth:
|
if self.inpaint_smooth:
|
||||||
vertices, faces, uv_map = self.get_mesh_np_attrs(mesh)
|
vertices, faces, uv_map = self.get_mesh_np_attrs(mesh)
|
||||||
texture, mask = _texture_inpaint_smooth(
|
texture, mask = _texture_inpaint_smooth(
|
||||||
@ -473,6 +599,15 @@ class TextureBacker:
|
|||||||
colors: list[Image.Image],
|
colors: list[Image.Image],
|
||||||
mesh: trimesh.Trimesh,
|
mesh: trimesh.Trimesh,
|
||||||
) -> trimesh.Trimesh:
|
) -> trimesh.Trimesh:
|
||||||
|
"""Computes the fused texture for the mesh from multi-view images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
colors (list[Image.Image]): List of view images.
|
||||||
|
mesh (trimesh.Trimesh): Mesh to texture.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[np.ndarray, np.ndarray]: Texture and mask.
|
||||||
|
"""
|
||||||
self._lazy_init_render(self.camera_params, self.mask_thresh)
|
self._lazy_init_render(self.camera_params, self.mask_thresh)
|
||||||
|
|
||||||
vertices = torch.from_numpy(mesh.vertices).to(self.device).float()
|
vertices = torch.from_numpy(mesh.vertices).to(self.device).float()
|
||||||
@ -517,7 +652,7 @@ class TextureBacker:
|
|||||||
Args:
|
Args:
|
||||||
colors (list[Image.Image]): List of input view images.
|
colors (list[Image.Image]): List of input view images.
|
||||||
mesh (trimesh.Trimesh): Input mesh to be textured.
|
mesh (trimesh.Trimesh): Input mesh to be textured.
|
||||||
output_path (str): Path to save the output textured mesh (.obj or .glb).
|
output_path (str): Path to save the output textured mesh.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
trimesh.Trimesh: The textured mesh with UV and texture image.
|
trimesh.Trimesh: The textured mesh with UV and texture image.
|
||||||
@ -540,6 +675,11 @@ class TextureBacker:
|
|||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
"""Parses command-line arguments for texture backprojection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
argparse.Namespace: Parsed arguments.
|
||||||
|
"""
|
||||||
parser = argparse.ArgumentParser(description="Backproject texture")
|
parser = argparse.ArgumentParser(description="Backproject texture")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--color_path",
|
"--color_path",
|
||||||
@ -636,6 +776,16 @@ def entrypoint(
|
|||||||
imagesr_model: ImageRealESRGAN = None,
|
imagesr_model: ImageRealESRGAN = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> trimesh.Trimesh:
|
) -> trimesh.Trimesh:
|
||||||
|
"""Entrypoint for texture backprojection from multi-view images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
delight_model (DelightingModel, optional): Delighting model.
|
||||||
|
imagesr_model (ImageRealESRGAN, optional): Super-resolution model.
|
||||||
|
**kwargs: Additional arguments to override CLI.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
trimesh.Trimesh: Textured mesh.
|
||||||
|
"""
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
if hasattr(args, k) and v is not None:
|
if hasattr(args, k) and v is not None:
|
||||||
|
|||||||
@ -39,6 +39,22 @@ def decompose_convex_coacd(
|
|||||||
auto_scale: bool = True,
|
auto_scale: bool = True,
|
||||||
scale_factor: float = 1.0,
|
scale_factor: float = 1.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Decomposes a mesh using CoACD and saves the result.
|
||||||
|
|
||||||
|
This function loads a mesh from a file, runs the CoACD algorithm with the
|
||||||
|
given parameters, optionally scales the resulting convex hulls to match the
|
||||||
|
original mesh's bounding box, and exports the combined result to a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Path to the input mesh file.
|
||||||
|
outfile: Path to save the decomposed output mesh.
|
||||||
|
params: A dictionary of parameters for the CoACD algorithm.
|
||||||
|
verbose: If True, sets the CoACD log level to 'info'.
|
||||||
|
auto_scale: If True, automatically computes a scale factor to match the
|
||||||
|
decomposed mesh's bounding box to the visual mesh's bounding box.
|
||||||
|
scale_factor: An additional scaling factor applied to the vertices of
|
||||||
|
the decomposed mesh parts.
|
||||||
|
"""
|
||||||
coacd.set_log_level("info" if verbose else "warn")
|
coacd.set_log_level("info" if verbose else "warn")
|
||||||
|
|
||||||
mesh = trimesh.load(filename, force="mesh")
|
mesh = trimesh.load(filename, force="mesh")
|
||||||
@ -83,7 +99,38 @@ def decompose_convex_mesh(
|
|||||||
scale_factor: float = 1.005,
|
scale_factor: float = 1.005,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Decompose a mesh into convex parts using the CoACD algorithm."""
|
"""Decomposes a mesh into convex parts with retry logic.
|
||||||
|
|
||||||
|
This function serves as a wrapper for `decompose_convex_coacd`, providing
|
||||||
|
explicit parameters for the CoACD algorithm and implementing a retry
|
||||||
|
mechanism. If the initial decomposition fails, it attempts again with
|
||||||
|
`preprocess_mode` set to 'on'.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Path to the input mesh file.
|
||||||
|
outfile: Path to save the decomposed output mesh.
|
||||||
|
threshold: CoACD parameter. See CoACD documentation for details.
|
||||||
|
max_convex_hull: CoACD parameter. See CoACD documentation for details.
|
||||||
|
preprocess_mode: CoACD parameter. See CoACD documentation for details.
|
||||||
|
preprocess_resolution: CoACD parameter. See CoACD documentation for details.
|
||||||
|
resolution: CoACD parameter. See CoACD documentation for details.
|
||||||
|
mcts_nodes: CoACD parameter. See CoACD documentation for details.
|
||||||
|
mcts_iterations: CoACD parameter. See CoACD documentation for details.
|
||||||
|
mcts_max_depth: CoACD parameter. See CoACD documentation for details.
|
||||||
|
pca: CoACD parameter. See CoACD documentation for details.
|
||||||
|
merge: CoACD parameter. See CoACD documentation for details.
|
||||||
|
seed: CoACD parameter. See CoACD documentation for details.
|
||||||
|
auto_scale: If True, automatically scale the output to match the input
|
||||||
|
bounding box.
|
||||||
|
scale_factor: Additional scaling factor to apply.
|
||||||
|
verbose: If True, enables detailed logging.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The path to the output file if decomposition is successful.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If convex decomposition fails after all attempts.
|
||||||
|
"""
|
||||||
coacd.set_log_level("info" if verbose else "warn")
|
coacd.set_log_level("info" if verbose else "warn")
|
||||||
|
|
||||||
if os.path.exists(outfile):
|
if os.path.exists(outfile):
|
||||||
@ -148,9 +195,37 @@ def decompose_convex_mp(
|
|||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
auto_scale: bool = True,
|
auto_scale: bool = True,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Decompose a mesh into convex parts using the CoACD algorithm in a separate process.
|
"""Decomposes a mesh into convex parts in a separate process.
|
||||||
|
|
||||||
|
This function uses the `multiprocessing` module to run the CoACD algorithm
|
||||||
|
in a spawned subprocess. This is useful for isolating the decomposition
|
||||||
|
process to prevent potential memory leaks or crashes in the main process.
|
||||||
|
It includes a retry mechanism similar to `decompose_convex_mesh`.
|
||||||
|
|
||||||
See https://simulately.wiki/docs/toolkits/ConvexDecomp for details.
|
See https://simulately.wiki/docs/toolkits/ConvexDecomp for details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Path to the input mesh file.
|
||||||
|
outfile: Path to save the decomposed output mesh.
|
||||||
|
threshold: CoACD parameter.
|
||||||
|
max_convex_hull: CoACD parameter.
|
||||||
|
preprocess_mode: CoACD parameter.
|
||||||
|
preprocess_resolution: CoACD parameter.
|
||||||
|
resolution: CoACD parameter.
|
||||||
|
mcts_nodes: CoACD parameter.
|
||||||
|
mcts_iterations: CoACD parameter.
|
||||||
|
mcts_max_depth: CoACD parameter.
|
||||||
|
pca: CoACD parameter.
|
||||||
|
merge: CoACD parameter.
|
||||||
|
seed: CoACD parameter.
|
||||||
|
verbose: If True, enables detailed logging in the subprocess.
|
||||||
|
auto_scale: If True, automatically scale the output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The path to the output file if decomposition is successful.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If convex decomposition fails after all attempts.
|
||||||
"""
|
"""
|
||||||
params = dict(
|
params = dict(
|
||||||
threshold=threshold,
|
threshold=threshold,
|
||||||
|
|||||||
@ -66,6 +66,14 @@ def create_mp4_from_images(
|
|||||||
fps: int = 10,
|
fps: int = 10,
|
||||||
prompt: str = None,
|
prompt: str = None,
|
||||||
):
|
):
|
||||||
|
"""Creates an MP4 video from a list of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (list[np.ndarray]): List of images as numpy arrays.
|
||||||
|
output_path (str): Path to save the MP4 file.
|
||||||
|
fps (int, optional): Frames per second. Defaults to 10.
|
||||||
|
prompt (str, optional): Optional text prompt overlay.
|
||||||
|
"""
|
||||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||||
font_scale = 0.5
|
font_scale = 0.5
|
||||||
font_thickness = 1
|
font_thickness = 1
|
||||||
@ -96,6 +104,13 @@ def create_mp4_from_images(
|
|||||||
def create_gif_from_images(
|
def create_gif_from_images(
|
||||||
images: list[np.ndarray], output_path: str, fps: int = 10
|
images: list[np.ndarray], output_path: str, fps: int = 10
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Creates a GIF animation from a list of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (list[np.ndarray]): List of images as numpy arrays.
|
||||||
|
output_path (str): Path to save the GIF file.
|
||||||
|
fps (int, optional): Frames per second. Defaults to 10.
|
||||||
|
"""
|
||||||
pil_images = []
|
pil_images = []
|
||||||
for image in images:
|
for image in images:
|
||||||
image = image.clip(min=0, max=1)
|
image = image.clip(min=0, max=1)
|
||||||
@ -116,32 +131,47 @@ def create_gif_from_images(
|
|||||||
|
|
||||||
|
|
||||||
class ImageRender(object):
|
class ImageRender(object):
|
||||||
"""A differentiable mesh renderer supporting multi-view rendering.
|
"""Differentiable mesh renderer supporting multi-view rendering.
|
||||||
|
|
||||||
This class wraps a differentiable rasterization using `nvdiffrast` to
|
This class wraps differentiable rasterization using `nvdiffrast` to render mesh
|
||||||
render mesh geometry to various maps (normal, depth, alpha, albedo, etc.).
|
geometry to various maps (normal, depth, alpha, albedo, etc.) and supports
|
||||||
|
saving images and videos.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
render_items (list[RenderItems]): A list of rendering targets to
|
render_items (list[RenderItems]): List of rendering targets.
|
||||||
generate (e.g., IMAGE, DEPTH, NORMAL, etc.).
|
camera_params (CameraSetting): Camera parameters for rendering.
|
||||||
camera_params (CameraSetting): The camera parameters for rendering,
|
recompute_vtx_normal (bool, optional): Recompute vertex normals. Defaults to True.
|
||||||
including intrinsic and extrinsic matrices.
|
with_mtl (bool, optional): Load mesh material files. Defaults to False.
|
||||||
recompute_vtx_normal (bool, optional): If True, recomputes
|
gen_color_gif (bool, optional): Generate GIF of color images. Defaults to False.
|
||||||
vertex normals from the mesh geometry. Defaults to True.
|
gen_color_mp4 (bool, optional): Generate MP4 of color images. Defaults to False.
|
||||||
with_mtl (bool, optional): Whether to load `.mtl` material files
|
gen_viewnormal_mp4 (bool, optional): Generate MP4 of view-space normals. Defaults to False.
|
||||||
for meshes. Defaults to False.
|
gen_glonormal_mp4 (bool, optional): Generate MP4 of global-space normals. Defaults to False.
|
||||||
gen_color_gif (bool, optional): Generate a GIF of rendered
|
no_index_file (bool, optional): Skip saving index file. Defaults to False.
|
||||||
color images. Defaults to False.
|
light_factor (float, optional): PBR light intensity multiplier. Defaults to 1.0.
|
||||||
gen_color_mp4 (bool, optional): Generate an MP4 video of rendered
|
|
||||||
color images. Defaults to False.
|
Example:
|
||||||
gen_viewnormal_mp4 (bool, optional): Generate an MP4 video of
|
```py
|
||||||
view-space normals. Defaults to False.
|
from embodied_gen.data.differentiable_render import ImageRender
|
||||||
gen_glonormal_mp4 (bool, optional): Generate an MP4 video of
|
from embodied_gen.data.utils import CameraSetting
|
||||||
global-space normals. Defaults to False.
|
from embodied_gen.utils.enum import RenderItems
|
||||||
no_index_file (bool, optional): If True, skip saving the `index.json`
|
|
||||||
summary file. Defaults to False.
|
camera_params = CameraSetting(
|
||||||
light_factor (float, optional): A scalar multiplier for
|
num_images=6,
|
||||||
PBR light intensity. Defaults to 1.0.
|
elevation=[20, -10],
|
||||||
|
distance=5,
|
||||||
|
resolution_hw=(512,512),
|
||||||
|
fov=math.radians(30),
|
||||||
|
device='cuda',
|
||||||
|
)
|
||||||
|
render_items = [RenderItems.IMAGE.value, RenderItems.DEPTH.value]
|
||||||
|
renderer = ImageRender(
|
||||||
|
render_items,
|
||||||
|
camera_params,
|
||||||
|
with_mtl=args.with_mtl,
|
||||||
|
gen_color_mp4=True,
|
||||||
|
)
|
||||||
|
renderer.render_mesh(mesh_path='mesh.obj', output_root='./renders')
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -198,6 +228,14 @@ class ImageRender(object):
|
|||||||
uuid: Union[str, List[str]] = None,
|
uuid: Union[str, List[str]] = None,
|
||||||
prompts: List[str] = None,
|
prompts: List[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Renders one or more meshes and saves outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mesh_path (Union[str, List[str]]): Path(s) to mesh files.
|
||||||
|
output_root (str): Directory to save outputs.
|
||||||
|
uuid (Union[str, List[str]], optional): Unique IDs for outputs.
|
||||||
|
prompts (List[str], optional): Text prompts for videos.
|
||||||
|
"""
|
||||||
mesh_path = as_list(mesh_path)
|
mesh_path = as_list(mesh_path)
|
||||||
if uuid is None:
|
if uuid is None:
|
||||||
uuid = [os.path.basename(p).split(".")[0] for p in mesh_path]
|
uuid = [os.path.basename(p).split(".")[0] for p in mesh_path]
|
||||||
@ -227,18 +265,15 @@ class ImageRender(object):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self, mesh_path: str, output_dir: str, prompt: str = None
|
self, mesh_path: str, output_dir: str, prompt: str = None
|
||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
"""Render a single mesh and return paths to the rendered outputs.
|
"""Renders a single mesh and returns output paths.
|
||||||
|
|
||||||
Processes the input mesh, renders multiple modalities (e.g., normals,
|
|
||||||
depth, albedo), and optionally saves video or image sequences.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mesh_path (str): Path to the mesh file (.obj/.glb).
|
mesh_path (str): Path to mesh file.
|
||||||
output_dir (str): Directory to save rendered outputs.
|
output_dir (str): Directory to save outputs.
|
||||||
prompt (str, optional): Optional caption prompt for MP4 metadata.
|
prompt (str, optional): Caption prompt for MP4 metadata.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, str]: A mapping render types to the saved image paths.
|
dict[str, str]: Mapping of render types to saved image paths.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
mesh = import_kaolin_mesh(mesh_path, self.with_mtl)
|
mesh = import_kaolin_mesh(mesh_path, self.with_mtl)
|
||||||
|
|||||||
@ -16,17 +16,13 @@
|
|||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing as mp
|
|
||||||
import os
|
|
||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
import coacd
|
|
||||||
import igraph
|
import igraph
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pyvista as pv
|
import pyvista as pv
|
||||||
import spaces
|
import spaces
|
||||||
import torch
|
import torch
|
||||||
import trimesh
|
|
||||||
import utils3d
|
import utils3d
|
||||||
from pymeshfix import _meshfix
|
from pymeshfix import _meshfix
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|||||||
@ -51,6 +51,33 @@ __all__ = ["PickEmbodiedGen"]
|
|||||||
|
|
||||||
@register_env("PickEmbodiedGen-v1", max_episode_steps=100)
|
@register_env("PickEmbodiedGen-v1", max_episode_steps=100)
|
||||||
class PickEmbodiedGen(BaseEnv):
|
class PickEmbodiedGen(BaseEnv):
|
||||||
|
"""PickEmbodiedGen as gym env example for object pick-and-place tasks.
|
||||||
|
|
||||||
|
This environment simulates a robot interacting with 3D assets in the
|
||||||
|
embodiedgen generated scene in SAPIEN. It supports multi-environment setups,
|
||||||
|
dynamic reconfiguration, and hybrid rendering with 3D Gaussian Splatting.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
Use `gym.make` to create the `PickEmbodiedGen-v1` parallel environment.
|
||||||
|
```python
|
||||||
|
import gymnasium as gym
|
||||||
|
env = gym.make(
|
||||||
|
"PickEmbodiedGen-v1",
|
||||||
|
num_envs=cfg.num_envs,
|
||||||
|
render_mode=cfg.render_mode,
|
||||||
|
enable_shadow=cfg.enable_shadow,
|
||||||
|
layout_file=cfg.layout_file,
|
||||||
|
control_mode=cfg.control_mode,
|
||||||
|
camera_cfg=dict(
|
||||||
|
camera_eye=cfg.camera_eye,
|
||||||
|
camera_target_pt=cfg.camera_target_pt,
|
||||||
|
image_hw=cfg.image_hw,
|
||||||
|
fovy_deg=cfg.fovy_deg,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
SUPPORTED_ROBOTS = ["panda", "panda_wristcam", "fetch"]
|
SUPPORTED_ROBOTS = ["panda", "panda_wristcam", "fetch"]
|
||||||
goal_thresh = 0.0
|
goal_thresh = 0.0
|
||||||
|
|
||||||
@ -63,6 +90,19 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
reconfiguration_freq: int = None,
|
reconfiguration_freq: int = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
"""Initializes the PickEmbodiedGen environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: Variable length argument list for the base class.
|
||||||
|
robot_uids: The robot(s) to use in the environment.
|
||||||
|
robot_init_qpos_noise: Noise added to the robot's initial joint
|
||||||
|
positions.
|
||||||
|
num_envs: The number of parallel environments to create.
|
||||||
|
reconfiguration_freq: How often to reconfigure the scene. If None,
|
||||||
|
it is set based on num_envs.
|
||||||
|
**kwargs: Additional keyword arguments for environment setup,
|
||||||
|
including layout_file, replace_objs, enable_grasp, etc.
|
||||||
|
"""
|
||||||
self.robot_init_qpos_noise = robot_init_qpos_noise
|
self.robot_init_qpos_noise = robot_init_qpos_noise
|
||||||
if reconfiguration_freq is None:
|
if reconfiguration_freq is None:
|
||||||
if num_envs == 1:
|
if num_envs == 1:
|
||||||
@ -116,6 +156,22 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
def init_env_layouts(
|
def init_env_layouts(
|
||||||
layout_file: str, num_envs: int, replace_objs: bool
|
layout_file: str, num_envs: int, replace_objs: bool
|
||||||
) -> list[LayoutInfo]:
|
) -> list[LayoutInfo]:
|
||||||
|
"""Initializes and saves layout files for each environment instance.
|
||||||
|
|
||||||
|
For each environment, this method creates a layout configuration. If
|
||||||
|
`replace_objs` is True, it generates new object placements for each
|
||||||
|
subsequent environment. The generated layouts are saved as new JSON
|
||||||
|
files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layout_file: Path to the base layout JSON file.
|
||||||
|
num_envs: The number of environments to create layouts for.
|
||||||
|
replace_objs: If True, generates new object placements for each
|
||||||
|
environment after the first one using BFS placement.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of file paths to the generated layout for each environment.
|
||||||
|
"""
|
||||||
layouts = []
|
layouts = []
|
||||||
for env_idx in range(num_envs):
|
for env_idx in range(num_envs):
|
||||||
if replace_objs and env_idx > 0:
|
if replace_objs and env_idx > 0:
|
||||||
@ -136,6 +192,18 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
def compute_robot_init_pose(
|
def compute_robot_init_pose(
|
||||||
layouts: list[str], num_envs: int, z_offset: float = 0.0
|
layouts: list[str], num_envs: int, z_offset: float = 0.0
|
||||||
) -> list[list[float]]:
|
) -> list[list[float]]:
|
||||||
|
"""Computes the initial pose for the robot in each environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layouts: A list of file paths to the environment layouts.
|
||||||
|
num_envs: The number of environments.
|
||||||
|
z_offset: An optional vertical offset to apply to the robot's
|
||||||
|
position to prevent collisions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of initial poses ([x, y, z, qw, qx, qy, qz]) for the robot
|
||||||
|
in each environment.
|
||||||
|
"""
|
||||||
robot_pose = []
|
robot_pose = []
|
||||||
for env_idx in range(num_envs):
|
for env_idx in range(num_envs):
|
||||||
layout = json.load(open(layouts[env_idx], "r"))
|
layout = json.load(open(layouts[env_idx], "r"))
|
||||||
@ -148,6 +216,11 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _default_sim_config(self):
|
def _default_sim_config(self):
|
||||||
|
"""Returns the default simulation configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The default simulation configuration object.
|
||||||
|
"""
|
||||||
return SimConfig(
|
return SimConfig(
|
||||||
scene_config=SceneConfig(
|
scene_config=SceneConfig(
|
||||||
solver_position_iterations=30,
|
solver_position_iterations=30,
|
||||||
@ -163,6 +236,11 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _default_sensor_configs(self):
|
def _default_sensor_configs(self):
|
||||||
|
"""Returns the default sensor configurations for the agent.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list containing the default camera configuration.
|
||||||
|
"""
|
||||||
pose = sapien_utils.look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1])
|
pose = sapien_utils.look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1])
|
||||||
|
|
||||||
return [
|
return [
|
||||||
@ -171,6 +249,11 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _default_human_render_camera_configs(self):
|
def _default_human_render_camera_configs(self):
|
||||||
|
"""Returns the default camera configuration for human-friendly rendering.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The default camera configuration for the renderer.
|
||||||
|
"""
|
||||||
pose = sapien_utils.look_at(
|
pose = sapien_utils.look_at(
|
||||||
eye=self.camera_cfg["camera_eye"],
|
eye=self.camera_cfg["camera_eye"],
|
||||||
target=self.camera_cfg["camera_target_pt"],
|
target=self.camera_cfg["camera_target_pt"],
|
||||||
@ -187,10 +270,24 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _load_agent(self, options: dict):
|
def _load_agent(self, options: dict):
|
||||||
|
"""Loads the agent (robot) and a ground plane into the scene.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
options: A dictionary of options for loading the agent.
|
||||||
|
"""
|
||||||
self.ground = build_ground(self.scene)
|
self.ground = build_ground(self.scene)
|
||||||
super()._load_agent(options, sapien.Pose(p=[-10, 0, 10]))
|
super()._load_agent(options, sapien.Pose(p=[-10, 0, 10]))
|
||||||
|
|
||||||
def _load_scene(self, options: dict):
|
def _load_scene(self, options: dict):
|
||||||
|
"""Loads all assets, objects, and the goal site into the scene.
|
||||||
|
|
||||||
|
This method iterates through the layouts for each environment, loads the
|
||||||
|
specified assets, and adds them to the simulation. It also creates a
|
||||||
|
kinematic sphere to represent the goal site.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
options: A dictionary of options for loading the scene.
|
||||||
|
"""
|
||||||
all_objects = []
|
all_objects = []
|
||||||
logger.info(f"Loading EmbodiedGen assets...")
|
logger.info(f"Loading EmbodiedGen assets...")
|
||||||
for env_idx in range(self.num_envs):
|
for env_idx in range(self.num_envs):
|
||||||
@ -222,6 +319,15 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
self._hidden_objects.append(self.goal_site)
|
self._hidden_objects.append(self.goal_site)
|
||||||
|
|
||||||
def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
|
def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
|
||||||
|
"""Initializes an episode for a given set of environments.
|
||||||
|
|
||||||
|
This method sets the goal position, resets the robot's joint positions
|
||||||
|
with optional noise, and sets its root pose.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env_idx: A tensor of environment indices to initialize.
|
||||||
|
options: A dictionary of options for initialization.
|
||||||
|
"""
|
||||||
with torch.device(self.device):
|
with torch.device(self.device):
|
||||||
b = len(env_idx)
|
b = len(env_idx)
|
||||||
goal_xyz = torch.zeros((b, 3))
|
goal_xyz = torch.zeros((b, 3))
|
||||||
@ -256,6 +362,21 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
def render_gs3d_images(
|
def render_gs3d_images(
|
||||||
self, layouts: list[str], num_envs: int, init_quat: list[float]
|
self, layouts: list[str], num_envs: int, init_quat: list[float]
|
||||||
) -> dict[str, np.ndarray]:
|
) -> dict[str, np.ndarray]:
|
||||||
|
"""Renders background images using a pre-trained Gaussian Splatting model.
|
||||||
|
|
||||||
|
This method pre-renders the static background for each environment from
|
||||||
|
the perspective of all cameras to be used for hybrid rendering.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layouts: A list of file paths to the environment layouts.
|
||||||
|
num_envs: The number of environments.
|
||||||
|
init_quat: An initial quaternion to orient the Gaussian Splatting
|
||||||
|
model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary mapping a unique key (e.g., 'camera-env_idx') to the
|
||||||
|
rendered background image as a numpy array.
|
||||||
|
"""
|
||||||
sim_coord_align = (
|
sim_coord_align = (
|
||||||
torch.tensor(SIM_COORD_ALIGN).to(torch.float32).to(self.device)
|
torch.tensor(SIM_COORD_ALIGN).to(torch.float32).to(self.device)
|
||||||
)
|
)
|
||||||
@ -293,6 +414,15 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
return bg_images
|
return bg_images
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
|
"""Renders the environment based on the configured render_mode.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If `render_mode` is not set.
|
||||||
|
NotImplementedError: If the `render_mode` is not supported.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The rendered output, which varies depending on the render mode.
|
||||||
|
"""
|
||||||
if self.render_mode is None:
|
if self.render_mode is None:
|
||||||
raise RuntimeError("render_mode is not set.")
|
raise RuntimeError("render_mode is not set.")
|
||||||
if self.render_mode == "human":
|
if self.render_mode == "human":
|
||||||
@ -315,6 +445,17 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
def render_rgb_array(
|
def render_rgb_array(
|
||||||
self, camera_name: str = None, return_alpha: bool = False
|
self, camera_name: str = None, return_alpha: bool = False
|
||||||
):
|
):
|
||||||
|
"""Renders an RGB image from the human-facing render camera.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_name: The name of the camera to render from. If None, uses
|
||||||
|
all human render cameras.
|
||||||
|
return_alpha: Whether to include the alpha channel in the output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A numpy array representing the rendered image(s). If multiple
|
||||||
|
cameras are used, the images are tiled.
|
||||||
|
"""
|
||||||
for obj in self._hidden_objects:
|
for obj in self._hidden_objects:
|
||||||
obj.show_visual()
|
obj.show_visual()
|
||||||
self.scene.update_render(
|
self.scene.update_render(
|
||||||
@ -335,6 +476,11 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
return tile_images(images)
|
return tile_images(images)
|
||||||
|
|
||||||
def render_sensors(self):
|
def render_sensors(self):
|
||||||
|
"""Renders images from all on-board sensor cameras.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tiled image of all sensor outputs as a numpy array.
|
||||||
|
"""
|
||||||
images = []
|
images = []
|
||||||
sensor_images = self.get_sensor_images()
|
sensor_images = self.get_sensor_images()
|
||||||
for image in sensor_images.values():
|
for image in sensor_images.values():
|
||||||
@ -343,6 +489,14 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
return tile_images(images)
|
return tile_images(images)
|
||||||
|
|
||||||
def hybrid_render(self):
|
def hybrid_render(self):
|
||||||
|
"""Renders a hybrid image by blending simulated foreground with a background.
|
||||||
|
|
||||||
|
The foreground is rendered with an alpha channel and then blended with
|
||||||
|
the pre-rendered Gaussian Splatting background image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A torch tensor of the final blended RGB images.
|
||||||
|
"""
|
||||||
fg_images = self.render_rgb_array(
|
fg_images = self.render_rgb_array(
|
||||||
return_alpha=True
|
return_alpha=True
|
||||||
) # (n_env, h, w, 3)
|
) # (n_env, h, w, 3)
|
||||||
@ -362,6 +516,16 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
return images[..., :3]
|
return images[..., :3]
|
||||||
|
|
||||||
def evaluate(self):
|
def evaluate(self):
|
||||||
|
"""Evaluates the current state of the environment.
|
||||||
|
|
||||||
|
Checks for task success criteria such as whether the object is grasped,
|
||||||
|
placed at the goal, and if the robot is static.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary containing boolean tensors for various success
|
||||||
|
metrics, including 'is_grasped', 'is_obj_placed', and overall
|
||||||
|
'success'.
|
||||||
|
"""
|
||||||
obj_to_goal_pos = (
|
obj_to_goal_pos = (
|
||||||
self.obj.pose.p
|
self.obj.pose.p
|
||||||
) # self.goal_site.pose.p - self.obj.pose.p
|
) # self.goal_site.pose.p - self.obj.pose.p
|
||||||
@ -381,10 +545,31 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _get_obs_extra(self, info: dict):
|
def _get_obs_extra(self, info: dict):
|
||||||
|
"""Gets extra information for the observation dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
info: A dictionary containing evaluation information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An empty dictionary, as no extra observations are added.
|
||||||
|
"""
|
||||||
|
|
||||||
return dict()
|
return dict()
|
||||||
|
|
||||||
def compute_dense_reward(self, obs: any, action: torch.Tensor, info: dict):
|
def compute_dense_reward(self, obs: any, action: torch.Tensor, info: dict):
|
||||||
|
"""Computes a dense reward for the current step.
|
||||||
|
|
||||||
|
The reward is a composite of reaching, grasping, placing, and
|
||||||
|
maintaining a static final pose.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obs: The current observation.
|
||||||
|
action: The action taken in the current step.
|
||||||
|
info: A dictionary containing evaluation information from `evaluate()`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor containing the dense reward for each environment.
|
||||||
|
"""
|
||||||
tcp_to_obj_dist = torch.linalg.norm(
|
tcp_to_obj_dist = torch.linalg.norm(
|
||||||
self.obj.pose.p - self.agent.tcp.pose.p, axis=1
|
self.obj.pose.p - self.agent.tcp.pose.p, axis=1
|
||||||
)
|
)
|
||||||
@ -417,4 +602,14 @@ class PickEmbodiedGen(BaseEnv):
|
|||||||
def compute_normalized_dense_reward(
|
def compute_normalized_dense_reward(
|
||||||
self, obs: any, action: torch.Tensor, info: dict
|
self, obs: any, action: torch.Tensor, info: dict
|
||||||
):
|
):
|
||||||
|
"""Computes a dense reward normalized to be between 0 and 1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obs: The current observation.
|
||||||
|
action: The action taken in the current step.
|
||||||
|
info: A dictionary containing evaluation information from `evaluate()`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor containing the normalized dense reward for each environment.
|
||||||
|
"""
|
||||||
return self.compute_dense_reward(obs=obs, action=action, info=info) / 6
|
return self.compute_dense_reward(obs=obs, action=action, info=info) / 6
|
||||||
|
|||||||
@ -40,7 +40,7 @@ class DelightingModel(object):
|
|||||||
"""A model to remove the lighting in image space.
|
"""A model to remove the lighting in image space.
|
||||||
|
|
||||||
This model is encapsulated based on the Hunyuan3D-Delight model
|
This model is encapsulated based on the Hunyuan3D-Delight model
|
||||||
from https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0 # noqa
|
from `https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0` # noqa
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
image_guide_scale (float): Weight of image guidance in diffusion process.
|
image_guide_scale (float): Weight of image guidance in diffusion process.
|
||||||
|
|||||||
@ -38,26 +38,61 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class BasePipelineLoader(ABC):
|
class BasePipelineLoader(ABC):
|
||||||
|
"""Abstract base class for loading Hugging Face image generation pipelines.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
device (str): Device to load the pipeline on.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
load(): Loads and returns the pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, device="cuda"):
|
def __init__(self, device="cuda"):
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load(self):
|
def load(self):
|
||||||
|
"""Load and return the pipeline instance."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BasePipelineRunner(ABC):
|
class BasePipelineRunner(ABC):
|
||||||
|
"""Abstract base class for running image generation pipelines.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
pipe: The loaded pipeline.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
run(prompt, **kwargs): Runs the pipeline with a prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, pipe):
|
def __init__(self, pipe):
|
||||||
self.pipe = pipe
|
self.pipe = pipe
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def run(self, prompt: str, **kwargs) -> Image.Image:
|
def run(self, prompt: str, **kwargs) -> Image.Image:
|
||||||
|
"""Run the pipeline with the given prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): Text prompt for image generation.
|
||||||
|
**kwargs: Additional pipeline arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image.Image: Generated image(s).
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# ===== SD3.5-medium =====
|
# ===== SD3.5-medium =====
|
||||||
class SD35Loader(BasePipelineLoader):
|
class SD35Loader(BasePipelineLoader):
|
||||||
|
"""Loader for Stable Diffusion 3.5 medium pipeline."""
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
|
"""Load the Stable Diffusion 3.5 medium pipeline.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StableDiffusion3Pipeline: Loaded pipeline.
|
||||||
|
"""
|
||||||
pipe = StableDiffusion3Pipeline.from_pretrained(
|
pipe = StableDiffusion3Pipeline.from_pretrained(
|
||||||
"stabilityai/stable-diffusion-3.5-medium",
|
"stabilityai/stable-diffusion-3.5-medium",
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
@ -70,12 +105,25 @@ class SD35Loader(BasePipelineLoader):
|
|||||||
|
|
||||||
|
|
||||||
class SD35Runner(BasePipelineRunner):
|
class SD35Runner(BasePipelineRunner):
|
||||||
|
"""Runner for Stable Diffusion 3.5 medium pipeline."""
|
||||||
|
|
||||||
def run(self, prompt: str, **kwargs) -> Image.Image:
|
def run(self, prompt: str, **kwargs) -> Image.Image:
|
||||||
|
"""Generate images using Stable Diffusion 3.5 medium.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): Text prompt.
|
||||||
|
**kwargs: Additional arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image.Image: Generated image(s).
|
||||||
|
"""
|
||||||
return self.pipe(prompt=prompt, **kwargs).images
|
return self.pipe(prompt=prompt, **kwargs).images
|
||||||
|
|
||||||
|
|
||||||
# ===== Cosmos2 =====
|
# ===== Cosmos2 =====
|
||||||
class CosmosLoader(BasePipelineLoader):
|
class CosmosLoader(BasePipelineLoader):
|
||||||
|
"""Loader for Cosmos2 text-to-image pipeline."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_id="nvidia/Cosmos-Predict2-2B-Text2Image",
|
model_id="nvidia/Cosmos-Predict2-2B-Text2Image",
|
||||||
@ -87,6 +135,8 @@ class CosmosLoader(BasePipelineLoader):
|
|||||||
self.local_dir = local_dir
|
self.local_dir = local_dir
|
||||||
|
|
||||||
def _patch(self):
|
def _patch(self):
|
||||||
|
"""Patch model and processor for optimized loading."""
|
||||||
|
|
||||||
def patch_model(cls):
|
def patch_model(cls):
|
||||||
orig = cls.from_pretrained
|
orig = cls.from_pretrained
|
||||||
|
|
||||||
@ -110,6 +160,11 @@ class CosmosLoader(BasePipelineLoader):
|
|||||||
patch_processor(SiglipProcessor)
|
patch_processor(SiglipProcessor)
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
|
"""Load the Cosmos2 text-to-image pipeline.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cosmos2TextToImagePipeline: Loaded pipeline.
|
||||||
|
"""
|
||||||
self._patch()
|
self._patch()
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
repo_id=self.model_id,
|
repo_id=self.model_id,
|
||||||
@ -141,7 +196,19 @@ class CosmosLoader(BasePipelineLoader):
|
|||||||
|
|
||||||
|
|
||||||
class CosmosRunner(BasePipelineRunner):
|
class CosmosRunner(BasePipelineRunner):
|
||||||
|
"""Runner for Cosmos2 text-to-image pipeline."""
|
||||||
|
|
||||||
def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
|
def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
|
||||||
|
"""Generate images using Cosmos2 pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): Text prompt.
|
||||||
|
negative_prompt (str, optional): Negative prompt.
|
||||||
|
**kwargs: Additional arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image.Image: Generated image(s).
|
||||||
|
"""
|
||||||
return self.pipe(
|
return self.pipe(
|
||||||
prompt=prompt, negative_prompt=negative_prompt, **kwargs
|
prompt=prompt, negative_prompt=negative_prompt, **kwargs
|
||||||
).images
|
).images
|
||||||
@ -149,7 +216,14 @@ class CosmosRunner(BasePipelineRunner):
|
|||||||
|
|
||||||
# ===== Kolors =====
|
# ===== Kolors =====
|
||||||
class KolorsLoader(BasePipelineLoader):
|
class KolorsLoader(BasePipelineLoader):
|
||||||
|
"""Loader for Kolors pipeline."""
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
|
"""Load the Kolors pipeline.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
KolorsPipeline: Loaded pipeline.
|
||||||
|
"""
|
||||||
pipe = KolorsPipeline.from_pretrained(
|
pipe = KolorsPipeline.from_pretrained(
|
||||||
"Kwai-Kolors/Kolors-diffusers",
|
"Kwai-Kolors/Kolors-diffusers",
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
@ -164,13 +238,31 @@ class KolorsLoader(BasePipelineLoader):
|
|||||||
|
|
||||||
|
|
||||||
class KolorsRunner(BasePipelineRunner):
|
class KolorsRunner(BasePipelineRunner):
|
||||||
|
"""Runner for Kolors pipeline."""
|
||||||
|
|
||||||
def run(self, prompt: str, **kwargs) -> Image.Image:
|
def run(self, prompt: str, **kwargs) -> Image.Image:
|
||||||
|
"""Generate images using Kolors pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): Text prompt.
|
||||||
|
**kwargs: Additional arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image.Image: Generated image(s).
|
||||||
|
"""
|
||||||
return self.pipe(prompt=prompt, **kwargs).images
|
return self.pipe(prompt=prompt, **kwargs).images
|
||||||
|
|
||||||
|
|
||||||
# ===== Flux =====
|
# ===== Flux =====
|
||||||
class FluxLoader(BasePipelineLoader):
|
class FluxLoader(BasePipelineLoader):
|
||||||
|
"""Loader for Flux pipeline."""
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
|
"""Load the Flux pipeline.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FluxPipeline: Loaded pipeline.
|
||||||
|
"""
|
||||||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
||||||
pipe = FluxPipeline.from_pretrained(
|
pipe = FluxPipeline.from_pretrained(
|
||||||
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
|
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
|
||||||
@ -182,20 +274,50 @@ class FluxLoader(BasePipelineLoader):
|
|||||||
|
|
||||||
|
|
||||||
class FluxRunner(BasePipelineRunner):
|
class FluxRunner(BasePipelineRunner):
|
||||||
|
"""Runner for Flux pipeline."""
|
||||||
|
|
||||||
def run(self, prompt: str, **kwargs) -> Image.Image:
|
def run(self, prompt: str, **kwargs) -> Image.Image:
|
||||||
|
"""Generate images using Flux pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): Text prompt.
|
||||||
|
**kwargs: Additional arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image.Image: Generated image(s).
|
||||||
|
"""
|
||||||
return self.pipe(prompt=prompt, **kwargs).images
|
return self.pipe(prompt=prompt, **kwargs).images
|
||||||
|
|
||||||
|
|
||||||
# ===== Chroma =====
|
# ===== Chroma =====
|
||||||
class ChromaLoader(BasePipelineLoader):
|
class ChromaLoader(BasePipelineLoader):
|
||||||
|
"""Loader for Chroma pipeline."""
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
|
"""Load the Chroma pipeline.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChromaPipeline: Loaded pipeline.
|
||||||
|
"""
|
||||||
return ChromaPipeline.from_pretrained(
|
return ChromaPipeline.from_pretrained(
|
||||||
"lodestones/Chroma", torch_dtype=torch.bfloat16
|
"lodestones/Chroma", torch_dtype=torch.bfloat16
|
||||||
).to(self.device)
|
).to(self.device)
|
||||||
|
|
||||||
|
|
||||||
class ChromaRunner(BasePipelineRunner):
|
class ChromaRunner(BasePipelineRunner):
|
||||||
|
"""Runner for Chroma pipeline."""
|
||||||
|
|
||||||
def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
|
def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
|
||||||
|
"""Generate images using Chroma pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): Text prompt.
|
||||||
|
negative_prompt (str, optional): Negative prompt.
|
||||||
|
**kwargs: Additional arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image.Image: Generated image(s).
|
||||||
|
"""
|
||||||
return self.pipe(
|
return self.pipe(
|
||||||
prompt=prompt, negative_prompt=negative_prompt, **kwargs
|
prompt=prompt, negative_prompt=negative_prompt, **kwargs
|
||||||
).images
|
).images
|
||||||
@ -211,6 +333,22 @@ PIPELINE_REGISTRY = {
|
|||||||
|
|
||||||
|
|
||||||
def build_hf_image_pipeline(name: str, device="cuda") -> BasePipelineRunner:
|
def build_hf_image_pipeline(name: str, device="cuda") -> BasePipelineRunner:
|
||||||
|
"""Build a Hugging Face image generation pipeline runner by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): Name of the pipeline (e.g., "sd35", "cosmos").
|
||||||
|
device (str): Device to load the pipeline on.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BasePipelineRunner: Pipeline runner instance.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
from embodied_gen.models.image_comm_model import build_hf_image_pipeline
|
||||||
|
runner = build_hf_image_pipeline("sd35")
|
||||||
|
images = runner.run(prompt="A robot holding a sign that says 'Hello'")
|
||||||
|
```
|
||||||
|
"""
|
||||||
if name not in PIPELINE_REGISTRY:
|
if name not in PIPELINE_REGISTRY:
|
||||||
raise ValueError(f"Unsupported model: {name}")
|
raise ValueError(f"Unsupported model: {name}")
|
||||||
loader_cls, runner_cls = PIPELINE_REGISTRY[name]
|
loader_cls, runner_cls = PIPELINE_REGISTRY[name]
|
||||||
|
|||||||
@ -376,6 +376,21 @@ LAYOUT_DESCRIBER_PROMPT = """
|
|||||||
|
|
||||||
|
|
||||||
class LayoutDesigner(object):
|
class LayoutDesigner(object):
|
||||||
|
"""A class for querying GPT-based scene layout reasoning and formatting responses.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
prompt (str): The system prompt for GPT.
|
||||||
|
verbose (bool): Whether to log responses.
|
||||||
|
gpt_client (GPTclient): The GPT client instance.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
query(prompt, params): Query GPT with a prompt and parameters.
|
||||||
|
format_response(response): Parse and clean JSON response.
|
||||||
|
format_response_repair(response): Repair and parse JSON response.
|
||||||
|
save_output(output, save_path): Save output to file.
|
||||||
|
__call__(prompt, save_path, params): Query and process output.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
gpt_client: GPTclient,
|
gpt_client: GPTclient,
|
||||||
@ -387,6 +402,15 @@ class LayoutDesigner(object):
|
|||||||
self.gpt_client = gpt_client
|
self.gpt_client = gpt_client
|
||||||
|
|
||||||
def query(self, prompt: str, params: dict = None) -> str:
|
def query(self, prompt: str, params: dict = None) -> str:
|
||||||
|
"""Query GPT with the system prompt and user prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): User prompt.
|
||||||
|
params (dict, optional): GPT parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: GPT response.
|
||||||
|
"""
|
||||||
full_prompt = self.prompt + f"\n\nInput:\n\"{prompt}\""
|
full_prompt = self.prompt + f"\n\nInput:\n\"{prompt}\""
|
||||||
|
|
||||||
response = self.gpt_client.query(
|
response = self.gpt_client.query(
|
||||||
@ -400,6 +424,17 @@ class LayoutDesigner(object):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
def format_response(self, response: str) -> dict:
|
def format_response(self, response: str) -> dict:
|
||||||
|
"""Format and parse GPT response as JSON.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response (str): Raw GPT response.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Parsed JSON output.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
json.JSONDecodeError: If parsing fails.
|
||||||
|
"""
|
||||||
cleaned = re.sub(r"^```json\s*|\s*```$", "", response.strip())
|
cleaned = re.sub(r"^```json\s*|\s*```$", "", response.strip())
|
||||||
try:
|
try:
|
||||||
output = json.loads(cleaned)
|
output = json.loads(cleaned)
|
||||||
@ -411,9 +446,23 @@ class LayoutDesigner(object):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
def format_response_repair(self, response: str) -> dict:
|
def format_response_repair(self, response: str) -> dict:
|
||||||
|
"""Repair and parse possibly broken JSON response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response (str): Raw GPT response.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Parsed JSON output.
|
||||||
|
"""
|
||||||
return json_repair.loads(response)
|
return json_repair.loads(response)
|
||||||
|
|
||||||
def save_output(self, output: dict, save_path: str) -> None:
|
def save_output(self, output: dict, save_path: str) -> None:
|
||||||
|
"""Save output dictionary to a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output (dict): Output data.
|
||||||
|
save_path (str): Path to save the file.
|
||||||
|
"""
|
||||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||||
with open(save_path, 'w') as f:
|
with open(save_path, 'w') as f:
|
||||||
json.dump(output, f, indent=4)
|
json.dump(output, f, indent=4)
|
||||||
@ -421,6 +470,16 @@ class LayoutDesigner(object):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self, prompt: str, save_path: str = None, params: dict = None
|
self, prompt: str, save_path: str = None, params: dict = None
|
||||||
) -> dict | str:
|
) -> dict | str:
|
||||||
|
"""Query GPT and process the output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): User prompt.
|
||||||
|
save_path (str, optional): Path to save output.
|
||||||
|
params (dict, optional): GPT parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict | str: Output data.
|
||||||
|
"""
|
||||||
response = self.query(prompt, params=params)
|
response = self.query(prompt, params=params)
|
||||||
output = self.format_response_repair(response)
|
output = self.format_response_repair(response)
|
||||||
self.save_output(output, save_path) if save_path else None
|
self.save_output(output, save_path) if save_path else None
|
||||||
@ -442,6 +501,29 @@ LAYOUT_DESCRIBER = LayoutDesigner(
|
|||||||
def build_scene_layout(
|
def build_scene_layout(
|
||||||
task_desc: str, output_path: str = None, gpt_params: dict = None
|
task_desc: str, output_path: str = None, gpt_params: dict = None
|
||||||
) -> LayoutInfo:
|
) -> LayoutInfo:
|
||||||
|
"""Build a 3D scene layout from a natural language task description.
|
||||||
|
|
||||||
|
This function uses GPT-based reasoning to generate a structured scene layout,
|
||||||
|
including object hierarchy, spatial relations, and style descriptions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_desc (str): Natural language description of the robotic task.
|
||||||
|
output_path (str, optional): Path to save the visualized scene tree.
|
||||||
|
gpt_params (dict, optional): Parameters for GPT queries.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LayoutInfo: Structured layout information for the scene.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
from embodied_gen.models.layout import build_scene_layout
|
||||||
|
layout_info = build_scene_layout(
|
||||||
|
task_desc="Put the apples on the table on the plate",
|
||||||
|
output_path="outputs/scene_tree.jpg",
|
||||||
|
)
|
||||||
|
print(layout_info)
|
||||||
|
```
|
||||||
|
"""
|
||||||
layout_relation = LAYOUT_DISASSEMBLER(task_desc, params=gpt_params)
|
layout_relation = LAYOUT_DISASSEMBLER(task_desc, params=gpt_params)
|
||||||
layout_tree = LAYOUT_GRAPHER(layout_relation, params=gpt_params)
|
layout_tree = LAYOUT_GRAPHER(layout_relation, params=gpt_params)
|
||||||
object_mapping = Scene3DItemEnum.object_mapping(layout_relation)
|
object_mapping = Scene3DItemEnum.object_mapping(layout_relation)
|
||||||
|
|||||||
@ -48,12 +48,19 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class SAMRemover(object):
|
class SAMRemover(object):
|
||||||
"""Loading SAM models and performing background removal on images.
|
"""Loads SAM models and performs background removal on images.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
checkpoint (str): Path to the model checkpoint.
|
checkpoint (str): Path to the model checkpoint.
|
||||||
model_type (str): Type of the SAM model to load (default: "vit_h").
|
model_type (str): Type of the SAM model to load.
|
||||||
area_ratio (float): Area ratio filtering small connected components.
|
area_ratio (float): Area ratio for filtering small connected components.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
from embodied_gen.models.segment_model import SAMRemover
|
||||||
|
remover = SAMRemover(model_type="vit_h")
|
||||||
|
result = remover("input.jpg", "output.png")
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -78,6 +85,14 @@ class SAMRemover(object):
|
|||||||
self.mask_generator = self._load_sam_model(checkpoint)
|
self.mask_generator = self._load_sam_model(checkpoint)
|
||||||
|
|
||||||
def _load_sam_model(self, checkpoint: str) -> SamAutomaticMaskGenerator:
|
def _load_sam_model(self, checkpoint: str) -> SamAutomaticMaskGenerator:
|
||||||
|
"""Loads the SAM model and returns a mask generator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint (str): Path to model checkpoint.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SamAutomaticMaskGenerator: Mask generator instance.
|
||||||
|
"""
|
||||||
sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
|
sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
|
||||||
sam.to(device=self.device)
|
sam.to(device=self.device)
|
||||||
|
|
||||||
@ -89,13 +104,11 @@ class SAMRemover(object):
|
|||||||
"""Removes the background from an image using the SAM model.
|
"""Removes the background from an image using the SAM model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image (Union[str, Image.Image, np.ndarray]): Input image,
|
image (Union[str, Image.Image, np.ndarray]): Input image.
|
||||||
can be a file path, PIL Image, or numpy array.
|
save_path (str, optional): Path to save the output image.
|
||||||
save_path (str): Path to save the output image (default: None).
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Image.Image: The image with background removed,
|
Image.Image: Image with background removed (RGBA).
|
||||||
including an alpha channel.
|
|
||||||
"""
|
"""
|
||||||
# Convert input to numpy array
|
# Convert input to numpy array
|
||||||
if isinstance(image, str):
|
if isinstance(image, str):
|
||||||
@ -134,6 +147,15 @@ class SAMRemover(object):
|
|||||||
|
|
||||||
|
|
||||||
class SAMPredictor(object):
|
class SAMPredictor(object):
|
||||||
|
"""Loads SAM models and predicts segmentation masks from user points.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint (str, optional): Path to model checkpoint.
|
||||||
|
model_type (str, optional): SAM model type.
|
||||||
|
binary_thresh (float, optional): Threshold for binary mask.
|
||||||
|
device (str, optional): Device for inference.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
checkpoint: str = None,
|
checkpoint: str = None,
|
||||||
@ -157,12 +179,28 @@ class SAMPredictor(object):
|
|||||||
self.binary_thresh = binary_thresh
|
self.binary_thresh = binary_thresh
|
||||||
|
|
||||||
def _load_sam_model(self, checkpoint: str) -> SamPredictor:
|
def _load_sam_model(self, checkpoint: str) -> SamPredictor:
|
||||||
|
"""Loads the SAM model and returns a predictor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint (str): Path to model checkpoint.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SamPredictor: Predictor instance.
|
||||||
|
"""
|
||||||
sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
|
sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
|
||||||
sam.to(device=self.device)
|
sam.to(device=self.device)
|
||||||
|
|
||||||
return SamPredictor(sam)
|
return SamPredictor(sam)
|
||||||
|
|
||||||
def preprocess_image(self, image: Image.Image) -> np.ndarray:
|
def preprocess_image(self, image: Image.Image) -> np.ndarray:
|
||||||
|
"""Preprocesses input image for SAM prediction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (Image.Image): Input image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Preprocessed image array.
|
||||||
|
"""
|
||||||
if isinstance(image, str):
|
if isinstance(image, str):
|
||||||
image = Image.open(image)
|
image = Image.open(image)
|
||||||
elif isinstance(image, np.ndarray):
|
elif isinstance(image, np.ndarray):
|
||||||
@ -178,6 +216,15 @@ class SAMPredictor(object):
|
|||||||
image: np.ndarray,
|
image: np.ndarray,
|
||||||
selected_points: list[list[int]],
|
selected_points: list[list[int]],
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
|
"""Generates segmentation masks from selected points.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (np.ndarray): Input image array.
|
||||||
|
selected_points (list[list[int]]): List of points and labels.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[tuple[np.ndarray, str]]: List of masks and names.
|
||||||
|
"""
|
||||||
if len(selected_points) == 0:
|
if len(selected_points) == 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@ -220,6 +267,15 @@ class SAMPredictor(object):
|
|||||||
def get_segmented_image(
|
def get_segmented_image(
|
||||||
self, image: np.ndarray, masks: list[tuple[np.ndarray, str]]
|
self, image: np.ndarray, masks: list[tuple[np.ndarray, str]]
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
|
"""Combines masks and returns segmented image with alpha channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (np.ndarray): Input image array.
|
||||||
|
masks (list[tuple[np.ndarray, str]]): List of masks.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image.Image: Segmented RGBA image.
|
||||||
|
"""
|
||||||
seg_image = Image.fromarray(image, mode="RGB")
|
seg_image = Image.fromarray(image, mode="RGB")
|
||||||
alpha_channel = np.zeros(
|
alpha_channel = np.zeros(
|
||||||
(seg_image.height, seg_image.width), dtype=np.uint8
|
(seg_image.height, seg_image.width), dtype=np.uint8
|
||||||
@ -241,6 +297,15 @@ class SAMPredictor(object):
|
|||||||
image: Union[str, Image.Image, np.ndarray],
|
image: Union[str, Image.Image, np.ndarray],
|
||||||
selected_points: list[list[int]],
|
selected_points: list[list[int]],
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
|
"""Segments image using selected points.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (Union[str, Image.Image, np.ndarray]): Input image.
|
||||||
|
selected_points (list[list[int]]): List of points and labels.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image.Image: Segmented RGBA image.
|
||||||
|
"""
|
||||||
image = self.preprocess_image(image)
|
image = self.preprocess_image(image)
|
||||||
self.predictor.set_image(image)
|
self.predictor.set_image(image)
|
||||||
masks = self.generate_masks(image, selected_points)
|
masks = self.generate_masks(image, selected_points)
|
||||||
@ -249,12 +314,32 @@ class SAMPredictor(object):
|
|||||||
|
|
||||||
|
|
||||||
class RembgRemover(object):
|
class RembgRemover(object):
|
||||||
|
"""Removes background from images using the rembg library.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
from embodied_gen.models.segment_model import RembgRemover
|
||||||
|
remover = RembgRemover()
|
||||||
|
result = remover("input.jpg", "output.png")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
"""Initializes the RembgRemover."""
|
||||||
self.rembg_session = rembg.new_session("u2net")
|
self.rembg_session = rembg.new_session("u2net")
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
|
self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
|
"""Removes background from an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (Union[str, Image.Image, np.ndarray]): Input image.
|
||||||
|
save_path (str, optional): Path to save the output image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image.Image: Image with background removed (RGBA).
|
||||||
|
"""
|
||||||
if isinstance(image, str):
|
if isinstance(image, str):
|
||||||
image = Image.open(image)
|
image = Image.open(image)
|
||||||
elif isinstance(image, np.ndarray):
|
elif isinstance(image, np.ndarray):
|
||||||
@ -271,7 +356,18 @@ class RembgRemover(object):
|
|||||||
|
|
||||||
|
|
||||||
class BMGG14Remover(object):
|
class BMGG14Remover(object):
|
||||||
|
"""Removes background using the RMBG-1.4 segmentation model.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
from embodied_gen.models.segment_model import BMGG14Remover
|
||||||
|
remover = BMGG14Remover()
|
||||||
|
result = remover("input.jpg", "output.png")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
"""Initializes the BMGG14Remover."""
|
||||||
self.model = pipeline(
|
self.model = pipeline(
|
||||||
"image-segmentation",
|
"image-segmentation",
|
||||||
model="briaai/RMBG-1.4",
|
model="briaai/RMBG-1.4",
|
||||||
@ -281,6 +377,15 @@ class BMGG14Remover(object):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
|
self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
|
||||||
):
|
):
|
||||||
|
"""Removes background from an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (Union[str, Image.Image, np.ndarray]): Input image.
|
||||||
|
save_path (str, optional): Path to save the output image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image.Image: Image with background removed.
|
||||||
|
"""
|
||||||
if isinstance(image, str):
|
if isinstance(image, str):
|
||||||
image = Image.open(image)
|
image = Image.open(image)
|
||||||
elif isinstance(image, np.ndarray):
|
elif isinstance(image, np.ndarray):
|
||||||
@ -299,6 +404,16 @@ class BMGG14Remover(object):
|
|||||||
def invert_rgba_pil(
|
def invert_rgba_pil(
|
||||||
image: Image.Image, mask: Image.Image, save_path: str = None
|
image: Image.Image, mask: Image.Image, save_path: str = None
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
|
"""Inverts the alpha channel of an RGBA image using a mask.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (Image.Image): Input RGB image.
|
||||||
|
mask (Image.Image): Mask image for alpha inversion.
|
||||||
|
save_path (str, optional): Path to save the output image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image.Image: RGBA image with inverted alpha.
|
||||||
|
"""
|
||||||
mask = (255 - np.array(mask))[..., None]
|
mask = (255 - np.array(mask))[..., None]
|
||||||
image_array = np.concatenate([np.array(image), mask], axis=-1)
|
image_array = np.concatenate([np.array(image), mask], axis=-1)
|
||||||
inverted_image = Image.fromarray(image_array, "RGBA")
|
inverted_image = Image.fromarray(image_array, "RGBA")
|
||||||
@ -318,6 +433,20 @@ def get_segmented_image_by_agent(
|
|||||||
save_path: str = None,
|
save_path: str = None,
|
||||||
mode: Literal["loose", "strict"] = "loose",
|
mode: Literal["loose", "strict"] = "loose",
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
|
"""Segments an image using SAM and rembg, with quality checking.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (Image.Image): Input image.
|
||||||
|
sam_remover (SAMRemover): SAM-based remover.
|
||||||
|
rbg_remover (RembgRemover): rembg-based remover.
|
||||||
|
seg_checker (ImageSegChecker, optional): Quality checker.
|
||||||
|
save_path (str, optional): Path to save the output image.
|
||||||
|
mode (Literal["loose", "strict"], optional): Segmentation mode.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image.Image: Segmented RGBA image.
|
||||||
|
"""
|
||||||
|
|
||||||
def _is_valid_seg(raw_img: Image.Image, seg_img: Image.Image) -> bool:
|
def _is_valid_seg(raw_img: Image.Image, seg_img: Image.Image) -> bool:
|
||||||
if seg_checker is None:
|
if seg_checker is None:
|
||||||
return True
|
return True
|
||||||
|
|||||||
@ -39,13 +39,38 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class ImageStableSR:
|
class ImageStableSR:
|
||||||
"""Super-resolution image upscaler using Stable Diffusion x4 upscaling model from StabilityAI."""
|
"""Super-resolution image upscaler using Stable Diffusion x4 upscaling model.
|
||||||
|
|
||||||
|
This class wraps the StabilityAI Stable Diffusion x4 upscaler for high-quality
|
||||||
|
image super-resolution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path (str, optional): Path or HuggingFace repo for the model.
|
||||||
|
device (str, optional): Device for inference.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
from embodied_gen.models.sr_model import ImageStableSR
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
sr_model = ImageStableSR()
|
||||||
|
img = Image.open("input.png")
|
||||||
|
upscaled = sr_model(img)
|
||||||
|
upscaled.save("output.png")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_path: str = "stabilityai/stable-diffusion-x4-upscaler",
|
model_path: str = "stabilityai/stable-diffusion-x4-upscaler",
|
||||||
device="cuda",
|
device="cuda",
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Initializes the Stable Diffusion x4 upscaler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path (str, optional): Model path or repo.
|
||||||
|
device (str, optional): Device for inference.
|
||||||
|
"""
|
||||||
from diffusers import StableDiffusionUpscalePipeline
|
from diffusers import StableDiffusionUpscalePipeline
|
||||||
|
|
||||||
self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained(
|
self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained(
|
||||||
@ -62,6 +87,16 @@ class ImageStableSR:
|
|||||||
prompt: str = "",
|
prompt: str = "",
|
||||||
infer_step: int = 20,
|
infer_step: int = 20,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
|
"""Performs super-resolution on the input image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (Union[Image.Image, np.ndarray]): Input image.
|
||||||
|
prompt (str, optional): Text prompt for upscaling.
|
||||||
|
infer_step (int, optional): Number of inference steps.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image.Image: Upscaled image.
|
||||||
|
"""
|
||||||
if isinstance(image, np.ndarray):
|
if isinstance(image, np.ndarray):
|
||||||
image = Image.fromarray(image)
|
image = Image.fromarray(image)
|
||||||
|
|
||||||
@ -86,9 +121,26 @@ class ImageRealESRGAN:
|
|||||||
Attributes:
|
Attributes:
|
||||||
outscale (int): The output image scale factor (e.g., 2, 4).
|
outscale (int): The output image scale factor (e.g., 2, 4).
|
||||||
model_path (str): Path to the pre-trained model weights.
|
model_path (str): Path to the pre-trained model weights.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
from embodied_gen.models.sr_model import ImageRealESRGAN
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
sr_model = ImageRealESRGAN(outscale=4)
|
||||||
|
img = Image.open("input.png")
|
||||||
|
upscaled = sr_model(img)
|
||||||
|
upscaled.save("output.png")
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, outscale: int, model_path: str = None) -> None:
|
def __init__(self, outscale: int, model_path: str = None) -> None:
|
||||||
|
"""Initializes the RealESRGAN upscaler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outscale (int): Output scale factor.
|
||||||
|
model_path (str, optional): Path to model weights.
|
||||||
|
"""
|
||||||
# monkey patch to support torchvision>=0.16
|
# monkey patch to support torchvision>=0.16
|
||||||
import torchvision
|
import torchvision
|
||||||
from packaging import version
|
from packaging import version
|
||||||
@ -122,6 +174,7 @@ class ImageRealESRGAN:
|
|||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
|
|
||||||
def _lazy_init(self):
|
def _lazy_init(self):
|
||||||
|
"""Lazily initializes the RealESRGAN model."""
|
||||||
if self.upsampler is None:
|
if self.upsampler is None:
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
@ -145,6 +198,14 @@ class ImageRealESRGAN:
|
|||||||
|
|
||||||
@spaces.GPU
|
@spaces.GPU
|
||||||
def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
|
def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
|
||||||
|
"""Performs super-resolution on the input image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (Union[Image.Image, np.ndarray]): Input image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image.Image: Upscaled image.
|
||||||
|
"""
|
||||||
self._lazy_init()
|
self._lazy_init()
|
||||||
|
|
||||||
if isinstance(image, Image.Image):
|
if isinstance(image, Image.Image):
|
||||||
|
|||||||
@ -60,6 +60,11 @@ PROMPT_KAPPEND = "Single {object}, in the center of the image, white background,
|
|||||||
|
|
||||||
|
|
||||||
def download_kolors_weights(local_dir: str = "weights/Kolors") -> None:
|
def download_kolors_weights(local_dir: str = "weights/Kolors") -> None:
|
||||||
|
"""Downloads Kolors model weights from HuggingFace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local_dir (str, optional): Local directory to store weights.
|
||||||
|
"""
|
||||||
logger.info(f"Download kolors weights from huggingface...")
|
logger.info(f"Download kolors weights from huggingface...")
|
||||||
os.makedirs(local_dir, exist_ok=True)
|
os.makedirs(local_dir, exist_ok=True)
|
||||||
subprocess.run(
|
subprocess.run(
|
||||||
@ -93,6 +98,22 @@ def build_text2img_ip_pipeline(
|
|||||||
ref_scale: float,
|
ref_scale: float,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
) -> StableDiffusionXLPipelineIP:
|
) -> StableDiffusionXLPipelineIP:
|
||||||
|
"""Builds a Stable Diffusion XL pipeline with IP-Adapter for text-to-image generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ckpt_dir (str): Directory containing model checkpoints.
|
||||||
|
ref_scale (float): Reference scale for IP-Adapter.
|
||||||
|
device (str, optional): Device for inference.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StableDiffusionXLPipelineIP: Configured pipeline.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
from embodied_gen.models.text_model import build_text2img_ip_pipeline
|
||||||
|
pipe = build_text2img_ip_pipeline("weights/Kolors", ref_scale=0.3)
|
||||||
|
```
|
||||||
|
"""
|
||||||
download_kolors_weights(ckpt_dir)
|
download_kolors_weights(ckpt_dir)
|
||||||
|
|
||||||
text_encoder = ChatGLMModel.from_pretrained(
|
text_encoder = ChatGLMModel.from_pretrained(
|
||||||
@ -146,6 +167,21 @@ def build_text2img_pipeline(
|
|||||||
ckpt_dir: str,
|
ckpt_dir: str,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
) -> StableDiffusionXLPipeline:
|
) -> StableDiffusionXLPipeline:
|
||||||
|
"""Builds a Stable Diffusion XL pipeline for text-to-image generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ckpt_dir (str): Directory containing model checkpoints.
|
||||||
|
device (str, optional): Device for inference.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StableDiffusionXLPipeline: Configured pipeline.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
from embodied_gen.models.text_model import build_text2img_pipeline
|
||||||
|
pipe = build_text2img_pipeline("weights/Kolors")
|
||||||
|
```
|
||||||
|
"""
|
||||||
download_kolors_weights(ckpt_dir)
|
download_kolors_weights(ckpt_dir)
|
||||||
|
|
||||||
text_encoder = ChatGLMModel.from_pretrained(
|
text_encoder = ChatGLMModel.from_pretrained(
|
||||||
@ -185,6 +221,29 @@ def text2img_gen(
|
|||||||
ip_image_size: int = 512,
|
ip_image_size: int = 512,
|
||||||
seed: int = None,
|
seed: int = None,
|
||||||
) -> list[Image.Image]:
|
) -> list[Image.Image]:
|
||||||
|
"""Generates images from text prompts using a Stable Diffusion XL pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): Text prompt for image generation.
|
||||||
|
n_sample (int): Number of images to generate.
|
||||||
|
guidance_scale (float): Guidance scale for diffusion.
|
||||||
|
pipeline (StableDiffusionXLPipeline | StableDiffusionXLPipelineIP): Pipeline instance.
|
||||||
|
ip_image (Image.Image | str, optional): Reference image for IP-Adapter.
|
||||||
|
image_wh (tuple[int, int], optional): Output image size (width, height).
|
||||||
|
infer_step (int, optional): Number of inference steps.
|
||||||
|
ip_image_size (int, optional): Size for IP-Adapter image.
|
||||||
|
seed (int, optional): Random seed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[Image.Image]: List of generated images.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
from embodied_gen.models.text_model import text2img_gen
|
||||||
|
images = text2img_gen(prompt="banana", n_sample=3, guidance_scale=7.5)
|
||||||
|
images[0].save("banana.png")
|
||||||
|
```
|
||||||
|
"""
|
||||||
prompt = PROMPT_KAPPEND.format(object=prompt.strip())
|
prompt = PROMPT_KAPPEND.format(object=prompt.strip())
|
||||||
logger.info(f"Processing prompt: {prompt}")
|
logger.info(f"Processing prompt: {prompt}")
|
||||||
|
|
||||||
|
|||||||
@ -53,26 +53,31 @@ from thirdparty.pano2room.utils.functions import (
|
|||||||
|
|
||||||
|
|
||||||
class Pano2MeshSRPipeline:
|
class Pano2MeshSRPipeline:
|
||||||
"""Converting panoramic RGB image into 3D mesh representations, followed by inpainting and mesh refinement.
|
"""Pipeline for converting panoramic RGB images into 3D mesh representations.
|
||||||
|
|
||||||
This class integrates several key components including:
|
This class integrates depth estimation, inpainting, mesh conversion, multi-view mesh repair,
|
||||||
- Depth estimation from RGB panorama
|
and 3D Gaussian Splatting (3DGS) dataset generation.
|
||||||
- Inpainting of missing regions under offsets
|
|
||||||
- RGB-D to mesh conversion
|
|
||||||
- Multi-view mesh repair
|
|
||||||
- 3D Gaussian Splatting (3DGS) dataset generation
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (Pano2MeshSRConfig): Configuration object containing model and pipeline parameters.
|
config (Pano2MeshSRConfig): Configuration object containing model and pipeline parameters.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```python
|
```py
|
||||||
|
from embodied_gen.trainer.pono2mesh_trainer import Pano2MeshSRPipeline
|
||||||
|
from embodied_gen.utils.config import Pano2MeshSRConfig
|
||||||
|
|
||||||
|
config = Pano2MeshSRConfig()
|
||||||
pipeline = Pano2MeshSRPipeline(config)
|
pipeline = Pano2MeshSRPipeline(config)
|
||||||
pipeline(pano_image='example.png', output_dir='./output')
|
pipeline(pano_image='example.png', output_dir='./output')
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: Pano2MeshSRConfig) -> None:
|
def __init__(self, config: Pano2MeshSRConfig) -> None:
|
||||||
|
"""Initializes the pipeline with models and camera poses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (Pano2MeshSRConfig): Configuration object.
|
||||||
|
"""
|
||||||
self.cfg = config
|
self.cfg = config
|
||||||
self.device = config.device
|
self.device = config.device
|
||||||
|
|
||||||
@ -93,6 +98,7 @@ class Pano2MeshSRPipeline:
|
|||||||
self.kernel = torch.from_numpy(kernel).float().to(self.device)
|
self.kernel = torch.from_numpy(kernel).float().to(self.device)
|
||||||
|
|
||||||
def init_mesh_params(self) -> None:
|
def init_mesh_params(self) -> None:
|
||||||
|
"""Initializes mesh parameters and inpaint mask."""
|
||||||
torch.set_default_device(self.device)
|
torch.set_default_device(self.device)
|
||||||
self.inpaint_mask = torch.ones(
|
self.inpaint_mask = torch.ones(
|
||||||
(self.cfg.cubemap_h, self.cfg.cubemap_w), dtype=torch.bool
|
(self.cfg.cubemap_h, self.cfg.cubemap_w), dtype=torch.bool
|
||||||
@ -103,6 +109,14 @@ class Pano2MeshSRPipeline:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def read_camera_pose_file(filepath: str) -> np.ndarray:
|
def read_camera_pose_file(filepath: str) -> np.ndarray:
|
||||||
|
"""Reads a camera pose file and returns the pose matrix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filepath (str): Path to the camera pose file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: 4x4 camera pose matrix.
|
||||||
|
"""
|
||||||
with open(filepath, "r") as f:
|
with open(filepath, "r") as f:
|
||||||
values = [float(num) for line in f for num in line.split()]
|
values = [float(num) for line in f for num in line.split()]
|
||||||
|
|
||||||
@ -111,6 +125,14 @@ class Pano2MeshSRPipeline:
|
|||||||
def load_camera_poses(
|
def load_camera_poses(
|
||||||
self, trajectory_dir: str
|
self, trajectory_dir: str
|
||||||
) -> tuple[np.ndarray, list[torch.Tensor]]:
|
) -> tuple[np.ndarray, list[torch.Tensor]]:
|
||||||
|
"""Loads camera poses from a directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trajectory_dir (str): Directory containing camera pose files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[np.ndarray, list[torch.Tensor]]: List of relative camera poses.
|
||||||
|
"""
|
||||||
pose_filenames = sorted(
|
pose_filenames = sorted(
|
||||||
[
|
[
|
||||||
fname
|
fname
|
||||||
@ -148,6 +170,14 @@ class Pano2MeshSRPipeline:
|
|||||||
def load_inpaint_poses(
|
def load_inpaint_poses(
|
||||||
self, poses: torch.Tensor
|
self, poses: torch.Tensor
|
||||||
) -> dict[int, torch.Tensor]:
|
) -> dict[int, torch.Tensor]:
|
||||||
|
"""Samples and loads poses for inpainting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
poses (torch.Tensor): Tensor of camera poses.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[int, torch.Tensor]: Dictionary mapping indices to pose tensors.
|
||||||
|
"""
|
||||||
inpaint_poses = dict()
|
inpaint_poses = dict()
|
||||||
sampled_views = poses[:: self.cfg.inpaint_frame_stride]
|
sampled_views = poses[:: self.cfg.inpaint_frame_stride]
|
||||||
init_pose = torch.eye(4)
|
init_pose = torch.eye(4)
|
||||||
@ -162,6 +192,14 @@ class Pano2MeshSRPipeline:
|
|||||||
return inpaint_poses
|
return inpaint_poses
|
||||||
|
|
||||||
def project(self, world_to_cam: torch.Tensor):
|
def project(self, world_to_cam: torch.Tensor):
|
||||||
|
"""Projects the mesh to an image using the given camera pose.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
world_to_cam (torch.Tensor): World-to-camera transformation matrix.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Projected RGB image, inpaint mask, and depth map.
|
||||||
|
"""
|
||||||
(
|
(
|
||||||
project_image,
|
project_image,
|
||||||
project_depth,
|
project_depth,
|
||||||
@ -185,6 +223,14 @@ class Pano2MeshSRPipeline:
|
|||||||
return project_image[:3, ...], inpaint_mask, project_depth
|
return project_image[:3, ...], inpaint_mask, project_depth
|
||||||
|
|
||||||
def render_pano(self, pose: torch.Tensor):
|
def render_pano(self, pose: torch.Tensor):
|
||||||
|
"""Renders a panorama from the mesh using the given pose.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pose (torch.Tensor): Camera pose.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: RGB panorama, depth map, and mask.
|
||||||
|
"""
|
||||||
cubemap_list = []
|
cubemap_list = []
|
||||||
for cubemap_pose in self.cubemap_w2cs:
|
for cubemap_pose in self.cubemap_w2cs:
|
||||||
project_pose = cubemap_pose @ pose
|
project_pose = cubemap_pose @ pose
|
||||||
@ -213,6 +259,15 @@ class Pano2MeshSRPipeline:
|
|||||||
world_to_cam: torch.Tensor = None,
|
world_to_cam: torch.Tensor = None,
|
||||||
using_distance_map: bool = True,
|
using_distance_map: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Converts RGB-D images to mesh and updates mesh parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rgb (torch.Tensor): RGB image tensor.
|
||||||
|
depth (torch.Tensor): Depth map tensor.
|
||||||
|
inpaint_mask (torch.Tensor): Inpaint mask tensor.
|
||||||
|
world_to_cam (torch.Tensor, optional): Camera pose.
|
||||||
|
using_distance_map (bool, optional): Whether to use distance map.
|
||||||
|
"""
|
||||||
if world_to_cam is None:
|
if world_to_cam is None:
|
||||||
world_to_cam = torch.eye(4, dtype=torch.float32).to(self.device)
|
world_to_cam = torch.eye(4, dtype=torch.float32).to(self.device)
|
||||||
|
|
||||||
@ -239,6 +294,15 @@ class Pano2MeshSRPipeline:
|
|||||||
def get_edge_image_by_depth(
|
def get_edge_image_by_depth(
|
||||||
self, depth: torch.Tensor, dilate_iter: int = 1
|
self, depth: torch.Tensor, dilate_iter: int = 1
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
|
"""Computes edge image from depth map.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
depth (torch.Tensor): Depth map tensor.
|
||||||
|
dilate_iter (int, optional): Number of dilation iterations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Edge image.
|
||||||
|
"""
|
||||||
if isinstance(depth, torch.Tensor):
|
if isinstance(depth, torch.Tensor):
|
||||||
depth = depth.cpu().detach().numpy()
|
depth = depth.cpu().detach().numpy()
|
||||||
|
|
||||||
@ -253,6 +317,15 @@ class Pano2MeshSRPipeline:
|
|||||||
def mesh_repair_by_greedy_view_selection(
|
def mesh_repair_by_greedy_view_selection(
|
||||||
self, pose_dict: dict[str, torch.Tensor], output_dir: str
|
self, pose_dict: dict[str, torch.Tensor], output_dir: str
|
||||||
) -> list:
|
) -> list:
|
||||||
|
"""Repairs mesh by selecting views greedily and inpainting missing regions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pose_dict (dict[str, torch.Tensor]): Dictionary of poses for inpainting.
|
||||||
|
output_dir (str): Directory to save visualizations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of inpainted panoramas with poses.
|
||||||
|
"""
|
||||||
inpainted_panos_w_pose = []
|
inpainted_panos_w_pose = []
|
||||||
while len(pose_dict) > 0:
|
while len(pose_dict) > 0:
|
||||||
logger.info(f"Repairing mesh left rounds {len(pose_dict)}")
|
logger.info(f"Repairing mesh left rounds {len(pose_dict)}")
|
||||||
@ -343,6 +416,17 @@ class Pano2MeshSRPipeline:
|
|||||||
distances: torch.Tensor,
|
distances: torch.Tensor,
|
||||||
pano_mask: torch.Tensor,
|
pano_mask: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor]:
|
) -> tuple[torch.Tensor]:
|
||||||
|
"""Inpaints missing regions in a panorama.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx (int): Index of the panorama.
|
||||||
|
colors (torch.Tensor): RGB image tensor.
|
||||||
|
distances (torch.Tensor): Distance map tensor.
|
||||||
|
pano_mask (torch.Tensor): Mask tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[torch.Tensor]: Inpainted RGB image, distances, and normals.
|
||||||
|
"""
|
||||||
mask = (pano_mask[None, ..., None] > 0.5).float()
|
mask = (pano_mask[None, ..., None] > 0.5).float()
|
||||||
mask = mask.permute(0, 3, 1, 2)
|
mask = mask.permute(0, 3, 1, 2)
|
||||||
mask = dilation(mask, kernel=self.kernel)
|
mask = dilation(mask, kernel=self.kernel)
|
||||||
@ -364,6 +448,14 @@ class Pano2MeshSRPipeline:
|
|||||||
def preprocess_pano(
|
def preprocess_pano(
|
||||||
self, image: Image.Image | str
|
self, image: Image.Image | str
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Preprocesses a panoramic image for mesh generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (Image.Image | str): Input image or path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[torch.Tensor, torch.Tensor]: Preprocessed RGB and depth tensors.
|
||||||
|
"""
|
||||||
if isinstance(image, str):
|
if isinstance(image, str):
|
||||||
image = Image.open(image)
|
image = Image.open(image)
|
||||||
|
|
||||||
@ -387,6 +479,17 @@ class Pano2MeshSRPipeline:
|
|||||||
def pano_to_perpective(
|
def pano_to_perpective(
|
||||||
self, pano_image: torch.Tensor, pitch: float, yaw: float, fov: float
|
self, pano_image: torch.Tensor, pitch: float, yaw: float, fov: float
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
"""Converts a panoramic image to a perspective view.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pano_image (torch.Tensor): Panoramic image tensor.
|
||||||
|
pitch (float): Pitch angle.
|
||||||
|
yaw (float): Yaw angle.
|
||||||
|
fov (float): Field of view.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Perspective image tensor.
|
||||||
|
"""
|
||||||
rots = dict(
|
rots = dict(
|
||||||
roll=0,
|
roll=0,
|
||||||
pitch=pitch,
|
pitch=pitch,
|
||||||
@ -404,6 +507,14 @@ class Pano2MeshSRPipeline:
|
|||||||
return perspective
|
return perspective
|
||||||
|
|
||||||
def pano_to_cubemap(self, pano_rgb: torch.Tensor):
|
def pano_to_cubemap(self, pano_rgb: torch.Tensor):
|
||||||
|
"""Converts a panoramic RGB image to six cubemap views.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pano_rgb (torch.Tensor): Panoramic RGB image tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of cubemap RGB tensors.
|
||||||
|
"""
|
||||||
# Define six canonical cube directions in (pitch, yaw)
|
# Define six canonical cube directions in (pitch, yaw)
|
||||||
directions = [
|
directions = [
|
||||||
(0, 0),
|
(0, 0),
|
||||||
@ -424,6 +535,11 @@ class Pano2MeshSRPipeline:
|
|||||||
return cubemaps_rgb
|
return cubemaps_rgb
|
||||||
|
|
||||||
def save_mesh(self, output_path: str) -> None:
|
def save_mesh(self, output_path: str) -> None:
|
||||||
|
"""Saves the mesh to a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_path (str): Path to save the mesh file.
|
||||||
|
"""
|
||||||
vertices_np = self.vertices.T.cpu().numpy()
|
vertices_np = self.vertices.T.cpu().numpy()
|
||||||
colors_np = self.colors.T.cpu().numpy()
|
colors_np = self.colors.T.cpu().numpy()
|
||||||
faces_np = self.faces.T.cpu().numpy()
|
faces_np = self.faces.T.cpu().numpy()
|
||||||
@ -434,6 +550,14 @@ class Pano2MeshSRPipeline:
|
|||||||
mesh.export(output_path)
|
mesh.export(output_path)
|
||||||
|
|
||||||
def mesh_pose_to_gs_pose(self, mesh_pose: torch.Tensor) -> np.ndarray:
|
def mesh_pose_to_gs_pose(self, mesh_pose: torch.Tensor) -> np.ndarray:
|
||||||
|
"""Converts mesh pose to 3D Gaussian Splatting pose.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mesh_pose (torch.Tensor): Mesh pose tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Converted pose matrix.
|
||||||
|
"""
|
||||||
pose = mesh_pose.clone()
|
pose = mesh_pose.clone()
|
||||||
pose[0, :] *= -1
|
pose[0, :] *= -1
|
||||||
pose[1, :] *= -1
|
pose[1, :] *= -1
|
||||||
@ -450,6 +574,15 @@ class Pano2MeshSRPipeline:
|
|||||||
return c2w
|
return c2w
|
||||||
|
|
||||||
def __call__(self, pano_image: Image.Image | str, output_dir: str):
|
def __call__(self, pano_image: Image.Image | str, output_dir: str):
|
||||||
|
"""Runs the pipeline to generate mesh and 3DGS data from a panoramic image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pano_image (Image.Image | str): Input panoramic image or path.
|
||||||
|
output_dir (str): Directory to save outputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
self.init_mesh_params()
|
self.init_mesh_params()
|
||||||
pano_rgb, pano_depth = self.preprocess_pano(pano_image)
|
pano_rgb, pano_depth = self.preprocess_pano(pano_image)
|
||||||
self.sup_pool = SupInfoPool()
|
self.sup_pool = SupInfoPool()
|
||||||
|
|||||||
@ -24,11 +24,27 @@ __all__ = [
|
|||||||
"Scene3DItemEnum",
|
"Scene3DItemEnum",
|
||||||
"SpatialRelationEnum",
|
"SpatialRelationEnum",
|
||||||
"RobotItemEnum",
|
"RobotItemEnum",
|
||||||
|
"LayoutInfo",
|
||||||
|
"AssetType",
|
||||||
|
"SimAssetMapper",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RenderItems(str, Enum):
|
class RenderItems(str, Enum):
|
||||||
|
"""Enumeration of render item types for 3D scenes.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
IMAGE: Color image.
|
||||||
|
ALPHA: Mask image.
|
||||||
|
VIEW_NORMAL: View-space normal image.
|
||||||
|
GLOBAL_NORMAL: World-space normal image.
|
||||||
|
POSITION_MAP: Position map image.
|
||||||
|
DEPTH: Depth image.
|
||||||
|
ALBEDO: Albedo image.
|
||||||
|
DIFFUSE: Diffuse image.
|
||||||
|
"""
|
||||||
|
|
||||||
IMAGE = "image_color"
|
IMAGE = "image_color"
|
||||||
ALPHA = "image_mask"
|
ALPHA = "image_mask"
|
||||||
VIEW_NORMAL = "image_view_normal"
|
VIEW_NORMAL = "image_view_normal"
|
||||||
@ -41,6 +57,21 @@ class RenderItems(str, Enum):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Scene3DItemEnum(str, Enum):
|
class Scene3DItemEnum(str, Enum):
|
||||||
|
"""Enumeration of 3D scene item categories.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
BACKGROUND: Background objects.
|
||||||
|
CONTEXT: Contextual objects.
|
||||||
|
ROBOT: Robot entity.
|
||||||
|
MANIPULATED_OBJS: Objects manipulated by the robot.
|
||||||
|
DISTRACTOR_OBJS: Distractor objects.
|
||||||
|
OTHERS: Other objects.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
object_list(layout_relation): Returns a list of objects in the scene.
|
||||||
|
object_mapping(layout_relation): Returns a mapping from object to category.
|
||||||
|
"""
|
||||||
|
|
||||||
BACKGROUND = "background"
|
BACKGROUND = "background"
|
||||||
CONTEXT = "context"
|
CONTEXT = "context"
|
||||||
ROBOT = "robot"
|
ROBOT = "robot"
|
||||||
@ -50,6 +81,14 @@ class Scene3DItemEnum(str, Enum):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def object_list(cls, layout_relation: dict) -> list:
|
def object_list(cls, layout_relation: dict) -> list:
|
||||||
|
"""Returns a list of objects in the scene.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layout_relation: Dictionary mapping categories to objects.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of objects in the scene.
|
||||||
|
"""
|
||||||
return (
|
return (
|
||||||
[
|
[
|
||||||
layout_relation[cls.BACKGROUND.value],
|
layout_relation[cls.BACKGROUND.value],
|
||||||
@ -61,6 +100,14 @@ class Scene3DItemEnum(str, Enum):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def object_mapping(cls, layout_relation):
|
def object_mapping(cls, layout_relation):
|
||||||
|
"""Returns a mapping from object to category.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layout_relation: Dictionary mapping categories to objects.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping object names to their category.
|
||||||
|
"""
|
||||||
relation_mapping = {
|
relation_mapping = {
|
||||||
# layout_relation[cls.ROBOT.value]: cls.ROBOT.value,
|
# layout_relation[cls.ROBOT.value]: cls.ROBOT.value,
|
||||||
layout_relation[cls.BACKGROUND.value]: cls.BACKGROUND.value,
|
layout_relation[cls.BACKGROUND.value]: cls.BACKGROUND.value,
|
||||||
@ -84,6 +131,15 @@ class Scene3DItemEnum(str, Enum):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SpatialRelationEnum(str, Enum):
|
class SpatialRelationEnum(str, Enum):
|
||||||
|
"""Enumeration of spatial relations for objects in a scene.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
ON: Objects on a surface (e.g., table).
|
||||||
|
IN: Objects in a container or room.
|
||||||
|
INSIDE: Objects inside a shelf or rack.
|
||||||
|
FLOOR: Objects on the floor.
|
||||||
|
"""
|
||||||
|
|
||||||
ON = "ON" # objects on the table
|
ON = "ON" # objects on the table
|
||||||
IN = "IN" # objects in the room
|
IN = "IN" # objects in the room
|
||||||
INSIDE = "INSIDE" # objects inside the shelf/rack
|
INSIDE = "INSIDE" # objects inside the shelf/rack
|
||||||
@ -92,6 +148,14 @@ class SpatialRelationEnum(str, Enum):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RobotItemEnum(str, Enum):
|
class RobotItemEnum(str, Enum):
|
||||||
|
"""Enumeration of supported robot types.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
FRANKA: Franka robot.
|
||||||
|
UR5: UR5 robot.
|
||||||
|
PIPER: Piper robot.
|
||||||
|
"""
|
||||||
|
|
||||||
FRANKA = "franka"
|
FRANKA = "franka"
|
||||||
UR5 = "ur5"
|
UR5 = "ur5"
|
||||||
PIPER = "piper"
|
PIPER = "piper"
|
||||||
@ -99,6 +163,18 @@ class RobotItemEnum(str, Enum):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LayoutInfo(DataClassJsonMixin):
|
class LayoutInfo(DataClassJsonMixin):
|
||||||
|
"""Data structure for layout information in a 3D scene.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
tree: Hierarchical structure of scene objects.
|
||||||
|
relation: Spatial relations between objects.
|
||||||
|
objs_desc: Descriptions of objects.
|
||||||
|
objs_mapping: Mapping from object names to categories.
|
||||||
|
assets: Asset file paths for objects.
|
||||||
|
quality: Quality information for assets.
|
||||||
|
position: Position coordinates for objects.
|
||||||
|
"""
|
||||||
|
|
||||||
tree: dict[str, list]
|
tree: dict[str, list]
|
||||||
relation: dict[str, str | list[str]]
|
relation: dict[str, str | list[str]]
|
||||||
objs_desc: dict[str, str] = field(default_factory=dict)
|
objs_desc: dict[str, str] = field(default_factory=dict)
|
||||||
@ -106,3 +182,64 @@ class LayoutInfo(DataClassJsonMixin):
|
|||||||
assets: dict[str, str] = field(default_factory=dict)
|
assets: dict[str, str] = field(default_factory=dict)
|
||||||
quality: dict[str, str] = field(default_factory=dict)
|
quality: dict[str, str] = field(default_factory=dict)
|
||||||
position: dict[str, list[float]] = field(default_factory=dict)
|
position: dict[str, list[float]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AssetType(str):
|
||||||
|
"""Enumeration for asset types.
|
||||||
|
|
||||||
|
Supported types:
|
||||||
|
MJCF: MuJoCo XML format.
|
||||||
|
USD: Universal Scene Description format.
|
||||||
|
URDF: Unified Robot Description Format.
|
||||||
|
MESH: Mesh file format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
MJCF = "mjcf"
|
||||||
|
USD = "usd"
|
||||||
|
URDF = "urdf"
|
||||||
|
MESH = "mesh"
|
||||||
|
|
||||||
|
|
||||||
|
class SimAssetMapper:
|
||||||
|
"""Maps simulator names to asset types.
|
||||||
|
|
||||||
|
Provides a mapping from simulator names to their corresponding asset type.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
from embodied_gen.utils.enum import SimAssetMapper
|
||||||
|
asset_type = SimAssetMapper["isaacsim"]
|
||||||
|
print(asset_type) # Output: 'usd'
|
||||||
|
```
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
__class_getitem__(key): Returns the asset type for a given simulator name.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_mapping = dict(
|
||||||
|
ISAACSIM=AssetType.USD,
|
||||||
|
ISAACGYM=AssetType.URDF,
|
||||||
|
MUJOCO=AssetType.MJCF,
|
||||||
|
GENESIS=AssetType.MJCF,
|
||||||
|
SAPIEN=AssetType.URDF,
|
||||||
|
PYBULLET=AssetType.URDF,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __class_getitem__(cls, key: str):
|
||||||
|
"""Returns the asset type for a given simulator name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Name of the simulator.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AssetType corresponding to the simulator.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If the simulator name is not recognized.
|
||||||
|
"""
|
||||||
|
key = key.upper()
|
||||||
|
if key.startswith("SAPIEN"):
|
||||||
|
key = "SAPIEN"
|
||||||
|
return cls._mapping[key]
|
||||||
|
|||||||
@ -45,13 +45,13 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
def matrix_to_pose(matrix: np.ndarray) -> list[float]:
|
def matrix_to_pose(matrix: np.ndarray) -> list[float]:
|
||||||
"""Convert a 4x4 transformation matrix to a pose (x, y, z, qx, qy, qz, qw).
|
"""Converts a 4x4 transformation matrix to a pose (x, y, z, qx, qy, qz, qw).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
matrix (np.ndarray): 4x4 transformation matrix.
|
matrix (np.ndarray): 4x4 transformation matrix.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[float]: Pose as [x, y, z, qx, qy, qz, qw].
|
list[float]: Pose as [x, y, z, qx, qy, qz, qw].
|
||||||
"""
|
"""
|
||||||
x, y, z = matrix[:3, 3]
|
x, y, z = matrix[:3, 3]
|
||||||
rot_mat = matrix[:3, :3]
|
rot_mat = matrix[:3, :3]
|
||||||
@ -62,13 +62,13 @@ def matrix_to_pose(matrix: np.ndarray) -> list[float]:
|
|||||||
|
|
||||||
|
|
||||||
def pose_to_matrix(pose: list[float]) -> np.ndarray:
|
def pose_to_matrix(pose: list[float]) -> np.ndarray:
|
||||||
"""Convert pose (x, y, z, qx, qy, qz, qw) to a 4x4 transformation matrix.
|
"""Converts pose (x, y, z, qx, qy, qz, qw) to a 4x4 transformation matrix.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
List[float]: Pose as [x, y, z, qx, qy, qz, qw].
|
pose (list[float]): Pose as [x, y, z, qx, qy, qz, qw].
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
matrix (np.ndarray): 4x4 transformation matrix.
|
np.ndarray: 4x4 transformation matrix.
|
||||||
"""
|
"""
|
||||||
x, y, z, qx, qy, qz, qw = pose
|
x, y, z, qx, qy, qz, qw = pose
|
||||||
r = R.from_quat([qx, qy, qz, qw])
|
r = R.from_quat([qx, qy, qz, qw])
|
||||||
@ -82,6 +82,16 @@ def pose_to_matrix(pose: list[float]) -> np.ndarray:
|
|||||||
def compute_xy_bbox(
|
def compute_xy_bbox(
|
||||||
vertices: np.ndarray, col_x: int = 0, col_y: int = 1
|
vertices: np.ndarray, col_x: int = 0, col_y: int = 1
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
|
"""Computes the bounding box in XY plane for given vertices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vertices (np.ndarray): Vertex coordinates.
|
||||||
|
col_x (int, optional): Column index for X.
|
||||||
|
col_y (int, optional): Column index for Y.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float]: [min_x, max_x, min_y, max_y]
|
||||||
|
"""
|
||||||
x_vals = vertices[:, col_x]
|
x_vals = vertices[:, col_x]
|
||||||
y_vals = vertices[:, col_y]
|
y_vals = vertices[:, col_y]
|
||||||
return x_vals.min(), x_vals.max(), y_vals.min(), y_vals.max()
|
return x_vals.min(), x_vals.max(), y_vals.min(), y_vals.max()
|
||||||
@ -92,6 +102,16 @@ def has_iou_conflict(
|
|||||||
placed_boxes: list[list[float]],
|
placed_boxes: list[list[float]],
|
||||||
iou_threshold: float = 0.0,
|
iou_threshold: float = 0.0,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
"""Checks for intersection-over-union conflict between boxes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
new_box (list[float]): New box coordinates.
|
||||||
|
placed_boxes (list[list[float]]): List of placed box coordinates.
|
||||||
|
iou_threshold (float, optional): IOU threshold.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if conflict exists, False otherwise.
|
||||||
|
"""
|
||||||
new_min_x, new_max_x, new_min_y, new_max_y = new_box
|
new_min_x, new_max_x, new_min_y, new_max_y = new_box
|
||||||
for min_x, max_x, min_y, max_y in placed_boxes:
|
for min_x, max_x, min_y, max_y in placed_boxes:
|
||||||
ix1 = max(new_min_x, min_x)
|
ix1 = max(new_min_x, min_x)
|
||||||
@ -105,7 +125,14 @@ def has_iou_conflict(
|
|||||||
|
|
||||||
|
|
||||||
def with_seed(seed_attr_name: str = "seed"):
|
def with_seed(seed_attr_name: str = "seed"):
|
||||||
"""A parameterized decorator that temporarily sets the random seed."""
|
"""Decorator to temporarily set the random seed for reproducibility.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed_attr_name (str, optional): Name of the seed argument.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
function: Decorator function.
|
||||||
|
"""
|
||||||
|
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
@ -143,6 +170,20 @@ def compute_convex_hull_path(
|
|||||||
y_axis: int = 1,
|
y_axis: int = 1,
|
||||||
z_axis: int = 2,
|
z_axis: int = 2,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
|
"""Computes a dense convex hull path for the top surface of a mesh.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vertices (np.ndarray): Mesh vertices.
|
||||||
|
z_threshold (float, optional): Z threshold for top surface.
|
||||||
|
interp_per_edge (int, optional): Interpolation points per edge.
|
||||||
|
margin (float, optional): Margin for polygon buffer.
|
||||||
|
x_axis (int, optional): X axis index.
|
||||||
|
y_axis (int, optional): Y axis index.
|
||||||
|
z_axis (int, optional): Z axis index.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Matplotlib path object for the convex hull.
|
||||||
|
"""
|
||||||
top_vertices = vertices[
|
top_vertices = vertices[
|
||||||
vertices[:, z_axis] > vertices[:, z_axis].max() - z_threshold
|
vertices[:, z_axis] > vertices[:, z_axis].max() - z_threshold
|
||||||
]
|
]
|
||||||
@ -170,6 +211,15 @@ def compute_convex_hull_path(
|
|||||||
|
|
||||||
|
|
||||||
def find_parent_node(node: str, tree: dict) -> str | None:
|
def find_parent_node(node: str, tree: dict) -> str | None:
|
||||||
|
"""Finds the parent node of a given node in a tree.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (str): Node name.
|
||||||
|
tree (dict): Tree structure.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str | None: Parent node name or None.
|
||||||
|
"""
|
||||||
for parent, children in tree.items():
|
for parent, children in tree.items():
|
||||||
if any(child[0] == node for child in children):
|
if any(child[0] == node for child in children):
|
||||||
return parent
|
return parent
|
||||||
@ -177,6 +227,16 @@ def find_parent_node(node: str, tree: dict) -> str | None:
|
|||||||
|
|
||||||
|
|
||||||
def all_corners_inside(hull: Path, box: list, threshold: int = 3) -> bool:
|
def all_corners_inside(hull: Path, box: list, threshold: int = 3) -> bool:
|
||||||
|
"""Checks if at least `threshold` corners of a box are inside a hull.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hull (Path): Convex hull path.
|
||||||
|
box (list): Box coordinates [x1, x2, y1, y2].
|
||||||
|
threshold (int, optional): Minimum corners inside.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if enough corners are inside.
|
||||||
|
"""
|
||||||
x1, x2, y1, y2 = box
|
x1, x2, y1, y2 = box
|
||||||
corners = [[x1, y1], [x2, y1], [x1, y2], [x2, y2]]
|
corners = [[x1, y1], [x2, y1], [x1, y2], [x2, y2]]
|
||||||
|
|
||||||
@ -187,6 +247,15 @@ def all_corners_inside(hull: Path, box: list, threshold: int = 3) -> bool:
|
|||||||
def compute_axis_rotation_quat(
|
def compute_axis_rotation_quat(
|
||||||
axis: Literal["x", "y", "z"], angle_rad: float
|
axis: Literal["x", "y", "z"], angle_rad: float
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
|
"""Computes quaternion for rotation around a given axis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
axis (Literal["x", "y", "z"]): Axis of rotation.
|
||||||
|
angle_rad (float): Rotation angle in radians.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float]: Quaternion [x, y, z, w].
|
||||||
|
"""
|
||||||
if axis.lower() == "x":
|
if axis.lower() == "x":
|
||||||
q = Quaternion(axis=[1, 0, 0], angle=angle_rad)
|
q = Quaternion(axis=[1, 0, 0], angle=angle_rad)
|
||||||
elif axis.lower() == "y":
|
elif axis.lower() == "y":
|
||||||
@ -202,6 +271,15 @@ def compute_axis_rotation_quat(
|
|||||||
def quaternion_multiply(
|
def quaternion_multiply(
|
||||||
init_quat: list[float], rotate_quat: list[float]
|
init_quat: list[float], rotate_quat: list[float]
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
|
"""Multiplies two quaternions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
init_quat (list[float]): Initial quaternion [x, y, z, w].
|
||||||
|
rotate_quat (list[float]): Rotation quaternion [x, y, z, w].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float]: Resulting quaternion [x, y, z, w].
|
||||||
|
"""
|
||||||
qx, qy, qz, qw = init_quat
|
qx, qy, qz, qw = init_quat
|
||||||
q1 = Quaternion(w=qw, x=qx, y=qy, z=qz)
|
q1 = Quaternion(w=qw, x=qx, y=qy, z=qz)
|
||||||
qx, qy, qz, qw = rotate_quat
|
qx, qy, qz, qw = rotate_quat
|
||||||
@ -217,7 +295,17 @@ def check_reachable(
|
|||||||
min_reach: float = 0.25,
|
min_reach: float = 0.25,
|
||||||
max_reach: float = 0.85,
|
max_reach: float = 0.85,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check if the target point is within the reachable range."""
|
"""Checks if the target point is within the reachable range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_xyz (np.ndarray): Base position.
|
||||||
|
reach_xyz (np.ndarray): Target position.
|
||||||
|
min_reach (float, optional): Minimum reach distance.
|
||||||
|
max_reach (float, optional): Maximum reach distance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if reachable, False otherwise.
|
||||||
|
"""
|
||||||
distance = np.linalg.norm(reach_xyz - base_xyz)
|
distance = np.linalg.norm(reach_xyz - base_xyz)
|
||||||
|
|
||||||
return min_reach < distance < max_reach
|
return min_reach < distance < max_reach
|
||||||
@ -238,26 +326,31 @@ def bfs_placement(
|
|||||||
robot_dim: float = 0.12,
|
robot_dim: float = 0.12,
|
||||||
seed: int = None,
|
seed: int = None,
|
||||||
) -> LayoutInfo:
|
) -> LayoutInfo:
|
||||||
"""Place objects in the layout using BFS traversal.
|
"""Places objects in a scene layout using BFS traversal.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
layout_file: Path to the JSON file defining the layout structure and assets.
|
layout_file (str): Path to layout JSON file generated from `layout-cli`.
|
||||||
floor_margin: Z-offset for the background object, typically for objects placed on the floor.
|
floor_margin (float, optional): Z-offset for objects placed on the floor.
|
||||||
beside_margin: Minimum margin for objects placed 'beside' their parent, used when 'on' placement fails.
|
beside_margin (float, optional): Minimum margin for objects placed 'beside' their parent, used when 'on' placement fails.
|
||||||
max_attempts: Maximum number of attempts to find a non-overlapping position for an object.
|
max_attempts (int, optional): Max attempts for a non-overlapping placement.
|
||||||
init_rpy: Initial Roll-Pitch-Yaw rotation rad applied to all object meshes to align the mesh's
|
init_rpy (tuple, optional): Initial rotation (rpy).
|
||||||
coordinate system with the world's (e.g., Z-up).
|
rotate_objs (bool, optional): Whether to random rotate objects.
|
||||||
rotate_objs: If True, apply a random rotation around the Z-axis for manipulated and distractor objects.
|
rotate_bg (bool, optional): Whether to random rotate background.
|
||||||
rotate_bg: If True, apply a random rotation around the Y-axis for the background object.
|
rotate_context (bool, optional): Whether to random rotate context asset.
|
||||||
rotate_context: If True, apply a random rotation around the Z-axis for the context object.
|
limit_reach_range (tuple[float, float] | None, optional): If set, enforce a check that manipulated objects are within the robot's reach range, in meter.
|
||||||
limit_reach_range: If set, enforce a check that manipulated objects are within the robot's reach range, in meter.
|
max_orient_diff (float | None, optional): If set, enforce a check that manipulated objects are within the robot's orientation range, in degree.
|
||||||
max_orient_diff: If set, enforce a check that manipulated objects are within the robot's orientation range, in degree.
|
robot_dim (float, optional): The approximate robot size.
|
||||||
robot_dim: The approximate dimension (e.g., diameter) of the robot for box representation.
|
seed (int, optional): Random seed for reproducible placement.
|
||||||
seed: Random seed for reproducible placement.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A :class:`LayoutInfo` object containing the objects and their final computed 7D poses
|
LayoutInfo: Layout information with object poses.
|
||||||
([x, y, z, qx, qy, qz, qw]).
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
from embodied_gen.utils.geometry import bfs_placement
|
||||||
|
layout = bfs_placement("scene_layout.json", seed=42)
|
||||||
|
print(layout.position)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
layout_info = LayoutInfo.from_dict(json.load(open(layout_file, "r")))
|
layout_info = LayoutInfo.from_dict(json.load(open(layout_file, "r")))
|
||||||
asset_dir = os.path.dirname(layout_file)
|
asset_dir = os.path.dirname(layout_file)
|
||||||
@ -478,6 +571,13 @@ def bfs_placement(
|
|||||||
def compose_mesh_scene(
|
def compose_mesh_scene(
|
||||||
layout_info: LayoutInfo, out_scene_path: str, with_bg: bool = False
|
layout_info: LayoutInfo, out_scene_path: str, with_bg: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Composes a mesh scene from layout information and saves to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layout_info (LayoutInfo): Layout information.
|
||||||
|
out_scene_path (str): Output scene file path.
|
||||||
|
with_bg (bool, optional): Include background mesh.
|
||||||
|
"""
|
||||||
object_mapping = Scene3DItemEnum.object_mapping(layout_info.relation)
|
object_mapping = Scene3DItemEnum.object_mapping(layout_info.relation)
|
||||||
scene = trimesh.Scene()
|
scene = trimesh.Scene()
|
||||||
for node in layout_info.assets:
|
for node in layout_info.assets:
|
||||||
@ -505,6 +605,16 @@ def compose_mesh_scene(
|
|||||||
def compute_pinhole_intrinsics(
|
def compute_pinhole_intrinsics(
|
||||||
image_w: int, image_h: int, fov_deg: float
|
image_w: int, image_h: int, fov_deg: float
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
|
"""Computes pinhole camera intrinsic matrix from image size and FOV.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_w (int): Image width.
|
||||||
|
image_h (int): Image height.
|
||||||
|
fov_deg (float): Field of view in degrees.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Intrinsic matrix K.
|
||||||
|
"""
|
||||||
fov_rad = np.deg2rad(fov_deg)
|
fov_rad = np.deg2rad(fov_deg)
|
||||||
fx = image_w / (2 * np.tan(fov_rad / 2))
|
fx = image_w / (2 * np.tan(fov_rad / 2))
|
||||||
fy = fx # assuming square pixels
|
fy = fx # assuming square pixels
|
||||||
|
|||||||
@ -45,7 +45,35 @@ CONFIG_FILE = "embodied_gen/utils/gpt_config.yaml"
|
|||||||
|
|
||||||
|
|
||||||
class GPTclient:
|
class GPTclient:
|
||||||
"""A client to interact with the GPT model via OpenAI or Azure API."""
|
"""A client to interact with GPT models via OpenAI or Azure API.
|
||||||
|
|
||||||
|
Supports text and image prompts, connection checking, and configurable parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
endpoint (str): API endpoint URL.
|
||||||
|
api_key (str): API key for authentication.
|
||||||
|
model_name (str, optional): Model name to use.
|
||||||
|
api_version (str, optional): API version (for Azure).
|
||||||
|
check_connection (bool, optional): Whether to check API connection.
|
||||||
|
verbose (bool, optional): Enable verbose logging.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```sh
|
||||||
|
export ENDPOINT="https://yfb-openai-sweden.openai.azure.com"
|
||||||
|
export API_KEY="xxxxxx"
|
||||||
|
export API_VERSION="2025-03-01-preview"
|
||||||
|
export MODEL_NAME="yfb-gpt-4o-sweden"
|
||||||
|
```
|
||||||
|
```py
|
||||||
|
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
||||||
|
|
||||||
|
response = GPT_CLIENT.query("Describe the physics of a falling apple.")
|
||||||
|
response = GPT_CLIENT.query(
|
||||||
|
text_prompt="Describe the content in each image."
|
||||||
|
image_base64=["path/to/image1.png", "path/to/image2.jpg"],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -82,6 +110,7 @@ class GPTclient:
|
|||||||
stop=(stop_after_attempt(10) | stop_after_delay(30)),
|
stop=(stop_after_attempt(10) | stop_after_delay(30)),
|
||||||
)
|
)
|
||||||
def completion_with_backoff(self, **kwargs):
|
def completion_with_backoff(self, **kwargs):
|
||||||
|
"""Performs a chat completion request with retry/backoff."""
|
||||||
return self.client.chat.completions.create(**kwargs)
|
return self.client.chat.completions.create(**kwargs)
|
||||||
|
|
||||||
def query(
|
def query(
|
||||||
@ -91,19 +120,16 @@ class GPTclient:
|
|||||||
system_role: Optional[str] = None,
|
system_role: Optional[str] = None,
|
||||||
params: Optional[dict] = None,
|
params: Optional[dict] = None,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""Queries the GPT model with a text and optional image prompts.
|
"""Queries the GPT model with text and optional image prompts.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text_prompt (str): The main text input that the model responds to.
|
text_prompt (str): Main text input.
|
||||||
image_base64 (Optional[List[str]]): A list of image base64 strings
|
image_base64 (Optional[list[str | Image.Image]], optional): List of image base64 strings, file paths, or PIL Images.
|
||||||
or local image paths or PIL.Image to accompany the text prompt.
|
system_role (Optional[str], optional): System-level instructions.
|
||||||
system_role (Optional[str]): Optional system-level instructions
|
params (Optional[dict], optional): Additional GPT parameters.
|
||||||
that specify the behavior of the assistant.
|
|
||||||
params (Optional[dict]): Additional parameters for GPT setting.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[str]: The response content generated by the model based on
|
Optional[str]: Model response content, or None if error.
|
||||||
the prompt. Returns `None` if an error occurs.
|
|
||||||
"""
|
"""
|
||||||
if system_role is None:
|
if system_role is None:
|
||||||
system_role = "You are a highly knowledgeable assistant specializing in physics, engineering, and object properties." # noqa
|
system_role = "You are a highly knowledgeable assistant specializing in physics, engineering, and object properties." # noqa
|
||||||
@ -177,7 +203,11 @@ class GPTclient:
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
def check_connection(self) -> None:
|
def check_connection(self) -> None:
|
||||||
"""Check whether the GPT API connection is working."""
|
"""Checks whether the GPT API connection is working.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ConnectionError: If connection fails.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
response = self.completion_with_backoff(
|
response = self.completion_with_backoff(
|
||||||
messages=[
|
messages=[
|
||||||
|
|||||||
@ -69,6 +69,40 @@ def render_asset3d(
|
|||||||
no_index_file: bool = False,
|
no_index_file: bool = False,
|
||||||
with_mtl: bool = True,
|
with_mtl: bool = True,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
|
"""Renders a 3D mesh asset and returns output image paths.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mesh_path (str): Path to the mesh file.
|
||||||
|
output_root (str): Directory to save outputs.
|
||||||
|
distance (float, optional): Camera distance.
|
||||||
|
num_images (int, optional): Number of views to render.
|
||||||
|
elevation (list[float], optional): Camera elevation angles.
|
||||||
|
pbr_light_factor (float, optional): PBR lighting factor.
|
||||||
|
return_key (str, optional): Glob pattern for output images.
|
||||||
|
output_subdir (str, optional): Subdirectory for outputs.
|
||||||
|
gen_color_mp4 (bool, optional): Generate color MP4 video.
|
||||||
|
gen_viewnormal_mp4 (bool, optional): Generate view normal MP4.
|
||||||
|
gen_glonormal_mp4 (bool, optional): Generate global normal MP4.
|
||||||
|
no_index_file (bool, optional): Skip index file saving.
|
||||||
|
with_mtl (bool, optional): Use mesh material.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[str]: List of output image file paths.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
from embodied_gen.utils.process_media import render_asset3d
|
||||||
|
|
||||||
|
image_paths = render_asset3d(
|
||||||
|
mesh_path="path_to_mesh.obj",
|
||||||
|
output_root="path_to_save_dir",
|
||||||
|
num_images=6,
|
||||||
|
elevation=(30, -30),
|
||||||
|
output_subdir="renders",
|
||||||
|
no_index_file=True,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
input_args = dict(
|
input_args = dict(
|
||||||
mesh_path=mesh_path,
|
mesh_path=mesh_path,
|
||||||
output_root=output_root,
|
output_root=output_root,
|
||||||
@ -95,6 +129,13 @@ def render_asset3d(
|
|||||||
|
|
||||||
|
|
||||||
def merge_images_video(color_images, normal_images, output_path) -> None:
|
def merge_images_video(color_images, normal_images, output_path) -> None:
|
||||||
|
"""Merges color and normal images into a video.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
color_images (list[np.ndarray]): List of color images.
|
||||||
|
normal_images (list[np.ndarray]): List of normal images.
|
||||||
|
output_path (str): Path to save the output video.
|
||||||
|
"""
|
||||||
width = color_images[0].shape[1]
|
width = color_images[0].shape[1]
|
||||||
combined_video = [
|
combined_video = [
|
||||||
np.hstack([rgb_img[:, : width // 2], normal_img[:, width // 2 :]])
|
np.hstack([rgb_img[:, : width // 2], normal_img[:, width // 2 :]])
|
||||||
@ -108,7 +149,13 @@ def merge_images_video(color_images, normal_images, output_path) -> None:
|
|||||||
def merge_video_video(
|
def merge_video_video(
|
||||||
video_path1: str, video_path2: str, output_path: str
|
video_path1: str, video_path2: str, output_path: str
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Merge two videos by the left half and the right half of the videos."""
|
"""Merges two videos by combining their left and right halves.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path1 (str): Path to first video.
|
||||||
|
video_path2 (str): Path to second video.
|
||||||
|
output_path (str): Path to save the merged video.
|
||||||
|
"""
|
||||||
clip1 = VideoFileClip(video_path1)
|
clip1 = VideoFileClip(video_path1)
|
||||||
clip2 = VideoFileClip(video_path2)
|
clip2 = VideoFileClip(video_path2)
|
||||||
|
|
||||||
@ -127,6 +174,16 @@ def filter_small_connected_components(
|
|||||||
area_ratio: float,
|
area_ratio: float,
|
||||||
connectivity: int = 8,
|
connectivity: int = 8,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
|
"""Removes small connected components from a binary mask.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask (Union[Image.Image, np.ndarray]): Input mask.
|
||||||
|
area_ratio (float): Minimum area ratio for components.
|
||||||
|
connectivity (int, optional): Connectivity for labeling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Mask with small components removed.
|
||||||
|
"""
|
||||||
if isinstance(mask, Image.Image):
|
if isinstance(mask, Image.Image):
|
||||||
mask = np.array(mask)
|
mask = np.array(mask)
|
||||||
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(
|
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(
|
||||||
@ -152,6 +209,16 @@ def filter_image_small_connected_components(
|
|||||||
area_ratio: float = 10,
|
area_ratio: float = 10,
|
||||||
connectivity: int = 8,
|
connectivity: int = 8,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
|
"""Removes small connected components from the alpha channel of an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (Union[Image.Image, np.ndarray]): Input image.
|
||||||
|
area_ratio (float, optional): Minimum area ratio.
|
||||||
|
connectivity (int, optional): Connectivity for labeling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Image with filtered alpha channel.
|
||||||
|
"""
|
||||||
if isinstance(image, Image.Image):
|
if isinstance(image, Image.Image):
|
||||||
image = image.convert("RGBA")
|
image = image.convert("RGBA")
|
||||||
image = np.array(image)
|
image = np.array(image)
|
||||||
@ -169,6 +236,24 @@ def combine_images_to_grid(
|
|||||||
target_wh: tuple[int, int] = (512, 512),
|
target_wh: tuple[int, int] = (512, 512),
|
||||||
image_mode: str = "RGB",
|
image_mode: str = "RGB",
|
||||||
) -> list[Image.Image]:
|
) -> list[Image.Image]:
|
||||||
|
"""Combines multiple images into a grid.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (list[str | Image.Image]): List of image paths or PIL Images.
|
||||||
|
cat_row_col (tuple[int, int], optional): Grid rows and columns.
|
||||||
|
target_wh (tuple[int, int], optional): Target image size.
|
||||||
|
image_mode (str, optional): Image mode.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[Image.Image]: List containing the grid image.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
from embodied_gen.utils.process_media import combine_images_to_grid
|
||||||
|
grid = combine_images_to_grid(["img1.png", "img2.png"])
|
||||||
|
grid[0].save("grid.png")
|
||||||
|
```
|
||||||
|
"""
|
||||||
n_images = len(images)
|
n_images = len(images)
|
||||||
if n_images == 1:
|
if n_images == 1:
|
||||||
return images
|
return images
|
||||||
@ -196,6 +281,19 @@ def combine_images_to_grid(
|
|||||||
|
|
||||||
|
|
||||||
class SceneTreeVisualizer:
|
class SceneTreeVisualizer:
|
||||||
|
"""Visualizes a scene tree layout using networkx and matplotlib.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layout_info (LayoutInfo): Layout information for the scene.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
from embodied_gen.utils.process_media import SceneTreeVisualizer
|
||||||
|
visualizer = SceneTreeVisualizer(layout_info)
|
||||||
|
visualizer.render(save_path="tree.png")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, layout_info: LayoutInfo) -> None:
|
def __init__(self, layout_info: LayoutInfo) -> None:
|
||||||
self.tree = layout_info.tree
|
self.tree = layout_info.tree
|
||||||
self.relation = layout_info.relation
|
self.relation = layout_info.relation
|
||||||
@ -274,6 +372,14 @@ class SceneTreeVisualizer:
|
|||||||
dpi=300,
|
dpi=300,
|
||||||
title: str = "Scene 3D Hierarchy Tree",
|
title: str = "Scene 3D Hierarchy Tree",
|
||||||
):
|
):
|
||||||
|
"""Renders the scene tree and saves to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
save_path (str): Path to save the rendered image.
|
||||||
|
figsize (tuple, optional): Figure size.
|
||||||
|
dpi (int, optional): Image DPI.
|
||||||
|
title (str, optional): Plot image title.
|
||||||
|
"""
|
||||||
node_colors = [
|
node_colors = [
|
||||||
self.role_colors[self._get_node_role(n)] for n in self.G.nodes
|
self.role_colors[self._get_node_role(n)] for n in self.G.nodes
|
||||||
]
|
]
|
||||||
@ -350,6 +456,14 @@ class SceneTreeVisualizer:
|
|||||||
|
|
||||||
|
|
||||||
def load_scene_dict(file_path: str) -> dict:
|
def load_scene_dict(file_path: str) -> dict:
|
||||||
|
"""Loads a scene description dictionary from a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (str): Path to the scene description file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Mapping from scene ID to description.
|
||||||
|
"""
|
||||||
scene_dict = {}
|
scene_dict = {}
|
||||||
with open(file_path, "r", encoding='utf-8') as f:
|
with open(file_path, "r", encoding='utf-8') as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
@ -363,12 +477,28 @@ def load_scene_dict(file_path: str) -> dict:
|
|||||||
|
|
||||||
|
|
||||||
def is_image_file(filename: str) -> bool:
|
def is_image_file(filename: str) -> bool:
|
||||||
|
"""Checks if a filename is an image file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename (str): Filename to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if image file, False otherwise.
|
||||||
|
"""
|
||||||
mime_type, _ = mimetypes.guess_type(filename)
|
mime_type, _ = mimetypes.guess_type(filename)
|
||||||
|
|
||||||
return mime_type is not None and mime_type.startswith('image')
|
return mime_type is not None and mime_type.startswith('image')
|
||||||
|
|
||||||
|
|
||||||
def parse_text_prompts(prompts: list[str]) -> list[str]:
|
def parse_text_prompts(prompts: list[str]) -> list[str]:
|
||||||
|
"""Parses text prompts from a list or file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts (list[str]): List of prompts or a file path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[str]: List of parsed prompts.
|
||||||
|
"""
|
||||||
if len(prompts) == 1 and prompts[0].endswith(".txt"):
|
if len(prompts) == 1 and prompts[0].endswith(".txt"):
|
||||||
with open(prompts[0], "r") as f:
|
with open(prompts[0], "r") as f:
|
||||||
prompts = [
|
prompts = [
|
||||||
@ -386,13 +516,18 @@ def alpha_blend_rgba(
|
|||||||
"""Alpha blends a foreground RGBA image over a background RGBA image.
|
"""Alpha blends a foreground RGBA image over a background RGBA image.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fg_image: Foreground image. Can be a file path (str), a PIL Image,
|
fg_image: Foreground image (str, PIL Image, or ndarray).
|
||||||
or a NumPy ndarray.
|
bg_image: Background image (str, PIL Image, or ndarray).
|
||||||
bg_image: Background image. Can be a file path (str), a PIL Image,
|
|
||||||
or a NumPy ndarray.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A PIL Image representing the alpha-blended result in RGBA mode.
|
Image.Image: Alpha-blended RGBA image.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
from embodied_gen.utils.process_media import alpha_blend_rgba
|
||||||
|
result = alpha_blend_rgba("fg.png", "bg.png")
|
||||||
|
result.save("blended.png")
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
if isinstance(fg_image, str):
|
if isinstance(fg_image, str):
|
||||||
fg_image = Image.open(fg_image)
|
fg_image = Image.open(fg_image)
|
||||||
@ -421,13 +556,11 @@ def check_object_edge_truncated(
|
|||||||
"""Checks if a binary object mask is truncated at the image edges.
|
"""Checks if a binary object mask is truncated at the image edges.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mask: A 2D binary NumPy array where nonzero values indicate the object region.
|
mask (np.ndarray): 2D binary mask.
|
||||||
edge_threshold: Number of pixels from each image edge to consider for truncation.
|
edge_threshold (int, optional): Edge pixel threshold.
|
||||||
Defaults to 5.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if the object is fully enclosed (not truncated).
|
bool: True if object is fully enclosed, False if truncated.
|
||||||
False if the object touches or crosses any image boundary.
|
|
||||||
"""
|
"""
|
||||||
top = mask[:edge_threshold, :].any()
|
top = mask[:edge_threshold, :].any()
|
||||||
bottom = mask[-edge_threshold:, :].any()
|
bottom = mask[-edge_threshold:, :].any()
|
||||||
@ -440,6 +573,22 @@ def check_object_edge_truncated(
|
|||||||
def vcat_pil_images(
|
def vcat_pil_images(
|
||||||
images: list[Image.Image], image_mode: str = "RGB"
|
images: list[Image.Image], image_mode: str = "RGB"
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
|
"""Vertically concatenates a list of PIL images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (list[Image.Image]): List of images.
|
||||||
|
image_mode (str, optional): Image mode.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image.Image: Vertically concatenated image.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
from embodied_gen.utils.process_media import vcat_pil_images
|
||||||
|
img = vcat_pil_images([Image.open("a.png"), Image.open("b.png")])
|
||||||
|
img.save("vcat.png")
|
||||||
|
```
|
||||||
|
"""
|
||||||
widths, heights = zip(*(img.size for img in images))
|
widths, heights = zip(*(img.size for img in images))
|
||||||
total_height = sum(heights)
|
total_height = sum(heights)
|
||||||
max_width = max(widths)
|
max_width = max(widths)
|
||||||
|
|||||||
@ -69,6 +69,21 @@ def load_actor_from_urdf(
|
|||||||
update_mass: bool = False,
|
update_mass: bool = False,
|
||||||
scale: float | np.ndarray = 1.0,
|
scale: float | np.ndarray = 1.0,
|
||||||
) -> sapien.pysapien.Entity:
|
) -> sapien.pysapien.Entity:
|
||||||
|
"""Load an sapien actor from a URDF file and add it to the scene.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scene (sapien.Scene | ManiSkillScene): The simulation scene.
|
||||||
|
file_path (str): Path to the URDF file.
|
||||||
|
pose (sapien.Pose | None): Initial pose of the actor.
|
||||||
|
env_idx (int): Environment index for multi-env setup.
|
||||||
|
use_static (bool): Whether the actor is static.
|
||||||
|
update_mass (bool): Whether to update the actor's mass from URDF.
|
||||||
|
scale (float | np.ndarray): Scale factor for the actor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
sapien.pysapien.Entity: The created actor entity.
|
||||||
|
"""
|
||||||
|
|
||||||
def _get_local_pose(origin_tag: ET.Element | None) -> sapien.Pose:
|
def _get_local_pose(origin_tag: ET.Element | None) -> sapien.Pose:
|
||||||
local_pose = sapien.Pose(p=[0, 0, 0], q=[1, 0, 0, 0])
|
local_pose = sapien.Pose(p=[0, 0, 0], q=[1, 0, 0, 0])
|
||||||
if origin_tag is not None:
|
if origin_tag is not None:
|
||||||
@ -154,14 +169,17 @@ def load_assets_from_layout_file(
|
|||||||
init_quat: list[float] = [0, 0, 0, 1],
|
init_quat: list[float] = [0, 0, 0, 1],
|
||||||
env_idx: int = None,
|
env_idx: int = None,
|
||||||
) -> dict[str, sapien.pysapien.Entity]:
|
) -> dict[str, sapien.pysapien.Entity]:
|
||||||
"""Load assets from `EmbodiedGen` layout-gen output and create actors in the scene.
|
"""Load assets from an EmbodiedGen layout file and create sapien actors in the scene.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
scene (sapien.Scene | ManiSkillScene): The SAPIEN or ManiSkill scene to load assets into.
|
scene (ManiSkillScene | sapien.Scene): The sapien simulation scene.
|
||||||
layout (str): The layout file path.
|
layout (str): Path to the embodiedgen layout file.
|
||||||
z_offset (float): Offset to apply to the Z-coordinate of non-context objects.
|
z_offset (float): Z offset for non-context objects.
|
||||||
init_quat (List[float]): Initial quaternion (x, y, z, w) for orientation adjustment.
|
init_quat (list[float]): Initial quaternion for orientation.
|
||||||
env_idx (int): Environment index for multi-environment setup.
|
env_idx (int): Environment index.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, sapien.pysapien.Entity]: Mapping from object names to actor entities.
|
||||||
"""
|
"""
|
||||||
asset_root = os.path.dirname(layout)
|
asset_root = os.path.dirname(layout)
|
||||||
layout = LayoutInfo.from_dict(json.load(open(layout, "r")))
|
layout = LayoutInfo.from_dict(json.load(open(layout, "r")))
|
||||||
@ -206,6 +224,19 @@ def load_mani_skill_robot(
|
|||||||
control_mode: str = "pd_joint_pos",
|
control_mode: str = "pd_joint_pos",
|
||||||
backend_str: tuple[str, str] = ("cpu", "gpu"),
|
backend_str: tuple[str, str] = ("cpu", "gpu"),
|
||||||
) -> BaseAgent:
|
) -> BaseAgent:
|
||||||
|
"""Load a ManiSkill robot agent into the scene.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scene (sapien.Scene | ManiSkillScene): The simulation scene.
|
||||||
|
layout (LayoutInfo | str): Layout info or path to layout file.
|
||||||
|
control_freq (int): Control frequency.
|
||||||
|
robot_init_qpos_noise (float): Noise for initial joint positions.
|
||||||
|
control_mode (str): Robot control mode.
|
||||||
|
backend_str (tuple[str, str]): Simulation/render backend.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BaseAgent: The loaded robot agent.
|
||||||
|
"""
|
||||||
from mani_skill.agents import REGISTERED_AGENTS
|
from mani_skill.agents import REGISTERED_AGENTS
|
||||||
from mani_skill.envs.scene import ManiSkillScene
|
from mani_skill.envs.scene import ManiSkillScene
|
||||||
from mani_skill.envs.utils.system.backend import (
|
from mani_skill.envs.utils.system.backend import (
|
||||||
@ -278,14 +309,14 @@ def render_images(
|
|||||||
]
|
]
|
||||||
] = None,
|
] = None,
|
||||||
) -> dict[str, Image.Image]:
|
) -> dict[str, Image.Image]:
|
||||||
"""Render images from a given sapien camera.
|
"""Render images from a given SAPIEN camera.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
camera (sapien.render.RenderCameraComponent): The camera to render from.
|
camera (sapien.render.RenderCameraComponent): Camera to render from.
|
||||||
render_keys (List[str]): Types of images to render (e.g., Color, Segmentation).
|
render_keys (list[str], optional): Types of images to render.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Image.Image]: Dictionary of rendered images.
|
dict[str, Image.Image]: Dictionary of rendered images.
|
||||||
"""
|
"""
|
||||||
if render_keys is None:
|
if render_keys is None:
|
||||||
render_keys = [
|
render_keys = [
|
||||||
@ -341,11 +372,33 @@ def render_images(
|
|||||||
|
|
||||||
|
|
||||||
class SapienSceneManager:
|
class SapienSceneManager:
|
||||||
"""A class to manage SAPIEN simulator."""
|
"""Manages SAPIEN simulation scenes, cameras, and rendering.
|
||||||
|
|
||||||
|
This class provides utilities for setting up scenes, adding cameras,
|
||||||
|
stepping simulation, and rendering images.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
sim_freq (int): Simulation frequency.
|
||||||
|
ray_tracing (bool): Whether to use ray tracing.
|
||||||
|
device (str): Device for simulation.
|
||||||
|
renderer (sapien.SapienRenderer): SAPIEN renderer.
|
||||||
|
scene (sapien.Scene): Simulation scene.
|
||||||
|
cameras (list): List of camera components.
|
||||||
|
actors (dict): Mapping of actor names to entities.
|
||||||
|
|
||||||
|
Example see `embodied_gen/scripts/simulate_sapien.py`.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, sim_freq: int, ray_tracing: bool, device: str = "cuda"
|
self, sim_freq: int, ray_tracing: bool, device: str = "cuda"
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Initialize the scene manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sim_freq (int): Simulation frequency.
|
||||||
|
ray_tracing (bool): Enable ray tracing.
|
||||||
|
device (str): Device for simulation.
|
||||||
|
"""
|
||||||
self.sim_freq = sim_freq
|
self.sim_freq = sim_freq
|
||||||
self.ray_tracing = ray_tracing
|
self.ray_tracing = ray_tracing
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -355,7 +408,11 @@ class SapienSceneManager:
|
|||||||
self.actors: dict[str, sapien.pysapien.Entity] = {}
|
self.actors: dict[str, sapien.pysapien.Entity] = {}
|
||||||
|
|
||||||
def _setup_scene(self) -> sapien.Scene:
|
def _setup_scene(self) -> sapien.Scene:
|
||||||
"""Set up the SAPIEN scene with lighting and ground."""
|
"""Set up the SAPIEN scene with lighting and ground.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
sapien.Scene: The initialized scene.
|
||||||
|
"""
|
||||||
# Ray tracing settings
|
# Ray tracing settings
|
||||||
if self.ray_tracing:
|
if self.ray_tracing:
|
||||||
sapien.render.set_camera_shader_dir("rt")
|
sapien.render.set_camera_shader_dir("rt")
|
||||||
@ -397,6 +454,18 @@ class SapienSceneManager:
|
|||||||
render_keys: list[str],
|
render_keys: list[str],
|
||||||
sim_steps_per_control: int = 1,
|
sim_steps_per_control: int = 1,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
"""Step the simulation and render images from cameras.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent (BaseAgent): The robot agent.
|
||||||
|
action (torch.Tensor): Action to apply.
|
||||||
|
cameras (list): List of camera components.
|
||||||
|
render_keys (list[str]): Types of images to render.
|
||||||
|
sim_steps_per_control (int): Simulation steps per control.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Dictionary of rendered frames per camera.
|
||||||
|
"""
|
||||||
agent.set_action(action)
|
agent.set_action(action)
|
||||||
frames = defaultdict(list)
|
frames = defaultdict(list)
|
||||||
for _ in range(sim_steps_per_control):
|
for _ in range(sim_steps_per_control):
|
||||||
@ -417,13 +486,13 @@ class SapienSceneManager:
|
|||||||
image_hw: tuple[int, int],
|
image_hw: tuple[int, int],
|
||||||
fovy_deg: float,
|
fovy_deg: float,
|
||||||
) -> sapien.render.RenderCameraComponent:
|
) -> sapien.render.RenderCameraComponent:
|
||||||
"""Create a single camera in the scene.
|
"""Create a camera in the scene.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cam_name (str): Name of the camera.
|
cam_name (str): Camera name.
|
||||||
pose (sapien.Pose): Camera pose p=(x, y, z), q=(w, x, y, z)
|
pose (sapien.Pose): Camera pose.
|
||||||
image_hw (Tuple[int, int]): Image resolution (height, width) for cameras.
|
image_hw (tuple[int, int]): Image resolution (height, width).
|
||||||
fovy_deg (float): Field of view in degrees for cameras.
|
fovy_deg (float): Field of view in degrees.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
sapien.render.RenderCameraComponent: The created camera.
|
sapien.render.RenderCameraComponent: The created camera.
|
||||||
@ -456,15 +525,15 @@ class SapienSceneManager:
|
|||||||
"""Initialize multiple cameras arranged in a circle.
|
"""Initialize multiple cameras arranged in a circle.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_cameras (int): Number of cameras to create.
|
num_cameras (int): Number of cameras.
|
||||||
radius (float): Radius of the camera circle.
|
radius (float): Circle radius.
|
||||||
height (float): Fixed Z-coordinate of the cameras.
|
height (float): Camera height.
|
||||||
target_pt (list[float]): 3D point (x, y, z) that cameras look at.
|
target_pt (list[float]): Target point to look at.
|
||||||
image_hw (Tuple[int, int]): Image resolution (height, width) for cameras.
|
image_hw (tuple[int, int]): Image resolution.
|
||||||
fovy_deg (float): Field of view in degrees for cameras.
|
fovy_deg (float): Field of view in degrees.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[sapien.render.RenderCameraComponent]: List of created cameras.
|
list[sapien.render.RenderCameraComponent]: List of cameras.
|
||||||
"""
|
"""
|
||||||
angle_step = 2 * np.pi / num_cameras
|
angle_step = 2 * np.pi / num_cameras
|
||||||
world_up_vec = np.array([0.0, 0.0, 1.0])
|
world_up_vec = np.array([0.0, 0.0, 1.0])
|
||||||
@ -510,6 +579,19 @@ class SapienSceneManager:
|
|||||||
|
|
||||||
|
|
||||||
class FrankaPandaGrasper(object):
|
class FrankaPandaGrasper(object):
|
||||||
|
"""Provides grasp planning and control for Franka Panda robot.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
agent (BaseAgent): The robot agent.
|
||||||
|
robot: The robot instance.
|
||||||
|
control_freq (float): Control frequency.
|
||||||
|
control_timestep (float): Control timestep.
|
||||||
|
joint_vel_limits (float): Joint velocity limits.
|
||||||
|
joint_acc_limits (float): Joint acceleration limits.
|
||||||
|
finger_length (float): Length of gripper fingers.
|
||||||
|
planners: Motion planners for each environment.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
agent: BaseAgent,
|
agent: BaseAgent,
|
||||||
@ -518,6 +600,7 @@ class FrankaPandaGrasper(object):
|
|||||||
joint_acc_limits: float = 1.0,
|
joint_acc_limits: float = 1.0,
|
||||||
finger_length: float = 0.025,
|
finger_length: float = 0.025,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Initialize the grasper."""
|
||||||
self.agent = agent
|
self.agent = agent
|
||||||
self.robot = agent.robot
|
self.robot = agent.robot
|
||||||
self.control_freq = control_freq
|
self.control_freq = control_freq
|
||||||
@ -553,6 +636,15 @@ class FrankaPandaGrasper(object):
|
|||||||
gripper_state: Literal[-1, 1],
|
gripper_state: Literal[-1, 1],
|
||||||
n_step: int = 10,
|
n_step: int = 10,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
|
"""Generate gripper control actions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gripper_state (Literal[-1, 1]): Desired gripper state.
|
||||||
|
n_step (int): Number of steps.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Array of gripper actions.
|
||||||
|
"""
|
||||||
qpos = self.robot.get_qpos()[0, :-2].cpu().numpy()
|
qpos = self.robot.get_qpos()[0, :-2].cpu().numpy()
|
||||||
actions = []
|
actions = []
|
||||||
for _ in range(n_step):
|
for _ in range(n_step):
|
||||||
@ -571,6 +663,20 @@ class FrankaPandaGrasper(object):
|
|||||||
action_key: str = "position",
|
action_key: str = "position",
|
||||||
env_idx: int = 0,
|
env_idx: int = 0,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
|
"""Plan and execute motion to a target pose.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pose (sapien.Pose): Target pose.
|
||||||
|
control_timestep (float): Control timestep.
|
||||||
|
gripper_state (Literal[-1, 1]): Desired gripper state.
|
||||||
|
use_point_cloud (bool): Use point cloud for planning.
|
||||||
|
n_max_step (int): Max number of steps.
|
||||||
|
action_key (str): Key for action in result.
|
||||||
|
env_idx (int): Environment index.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Array of actions to reach the pose.
|
||||||
|
"""
|
||||||
result = self.planners[env_idx].plan_qpos_to_pose(
|
result = self.planners[env_idx].plan_qpos_to_pose(
|
||||||
np.concatenate([pose.p, pose.q]),
|
np.concatenate([pose.p, pose.q]),
|
||||||
self.robot.get_qpos().cpu().numpy()[0],
|
self.robot.get_qpos().cpu().numpy()[0],
|
||||||
@ -608,6 +714,17 @@ class FrankaPandaGrasper(object):
|
|||||||
offset: tuple[float, float, float] = [0, 0, -0.05],
|
offset: tuple[float, float, float] = [0, 0, -0.05],
|
||||||
env_idx: int = 0,
|
env_idx: int = 0,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
|
"""Compute grasp actions for a target actor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actor (sapien.pysapien.Entity): Target actor to grasp.
|
||||||
|
reach_target_only (bool): Only reach the target pose if True.
|
||||||
|
offset (tuple[float, float, float]): Offset for reach pose.
|
||||||
|
env_idx (int): Environment index.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Array of grasp actions.
|
||||||
|
"""
|
||||||
physx_rigid = actor.components[1]
|
physx_rigid = actor.components[1]
|
||||||
mesh = get_component_mesh(physx_rigid, to_world_frame=True)
|
mesh = get_component_mesh(physx_rigid, to_world_frame=True)
|
||||||
obb = mesh.bounding_box_oriented
|
obb = mesh.bounding_box_oriented
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
VERSION = "v0.1.5"
|
VERSION = "v0.1.6"
|
||||||
|
|||||||
@ -27,14 +27,22 @@ from PIL import Image
|
|||||||
|
|
||||||
|
|
||||||
class AestheticPredictor:
|
class AestheticPredictor:
|
||||||
"""Aesthetic Score Predictor.
|
"""Aesthetic Score Predictor using CLIP and a pre-trained MLP.
|
||||||
|
|
||||||
Checkpoints from https://github.com/christophschuhmann/improved-aesthetic-predictor/tree/main
|
Checkpoints from `https://github.com/christophschuhmann/improved-aesthetic-predictor/tree/main`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
clip_model_dir (str): Path to the directory of the CLIP model.
|
clip_model_dir (str, optional): Path to CLIP model directory.
|
||||||
sac_model_path (str): Path to the pre-trained SAC model.
|
sac_model_path (str, optional): Path to SAC model weights.
|
||||||
device (str): Device to use for computation ("cuda" or "cpu").
|
device (str, optional): Device for computation ("cuda" or "cpu").
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
from embodied_gen.validators.aesthetic_predictor import AestheticPredictor
|
||||||
|
predictor = AestheticPredictor(device="cuda")
|
||||||
|
score = predictor.predict("image.png")
|
||||||
|
print("Aesthetic score:", score)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, clip_model_dir=None, sac_model_path=None, device="cpu"):
|
def __init__(self, clip_model_dir=None, sac_model_path=None, device="cpu"):
|
||||||
@ -109,7 +117,7 @@ class AestheticPredictor:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
def predict(self, image_path):
|
def predict(self, image_path):
|
||||||
"""Predict the aesthetic score for a given image.
|
"""Predicts the aesthetic score for a given image.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image_path (str): Path to the image file.
|
image_path (str): Path to the image file.
|
||||||
|
|||||||
@ -40,6 +40,16 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class BaseChecker:
|
class BaseChecker:
|
||||||
|
"""Base class for quality checkers using GPT clients.
|
||||||
|
|
||||||
|
Provides a common interface for querying and validating responses.
|
||||||
|
Subclasses must implement the `query` method.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
prompt (str): The prompt used for queries.
|
||||||
|
verbose (bool): Whether to enable verbose logging.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, prompt: str = None, verbose: bool = False) -> None:
|
def __init__(self, prompt: str = None, verbose: bool = False) -> None:
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
@ -70,6 +80,15 @@ class BaseChecker:
|
|||||||
def validate(
|
def validate(
|
||||||
checkers: list["BaseChecker"], images_list: list[list[str]]
|
checkers: list["BaseChecker"], images_list: list[list[str]]
|
||||||
) -> list:
|
) -> list:
|
||||||
|
"""Validates a list of checkers against corresponding image lists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkers (list[BaseChecker]): List of checker instances.
|
||||||
|
images_list (list[list[str]]): List of image path lists.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: Validation results with overall outcome.
|
||||||
|
"""
|
||||||
assert len(checkers) == len(images_list)
|
assert len(checkers) == len(images_list)
|
||||||
results = []
|
results = []
|
||||||
overall_result = True
|
overall_result = True
|
||||||
@ -192,7 +211,7 @@ class ImageSegChecker(BaseChecker):
|
|||||||
|
|
||||||
|
|
||||||
class ImageAestheticChecker(BaseChecker):
|
class ImageAestheticChecker(BaseChecker):
|
||||||
"""A class for evaluating the aesthetic quality of images.
|
"""Evaluates the aesthetic quality of images using a CLIP-based predictor.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
clip_model_dir (str): Path to the CLIP model directory.
|
clip_model_dir (str): Path to the CLIP model directory.
|
||||||
@ -200,6 +219,14 @@ class ImageAestheticChecker(BaseChecker):
|
|||||||
thresh (float): Threshold above which images are considered aesthetically acceptable.
|
thresh (float): Threshold above which images are considered aesthetically acceptable.
|
||||||
verbose (bool): Whether to print detailed log messages.
|
verbose (bool): Whether to print detailed log messages.
|
||||||
predictor (AestheticPredictor): The model used to predict aesthetic scores.
|
predictor (AestheticPredictor): The model used to predict aesthetic scores.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
from embodied_gen.validators.quality_checkers import ImageAestheticChecker
|
||||||
|
checker = ImageAestheticChecker(thresh=4.5)
|
||||||
|
flag, score = checker(["image1.png", "image2.png"])
|
||||||
|
print("Aesthetic OK:", flag, "Score:", score)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -227,6 +254,16 @@ class ImageAestheticChecker(BaseChecker):
|
|||||||
|
|
||||||
|
|
||||||
class SemanticConsistChecker(BaseChecker):
|
class SemanticConsistChecker(BaseChecker):
|
||||||
|
"""Checks semantic consistency between text descriptions and segmented images.
|
||||||
|
|
||||||
|
Uses GPT to evaluate if the image matches the text in object type, geometry, and color.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
gpt_client (GPTclient): GPT client for queries.
|
||||||
|
prompt (str): Prompt for consistency evaluation.
|
||||||
|
verbose (bool): Whether to enable verbose logging.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
gpt_client: GPTclient,
|
gpt_client: GPTclient,
|
||||||
@ -276,6 +313,16 @@ class SemanticConsistChecker(BaseChecker):
|
|||||||
|
|
||||||
|
|
||||||
class TextGenAlignChecker(BaseChecker):
|
class TextGenAlignChecker(BaseChecker):
|
||||||
|
"""Evaluates alignment between text prompts and generated 3D asset images.
|
||||||
|
|
||||||
|
Assesses if the rendered images match the text description in category and geometry.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
gpt_client (GPTclient): GPT client for queries.
|
||||||
|
prompt (str): Prompt for alignment evaluation.
|
||||||
|
verbose (bool): Whether to enable verbose logging.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
gpt_client: GPTclient,
|
gpt_client: GPTclient,
|
||||||
@ -489,6 +536,17 @@ class PanoHeightEstimator(object):
|
|||||||
|
|
||||||
|
|
||||||
class SemanticMatcher(BaseChecker):
|
class SemanticMatcher(BaseChecker):
|
||||||
|
"""Matches query text to semantically similar scene descriptions.
|
||||||
|
|
||||||
|
Uses GPT to find the most similar scene IDs from a dictionary.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
gpt_client (GPTclient): GPT client for queries.
|
||||||
|
prompt (str): Prompt for semantic matching.
|
||||||
|
verbose (bool): Whether to enable verbose logging.
|
||||||
|
seed (int): Random seed for selection.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
gpt_client: GPTclient,
|
gpt_client: GPTclient,
|
||||||
@ -543,6 +601,17 @@ class SemanticMatcher(BaseChecker):
|
|||||||
def query(
|
def query(
|
||||||
self, text: str, context: dict, rand: bool = True, params: dict = None
|
self, text: str, context: dict, rand: bool = True, params: dict = None
|
||||||
) -> str:
|
) -> str:
|
||||||
|
"""Queries for semantically similar scene IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): Query text.
|
||||||
|
context (dict): Dictionary of scene descriptions.
|
||||||
|
rand (bool, optional): Whether to randomly select from top matches.
|
||||||
|
params (dict, optional): Additional GPT parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Matched scene ID.
|
||||||
|
"""
|
||||||
match_list = self.gpt_client.query(
|
match_list = self.gpt_client.query(
|
||||||
self.prompt.format(context=context, text=text),
|
self.prompt.format(context=context, text=text),
|
||||||
params=params,
|
params=params,
|
||||||
|
|||||||
@ -80,6 +80,31 @@ URDF_TEMPLATE = """
|
|||||||
|
|
||||||
|
|
||||||
class URDFGenerator(object):
|
class URDFGenerator(object):
|
||||||
|
"""Generates URDF files for 3D assets with physical and semantic attributes.
|
||||||
|
|
||||||
|
Uses GPT to estimate object properties and generates a URDF file with mesh, friction, mass, and metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gpt_client (GPTclient): GPT client for attribute estimation.
|
||||||
|
mesh_file_list (list[str], optional): Additional mesh files to copy.
|
||||||
|
prompt_template (str, optional): Prompt template for GPT queries.
|
||||||
|
attrs_name (list[str], optional): List of attribute names to include.
|
||||||
|
render_dir (str, optional): Directory for rendered images.
|
||||||
|
render_view_num (int, optional): Number of views to render.
|
||||||
|
decompose_convex (bool, optional): Whether to decompose mesh for collision.
|
||||||
|
rotate_xyzw (list[float], optional): Quaternion for mesh rotation.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
from embodied_gen.validators.urdf_convertor import URDFGenerator
|
||||||
|
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
||||||
|
|
||||||
|
urdf_gen = URDFGenerator(GPT_CLIENT, render_view_num=4)
|
||||||
|
urdf_path = urdf_gen(mesh_path="mesh.obj", output_root="output_dir")
|
||||||
|
print("Generated URDF:", urdf_path)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
gpt_client: GPTclient,
|
gpt_client: GPTclient,
|
||||||
@ -168,6 +193,14 @@ class URDFGenerator(object):
|
|||||||
self.rotate_xyzw = rotate_xyzw
|
self.rotate_xyzw = rotate_xyzw
|
||||||
|
|
||||||
def parse_response(self, response: str) -> dict[str, any]:
|
def parse_response(self, response: str) -> dict[str, any]:
|
||||||
|
"""Parses GPT response to extract asset attributes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response (str): GPT response string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, any]: Parsed attributes.
|
||||||
|
"""
|
||||||
lines = response.split("\n")
|
lines = response.split("\n")
|
||||||
lines = [line.strip() for line in lines if line]
|
lines = [line.strip() for line in lines if line]
|
||||||
category = lines[0].split(": ")[1]
|
category = lines[0].split(": ")[1]
|
||||||
@ -207,11 +240,9 @@ class URDFGenerator(object):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_mesh (str): Path to the input mesh file.
|
input_mesh (str): Path to the input mesh file.
|
||||||
output_dir (str): Directory to store the generated URDF
|
output_dir (str): Directory to store the generated URDF and mesh.
|
||||||
and processed mesh.
|
attr_dict (dict): Dictionary of asset attributes.
|
||||||
attr_dict (dict): Dictionary containing attributes like height,
|
output_name (str, optional): Name for the URDF and robot.
|
||||||
mass, and friction coefficients.
|
|
||||||
output_name (str, optional): Name for the generated URDF and robot.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: Path to the generated URDF file.
|
str: Path to the generated URDF file.
|
||||||
@ -336,6 +367,16 @@ class URDFGenerator(object):
|
|||||||
attr_root: str = ".//link/extra_info",
|
attr_root: str = ".//link/extra_info",
|
||||||
attr_name: str = "scale",
|
attr_name: str = "scale",
|
||||||
) -> float:
|
) -> float:
|
||||||
|
"""Extracts an attribute value from a URDF file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
urdf_path (str): Path to the URDF file.
|
||||||
|
attr_root (str, optional): XML path to attribute root.
|
||||||
|
attr_name (str, optional): Attribute name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: Attribute value, or None if not found.
|
||||||
|
"""
|
||||||
if not os.path.exists(urdf_path):
|
if not os.path.exists(urdf_path):
|
||||||
raise FileNotFoundError(f"URDF file not found: {urdf_path}")
|
raise FileNotFoundError(f"URDF file not found: {urdf_path}")
|
||||||
|
|
||||||
@ -358,6 +399,13 @@ class URDFGenerator(object):
|
|||||||
def add_quality_tag(
|
def add_quality_tag(
|
||||||
urdf_path: str, results: list, output_path: str = None
|
urdf_path: str, results: list, output_path: str = None
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Adds a quality tag to a URDF file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
urdf_path (str): Path to the URDF file.
|
||||||
|
results (list): List of [checker_name, result] pairs.
|
||||||
|
output_path (str, optional): Output file path.
|
||||||
|
"""
|
||||||
if output_path is None:
|
if output_path is None:
|
||||||
output_path = urdf_path
|
output_path = urdf_path
|
||||||
|
|
||||||
@ -382,6 +430,14 @@ class URDFGenerator(object):
|
|||||||
logger.info(f"URDF files saved to {output_path}")
|
logger.info(f"URDF files saved to {output_path}")
|
||||||
|
|
||||||
def get_estimated_attributes(self, asset_attrs: dict):
|
def get_estimated_attributes(self, asset_attrs: dict):
|
||||||
|
"""Calculates estimated attributes from asset properties.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
asset_attrs (dict): Asset attributes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Estimated attributes (height, mass, mu, category).
|
||||||
|
"""
|
||||||
estimated_attrs = {
|
estimated_attrs = {
|
||||||
"height": round(
|
"height": round(
|
||||||
(asset_attrs["min_height"] + asset_attrs["max_height"]) / 2, 4
|
(asset_attrs["min_height"] + asset_attrs["max_height"]) / 2, 4
|
||||||
@ -403,6 +459,18 @@ class URDFGenerator(object):
|
|||||||
category: str = "unknown",
|
category: str = "unknown",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
"""Generates a URDF file for a mesh asset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mesh_path (str): Path to mesh file.
|
||||||
|
output_root (str): Directory for outputs.
|
||||||
|
text_prompt (str, optional): Prompt for GPT.
|
||||||
|
category (str, optional): Asset category.
|
||||||
|
**kwargs: Additional attributes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Path to generated URDF file.
|
||||||
|
"""
|
||||||
if text_prompt is None or len(text_prompt) == 0:
|
if text_prompt is None or len(text_prompt) == 0:
|
||||||
text_prompt = self.prompt_template
|
text_prompt = self.prompt_template
|
||||||
text_prompt = text_prompt.format(category=category.lower())
|
text_prompt = text_prompt.format(category=category.lower())
|
||||||
|
|||||||
@ -7,7 +7,7 @@ packages = ["embodied_gen"]
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "embodied_gen"
|
name = "embodied_gen"
|
||||||
version = "v0.1.5"
|
version = "v0.1.6"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
license-files = ["LICENSE", "NOTICE"]
|
license-files = ["LICENSE", "NOTICE"]
|
||||||
|
|||||||
@ -4,10 +4,9 @@ import pytest
|
|||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from embodied_gen.data.asset_converter import (
|
from embodied_gen.data.asset_converter import (
|
||||||
AssetConverterFactory,
|
AssetConverterFactory,
|
||||||
AssetType,
|
|
||||||
SimAssetMapper,
|
|
||||||
cvt_embodiedgen_asset_to_anysim,
|
cvt_embodiedgen_asset_to_anysim,
|
||||||
)
|
)
|
||||||
|
from embodied_gen.utils.enum import AssetType, SimAssetMapper
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
@ -77,7 +76,10 @@ def test_cvt_embodiedgen_asset_to_anysim(
|
|||||||
):
|
):
|
||||||
dst_asset_path = cvt_embodiedgen_asset_to_anysim(
|
dst_asset_path = cvt_embodiedgen_asset_to_anysim(
|
||||||
urdf_files=[
|
urdf_files=[
|
||||||
"outputs/embodiedgen_assets/demo_assets/remote_control2/result/remote_control.urdf",
|
"outputs/embodiedgen_assets/demo_assets/remote_control/result/remote_control.urdf",
|
||||||
|
],
|
||||||
|
target_dirs=[
|
||||||
|
"outputs/embodiedgen_assets/demo_assets/remote_control/usd/remote_control.usd",
|
||||||
],
|
],
|
||||||
target_type=SimAssetMapper[simulator_name],
|
target_type=SimAssetMapper[simulator_name],
|
||||||
source_type=AssetType.MESH,
|
source_type=AssetType.MESH,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user