109 lines
3.4 KiB
Python
109 lines
3.4 KiB
Python
import json
|
|
|
|
from pathlib import Path
|
|
from tqdm import tqdm
|
|
|
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|
|
|
|
|
with open("/workspace/inputs/task.json", "r") as f:
|
|
task_config = json.load(f)
|
|
|
|
src_dataset_paths = [i for i in Path(task_config["train"]["input_data_path"]).iterdir() if i.is_dir()]
|
|
|
|
EPS = 1e-2
|
|
|
|
# Feature Check
|
|
features = {}
|
|
keys_to_check = ["action", "observation.state", "observation.images"]
|
|
for p in src_dataset_paths:
|
|
dataset = LeRobotDataset(repo_id="O24H/Src", root=p)
|
|
if not features:
|
|
features = {
|
|
k: v for k, v in dataset.features.items() if any(k.startswith(prefix) for prefix in keys_to_check)
|
|
}
|
|
else:
|
|
for k in features.keys():
|
|
assert k in dataset.features, f"Feature key {k} not found in dataset {p}"
|
|
# pprint(dataset.features[k])
|
|
# pprint(features[k])
|
|
# assert dataset.features[k] == features[k], f"Feature key {k} mismatch in dataset {p}"
|
|
|
|
# Initialize Target Dataset
|
|
target_path = Path(task_config["train"]["output_data_path"])
|
|
# assert not target_path.exists(), f"Output path {target_path} already exists!"
|
|
if target_path.exists():
|
|
import os
|
|
|
|
os.system(f"rm -rf {target_path}")
|
|
|
|
### using images to store all data rather than videos:
|
|
### 35s per episode -> 20s per episode but size will be ~40x larger 6M -> 260M
|
|
# for i in features.keys():
|
|
# if i.startswith("observation.images"):
|
|
# if not features[i]["dtype"] == "image":
|
|
# features[i]["dtype"] = "image"
|
|
# try:
|
|
# features[i].pop("info")
|
|
# except KeyError:
|
|
# pass
|
|
# target = LeRobotDataset.create(
|
|
# repo_id="O24H/Target",
|
|
# fps=30,
|
|
# root=target_path,
|
|
# robot_type="so101_follower",
|
|
# features=features,
|
|
# image_writer_processes=8,
|
|
# image_writer_threads=16,
|
|
# use_videos=False
|
|
# )
|
|
|
|
# [TODO] use the largest dataset as the base rather than creating a new one
|
|
target = LeRobotDataset.create(
|
|
repo_id="O24H/Target",
|
|
fps=30,
|
|
root=target_path,
|
|
robot_type="so101_follower",
|
|
features=features,
|
|
image_writer_processes=8,
|
|
image_writer_threads=16,
|
|
)
|
|
|
|
for p in src_dataset_paths:
|
|
src = LeRobotDataset(repo_id="O24H/Src", root=p)
|
|
|
|
for eps_idx in tqdm(range(src.num_episodes), desc=f"Processing episode in {p.name}"):
|
|
frame_idx = range(
|
|
src.episode_data_index["from"][eps_idx].item(),
|
|
src.episode_data_index["to"][eps_idx].item(),
|
|
)
|
|
|
|
eps_data = [src.__getitem__(i) for i in frame_idx]
|
|
|
|
diff_actions = [eps_data[i]["action"] - eps_data[i - 1]["action"] for i in range(1, len(eps_data))]
|
|
keep_idx = [i + 1 for i, a in enumerate(diff_actions) if (a.abs() > EPS).any()]
|
|
|
|
compress_ratio = len(keep_idx) / len(frame_idx)
|
|
print(f"Episode {eps_idx}: compress ratio {compress_ratio:.2f}")
|
|
|
|
if len(keep_idx) < 32:
|
|
continue
|
|
# Skip too short episodes after compression
|
|
|
|
for o in keep_idx:
|
|
batch = eps_data[o]
|
|
|
|
image_keys = [k for k in batch.keys() if k.startswith("observation.images.")]
|
|
|
|
frame = {
|
|
"action": batch["action"],
|
|
"observation.state": batch["observation.state"],
|
|
}
|
|
|
|
for k in image_keys:
|
|
frame[k] = batch[k].permute(1, 2, 0).contiguous() # CHW -> HWC
|
|
|
|
target.add_frame(frame, task=batch["task"])
|
|
|
|
target.save_episode()
|