Lerobot/docker/visualize/visualize.py
2025-12-11 14:11:41 +08:00

109 lines
3.3 KiB
Python

from pathlib import Path
import numpy as np
import rerun as rr
import rerun.blueprint as rrb
import scipy.spatial.transform as st
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from urdf_parser_py import urdf as urdf_parser
SO101_URDF = Path(__file__).resolve().parent / "assets/SO101/so101_new_calib.urdf"
DEFAULT_JOINT_NAMES = [
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
]
def visualize_episode(dataset_path: str | Path, episode_index: int):
# Init rerun env
bp = rrb.Blueprint(collapse_panels=True)
rr.init("play_lerobot", spawn=False, default_blueprint=bp)
rr.send_blueprint(bp, make_active=True)
rr.serve_grpc(server_memory_limit="80%")
# Visualize URDF
rr.log_file_from_path(file_path=SO101_URDF)
rr.set_time("frame", sequence=0)
# URDF Controller
urdf = urdf_parser.URDF.from_xml_file(SO101_URDF)
def get_joint_paths(urdf: urdf_parser.URDF, entity_prefix: str = "") -> dict[str, str]:
joint_paths = {}
def get_joint_path(joint: urdf_parser.Joint) -> str:
"""Return the entity path for the URDF joint."""
root_name = urdf.get_root()
joint_names = urdf.get_chain(root_name, joint.child)
return entity_prefix + ("/".join(joint_names)) if entity_prefix else "/" + ("/".join(joint_names))
for joint in urdf.joints:
joint_paths[joint.name] = get_joint_path(joint)
return joint_paths
joint_paths = get_joint_paths(urdf, entity_prefix="/so101_new_calib/")
ds = LeRobotDataset(repo_id="O24H/Vis", root=dataset_path)
frame_idx = range(
ds.episode_data_index["from"][episode_index].item(),
ds.episode_data_index["to"][episode_index].item(),
)
for i in frame_idx:
frame = ds[i]
# dt = 1 / ds.fps
# rr.set_time_seconds("timestamp", i * dt)
rr.set_time("frame", sequence=i + 1)
# Visualize robot state
for joint_index, joint_name in enumerate(JOINT_NAMES):
joint = urdf.joint_map[joint_name]
if joint.type == "revolute":
# 取得关节可以旋转的轴
axis = np.array(getattr(joint, "axis", [0, 0, 1])) # 默认轴为z轴
ang = frame["observation.state"][joint_index]
R = rr.RotationAxisAngle(axis=axis, degrees=ang)
entity_path = joint_paths[joint_name]
rr.log(
entity_path,
rr.Transform3D(rotation=R),
)
for k in ds.meta.camera_keys:
rr.log(
"/camera/" + k,
rr.Image((frame[k] * 255).permute(1, 2, 0).numpy().astype(np.uint8)),
)
if __name__ == "__main__":
import argparse
import time
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, required=True, help="Path to dataset")
parser.add_argument("--episode", type=int, default=0, help="Episode to visualize")
args = parser.parse_args()
tic = time.time()
visualize_episode(dataset_path=args.dataset, episode_index=args.episode - 1) # args.episode start from 1
toc = time.time()
print(f"Visualization server started in {toc - tic:.2f} seconds.")
while True:
time.sleep(3600)