372 lines
17 KiB
Python
372 lines
17 KiB
Python
import os
|
|
import fnmatch
|
|
import json
|
|
|
|
import h5py
|
|
import yaml
|
|
import cv2
|
|
import numpy as np
|
|
|
|
from configs.state_vec import STATE_VEC_IDX_MAPPING
|
|
|
|
class HDF5VLADataset:
|
|
"""
|
|
This class is used to sample episodes from the embododiment dataset
|
|
stored in HDF5.
|
|
"""
|
|
|
|
def __init__(self, model_config_path) -> None:
|
|
# [Modify] The path to the HDF5 dataset directory
|
|
# Each HDF5 file contains one episode
|
|
with open(model_config_path, "r") as f:
|
|
model_config = yaml.safe_load(f)
|
|
HDF5_DIR = model_config["data_path"]
|
|
self.DATASET_NAME = "agilex"
|
|
|
|
self.file_paths = []
|
|
for root, _, files in os.walk(HDF5_DIR):
|
|
for filename in fnmatch.filter(files, "*.hdf5"):
|
|
file_path = os.path.join(root, filename)
|
|
self.file_paths.append(file_path)
|
|
|
|
# Load the config
|
|
with open("configs/base.yaml", "r") as file:
|
|
config = yaml.safe_load(file)
|
|
self.CHUNK_SIZE = config["common"]["action_chunk_size"]
|
|
self.IMG_HISORY_SIZE = config["common"]["img_history_size"]
|
|
self.STATE_DIM = config["common"]["state_dim"]
|
|
|
|
# Get each episode's len (use original length, not standardized length)
|
|
episode_lens = []
|
|
for file_path in self.file_paths:
|
|
try:
|
|
with h5py.File(file_path, "r") as f:
|
|
qpos = f["observations"]["qpos"][:]
|
|
num_steps = qpos.shape[0]
|
|
episode_lens.append(num_steps)
|
|
except Exception as e:
|
|
print(f"Warning: Could not read {file_path}: {e}")
|
|
episode_lens.append(0)
|
|
self.episode_sample_weights = np.array(episode_lens) / np.sum(episode_lens)
|
|
|
|
def __len__(self):
|
|
return len(self.file_paths)
|
|
|
|
def get_dataset_name(self):
|
|
return self.DATASET_NAME
|
|
|
|
def get_item(self, index: int = None, state_only=False):
|
|
"""Get a training sample at a random timestep.
|
|
|
|
Args:
|
|
index (int, optional): the index of the episode.
|
|
If not provided, a random episode will be selected.
|
|
state_only (bool, optional): Whether to return only the state.
|
|
In this way, the sample will contain a complete trajectory rather
|
|
than a single timestep. Defaults to False.
|
|
|
|
Returns:
|
|
sample (dict): a dictionary containing the training sample.
|
|
"""
|
|
while True:
|
|
if index is None:
|
|
file_path = np.random.choice(self.file_paths, p=self.episode_sample_weights)
|
|
else:
|
|
file_path = self.file_paths[index]
|
|
valid, sample = (self.parse_hdf5_file(file_path)
|
|
if not state_only else self.parse_hdf5_file_state_only(file_path))
|
|
if valid:
|
|
return sample
|
|
else:
|
|
index = np.random.randint(0, len(self.file_paths))
|
|
|
|
def parse_hdf5_file(self, file_path):
|
|
"""[Modify] Parse a hdf5 file to generate a training sample at
|
|
a random timestep.
|
|
|
|
Args:
|
|
file_path (str): the path to the hdf5 file
|
|
|
|
Returns:
|
|
valid (bool): whether the episode is valid, which is useful for filtering.
|
|
If False, this episode will be dropped.
|
|
dict: a dictionary containing the training sample,
|
|
{
|
|
"meta": {
|
|
"dataset_name": str, # the name of your dataset.
|
|
"#steps": int, # the number of steps in the episode,
|
|
# also the total timesteps.
|
|
"instruction": str # the language instruction for this episode.
|
|
},
|
|
"step_id": int, # the index of the sampled step,
|
|
# also the timestep t.
|
|
"state": ndarray, # state[t], (1, STATE_DIM).
|
|
"state_std": ndarray, # std(state[:]), (STATE_DIM,).
|
|
"state_mean": ndarray, # mean(state[:]), (STATE_DIM,).
|
|
"state_norm": ndarray, # norm(state[:]), (STATE_DIM,).
|
|
"actions": ndarray, # action[t:t+CHUNK_SIZE], (CHUNK_SIZE, STATE_DIM).
|
|
"state_indicator", ndarray, # indicates the validness of each dim, (STATE_DIM,).
|
|
"cam_high": ndarray, # external camera image, (IMG_HISORY_SIZE, H, W, 3)
|
|
# or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
|
|
"cam_high_mask": ndarray, # indicates the validness of each timestep, (IMG_HISORY_SIZE,) boolean array.
|
|
# For the first IMAGE_HISTORY_SIZE-1 timesteps, the mask should be False.
|
|
"cam_left_wrist": ndarray, # left wrist camera image, (IMG_HISORY_SIZE, H, W, 3).
|
|
# or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
|
|
"cam_left_wrist_mask": ndarray,
|
|
"cam_right_wrist": ndarray, # right wrist camera image, (IMG_HISORY_SIZE, H, W, 3).
|
|
# or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
|
|
# If only one wrist, make it right wrist, plz.
|
|
"cam_right_wrist_mask": ndarray
|
|
} or None if the episode is invalid.
|
|
"""
|
|
with h5py.File(file_path, "r") as f:
|
|
qpos = f["observations"]["qpos"][:]
|
|
left_arm_dim = f["observations"]["left_arm_dim"][:]
|
|
right_arm_dim = f["observations"]["right_arm_dim"][:]
|
|
num_steps = qpos.shape[0]
|
|
action_dim = qpos
|
|
# [Optional] We drop too-short episode
|
|
# if num_steps < 128:
|
|
# return False, None
|
|
|
|
# [Optional] We skip the first few still steps
|
|
EPS = 1e-2
|
|
# Get the idx of the first qpos whose delta exceeds the threshold
|
|
qpos_delta = np.abs(qpos - qpos[0:1])
|
|
indices = np.where(np.any(qpos_delta > EPS, axis=1))[0]
|
|
if len(indices) > 0:
|
|
first_idx = indices[0]
|
|
else:
|
|
raise ValueError("Found no qpos that exceeds the threshold.")
|
|
|
|
# We randomly sample a timestep
|
|
step_id = np.random.randint(first_idx - 1, num_steps)
|
|
|
|
# Load the instruction
|
|
dir_path = os.path.dirname(file_path)
|
|
|
|
# with open(os.path.join(dir_path, 'instruction.json'), 'r') as f_instr:
|
|
# instruction_dict = json.load(f_instr)
|
|
# # We have 1/3 prob to use original instruction,
|
|
# # 1/3 to use simplified instruction,
|
|
# # and 1/3 to use expanded instruction.
|
|
# instruction_type = np.random.choice([
|
|
# 'instruction', 'expanded_instruction'])
|
|
# instruction = instruction_dict[instruction_type]
|
|
# if isinstance(instruction, list):
|
|
# instruction = np.random.choice(instruction)
|
|
|
|
# You can also use precomputed language embeddings (recommended)
|
|
# instruction = "path/to/lang_embed.pt"
|
|
instructions_path = os.path.join(dir_path, "instructions")
|
|
instructions_names = []
|
|
|
|
for filename in os.listdir(instructions_path):
|
|
# 检查文件名是否以.pt结尾
|
|
if filename.endswith(".pt"):
|
|
instructions_names.append(os.path.join(instructions_path, filename))
|
|
instruction = np.random.choice(instructions_names)
|
|
# print(f"choose {instruction} file as instruction.")
|
|
# Assemble the meta
|
|
meta = {
|
|
"dataset_name": self.DATASET_NAME,
|
|
"#steps": num_steps,
|
|
"step_id": step_id,
|
|
"instruction": instruction,
|
|
}
|
|
|
|
# Rescale gripper to [0, 1]
|
|
# qpos = qpos / np.array([[1 for i in range(left_arm_dim[0] + 1 + right_arm_dim[0] + 1)]])
|
|
# target_qpos = f["action"][step_id:step_id + self.CHUNK_SIZE] / np.array(
|
|
# [[1 for i in range(left_arm_dim[0] + 1 + right_arm_dim[0] + 1)]])
|
|
|
|
qpos = qpos / np.array(
|
|
# [[1, 1, 1, 1, 1, 1, 4.7908, 1, 1, 1, 1, 1, 1, 4.7888]]
|
|
[[180, 180, 180, 180, 180, 180]]
|
|
)
|
|
target_qpos = f['action'][step_id:step_id + self.CHUNK_SIZE] / np.array(
|
|
# [[1, 1, 1, 1, 1, 1, 11.8997, 1, 1, 1, 1, 1, 1, 13.9231]]
|
|
[[180, 180, 180, 180, 180, 180]]
|
|
)
|
|
|
|
# Parse the state and action
|
|
state = qpos[step_id:step_id + 1]
|
|
state_std = np.std(qpos, axis=0)
|
|
state_mean = np.mean(qpos, axis=0)
|
|
state_norm = np.sqrt(np.mean(qpos**2, axis=0))
|
|
actions = target_qpos
|
|
if actions.shape[0] < self.CHUNK_SIZE:
|
|
# Pad the actions using the last action
|
|
actions = np.concatenate(
|
|
[
|
|
actions,
|
|
np.tile(actions[-1:], (self.CHUNK_SIZE - actions.shape[0], 1)),
|
|
],
|
|
axis=0,
|
|
)
|
|
|
|
# Fill the state/action into the unified vector
|
|
|
|
def fill_in_state(values):
|
|
# Target indices corresponding to your state space
|
|
# In this example: 6 joints + 1 gripper for each arm
|
|
UNI_STATE_INDICES = [
|
|
STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(6)
|
|
# ] + [
|
|
# STATE_VEC_IDX_MAPPING["right_gripper_open"]
|
|
]
|
|
uni_vec = np.zeros(values.shape[:-1] + (self.STATE_DIM, ))
|
|
uni_vec[..., UNI_STATE_INDICES] = values
|
|
return uni_vec
|
|
|
|
state = fill_in_state(state)
|
|
state_indicator = fill_in_state(np.ones_like(state_std))
|
|
state_std = fill_in_state(state_std)
|
|
state_mean = fill_in_state(state_mean)
|
|
state_norm = fill_in_state(state_norm)
|
|
# If action's format is different from state's,
|
|
# you may implement fill_in_action()
|
|
actions = fill_in_state(actions)
|
|
|
|
# Parse the images
|
|
def parse_img(key):
|
|
imgs = []
|
|
for i in range(max(step_id - self.IMG_HISORY_SIZE + 1, 0), step_id + 1):
|
|
img_bits = f["observations"]["images"][key][i]
|
|
img = cv2.imdecode(np.frombuffer(img_bits, np.uint8), cv2.IMREAD_COLOR)
|
|
imgs.append(img)
|
|
imgs = np.stack(imgs)
|
|
if imgs.shape[0] < self.IMG_HISORY_SIZE:
|
|
# Pad the images using the first image
|
|
imgs = np.concatenate(
|
|
[
|
|
np.tile(
|
|
imgs[:1],
|
|
(self.IMG_HISORY_SIZE - imgs.shape[0], 1, 1, 1),
|
|
),
|
|
imgs,
|
|
],
|
|
axis=0,
|
|
)
|
|
return imgs
|
|
|
|
# `cam_high` is the external camera image
|
|
cam_high = parse_img("cam_high")
|
|
# For step_id = first_idx - 1, the valid_len should be one
|
|
valid_len = min(step_id - (first_idx - 1) + 1, self.IMG_HISORY_SIZE)
|
|
cam_high_mask = np.array([False] * (self.IMG_HISORY_SIZE - valid_len) + [True] * valid_len)
|
|
# cam_left_wrist = parse_img("cam_left_wrist")
|
|
# cam_left_wrist_mask = cam_high_mask.copy()
|
|
cam_left_wrist = np.zeros((self.IMG_HISORY_SIZE, 0, 0, 0))#parse_img('cam_right_wrist')
|
|
cam_left_wrist_mask = np.array([False] * self.IMG_HISORY_SIZE)#cam_high_mask.copy()
|
|
cam_right_wrist = parse_img("cam_right_wrist")
|
|
cam_right_wrist_mask = cam_high_mask.copy() # 使用相同的掩码逻辑
|
|
|
|
# Return the resulting sample
|
|
# For unavailable images, return zero-shape arrays, i.e., (IMG_HISORY_SIZE, 0, 0, 0)
|
|
# E.g., return np.zeros((self.IMG_HISORY_SIZE, 0, 0, 0)) for the key "cam_left_wrist",
|
|
# if the left-wrist camera is unavailable on your robot
|
|
return True, {
|
|
"meta": meta,
|
|
"state": state,
|
|
"state_std": state_std,
|
|
"state_mean": state_mean,
|
|
"state_norm": state_norm,
|
|
"actions": actions,
|
|
"state_indicator": state_indicator,
|
|
"cam_high": cam_high,
|
|
"cam_high_mask": cam_high_mask,
|
|
"cam_left_wrist": cam_left_wrist,
|
|
"cam_left_wrist_mask": cam_left_wrist_mask,
|
|
"cam_right_wrist": cam_right_wrist,
|
|
"cam_right_wrist_mask": cam_right_wrist_mask,
|
|
}
|
|
|
|
def parse_hdf5_file_state_only(self, file_path):
|
|
"""[Modify] Parse a hdf5 file to generate a state trajectory.
|
|
|
|
Args:
|
|
file_path (str): the path to the hdf5 file
|
|
|
|
Returns:
|
|
valid (bool): whether the episode is valid, which is useful for filtering.
|
|
If False, this episode will be dropped.
|
|
dict: a dictionary containing the training sample,
|
|
{
|
|
"state": ndarray, # state[:], (T, STATE_DIM).
|
|
"action": ndarray, # action[:], (T, STATE_DIM).
|
|
} or None if the episode is invalid.
|
|
"""
|
|
with h5py.File(file_path, "r") as f:
|
|
qpos = f["observations"]["qpos"][:]
|
|
left_arm_dim = f["observations"]["left_arm_dim"][:]
|
|
right_arm_dim = f["observations"]["right_arm_dim"][:]
|
|
|
|
num_steps = qpos.shape[0]
|
|
# [Optional] We drop too-short episode
|
|
# if num_steps < 128:
|
|
# return False, None
|
|
|
|
# [Optional] We skip the first few still steps
|
|
EPS = 1e-2
|
|
# Get the idx of the first qpos whose delta exceeds the threshold
|
|
qpos_delta = np.abs(qpos - qpos[0:1])
|
|
indices = np.where(np.any(qpos_delta > EPS, axis=1))[0]
|
|
if len(indices) > 0:
|
|
first_idx = indices[0]
|
|
else:
|
|
raise ValueError("Found no qpos that exceeds the threshold.")
|
|
|
|
# Rescale gripper to [0, 1]
|
|
# qpos = qpos / np.array([[1 for i in range(left_arm_dim[0] + right_arm_dim[0] + 2)]])
|
|
# target_qpos = f["action"][:] / np.array([[1 for i in range(left_arm_dim[0] + right_arm_dim[0] + 2)]])
|
|
|
|
qpos = qpos / np.array(
|
|
# [[1, 1, 1, 1, 1, 1, 4.7908, 1, 1, 1, 1, 1, 1, 4.7888]]
|
|
[[180, 180, 180, 180, 180, 180]]
|
|
)
|
|
target_qpos = f['action'][first_idx - 1:] / np.array(
|
|
# [[1, 1, 1, 1, 1, 1, 11.8997, 1, 1, 1, 1, 1, 1, 13.9231]]
|
|
[[180, 180, 180, 180, 180, 180]]
|
|
)
|
|
# Parse the state and action
|
|
state = qpos[first_idx - 1:]
|
|
action = target_qpos[first_idx - 1:]
|
|
|
|
# Standardize trajectory length to avoid batch size mismatch
|
|
# Use a fixed length (e.g., 128) or pad/truncate to match
|
|
target_length = 128 # You can adjust this value
|
|
if state.shape[0] > target_length:
|
|
# Truncate to target length
|
|
state = state[:target_length]
|
|
action = action[:target_length]
|
|
elif state.shape[0] < target_length:
|
|
# Pad with the last state/action
|
|
pad_length = target_length - state.shape[0]
|
|
state = np.concatenate([state, np.tile(state[-1:], (pad_length, 1))], axis=0)
|
|
action = np.concatenate([action, np.tile(action[-1:], (pad_length, 1))], axis=0)
|
|
|
|
# Fill the state/action into the unified vector
|
|
def fill_in_state(values):
|
|
# Target indices corresponding to your state space
|
|
# In this example: 6 joints + 1 gripper for each arm
|
|
UNI_STATE_INDICES = [
|
|
STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(6)
|
|
# ] + [
|
|
# STATE_VEC_IDX_MAPPING["right_gripper_open"]
|
|
]
|
|
uni_vec = np.zeros(values.shape[:-1] + (self.STATE_DIM, ))
|
|
uni_vec[..., UNI_STATE_INDICES] = values
|
|
return uni_vec
|
|
|
|
state = fill_in_state(state)
|
|
action = fill_in_state(action)
|
|
|
|
# Return the resulting sample
|
|
return True, {"state": state, "action": action}
|
|
|
|
if __name__ == "__main__":
|
|
ds = HDF5VLADataset()
|
|
for i in range(len(ds)):
|
|
print(f"Processing episode {i}/{len(ds)}...")
|
|
ds.get_item(i) |