Lerobot/docker/merge.py
2025-12-11 14:11:41 +08:00

137 lines
4.5 KiB
Python

import json
import os
from pathlib import Path
from tqdm import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset
with open("/workspace/embolab/params/build_task.json", "r") as f:
task_config = json.load(f)
src_dataset_paths = [i for i in Path(task_config["merge"]["input_data_path"]).iterdir() if i.is_dir()]
# EPS = 1e-2
# Feature Check
features = {}
keys_to_check = ["action", "observation.state", "observation.images"]
for p in src_dataset_paths:
dataset = LeRobotDataset(repo_id="O24H/Src", root=p)
if not features:
features = {
k: v for k, v in dataset.features.items() if any(k.startswith(prefix) for prefix in keys_to_check)
}
else:
for k in features.keys():
assert k in dataset.features, f"Feature key {k} not found in dataset {p}"
# pprint(dataset.features[k])
# pprint(features[k])
# assert dataset.features[k] == features[k], f"Feature key {k} mismatch in dataset {p}"
# Initialize Target Dataset
target_path = Path(task_config["merge"]["output_data_path"])
# assert not target_path.exists(), f"Output path {target_path} already exists!"
if target_path.exists():
import os
os.system(f"rm -rf {target_path}")
tmp_target_path = "/tmp/lerobot_merge_temp"
### using images to store all data rather than videos:
### 35s per episode -> 20s per episode but size will be ~40x larger 6M -> 260M
# for i in features.keys():
# if i.startswith("observation.images"):
# if not features[i]["dtype"] == "image":
# features[i]["dtype"] = "image"
# try:
# features[i].pop("info")
# except KeyError:
# pass
# target = LeRobotDataset.create(
# repo_id="O24H/Target",
# fps=30,
# root=target_path,
# robot_type="so101_follower",
# features=features,
# image_writer_processes=8,
# image_writer_threads=16,
# use_videos=False
# )
# [TODO] use the largest dataset as the base rather than creating a new one
target = LeRobotDataset.create(
repo_id="O24H/Target",
fps=30,
root=tmp_target_path,
robot_type="so101_follower",
features=features,
image_writer_threads=16,
)
for p in src_dataset_paths:
try:
src = LeRobotDataset(repo_id="O24H/Src", root=p)
except:
print("Error while Processing: ", p, ", Skip...")
for eps_idx in tqdm(range(src.num_episodes), desc=f"Processing episode in {p.name}"):
frame_idx = range(
src.episode_data_index["from"][eps_idx].item(),
src.episode_data_index["to"][eps_idx].item(),
)
# eps_data = [src.__getitem__(i) for i in frame_idx]
# diff_actions = [eps_data[i]["action"] - eps_data[i - 1]["action"] for i in range(1, len(eps_data))]
# keep_idx = [i + 1 for i, a in enumerate(diff_actions) if (a.abs() > EPS).any()]
# compress_ratio = len(keep_idx) / len(frame_idx)
# print(f"Episode {eps_idx}: compress ratio {compress_ratio:.2f}")
keep_idx = frame_idx # No Compression
if len(keep_idx) < 32:
continue
# Skip too short episodes after compression
for o in keep_idx:
batch = src.__getitem__(o)
image_keys = [k for k in batch.keys() if k.startswith("observation.images.")]
frame = {
"action": batch["action"],
"observation.state": batch["observation.state"],
}
for k in image_keys:
frame[k] = batch[k].permute(1, 2, 0).contiguous() # CHW -> HWC
target.add_frame(frame, task=batch["task"])
# manually copy the video and skip encoding process
for key in target.meta.video_keys:
__src_video_path = src.root / src.meta.get_video_file_path(eps_idx, key)
__tgt_video_path = target.root / target.meta.get_video_file_path(target.num_episodes, key)
__tgt_video_path.parent.mkdir(parents=True, exist_ok=True)
os.system(f"cp {__src_video_path} {__tgt_video_path}")
target.save_episode()
# Remove images if video exists
for key in target.meta.video_keys:
__img_dir = target._get_image_file_path(
episode_index=target.num_episodes - 1, image_key=key, frame_index=0
).parent
if __img_dir.exists():
print(f"Removing image dir: {__img_dir}")
os.system(f"rm -rf {__img_dir}")
# move tmp target to final target
os.system(f"mv {tmp_target_path} {target_path}")
print(f"Merged dataset saved at {target_path}")