109 lines
3.3 KiB
Python
109 lines
3.3 KiB
Python
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import rerun as rr
|
|
import rerun.blueprint as rrb
|
|
import scipy.spatial.transform as st
|
|
|
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|
from urdf_parser_py import urdf as urdf_parser
|
|
|
|
SO101_URDF = Path(__file__).resolve().parent / "assets/SO101/so101_new_calib.urdf"
|
|
DEFAULT_JOINT_NAMES = [
|
|
"shoulder_pan",
|
|
"shoulder_lift",
|
|
"elbow_flex",
|
|
"wrist_flex",
|
|
"wrist_roll",
|
|
"gripper",
|
|
]
|
|
|
|
|
|
def visualize_episode(dataset_path: str | Path, episode_index: int):
|
|
# Init rerun env
|
|
bp = rrb.Blueprint(collapse_panels=True)
|
|
rr.init("play_lerobot", spawn=False, default_blueprint=bp)
|
|
rr.send_blueprint(bp, make_active=True)
|
|
rr.serve_grpc(server_memory_limit="80%")
|
|
|
|
# Visualize URDF
|
|
rr.log_file_from_path(file_path=SO101_URDF)
|
|
rr.set_time("frame", sequence=0)
|
|
|
|
# URDF Controller
|
|
urdf = urdf_parser.URDF.from_xml_file(SO101_URDF)
|
|
|
|
def get_joint_paths(urdf: urdf_parser.URDF, entity_prefix: str = "") -> dict[str, str]:
|
|
joint_paths = {}
|
|
|
|
def get_joint_path(joint: urdf_parser.Joint) -> str:
|
|
"""Return the entity path for the URDF joint."""
|
|
root_name = urdf.get_root()
|
|
joint_names = urdf.get_chain(root_name, joint.child)
|
|
return entity_prefix + ("/".join(joint_names)) if entity_prefix else "/" + ("/".join(joint_names))
|
|
|
|
for joint in urdf.joints:
|
|
joint_paths[joint.name] = get_joint_path(joint)
|
|
return joint_paths
|
|
|
|
joint_paths = get_joint_paths(urdf, entity_prefix="/so101_new_calib/")
|
|
|
|
ds = LeRobotDataset(repo_id="O24H/Vis", root=dataset_path)
|
|
|
|
frame_idx = range(
|
|
ds.episode_data_index["from"][episode_index].item(),
|
|
ds.episode_data_index["to"][episode_index].item(),
|
|
)
|
|
|
|
for i in frame_idx:
|
|
frame = ds[i]
|
|
|
|
# dt = 1 / ds.fps
|
|
# rr.set_time_seconds("timestamp", i * dt)
|
|
|
|
rr.set_time("frame", sequence=i + 1)
|
|
|
|
# Visualize robot state
|
|
for joint_index, joint_name in enumerate(JOINT_NAMES):
|
|
joint = urdf.joint_map[joint_name]
|
|
if joint.type == "revolute":
|
|
# 取得关节可以旋转的轴
|
|
axis = np.array(getattr(joint, "axis", [0, 0, 1])) # 默认轴为z轴
|
|
|
|
ang = frame["observation.state"][joint_index]
|
|
|
|
R = rr.RotationAxisAngle(axis=axis, degrees=ang)
|
|
|
|
entity_path = joint_paths[joint_name]
|
|
|
|
rr.log(
|
|
entity_path,
|
|
rr.Transform3D(rotation=R),
|
|
)
|
|
|
|
for k in ds.meta.camera_keys:
|
|
rr.log(
|
|
"/camera/" + k,
|
|
rr.Image((frame[k] * 255).permute(1, 2, 0).numpy().astype(np.uint8)),
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
import time
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--dataset", type=str, required=True, help="Path to dataset")
|
|
parser.add_argument("--episode", type=int, default=0, help="Episode to visualize")
|
|
args = parser.parse_args()
|
|
|
|
tic = time.time()
|
|
|
|
visualize_episode(dataset_path=args.dataset, episode_index=args.episode - 1) # args.episode start from 1
|
|
|
|
toc = time.time()
|
|
print(f"Visualization server started in {toc - tic:.2f} seconds.")
|
|
|
|
while True:
|
|
time.sleep(3600)
|