RoboTwin_image/script/pkl2zarr_dp3.py
2025-07-02 03:13:07 +00:00

99 lines
3.8 KiB
Python

import pdb, pickle, os
import numpy as np
import open3d as o3d
from copy import deepcopy
import zarr, shutil
import argparse
def main():
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('task_name', type=str)
parser.add_argument('episode_number', type=int)
args = parser.parse_args()
visualize_pcd = False
task_name = args.task_name
num = args.episode_number
current_ep, num = 0, num
load_dir = f'./data/{task_name}_pkl'
total_count = 0
save_dir = f'./policy/3D-Diffusion-Policy/3D-Diffusion-Policy/data/{task_name}_{num}.zarr'
if os.path.exists(save_dir):
shutil.rmtree(save_dir)
zarr_root = zarr.group(save_dir)
zarr_data = zarr_root.create_group('data')
zarr_meta = zarr_root.create_group('meta')
point_cloud_arrays, episode_ends_arrays, action_arrays, state_arrays, joint_action_arrays = [], [], [], [], []
while os.path.isdir(load_dir+f'/episode{current_ep}') and current_ep < num:
print(f'processing episode: {current_ep + 1} / {num}', end='\r')
file_num = 0
point_cloud_sub_arrays = []
state_sub_arrays = []
action_sub_arrays = []
joint_action_sub_arrays = []
episode_ends_sub_arrays = []
while os.path.exists(load_dir+f'/episode{current_ep}'+f'/{file_num}.pkl'):
with open(load_dir+f'/episode{current_ep}'+f'/{file_num}.pkl', 'rb') as file:
data = pickle.load(file)
pcd = data['pointcloud'][:,:]
action = data['endpose']
joint_action = data['joint_action']
point_cloud_sub_arrays.append(pcd)
state_sub_arrays.append(joint_action)
action_sub_arrays.append(action)
joint_action_sub_arrays.append(joint_action)
if visualize_pcd:
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(data['pcd']['points'])
pcd.colors = o3d.utility.Vector3dVector(data['pcd']['colors'])
o3d.visualization.draw_geometries([pcd])
file_num += 1
total_count += 1
current_ep += 1
episode_ends_arrays.append(deepcopy(total_count))
point_cloud_arrays.extend(point_cloud_sub_arrays)
action_arrays.extend(action_sub_arrays)
state_arrays.extend(state_sub_arrays)
joint_action_arrays.extend(joint_action_sub_arrays)
print()
episode_ends_arrays = np.array(episode_ends_arrays)
action_arrays = np.array(action_arrays)
state_arrays = np.array(state_arrays)
point_cloud_arrays = np.array(point_cloud_arrays)
joint_action_arrays = np.array(joint_action_arrays)
compressor = zarr.Blosc(cname='zstd', clevel=3, shuffle=1)
action_chunk_size = (100, action_arrays.shape[1])
state_chunk_size = (100, state_arrays.shape[1])
joint_chunk_size = (100, joint_action_arrays.shape[1])
point_cloud_chunk_size = (100, point_cloud_arrays.shape[1], point_cloud_arrays.shape[2])
zarr_data.create_dataset('point_cloud', data=point_cloud_arrays, chunks=point_cloud_chunk_size, dtype='float32', overwrite=True, compressor=compressor)
zarr_data.create_dataset('tcp_action', data=action_arrays, chunks=action_chunk_size, dtype='float32', overwrite=True, compressor=compressor)
zarr_data.create_dataset('state', data=state_arrays, chunks=state_chunk_size, dtype='float32', overwrite=True, compressor=compressor)
zarr_data.create_dataset('action', data=joint_action_arrays, chunks=joint_chunk_size, dtype='float32', overwrite=True, compressor=compressor)
zarr_meta.create_dataset('episode_ends', data=episode_ends_arrays, dtype='int64', overwrite=True, compressor=compressor)
if __name__ == '__main__':
main()