Lerobot/docker/smolvla_dataloader.py
2025-12-11 14:11:41 +08:00

61 lines
1.9 KiB
Python

from tqdm import tqdm
from cloud_helper import Client
from lerobot.datasets.lerobot_dataset import LeRobotDataset
import cv2
import numpy as np
client = Client(host="127.0.0.1", port=50000)
src = LeRobotDataset(
repo_id="Foo/Bar",
root="./docker/inputs/pick_orange_vggt", # Modify this path accordingly
)
EPS = 1e-2 # Threshold for action change
for eps_idx in tqdm(range(src.num_episodes), desc="Processing episodes: "):
frame_idx = range(
src.episode_data_index["from"][eps_idx].item(),
src.episode_data_index["to"][eps_idx].item(),
)
# eps_data = [src.__getitem__(i) for i in frame_idx]
# diff_actions = [eps_data[i]["action"] - eps_data[i - 1]["action"] for i in range(1, len(eps_data))]
# keep_idx = [i + 1 for i, a in enumerate(diff_actions) if (a.abs() > EPS).any()]
# compress_ratio = len(keep_idx) / len(frame_idx)
# print(f"Episode {eps_idx}: compress ratio {compress_ratio:.2f}")
keep_idx = frame_idx # Keep all frames
for o in keep_idx:
# raw_batch = eps_data[o]
raw_batch = src.__getitem__(o)
state_np = raw_batch["observation.state"].numpy()
image_front_np = (raw_batch["observation.images.front"].permute(1, 2, 0).numpy() * 255).astype("uint8")
image_wrist_np = (raw_batch["observation.images.front_depth"].permute(1, 2, 0).numpy() * 255).astype("uint8")
batch = {
"observation": {
"images.front": image_front_np,
"images.wrist": np.zeros_like(image_front_np), # Placeholder for wrist image
"state": state_np,
},
"instruction": raw_batch["task"],
}
action_chunk = client.call_endpoint("get_actions", batch)
target_action = raw_batch["action"]
predicted_action = action_chunk[0]
print(f"Target action: {target_action}, Predicted action: {predicted_action}")
pass