import json import os from pathlib import Path from tqdm import tqdm from lerobot.datasets.lerobot_dataset import LeRobotDataset with open("/workspace/embolab/params/build_task.json", "r") as f: task_config = json.load(f) src_dataset_paths = [i for i in Path(task_config["merge"]["input_data_path"]).iterdir() if i.is_dir()] # EPS = 1e-2 # Feature Check features = {} keys_to_check = ["action", "observation.state", "observation.images"] for p in src_dataset_paths: dataset = LeRobotDataset(repo_id="O24H/Src", root=p) if not features: features = { k: v for k, v in dataset.features.items() if any(k.startswith(prefix) for prefix in keys_to_check) } else: for k in features.keys(): assert k in dataset.features, f"Feature key {k} not found in dataset {p}" # pprint(dataset.features[k]) # pprint(features[k]) # assert dataset.features[k] == features[k], f"Feature key {k} mismatch in dataset {p}" # Initialize Target Dataset target_path = Path(task_config["merge"]["output_data_path"]) # assert not target_path.exists(), f"Output path {target_path} already exists!" if target_path.exists(): import os os.system(f"rm -rf {target_path}") tmp_target_path = "/tmp/lerobot_merge_temp" ### using images to store all data rather than videos: ### 35s per episode -> 20s per episode but size will be ~40x larger 6M -> 260M # for i in features.keys(): # if i.startswith("observation.images"): # if not features[i]["dtype"] == "image": # features[i]["dtype"] = "image" # try: # features[i].pop("info") # except KeyError: # pass # target = LeRobotDataset.create( # repo_id="O24H/Target", # fps=30, # root=target_path, # robot_type="so101_follower", # features=features, # image_writer_processes=8, # image_writer_threads=16, # use_videos=False # ) # [TODO] use the largest dataset as the base rather than creating a new one target = LeRobotDataset.create( repo_id="O24H/Target", fps=30, root=tmp_target_path, robot_type="so101_follower", features=features, image_writer_threads=16, ) for p in src_dataset_paths: try: src = LeRobotDataset(repo_id="O24H/Src", root=p) except: print("Error while Processing: ", p, ", Skip...") for eps_idx in tqdm(range(src.num_episodes), desc=f"Processing episode in {p.name}"): frame_idx = range( src.episode_data_index["from"][eps_idx].item(), src.episode_data_index["to"][eps_idx].item(), ) # eps_data = [src.__getitem__(i) for i in frame_idx] # diff_actions = [eps_data[i]["action"] - eps_data[i - 1]["action"] for i in range(1, len(eps_data))] # keep_idx = [i + 1 for i, a in enumerate(diff_actions) if (a.abs() > EPS).any()] # compress_ratio = len(keep_idx) / len(frame_idx) # print(f"Episode {eps_idx}: compress ratio {compress_ratio:.2f}") keep_idx = frame_idx # No Compression if len(keep_idx) < 32: continue # Skip too short episodes after compression for o in keep_idx: batch = src.__getitem__(o) image_keys = [k for k in batch.keys() if k.startswith("observation.images.")] frame = { "action": batch["action"], "observation.state": batch["observation.state"], } for k in image_keys: frame[k] = batch[k].permute(1, 2, 0).contiguous() # CHW -> HWC target.add_frame(frame, task=batch["task"]) # manually copy the video and skip encoding process for key in target.meta.video_keys: __src_video_path = src.root / src.meta.get_video_file_path(eps_idx, key) __tgt_video_path = target.root / target.meta.get_video_file_path(target.num_episodes, key) __tgt_video_path.parent.mkdir(parents=True, exist_ok=True) os.system(f"cp {__src_video_path} {__tgt_video_path}") target.save_episode() # Remove images if video exists for key in target.meta.video_keys: __img_dir = target._get_image_file_path( episode_index=target.num_episodes - 1, image_key=key, frame_index=0 ).parent if __img_dir.exists(): print(f"Removing image dir: {__img_dir}") os.system(f"rm -rf {__img_dir}") # move tmp target to final target os.system(f"mv {tmp_target_path} {target_path}") print(f"Merged dataset saved at {target_path}")