Lerobot/docker/add_depth.py
2025-12-11 14:11:41 +08:00

110 lines
3.3 KiB
Python

from lerobot.datasets.lerobot_dataset import LeRobotDataset
import numpy as np
target_path = "docker/inputs/pick_orange_rs_depth"
src_ds = [
("datasets/realsense_output/pick_orange_0", "Pick the orange to the plate"),
("datasets/realsense_output/pick_orange_1", "Pick the orange to the plate"),
("datasets/realsense_output/pick_orange_2", "Pick the orange to the plate"),
("datasets/realsense_output/pick_orange_3", "Pick the orange to the plate"),
("datasets/realsense_output/pick_orange_4", "Pick the orange to the plate"),
]
SINGLE_ARM_FEATURES = {
"action": {
"dtype": "float32",
"shape": (6,),
"names": [
"shoulder_pan.pos",
"shoulder_lift.pos",
"elbow_flex.pos",
"wrist_flex.pos",
"wrist_roll.pos",
"gripper.pos",
],
},
"observation.state": {
"dtype": "float32",
"shape": (6,),
"names": [
"shoulder_pan.pos",
"shoulder_lift.pos",
"elbow_flex.pos",
"wrist_flex.pos",
"wrist_roll.pos",
"gripper.pos",
],
},
"observation.images.front": {
"dtype": "image",
"shape": [480, 640, 3],
"names": ["height", "width", "channels"],
},
"observation.images.front_depth": {
"dtype": "image",
"shape": [480, 640, 3],
"names": ["height", "width", "channels"],
},
}
from cloud_helper import Client
client = Client(host="localhost", port=50000)
# client = Client(host="120.48.81.132", port=50000)
import os.path as osp
from os import system
from tqdm import tqdm
if osp.exists(target_path):
system(f"rm -rf {target_path}")
target = LeRobotDataset.create(
repo_id="O24H/Target",
fps=30,
root=target_path,
robot_type="so101_follower",
features=SINGLE_ARM_FEATURES,
)
for src_path, task in src_ds:
src = LeRobotDataset(
repo_id="O24H/Src",
root=src_path,
)
for eps_idx in range(src.num_episodes):
frame_idx = range(
src.episode_data_index["from"][eps_idx].item(),
src.episode_data_index["to"][eps_idx].item(),
)
diff_actions = [src.__getitem__(i)["action"] - src.__getitem__(i - 1)["action"] for i in frame_idx if i > 0]
EPS = 1e-3
keep_idx = [i for i, a in zip(frame_idx, diff_actions) if (a.abs() > EPS).any()]
compress_ratio = len(keep_idx) / len(frame_idx)
print(f"Episode {eps_idx}: compress ratio {compress_ratio:.2f}")
if len(keep_idx) == 0:
continue
for frame_idx in tqdm(keep_idx):
batch = src.__getitem__(frame_idx)
front_img = batch["observation.images.front"].permute(1, 2, 0).contiguous().numpy()
front_depth = client.call_endpoint("get_depth", front_img)
front_depth_ = front_depth[:, :, np.newaxis] * np.ones(3) * 255.0
front_depth_ = front_depth_.astype(np.uint8)
frame = {
"action": batch["action"],
"observation.state": batch["observation.state"],
"observation.images.front": front_img,
"observation.images.front_depth": front_depth_}
target.add_frame(frame, task=task)
target.save_episode()