167 lines
5.6 KiB
Python
167 lines
5.6 KiB
Python
# Project EmbodiedGen
|
|
#
|
|
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
# implied. See the License for the specific language governing
|
|
# permissions and limitations under the License.
|
|
|
|
|
|
from embodied_gen.utils.monkey_patches import monkey_patch_maniskill
|
|
|
|
monkey_patch_maniskill()
|
|
import json
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass, field
|
|
from typing import Literal
|
|
|
|
import gymnasium as gym
|
|
import numpy as np
|
|
import torch
|
|
import tyro
|
|
from mani_skill.utils.wrappers import RecordEpisode
|
|
from tqdm import tqdm
|
|
import embodied_gen.envs.pick_embodiedgen
|
|
from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum
|
|
from embodied_gen.utils.log import logger
|
|
from embodied_gen.utils.simulation import FrankaPandaGrasper
|
|
|
|
|
|
@dataclass
|
|
class ParallelSimConfig:
|
|
"""CLI parameters for Parallel Sapien simulation."""
|
|
|
|
# Environment configuration
|
|
layout_file: str
|
|
"""Path to the layout JSON file"""
|
|
output_dir: str
|
|
"""Directory to save recorded videos"""
|
|
gym_env_name: str = "PickEmbodiedGen-v1"
|
|
"""Name of the Gym environment to use"""
|
|
num_envs: int = 4
|
|
"""Number of parallel environments"""
|
|
render_mode: Literal["rgb_array", "hybrid"] = "hybrid"
|
|
"""Rendering mode: rgb_array or hybrid"""
|
|
enable_shadow: bool = True
|
|
"""Whether to enable shadows in rendering"""
|
|
control_mode: str = "pd_joint_pos"
|
|
"""Control mode for the agent"""
|
|
|
|
# Recording configuration
|
|
max_steps_per_video: int = 1000
|
|
"""Maximum steps to record per video"""
|
|
save_trajectory: bool = False
|
|
"""Whether to save trajectory data"""
|
|
|
|
# Simulation parameters
|
|
seed: int = 0
|
|
"""Random seed for environment reset"""
|
|
warmup_steps: int = 50
|
|
"""Number of warmup steps before action computation"""
|
|
reach_target_only: bool = True
|
|
"""Whether to only reach target without full action"""
|
|
|
|
# Camera settings
|
|
camera_eye: list[float] = field(default_factory=lambda: [0.9, 0.0, 1.1])
|
|
"""Camera eye position [x, y, z] in global coordiante system"""
|
|
camera_target_pt: list[float] = field(
|
|
default_factory=lambda: [0.0, 0.0, 0.9]
|
|
)
|
|
"""Camera target(look-at) point [x, y, z] in global coordiante system"""
|
|
image_hw: list[int] = field(default_factory=lambda: [256, 256])
|
|
"""Rendered image height and width [height, width]"""
|
|
fovy_deg: float = 75
|
|
"""Camera vertical field of view in degrees"""
|
|
|
|
|
|
def entrypoint(**kwargs):
|
|
if kwargs is None or len(kwargs) == 0:
|
|
cfg = tyro.cli(ParallelSimConfig)
|
|
else:
|
|
cfg = ParallelSimConfig(**kwargs)
|
|
|
|
env = gym.make(
|
|
cfg.gym_env_name,
|
|
num_envs=cfg.num_envs,
|
|
render_mode=cfg.render_mode,
|
|
enable_shadow=cfg.enable_shadow,
|
|
layout_file=cfg.layout_file,
|
|
control_mode=cfg.control_mode,
|
|
camera_cfg=dict(
|
|
camera_eye=cfg.camera_eye,
|
|
camera_target_pt=cfg.camera_target_pt,
|
|
image_hw=cfg.image_hw,
|
|
fovy_deg=cfg.fovy_deg,
|
|
),
|
|
)
|
|
env = RecordEpisode(
|
|
env,
|
|
cfg.output_dir,
|
|
max_steps_per_video=cfg.max_steps_per_video,
|
|
save_trajectory=cfg.save_trajectory,
|
|
)
|
|
env.reset(seed=cfg.seed)
|
|
|
|
default_action = env.unwrapped.agent.init_qpos[:, :8]
|
|
for _ in tqdm(range(cfg.warmup_steps), desc="SIM Warmup"):
|
|
# action = env.action_space.sample() # Random action
|
|
obs, reward, terminated, truncated, info = env.step(default_action)
|
|
|
|
grasper = FrankaPandaGrasper(
|
|
env.unwrapped.agent,
|
|
env.unwrapped.sim_config.control_freq,
|
|
)
|
|
|
|
layout_data = LayoutInfo.from_dict(json.load(open(cfg.layout_file, "r")))
|
|
actions = defaultdict(list)
|
|
# Plan Grasp reach pose for each manipulated object in each env.
|
|
for env_idx in range(env.num_envs):
|
|
actors = env.unwrapped.env_actors[f"env{env_idx}"]
|
|
for node in layout_data.relation[
|
|
Scene3DItemEnum.MANIPULATED_OBJS.value
|
|
]:
|
|
action = grasper.compute_grasp_action(
|
|
actor=actors[node]._objs[0],
|
|
reach_target_only=True,
|
|
env_idx=env_idx,
|
|
)
|
|
actions[node].append(action)
|
|
|
|
# Excute the planned actions for each manipulated object in each env.
|
|
for node in actions:
|
|
max_env_steps = 0
|
|
for env_idx in range(env.num_envs):
|
|
if actions[node][env_idx] is None:
|
|
continue
|
|
max_env_steps = max(max_env_steps, len(actions[node][env_idx]))
|
|
|
|
action_tensor = np.ones(
|
|
(max_env_steps, env.num_envs, env.action_space.shape[-1])
|
|
)
|
|
action_tensor *= default_action[None, ...]
|
|
for env_idx in range(env.num_envs):
|
|
action = actions[node][env_idx]
|
|
if action is None:
|
|
continue
|
|
action_tensor[: len(action), env_idx, :] = action
|
|
|
|
for step in tqdm(range(max_env_steps), desc=f"Grasping: {node}"):
|
|
action = torch.Tensor(action_tensor[step]).to(env.unwrapped.device)
|
|
env.unwrapped.agent.set_action(action)
|
|
obs, reward, terminated, truncated, info = env.step(action)
|
|
|
|
env.close()
|
|
logger.info(f"Results saved in {cfg.output_dir}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
entrypoint()
|