feat(docs): Improve docstrings across the codebase and docs. (#56)
This commit is contained in:
parent
cd94669770
commit
a256674bf2
@ -37,7 +37,7 @@
|
||||
```sh
|
||||
git clone https://github.com/HorizonRobotics/EmbodiedGen.git
|
||||
cd EmbodiedGen
|
||||
git checkout v0.1.5
|
||||
git checkout v0.1.6
|
||||
git submodule update --init --recursive --progress
|
||||
conda create -n embodiedgen python=3.10.13 -y # recommended to use a new env.
|
||||
conda activate embodiedgen
|
||||
|
||||
@ -31,8 +31,8 @@ from typing import Any, Dict, Tuple
|
||||
|
||||
import gradio as gr
|
||||
import pandas as pd
|
||||
import yaml
|
||||
from app_style import custom_theme, lighting_css
|
||||
from embodied_gen.utils.tags import VERSION
|
||||
|
||||
try:
|
||||
from embodied_gen.utils.gpt_clients import GPT_CLIENT as gpt_client
|
||||
@ -48,7 +48,6 @@ except Exception as e:
|
||||
|
||||
|
||||
# --- Configuration & Data Loading ---
|
||||
VERSION = "v0.1.5"
|
||||
RUNNING_MODE = "local" # local or hf_remote
|
||||
CSV_FILE = "dataset_index.csv"
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ hide:
|
||||
```sh
|
||||
git clone https://github.com/HorizonRobotics/EmbodiedGen.git
|
||||
cd EmbodiedGen
|
||||
git checkout v0.1.5
|
||||
git checkout v0.1.6
|
||||
git submodule update --init --recursive --progress
|
||||
conda create -n embodiedgen python=3.10.13 -y # recommended to use a new env.
|
||||
conda activate embodiedgen
|
||||
|
||||
@ -35,7 +35,8 @@ Leverage **EmbodiedGen-generated assets** with *accurate physical collisions* an
|
||||
## 🧱 Example: Conversion to Target Simulator
|
||||
|
||||
```python
|
||||
from embodied_gen.data.asset_converter import SimAssetMapper, cvt_embodiedgen_asset_to_anysim
|
||||
from embodied_gen.data.asset_converter import cvt_embodiedgen_asset_to_anysim
|
||||
from embodied_gen.utils.enum import AssetType, SimAssetMapper
|
||||
from typing import Literal
|
||||
|
||||
simulator_name: Literal[
|
||||
@ -52,6 +53,10 @@ dst_asset_path = cvt_embodiedgen_asset_to_anysim(
|
||||
"path1_to_embodiedgen_asset/asset.urdf",
|
||||
"path2_to_embodiedgen_asset/asset.urdf",
|
||||
],
|
||||
target_dirs=[
|
||||
"path1_to_target_dir/asset.usd",
|
||||
"path2_to_target_dir/asset.usd",
|
||||
],
|
||||
target_type=SimAssetMapper[simulator_name],
|
||||
source_type=AssetType.MESH,
|
||||
overwrite=True,
|
||||
|
||||
@ -4,12 +4,12 @@ import logging
|
||||
import os
|
||||
import xml.etree.ElementTree as ET
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from glob import glob
|
||||
from shutil import copy, copytree, rmtree
|
||||
|
||||
import trimesh
|
||||
from scipy.spatial.transform import Rotation
|
||||
from embodied_gen.utils.enum import AssetType
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -17,75 +17,62 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = [
|
||||
"AssetConverterFactory",
|
||||
"AssetType",
|
||||
"MeshtoMJCFConverter",
|
||||
"MeshtoUSDConverter",
|
||||
"URDFtoUSDConverter",
|
||||
"cvt_embodiedgen_asset_to_anysim",
|
||||
"PhysicsUSDAdder",
|
||||
"SimAssetMapper",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssetType(str):
|
||||
"""Asset type enumeration."""
|
||||
|
||||
MJCF = "mjcf"
|
||||
USD = "usd"
|
||||
URDF = "urdf"
|
||||
MESH = "mesh"
|
||||
|
||||
|
||||
class SimAssetMapper:
|
||||
_mapping = dict(
|
||||
ISAACSIM=AssetType.USD,
|
||||
ISAACGYM=AssetType.URDF,
|
||||
MUJOCO=AssetType.MJCF,
|
||||
GENESIS=AssetType.MJCF,
|
||||
SAPIEN=AssetType.URDF,
|
||||
PYBULLET=AssetType.URDF,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __class_getitem__(cls, key: str):
|
||||
key = key.upper()
|
||||
if key.startswith("SAPIEN"):
|
||||
key = "SAPIEN"
|
||||
return cls._mapping[key]
|
||||
|
||||
|
||||
def cvt_embodiedgen_asset_to_anysim(
|
||||
urdf_files: list[str],
|
||||
target_dirs: list[str],
|
||||
target_type: AssetType,
|
||||
source_type: AssetType,
|
||||
overwrite: bool = False,
|
||||
**kwargs,
|
||||
) -> dict[str, str]:
|
||||
"""Convert URDF files generated by EmbodiedGen into the format required by all simulators.
|
||||
"""Convert URDF files generated by EmbodiedGen into formats required by simulators.
|
||||
|
||||
Supported simulators include SAPIEN, Isaac Sim, MuJoCo, Isaac Gym, Genesis, and Pybullet.
|
||||
Converting to the `USD` format requires `isaacsim` to be installed.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from embodied_gen.data.asset_converter import cvt_embodiedgen_asset_to_anysim
|
||||
from embodied_gen.utils.enum import AssetType
|
||||
|
||||
dst_asset_path = cvt_embodiedgen_asset_to_anysim(
|
||||
urdf_files,
|
||||
target_type=SimAssetMapper[simulator_name],
|
||||
urdf_files=[
|
||||
"path1_to_embodiedgen_asset/asset.urdf",
|
||||
"path2_to_embodiedgen_asset/asset.urdf",
|
||||
],
|
||||
target_dirs=[
|
||||
"path1_to_target_dir/asset.usd",
|
||||
"path2_to_target_dir/asset.usd",
|
||||
],
|
||||
target_type=AssetType.USD,
|
||||
source_type=AssetType.MESH,
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
urdf_files (List[str]): List of URDF file paths to be converted.
|
||||
target_type (AssetType): The target asset type.
|
||||
source_type (AssetType): The source asset type.
|
||||
overwrite (bool): Whether to overwrite existing converted files.
|
||||
**kwargs: Additional keyword arguments for the converter.
|
||||
urdf_files (list[str]): List of URDF file paths.
|
||||
target_dirs (list[str]): List of target directories.
|
||||
target_type (AssetType): Target asset type.
|
||||
source_type (AssetType): Source asset type.
|
||||
overwrite (bool, optional): Overwrite existing files.
|
||||
**kwargs: Additional converter arguments.
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: A dictionary mapping the original URDF file path to the converted asset file path.
|
||||
dict[str, str]: Mapping from URDF file to converted asset file.
|
||||
"""
|
||||
|
||||
if isinstance(urdf_files, str):
|
||||
urdf_files = [urdf_files]
|
||||
if isinstance(target_dirs, str):
|
||||
urdf_files = [target_dirs]
|
||||
|
||||
# If the target type is URDF, no conversion is needed.
|
||||
if target_type == AssetType.URDF:
|
||||
@ -99,18 +86,17 @@ def cvt_embodiedgen_asset_to_anysim(
|
||||
asset_paths = dict()
|
||||
|
||||
with asset_converter:
|
||||
for urdf_file in urdf_files:
|
||||
for urdf_file, target_dir in zip(urdf_files, target_dirs):
|
||||
filename = os.path.basename(urdf_file).replace(".urdf", "")
|
||||
asset_dir = os.path.dirname(urdf_file)
|
||||
if target_type == AssetType.MJCF:
|
||||
target_file = f"{asset_dir}/../mjcf/{filename}.xml"
|
||||
target_file = f"{target_dir}/{filename}.xml"
|
||||
elif target_type == AssetType.USD:
|
||||
target_file = f"{asset_dir}/../usd/{filename}.usd"
|
||||
target_file = f"{target_dir}/{filename}.usd"
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Target type {target_type} not supported."
|
||||
)
|
||||
if not os.path.exists(target_file):
|
||||
if not os.path.exists(target_file) or overwrite:
|
||||
asset_converter.convert(urdf_file, target_file)
|
||||
|
||||
asset_paths[urdf_file] = target_file
|
||||
@ -119,16 +105,35 @@ def cvt_embodiedgen_asset_to_anysim(
|
||||
|
||||
|
||||
class AssetConverterBase(ABC):
|
||||
"""Converter abstract base class."""
|
||||
"""Abstract base class for asset converters.
|
||||
|
||||
Provides context management and mesh transformation utilities.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def convert(self, urdf_path: str, output_path: str, **kwargs) -> str:
|
||||
"""Convert an asset file.
|
||||
|
||||
Args:
|
||||
urdf_path (str): Path to input URDF file.
|
||||
output_path (str): Path to output file.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
str: Path to converted asset.
|
||||
"""
|
||||
pass
|
||||
|
||||
def transform_mesh(
|
||||
self, input_mesh: str, output_mesh: str, mesh_origin: ET.Element
|
||||
) -> None:
|
||||
"""Apply transform to the mesh based on the origin element in URDF."""
|
||||
"""Apply transform to mesh based on URDF origin element.
|
||||
|
||||
Args:
|
||||
input_mesh (str): Path to input mesh.
|
||||
output_mesh (str): Path to output mesh.
|
||||
mesh_origin (ET.Element): Origin element from URDF.
|
||||
"""
|
||||
mesh = trimesh.load(input_mesh, group_material=False)
|
||||
rpy = list(map(float, mesh_origin.get("rpy").split(" ")))
|
||||
rotation = Rotation.from_euler("xyz", rpy, degrees=False)
|
||||
@ -150,14 +155,19 @@ class AssetConverterBase(ABC):
|
||||
return
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit."""
|
||||
return False
|
||||
|
||||
|
||||
class MeshtoMJCFConverter(AssetConverterBase):
|
||||
"""Convert URDF files into MJCF format."""
|
||||
"""Converts mesh-based URDF files to MJCF format.
|
||||
|
||||
Handles geometry, materials, and asset copying.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -166,6 +176,12 @@ class MeshtoMJCFConverter(AssetConverterBase):
|
||||
self.kwargs = kwargs
|
||||
|
||||
def _copy_asset_file(self, src: str, dst: str) -> None:
|
||||
"""Copies asset file if not already present.
|
||||
|
||||
Args:
|
||||
src (str): Source file path.
|
||||
dst (str): Destination file path.
|
||||
"""
|
||||
if os.path.exists(dst):
|
||||
return
|
||||
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
||||
@ -183,7 +199,19 @@ class MeshtoMJCFConverter(AssetConverterBase):
|
||||
material: ET.Element | None = None,
|
||||
is_collision: bool = False,
|
||||
) -> None:
|
||||
"""Add geometry to the MJCF body from the URDF link."""
|
||||
"""Adds geometry to MJCF body from URDF link.
|
||||
|
||||
Args:
|
||||
mujoco_element (ET.Element): MJCF asset element.
|
||||
link (ET.Element): URDF link element.
|
||||
body (ET.Element): MJCF body element.
|
||||
tag (str): Tag name ("visual" or "collision").
|
||||
input_dir (str): Input directory.
|
||||
output_dir (str): Output directory.
|
||||
mesh_name (str): Mesh name.
|
||||
material (ET.Element, optional): Material element.
|
||||
is_collision (bool, optional): If True, treat as collision geometry.
|
||||
"""
|
||||
element = link.find(tag)
|
||||
geometry = element.find("geometry")
|
||||
mesh = geometry.find("mesh")
|
||||
@ -242,7 +270,20 @@ class MeshtoMJCFConverter(AssetConverterBase):
|
||||
name: str,
|
||||
reflectance: float = 0.2,
|
||||
) -> ET.Element:
|
||||
"""Add materials to the MJCF asset from the URDF link."""
|
||||
"""Adds materials to MJCF asset from URDF link.
|
||||
|
||||
Args:
|
||||
mujoco_element (ET.Element): MJCF asset element.
|
||||
link (ET.Element): URDF link element.
|
||||
tag (str): Tag name.
|
||||
input_dir (str): Input directory.
|
||||
output_dir (str): Output directory.
|
||||
name (str): Material name.
|
||||
reflectance (float, optional): Reflectance value.
|
||||
|
||||
Returns:
|
||||
ET.Element: Material element.
|
||||
"""
|
||||
element = link.find(tag)
|
||||
geometry = element.find("geometry")
|
||||
mesh = geometry.find("mesh")
|
||||
@ -282,7 +323,12 @@ class MeshtoMJCFConverter(AssetConverterBase):
|
||||
return material
|
||||
|
||||
def convert(self, urdf_path: str, mjcf_path: str):
|
||||
"""Convert a URDF file to MJCF format."""
|
||||
"""Converts a URDF file to MJCF format.
|
||||
|
||||
Args:
|
||||
urdf_path (str): Path to URDF file.
|
||||
mjcf_path (str): Path to output MJCF file.
|
||||
"""
|
||||
tree = ET.parse(urdf_path)
|
||||
root = tree.getroot()
|
||||
|
||||
@ -336,10 +382,22 @@ class MeshtoMJCFConverter(AssetConverterBase):
|
||||
|
||||
|
||||
class URDFtoMJCFConverter(MeshtoMJCFConverter):
|
||||
"""Convert URDF files with joints to MJCF format, handling transformations from joints."""
|
||||
"""Converts URDF files with joints to MJCF format, handling joint transformations.
|
||||
|
||||
Handles fixed joints and hierarchical body structure.
|
||||
"""
|
||||
|
||||
def convert(self, urdf_path: str, mjcf_path: str, **kwargs) -> str:
|
||||
"""Convert a URDF file with joints to MJCF format."""
|
||||
"""Converts a URDF file with joints to MJCF format.
|
||||
|
||||
Args:
|
||||
urdf_path (str): Path to URDF file.
|
||||
mjcf_path (str): Path to output MJCF file.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
str: Path to converted MJCF file.
|
||||
"""
|
||||
tree = ET.parse(urdf_path)
|
||||
root = tree.getroot()
|
||||
|
||||
@ -423,7 +481,10 @@ class URDFtoMJCFConverter(MeshtoMJCFConverter):
|
||||
|
||||
|
||||
class MeshtoUSDConverter(AssetConverterBase):
|
||||
"""Convert Mesh file from URDF into USD format."""
|
||||
"""Converts mesh-based URDF files to USD format.
|
||||
|
||||
Adds physics APIs and post-processes collision meshes.
|
||||
"""
|
||||
|
||||
DEFAULT_BIND_APIS = [
|
||||
"MaterialBindingAPI",
|
||||
@ -443,6 +504,14 @@ class MeshtoUSDConverter(AssetConverterBase):
|
||||
simulation_app=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initializes the converter.
|
||||
|
||||
Args:
|
||||
force_usd_conversion (bool, optional): Force USD conversion.
|
||||
make_instanceable (bool, optional): Make prims instanceable.
|
||||
simulation_app (optional): Simulation app instance.
|
||||
**kwargs: Additional arguments.
|
||||
"""
|
||||
if simulation_app is not None:
|
||||
self.simulation_app = simulation_app
|
||||
|
||||
@ -458,6 +527,7 @@ class MeshtoUSDConverter(AssetConverterBase):
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry, launches simulation app if needed."""
|
||||
from isaaclab.app import AppLauncher
|
||||
|
||||
if not hasattr(self, "simulation_app"):
|
||||
@ -476,6 +546,7 @@ class MeshtoUSDConverter(AssetConverterBase):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit, closes simulation app if created."""
|
||||
# Close the simulation app if it was created here
|
||||
if hasattr(self, "app_launcher") and self.exit_close:
|
||||
self.simulation_app.close()
|
||||
@ -486,7 +557,12 @@ class MeshtoUSDConverter(AssetConverterBase):
|
||||
return False
|
||||
|
||||
def convert(self, urdf_path: str, output_file: str):
|
||||
"""Convert a URDF file to USD and post-process collision meshes."""
|
||||
"""Converts a URDF file to USD and post-processes collision meshes.
|
||||
|
||||
Args:
|
||||
urdf_path (str): Path to URDF file.
|
||||
output_file (str): Path to output USD file.
|
||||
"""
|
||||
from isaaclab.sim.converters import MeshConverter, MeshConverterCfg
|
||||
from pxr import PhysxSchema, Sdf, Usd, UsdShade
|
||||
|
||||
@ -556,6 +632,11 @@ class MeshtoUSDConverter(AssetConverterBase):
|
||||
|
||||
|
||||
class PhysicsUSDAdder(MeshtoUSDConverter):
|
||||
"""Adds physics APIs and collision properties to USD assets.
|
||||
|
||||
Useful for post-processing USD files for simulation.
|
||||
"""
|
||||
|
||||
DEFAULT_BIND_APIS = [
|
||||
"MaterialBindingAPI",
|
||||
# "PhysicsMeshCollisionAPI",
|
||||
@ -566,6 +647,12 @@ class PhysicsUSDAdder(MeshtoUSDConverter):
|
||||
]
|
||||
|
||||
def convert(self, usd_path: str, output_file: str = None):
|
||||
"""Adds physics APIs and collision properties to a USD file.
|
||||
|
||||
Args:
|
||||
usd_path (str): Path to input USD file.
|
||||
output_file (str, optional): Path to output USD file.
|
||||
"""
|
||||
from pxr import PhysxSchema, Sdf, Usd, UsdGeom, UsdPhysics
|
||||
|
||||
if output_file is None:
|
||||
@ -626,14 +713,18 @@ class PhysicsUSDAdder(MeshtoUSDConverter):
|
||||
|
||||
|
||||
class URDFtoUSDConverter(MeshtoUSDConverter):
|
||||
"""Convert URDF files into USD format.
|
||||
"""Converts URDF files to USD format.
|
||||
|
||||
Args:
|
||||
fix_base (bool): Whether to fix the base link.
|
||||
merge_fixed_joints (bool): Whether to merge fixed joints.
|
||||
make_instanceable (bool): Whether to make prims instanceable.
|
||||
force_usd_conversion (bool): Force conversion to USD.
|
||||
collision_from_visuals (bool): Generate collisions from visuals if not provided.
|
||||
fix_base (bool, optional): Fix the base link.
|
||||
merge_fixed_joints (bool, optional): Merge fixed joints.
|
||||
make_instanceable (bool, optional): Make prims instanceable.
|
||||
force_usd_conversion (bool, optional): Force conversion to USD.
|
||||
collision_from_visuals (bool, optional): Generate collisions from visuals.
|
||||
joint_drive (optional): Joint drive configuration.
|
||||
rotate_wxyz (tuple[float], optional): Quaternion for rotation.
|
||||
simulation_app (optional): Simulation app instance.
|
||||
**kwargs: Additional arguments.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -648,6 +739,19 @@ class URDFtoUSDConverter(MeshtoUSDConverter):
|
||||
simulation_app=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initializes the converter.
|
||||
|
||||
Args:
|
||||
fix_base (bool, optional): Fix the base link.
|
||||
merge_fixed_joints (bool, optional): Merge fixed joints.
|
||||
make_instanceable (bool, optional): Make prims instanceable.
|
||||
force_usd_conversion (bool, optional): Force conversion to USD.
|
||||
collision_from_visuals (bool, optional): Generate collisions from visuals.
|
||||
joint_drive (optional): Joint drive configuration.
|
||||
rotate_wxyz (tuple[float], optional): Quaternion for rotation.
|
||||
simulation_app (optional): Simulation app instance.
|
||||
**kwargs: Additional arguments.
|
||||
"""
|
||||
self.usd_parms = dict(
|
||||
fix_base=fix_base,
|
||||
merge_fixed_joints=merge_fixed_joints,
|
||||
@ -662,7 +766,12 @@ class URDFtoUSDConverter(MeshtoUSDConverter):
|
||||
self.simulation_app = simulation_app
|
||||
|
||||
def convert(self, urdf_path: str, output_file: str):
|
||||
"""Convert a URDF file to USD and post-process collision meshes."""
|
||||
"""Converts a URDF file to USD and post-processes collision meshes.
|
||||
|
||||
Args:
|
||||
urdf_path (str): Path to URDF file.
|
||||
output_file (str): Path to output USD file.
|
||||
"""
|
||||
from isaaclab.sim.converters import UrdfConverter, UrdfConverterCfg
|
||||
from pxr import Gf, PhysxSchema, Sdf, Usd, UsdGeom
|
||||
|
||||
@ -723,13 +832,36 @@ class URDFtoUSDConverter(MeshtoUSDConverter):
|
||||
|
||||
|
||||
class AssetConverterFactory:
|
||||
"""Factory class for creating asset converters based on target and source types."""
|
||||
"""Factory for creating asset converters based on target and source types.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from embodied_gen.data.asset_converter import AssetConverterFactory
|
||||
from embodied_gen.utils.enum import AssetType
|
||||
|
||||
converter = AssetConverterFactory.create(
|
||||
target_type=AssetType.USD, source_type=AssetType.MESH
|
||||
)
|
||||
with converter:
|
||||
for urdf_path, output_file in zip(urdf_paths, output_files):
|
||||
converter.convert(urdf_path, output_file)
|
||||
```
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
target_type: AssetType, source_type: AssetType = "urdf", **kwargs
|
||||
) -> AssetConverterBase:
|
||||
"""Create an asset converter instance based on target and source types."""
|
||||
"""Creates an asset converter instance.
|
||||
|
||||
Args:
|
||||
target_type (AssetType): Target asset type.
|
||||
source_type (AssetType, optional): Source asset type.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
AssetConverterBase: Converter instance.
|
||||
"""
|
||||
if target_type == AssetType.MJCF and source_type == AssetType.MESH:
|
||||
converter = MeshtoMJCFConverter(**kwargs)
|
||||
elif target_type == AssetType.MJCF and source_type == AssetType.URDF:
|
||||
@ -751,7 +883,14 @@ if __name__ == "__main__":
|
||||
# target_asset_type = AssetType.USD
|
||||
|
||||
urdf_paths = [
|
||||
"outputs/embodiedgen_assets/demo_assets/remote_control/result/remote_control.urdf",
|
||||
'outputs/EmbodiedGenData/demo_assets/banana/result/banana.urdf',
|
||||
'outputs/EmbodiedGenData/demo_assets/book/result/book.urdf',
|
||||
'outputs/EmbodiedGenData/demo_assets/lamp/result/lamp.urdf',
|
||||
'outputs/EmbodiedGenData/demo_assets/mug/result/mug.urdf',
|
||||
'outputs/EmbodiedGenData/demo_assets/remote_control/result/remote_control.urdf',
|
||||
"outputs/EmbodiedGenData/demo_assets/rubik's_cube/result/rubik's_cube.urdf",
|
||||
'outputs/EmbodiedGenData/demo_assets/table/result/table.urdf',
|
||||
'outputs/EmbodiedGenData/demo_assets/vase/result/vase.urdf',
|
||||
]
|
||||
|
||||
if target_asset_type == AssetType.MJCF:
|
||||
@ -765,7 +904,14 @@ if __name__ == "__main__":
|
||||
|
||||
elif target_asset_type == AssetType.USD:
|
||||
output_files = [
|
||||
"outputs/embodiedgen_assets/demo_assets/remote_control/usd/remote_control.usd",
|
||||
'outputs/EmbodiedGenData/demo_assets/banana/usd/banana.usd',
|
||||
'outputs/EmbodiedGenData/demo_assets/book/usd/book.usd',
|
||||
'outputs/EmbodiedGenData/demo_assets/lamp/usd/lamp.usd',
|
||||
'outputs/EmbodiedGenData/demo_assets/mug/usd/mug.usd',
|
||||
'outputs/EmbodiedGenData/demo_assets/remote_control/usd/remote_control.usd',
|
||||
"outputs/EmbodiedGenData/demo_assets/rubik's_cube/usd/rubik's_cube.usd",
|
||||
'outputs/EmbodiedGenData/demo_assets/table/usd/table.usd',
|
||||
'outputs/EmbodiedGenData/demo_assets/vase/usd/vase.usd',
|
||||
]
|
||||
asset_converter = AssetConverterFactory.create(
|
||||
target_type=AssetType.USD,
|
||||
@ -776,33 +922,33 @@ if __name__ == "__main__":
|
||||
for urdf_path, output_file in zip(urdf_paths, output_files):
|
||||
asset_converter.convert(urdf_path, output_file)
|
||||
|
||||
urdf_path = "outputs/embodiedgen_assets/demo_assets/remote_control/result/remote_control.urdf"
|
||||
output_file = "outputs/embodiedgen_assets/demo_assets/remote_control/usd/remote_control.usd"
|
||||
# urdf_path = "outputs/embodiedgen_assets/demo_assets/remote_control/result/remote_control.urdf"
|
||||
# output_file = "outputs/embodiedgen_assets/demo_assets/remote_control/usd/remote_control.usd"
|
||||
|
||||
asset_converter = AssetConverterFactory.create(
|
||||
target_type=AssetType.USD,
|
||||
source_type=AssetType.URDF,
|
||||
rotate_wxyz=(0.7071, 0.7071, 0, 0), # rotate 90 deg around the X-axis
|
||||
)
|
||||
# asset_converter = AssetConverterFactory.create(
|
||||
# target_type=AssetType.USD,
|
||||
# source_type=AssetType.URDF,
|
||||
# rotate_wxyz=(0.7071, 0.7071, 0, 0), # rotate 90 deg around the X-axis
|
||||
# )
|
||||
|
||||
with asset_converter:
|
||||
asset_converter.convert(urdf_path, output_file)
|
||||
# with asset_converter:
|
||||
# asset_converter.convert(urdf_path, output_file)
|
||||
|
||||
# Convert infinigen urdf to mjcf
|
||||
urdf_path = "/home/users/xinjie.wang/xinjie/infinigen/outputs/exports/kitchen_i_urdf/export_scene/scene.urdf"
|
||||
output_file = "/home/users/xinjie.wang/xinjie/infinigen/outputs/exports/kitchen_i_urdf/mjcf/scene.xml"
|
||||
asset_converter = AssetConverterFactory.create(
|
||||
target_type=AssetType.MJCF,
|
||||
source_type=AssetType.URDF,
|
||||
keep_materials=["diffuse"],
|
||||
)
|
||||
with asset_converter:
|
||||
asset_converter.convert(urdf_path, output_file)
|
||||
# # Convert infinigen urdf to mjcf
|
||||
# urdf_path = "/home/users/xinjie.wang/xinjie/infinigen/outputs/exports/kitchen_i_urdf/export_scene/scene.urdf"
|
||||
# output_file = "/home/users/xinjie.wang/xinjie/infinigen/outputs/exports/kitchen_i_urdf/mjcf/scene.xml"
|
||||
# asset_converter = AssetConverterFactory.create(
|
||||
# target_type=AssetType.MJCF,
|
||||
# source_type=AssetType.URDF,
|
||||
# keep_materials=["diffuse"],
|
||||
# )
|
||||
# with asset_converter:
|
||||
# asset_converter.convert(urdf_path, output_file)
|
||||
|
||||
# Convert infinigen usdc to physics usdc
|
||||
converter = PhysicsUSDAdder()
|
||||
with converter:
|
||||
converter.convert(
|
||||
usd_path="/home/users/xinjie.wang/xinjie/infinigen/outputs/usdc/export_scene/export_scene.usdc",
|
||||
output_file="/home/users/xinjie.wang/xinjie/infinigen/outputs/usdc_p3/export_scene/export_scene.usdc",
|
||||
)
|
||||
# # Convert infinigen usdc to physics usdc
|
||||
# converter = PhysicsUSDAdder()
|
||||
# with converter:
|
||||
# converter.convert(
|
||||
# usd_path="/home/users/xinjie.wang/xinjie/infinigen/outputs/usdc/export_scene/export_scene.usdc",
|
||||
# output_file="/home/users/xinjie.wang/xinjie/infinigen/outputs/usdc_p3/export_scene/export_scene.usdc",
|
||||
# )
|
||||
|
||||
@ -58,7 +58,16 @@ __all__ = [
|
||||
def _transform_vertices(
|
||||
mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""Transform 3D vertices using a projection matrix."""
|
||||
"""Transforms 3D vertices using a projection matrix.
|
||||
|
||||
Args:
|
||||
mtx (torch.Tensor): Projection matrix.
|
||||
pos (torch.Tensor): Vertex positions.
|
||||
keepdim (bool, optional): If True, keeps the batch dimension.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed vertices.
|
||||
"""
|
||||
t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype)
|
||||
if pos.size(-1) == 3:
|
||||
pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1)
|
||||
@ -71,7 +80,17 @@ def _transform_vertices(
|
||||
def _bilinear_interpolation_scattering(
|
||||
image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""Bilinear interpolation scattering for grid-based value accumulation."""
|
||||
"""Performs bilinear interpolation scattering for grid-based value accumulation.
|
||||
|
||||
Args:
|
||||
image_h (int): Image height.
|
||||
image_w (int): Image width.
|
||||
coords (torch.Tensor): Normalized coordinates.
|
||||
values (torch.Tensor): Values to scatter.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Interpolated grid.
|
||||
"""
|
||||
device = values.device
|
||||
dtype = values.dtype
|
||||
C = values.shape[-1]
|
||||
@ -135,7 +154,18 @@ def _texture_inpaint_smooth(
|
||||
faces: np.ndarray,
|
||||
uv_map: np.ndarray,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Perform texture inpainting using vertex-based color propagation."""
|
||||
"""Performs texture inpainting using vertex-based color propagation.
|
||||
|
||||
Args:
|
||||
texture (np.ndarray): Texture image.
|
||||
mask (np.ndarray): Mask image.
|
||||
vertices (np.ndarray): Mesh vertices.
|
||||
faces (np.ndarray): Mesh faces.
|
||||
uv_map (np.ndarray): UV coordinates.
|
||||
|
||||
Returns:
|
||||
tuple[np.ndarray, np.ndarray]: Inpainted texture and updated mask.
|
||||
"""
|
||||
image_h, image_w, C = texture.shape
|
||||
N = vertices.shape[0]
|
||||
|
||||
@ -231,29 +261,41 @@ def _texture_inpaint_smooth(
|
||||
class TextureBacker:
|
||||
"""Texture baking pipeline for multi-view projection and fusion.
|
||||
|
||||
This class performs UV-based texture generation for a 3D mesh using
|
||||
multi-view color images, depth, and normal information. The pipeline
|
||||
includes mesh normalization and UV unwrapping, visibility-aware
|
||||
back-projection, confidence-weighted texture fusion, and inpainting
|
||||
of missing texture regions.
|
||||
This class generates UV-based textures for a 3D mesh using multi-view images,
|
||||
depth, and normal information. It includes mesh normalization, UV unwrapping,
|
||||
visibility-aware back-projection, confidence-weighted fusion, and inpainting.
|
||||
|
||||
Args:
|
||||
camera_params (CameraSetting): Camera intrinsics and extrinsics used
|
||||
for rendering each view.
|
||||
view_weights (list[float]): A list of weights for each view, used
|
||||
to blend confidence maps during texture fusion.
|
||||
render_wh (tuple[int, int], optional): Resolution (width, height) for
|
||||
intermediate rendering passes. Defaults to (2048, 2048).
|
||||
texture_wh (tuple[int, int], optional): Output texture resolution
|
||||
(width, height). Defaults to (2048, 2048).
|
||||
bake_angle_thresh (int, optional): Maximum angle (in degrees) between
|
||||
view direction and surface normal for projection to be considered valid.
|
||||
Defaults to 75.
|
||||
mask_thresh (float, optional): Threshold applied to visibility masks
|
||||
during rendering. Defaults to 0.5.
|
||||
smooth_texture (bool, optional): If True, apply post-processing (e.g.,
|
||||
blurring) to the final texture. Defaults to True.
|
||||
inpaint_smooth (bool, optional): If True, apply inpainting to smooth.
|
||||
camera_params (CameraSetting): Camera intrinsics and extrinsics.
|
||||
view_weights (list[float]): Weights for each view in texture fusion.
|
||||
render_wh (tuple[int, int], optional): Intermediate rendering resolution.
|
||||
texture_wh (tuple[int, int], optional): Output texture resolution.
|
||||
bake_angle_thresh (int, optional): Max angle for valid projection.
|
||||
mask_thresh (float, optional): Threshold for visibility masks.
|
||||
smooth_texture (bool, optional): Apply post-processing to texture.
|
||||
inpaint_smooth (bool, optional): Apply inpainting smoothing.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from embodied_gen.data.backproject_v2 import TextureBacker
|
||||
from embodied_gen.data.utils import CameraSetting
|
||||
import trimesh
|
||||
from PIL import Image
|
||||
|
||||
camera_params = CameraSetting(
|
||||
num_images=6,
|
||||
elevation=[20, -10],
|
||||
distance=5,
|
||||
resolution_hw=(2048,2048),
|
||||
fov=math.radians(30),
|
||||
device='cuda',
|
||||
)
|
||||
view_weights = [1, 0.1, 0.02, 0.1, 1, 0.02]
|
||||
mesh = trimesh.load('mesh.obj')
|
||||
images = [Image.open(f'view_{i}.png') for i in range(6)]
|
||||
texture_backer = TextureBacker(camera_params, view_weights)
|
||||
textured_mesh = texture_backer(images, mesh, 'output.obj')
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -283,6 +325,12 @@ class TextureBacker:
|
||||
)
|
||||
|
||||
def _lazy_init_render(self, camera_params, mask_thresh):
|
||||
"""Lazily initializes the renderer.
|
||||
|
||||
Args:
|
||||
camera_params (CameraSetting): Camera settings.
|
||||
mask_thresh (float): Mask threshold.
|
||||
"""
|
||||
if self.renderer is None:
|
||||
camera = init_kal_camera(camera_params)
|
||||
mv = camera.view_matrix() # (n 4 4) world2cam
|
||||
@ -301,6 +349,14 @@ class TextureBacker:
|
||||
)
|
||||
|
||||
def load_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh:
|
||||
"""Normalizes mesh and unwraps UVs.
|
||||
|
||||
Args:
|
||||
mesh (trimesh.Trimesh): Input mesh.
|
||||
|
||||
Returns:
|
||||
trimesh.Trimesh: Mesh with normalized vertices and UVs.
|
||||
"""
|
||||
mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
|
||||
self.scale, self.center = scale, center
|
||||
|
||||
@ -318,6 +374,16 @@ class TextureBacker:
|
||||
scale: float = None,
|
||||
center: np.ndarray = None,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Gets mesh attributes as numpy arrays.
|
||||
|
||||
Args:
|
||||
mesh (trimesh.Trimesh): Input mesh.
|
||||
scale (float, optional): Scale factor.
|
||||
center (np.ndarray, optional): Center offset.
|
||||
|
||||
Returns:
|
||||
tuple: (vertices, faces, uv_map)
|
||||
"""
|
||||
vertices = mesh.vertices.copy()
|
||||
faces = mesh.faces.copy()
|
||||
uv_map = mesh.visual.uv.copy()
|
||||
@ -331,6 +397,14 @@ class TextureBacker:
|
||||
return vertices, faces, uv_map
|
||||
|
||||
def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor:
|
||||
"""Computes edge image from depth map.
|
||||
|
||||
Args:
|
||||
depth_image (torch.Tensor): Depth map.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Edge image.
|
||||
"""
|
||||
depth_image_np = depth_image.cpu().numpy()
|
||||
depth_image_np = (depth_image_np * 255).astype(np.uint8)
|
||||
depth_edges = cv2.Canny(depth_image_np, 30, 80)
|
||||
@ -344,6 +418,16 @@ class TextureBacker:
|
||||
def compute_enhanced_viewnormal(
|
||||
self, mv_mtx: torch.Tensor, vertices: torch.Tensor, faces: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""Computes enhanced view normals for mesh faces.
|
||||
|
||||
Args:
|
||||
mv_mtx (torch.Tensor): View matrices.
|
||||
vertices (torch.Tensor): Mesh vertices.
|
||||
faces (torch.Tensor): Mesh faces.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: View normals.
|
||||
"""
|
||||
rast, _ = self.renderer.compute_dr_raster(vertices, faces)
|
||||
rendered_view_normals = []
|
||||
for idx in range(len(mv_mtx)):
|
||||
@ -376,6 +460,18 @@ class TextureBacker:
|
||||
def back_project(
|
||||
self, image, vis_mask, depth, normal, uv
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Back-projects image and confidence to UV texture space.
|
||||
|
||||
Args:
|
||||
image (PIL.Image or np.ndarray): Input image.
|
||||
vis_mask (torch.Tensor): Visibility mask.
|
||||
depth (torch.Tensor): Depth map.
|
||||
normal (torch.Tensor): Normal map.
|
||||
uv (torch.Tensor): UV coordinates.
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor]: Texture and confidence map.
|
||||
"""
|
||||
image = np.array(image)
|
||||
image = torch.as_tensor(image, device=self.device, dtype=torch.float32)
|
||||
if image.ndim == 2:
|
||||
@ -418,6 +514,17 @@ class TextureBacker:
|
||||
)
|
||||
|
||||
def _scatter_texture(self, uv, data, mask):
|
||||
"""Scatters data to texture using UV coordinates and mask.
|
||||
|
||||
Args:
|
||||
uv (torch.Tensor): UV coordinates.
|
||||
data (torch.Tensor): Data to scatter.
|
||||
mask (torch.Tensor): Mask for valid pixels.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Scattered texture.
|
||||
"""
|
||||
|
||||
def __filter_data(data, mask):
|
||||
return data.view(-1, data.shape[-1])[mask]
|
||||
|
||||
@ -432,6 +539,15 @@ class TextureBacker:
|
||||
def fast_bake_texture(
|
||||
self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor]
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Fuses multiple textures and confidence maps.
|
||||
|
||||
Args:
|
||||
textures (list[torch.Tensor]): List of textures.
|
||||
confidence_maps (list[torch.Tensor]): List of confidence maps.
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor]: Fused texture and mask.
|
||||
"""
|
||||
channel = textures[0].shape[-1]
|
||||
texture_merge = torch.zeros(self.texture_wh + [channel]).to(
|
||||
self.device
|
||||
@ -451,6 +567,16 @@ class TextureBacker:
|
||||
def uv_inpaint(
|
||||
self, mesh: trimesh.Trimesh, texture: np.ndarray, mask: np.ndarray
|
||||
) -> np.ndarray:
|
||||
"""Inpaints missing regions in the UV texture.
|
||||
|
||||
Args:
|
||||
mesh (trimesh.Trimesh): Mesh.
|
||||
texture (np.ndarray): Texture image.
|
||||
mask (np.ndarray): Mask image.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Inpainted texture.
|
||||
"""
|
||||
if self.inpaint_smooth:
|
||||
vertices, faces, uv_map = self.get_mesh_np_attrs(mesh)
|
||||
texture, mask = _texture_inpaint_smooth(
|
||||
@ -473,6 +599,15 @@ class TextureBacker:
|
||||
colors: list[Image.Image],
|
||||
mesh: trimesh.Trimesh,
|
||||
) -> trimesh.Trimesh:
|
||||
"""Computes the fused texture for the mesh from multi-view images.
|
||||
|
||||
Args:
|
||||
colors (list[Image.Image]): List of view images.
|
||||
mesh (trimesh.Trimesh): Mesh to texture.
|
||||
|
||||
Returns:
|
||||
tuple[np.ndarray, np.ndarray]: Texture and mask.
|
||||
"""
|
||||
self._lazy_init_render(self.camera_params, self.mask_thresh)
|
||||
|
||||
vertices = torch.from_numpy(mesh.vertices).to(self.device).float()
|
||||
@ -517,7 +652,7 @@ class TextureBacker:
|
||||
Args:
|
||||
colors (list[Image.Image]): List of input view images.
|
||||
mesh (trimesh.Trimesh): Input mesh to be textured.
|
||||
output_path (str): Path to save the output textured mesh (.obj or .glb).
|
||||
output_path (str): Path to save the output textured mesh.
|
||||
|
||||
Returns:
|
||||
trimesh.Trimesh: The textured mesh with UV and texture image.
|
||||
@ -540,6 +675,11 @@ class TextureBacker:
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parses command-line arguments for texture backprojection.
|
||||
|
||||
Returns:
|
||||
argparse.Namespace: Parsed arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Backproject texture")
|
||||
parser.add_argument(
|
||||
"--color_path",
|
||||
@ -636,6 +776,16 @@ def entrypoint(
|
||||
imagesr_model: ImageRealESRGAN = None,
|
||||
**kwargs,
|
||||
) -> trimesh.Trimesh:
|
||||
"""Entrypoint for texture backprojection from multi-view images.
|
||||
|
||||
Args:
|
||||
delight_model (DelightingModel, optional): Delighting model.
|
||||
imagesr_model (ImageRealESRGAN, optional): Super-resolution model.
|
||||
**kwargs: Additional arguments to override CLI.
|
||||
|
||||
Returns:
|
||||
trimesh.Trimesh: Textured mesh.
|
||||
"""
|
||||
args = parse_args()
|
||||
for k, v in kwargs.items():
|
||||
if hasattr(args, k) and v is not None:
|
||||
|
||||
@ -39,6 +39,22 @@ def decompose_convex_coacd(
|
||||
auto_scale: bool = True,
|
||||
scale_factor: float = 1.0,
|
||||
) -> None:
|
||||
"""Decomposes a mesh using CoACD and saves the result.
|
||||
|
||||
This function loads a mesh from a file, runs the CoACD algorithm with the
|
||||
given parameters, optionally scales the resulting convex hulls to match the
|
||||
original mesh's bounding box, and exports the combined result to a file.
|
||||
|
||||
Args:
|
||||
filename: Path to the input mesh file.
|
||||
outfile: Path to save the decomposed output mesh.
|
||||
params: A dictionary of parameters for the CoACD algorithm.
|
||||
verbose: If True, sets the CoACD log level to 'info'.
|
||||
auto_scale: If True, automatically computes a scale factor to match the
|
||||
decomposed mesh's bounding box to the visual mesh's bounding box.
|
||||
scale_factor: An additional scaling factor applied to the vertices of
|
||||
the decomposed mesh parts.
|
||||
"""
|
||||
coacd.set_log_level("info" if verbose else "warn")
|
||||
|
||||
mesh = trimesh.load(filename, force="mesh")
|
||||
@ -83,7 +99,38 @@ def decompose_convex_mesh(
|
||||
scale_factor: float = 1.005,
|
||||
verbose: bool = False,
|
||||
) -> str:
|
||||
"""Decompose a mesh into convex parts using the CoACD algorithm."""
|
||||
"""Decomposes a mesh into convex parts with retry logic.
|
||||
|
||||
This function serves as a wrapper for `decompose_convex_coacd`, providing
|
||||
explicit parameters for the CoACD algorithm and implementing a retry
|
||||
mechanism. If the initial decomposition fails, it attempts again with
|
||||
`preprocess_mode` set to 'on'.
|
||||
|
||||
Args:
|
||||
filename: Path to the input mesh file.
|
||||
outfile: Path to save the decomposed output mesh.
|
||||
threshold: CoACD parameter. See CoACD documentation for details.
|
||||
max_convex_hull: CoACD parameter. See CoACD documentation for details.
|
||||
preprocess_mode: CoACD parameter. See CoACD documentation for details.
|
||||
preprocess_resolution: CoACD parameter. See CoACD documentation for details.
|
||||
resolution: CoACD parameter. See CoACD documentation for details.
|
||||
mcts_nodes: CoACD parameter. See CoACD documentation for details.
|
||||
mcts_iterations: CoACD parameter. See CoACD documentation for details.
|
||||
mcts_max_depth: CoACD parameter. See CoACD documentation for details.
|
||||
pca: CoACD parameter. See CoACD documentation for details.
|
||||
merge: CoACD parameter. See CoACD documentation for details.
|
||||
seed: CoACD parameter. See CoACD documentation for details.
|
||||
auto_scale: If True, automatically scale the output to match the input
|
||||
bounding box.
|
||||
scale_factor: Additional scaling factor to apply.
|
||||
verbose: If True, enables detailed logging.
|
||||
|
||||
Returns:
|
||||
The path to the output file if decomposition is successful.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If convex decomposition fails after all attempts.
|
||||
"""
|
||||
coacd.set_log_level("info" if verbose else "warn")
|
||||
|
||||
if os.path.exists(outfile):
|
||||
@ -148,9 +195,37 @@ def decompose_convex_mp(
|
||||
verbose: bool = False,
|
||||
auto_scale: bool = True,
|
||||
) -> str:
|
||||
"""Decompose a mesh into convex parts using the CoACD algorithm in a separate process.
|
||||
"""Decomposes a mesh into convex parts in a separate process.
|
||||
|
||||
This function uses the `multiprocessing` module to run the CoACD algorithm
|
||||
in a spawned subprocess. This is useful for isolating the decomposition
|
||||
process to prevent potential memory leaks or crashes in the main process.
|
||||
It includes a retry mechanism similar to `decompose_convex_mesh`.
|
||||
|
||||
See https://simulately.wiki/docs/toolkits/ConvexDecomp for details.
|
||||
|
||||
Args:
|
||||
filename: Path to the input mesh file.
|
||||
outfile: Path to save the decomposed output mesh.
|
||||
threshold: CoACD parameter.
|
||||
max_convex_hull: CoACD parameter.
|
||||
preprocess_mode: CoACD parameter.
|
||||
preprocess_resolution: CoACD parameter.
|
||||
resolution: CoACD parameter.
|
||||
mcts_nodes: CoACD parameter.
|
||||
mcts_iterations: CoACD parameter.
|
||||
mcts_max_depth: CoACD parameter.
|
||||
pca: CoACD parameter.
|
||||
merge: CoACD parameter.
|
||||
seed: CoACD parameter.
|
||||
verbose: If True, enables detailed logging in the subprocess.
|
||||
auto_scale: If True, automatically scale the output.
|
||||
|
||||
Returns:
|
||||
The path to the output file if decomposition is successful.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If convex decomposition fails after all attempts.
|
||||
"""
|
||||
params = dict(
|
||||
threshold=threshold,
|
||||
|
||||
@ -66,6 +66,14 @@ def create_mp4_from_images(
|
||||
fps: int = 10,
|
||||
prompt: str = None,
|
||||
):
|
||||
"""Creates an MP4 video from a list of images.
|
||||
|
||||
Args:
|
||||
images (list[np.ndarray]): List of images as numpy arrays.
|
||||
output_path (str): Path to save the MP4 file.
|
||||
fps (int, optional): Frames per second. Defaults to 10.
|
||||
prompt (str, optional): Optional text prompt overlay.
|
||||
"""
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
font_scale = 0.5
|
||||
font_thickness = 1
|
||||
@ -96,6 +104,13 @@ def create_mp4_from_images(
|
||||
def create_gif_from_images(
|
||||
images: list[np.ndarray], output_path: str, fps: int = 10
|
||||
) -> None:
|
||||
"""Creates a GIF animation from a list of images.
|
||||
|
||||
Args:
|
||||
images (list[np.ndarray]): List of images as numpy arrays.
|
||||
output_path (str): Path to save the GIF file.
|
||||
fps (int, optional): Frames per second. Defaults to 10.
|
||||
"""
|
||||
pil_images = []
|
||||
for image in images:
|
||||
image = image.clip(min=0, max=1)
|
||||
@ -116,32 +131,47 @@ def create_gif_from_images(
|
||||
|
||||
|
||||
class ImageRender(object):
|
||||
"""A differentiable mesh renderer supporting multi-view rendering.
|
||||
"""Differentiable mesh renderer supporting multi-view rendering.
|
||||
|
||||
This class wraps a differentiable rasterization using `nvdiffrast` to
|
||||
render mesh geometry to various maps (normal, depth, alpha, albedo, etc.).
|
||||
This class wraps differentiable rasterization using `nvdiffrast` to render mesh
|
||||
geometry to various maps (normal, depth, alpha, albedo, etc.) and supports
|
||||
saving images and videos.
|
||||
|
||||
Args:
|
||||
render_items (list[RenderItems]): A list of rendering targets to
|
||||
generate (e.g., IMAGE, DEPTH, NORMAL, etc.).
|
||||
camera_params (CameraSetting): The camera parameters for rendering,
|
||||
including intrinsic and extrinsic matrices.
|
||||
recompute_vtx_normal (bool, optional): If True, recomputes
|
||||
vertex normals from the mesh geometry. Defaults to True.
|
||||
with_mtl (bool, optional): Whether to load `.mtl` material files
|
||||
for meshes. Defaults to False.
|
||||
gen_color_gif (bool, optional): Generate a GIF of rendered
|
||||
color images. Defaults to False.
|
||||
gen_color_mp4 (bool, optional): Generate an MP4 video of rendered
|
||||
color images. Defaults to False.
|
||||
gen_viewnormal_mp4 (bool, optional): Generate an MP4 video of
|
||||
view-space normals. Defaults to False.
|
||||
gen_glonormal_mp4 (bool, optional): Generate an MP4 video of
|
||||
global-space normals. Defaults to False.
|
||||
no_index_file (bool, optional): If True, skip saving the `index.json`
|
||||
summary file. Defaults to False.
|
||||
light_factor (float, optional): A scalar multiplier for
|
||||
PBR light intensity. Defaults to 1.0.
|
||||
render_items (list[RenderItems]): List of rendering targets.
|
||||
camera_params (CameraSetting): Camera parameters for rendering.
|
||||
recompute_vtx_normal (bool, optional): Recompute vertex normals. Defaults to True.
|
||||
with_mtl (bool, optional): Load mesh material files. Defaults to False.
|
||||
gen_color_gif (bool, optional): Generate GIF of color images. Defaults to False.
|
||||
gen_color_mp4 (bool, optional): Generate MP4 of color images. Defaults to False.
|
||||
gen_viewnormal_mp4 (bool, optional): Generate MP4 of view-space normals. Defaults to False.
|
||||
gen_glonormal_mp4 (bool, optional): Generate MP4 of global-space normals. Defaults to False.
|
||||
no_index_file (bool, optional): Skip saving index file. Defaults to False.
|
||||
light_factor (float, optional): PBR light intensity multiplier. Defaults to 1.0.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from embodied_gen.data.differentiable_render import ImageRender
|
||||
from embodied_gen.data.utils import CameraSetting
|
||||
from embodied_gen.utils.enum import RenderItems
|
||||
|
||||
camera_params = CameraSetting(
|
||||
num_images=6,
|
||||
elevation=[20, -10],
|
||||
distance=5,
|
||||
resolution_hw=(512,512),
|
||||
fov=math.radians(30),
|
||||
device='cuda',
|
||||
)
|
||||
render_items = [RenderItems.IMAGE.value, RenderItems.DEPTH.value]
|
||||
renderer = ImageRender(
|
||||
render_items,
|
||||
camera_params,
|
||||
with_mtl=args.with_mtl,
|
||||
gen_color_mp4=True,
|
||||
)
|
||||
renderer.render_mesh(mesh_path='mesh.obj', output_root='./renders')
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -198,6 +228,14 @@ class ImageRender(object):
|
||||
uuid: Union[str, List[str]] = None,
|
||||
prompts: List[str] = None,
|
||||
) -> None:
|
||||
"""Renders one or more meshes and saves outputs.
|
||||
|
||||
Args:
|
||||
mesh_path (Union[str, List[str]]): Path(s) to mesh files.
|
||||
output_root (str): Directory to save outputs.
|
||||
uuid (Union[str, List[str]], optional): Unique IDs for outputs.
|
||||
prompts (List[str], optional): Text prompts for videos.
|
||||
"""
|
||||
mesh_path = as_list(mesh_path)
|
||||
if uuid is None:
|
||||
uuid = [os.path.basename(p).split(".")[0] for p in mesh_path]
|
||||
@ -227,18 +265,15 @@ class ImageRender(object):
|
||||
def __call__(
|
||||
self, mesh_path: str, output_dir: str, prompt: str = None
|
||||
) -> dict[str, str]:
|
||||
"""Render a single mesh and return paths to the rendered outputs.
|
||||
|
||||
Processes the input mesh, renders multiple modalities (e.g., normals,
|
||||
depth, albedo), and optionally saves video or image sequences.
|
||||
"""Renders a single mesh and returns output paths.
|
||||
|
||||
Args:
|
||||
mesh_path (str): Path to the mesh file (.obj/.glb).
|
||||
output_dir (str): Directory to save rendered outputs.
|
||||
prompt (str, optional): Optional caption prompt for MP4 metadata.
|
||||
mesh_path (str): Path to mesh file.
|
||||
output_dir (str): Directory to save outputs.
|
||||
prompt (str, optional): Caption prompt for MP4 metadata.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: A mapping render types to the saved image paths.
|
||||
dict[str, str]: Mapping of render types to saved image paths.
|
||||
"""
|
||||
try:
|
||||
mesh = import_kaolin_mesh(mesh_path, self.with_mtl)
|
||||
|
||||
@ -16,17 +16,13 @@
|
||||
|
||||
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
from typing import Tuple, Union
|
||||
|
||||
import coacd
|
||||
import igraph
|
||||
import numpy as np
|
||||
import pyvista as pv
|
||||
import spaces
|
||||
import torch
|
||||
import trimesh
|
||||
import utils3d
|
||||
from pymeshfix import _meshfix
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -51,6 +51,33 @@ __all__ = ["PickEmbodiedGen"]
|
||||
|
||||
@register_env("PickEmbodiedGen-v1", max_episode_steps=100)
|
||||
class PickEmbodiedGen(BaseEnv):
|
||||
"""PickEmbodiedGen as gym env example for object pick-and-place tasks.
|
||||
|
||||
This environment simulates a robot interacting with 3D assets in the
|
||||
embodiedgen generated scene in SAPIEN. It supports multi-environment setups,
|
||||
dynamic reconfiguration, and hybrid rendering with 3D Gaussian Splatting.
|
||||
|
||||
Example:
|
||||
Use `gym.make` to create the `PickEmbodiedGen-v1` parallel environment.
|
||||
```python
|
||||
import gymnasium as gym
|
||||
env = gym.make(
|
||||
"PickEmbodiedGen-v1",
|
||||
num_envs=cfg.num_envs,
|
||||
render_mode=cfg.render_mode,
|
||||
enable_shadow=cfg.enable_shadow,
|
||||
layout_file=cfg.layout_file,
|
||||
control_mode=cfg.control_mode,
|
||||
camera_cfg=dict(
|
||||
camera_eye=cfg.camera_eye,
|
||||
camera_target_pt=cfg.camera_target_pt,
|
||||
image_hw=cfg.image_hw,
|
||||
fovy_deg=cfg.fovy_deg,
|
||||
),
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
SUPPORTED_ROBOTS = ["panda", "panda_wristcam", "fetch"]
|
||||
goal_thresh = 0.0
|
||||
|
||||
@ -63,6 +90,19 @@ class PickEmbodiedGen(BaseEnv):
|
||||
reconfiguration_freq: int = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initializes the PickEmbodiedGen environment.
|
||||
|
||||
Args:
|
||||
*args: Variable length argument list for the base class.
|
||||
robot_uids: The robot(s) to use in the environment.
|
||||
robot_init_qpos_noise: Noise added to the robot's initial joint
|
||||
positions.
|
||||
num_envs: The number of parallel environments to create.
|
||||
reconfiguration_freq: How often to reconfigure the scene. If None,
|
||||
it is set based on num_envs.
|
||||
**kwargs: Additional keyword arguments for environment setup,
|
||||
including layout_file, replace_objs, enable_grasp, etc.
|
||||
"""
|
||||
self.robot_init_qpos_noise = robot_init_qpos_noise
|
||||
if reconfiguration_freq is None:
|
||||
if num_envs == 1:
|
||||
@ -116,6 +156,22 @@ class PickEmbodiedGen(BaseEnv):
|
||||
def init_env_layouts(
|
||||
layout_file: str, num_envs: int, replace_objs: bool
|
||||
) -> list[LayoutInfo]:
|
||||
"""Initializes and saves layout files for each environment instance.
|
||||
|
||||
For each environment, this method creates a layout configuration. If
|
||||
`replace_objs` is True, it generates new object placements for each
|
||||
subsequent environment. The generated layouts are saved as new JSON
|
||||
files.
|
||||
|
||||
Args:
|
||||
layout_file: Path to the base layout JSON file.
|
||||
num_envs: The number of environments to create layouts for.
|
||||
replace_objs: If True, generates new object placements for each
|
||||
environment after the first one using BFS placement.
|
||||
|
||||
Returns:
|
||||
A list of file paths to the generated layout for each environment.
|
||||
"""
|
||||
layouts = []
|
||||
for env_idx in range(num_envs):
|
||||
if replace_objs and env_idx > 0:
|
||||
@ -136,6 +192,18 @@ class PickEmbodiedGen(BaseEnv):
|
||||
def compute_robot_init_pose(
|
||||
layouts: list[str], num_envs: int, z_offset: float = 0.0
|
||||
) -> list[list[float]]:
|
||||
"""Computes the initial pose for the robot in each environment.
|
||||
|
||||
Args:
|
||||
layouts: A list of file paths to the environment layouts.
|
||||
num_envs: The number of environments.
|
||||
z_offset: An optional vertical offset to apply to the robot's
|
||||
position to prevent collisions.
|
||||
|
||||
Returns:
|
||||
A list of initial poses ([x, y, z, qw, qx, qy, qz]) for the robot
|
||||
in each environment.
|
||||
"""
|
||||
robot_pose = []
|
||||
for env_idx in range(num_envs):
|
||||
layout = json.load(open(layouts[env_idx], "r"))
|
||||
@ -148,6 +216,11 @@ class PickEmbodiedGen(BaseEnv):
|
||||
|
||||
@property
|
||||
def _default_sim_config(self):
|
||||
"""Returns the default simulation configuration.
|
||||
|
||||
Returns:
|
||||
The default simulation configuration object.
|
||||
"""
|
||||
return SimConfig(
|
||||
scene_config=SceneConfig(
|
||||
solver_position_iterations=30,
|
||||
@ -163,6 +236,11 @@ class PickEmbodiedGen(BaseEnv):
|
||||
|
||||
@property
|
||||
def _default_sensor_configs(self):
|
||||
"""Returns the default sensor configurations for the agent.
|
||||
|
||||
Returns:
|
||||
A list containing the default camera configuration.
|
||||
"""
|
||||
pose = sapien_utils.look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1])
|
||||
|
||||
return [
|
||||
@ -171,6 +249,11 @@ class PickEmbodiedGen(BaseEnv):
|
||||
|
||||
@property
|
||||
def _default_human_render_camera_configs(self):
|
||||
"""Returns the default camera configuration for human-friendly rendering.
|
||||
|
||||
Returns:
|
||||
The default camera configuration for the renderer.
|
||||
"""
|
||||
pose = sapien_utils.look_at(
|
||||
eye=self.camera_cfg["camera_eye"],
|
||||
target=self.camera_cfg["camera_target_pt"],
|
||||
@ -187,10 +270,24 @@ class PickEmbodiedGen(BaseEnv):
|
||||
)
|
||||
|
||||
def _load_agent(self, options: dict):
|
||||
"""Loads the agent (robot) and a ground plane into the scene.
|
||||
|
||||
Args:
|
||||
options: A dictionary of options for loading the agent.
|
||||
"""
|
||||
self.ground = build_ground(self.scene)
|
||||
super()._load_agent(options, sapien.Pose(p=[-10, 0, 10]))
|
||||
|
||||
def _load_scene(self, options: dict):
|
||||
"""Loads all assets, objects, and the goal site into the scene.
|
||||
|
||||
This method iterates through the layouts for each environment, loads the
|
||||
specified assets, and adds them to the simulation. It also creates a
|
||||
kinematic sphere to represent the goal site.
|
||||
|
||||
Args:
|
||||
options: A dictionary of options for loading the scene.
|
||||
"""
|
||||
all_objects = []
|
||||
logger.info(f"Loading EmbodiedGen assets...")
|
||||
for env_idx in range(self.num_envs):
|
||||
@ -222,6 +319,15 @@ class PickEmbodiedGen(BaseEnv):
|
||||
self._hidden_objects.append(self.goal_site)
|
||||
|
||||
def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
|
||||
"""Initializes an episode for a given set of environments.
|
||||
|
||||
This method sets the goal position, resets the robot's joint positions
|
||||
with optional noise, and sets its root pose.
|
||||
|
||||
Args:
|
||||
env_idx: A tensor of environment indices to initialize.
|
||||
options: A dictionary of options for initialization.
|
||||
"""
|
||||
with torch.device(self.device):
|
||||
b = len(env_idx)
|
||||
goal_xyz = torch.zeros((b, 3))
|
||||
@ -256,6 +362,21 @@ class PickEmbodiedGen(BaseEnv):
|
||||
def render_gs3d_images(
|
||||
self, layouts: list[str], num_envs: int, init_quat: list[float]
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""Renders background images using a pre-trained Gaussian Splatting model.
|
||||
|
||||
This method pre-renders the static background for each environment from
|
||||
the perspective of all cameras to be used for hybrid rendering.
|
||||
|
||||
Args:
|
||||
layouts: A list of file paths to the environment layouts.
|
||||
num_envs: The number of environments.
|
||||
init_quat: An initial quaternion to orient the Gaussian Splatting
|
||||
model.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping a unique key (e.g., 'camera-env_idx') to the
|
||||
rendered background image as a numpy array.
|
||||
"""
|
||||
sim_coord_align = (
|
||||
torch.tensor(SIM_COORD_ALIGN).to(torch.float32).to(self.device)
|
||||
)
|
||||
@ -293,6 +414,15 @@ class PickEmbodiedGen(BaseEnv):
|
||||
return bg_images
|
||||
|
||||
def render(self):
|
||||
"""Renders the environment based on the configured render_mode.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If `render_mode` is not set.
|
||||
NotImplementedError: If the `render_mode` is not supported.
|
||||
|
||||
Returns:
|
||||
The rendered output, which varies depending on the render mode.
|
||||
"""
|
||||
if self.render_mode is None:
|
||||
raise RuntimeError("render_mode is not set.")
|
||||
if self.render_mode == "human":
|
||||
@ -315,6 +445,17 @@ class PickEmbodiedGen(BaseEnv):
|
||||
def render_rgb_array(
|
||||
self, camera_name: str = None, return_alpha: bool = False
|
||||
):
|
||||
"""Renders an RGB image from the human-facing render camera.
|
||||
|
||||
Args:
|
||||
camera_name: The name of the camera to render from. If None, uses
|
||||
all human render cameras.
|
||||
return_alpha: Whether to include the alpha channel in the output.
|
||||
|
||||
Returns:
|
||||
A numpy array representing the rendered image(s). If multiple
|
||||
cameras are used, the images are tiled.
|
||||
"""
|
||||
for obj in self._hidden_objects:
|
||||
obj.show_visual()
|
||||
self.scene.update_render(
|
||||
@ -335,6 +476,11 @@ class PickEmbodiedGen(BaseEnv):
|
||||
return tile_images(images)
|
||||
|
||||
def render_sensors(self):
|
||||
"""Renders images from all on-board sensor cameras.
|
||||
|
||||
Returns:
|
||||
A tiled image of all sensor outputs as a numpy array.
|
||||
"""
|
||||
images = []
|
||||
sensor_images = self.get_sensor_images()
|
||||
for image in sensor_images.values():
|
||||
@ -343,6 +489,14 @@ class PickEmbodiedGen(BaseEnv):
|
||||
return tile_images(images)
|
||||
|
||||
def hybrid_render(self):
|
||||
"""Renders a hybrid image by blending simulated foreground with a background.
|
||||
|
||||
The foreground is rendered with an alpha channel and then blended with
|
||||
the pre-rendered Gaussian Splatting background image.
|
||||
|
||||
Returns:
|
||||
A torch tensor of the final blended RGB images.
|
||||
"""
|
||||
fg_images = self.render_rgb_array(
|
||||
return_alpha=True
|
||||
) # (n_env, h, w, 3)
|
||||
@ -362,6 +516,16 @@ class PickEmbodiedGen(BaseEnv):
|
||||
return images[..., :3]
|
||||
|
||||
def evaluate(self):
|
||||
"""Evaluates the current state of the environment.
|
||||
|
||||
Checks for task success criteria such as whether the object is grasped,
|
||||
placed at the goal, and if the robot is static.
|
||||
|
||||
Returns:
|
||||
A dictionary containing boolean tensors for various success
|
||||
metrics, including 'is_grasped', 'is_obj_placed', and overall
|
||||
'success'.
|
||||
"""
|
||||
obj_to_goal_pos = (
|
||||
self.obj.pose.p
|
||||
) # self.goal_site.pose.p - self.obj.pose.p
|
||||
@ -381,10 +545,31 @@ class PickEmbodiedGen(BaseEnv):
|
||||
)
|
||||
|
||||
def _get_obs_extra(self, info: dict):
|
||||
"""Gets extra information for the observation dictionary.
|
||||
|
||||
Args:
|
||||
info: A dictionary containing evaluation information.
|
||||
|
||||
Returns:
|
||||
An empty dictionary, as no extra observations are added.
|
||||
"""
|
||||
|
||||
return dict()
|
||||
|
||||
def compute_dense_reward(self, obs: any, action: torch.Tensor, info: dict):
|
||||
"""Computes a dense reward for the current step.
|
||||
|
||||
The reward is a composite of reaching, grasping, placing, and
|
||||
maintaining a static final pose.
|
||||
|
||||
Args:
|
||||
obs: The current observation.
|
||||
action: The action taken in the current step.
|
||||
info: A dictionary containing evaluation information from `evaluate()`.
|
||||
|
||||
Returns:
|
||||
A tensor containing the dense reward for each environment.
|
||||
"""
|
||||
tcp_to_obj_dist = torch.linalg.norm(
|
||||
self.obj.pose.p - self.agent.tcp.pose.p, axis=1
|
||||
)
|
||||
@ -417,4 +602,14 @@ class PickEmbodiedGen(BaseEnv):
|
||||
def compute_normalized_dense_reward(
|
||||
self, obs: any, action: torch.Tensor, info: dict
|
||||
):
|
||||
"""Computes a dense reward normalized to be between 0 and 1.
|
||||
|
||||
Args:
|
||||
obs: The current observation.
|
||||
action: The action taken in the current step.
|
||||
info: A dictionary containing evaluation information from `evaluate()`.
|
||||
|
||||
Returns:
|
||||
A tensor containing the normalized dense reward for each environment.
|
||||
"""
|
||||
return self.compute_dense_reward(obs=obs, action=action, info=info) / 6
|
||||
|
||||
@ -40,7 +40,7 @@ class DelightingModel(object):
|
||||
"""A model to remove the lighting in image space.
|
||||
|
||||
This model is encapsulated based on the Hunyuan3D-Delight model
|
||||
from https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0 # noqa
|
||||
from `https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0` # noqa
|
||||
|
||||
Attributes:
|
||||
image_guide_scale (float): Weight of image guidance in diffusion process.
|
||||
|
||||
@ -38,26 +38,61 @@ __all__ = [
|
||||
|
||||
|
||||
class BasePipelineLoader(ABC):
|
||||
"""Abstract base class for loading Hugging Face image generation pipelines.
|
||||
|
||||
Attributes:
|
||||
device (str): Device to load the pipeline on.
|
||||
|
||||
Methods:
|
||||
load(): Loads and returns the pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, device="cuda"):
|
||||
self.device = device
|
||||
|
||||
@abstractmethod
|
||||
def load(self):
|
||||
"""Load and return the pipeline instance."""
|
||||
pass
|
||||
|
||||
|
||||
class BasePipelineRunner(ABC):
|
||||
"""Abstract base class for running image generation pipelines.
|
||||
|
||||
Attributes:
|
||||
pipe: The loaded pipeline.
|
||||
|
||||
Methods:
|
||||
run(prompt, **kwargs): Runs the pipeline with a prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, pipe):
|
||||
self.pipe = pipe
|
||||
|
||||
@abstractmethod
|
||||
def run(self, prompt: str, **kwargs) -> Image.Image:
|
||||
"""Run the pipeline with the given prompt.
|
||||
|
||||
Args:
|
||||
prompt (str): Text prompt for image generation.
|
||||
**kwargs: Additional pipeline arguments.
|
||||
|
||||
Returns:
|
||||
Image.Image: Generated image(s).
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# ===== SD3.5-medium =====
|
||||
class SD35Loader(BasePipelineLoader):
|
||||
"""Loader for Stable Diffusion 3.5 medium pipeline."""
|
||||
|
||||
def load(self):
|
||||
"""Load the Stable Diffusion 3.5 medium pipeline.
|
||||
|
||||
Returns:
|
||||
StableDiffusion3Pipeline: Loaded pipeline.
|
||||
"""
|
||||
pipe = StableDiffusion3Pipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3.5-medium",
|
||||
torch_dtype=torch.float16,
|
||||
@ -70,12 +105,25 @@ class SD35Loader(BasePipelineLoader):
|
||||
|
||||
|
||||
class SD35Runner(BasePipelineRunner):
|
||||
"""Runner for Stable Diffusion 3.5 medium pipeline."""
|
||||
|
||||
def run(self, prompt: str, **kwargs) -> Image.Image:
|
||||
"""Generate images using Stable Diffusion 3.5 medium.
|
||||
|
||||
Args:
|
||||
prompt (str): Text prompt.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
Image.Image: Generated image(s).
|
||||
"""
|
||||
return self.pipe(prompt=prompt, **kwargs).images
|
||||
|
||||
|
||||
# ===== Cosmos2 =====
|
||||
class CosmosLoader(BasePipelineLoader):
|
||||
"""Loader for Cosmos2 text-to-image pipeline."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id="nvidia/Cosmos-Predict2-2B-Text2Image",
|
||||
@ -87,6 +135,8 @@ class CosmosLoader(BasePipelineLoader):
|
||||
self.local_dir = local_dir
|
||||
|
||||
def _patch(self):
|
||||
"""Patch model and processor for optimized loading."""
|
||||
|
||||
def patch_model(cls):
|
||||
orig = cls.from_pretrained
|
||||
|
||||
@ -110,6 +160,11 @@ class CosmosLoader(BasePipelineLoader):
|
||||
patch_processor(SiglipProcessor)
|
||||
|
||||
def load(self):
|
||||
"""Load the Cosmos2 text-to-image pipeline.
|
||||
|
||||
Returns:
|
||||
Cosmos2TextToImagePipeline: Loaded pipeline.
|
||||
"""
|
||||
self._patch()
|
||||
snapshot_download(
|
||||
repo_id=self.model_id,
|
||||
@ -141,7 +196,19 @@ class CosmosLoader(BasePipelineLoader):
|
||||
|
||||
|
||||
class CosmosRunner(BasePipelineRunner):
|
||||
"""Runner for Cosmos2 text-to-image pipeline."""
|
||||
|
||||
def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
|
||||
"""Generate images using Cosmos2 pipeline.
|
||||
|
||||
Args:
|
||||
prompt (str): Text prompt.
|
||||
negative_prompt (str, optional): Negative prompt.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
Image.Image: Generated image(s).
|
||||
"""
|
||||
return self.pipe(
|
||||
prompt=prompt, negative_prompt=negative_prompt, **kwargs
|
||||
).images
|
||||
@ -149,7 +216,14 @@ class CosmosRunner(BasePipelineRunner):
|
||||
|
||||
# ===== Kolors =====
|
||||
class KolorsLoader(BasePipelineLoader):
|
||||
"""Loader for Kolors pipeline."""
|
||||
|
||||
def load(self):
|
||||
"""Load the Kolors pipeline.
|
||||
|
||||
Returns:
|
||||
KolorsPipeline: Loaded pipeline.
|
||||
"""
|
||||
pipe = KolorsPipeline.from_pretrained(
|
||||
"Kwai-Kolors/Kolors-diffusers",
|
||||
torch_dtype=torch.float16,
|
||||
@ -164,13 +238,31 @@ class KolorsLoader(BasePipelineLoader):
|
||||
|
||||
|
||||
class KolorsRunner(BasePipelineRunner):
|
||||
"""Runner for Kolors pipeline."""
|
||||
|
||||
def run(self, prompt: str, **kwargs) -> Image.Image:
|
||||
"""Generate images using Kolors pipeline.
|
||||
|
||||
Args:
|
||||
prompt (str): Text prompt.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
Image.Image: Generated image(s).
|
||||
"""
|
||||
return self.pipe(prompt=prompt, **kwargs).images
|
||||
|
||||
|
||||
# ===== Flux =====
|
||||
class FluxLoader(BasePipelineLoader):
|
||||
"""Loader for Flux pipeline."""
|
||||
|
||||
def load(self):
|
||||
"""Load the Flux pipeline.
|
||||
|
||||
Returns:
|
||||
FluxPipeline: Loaded pipeline.
|
||||
"""
|
||||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
|
||||
@ -182,20 +274,50 @@ class FluxLoader(BasePipelineLoader):
|
||||
|
||||
|
||||
class FluxRunner(BasePipelineRunner):
|
||||
"""Runner for Flux pipeline."""
|
||||
|
||||
def run(self, prompt: str, **kwargs) -> Image.Image:
|
||||
"""Generate images using Flux pipeline.
|
||||
|
||||
Args:
|
||||
prompt (str): Text prompt.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
Image.Image: Generated image(s).
|
||||
"""
|
||||
return self.pipe(prompt=prompt, **kwargs).images
|
||||
|
||||
|
||||
# ===== Chroma =====
|
||||
class ChromaLoader(BasePipelineLoader):
|
||||
"""Loader for Chroma pipeline."""
|
||||
|
||||
def load(self):
|
||||
"""Load the Chroma pipeline.
|
||||
|
||||
Returns:
|
||||
ChromaPipeline: Loaded pipeline.
|
||||
"""
|
||||
return ChromaPipeline.from_pretrained(
|
||||
"lodestones/Chroma", torch_dtype=torch.bfloat16
|
||||
).to(self.device)
|
||||
|
||||
|
||||
class ChromaRunner(BasePipelineRunner):
|
||||
"""Runner for Chroma pipeline."""
|
||||
|
||||
def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
|
||||
"""Generate images using Chroma pipeline.
|
||||
|
||||
Args:
|
||||
prompt (str): Text prompt.
|
||||
negative_prompt (str, optional): Negative prompt.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
Image.Image: Generated image(s).
|
||||
"""
|
||||
return self.pipe(
|
||||
prompt=prompt, negative_prompt=negative_prompt, **kwargs
|
||||
).images
|
||||
@ -211,6 +333,22 @@ PIPELINE_REGISTRY = {
|
||||
|
||||
|
||||
def build_hf_image_pipeline(name: str, device="cuda") -> BasePipelineRunner:
|
||||
"""Build a Hugging Face image generation pipeline runner by name.
|
||||
|
||||
Args:
|
||||
name (str): Name of the pipeline (e.g., "sd35", "cosmos").
|
||||
device (str): Device to load the pipeline on.
|
||||
|
||||
Returns:
|
||||
BasePipelineRunner: Pipeline runner instance.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from embodied_gen.models.image_comm_model import build_hf_image_pipeline
|
||||
runner = build_hf_image_pipeline("sd35")
|
||||
images = runner.run(prompt="A robot holding a sign that says 'Hello'")
|
||||
```
|
||||
"""
|
||||
if name not in PIPELINE_REGISTRY:
|
||||
raise ValueError(f"Unsupported model: {name}")
|
||||
loader_cls, runner_cls = PIPELINE_REGISTRY[name]
|
||||
|
||||
@ -376,6 +376,21 @@ LAYOUT_DESCRIBER_PROMPT = """
|
||||
|
||||
|
||||
class LayoutDesigner(object):
|
||||
"""A class for querying GPT-based scene layout reasoning and formatting responses.
|
||||
|
||||
Attributes:
|
||||
prompt (str): The system prompt for GPT.
|
||||
verbose (bool): Whether to log responses.
|
||||
gpt_client (GPTclient): The GPT client instance.
|
||||
|
||||
Methods:
|
||||
query(prompt, params): Query GPT with a prompt and parameters.
|
||||
format_response(response): Parse and clean JSON response.
|
||||
format_response_repair(response): Repair and parse JSON response.
|
||||
save_output(output, save_path): Save output to file.
|
||||
__call__(prompt, save_path, params): Query and process output.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gpt_client: GPTclient,
|
||||
@ -387,6 +402,15 @@ class LayoutDesigner(object):
|
||||
self.gpt_client = gpt_client
|
||||
|
||||
def query(self, prompt: str, params: dict = None) -> str:
|
||||
"""Query GPT with the system prompt and user prompt.
|
||||
|
||||
Args:
|
||||
prompt (str): User prompt.
|
||||
params (dict, optional): GPT parameters.
|
||||
|
||||
Returns:
|
||||
str: GPT response.
|
||||
"""
|
||||
full_prompt = self.prompt + f"\n\nInput:\n\"{prompt}\""
|
||||
|
||||
response = self.gpt_client.query(
|
||||
@ -400,6 +424,17 @@ class LayoutDesigner(object):
|
||||
return response
|
||||
|
||||
def format_response(self, response: str) -> dict:
|
||||
"""Format and parse GPT response as JSON.
|
||||
|
||||
Args:
|
||||
response (str): Raw GPT response.
|
||||
|
||||
Returns:
|
||||
dict: Parsed JSON output.
|
||||
|
||||
Raises:
|
||||
json.JSONDecodeError: If parsing fails.
|
||||
"""
|
||||
cleaned = re.sub(r"^```json\s*|\s*```$", "", response.strip())
|
||||
try:
|
||||
output = json.loads(cleaned)
|
||||
@ -411,9 +446,23 @@ class LayoutDesigner(object):
|
||||
return output
|
||||
|
||||
def format_response_repair(self, response: str) -> dict:
|
||||
"""Repair and parse possibly broken JSON response.
|
||||
|
||||
Args:
|
||||
response (str): Raw GPT response.
|
||||
|
||||
Returns:
|
||||
dict: Parsed JSON output.
|
||||
"""
|
||||
return json_repair.loads(response)
|
||||
|
||||
def save_output(self, output: dict, save_path: str) -> None:
|
||||
"""Save output dictionary to a file.
|
||||
|
||||
Args:
|
||||
output (dict): Output data.
|
||||
save_path (str): Path to save the file.
|
||||
"""
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
with open(save_path, 'w') as f:
|
||||
json.dump(output, f, indent=4)
|
||||
@ -421,6 +470,16 @@ class LayoutDesigner(object):
|
||||
def __call__(
|
||||
self, prompt: str, save_path: str = None, params: dict = None
|
||||
) -> dict | str:
|
||||
"""Query GPT and process the output.
|
||||
|
||||
Args:
|
||||
prompt (str): User prompt.
|
||||
save_path (str, optional): Path to save output.
|
||||
params (dict, optional): GPT parameters.
|
||||
|
||||
Returns:
|
||||
dict | str: Output data.
|
||||
"""
|
||||
response = self.query(prompt, params=params)
|
||||
output = self.format_response_repair(response)
|
||||
self.save_output(output, save_path) if save_path else None
|
||||
@ -442,6 +501,29 @@ LAYOUT_DESCRIBER = LayoutDesigner(
|
||||
def build_scene_layout(
|
||||
task_desc: str, output_path: str = None, gpt_params: dict = None
|
||||
) -> LayoutInfo:
|
||||
"""Build a 3D scene layout from a natural language task description.
|
||||
|
||||
This function uses GPT-based reasoning to generate a structured scene layout,
|
||||
including object hierarchy, spatial relations, and style descriptions.
|
||||
|
||||
Args:
|
||||
task_desc (str): Natural language description of the robotic task.
|
||||
output_path (str, optional): Path to save the visualized scene tree.
|
||||
gpt_params (dict, optional): Parameters for GPT queries.
|
||||
|
||||
Returns:
|
||||
LayoutInfo: Structured layout information for the scene.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from embodied_gen.models.layout import build_scene_layout
|
||||
layout_info = build_scene_layout(
|
||||
task_desc="Put the apples on the table on the plate",
|
||||
output_path="outputs/scene_tree.jpg",
|
||||
)
|
||||
print(layout_info)
|
||||
```
|
||||
"""
|
||||
layout_relation = LAYOUT_DISASSEMBLER(task_desc, params=gpt_params)
|
||||
layout_tree = LAYOUT_GRAPHER(layout_relation, params=gpt_params)
|
||||
object_mapping = Scene3DItemEnum.object_mapping(layout_relation)
|
||||
|
||||
@ -48,12 +48,19 @@ __all__ = [
|
||||
|
||||
|
||||
class SAMRemover(object):
|
||||
"""Loading SAM models and performing background removal on images.
|
||||
"""Loads SAM models and performs background removal on images.
|
||||
|
||||
Attributes:
|
||||
checkpoint (str): Path to the model checkpoint.
|
||||
model_type (str): Type of the SAM model to load (default: "vit_h").
|
||||
area_ratio (float): Area ratio filtering small connected components.
|
||||
model_type (str): Type of the SAM model to load.
|
||||
area_ratio (float): Area ratio for filtering small connected components.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from embodied_gen.models.segment_model import SAMRemover
|
||||
remover = SAMRemover(model_type="vit_h")
|
||||
result = remover("input.jpg", "output.png")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -78,6 +85,14 @@ class SAMRemover(object):
|
||||
self.mask_generator = self._load_sam_model(checkpoint)
|
||||
|
||||
def _load_sam_model(self, checkpoint: str) -> SamAutomaticMaskGenerator:
|
||||
"""Loads the SAM model and returns a mask generator.
|
||||
|
||||
Args:
|
||||
checkpoint (str): Path to model checkpoint.
|
||||
|
||||
Returns:
|
||||
SamAutomaticMaskGenerator: Mask generator instance.
|
||||
"""
|
||||
sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
|
||||
sam.to(device=self.device)
|
||||
|
||||
@ -89,13 +104,11 @@ class SAMRemover(object):
|
||||
"""Removes the background from an image using the SAM model.
|
||||
|
||||
Args:
|
||||
image (Union[str, Image.Image, np.ndarray]): Input image,
|
||||
can be a file path, PIL Image, or numpy array.
|
||||
save_path (str): Path to save the output image (default: None).
|
||||
image (Union[str, Image.Image, np.ndarray]): Input image.
|
||||
save_path (str, optional): Path to save the output image.
|
||||
|
||||
Returns:
|
||||
Image.Image: The image with background removed,
|
||||
including an alpha channel.
|
||||
Image.Image: Image with background removed (RGBA).
|
||||
"""
|
||||
# Convert input to numpy array
|
||||
if isinstance(image, str):
|
||||
@ -134,6 +147,15 @@ class SAMRemover(object):
|
||||
|
||||
|
||||
class SAMPredictor(object):
|
||||
"""Loads SAM models and predicts segmentation masks from user points.
|
||||
|
||||
Args:
|
||||
checkpoint (str, optional): Path to model checkpoint.
|
||||
model_type (str, optional): SAM model type.
|
||||
binary_thresh (float, optional): Threshold for binary mask.
|
||||
device (str, optional): Device for inference.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint: str = None,
|
||||
@ -157,12 +179,28 @@ class SAMPredictor(object):
|
||||
self.binary_thresh = binary_thresh
|
||||
|
||||
def _load_sam_model(self, checkpoint: str) -> SamPredictor:
|
||||
"""Loads the SAM model and returns a predictor.
|
||||
|
||||
Args:
|
||||
checkpoint (str): Path to model checkpoint.
|
||||
|
||||
Returns:
|
||||
SamPredictor: Predictor instance.
|
||||
"""
|
||||
sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
|
||||
sam.to(device=self.device)
|
||||
|
||||
return SamPredictor(sam)
|
||||
|
||||
def preprocess_image(self, image: Image.Image) -> np.ndarray:
|
||||
"""Preprocesses input image for SAM prediction.
|
||||
|
||||
Args:
|
||||
image (Image.Image): Input image.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Preprocessed image array.
|
||||
"""
|
||||
if isinstance(image, str):
|
||||
image = Image.open(image)
|
||||
elif isinstance(image, np.ndarray):
|
||||
@ -178,6 +216,15 @@ class SAMPredictor(object):
|
||||
image: np.ndarray,
|
||||
selected_points: list[list[int]],
|
||||
) -> np.ndarray:
|
||||
"""Generates segmentation masks from selected points.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): Input image array.
|
||||
selected_points (list[list[int]]): List of points and labels.
|
||||
|
||||
Returns:
|
||||
list[tuple[np.ndarray, str]]: List of masks and names.
|
||||
"""
|
||||
if len(selected_points) == 0:
|
||||
return []
|
||||
|
||||
@ -220,6 +267,15 @@ class SAMPredictor(object):
|
||||
def get_segmented_image(
|
||||
self, image: np.ndarray, masks: list[tuple[np.ndarray, str]]
|
||||
) -> Image.Image:
|
||||
"""Combines masks and returns segmented image with alpha channel.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): Input image array.
|
||||
masks (list[tuple[np.ndarray, str]]): List of masks.
|
||||
|
||||
Returns:
|
||||
Image.Image: Segmented RGBA image.
|
||||
"""
|
||||
seg_image = Image.fromarray(image, mode="RGB")
|
||||
alpha_channel = np.zeros(
|
||||
(seg_image.height, seg_image.width), dtype=np.uint8
|
||||
@ -241,6 +297,15 @@ class SAMPredictor(object):
|
||||
image: Union[str, Image.Image, np.ndarray],
|
||||
selected_points: list[list[int]],
|
||||
) -> Image.Image:
|
||||
"""Segments image using selected points.
|
||||
|
||||
Args:
|
||||
image (Union[str, Image.Image, np.ndarray]): Input image.
|
||||
selected_points (list[list[int]]): List of points and labels.
|
||||
|
||||
Returns:
|
||||
Image.Image: Segmented RGBA image.
|
||||
"""
|
||||
image = self.preprocess_image(image)
|
||||
self.predictor.set_image(image)
|
||||
masks = self.generate_masks(image, selected_points)
|
||||
@ -249,12 +314,32 @@ class SAMPredictor(object):
|
||||
|
||||
|
||||
class RembgRemover(object):
|
||||
"""Removes background from images using the rembg library.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from embodied_gen.models.segment_model import RembgRemover
|
||||
remover = RembgRemover()
|
||||
result = remover("input.jpg", "output.png")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes the RembgRemover."""
|
||||
self.rembg_session = rembg.new_session("u2net")
|
||||
|
||||
def __call__(
|
||||
self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
|
||||
) -> Image.Image:
|
||||
"""Removes background from an image.
|
||||
|
||||
Args:
|
||||
image (Union[str, Image.Image, np.ndarray]): Input image.
|
||||
save_path (str, optional): Path to save the output image.
|
||||
|
||||
Returns:
|
||||
Image.Image: Image with background removed (RGBA).
|
||||
"""
|
||||
if isinstance(image, str):
|
||||
image = Image.open(image)
|
||||
elif isinstance(image, np.ndarray):
|
||||
@ -271,7 +356,18 @@ class RembgRemover(object):
|
||||
|
||||
|
||||
class BMGG14Remover(object):
|
||||
"""Removes background using the RMBG-1.4 segmentation model.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from embodied_gen.models.segment_model import BMGG14Remover
|
||||
remover = BMGG14Remover()
|
||||
result = remover("input.jpg", "output.png")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initializes the BMGG14Remover."""
|
||||
self.model = pipeline(
|
||||
"image-segmentation",
|
||||
model="briaai/RMBG-1.4",
|
||||
@ -281,6 +377,15 @@ class BMGG14Remover(object):
|
||||
def __call__(
|
||||
self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
|
||||
):
|
||||
"""Removes background from an image.
|
||||
|
||||
Args:
|
||||
image (Union[str, Image.Image, np.ndarray]): Input image.
|
||||
save_path (str, optional): Path to save the output image.
|
||||
|
||||
Returns:
|
||||
Image.Image: Image with background removed.
|
||||
"""
|
||||
if isinstance(image, str):
|
||||
image = Image.open(image)
|
||||
elif isinstance(image, np.ndarray):
|
||||
@ -299,6 +404,16 @@ class BMGG14Remover(object):
|
||||
def invert_rgba_pil(
|
||||
image: Image.Image, mask: Image.Image, save_path: str = None
|
||||
) -> Image.Image:
|
||||
"""Inverts the alpha channel of an RGBA image using a mask.
|
||||
|
||||
Args:
|
||||
image (Image.Image): Input RGB image.
|
||||
mask (Image.Image): Mask image for alpha inversion.
|
||||
save_path (str, optional): Path to save the output image.
|
||||
|
||||
Returns:
|
||||
Image.Image: RGBA image with inverted alpha.
|
||||
"""
|
||||
mask = (255 - np.array(mask))[..., None]
|
||||
image_array = np.concatenate([np.array(image), mask], axis=-1)
|
||||
inverted_image = Image.fromarray(image_array, "RGBA")
|
||||
@ -318,6 +433,20 @@ def get_segmented_image_by_agent(
|
||||
save_path: str = None,
|
||||
mode: Literal["loose", "strict"] = "loose",
|
||||
) -> Image.Image:
|
||||
"""Segments an image using SAM and rembg, with quality checking.
|
||||
|
||||
Args:
|
||||
image (Image.Image): Input image.
|
||||
sam_remover (SAMRemover): SAM-based remover.
|
||||
rbg_remover (RembgRemover): rembg-based remover.
|
||||
seg_checker (ImageSegChecker, optional): Quality checker.
|
||||
save_path (str, optional): Path to save the output image.
|
||||
mode (Literal["loose", "strict"], optional): Segmentation mode.
|
||||
|
||||
Returns:
|
||||
Image.Image: Segmented RGBA image.
|
||||
"""
|
||||
|
||||
def _is_valid_seg(raw_img: Image.Image, seg_img: Image.Image) -> bool:
|
||||
if seg_checker is None:
|
||||
return True
|
||||
|
||||
@ -39,13 +39,38 @@ __all__ = [
|
||||
|
||||
|
||||
class ImageStableSR:
|
||||
"""Super-resolution image upscaler using Stable Diffusion x4 upscaling model from StabilityAI."""
|
||||
"""Super-resolution image upscaler using Stable Diffusion x4 upscaling model.
|
||||
|
||||
This class wraps the StabilityAI Stable Diffusion x4 upscaler for high-quality
|
||||
image super-resolution.
|
||||
|
||||
Args:
|
||||
model_path (str, optional): Path or HuggingFace repo for the model.
|
||||
device (str, optional): Device for inference.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from embodied_gen.models.sr_model import ImageStableSR
|
||||
from PIL import Image
|
||||
|
||||
sr_model = ImageStableSR()
|
||||
img = Image.open("input.png")
|
||||
upscaled = sr_model(img)
|
||||
upscaled.save("output.png")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str = "stabilityai/stable-diffusion-x4-upscaler",
|
||||
device="cuda",
|
||||
) -> None:
|
||||
"""Initializes the Stable Diffusion x4 upscaler.
|
||||
|
||||
Args:
|
||||
model_path (str, optional): Model path or repo.
|
||||
device (str, optional): Device for inference.
|
||||
"""
|
||||
from diffusers import StableDiffusionUpscalePipeline
|
||||
|
||||
self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained(
|
||||
@ -62,6 +87,16 @@ class ImageStableSR:
|
||||
prompt: str = "",
|
||||
infer_step: int = 20,
|
||||
) -> Image.Image:
|
||||
"""Performs super-resolution on the input image.
|
||||
|
||||
Args:
|
||||
image (Union[Image.Image, np.ndarray]): Input image.
|
||||
prompt (str, optional): Text prompt for upscaling.
|
||||
infer_step (int, optional): Number of inference steps.
|
||||
|
||||
Returns:
|
||||
Image.Image: Upscaled image.
|
||||
"""
|
||||
if isinstance(image, np.ndarray):
|
||||
image = Image.fromarray(image)
|
||||
|
||||
@ -86,9 +121,26 @@ class ImageRealESRGAN:
|
||||
Attributes:
|
||||
outscale (int): The output image scale factor (e.g., 2, 4).
|
||||
model_path (str): Path to the pre-trained model weights.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from embodied_gen.models.sr_model import ImageRealESRGAN
|
||||
from PIL import Image
|
||||
|
||||
sr_model = ImageRealESRGAN(outscale=4)
|
||||
img = Image.open("input.png")
|
||||
upscaled = sr_model(img)
|
||||
upscaled.save("output.png")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, outscale: int, model_path: str = None) -> None:
|
||||
"""Initializes the RealESRGAN upscaler.
|
||||
|
||||
Args:
|
||||
outscale (int): Output scale factor.
|
||||
model_path (str, optional): Path to model weights.
|
||||
"""
|
||||
# monkey patch to support torchvision>=0.16
|
||||
import torchvision
|
||||
from packaging import version
|
||||
@ -122,6 +174,7 @@ class ImageRealESRGAN:
|
||||
self.model_path = model_path
|
||||
|
||||
def _lazy_init(self):
|
||||
"""Lazily initializes the RealESRGAN model."""
|
||||
if self.upsampler is None:
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan import RealESRGANer
|
||||
@ -145,6 +198,14 @@ class ImageRealESRGAN:
|
||||
|
||||
@spaces.GPU
|
||||
def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
|
||||
"""Performs super-resolution on the input image.
|
||||
|
||||
Args:
|
||||
image (Union[Image.Image, np.ndarray]): Input image.
|
||||
|
||||
Returns:
|
||||
Image.Image: Upscaled image.
|
||||
"""
|
||||
self._lazy_init()
|
||||
|
||||
if isinstance(image, Image.Image):
|
||||
|
||||
@ -60,6 +60,11 @@ PROMPT_KAPPEND = "Single {object}, in the center of the image, white background,
|
||||
|
||||
|
||||
def download_kolors_weights(local_dir: str = "weights/Kolors") -> None:
|
||||
"""Downloads Kolors model weights from HuggingFace.
|
||||
|
||||
Args:
|
||||
local_dir (str, optional): Local directory to store weights.
|
||||
"""
|
||||
logger.info(f"Download kolors weights from huggingface...")
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
subprocess.run(
|
||||
@ -93,6 +98,22 @@ def build_text2img_ip_pipeline(
|
||||
ref_scale: float,
|
||||
device: str = "cuda",
|
||||
) -> StableDiffusionXLPipelineIP:
|
||||
"""Builds a Stable Diffusion XL pipeline with IP-Adapter for text-to-image generation.
|
||||
|
||||
Args:
|
||||
ckpt_dir (str): Directory containing model checkpoints.
|
||||
ref_scale (float): Reference scale for IP-Adapter.
|
||||
device (str, optional): Device for inference.
|
||||
|
||||
Returns:
|
||||
StableDiffusionXLPipelineIP: Configured pipeline.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from embodied_gen.models.text_model import build_text2img_ip_pipeline
|
||||
pipe = build_text2img_ip_pipeline("weights/Kolors", ref_scale=0.3)
|
||||
```
|
||||
"""
|
||||
download_kolors_weights(ckpt_dir)
|
||||
|
||||
text_encoder = ChatGLMModel.from_pretrained(
|
||||
@ -146,6 +167,21 @@ def build_text2img_pipeline(
|
||||
ckpt_dir: str,
|
||||
device: str = "cuda",
|
||||
) -> StableDiffusionXLPipeline:
|
||||
"""Builds a Stable Diffusion XL pipeline for text-to-image generation.
|
||||
|
||||
Args:
|
||||
ckpt_dir (str): Directory containing model checkpoints.
|
||||
device (str, optional): Device for inference.
|
||||
|
||||
Returns:
|
||||
StableDiffusionXLPipeline: Configured pipeline.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from embodied_gen.models.text_model import build_text2img_pipeline
|
||||
pipe = build_text2img_pipeline("weights/Kolors")
|
||||
```
|
||||
"""
|
||||
download_kolors_weights(ckpt_dir)
|
||||
|
||||
text_encoder = ChatGLMModel.from_pretrained(
|
||||
@ -185,6 +221,29 @@ def text2img_gen(
|
||||
ip_image_size: int = 512,
|
||||
seed: int = None,
|
||||
) -> list[Image.Image]:
|
||||
"""Generates images from text prompts using a Stable Diffusion XL pipeline.
|
||||
|
||||
Args:
|
||||
prompt (str): Text prompt for image generation.
|
||||
n_sample (int): Number of images to generate.
|
||||
guidance_scale (float): Guidance scale for diffusion.
|
||||
pipeline (StableDiffusionXLPipeline | StableDiffusionXLPipelineIP): Pipeline instance.
|
||||
ip_image (Image.Image | str, optional): Reference image for IP-Adapter.
|
||||
image_wh (tuple[int, int], optional): Output image size (width, height).
|
||||
infer_step (int, optional): Number of inference steps.
|
||||
ip_image_size (int, optional): Size for IP-Adapter image.
|
||||
seed (int, optional): Random seed.
|
||||
|
||||
Returns:
|
||||
list[Image.Image]: List of generated images.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from embodied_gen.models.text_model import text2img_gen
|
||||
images = text2img_gen(prompt="banana", n_sample=3, guidance_scale=7.5)
|
||||
images[0].save("banana.png")
|
||||
```
|
||||
"""
|
||||
prompt = PROMPT_KAPPEND.format(object=prompt.strip())
|
||||
logger.info(f"Processing prompt: {prompt}")
|
||||
|
||||
|
||||
@ -53,26 +53,31 @@ from thirdparty.pano2room.utils.functions import (
|
||||
|
||||
|
||||
class Pano2MeshSRPipeline:
|
||||
"""Converting panoramic RGB image into 3D mesh representations, followed by inpainting and mesh refinement.
|
||||
"""Pipeline for converting panoramic RGB images into 3D mesh representations.
|
||||
|
||||
This class integrates several key components including:
|
||||
- Depth estimation from RGB panorama
|
||||
- Inpainting of missing regions under offsets
|
||||
- RGB-D to mesh conversion
|
||||
- Multi-view mesh repair
|
||||
- 3D Gaussian Splatting (3DGS) dataset generation
|
||||
This class integrates depth estimation, inpainting, mesh conversion, multi-view mesh repair,
|
||||
and 3D Gaussian Splatting (3DGS) dataset generation.
|
||||
|
||||
Args:
|
||||
config (Pano2MeshSRConfig): Configuration object containing model and pipeline parameters.
|
||||
|
||||
Example:
|
||||
```python
|
||||
```py
|
||||
from embodied_gen.trainer.pono2mesh_trainer import Pano2MeshSRPipeline
|
||||
from embodied_gen.utils.config import Pano2MeshSRConfig
|
||||
|
||||
config = Pano2MeshSRConfig()
|
||||
pipeline = Pano2MeshSRPipeline(config)
|
||||
pipeline(pano_image='example.png', output_dir='./output')
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config: Pano2MeshSRConfig) -> None:
|
||||
"""Initializes the pipeline with models and camera poses.
|
||||
|
||||
Args:
|
||||
config (Pano2MeshSRConfig): Configuration object.
|
||||
"""
|
||||
self.cfg = config
|
||||
self.device = config.device
|
||||
|
||||
@ -93,6 +98,7 @@ class Pano2MeshSRPipeline:
|
||||
self.kernel = torch.from_numpy(kernel).float().to(self.device)
|
||||
|
||||
def init_mesh_params(self) -> None:
|
||||
"""Initializes mesh parameters and inpaint mask."""
|
||||
torch.set_default_device(self.device)
|
||||
self.inpaint_mask = torch.ones(
|
||||
(self.cfg.cubemap_h, self.cfg.cubemap_w), dtype=torch.bool
|
||||
@ -103,6 +109,14 @@ class Pano2MeshSRPipeline:
|
||||
|
||||
@staticmethod
|
||||
def read_camera_pose_file(filepath: str) -> np.ndarray:
|
||||
"""Reads a camera pose file and returns the pose matrix.
|
||||
|
||||
Args:
|
||||
filepath (str): Path to the camera pose file.
|
||||
|
||||
Returns:
|
||||
np.ndarray: 4x4 camera pose matrix.
|
||||
"""
|
||||
with open(filepath, "r") as f:
|
||||
values = [float(num) for line in f for num in line.split()]
|
||||
|
||||
@ -111,6 +125,14 @@ class Pano2MeshSRPipeline:
|
||||
def load_camera_poses(
|
||||
self, trajectory_dir: str
|
||||
) -> tuple[np.ndarray, list[torch.Tensor]]:
|
||||
"""Loads camera poses from a directory.
|
||||
|
||||
Args:
|
||||
trajectory_dir (str): Directory containing camera pose files.
|
||||
|
||||
Returns:
|
||||
tuple[np.ndarray, list[torch.Tensor]]: List of relative camera poses.
|
||||
"""
|
||||
pose_filenames = sorted(
|
||||
[
|
||||
fname
|
||||
@ -148,6 +170,14 @@ class Pano2MeshSRPipeline:
|
||||
def load_inpaint_poses(
|
||||
self, poses: torch.Tensor
|
||||
) -> dict[int, torch.Tensor]:
|
||||
"""Samples and loads poses for inpainting.
|
||||
|
||||
Args:
|
||||
poses (torch.Tensor): Tensor of camera poses.
|
||||
|
||||
Returns:
|
||||
dict[int, torch.Tensor]: Dictionary mapping indices to pose tensors.
|
||||
"""
|
||||
inpaint_poses = dict()
|
||||
sampled_views = poses[:: self.cfg.inpaint_frame_stride]
|
||||
init_pose = torch.eye(4)
|
||||
@ -162,6 +192,14 @@ class Pano2MeshSRPipeline:
|
||||
return inpaint_poses
|
||||
|
||||
def project(self, world_to_cam: torch.Tensor):
|
||||
"""Projects the mesh to an image using the given camera pose.
|
||||
|
||||
Args:
|
||||
world_to_cam (torch.Tensor): World-to-camera transformation matrix.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Projected RGB image, inpaint mask, and depth map.
|
||||
"""
|
||||
(
|
||||
project_image,
|
||||
project_depth,
|
||||
@ -185,6 +223,14 @@ class Pano2MeshSRPipeline:
|
||||
return project_image[:3, ...], inpaint_mask, project_depth
|
||||
|
||||
def render_pano(self, pose: torch.Tensor):
|
||||
"""Renders a panorama from the mesh using the given pose.
|
||||
|
||||
Args:
|
||||
pose (torch.Tensor): Camera pose.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: RGB panorama, depth map, and mask.
|
||||
"""
|
||||
cubemap_list = []
|
||||
for cubemap_pose in self.cubemap_w2cs:
|
||||
project_pose = cubemap_pose @ pose
|
||||
@ -213,6 +259,15 @@ class Pano2MeshSRPipeline:
|
||||
world_to_cam: torch.Tensor = None,
|
||||
using_distance_map: bool = True,
|
||||
) -> None:
|
||||
"""Converts RGB-D images to mesh and updates mesh parameters.
|
||||
|
||||
Args:
|
||||
rgb (torch.Tensor): RGB image tensor.
|
||||
depth (torch.Tensor): Depth map tensor.
|
||||
inpaint_mask (torch.Tensor): Inpaint mask tensor.
|
||||
world_to_cam (torch.Tensor, optional): Camera pose.
|
||||
using_distance_map (bool, optional): Whether to use distance map.
|
||||
"""
|
||||
if world_to_cam is None:
|
||||
world_to_cam = torch.eye(4, dtype=torch.float32).to(self.device)
|
||||
|
||||
@ -239,6 +294,15 @@ class Pano2MeshSRPipeline:
|
||||
def get_edge_image_by_depth(
|
||||
self, depth: torch.Tensor, dilate_iter: int = 1
|
||||
) -> np.ndarray:
|
||||
"""Computes edge image from depth map.
|
||||
|
||||
Args:
|
||||
depth (torch.Tensor): Depth map tensor.
|
||||
dilate_iter (int, optional): Number of dilation iterations.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Edge image.
|
||||
"""
|
||||
if isinstance(depth, torch.Tensor):
|
||||
depth = depth.cpu().detach().numpy()
|
||||
|
||||
@ -253,6 +317,15 @@ class Pano2MeshSRPipeline:
|
||||
def mesh_repair_by_greedy_view_selection(
|
||||
self, pose_dict: dict[str, torch.Tensor], output_dir: str
|
||||
) -> list:
|
||||
"""Repairs mesh by selecting views greedily and inpainting missing regions.
|
||||
|
||||
Args:
|
||||
pose_dict (dict[str, torch.Tensor]): Dictionary of poses for inpainting.
|
||||
output_dir (str): Directory to save visualizations.
|
||||
|
||||
Returns:
|
||||
list: List of inpainted panoramas with poses.
|
||||
"""
|
||||
inpainted_panos_w_pose = []
|
||||
while len(pose_dict) > 0:
|
||||
logger.info(f"Repairing mesh left rounds {len(pose_dict)}")
|
||||
@ -343,6 +416,17 @@ class Pano2MeshSRPipeline:
|
||||
distances: torch.Tensor,
|
||||
pano_mask: torch.Tensor,
|
||||
) -> tuple[torch.Tensor]:
|
||||
"""Inpaints missing regions in a panorama.
|
||||
|
||||
Args:
|
||||
idx (int): Index of the panorama.
|
||||
colors (torch.Tensor): RGB image tensor.
|
||||
distances (torch.Tensor): Distance map tensor.
|
||||
pano_mask (torch.Tensor): Mask tensor.
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor]: Inpainted RGB image, distances, and normals.
|
||||
"""
|
||||
mask = (pano_mask[None, ..., None] > 0.5).float()
|
||||
mask = mask.permute(0, 3, 1, 2)
|
||||
mask = dilation(mask, kernel=self.kernel)
|
||||
@ -364,6 +448,14 @@ class Pano2MeshSRPipeline:
|
||||
def preprocess_pano(
|
||||
self, image: Image.Image | str
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Preprocesses a panoramic image for mesh generation.
|
||||
|
||||
Args:
|
||||
image (Image.Image | str): Input image or path.
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor]: Preprocessed RGB and depth tensors.
|
||||
"""
|
||||
if isinstance(image, str):
|
||||
image = Image.open(image)
|
||||
|
||||
@ -387,6 +479,17 @@ class Pano2MeshSRPipeline:
|
||||
def pano_to_perpective(
|
||||
self, pano_image: torch.Tensor, pitch: float, yaw: float, fov: float
|
||||
) -> torch.Tensor:
|
||||
"""Converts a panoramic image to a perspective view.
|
||||
|
||||
Args:
|
||||
pano_image (torch.Tensor): Panoramic image tensor.
|
||||
pitch (float): Pitch angle.
|
||||
yaw (float): Yaw angle.
|
||||
fov (float): Field of view.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Perspective image tensor.
|
||||
"""
|
||||
rots = dict(
|
||||
roll=0,
|
||||
pitch=pitch,
|
||||
@ -404,6 +507,14 @@ class Pano2MeshSRPipeline:
|
||||
return perspective
|
||||
|
||||
def pano_to_cubemap(self, pano_rgb: torch.Tensor):
|
||||
"""Converts a panoramic RGB image to six cubemap views.
|
||||
|
||||
Args:
|
||||
pano_rgb (torch.Tensor): Panoramic RGB image tensor.
|
||||
|
||||
Returns:
|
||||
list: List of cubemap RGB tensors.
|
||||
"""
|
||||
# Define six canonical cube directions in (pitch, yaw)
|
||||
directions = [
|
||||
(0, 0),
|
||||
@ -424,6 +535,11 @@ class Pano2MeshSRPipeline:
|
||||
return cubemaps_rgb
|
||||
|
||||
def save_mesh(self, output_path: str) -> None:
|
||||
"""Saves the mesh to a file.
|
||||
|
||||
Args:
|
||||
output_path (str): Path to save the mesh file.
|
||||
"""
|
||||
vertices_np = self.vertices.T.cpu().numpy()
|
||||
colors_np = self.colors.T.cpu().numpy()
|
||||
faces_np = self.faces.T.cpu().numpy()
|
||||
@ -434,6 +550,14 @@ class Pano2MeshSRPipeline:
|
||||
mesh.export(output_path)
|
||||
|
||||
def mesh_pose_to_gs_pose(self, mesh_pose: torch.Tensor) -> np.ndarray:
|
||||
"""Converts mesh pose to 3D Gaussian Splatting pose.
|
||||
|
||||
Args:
|
||||
mesh_pose (torch.Tensor): Mesh pose tensor.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Converted pose matrix.
|
||||
"""
|
||||
pose = mesh_pose.clone()
|
||||
pose[0, :] *= -1
|
||||
pose[1, :] *= -1
|
||||
@ -450,6 +574,15 @@ class Pano2MeshSRPipeline:
|
||||
return c2w
|
||||
|
||||
def __call__(self, pano_image: Image.Image | str, output_dir: str):
|
||||
"""Runs the pipeline to generate mesh and 3DGS data from a panoramic image.
|
||||
|
||||
Args:
|
||||
pano_image (Image.Image | str): Input panoramic image or path.
|
||||
output_dir (str): Directory to save outputs.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.init_mesh_params()
|
||||
pano_rgb, pano_depth = self.preprocess_pano(pano_image)
|
||||
self.sup_pool = SupInfoPool()
|
||||
|
||||
@ -24,11 +24,27 @@ __all__ = [
|
||||
"Scene3DItemEnum",
|
||||
"SpatialRelationEnum",
|
||||
"RobotItemEnum",
|
||||
"LayoutInfo",
|
||||
"AssetType",
|
||||
"SimAssetMapper",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RenderItems(str, Enum):
|
||||
"""Enumeration of render item types for 3D scenes.
|
||||
|
||||
Attributes:
|
||||
IMAGE: Color image.
|
||||
ALPHA: Mask image.
|
||||
VIEW_NORMAL: View-space normal image.
|
||||
GLOBAL_NORMAL: World-space normal image.
|
||||
POSITION_MAP: Position map image.
|
||||
DEPTH: Depth image.
|
||||
ALBEDO: Albedo image.
|
||||
DIFFUSE: Diffuse image.
|
||||
"""
|
||||
|
||||
IMAGE = "image_color"
|
||||
ALPHA = "image_mask"
|
||||
VIEW_NORMAL = "image_view_normal"
|
||||
@ -41,6 +57,21 @@ class RenderItems(str, Enum):
|
||||
|
||||
@dataclass
|
||||
class Scene3DItemEnum(str, Enum):
|
||||
"""Enumeration of 3D scene item categories.
|
||||
|
||||
Attributes:
|
||||
BACKGROUND: Background objects.
|
||||
CONTEXT: Contextual objects.
|
||||
ROBOT: Robot entity.
|
||||
MANIPULATED_OBJS: Objects manipulated by the robot.
|
||||
DISTRACTOR_OBJS: Distractor objects.
|
||||
OTHERS: Other objects.
|
||||
|
||||
Methods:
|
||||
object_list(layout_relation): Returns a list of objects in the scene.
|
||||
object_mapping(layout_relation): Returns a mapping from object to category.
|
||||
"""
|
||||
|
||||
BACKGROUND = "background"
|
||||
CONTEXT = "context"
|
||||
ROBOT = "robot"
|
||||
@ -50,6 +81,14 @@ class Scene3DItemEnum(str, Enum):
|
||||
|
||||
@classmethod
|
||||
def object_list(cls, layout_relation: dict) -> list:
|
||||
"""Returns a list of objects in the scene.
|
||||
|
||||
Args:
|
||||
layout_relation: Dictionary mapping categories to objects.
|
||||
|
||||
Returns:
|
||||
List of objects in the scene.
|
||||
"""
|
||||
return (
|
||||
[
|
||||
layout_relation[cls.BACKGROUND.value],
|
||||
@ -61,6 +100,14 @@ class Scene3DItemEnum(str, Enum):
|
||||
|
||||
@classmethod
|
||||
def object_mapping(cls, layout_relation):
|
||||
"""Returns a mapping from object to category.
|
||||
|
||||
Args:
|
||||
layout_relation: Dictionary mapping categories to objects.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping object names to their category.
|
||||
"""
|
||||
relation_mapping = {
|
||||
# layout_relation[cls.ROBOT.value]: cls.ROBOT.value,
|
||||
layout_relation[cls.BACKGROUND.value]: cls.BACKGROUND.value,
|
||||
@ -84,6 +131,15 @@ class Scene3DItemEnum(str, Enum):
|
||||
|
||||
@dataclass
|
||||
class SpatialRelationEnum(str, Enum):
|
||||
"""Enumeration of spatial relations for objects in a scene.
|
||||
|
||||
Attributes:
|
||||
ON: Objects on a surface (e.g., table).
|
||||
IN: Objects in a container or room.
|
||||
INSIDE: Objects inside a shelf or rack.
|
||||
FLOOR: Objects on the floor.
|
||||
"""
|
||||
|
||||
ON = "ON" # objects on the table
|
||||
IN = "IN" # objects in the room
|
||||
INSIDE = "INSIDE" # objects inside the shelf/rack
|
||||
@ -92,6 +148,14 @@ class SpatialRelationEnum(str, Enum):
|
||||
|
||||
@dataclass
|
||||
class RobotItemEnum(str, Enum):
|
||||
"""Enumeration of supported robot types.
|
||||
|
||||
Attributes:
|
||||
FRANKA: Franka robot.
|
||||
UR5: UR5 robot.
|
||||
PIPER: Piper robot.
|
||||
"""
|
||||
|
||||
FRANKA = "franka"
|
||||
UR5 = "ur5"
|
||||
PIPER = "piper"
|
||||
@ -99,6 +163,18 @@ class RobotItemEnum(str, Enum):
|
||||
|
||||
@dataclass
|
||||
class LayoutInfo(DataClassJsonMixin):
|
||||
"""Data structure for layout information in a 3D scene.
|
||||
|
||||
Attributes:
|
||||
tree: Hierarchical structure of scene objects.
|
||||
relation: Spatial relations between objects.
|
||||
objs_desc: Descriptions of objects.
|
||||
objs_mapping: Mapping from object names to categories.
|
||||
assets: Asset file paths for objects.
|
||||
quality: Quality information for assets.
|
||||
position: Position coordinates for objects.
|
||||
"""
|
||||
|
||||
tree: dict[str, list]
|
||||
relation: dict[str, str | list[str]]
|
||||
objs_desc: dict[str, str] = field(default_factory=dict)
|
||||
@ -106,3 +182,64 @@ class LayoutInfo(DataClassJsonMixin):
|
||||
assets: dict[str, str] = field(default_factory=dict)
|
||||
quality: dict[str, str] = field(default_factory=dict)
|
||||
position: dict[str, list[float]] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssetType(str):
|
||||
"""Enumeration for asset types.
|
||||
|
||||
Supported types:
|
||||
MJCF: MuJoCo XML format.
|
||||
USD: Universal Scene Description format.
|
||||
URDF: Unified Robot Description Format.
|
||||
MESH: Mesh file format.
|
||||
"""
|
||||
|
||||
MJCF = "mjcf"
|
||||
USD = "usd"
|
||||
URDF = "urdf"
|
||||
MESH = "mesh"
|
||||
|
||||
|
||||
class SimAssetMapper:
|
||||
"""Maps simulator names to asset types.
|
||||
|
||||
Provides a mapping from simulator names to their corresponding asset type.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from embodied_gen.utils.enum import SimAssetMapper
|
||||
asset_type = SimAssetMapper["isaacsim"]
|
||||
print(asset_type) # Output: 'usd'
|
||||
```
|
||||
|
||||
Methods:
|
||||
__class_getitem__(key): Returns the asset type for a given simulator name.
|
||||
"""
|
||||
|
||||
_mapping = dict(
|
||||
ISAACSIM=AssetType.USD,
|
||||
ISAACGYM=AssetType.URDF,
|
||||
MUJOCO=AssetType.MJCF,
|
||||
GENESIS=AssetType.MJCF,
|
||||
SAPIEN=AssetType.URDF,
|
||||
PYBULLET=AssetType.URDF,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __class_getitem__(cls, key: str):
|
||||
"""Returns the asset type for a given simulator name.
|
||||
|
||||
Args:
|
||||
key: Name of the simulator.
|
||||
|
||||
Returns:
|
||||
AssetType corresponding to the simulator.
|
||||
|
||||
Raises:
|
||||
KeyError: If the simulator name is not recognized.
|
||||
"""
|
||||
key = key.upper()
|
||||
if key.startswith("SAPIEN"):
|
||||
key = "SAPIEN"
|
||||
return cls._mapping[key]
|
||||
|
||||
@ -45,13 +45,13 @@ __all__ = [
|
||||
|
||||
|
||||
def matrix_to_pose(matrix: np.ndarray) -> list[float]:
|
||||
"""Convert a 4x4 transformation matrix to a pose (x, y, z, qx, qy, qz, qw).
|
||||
"""Converts a 4x4 transformation matrix to a pose (x, y, z, qx, qy, qz, qw).
|
||||
|
||||
Args:
|
||||
matrix (np.ndarray): 4x4 transformation matrix.
|
||||
|
||||
Returns:
|
||||
List[float]: Pose as [x, y, z, qx, qy, qz, qw].
|
||||
list[float]: Pose as [x, y, z, qx, qy, qz, qw].
|
||||
"""
|
||||
x, y, z = matrix[:3, 3]
|
||||
rot_mat = matrix[:3, :3]
|
||||
@ -62,13 +62,13 @@ def matrix_to_pose(matrix: np.ndarray) -> list[float]:
|
||||
|
||||
|
||||
def pose_to_matrix(pose: list[float]) -> np.ndarray:
|
||||
"""Convert pose (x, y, z, qx, qy, qz, qw) to a 4x4 transformation matrix.
|
||||
"""Converts pose (x, y, z, qx, qy, qz, qw) to a 4x4 transformation matrix.
|
||||
|
||||
Args:
|
||||
List[float]: Pose as [x, y, z, qx, qy, qz, qw].
|
||||
pose (list[float]): Pose as [x, y, z, qx, qy, qz, qw].
|
||||
|
||||
Returns:
|
||||
matrix (np.ndarray): 4x4 transformation matrix.
|
||||
np.ndarray: 4x4 transformation matrix.
|
||||
"""
|
||||
x, y, z, qx, qy, qz, qw = pose
|
||||
r = R.from_quat([qx, qy, qz, qw])
|
||||
@ -82,6 +82,16 @@ def pose_to_matrix(pose: list[float]) -> np.ndarray:
|
||||
def compute_xy_bbox(
|
||||
vertices: np.ndarray, col_x: int = 0, col_y: int = 1
|
||||
) -> list[float]:
|
||||
"""Computes the bounding box in XY plane for given vertices.
|
||||
|
||||
Args:
|
||||
vertices (np.ndarray): Vertex coordinates.
|
||||
col_x (int, optional): Column index for X.
|
||||
col_y (int, optional): Column index for Y.
|
||||
|
||||
Returns:
|
||||
list[float]: [min_x, max_x, min_y, max_y]
|
||||
"""
|
||||
x_vals = vertices[:, col_x]
|
||||
y_vals = vertices[:, col_y]
|
||||
return x_vals.min(), x_vals.max(), y_vals.min(), y_vals.max()
|
||||
@ -92,6 +102,16 @@ def has_iou_conflict(
|
||||
placed_boxes: list[list[float]],
|
||||
iou_threshold: float = 0.0,
|
||||
) -> bool:
|
||||
"""Checks for intersection-over-union conflict between boxes.
|
||||
|
||||
Args:
|
||||
new_box (list[float]): New box coordinates.
|
||||
placed_boxes (list[list[float]]): List of placed box coordinates.
|
||||
iou_threshold (float, optional): IOU threshold.
|
||||
|
||||
Returns:
|
||||
bool: True if conflict exists, False otherwise.
|
||||
"""
|
||||
new_min_x, new_max_x, new_min_y, new_max_y = new_box
|
||||
for min_x, max_x, min_y, max_y in placed_boxes:
|
||||
ix1 = max(new_min_x, min_x)
|
||||
@ -105,7 +125,14 @@ def has_iou_conflict(
|
||||
|
||||
|
||||
def with_seed(seed_attr_name: str = "seed"):
|
||||
"""A parameterized decorator that temporarily sets the random seed."""
|
||||
"""Decorator to temporarily set the random seed for reproducibility.
|
||||
|
||||
Args:
|
||||
seed_attr_name (str, optional): Name of the seed argument.
|
||||
|
||||
Returns:
|
||||
function: Decorator function.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
@ -143,6 +170,20 @@ def compute_convex_hull_path(
|
||||
y_axis: int = 1,
|
||||
z_axis: int = 2,
|
||||
) -> Path:
|
||||
"""Computes a dense convex hull path for the top surface of a mesh.
|
||||
|
||||
Args:
|
||||
vertices (np.ndarray): Mesh vertices.
|
||||
z_threshold (float, optional): Z threshold for top surface.
|
||||
interp_per_edge (int, optional): Interpolation points per edge.
|
||||
margin (float, optional): Margin for polygon buffer.
|
||||
x_axis (int, optional): X axis index.
|
||||
y_axis (int, optional): Y axis index.
|
||||
z_axis (int, optional): Z axis index.
|
||||
|
||||
Returns:
|
||||
Path: Matplotlib path object for the convex hull.
|
||||
"""
|
||||
top_vertices = vertices[
|
||||
vertices[:, z_axis] > vertices[:, z_axis].max() - z_threshold
|
||||
]
|
||||
@ -170,6 +211,15 @@ def compute_convex_hull_path(
|
||||
|
||||
|
||||
def find_parent_node(node: str, tree: dict) -> str | None:
|
||||
"""Finds the parent node of a given node in a tree.
|
||||
|
||||
Args:
|
||||
node (str): Node name.
|
||||
tree (dict): Tree structure.
|
||||
|
||||
Returns:
|
||||
str | None: Parent node name or None.
|
||||
"""
|
||||
for parent, children in tree.items():
|
||||
if any(child[0] == node for child in children):
|
||||
return parent
|
||||
@ -177,6 +227,16 @@ def find_parent_node(node: str, tree: dict) -> str | None:
|
||||
|
||||
|
||||
def all_corners_inside(hull: Path, box: list, threshold: int = 3) -> bool:
|
||||
"""Checks if at least `threshold` corners of a box are inside a hull.
|
||||
|
||||
Args:
|
||||
hull (Path): Convex hull path.
|
||||
box (list): Box coordinates [x1, x2, y1, y2].
|
||||
threshold (int, optional): Minimum corners inside.
|
||||
|
||||
Returns:
|
||||
bool: True if enough corners are inside.
|
||||
"""
|
||||
x1, x2, y1, y2 = box
|
||||
corners = [[x1, y1], [x2, y1], [x1, y2], [x2, y2]]
|
||||
|
||||
@ -187,6 +247,15 @@ def all_corners_inside(hull: Path, box: list, threshold: int = 3) -> bool:
|
||||
def compute_axis_rotation_quat(
|
||||
axis: Literal["x", "y", "z"], angle_rad: float
|
||||
) -> list[float]:
|
||||
"""Computes quaternion for rotation around a given axis.
|
||||
|
||||
Args:
|
||||
axis (Literal["x", "y", "z"]): Axis of rotation.
|
||||
angle_rad (float): Rotation angle in radians.
|
||||
|
||||
Returns:
|
||||
list[float]: Quaternion [x, y, z, w].
|
||||
"""
|
||||
if axis.lower() == "x":
|
||||
q = Quaternion(axis=[1, 0, 0], angle=angle_rad)
|
||||
elif axis.lower() == "y":
|
||||
@ -202,6 +271,15 @@ def compute_axis_rotation_quat(
|
||||
def quaternion_multiply(
|
||||
init_quat: list[float], rotate_quat: list[float]
|
||||
) -> list[float]:
|
||||
"""Multiplies two quaternions.
|
||||
|
||||
Args:
|
||||
init_quat (list[float]): Initial quaternion [x, y, z, w].
|
||||
rotate_quat (list[float]): Rotation quaternion [x, y, z, w].
|
||||
|
||||
Returns:
|
||||
list[float]: Resulting quaternion [x, y, z, w].
|
||||
"""
|
||||
qx, qy, qz, qw = init_quat
|
||||
q1 = Quaternion(w=qw, x=qx, y=qy, z=qz)
|
||||
qx, qy, qz, qw = rotate_quat
|
||||
@ -217,7 +295,17 @@ def check_reachable(
|
||||
min_reach: float = 0.25,
|
||||
max_reach: float = 0.85,
|
||||
) -> bool:
|
||||
"""Check if the target point is within the reachable range."""
|
||||
"""Checks if the target point is within the reachable range.
|
||||
|
||||
Args:
|
||||
base_xyz (np.ndarray): Base position.
|
||||
reach_xyz (np.ndarray): Target position.
|
||||
min_reach (float, optional): Minimum reach distance.
|
||||
max_reach (float, optional): Maximum reach distance.
|
||||
|
||||
Returns:
|
||||
bool: True if reachable, False otherwise.
|
||||
"""
|
||||
distance = np.linalg.norm(reach_xyz - base_xyz)
|
||||
|
||||
return min_reach < distance < max_reach
|
||||
@ -238,26 +326,31 @@ def bfs_placement(
|
||||
robot_dim: float = 0.12,
|
||||
seed: int = None,
|
||||
) -> LayoutInfo:
|
||||
"""Place objects in the layout using BFS traversal.
|
||||
"""Places objects in a scene layout using BFS traversal.
|
||||
|
||||
Args:
|
||||
layout_file: Path to the JSON file defining the layout structure and assets.
|
||||
floor_margin: Z-offset for the background object, typically for objects placed on the floor.
|
||||
beside_margin: Minimum margin for objects placed 'beside' their parent, used when 'on' placement fails.
|
||||
max_attempts: Maximum number of attempts to find a non-overlapping position for an object.
|
||||
init_rpy: Initial Roll-Pitch-Yaw rotation rad applied to all object meshes to align the mesh's
|
||||
coordinate system with the world's (e.g., Z-up).
|
||||
rotate_objs: If True, apply a random rotation around the Z-axis for manipulated and distractor objects.
|
||||
rotate_bg: If True, apply a random rotation around the Y-axis for the background object.
|
||||
rotate_context: If True, apply a random rotation around the Z-axis for the context object.
|
||||
limit_reach_range: If set, enforce a check that manipulated objects are within the robot's reach range, in meter.
|
||||
max_orient_diff: If set, enforce a check that manipulated objects are within the robot's orientation range, in degree.
|
||||
robot_dim: The approximate dimension (e.g., diameter) of the robot for box representation.
|
||||
seed: Random seed for reproducible placement.
|
||||
layout_file (str): Path to layout JSON file generated from `layout-cli`.
|
||||
floor_margin (float, optional): Z-offset for objects placed on the floor.
|
||||
beside_margin (float, optional): Minimum margin for objects placed 'beside' their parent, used when 'on' placement fails.
|
||||
max_attempts (int, optional): Max attempts for a non-overlapping placement.
|
||||
init_rpy (tuple, optional): Initial rotation (rpy).
|
||||
rotate_objs (bool, optional): Whether to random rotate objects.
|
||||
rotate_bg (bool, optional): Whether to random rotate background.
|
||||
rotate_context (bool, optional): Whether to random rotate context asset.
|
||||
limit_reach_range (tuple[float, float] | None, optional): If set, enforce a check that manipulated objects are within the robot's reach range, in meter.
|
||||
max_orient_diff (float | None, optional): If set, enforce a check that manipulated objects are within the robot's orientation range, in degree.
|
||||
robot_dim (float, optional): The approximate robot size.
|
||||
seed (int, optional): Random seed for reproducible placement.
|
||||
|
||||
Returns:
|
||||
A :class:`LayoutInfo` object containing the objects and their final computed 7D poses
|
||||
([x, y, z, qx, qy, qz, qw]).
|
||||
LayoutInfo: Layout information with object poses.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from embodied_gen.utils.geometry import bfs_placement
|
||||
layout = bfs_placement("scene_layout.json", seed=42)
|
||||
print(layout.position)
|
||||
```
|
||||
"""
|
||||
layout_info = LayoutInfo.from_dict(json.load(open(layout_file, "r")))
|
||||
asset_dir = os.path.dirname(layout_file)
|
||||
@ -478,6 +571,13 @@ def bfs_placement(
|
||||
def compose_mesh_scene(
|
||||
layout_info: LayoutInfo, out_scene_path: str, with_bg: bool = False
|
||||
) -> None:
|
||||
"""Composes a mesh scene from layout information and saves to file.
|
||||
|
||||
Args:
|
||||
layout_info (LayoutInfo): Layout information.
|
||||
out_scene_path (str): Output scene file path.
|
||||
with_bg (bool, optional): Include background mesh.
|
||||
"""
|
||||
object_mapping = Scene3DItemEnum.object_mapping(layout_info.relation)
|
||||
scene = trimesh.Scene()
|
||||
for node in layout_info.assets:
|
||||
@ -505,6 +605,16 @@ def compose_mesh_scene(
|
||||
def compute_pinhole_intrinsics(
|
||||
image_w: int, image_h: int, fov_deg: float
|
||||
) -> np.ndarray:
|
||||
"""Computes pinhole camera intrinsic matrix from image size and FOV.
|
||||
|
||||
Args:
|
||||
image_w (int): Image width.
|
||||
image_h (int): Image height.
|
||||
fov_deg (float): Field of view in degrees.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Intrinsic matrix K.
|
||||
"""
|
||||
fov_rad = np.deg2rad(fov_deg)
|
||||
fx = image_w / (2 * np.tan(fov_rad / 2))
|
||||
fy = fx # assuming square pixels
|
||||
|
||||
@ -45,7 +45,35 @@ CONFIG_FILE = "embodied_gen/utils/gpt_config.yaml"
|
||||
|
||||
|
||||
class GPTclient:
|
||||
"""A client to interact with the GPT model via OpenAI or Azure API."""
|
||||
"""A client to interact with GPT models via OpenAI or Azure API.
|
||||
|
||||
Supports text and image prompts, connection checking, and configurable parameters.
|
||||
|
||||
Args:
|
||||
endpoint (str): API endpoint URL.
|
||||
api_key (str): API key for authentication.
|
||||
model_name (str, optional): Model name to use.
|
||||
api_version (str, optional): API version (for Azure).
|
||||
check_connection (bool, optional): Whether to check API connection.
|
||||
verbose (bool, optional): Enable verbose logging.
|
||||
|
||||
Example:
|
||||
```sh
|
||||
export ENDPOINT="https://yfb-openai-sweden.openai.azure.com"
|
||||
export API_KEY="xxxxxx"
|
||||
export API_VERSION="2025-03-01-preview"
|
||||
export MODEL_NAME="yfb-gpt-4o-sweden"
|
||||
```
|
||||
```py
|
||||
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
||||
|
||||
response = GPT_CLIENT.query("Describe the physics of a falling apple.")
|
||||
response = GPT_CLIENT.query(
|
||||
text_prompt="Describe the content in each image."
|
||||
image_base64=["path/to/image1.png", "path/to/image2.jpg"],
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -82,6 +110,7 @@ class GPTclient:
|
||||
stop=(stop_after_attempt(10) | stop_after_delay(30)),
|
||||
)
|
||||
def completion_with_backoff(self, **kwargs):
|
||||
"""Performs a chat completion request with retry/backoff."""
|
||||
return self.client.chat.completions.create(**kwargs)
|
||||
|
||||
def query(
|
||||
@ -91,19 +120,16 @@ class GPTclient:
|
||||
system_role: Optional[str] = None,
|
||||
params: Optional[dict] = None,
|
||||
) -> Optional[str]:
|
||||
"""Queries the GPT model with a text and optional image prompts.
|
||||
"""Queries the GPT model with text and optional image prompts.
|
||||
|
||||
Args:
|
||||
text_prompt (str): The main text input that the model responds to.
|
||||
image_base64 (Optional[List[str]]): A list of image base64 strings
|
||||
or local image paths or PIL.Image to accompany the text prompt.
|
||||
system_role (Optional[str]): Optional system-level instructions
|
||||
that specify the behavior of the assistant.
|
||||
params (Optional[dict]): Additional parameters for GPT setting.
|
||||
text_prompt (str): Main text input.
|
||||
image_base64 (Optional[list[str | Image.Image]], optional): List of image base64 strings, file paths, or PIL Images.
|
||||
system_role (Optional[str], optional): System-level instructions.
|
||||
params (Optional[dict], optional): Additional GPT parameters.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The response content generated by the model based on
|
||||
the prompt. Returns `None` if an error occurs.
|
||||
Optional[str]: Model response content, or None if error.
|
||||
"""
|
||||
if system_role is None:
|
||||
system_role = "You are a highly knowledgeable assistant specializing in physics, engineering, and object properties." # noqa
|
||||
@ -177,7 +203,11 @@ class GPTclient:
|
||||
return response
|
||||
|
||||
def check_connection(self) -> None:
|
||||
"""Check whether the GPT API connection is working."""
|
||||
"""Checks whether the GPT API connection is working.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If connection fails.
|
||||
"""
|
||||
try:
|
||||
response = self.completion_with_backoff(
|
||||
messages=[
|
||||
|
||||
@ -69,6 +69,40 @@ def render_asset3d(
|
||||
no_index_file: bool = False,
|
||||
with_mtl: bool = True,
|
||||
) -> list[str]:
|
||||
"""Renders a 3D mesh asset and returns output image paths.
|
||||
|
||||
Args:
|
||||
mesh_path (str): Path to the mesh file.
|
||||
output_root (str): Directory to save outputs.
|
||||
distance (float, optional): Camera distance.
|
||||
num_images (int, optional): Number of views to render.
|
||||
elevation (list[float], optional): Camera elevation angles.
|
||||
pbr_light_factor (float, optional): PBR lighting factor.
|
||||
return_key (str, optional): Glob pattern for output images.
|
||||
output_subdir (str, optional): Subdirectory for outputs.
|
||||
gen_color_mp4 (bool, optional): Generate color MP4 video.
|
||||
gen_viewnormal_mp4 (bool, optional): Generate view normal MP4.
|
||||
gen_glonormal_mp4 (bool, optional): Generate global normal MP4.
|
||||
no_index_file (bool, optional): Skip index file saving.
|
||||
with_mtl (bool, optional): Use mesh material.
|
||||
|
||||
Returns:
|
||||
list[str]: List of output image file paths.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from embodied_gen.utils.process_media import render_asset3d
|
||||
|
||||
image_paths = render_asset3d(
|
||||
mesh_path="path_to_mesh.obj",
|
||||
output_root="path_to_save_dir",
|
||||
num_images=6,
|
||||
elevation=(30, -30),
|
||||
output_subdir="renders",
|
||||
no_index_file=True,
|
||||
)
|
||||
```
|
||||
"""
|
||||
input_args = dict(
|
||||
mesh_path=mesh_path,
|
||||
output_root=output_root,
|
||||
@ -95,6 +129,13 @@ def render_asset3d(
|
||||
|
||||
|
||||
def merge_images_video(color_images, normal_images, output_path) -> None:
|
||||
"""Merges color and normal images into a video.
|
||||
|
||||
Args:
|
||||
color_images (list[np.ndarray]): List of color images.
|
||||
normal_images (list[np.ndarray]): List of normal images.
|
||||
output_path (str): Path to save the output video.
|
||||
"""
|
||||
width = color_images[0].shape[1]
|
||||
combined_video = [
|
||||
np.hstack([rgb_img[:, : width // 2], normal_img[:, width // 2 :]])
|
||||
@ -108,7 +149,13 @@ def merge_images_video(color_images, normal_images, output_path) -> None:
|
||||
def merge_video_video(
|
||||
video_path1: str, video_path2: str, output_path: str
|
||||
) -> None:
|
||||
"""Merge two videos by the left half and the right half of the videos."""
|
||||
"""Merges two videos by combining their left and right halves.
|
||||
|
||||
Args:
|
||||
video_path1 (str): Path to first video.
|
||||
video_path2 (str): Path to second video.
|
||||
output_path (str): Path to save the merged video.
|
||||
"""
|
||||
clip1 = VideoFileClip(video_path1)
|
||||
clip2 = VideoFileClip(video_path2)
|
||||
|
||||
@ -127,6 +174,16 @@ def filter_small_connected_components(
|
||||
area_ratio: float,
|
||||
connectivity: int = 8,
|
||||
) -> np.ndarray:
|
||||
"""Removes small connected components from a binary mask.
|
||||
|
||||
Args:
|
||||
mask (Union[Image.Image, np.ndarray]): Input mask.
|
||||
area_ratio (float): Minimum area ratio for components.
|
||||
connectivity (int, optional): Connectivity for labeling.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Mask with small components removed.
|
||||
"""
|
||||
if isinstance(mask, Image.Image):
|
||||
mask = np.array(mask)
|
||||
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(
|
||||
@ -152,6 +209,16 @@ def filter_image_small_connected_components(
|
||||
area_ratio: float = 10,
|
||||
connectivity: int = 8,
|
||||
) -> np.ndarray:
|
||||
"""Removes small connected components from the alpha channel of an image.
|
||||
|
||||
Args:
|
||||
image (Union[Image.Image, np.ndarray]): Input image.
|
||||
area_ratio (float, optional): Minimum area ratio.
|
||||
connectivity (int, optional): Connectivity for labeling.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Image with filtered alpha channel.
|
||||
"""
|
||||
if isinstance(image, Image.Image):
|
||||
image = image.convert("RGBA")
|
||||
image = np.array(image)
|
||||
@ -169,6 +236,24 @@ def combine_images_to_grid(
|
||||
target_wh: tuple[int, int] = (512, 512),
|
||||
image_mode: str = "RGB",
|
||||
) -> list[Image.Image]:
|
||||
"""Combines multiple images into a grid.
|
||||
|
||||
Args:
|
||||
images (list[str | Image.Image]): List of image paths or PIL Images.
|
||||
cat_row_col (tuple[int, int], optional): Grid rows and columns.
|
||||
target_wh (tuple[int, int], optional): Target image size.
|
||||
image_mode (str, optional): Image mode.
|
||||
|
||||
Returns:
|
||||
list[Image.Image]: List containing the grid image.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from embodied_gen.utils.process_media import combine_images_to_grid
|
||||
grid = combine_images_to_grid(["img1.png", "img2.png"])
|
||||
grid[0].save("grid.png")
|
||||
```
|
||||
"""
|
||||
n_images = len(images)
|
||||
if n_images == 1:
|
||||
return images
|
||||
@ -196,6 +281,19 @@ def combine_images_to_grid(
|
||||
|
||||
|
||||
class SceneTreeVisualizer:
|
||||
"""Visualizes a scene tree layout using networkx and matplotlib.
|
||||
|
||||
Args:
|
||||
layout_info (LayoutInfo): Layout information for the scene.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from embodied_gen.utils.process_media import SceneTreeVisualizer
|
||||
visualizer = SceneTreeVisualizer(layout_info)
|
||||
visualizer.render(save_path="tree.png")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, layout_info: LayoutInfo) -> None:
|
||||
self.tree = layout_info.tree
|
||||
self.relation = layout_info.relation
|
||||
@ -274,6 +372,14 @@ class SceneTreeVisualizer:
|
||||
dpi=300,
|
||||
title: str = "Scene 3D Hierarchy Tree",
|
||||
):
|
||||
"""Renders the scene tree and saves to file.
|
||||
|
||||
Args:
|
||||
save_path (str): Path to save the rendered image.
|
||||
figsize (tuple, optional): Figure size.
|
||||
dpi (int, optional): Image DPI.
|
||||
title (str, optional): Plot image title.
|
||||
"""
|
||||
node_colors = [
|
||||
self.role_colors[self._get_node_role(n)] for n in self.G.nodes
|
||||
]
|
||||
@ -350,6 +456,14 @@ class SceneTreeVisualizer:
|
||||
|
||||
|
||||
def load_scene_dict(file_path: str) -> dict:
|
||||
"""Loads a scene description dictionary from a file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the scene description file.
|
||||
|
||||
Returns:
|
||||
dict: Mapping from scene ID to description.
|
||||
"""
|
||||
scene_dict = {}
|
||||
with open(file_path, "r", encoding='utf-8') as f:
|
||||
for line in f:
|
||||
@ -363,12 +477,28 @@ def load_scene_dict(file_path: str) -> dict:
|
||||
|
||||
|
||||
def is_image_file(filename: str) -> bool:
|
||||
"""Checks if a filename is an image file.
|
||||
|
||||
Args:
|
||||
filename (str): Filename to check.
|
||||
|
||||
Returns:
|
||||
bool: True if image file, False otherwise.
|
||||
"""
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
|
||||
return mime_type is not None and mime_type.startswith('image')
|
||||
|
||||
|
||||
def parse_text_prompts(prompts: list[str]) -> list[str]:
|
||||
"""Parses text prompts from a list or file.
|
||||
|
||||
Args:
|
||||
prompts (list[str]): List of prompts or a file path.
|
||||
|
||||
Returns:
|
||||
list[str]: List of parsed prompts.
|
||||
"""
|
||||
if len(prompts) == 1 and prompts[0].endswith(".txt"):
|
||||
with open(prompts[0], "r") as f:
|
||||
prompts = [
|
||||
@ -386,13 +516,18 @@ def alpha_blend_rgba(
|
||||
"""Alpha blends a foreground RGBA image over a background RGBA image.
|
||||
|
||||
Args:
|
||||
fg_image: Foreground image. Can be a file path (str), a PIL Image,
|
||||
or a NumPy ndarray.
|
||||
bg_image: Background image. Can be a file path (str), a PIL Image,
|
||||
or a NumPy ndarray.
|
||||
fg_image: Foreground image (str, PIL Image, or ndarray).
|
||||
bg_image: Background image (str, PIL Image, or ndarray).
|
||||
|
||||
Returns:
|
||||
A PIL Image representing the alpha-blended result in RGBA mode.
|
||||
Image.Image: Alpha-blended RGBA image.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from embodied_gen.utils.process_media import alpha_blend_rgba
|
||||
result = alpha_blend_rgba("fg.png", "bg.png")
|
||||
result.save("blended.png")
|
||||
```
|
||||
"""
|
||||
if isinstance(fg_image, str):
|
||||
fg_image = Image.open(fg_image)
|
||||
@ -421,13 +556,11 @@ def check_object_edge_truncated(
|
||||
"""Checks if a binary object mask is truncated at the image edges.
|
||||
|
||||
Args:
|
||||
mask: A 2D binary NumPy array where nonzero values indicate the object region.
|
||||
edge_threshold: Number of pixels from each image edge to consider for truncation.
|
||||
Defaults to 5.
|
||||
mask (np.ndarray): 2D binary mask.
|
||||
edge_threshold (int, optional): Edge pixel threshold.
|
||||
|
||||
Returns:
|
||||
True if the object is fully enclosed (not truncated).
|
||||
False if the object touches or crosses any image boundary.
|
||||
bool: True if object is fully enclosed, False if truncated.
|
||||
"""
|
||||
top = mask[:edge_threshold, :].any()
|
||||
bottom = mask[-edge_threshold:, :].any()
|
||||
@ -440,6 +573,22 @@ def check_object_edge_truncated(
|
||||
def vcat_pil_images(
|
||||
images: list[Image.Image], image_mode: str = "RGB"
|
||||
) -> Image.Image:
|
||||
"""Vertically concatenates a list of PIL images.
|
||||
|
||||
Args:
|
||||
images (list[Image.Image]): List of images.
|
||||
image_mode (str, optional): Image mode.
|
||||
|
||||
Returns:
|
||||
Image.Image: Vertically concatenated image.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from embodied_gen.utils.process_media import vcat_pil_images
|
||||
img = vcat_pil_images([Image.open("a.png"), Image.open("b.png")])
|
||||
img.save("vcat.png")
|
||||
```
|
||||
"""
|
||||
widths, heights = zip(*(img.size for img in images))
|
||||
total_height = sum(heights)
|
||||
max_width = max(widths)
|
||||
|
||||
@ -69,6 +69,21 @@ def load_actor_from_urdf(
|
||||
update_mass: bool = False,
|
||||
scale: float | np.ndarray = 1.0,
|
||||
) -> sapien.pysapien.Entity:
|
||||
"""Load an sapien actor from a URDF file and add it to the scene.
|
||||
|
||||
Args:
|
||||
scene (sapien.Scene | ManiSkillScene): The simulation scene.
|
||||
file_path (str): Path to the URDF file.
|
||||
pose (sapien.Pose | None): Initial pose of the actor.
|
||||
env_idx (int): Environment index for multi-env setup.
|
||||
use_static (bool): Whether the actor is static.
|
||||
update_mass (bool): Whether to update the actor's mass from URDF.
|
||||
scale (float | np.ndarray): Scale factor for the actor.
|
||||
|
||||
Returns:
|
||||
sapien.pysapien.Entity: The created actor entity.
|
||||
"""
|
||||
|
||||
def _get_local_pose(origin_tag: ET.Element | None) -> sapien.Pose:
|
||||
local_pose = sapien.Pose(p=[0, 0, 0], q=[1, 0, 0, 0])
|
||||
if origin_tag is not None:
|
||||
@ -154,14 +169,17 @@ def load_assets_from_layout_file(
|
||||
init_quat: list[float] = [0, 0, 0, 1],
|
||||
env_idx: int = None,
|
||||
) -> dict[str, sapien.pysapien.Entity]:
|
||||
"""Load assets from `EmbodiedGen` layout-gen output and create actors in the scene.
|
||||
"""Load assets from an EmbodiedGen layout file and create sapien actors in the scene.
|
||||
|
||||
Args:
|
||||
scene (sapien.Scene | ManiSkillScene): The SAPIEN or ManiSkill scene to load assets into.
|
||||
layout (str): The layout file path.
|
||||
z_offset (float): Offset to apply to the Z-coordinate of non-context objects.
|
||||
init_quat (List[float]): Initial quaternion (x, y, z, w) for orientation adjustment.
|
||||
env_idx (int): Environment index for multi-environment setup.
|
||||
scene (ManiSkillScene | sapien.Scene): The sapien simulation scene.
|
||||
layout (str): Path to the embodiedgen layout file.
|
||||
z_offset (float): Z offset for non-context objects.
|
||||
init_quat (list[float]): Initial quaternion for orientation.
|
||||
env_idx (int): Environment index.
|
||||
|
||||
Returns:
|
||||
dict[str, sapien.pysapien.Entity]: Mapping from object names to actor entities.
|
||||
"""
|
||||
asset_root = os.path.dirname(layout)
|
||||
layout = LayoutInfo.from_dict(json.load(open(layout, "r")))
|
||||
@ -206,6 +224,19 @@ def load_mani_skill_robot(
|
||||
control_mode: str = "pd_joint_pos",
|
||||
backend_str: tuple[str, str] = ("cpu", "gpu"),
|
||||
) -> BaseAgent:
|
||||
"""Load a ManiSkill robot agent into the scene.
|
||||
|
||||
Args:
|
||||
scene (sapien.Scene | ManiSkillScene): The simulation scene.
|
||||
layout (LayoutInfo | str): Layout info or path to layout file.
|
||||
control_freq (int): Control frequency.
|
||||
robot_init_qpos_noise (float): Noise for initial joint positions.
|
||||
control_mode (str): Robot control mode.
|
||||
backend_str (tuple[str, str]): Simulation/render backend.
|
||||
|
||||
Returns:
|
||||
BaseAgent: The loaded robot agent.
|
||||
"""
|
||||
from mani_skill.agents import REGISTERED_AGENTS
|
||||
from mani_skill.envs.scene import ManiSkillScene
|
||||
from mani_skill.envs.utils.system.backend import (
|
||||
@ -278,14 +309,14 @@ def render_images(
|
||||
]
|
||||
] = None,
|
||||
) -> dict[str, Image.Image]:
|
||||
"""Render images from a given sapien camera.
|
||||
"""Render images from a given SAPIEN camera.
|
||||
|
||||
Args:
|
||||
camera (sapien.render.RenderCameraComponent): The camera to render from.
|
||||
render_keys (List[str]): Types of images to render (e.g., Color, Segmentation).
|
||||
camera (sapien.render.RenderCameraComponent): Camera to render from.
|
||||
render_keys (list[str], optional): Types of images to render.
|
||||
|
||||
Returns:
|
||||
Dict[str, Image.Image]: Dictionary of rendered images.
|
||||
dict[str, Image.Image]: Dictionary of rendered images.
|
||||
"""
|
||||
if render_keys is None:
|
||||
render_keys = [
|
||||
@ -341,11 +372,33 @@ def render_images(
|
||||
|
||||
|
||||
class SapienSceneManager:
|
||||
"""A class to manage SAPIEN simulator."""
|
||||
"""Manages SAPIEN simulation scenes, cameras, and rendering.
|
||||
|
||||
This class provides utilities for setting up scenes, adding cameras,
|
||||
stepping simulation, and rendering images.
|
||||
|
||||
Attributes:
|
||||
sim_freq (int): Simulation frequency.
|
||||
ray_tracing (bool): Whether to use ray tracing.
|
||||
device (str): Device for simulation.
|
||||
renderer (sapien.SapienRenderer): SAPIEN renderer.
|
||||
scene (sapien.Scene): Simulation scene.
|
||||
cameras (list): List of camera components.
|
||||
actors (dict): Mapping of actor names to entities.
|
||||
|
||||
Example see `embodied_gen/scripts/simulate_sapien.py`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, sim_freq: int, ray_tracing: bool, device: str = "cuda"
|
||||
) -> None:
|
||||
"""Initialize the scene manager.
|
||||
|
||||
Args:
|
||||
sim_freq (int): Simulation frequency.
|
||||
ray_tracing (bool): Enable ray tracing.
|
||||
device (str): Device for simulation.
|
||||
"""
|
||||
self.sim_freq = sim_freq
|
||||
self.ray_tracing = ray_tracing
|
||||
self.device = device
|
||||
@ -355,7 +408,11 @@ class SapienSceneManager:
|
||||
self.actors: dict[str, sapien.pysapien.Entity] = {}
|
||||
|
||||
def _setup_scene(self) -> sapien.Scene:
|
||||
"""Set up the SAPIEN scene with lighting and ground."""
|
||||
"""Set up the SAPIEN scene with lighting and ground.
|
||||
|
||||
Returns:
|
||||
sapien.Scene: The initialized scene.
|
||||
"""
|
||||
# Ray tracing settings
|
||||
if self.ray_tracing:
|
||||
sapien.render.set_camera_shader_dir("rt")
|
||||
@ -397,6 +454,18 @@ class SapienSceneManager:
|
||||
render_keys: list[str],
|
||||
sim_steps_per_control: int = 1,
|
||||
) -> dict:
|
||||
"""Step the simulation and render images from cameras.
|
||||
|
||||
Args:
|
||||
agent (BaseAgent): The robot agent.
|
||||
action (torch.Tensor): Action to apply.
|
||||
cameras (list): List of camera components.
|
||||
render_keys (list[str]): Types of images to render.
|
||||
sim_steps_per_control (int): Simulation steps per control.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary of rendered frames per camera.
|
||||
"""
|
||||
agent.set_action(action)
|
||||
frames = defaultdict(list)
|
||||
for _ in range(sim_steps_per_control):
|
||||
@ -417,13 +486,13 @@ class SapienSceneManager:
|
||||
image_hw: tuple[int, int],
|
||||
fovy_deg: float,
|
||||
) -> sapien.render.RenderCameraComponent:
|
||||
"""Create a single camera in the scene.
|
||||
"""Create a camera in the scene.
|
||||
|
||||
Args:
|
||||
cam_name (str): Name of the camera.
|
||||
pose (sapien.Pose): Camera pose p=(x, y, z), q=(w, x, y, z)
|
||||
image_hw (Tuple[int, int]): Image resolution (height, width) for cameras.
|
||||
fovy_deg (float): Field of view in degrees for cameras.
|
||||
cam_name (str): Camera name.
|
||||
pose (sapien.Pose): Camera pose.
|
||||
image_hw (tuple[int, int]): Image resolution (height, width).
|
||||
fovy_deg (float): Field of view in degrees.
|
||||
|
||||
Returns:
|
||||
sapien.render.RenderCameraComponent: The created camera.
|
||||
@ -456,15 +525,15 @@ class SapienSceneManager:
|
||||
"""Initialize multiple cameras arranged in a circle.
|
||||
|
||||
Args:
|
||||
num_cameras (int): Number of cameras to create.
|
||||
radius (float): Radius of the camera circle.
|
||||
height (float): Fixed Z-coordinate of the cameras.
|
||||
target_pt (list[float]): 3D point (x, y, z) that cameras look at.
|
||||
image_hw (Tuple[int, int]): Image resolution (height, width) for cameras.
|
||||
fovy_deg (float): Field of view in degrees for cameras.
|
||||
num_cameras (int): Number of cameras.
|
||||
radius (float): Circle radius.
|
||||
height (float): Camera height.
|
||||
target_pt (list[float]): Target point to look at.
|
||||
image_hw (tuple[int, int]): Image resolution.
|
||||
fovy_deg (float): Field of view in degrees.
|
||||
|
||||
Returns:
|
||||
List[sapien.render.RenderCameraComponent]: List of created cameras.
|
||||
list[sapien.render.RenderCameraComponent]: List of cameras.
|
||||
"""
|
||||
angle_step = 2 * np.pi / num_cameras
|
||||
world_up_vec = np.array([0.0, 0.0, 1.0])
|
||||
@ -510,6 +579,19 @@ class SapienSceneManager:
|
||||
|
||||
|
||||
class FrankaPandaGrasper(object):
|
||||
"""Provides grasp planning and control for Franka Panda robot.
|
||||
|
||||
Attributes:
|
||||
agent (BaseAgent): The robot agent.
|
||||
robot: The robot instance.
|
||||
control_freq (float): Control frequency.
|
||||
control_timestep (float): Control timestep.
|
||||
joint_vel_limits (float): Joint velocity limits.
|
||||
joint_acc_limits (float): Joint acceleration limits.
|
||||
finger_length (float): Length of gripper fingers.
|
||||
planners: Motion planners for each environment.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent: BaseAgent,
|
||||
@ -518,6 +600,7 @@ class FrankaPandaGrasper(object):
|
||||
joint_acc_limits: float = 1.0,
|
||||
finger_length: float = 0.025,
|
||||
) -> None:
|
||||
"""Initialize the grasper."""
|
||||
self.agent = agent
|
||||
self.robot = agent.robot
|
||||
self.control_freq = control_freq
|
||||
@ -553,6 +636,15 @@ class FrankaPandaGrasper(object):
|
||||
gripper_state: Literal[-1, 1],
|
||||
n_step: int = 10,
|
||||
) -> np.ndarray:
|
||||
"""Generate gripper control actions.
|
||||
|
||||
Args:
|
||||
gripper_state (Literal[-1, 1]): Desired gripper state.
|
||||
n_step (int): Number of steps.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Array of gripper actions.
|
||||
"""
|
||||
qpos = self.robot.get_qpos()[0, :-2].cpu().numpy()
|
||||
actions = []
|
||||
for _ in range(n_step):
|
||||
@ -571,6 +663,20 @@ class FrankaPandaGrasper(object):
|
||||
action_key: str = "position",
|
||||
env_idx: int = 0,
|
||||
) -> np.ndarray:
|
||||
"""Plan and execute motion to a target pose.
|
||||
|
||||
Args:
|
||||
pose (sapien.Pose): Target pose.
|
||||
control_timestep (float): Control timestep.
|
||||
gripper_state (Literal[-1, 1]): Desired gripper state.
|
||||
use_point_cloud (bool): Use point cloud for planning.
|
||||
n_max_step (int): Max number of steps.
|
||||
action_key (str): Key for action in result.
|
||||
env_idx (int): Environment index.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Array of actions to reach the pose.
|
||||
"""
|
||||
result = self.planners[env_idx].plan_qpos_to_pose(
|
||||
np.concatenate([pose.p, pose.q]),
|
||||
self.robot.get_qpos().cpu().numpy()[0],
|
||||
@ -608,6 +714,17 @@ class FrankaPandaGrasper(object):
|
||||
offset: tuple[float, float, float] = [0, 0, -0.05],
|
||||
env_idx: int = 0,
|
||||
) -> np.ndarray:
|
||||
"""Compute grasp actions for a target actor.
|
||||
|
||||
Args:
|
||||
actor (sapien.pysapien.Entity): Target actor to grasp.
|
||||
reach_target_only (bool): Only reach the target pose if True.
|
||||
offset (tuple[float, float, float]): Offset for reach pose.
|
||||
env_idx (int): Environment index.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Array of grasp actions.
|
||||
"""
|
||||
physx_rigid = actor.components[1]
|
||||
mesh = get_component_mesh(physx_rigid, to_world_frame=True)
|
||||
obb = mesh.bounding_box_oriented
|
||||
|
||||
@ -1 +1 @@
|
||||
VERSION = "v0.1.5"
|
||||
VERSION = "v0.1.6"
|
||||
|
||||
@ -27,14 +27,22 @@ from PIL import Image
|
||||
|
||||
|
||||
class AestheticPredictor:
|
||||
"""Aesthetic Score Predictor.
|
||||
"""Aesthetic Score Predictor using CLIP and a pre-trained MLP.
|
||||
|
||||
Checkpoints from https://github.com/christophschuhmann/improved-aesthetic-predictor/tree/main
|
||||
Checkpoints from `https://github.com/christophschuhmann/improved-aesthetic-predictor/tree/main`.
|
||||
|
||||
Args:
|
||||
clip_model_dir (str): Path to the directory of the CLIP model.
|
||||
sac_model_path (str): Path to the pre-trained SAC model.
|
||||
device (str): Device to use for computation ("cuda" or "cpu").
|
||||
clip_model_dir (str, optional): Path to CLIP model directory.
|
||||
sac_model_path (str, optional): Path to SAC model weights.
|
||||
device (str, optional): Device for computation ("cuda" or "cpu").
|
||||
|
||||
Example:
|
||||
```py
|
||||
from embodied_gen.validators.aesthetic_predictor import AestheticPredictor
|
||||
predictor = AestheticPredictor(device="cuda")
|
||||
score = predictor.predict("image.png")
|
||||
print("Aesthetic score:", score)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, clip_model_dir=None, sac_model_path=None, device="cpu"):
|
||||
@ -109,7 +117,7 @@ class AestheticPredictor:
|
||||
return model
|
||||
|
||||
def predict(self, image_path):
|
||||
"""Predict the aesthetic score for a given image.
|
||||
"""Predicts the aesthetic score for a given image.
|
||||
|
||||
Args:
|
||||
image_path (str): Path to the image file.
|
||||
|
||||
@ -40,6 +40,16 @@ __all__ = [
|
||||
|
||||
|
||||
class BaseChecker:
|
||||
"""Base class for quality checkers using GPT clients.
|
||||
|
||||
Provides a common interface for querying and validating responses.
|
||||
Subclasses must implement the `query` method.
|
||||
|
||||
Attributes:
|
||||
prompt (str): The prompt used for queries.
|
||||
verbose (bool): Whether to enable verbose logging.
|
||||
"""
|
||||
|
||||
def __init__(self, prompt: str = None, verbose: bool = False) -> None:
|
||||
self.prompt = prompt
|
||||
self.verbose = verbose
|
||||
@ -70,6 +80,15 @@ class BaseChecker:
|
||||
def validate(
|
||||
checkers: list["BaseChecker"], images_list: list[list[str]]
|
||||
) -> list:
|
||||
"""Validates a list of checkers against corresponding image lists.
|
||||
|
||||
Args:
|
||||
checkers (list[BaseChecker]): List of checker instances.
|
||||
images_list (list[list[str]]): List of image path lists.
|
||||
|
||||
Returns:
|
||||
list: Validation results with overall outcome.
|
||||
"""
|
||||
assert len(checkers) == len(images_list)
|
||||
results = []
|
||||
overall_result = True
|
||||
@ -192,7 +211,7 @@ class ImageSegChecker(BaseChecker):
|
||||
|
||||
|
||||
class ImageAestheticChecker(BaseChecker):
|
||||
"""A class for evaluating the aesthetic quality of images.
|
||||
"""Evaluates the aesthetic quality of images using a CLIP-based predictor.
|
||||
|
||||
Attributes:
|
||||
clip_model_dir (str): Path to the CLIP model directory.
|
||||
@ -200,6 +219,14 @@ class ImageAestheticChecker(BaseChecker):
|
||||
thresh (float): Threshold above which images are considered aesthetically acceptable.
|
||||
verbose (bool): Whether to print detailed log messages.
|
||||
predictor (AestheticPredictor): The model used to predict aesthetic scores.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from embodied_gen.validators.quality_checkers import ImageAestheticChecker
|
||||
checker = ImageAestheticChecker(thresh=4.5)
|
||||
flag, score = checker(["image1.png", "image2.png"])
|
||||
print("Aesthetic OK:", flag, "Score:", score)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -227,6 +254,16 @@ class ImageAestheticChecker(BaseChecker):
|
||||
|
||||
|
||||
class SemanticConsistChecker(BaseChecker):
|
||||
"""Checks semantic consistency between text descriptions and segmented images.
|
||||
|
||||
Uses GPT to evaluate if the image matches the text in object type, geometry, and color.
|
||||
|
||||
Attributes:
|
||||
gpt_client (GPTclient): GPT client for queries.
|
||||
prompt (str): Prompt for consistency evaluation.
|
||||
verbose (bool): Whether to enable verbose logging.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gpt_client: GPTclient,
|
||||
@ -276,6 +313,16 @@ class SemanticConsistChecker(BaseChecker):
|
||||
|
||||
|
||||
class TextGenAlignChecker(BaseChecker):
|
||||
"""Evaluates alignment between text prompts and generated 3D asset images.
|
||||
|
||||
Assesses if the rendered images match the text description in category and geometry.
|
||||
|
||||
Attributes:
|
||||
gpt_client (GPTclient): GPT client for queries.
|
||||
prompt (str): Prompt for alignment evaluation.
|
||||
verbose (bool): Whether to enable verbose logging.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gpt_client: GPTclient,
|
||||
@ -489,6 +536,17 @@ class PanoHeightEstimator(object):
|
||||
|
||||
|
||||
class SemanticMatcher(BaseChecker):
|
||||
"""Matches query text to semantically similar scene descriptions.
|
||||
|
||||
Uses GPT to find the most similar scene IDs from a dictionary.
|
||||
|
||||
Attributes:
|
||||
gpt_client (GPTclient): GPT client for queries.
|
||||
prompt (str): Prompt for semantic matching.
|
||||
verbose (bool): Whether to enable verbose logging.
|
||||
seed (int): Random seed for selection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gpt_client: GPTclient,
|
||||
@ -543,6 +601,17 @@ class SemanticMatcher(BaseChecker):
|
||||
def query(
|
||||
self, text: str, context: dict, rand: bool = True, params: dict = None
|
||||
) -> str:
|
||||
"""Queries for semantically similar scene IDs.
|
||||
|
||||
Args:
|
||||
text (str): Query text.
|
||||
context (dict): Dictionary of scene descriptions.
|
||||
rand (bool, optional): Whether to randomly select from top matches.
|
||||
params (dict, optional): Additional GPT parameters.
|
||||
|
||||
Returns:
|
||||
str: Matched scene ID.
|
||||
"""
|
||||
match_list = self.gpt_client.query(
|
||||
self.prompt.format(context=context, text=text),
|
||||
params=params,
|
||||
|
||||
@ -80,6 +80,31 @@ URDF_TEMPLATE = """
|
||||
|
||||
|
||||
class URDFGenerator(object):
|
||||
"""Generates URDF files for 3D assets with physical and semantic attributes.
|
||||
|
||||
Uses GPT to estimate object properties and generates a URDF file with mesh, friction, mass, and metadata.
|
||||
|
||||
Args:
|
||||
gpt_client (GPTclient): GPT client for attribute estimation.
|
||||
mesh_file_list (list[str], optional): Additional mesh files to copy.
|
||||
prompt_template (str, optional): Prompt template for GPT queries.
|
||||
attrs_name (list[str], optional): List of attribute names to include.
|
||||
render_dir (str, optional): Directory for rendered images.
|
||||
render_view_num (int, optional): Number of views to render.
|
||||
decompose_convex (bool, optional): Whether to decompose mesh for collision.
|
||||
rotate_xyzw (list[float], optional): Quaternion for mesh rotation.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from embodied_gen.validators.urdf_convertor import URDFGenerator
|
||||
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
||||
|
||||
urdf_gen = URDFGenerator(GPT_CLIENT, render_view_num=4)
|
||||
urdf_path = urdf_gen(mesh_path="mesh.obj", output_root="output_dir")
|
||||
print("Generated URDF:", urdf_path)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gpt_client: GPTclient,
|
||||
@ -168,6 +193,14 @@ class URDFGenerator(object):
|
||||
self.rotate_xyzw = rotate_xyzw
|
||||
|
||||
def parse_response(self, response: str) -> dict[str, any]:
|
||||
"""Parses GPT response to extract asset attributes.
|
||||
|
||||
Args:
|
||||
response (str): GPT response string.
|
||||
|
||||
Returns:
|
||||
dict[str, any]: Parsed attributes.
|
||||
"""
|
||||
lines = response.split("\n")
|
||||
lines = [line.strip() for line in lines if line]
|
||||
category = lines[0].split(": ")[1]
|
||||
@ -207,11 +240,9 @@ class URDFGenerator(object):
|
||||
|
||||
Args:
|
||||
input_mesh (str): Path to the input mesh file.
|
||||
output_dir (str): Directory to store the generated URDF
|
||||
and processed mesh.
|
||||
attr_dict (dict): Dictionary containing attributes like height,
|
||||
mass, and friction coefficients.
|
||||
output_name (str, optional): Name for the generated URDF and robot.
|
||||
output_dir (str): Directory to store the generated URDF and mesh.
|
||||
attr_dict (dict): Dictionary of asset attributes.
|
||||
output_name (str, optional): Name for the URDF and robot.
|
||||
|
||||
Returns:
|
||||
str: Path to the generated URDF file.
|
||||
@ -336,6 +367,16 @@ class URDFGenerator(object):
|
||||
attr_root: str = ".//link/extra_info",
|
||||
attr_name: str = "scale",
|
||||
) -> float:
|
||||
"""Extracts an attribute value from a URDF file.
|
||||
|
||||
Args:
|
||||
urdf_path (str): Path to the URDF file.
|
||||
attr_root (str, optional): XML path to attribute root.
|
||||
attr_name (str, optional): Attribute name.
|
||||
|
||||
Returns:
|
||||
float: Attribute value, or None if not found.
|
||||
"""
|
||||
if not os.path.exists(urdf_path):
|
||||
raise FileNotFoundError(f"URDF file not found: {urdf_path}")
|
||||
|
||||
@ -358,6 +399,13 @@ class URDFGenerator(object):
|
||||
def add_quality_tag(
|
||||
urdf_path: str, results: list, output_path: str = None
|
||||
) -> None:
|
||||
"""Adds a quality tag to a URDF file.
|
||||
|
||||
Args:
|
||||
urdf_path (str): Path to the URDF file.
|
||||
results (list): List of [checker_name, result] pairs.
|
||||
output_path (str, optional): Output file path.
|
||||
"""
|
||||
if output_path is None:
|
||||
output_path = urdf_path
|
||||
|
||||
@ -382,6 +430,14 @@ class URDFGenerator(object):
|
||||
logger.info(f"URDF files saved to {output_path}")
|
||||
|
||||
def get_estimated_attributes(self, asset_attrs: dict):
|
||||
"""Calculates estimated attributes from asset properties.
|
||||
|
||||
Args:
|
||||
asset_attrs (dict): Asset attributes.
|
||||
|
||||
Returns:
|
||||
dict: Estimated attributes (height, mass, mu, category).
|
||||
"""
|
||||
estimated_attrs = {
|
||||
"height": round(
|
||||
(asset_attrs["min_height"] + asset_attrs["max_height"]) / 2, 4
|
||||
@ -403,6 +459,18 @@ class URDFGenerator(object):
|
||||
category: str = "unknown",
|
||||
**kwargs,
|
||||
):
|
||||
"""Generates a URDF file for a mesh asset.
|
||||
|
||||
Args:
|
||||
mesh_path (str): Path to mesh file.
|
||||
output_root (str): Directory for outputs.
|
||||
text_prompt (str, optional): Prompt for GPT.
|
||||
category (str, optional): Asset category.
|
||||
**kwargs: Additional attributes.
|
||||
|
||||
Returns:
|
||||
str: Path to generated URDF file.
|
||||
"""
|
||||
if text_prompt is None or len(text_prompt) == 0:
|
||||
text_prompt = self.prompt_template
|
||||
text_prompt = text_prompt.format(category=category.lower())
|
||||
|
||||
@ -7,7 +7,7 @@ packages = ["embodied_gen"]
|
||||
|
||||
[project]
|
||||
name = "embodied_gen"
|
||||
version = "v0.1.5"
|
||||
version = "v0.1.6"
|
||||
readme = "README.md"
|
||||
license = "Apache-2.0"
|
||||
license-files = ["LICENSE", "NOTICE"]
|
||||
|
||||
@ -4,10 +4,9 @@ import pytest
|
||||
from huggingface_hub import snapshot_download
|
||||
from embodied_gen.data.asset_converter import (
|
||||
AssetConverterFactory,
|
||||
AssetType,
|
||||
SimAssetMapper,
|
||||
cvt_embodiedgen_asset_to_anysim,
|
||||
)
|
||||
from embodied_gen.utils.enum import AssetType, SimAssetMapper
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@ -77,7 +76,10 @@ def test_cvt_embodiedgen_asset_to_anysim(
|
||||
):
|
||||
dst_asset_path = cvt_embodiedgen_asset_to_anysim(
|
||||
urdf_files=[
|
||||
"outputs/embodiedgen_assets/demo_assets/remote_control2/result/remote_control.urdf",
|
||||
"outputs/embodiedgen_assets/demo_assets/remote_control/result/remote_control.urdf",
|
||||
],
|
||||
target_dirs=[
|
||||
"outputs/embodiedgen_assets/demo_assets/remote_control/usd/remote_control.usd",
|
||||
],
|
||||
target_type=SimAssetMapper[simulator_name],
|
||||
source_type=AssetType.MESH,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user