from pathlib import Path import numpy as np import rerun as rr import rerun.blueprint as rrb import scipy.spatial.transform as st from lerobot.datasets.lerobot_dataset import LeRobotDataset from urdf_parser_py import urdf as urdf_parser SO101_URDF = Path(__file__).resolve().parent / "assets/SO101/so101_new_calib.urdf" DEFAULT_JOINT_NAMES = [ "shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper", ] def visualize_episode(dataset_path: str | Path, episode_index: int): # Init rerun env bp = rrb.Blueprint(collapse_panels=True) rr.init("play_lerobot", spawn=False, default_blueprint=bp) rr.send_blueprint(bp, make_active=True) rr.serve_grpc(server_memory_limit="80%") # Visualize URDF rr.log_file_from_path(file_path=SO101_URDF) rr.set_time("frame", sequence=0) # URDF Controller urdf = urdf_parser.URDF.from_xml_file(SO101_URDF) def get_joint_paths(urdf: urdf_parser.URDF, entity_prefix: str = "") -> dict[str, str]: joint_paths = {} def get_joint_path(joint: urdf_parser.Joint) -> str: """Return the entity path for the URDF joint.""" root_name = urdf.get_root() joint_names = urdf.get_chain(root_name, joint.child) return entity_prefix + ("/".join(joint_names)) if entity_prefix else "/" + ("/".join(joint_names)) for joint in urdf.joints: joint_paths[joint.name] = get_joint_path(joint) return joint_paths joint_paths = get_joint_paths(urdf, entity_prefix="/so101_new_calib/") ds = LeRobotDataset(repo_id="O24H/Vis", root=dataset_path) frame_idx = range( ds.episode_data_index["from"][episode_index].item(), ds.episode_data_index["to"][episode_index].item(), ) for i in frame_idx: frame = ds[i] # dt = 1 / ds.fps # rr.set_time_seconds("timestamp", i * dt) rr.set_time("frame", sequence=i + 1) # Visualize robot state for joint_index, joint_name in enumerate(JOINT_NAMES): joint = urdf.joint_map[joint_name] if joint.type == "revolute": # 取得关节可以旋转的轴 axis = np.array(getattr(joint, "axis", [0, 0, 1])) # 默认轴为z轴 ang = frame["observation.state"][joint_index] R = rr.RotationAxisAngle(axis=axis, degrees=ang) entity_path = joint_paths[joint_name] rr.log( entity_path, rr.Transform3D(rotation=R), ) for k in ds.meta.camera_keys: rr.log( "/camera/" + k, rr.Image((frame[k] * 255).permute(1, 2, 0).numpy().astype(np.uint8)), ) if __name__ == "__main__": import argparse import time parser = argparse.ArgumentParser() parser.add_argument("--dataset", type=str, required=True, help="Path to dataset") parser.add_argument("--episode", type=int, default=0, help="Episode to visualize") args = parser.parse_args() tic = time.time() visualize_episode(dataset_path=args.dataset, episode_index=args.episode - 1) # args.episode start from 1 toc = time.time() print(f"Visualization server started in {toc - tic:.2f} seconds.") while True: time.sleep(3600)