Lerobot/docker/merge.py

109 lines
3.4 KiB
Python

import json
from pathlib import Path
from tqdm import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset
with open("/workspace/inputs/task.json", "r") as f:
task_config = json.load(f)
src_dataset_paths = [i for i in Path(task_config["train"]["input_data_path"]).iterdir() if i.is_dir()]
EPS = 1e-2
# Feature Check
features = {}
keys_to_check = ["action", "observation.state", "observation.images"]
for p in src_dataset_paths:
dataset = LeRobotDataset(repo_id="O24H/Src", root=p)
if not features:
features = {
k: v for k, v in dataset.features.items() if any(k.startswith(prefix) for prefix in keys_to_check)
}
else:
for k in features.keys():
assert k in dataset.features, f"Feature key {k} not found in dataset {p}"
# pprint(dataset.features[k])
# pprint(features[k])
# assert dataset.features[k] == features[k], f"Feature key {k} mismatch in dataset {p}"
# Initialize Target Dataset
target_path = Path(task_config["train"]["output_data_path"])
# assert not target_path.exists(), f"Output path {target_path} already exists!"
if target_path.exists():
import os
os.system(f"rm -rf {target_path}")
### using images to store all data rather than videos:
### 35s per episode -> 20s per episode but size will be ~40x larger 6M -> 260M
# for i in features.keys():
# if i.startswith("observation.images"):
# if not features[i]["dtype"] == "image":
# features[i]["dtype"] = "image"
# try:
# features[i].pop("info")
# except KeyError:
# pass
# target = LeRobotDataset.create(
# repo_id="O24H/Target",
# fps=30,
# root=target_path,
# robot_type="so101_follower",
# features=features,
# image_writer_processes=8,
# image_writer_threads=16,
# use_videos=False
# )
# [TODO] use the largest dataset as the base rather than creating a new one
target = LeRobotDataset.create(
repo_id="O24H/Target",
fps=30,
root=target_path,
robot_type="so101_follower",
features=features,
image_writer_processes=8,
image_writer_threads=16,
)
for p in src_dataset_paths:
src = LeRobotDataset(repo_id="O24H/Src", root=p)
for eps_idx in tqdm(range(src.num_episodes), desc=f"Processing episode in {p.name}"):
frame_idx = range(
src.episode_data_index["from"][eps_idx].item(),
src.episode_data_index["to"][eps_idx].item(),
)
eps_data = [src.__getitem__(i) for i in frame_idx]
diff_actions = [eps_data[i]["action"] - eps_data[i - 1]["action"] for i in range(1, len(eps_data))]
keep_idx = [i + 1 for i, a in enumerate(diff_actions) if (a.abs() > EPS).any()]
compress_ratio = len(keep_idx) / len(frame_idx)
print(f"Episode {eps_idx}: compress ratio {compress_ratio:.2f}")
if len(keep_idx) < 32:
continue
# Skip too short episodes after compression
for o in keep_idx:
batch = eps_data[o]
image_keys = [k for k in batch.keys() if k.startswith("observation.images.")]
frame = {
"action": batch["action"],
"observation.state": batch["observation.state"],
}
for k in image_keys:
frame[k] = batch[k].permute(1, 2, 0).contiguous() # CHW -> HWC
target.add_frame(frame, task=batch["task"])
target.save_episode()