110 lines
3.3 KiB
Python
110 lines
3.3 KiB
Python
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|
import numpy as np
|
|
|
|
target_path = "docker/inputs/pick_orange_rs_depth"
|
|
src_ds = [
|
|
("datasets/realsense_output/pick_orange_0", "Pick the orange to the plate"),
|
|
("datasets/realsense_output/pick_orange_1", "Pick the orange to the plate"),
|
|
("datasets/realsense_output/pick_orange_2", "Pick the orange to the plate"),
|
|
("datasets/realsense_output/pick_orange_3", "Pick the orange to the plate"),
|
|
("datasets/realsense_output/pick_orange_4", "Pick the orange to the plate"),
|
|
]
|
|
|
|
SINGLE_ARM_FEATURES = {
|
|
"action": {
|
|
"dtype": "float32",
|
|
"shape": (6,),
|
|
"names": [
|
|
"shoulder_pan.pos",
|
|
"shoulder_lift.pos",
|
|
"elbow_flex.pos",
|
|
"wrist_flex.pos",
|
|
"wrist_roll.pos",
|
|
"gripper.pos",
|
|
],
|
|
},
|
|
"observation.state": {
|
|
"dtype": "float32",
|
|
"shape": (6,),
|
|
"names": [
|
|
"shoulder_pan.pos",
|
|
"shoulder_lift.pos",
|
|
"elbow_flex.pos",
|
|
"wrist_flex.pos",
|
|
"wrist_roll.pos",
|
|
"gripper.pos",
|
|
],
|
|
},
|
|
"observation.images.front": {
|
|
"dtype": "image",
|
|
"shape": [480, 640, 3],
|
|
"names": ["height", "width", "channels"],
|
|
},
|
|
"observation.images.front_depth": {
|
|
"dtype": "image",
|
|
"shape": [480, 640, 3],
|
|
"names": ["height", "width", "channels"],
|
|
},
|
|
}
|
|
|
|
from cloud_helper import Client
|
|
client = Client(host="localhost", port=50000)
|
|
# client = Client(host="120.48.81.132", port=50000)
|
|
|
|
import os.path as osp
|
|
from os import system
|
|
from tqdm import tqdm
|
|
|
|
if osp.exists(target_path):
|
|
system(f"rm -rf {target_path}")
|
|
|
|
target = LeRobotDataset.create(
|
|
repo_id="O24H/Target",
|
|
fps=30,
|
|
root=target_path,
|
|
robot_type="so101_follower",
|
|
features=SINGLE_ARM_FEATURES,
|
|
)
|
|
|
|
|
|
for src_path, task in src_ds:
|
|
src = LeRobotDataset(
|
|
repo_id="O24H/Src",
|
|
root=src_path,
|
|
)
|
|
|
|
for eps_idx in range(src.num_episodes):
|
|
|
|
frame_idx = range(
|
|
src.episode_data_index["from"][eps_idx].item(),
|
|
src.episode_data_index["to"][eps_idx].item(),
|
|
)
|
|
|
|
diff_actions = [src.__getitem__(i)["action"] - src.__getitem__(i - 1)["action"] for i in frame_idx if i > 0]
|
|
EPS = 1e-3
|
|
|
|
keep_idx = [i for i, a in zip(frame_idx, diff_actions) if (a.abs() > EPS).any()]
|
|
|
|
compress_ratio = len(keep_idx) / len(frame_idx)
|
|
print(f"Episode {eps_idx}: compress ratio {compress_ratio:.2f}")
|
|
|
|
if len(keep_idx) == 0:
|
|
continue
|
|
|
|
for frame_idx in tqdm(keep_idx):
|
|
batch = src.__getitem__(frame_idx)
|
|
|
|
front_img = batch["observation.images.front"].permute(1, 2, 0).contiguous().numpy()
|
|
front_depth = client.call_endpoint("get_depth", front_img)
|
|
|
|
front_depth_ = front_depth[:, :, np.newaxis] * np.ones(3) * 255.0
|
|
front_depth_ = front_depth_.astype(np.uint8)
|
|
|
|
frame = {
|
|
"action": batch["action"],
|
|
"observation.state": batch["observation.state"],
|
|
"observation.images.front": front_img,
|
|
"observation.images.front_depth": front_depth_}
|
|
target.add_frame(frame, task=task)
|
|
target.save_episode()
|