155 lines
5.8 KiB
Python
155 lines
5.8 KiB
Python
import tensorflow as tf
|
|
import h5py
|
|
import os
|
|
import fnmatch
|
|
import shutil
|
|
from tqdm import tqdm
|
|
from multiprocessing import Pool
|
|
import numpy as np
|
|
|
|
|
|
def _bytes_feature(value):
|
|
"""Returns a bytes_list from a string / byte."""
|
|
if isinstance(value, type(tf.constant(0))):
|
|
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
|
|
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
|
|
|
|
|
|
def _bool_feature(value):
|
|
"""Returns a bool_list from a boolean."""
|
|
return tf.train.Feature(int64_list=tf.train.Int64List(value=[int(value)]))
|
|
|
|
|
|
def serialize_example(
|
|
action,
|
|
base_action,
|
|
qpos,
|
|
qvel,
|
|
cam_high,
|
|
cam_left_wrist,
|
|
cam_right_wrist,
|
|
instruction,
|
|
terminate_episode,
|
|
):
|
|
feature = {
|
|
"action":
|
|
_bytes_feature(tf.io.serialize_tensor(action)),
|
|
"base_action":
|
|
_bytes_feature(tf.io.serialize_tensor(base_action)),
|
|
"qpos":
|
|
_bytes_feature(tf.io.serialize_tensor(qpos)),
|
|
"qvel":
|
|
_bytes_feature(tf.io.serialize_tensor(qvel)),
|
|
"cam_high":
|
|
_bytes_feature(tf.io.serialize_tensor(tf.convert_to_tensor(cam_high.tobytes(), dtype=tf.string))),
|
|
"cam_left_wrist":
|
|
_bytes_feature(tf.io.serialize_tensor(tf.convert_to_tensor(cam_left_wrist.tobytes(), dtype=tf.string))),
|
|
"cam_right_wrist":
|
|
_bytes_feature(tf.io.serialize_tensor(tf.convert_to_tensor(cam_right_wrist.tobytes(), dtype=tf.string))),
|
|
"instruction":
|
|
_bytes_feature(instruction),
|
|
"terminate_episode":
|
|
_bool_feature(terminate_episode),
|
|
}
|
|
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
|
|
return example_proto.SerializeToString()
|
|
|
|
|
|
def process_hdf5_file(args):
|
|
filepath, root_dir, out_dir = args
|
|
output_dir = os.path.join(out_dir, os.path.relpath(os.path.dirname(filepath), root_dir))
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
filename = os.path.basename(filepath)
|
|
tfrecord_path = os.path.join(output_dir, filename.replace(".hdf5", ".tfrecord"))
|
|
|
|
if os.path.exists(tfrecord_path) and os.path.getsize(tfrecord_path) > 0:
|
|
return f"TFRecords already exist at {tfrecord_path}"
|
|
try:
|
|
with h5py.File(filepath, "r") as f, tf.io.TFRecordWriter(tfrecord_path) as writer:
|
|
num_episodes = f["action"].shape[0]
|
|
# Remove the first few still steps
|
|
EPS = 1e-2
|
|
qpos = f["observations"]["qpos"][:]
|
|
# Get the idx of the first qpos whose delta exceeds the threshold
|
|
qpos_delta = np.abs(qpos - qpos[0:1])
|
|
indices = np.where(np.any(qpos_delta > EPS, axis=1))[0]
|
|
if len(indices) > 0:
|
|
first_idx = indices[0]
|
|
else:
|
|
raise ValueError("Found no qpos that exceeds the threshold.")
|
|
|
|
for i in range(first_idx - 1, num_episodes):
|
|
action = f["action"][i]
|
|
base_action = f["base_action"][i]
|
|
qpos = f["observations"]["qpos"][i]
|
|
qvel = f["observations"]["qvel"][i]
|
|
cam_high = f["observations"]["images"]["cam_high"][i]
|
|
cam_left_wrist = f["observations"]["images"]["cam_left_wrist"][i]
|
|
cam_right_wrist = f["observations"]["images"]["cam_right_wrist"][i]
|
|
instruction = f["instruction"][()]
|
|
terminate_episode = i == num_episodes - 1
|
|
serialized_example = serialize_example(
|
|
action,
|
|
base_action,
|
|
qpos,
|
|
qvel,
|
|
cam_high,
|
|
cam_left_wrist,
|
|
cam_right_wrist,
|
|
instruction,
|
|
terminate_episode,
|
|
)
|
|
writer.write(serialized_example)
|
|
except Exception as e:
|
|
with open("error_log.txt", "a") as f:
|
|
f.write(f"{filepath}\n")
|
|
print(f"error at {filepath}: {e}")
|
|
return f"TFRecords written to {tfrecord_path}"
|
|
|
|
|
|
def write_tfrecords(root_dir, out_dir):
|
|
if not os.path.exists(out_dir):
|
|
os.makedirs(out_dir)
|
|
|
|
hdf5_files = []
|
|
for root, dirs, files in os.walk(root_dir):
|
|
if os.path.exists(os.path.join(root, "expanded_instruction_gpt-4-turbo.json")):
|
|
# copy the instruction file
|
|
target_path = os.path.join(out_dir, os.path.relpath(root, root_dir))
|
|
os.makedirs(target_path, exist_ok=True)
|
|
shutil.copy(os.path.join(root, "expanded_instruction_gpt-4-turbo.json"), target_path)
|
|
elif os.path.exists(os.path.join(root, "expanded_instruction.json")):
|
|
print(root)
|
|
target_path = os.path.join(out_dir, os.path.relpath(root, root_dir))
|
|
os.makedirs(target_path, exist_ok=True)
|
|
shutil.copy(os.path.join(root, "expanded_instruction.json"), target_path)
|
|
# rename into expanded_instruction_gpt-4-turbo.json
|
|
os.rename(
|
|
os.path.join(
|
|
out_dir,
|
|
os.path.relpath(root, root_dir),
|
|
"expanded_instruction.json",
|
|
),
|
|
os.path.join(
|
|
out_dir,
|
|
os.path.relpath(root, root_dir),
|
|
"expanded_instruction_gpt-4-turbo.json",
|
|
),
|
|
)
|
|
for filename in fnmatch.filter(files, "*.hdf5"):
|
|
filepath = os.path.join(root, filename)
|
|
hdf5_files.append((filepath, root_dir, out_dir))
|
|
|
|
with Pool(16) as pool:
|
|
max_count = len(hdf5_files)
|
|
with tqdm(total=max_count) as pbar:
|
|
for _ in pool.imap_unordered(process_hdf5_file, hdf5_files):
|
|
pbar.update(1)
|
|
|
|
print(f"TFRecords written to {out_dir}")
|
|
|
|
|
|
root_dir = "../datasets/agilex/rdt_data/"
|
|
out_dir = "../datasets/agilex/tfrecords/"
|
|
write_tfrecords(root_dir, out_dir)
|