61 lines
1.9 KiB
Python
61 lines
1.9 KiB
Python
|
|
from tqdm import tqdm
|
|
from cloud_helper import Client
|
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|
|
|
import cv2
|
|
import numpy as np
|
|
|
|
|
|
client = Client(host="127.0.0.1", port=50000)
|
|
|
|
src = LeRobotDataset(
|
|
repo_id="Foo/Bar",
|
|
root="./docker/inputs/pick_orange_vggt", # Modify this path accordingly
|
|
)
|
|
|
|
EPS = 1e-2 # Threshold for action change
|
|
|
|
for eps_idx in tqdm(range(src.num_episodes), desc="Processing episodes: "):
|
|
|
|
frame_idx = range(
|
|
src.episode_data_index["from"][eps_idx].item(),
|
|
src.episode_data_index["to"][eps_idx].item(),
|
|
)
|
|
|
|
# eps_data = [src.__getitem__(i) for i in frame_idx]
|
|
|
|
# diff_actions = [eps_data[i]["action"] - eps_data[i - 1]["action"] for i in range(1, len(eps_data))]
|
|
# keep_idx = [i + 1 for i, a in enumerate(diff_actions) if (a.abs() > EPS).any()]
|
|
|
|
# compress_ratio = len(keep_idx) / len(frame_idx)
|
|
# print(f"Episode {eps_idx}: compress ratio {compress_ratio:.2f}")
|
|
|
|
keep_idx = frame_idx # Keep all frames
|
|
|
|
for o in keep_idx:
|
|
|
|
# raw_batch = eps_data[o]
|
|
raw_batch = src.__getitem__(o)
|
|
|
|
state_np = raw_batch["observation.state"].numpy()
|
|
image_front_np = (raw_batch["observation.images.front"].permute(1, 2, 0).numpy() * 255).astype("uint8")
|
|
image_wrist_np = (raw_batch["observation.images.front_depth"].permute(1, 2, 0).numpy() * 255).astype("uint8")
|
|
|
|
batch = {
|
|
"observation": {
|
|
"images.front": image_front_np,
|
|
"images.wrist": np.zeros_like(image_front_np), # Placeholder for wrist image
|
|
"state": state_np,
|
|
},
|
|
"instruction": raw_batch["task"],
|
|
}
|
|
|
|
action_chunk = client.call_endpoint("get_actions", batch)
|
|
|
|
target_action = raw_batch["action"]
|
|
predicted_action = action_chunk[0]
|
|
|
|
print(f"Target action: {target_action}, Predicted action: {predicted_action}")
|
|
|
|
pass |