import json from pathlib import Path from tqdm import tqdm from lerobot.datasets.lerobot_dataset import LeRobotDataset with open("/workspace/inputs/task.json", "r") as f: task_config = json.load(f) src_dataset_paths = [i for i in Path(task_config["train"]["input_data_path"]).iterdir() if i.is_dir()] EPS = 1e-2 # Feature Check features = {} keys_to_check = ["action", "observation.state", "observation.images"] for p in src_dataset_paths: dataset = LeRobotDataset(repo_id="O24H/Src", root=p) if not features: features = { k: v for k, v in dataset.features.items() if any(k.startswith(prefix) for prefix in keys_to_check) } else: for k in features.keys(): assert k in dataset.features, f"Feature key {k} not found in dataset {p}" # pprint(dataset.features[k]) # pprint(features[k]) # assert dataset.features[k] == features[k], f"Feature key {k} mismatch in dataset {p}" # Initialize Target Dataset target_path = Path(task_config["train"]["output_data_path"]) # assert not target_path.exists(), f"Output path {target_path} already exists!" if target_path.exists(): import os os.system(f"rm -rf {target_path}") ### using images to store all data rather than videos: ### 35s per episode -> 20s per episode but size will be ~40x larger 6M -> 260M # for i in features.keys(): # if i.startswith("observation.images"): # if not features[i]["dtype"] == "image": # features[i]["dtype"] = "image" # try: # features[i].pop("info") # except KeyError: # pass # target = LeRobotDataset.create( # repo_id="O24H/Target", # fps=30, # root=target_path, # robot_type="so101_follower", # features=features, # image_writer_processes=8, # image_writer_threads=16, # use_videos=False # ) # [TODO] use the largest dataset as the base rather than creating a new one target = LeRobotDataset.create( repo_id="O24H/Target", fps=30, root=target_path, robot_type="so101_follower", features=features, image_writer_processes=8, image_writer_threads=16, ) for p in src_dataset_paths: src = LeRobotDataset(repo_id="O24H/Src", root=p) for eps_idx in tqdm(range(src.num_episodes), desc=f"Processing episode in {p.name}"): frame_idx = range( src.episode_data_index["from"][eps_idx].item(), src.episode_data_index["to"][eps_idx].item(), ) eps_data = [src.__getitem__(i) for i in frame_idx] diff_actions = [eps_data[i]["action"] - eps_data[i - 1]["action"] for i in range(1, len(eps_data))] keep_idx = [i + 1 for i, a in enumerate(diff_actions) if (a.abs() > EPS).any()] compress_ratio = len(keep_idx) / len(frame_idx) print(f"Episode {eps_idx}: compress ratio {compress_ratio:.2f}") if len(keep_idx) < 32: continue # Skip too short episodes after compression for o in keep_idx: batch = eps_data[o] image_keys = [k for k in batch.keys() if k.startswith("observation.images.")] frame = { "action": batch["action"], "observation.state": batch["observation.state"], } for k in image_keys: frame[k] = batch[k].permute(1, 2, 0).contiguous() # CHW -> HWC target.add_frame(frame, task=batch["task"]) target.save_episode()