commit 52d79bbc5ea4f663c959a965f0e092b6959360de Author: skyxz Date: Fri Oct 24 19:25:20 2025 +0800 update diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..fb5683c --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +input/ +output/ +Temp/ +weights/ \ No newline at end of file diff --git a/RDT-1B/.dockerignore b/RDT-1B/.dockerignore new file mode 100644 index 0000000..1b91c52 --- /dev/null +++ b/RDT-1B/.dockerignore @@ -0,0 +1,2 @@ +input/* +output/* \ No newline at end of file diff --git a/RDT-1B/.gitignore b/RDT-1B/.gitignore new file mode 100644 index 0000000..924ce29 --- /dev/null +++ b/RDT-1B/.gitignore @@ -0,0 +1,7 @@ +processed_data/ +training_data/ +checkpoints/ +model_config/*.yml +wandb/* +!models/ +!data/ \ No newline at end of file diff --git a/RDT-1B/Dockerfile b/RDT-1B/Dockerfile new file mode 100644 index 0000000..2c9f0dc --- /dev/null +++ b/RDT-1B/Dockerfile @@ -0,0 +1,45 @@ + +FROM registry.d-robotics.cc/public/cuda:11.8.0-cudnn8-devel-ubuntu22.04 +# ccr-29eug8s3-pub.cnc.bj.baidubce.com/public/cuda:11.8.0-cudnn8-devel-ubuntu22.04 +WORKDIR /app + +ENV DEBIAN_FRONTEND=noninteractive +ENV PYTHONUNBUFFERED=1 +ENV TZ=Asia/Shanghai + +RUN sed -i 's/archive.ubuntu.com/mirrors.tuna.tsinghua.edu.cn/g' /etc/apt/sources.list && \ + sed -i 's/security.ubuntu.com/mirrors.tuna.tsinghua.edu.cn/g' /etc/apt/sources.list + +RUN apt-get update && apt-get install -y \ + software-properties-common \ + && add-apt-repository ppa:deadsnakes/ppa \ + && apt-get update \ + && apt-get install -y \ + python3.10 \ + python3.10-dev \ + python3.10-distutils \ + libgl1-mesa-glx \ + libglib2.0-0 \ + wget \ + ffmpeg \ + libsm6 \ + libxext6 \ + && rm -rf /var/lib/apt/lists/* + +RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1 + +COPY . /app/ + +RUN python3 -m pip install --upgrade pip + +RUN pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu121 + +RUN pip3 install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple + +RUN pip install packaging==24.0 + +RUN pip install flash_attn-2.7.2.post1+cu12torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl + +RUN mkdir -p /app/dataset/input /app/dataset/output + +ENTRYPOINT ["bash", "finetune.sh"] diff --git a/RDT-1B/__init__.py b/RDT-1B/__init__.py new file mode 100644 index 0000000..d4b6770 --- /dev/null +++ b/RDT-1B/__init__.py @@ -0,0 +1 @@ +from .deploy_policy import * diff --git a/RDT-1B/assets/head.png b/RDT-1B/assets/head.png new file mode 100644 index 0000000..c0dd6e7 Binary files /dev/null and b/RDT-1B/assets/head.png differ diff --git a/RDT-1B/configs/__pycache__/state_vec.cpython-310.pyc b/RDT-1B/configs/__pycache__/state_vec.cpython-310.pyc new file mode 100644 index 0000000..ed70d3e Binary files /dev/null and b/RDT-1B/configs/__pycache__/state_vec.cpython-310.pyc differ diff --git a/RDT-1B/configs/base.yaml b/RDT-1B/configs/base.yaml new file mode 100644 index 0000000..01b47f3 --- /dev/null +++ b/RDT-1B/configs/base.yaml @@ -0,0 +1,71 @@ +common: + # The number of historical images + img_history_size: 2 + # The number of future actions to predict + action_chunk_size: 64 + # The number of cameras to be used in the model + num_cameras: 3 + # Dimension for state/action, we use the same space for both state and action + # This MUST be equal to configs/state_vec.py + state_dim: 128 + + +dataset: + # We will extract the data from raw dataset + # and store them in the disk buffer by producer + # When training, we will read the data + # randomly from the buffer by consumer + # The producer will replace the data which has been + # read by the consumer with new data + + # The path to the buffer (at least 400GB) + buf_path: /path/to/buffer + # The number of chunks in the buffer + buf_num_chunks: 512 + # The number of samples (step rather than episode) in each chunk + buf_chunk_size: 512 + + # We will filter the episodes with length less than `epsd_len_thresh_low` + epsd_len_thresh_low: 32 + # For those more than `epsd_len_thresh_high`, + # we will randomly sample `epsd_len_thresh_high` steps each time we load the episode + # to better balance the training datasets + epsd_len_thresh_high: 2048 + # How to fit the image size + image_aspect_ratio: pad + # Maximum number of language tokens + tokenizer_max_length: 1024 + +model: + # Config for condition adpators + lang_adaptor: mlp2x_gelu + img_adaptor: mlp2x_gelu + state_adaptor: mlp3x_gelu + lang_token_dim: 4096 + img_token_dim: 1152 + # Dim of action or proprioception vector + # A `state` refers to an action or a proprioception vector + state_token_dim: 128 + # Config for RDT structure + rdt: + # 1B: num_head 32 hidden_size 2048 + hidden_size: 2048 + depth: 28 + num_heads: 32 + cond_pos_embed_type: multimodal + # For noise scheduler + noise_scheduler: + type: ddpm + num_train_timesteps: 1000 + num_inference_timesteps: 5 + beta_schedule: squaredcos_cap_v2 # Critical choice + prediction_type: sample + clip_sample: False + # For EMA (params averaging) + # We do not use EMA currently + ema: + update_after_step: 0 + inv_gamma: 1.0 + power: 0.75 + min_value: 0.0 + max_value: 0.9999 diff --git a/RDT-1B/configs/calvin_rel_traj_location_bounds_task_ABC_D.json b/RDT-1B/configs/calvin_rel_traj_location_bounds_task_ABC_D.json new file mode 100644 index 0000000..ac1679b --- /dev/null +++ b/RDT-1B/configs/calvin_rel_traj_location_bounds_task_ABC_D.json @@ -0,0 +1,50 @@ +{ + "A": [ + [ + -0.2691913843154907, + -0.21995729207992554, + -0.182277649641037 + ], + [ + 0.35127854347229004, + 0.2769763469696045, + 0.17159393429756165 + ] + ], + "B": [ + [ + -0.2576896846294403, + -0.22244493663311005, + -0.20557966828346252 + ], + [ + 0.32854634523391724, + 0.2922680974006653, + 0.17373555898666382 + ] + ], + "C": [ + [ + -0.29205888509750366, + -0.24688798189163208, + -0.17577645182609558 + ], + [ + 0.25053921341896057, + 0.3277084231376648, + 0.16431939601898193 + ] + ], + "D": [ + [ + -0.25131964683532715, + -0.15233077108860016, + -0.13294968008995056 + ], + [ + 0.19209328293800354, + 0.19344553351402283, + 0.1370421051979065 + ] + ] +} \ No newline at end of file diff --git a/RDT-1B/configs/dataset_control_freq.json b/RDT-1B/configs/dataset_control_freq.json new file mode 100644 index 0000000..70b9f0e --- /dev/null +++ b/RDT-1B/configs/dataset_control_freq.json @@ -0,0 +1,65 @@ +{ + "fractal20220817_data": 3, + "taco_play": 15, + "jaco_play": 10, + "berkeley_cable_routing": 10, + "nyu_door_opening_surprising_effectiveness": 3, + "viola": 20, + "berkeley_autolab_ur5": 5, + "toto": 30, + "kuka": 10, + "language_table": 10, + "columbia_cairlab_pusht_real": 10, + "stanford_kuka_multimodal_dataset_converted_externally_to_rlds": 20, + "nyu_rot_dataset_converted_externally_to_rlds":3, + "stanford_hydra_dataset_converted_externally_to_rlds": 10, + "austin_buds_dataset_converted_externally_to_rlds": 20, + "nyu_franka_play_dataset_converted_externally_to_rlds": 3, + "maniskill_dataset_converted_externally_to_rlds": 20, + "furniture_bench_dataset_converted_externally_to_rlds": 10, + "ucsd_kitchen_dataset_converted_externally_to_rlds": 2, + "ucsd_pick_and_place_dataset_converted_externally_to_rlds": 3, + "austin_sailor_dataset_converted_externally_to_rlds": 20, + "austin_sirius_dataset_converted_externally_to_rlds": 20, + "bc_z": 10, + "utokyo_pr2_opening_fridge_converted_externally_to_rlds": 10, + "utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": 10, + "utokyo_xarm_pick_and_place_converted_externally_to_rlds": 10, + "utokyo_xarm_bimanual_converted_externally_to_rlds": 10, + "berkeley_mvp_converted_externally_to_rlds": 5, + "berkeley_rpt_converted_externally_to_rlds": 30, + "kaist_nonprehensile_converted_externally_to_rlds": 10, + "stanford_mask_vit_converted_externally_to_rlds": 0, + "tokyo_u_lsmo_converted_externally_to_rlds": 10, + "dlr_sara_pour_converted_externally_to_rlds": 10, + "dlr_sara_grid_clamp_converted_externally_to_rlds": 10, + "dlr_edan_shared_control_converted_externally_to_rlds": 5, + "asu_table_top_converted_externally_to_rlds": 12.5, + "stanford_robocook_converted_externally_to_rlds": 5, + "eth_agent_affordances": 66.6, + "imperialcollege_sawyer_wrist_cam": 10, + "iamlab_cmu_pickup_insert_converted_externally_to_rlds": 20, + "uiuc_d3field": 1, + "utaustin_mutex": 20, + "berkeley_fanuc_manipulation": 10, + "cmu_play_fusion": 5, + "cmu_stretch": 10, + "berkeley_gnm_recon": 3, + "berkeley_gnm_cory_hall": 5, + "berkeley_gnm_sac_son": 10, + "robo_net": 1, + "roboturk_real_towercreation": 10, + "roboturk_real_laundrylayout": 10, + "roboturk_real_objectsearch": 10, + "aloha_mobile": 50, + "aloha_static": 50, + "roboset": 5, + "droid": 15, + "fmb": 10, + "dobbe": 30, + "qut_dexterous_manpulation": 30, + "agilex": 25, + "rh20t": 10, + "calvin": 30, + "bridgev2": 5 +} \ No newline at end of file diff --git a/RDT-1B/configs/dataset_img_keys.json b/RDT-1B/configs/dataset_img_keys.json new file mode 100644 index 0000000..6aede91 --- /dev/null +++ b/RDT-1B/configs/dataset_img_keys.json @@ -0,0 +1,575 @@ +{ + "fractal20220817_data": { + "image_keys": [ + "image", + "image", + "image", + "image" + ], + "image_mask":[ + 1,0,0,0 + ] + }, + "taco_play": { + "image_keys": [ + "rgb_static", + "rgb_gripper", + "rgb_static", + "rgb_static" + ], + "image_mask":[ + 1,1,0,0 + ] + }, + "jaco_play": { + "image_keys": [ + "image", + "image_wrist", + "image_wrist", + "image_wrist" + ], + "image_mask":[ + 1,1,0,0 + ] + }, + "berkeley_cable_routing": { + "image_keys": [ + "image", + "wrist45_image", + "wrist225_image", + "top_image" + ], + "image_mask":[1,1,0,1] + }, + "nyu_door_opening_surprising_effectiveness": { + "image_keys": [ + "image", + "image", + "image", + "image" + ], + "image_mask":[1,0,0,0] + }, + "viola": { + "image_keys": [ + "agentview_rgb", + "eye_in_hand_rgb", + "eye_in_hand_rgb", + "eye_in_hand_rgb" + ], + "image_mask":[1,1,0,0] + }, + "berkeley_autolab_ur5": { + "image_keys": [ + "image", + "hand_image", + "hand_image", + "hand_image" + ], + "image_mask":[1,1,0,0] + }, + "toto": { + "image_keys": [ + "image", + "image", + "image", + "image" + ], + "image_mask":[1,0,0,0] + }, + "kuka": { + "image_keys": [ + "image", + "image", + "image", + "image" + ], + "image_mask":[1,0,0,0] + }, + "language_table": { + "image_keys": [ + "rgb", + "rgb", + "rgb", + "rgb" + ], + "image_mask":[1,0,0,0] + }, + "columbia_cairlab_pusht_real": { + "image_keys": [ + "image", + "wrist_image", + "wrist_image", + "wrist_image" + ], + "image_mask":[1,1,0,0] + }, + "stanford_kuka_multimodal_dataset_converted_externally_to_rlds": { + "image_keys": [ + "image", + "image", + "image", + "image" + ], + "image_mask":[1,0,0,0] + }, + "nyu_rot_dataset_converted_externally_to_rlds": { + "image_keys": [ + "image", + "image", + "image", + "image" + ], + "image_mask":[1,0,0,0] + }, + "stanford_hydra_dataset_converted_externally_to_rlds": { + "image_keys": [ + "image", + "wrist_image", + "wrist_image", + "wrist_image" + ], + "image_mask":[1,1,0,0] + }, + "austin_buds_dataset_converted_externally_to_rlds": { + "image_keys": [ + "image", + "wrist_image", + "wrist_image", + "wrist_image" + ], + "image_mask":[1,1,0,0] + }, + "nyu_franka_play_dataset_converted_externally_to_rlds": { + "image_keys": [ + "image", + "image_additional_view", + "image_additional_view", + "image_additional_view" + ], + "image_mask":[1,0,0,1] + }, + "maniskill_dataset_converted_externally_to_rlds": { + "image_keys": [ + "image", + "wrist_image", + "wrist_image", + "wrist_image" + ], + "image_mask":[1,1,0,0] + }, + "furniture_bench_dataset_converted_externally_to_rlds": { + "image_keys": [ + "image", + "wrist_image", + "wrist_image", + "wrist_image" + ], + "image_mask":[1,1,0,0] + }, + "ucsd_kitchen_dataset_converted_externally_to_rlds": { + "image_keys": [ + "image", + "image", + "image", + "image" + ], + "image_mask":[1,0,0,0] + }, + "ucsd_pick_and_place_dataset_converted_externally_to_rlds": { + "image_keys": [ + "image", + "image", + "image", + "image" + ], + "image_mask":[1,0,0,0] + }, + "austin_sailor_dataset_converted_externally_to_rlds": { + "image_keys": [ + "image", + "wrist_image", + "wrist_image", + "wrist_image" + ], + "image_mask":[1,1,0,0] + }, + "austin_sirius_dataset_converted_externally_to_rlds": { + "image_keys": [ + "image", + "wrist_image", + "wrist_image", + "wrist_image" + ], + "image_mask":[1,1,0,0] + }, + "bc_z": { + "image_keys": [ + "image", + "image", + "image", + "image" + ], + "image_mask":[1,0,0,0] + }, + "utokyo_pr2_opening_fridge_converted_externally_to_rlds": { + "image_keys": [ + "image", + "image", + "image", + "image" + ], + "image_mask":[1,0,0,0] + }, + "utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": { + "image_keys": [ + "image", + "image", + "image", + "image" + ], + "image_mask":[1,0,0,0] + }, + "utokyo_xarm_pick_and_place_converted_externally_to_rlds": { + "image_keys": [ + "image", + "hand_image", + "hand_image", + "image2" + ], + "image_mask":[1,1,0,1] + }, + "utokyo_xarm_bimanual_converted_externally_to_rlds": { + "image_keys": [ + "image", + "image", + "image", + "image" + ], + "image_mask":[1,0,0,0] + }, + "berkeley_mvp_converted_externally_to_rlds": { + "image_keys": [ + "hand_image", + "hand_image", + "hand_image", + "hand_image" + ], + "image_mask":[0,1,0,0] + }, + "berkeley_rpt_converted_externally_to_rlds": { + "image_keys": [ + "hand_image", + "hand_image", + "hand_image", + "hand_image" + ], + "image_mask":[0,1,0,0] + }, + "kaist_nonprehensile_converted_externally_to_rlds": { + "image_keys": [ + "image", + "image", + "image", + "image" + ], + "image_mask":[1,0,0,0] + }, + "stanford_mask_vit_converted_externally_to_rlds": { + "image_keys": [ + "image", + "image", + "image", + "image" + ], + "image_mask":[1,0,0,0] + }, + "tokyo_u_lsmo_converted_externally_to_rlds": { + "image_keys": [ + "image", + "image", + "image", + "image" + ], + "image_mask":[1,0,0,0] + }, + "dlr_sara_pour_converted_externally_to_rlds": { + "image_keys": [ + "image", + "image", + "image", + "image" + ], + "image_mask":[1,0,0,0] + }, + "dlr_sara_grid_clamp_converted_externally_to_rlds": { + "image_keys": [ + "image", + "image", + "image", + "image" + ], + "image_mask":[1,0,0,0] + }, + "dlr_edan_shared_control_converted_externally_to_rlds": { + "image_keys": [ + "image", + "image", + "image", + "image" + ], + "image_mask":[1,0,0,0] + }, + "asu_table_top_converted_externally_to_rlds": { + "image_keys": [ + "image", + "image", + "image", + "image" + ], + "image_mask":[1,0,0,0] + }, + "stanford_robocook_converted_externally_to_rlds": { + "image_keys": [ + "image_2", + "image_1", + "image_3", + "image_4" + ], + "image_mask":[1,0,0,1] + }, + "eth_agent_affordances": { + "image_keys": [ + "image", + "image", + "image", + "image" + ], + "image_mask":[1,0,0,0] + }, + "imperialcollege_sawyer_wrist_cam": { + "image_keys": [ + "image", + "wrist_image", + "wrist_image", + "wrist_image" + ], + "image_mask":[0,1,0,0] + }, + "iamlab_cmu_pickup_insert_converted_externally_to_rlds": { + "image_keys": [ + "image", + "wrist_image", + "wrist_image", + "wrist_image" + ], + "image_mask":[1,1,0,0] + }, + "uiuc_d3field": { + "image_keys": [ + "image_1", + "image_2", + "image_3", + "image_4" + ], + "image_mask":[1,0,0,1] + }, + "utaustin_mutex": { + "image_keys": [ + "image", + "wrist_image", + "wrist_image", + "wrist_image" + ], + "image_mask":[1,1,0,0] + }, + "berkeley_fanuc_manipulation": { + "image_keys": [ + "image", + "wrist_image", + "wrist_image", + "wrist_image" + ], + "image_mask":[1,1,0,0] + }, + "cmu_play_fusion": { + "image_keys": [ + "image", + "image", + "image", + "image" + ], + "image_mask":[1,0,0,0] + }, + "cmu_stretch": { + "image_keys": [ + "image", + "image", + "image", + "image" + ], + "image_mask":[1,0,0,0] + }, + "berkeley_gnm_recon": { + "image_keys": [ + "image", + "image", + "image", + "image" + ], + "image_mask":[1,0,0,0] + }, + "berkeley_gnm_cory_hall": { + "image_keys": [ + "image", + "image", + "image", + "image" + ], + "image_mask":[1,0,0,0] + }, + "berkeley_gnm_sac_son": { + "image_keys": [ + "image", + "image", + "image", + "image" + ], + "image_mask":[1,0,0,0] + }, + "robo_net": { + "image_keys": [ + "image", + "image1", + "image2", + "image2" + ], + "image_mask":[1,0,0,1] + }, + "roboturk_real_towercreation": { + "image_keys": [ + "top_rgb_frame", + "front_rgb_frame", + "front_rgb_frame", + "front_rgb_frame" + ], + "image_mask":[1,0,0,1] + }, + "roboturk_real_laundrylayout": { + "image_keys": [ + "top_rgb_frame", + "front_rgb_frame", + "front_rgb_frame", + "front_rgb_frame" + ], + "image_mask":[1,0,0,1] + }, + "roboturk_real_objectsearch": { + "image_keys": [ + "top_rgb_frame", + "front_rgb_frame", + "front_rgb_frame", + "front_rgb_frame" + ], + "image_mask":[1,0,0,1] + }, + "aloha_mobile": { + "image_keys": [ + "cam_high", + "cam_right_wrist", + "cam_left_wrist", + "cam_right_wrist" + ], + "image_mask":[1,1,1,0] + }, + "aloha_static": { + "image_keys": [ + "cam_high", + "cam_right_wrist", + "cam_left_wrist", + "cam_low" + ], + "image_mask":[1,1,1,1] + }, + "roboset": { + "image_keys": [ + "rgb_top", + "rgb_right", + "rgb_left", + "rgb_right" + ], + "image_mask":[1,1,1,0] + }, + "droid": { + "image_keys": [ + "exterior_image_1_left", + "wrist_image_left", + "wrist_image_left", + "exterior_image_2_left" + ], + "image_mask":[1,1,0,1] + }, + "fmb": { + "image_keys": [ + "image_side_1", + "image_wrist_1", + "image_wrist_1", + "image_side_2" + ], + "image_mask":[1,1,0,1] + }, + "dobbe": { + "image_keys": [ + "wrist_image", + "wrist_image", + "wrist_image", + "wrist_image" + ], + "image_mask":[0,1,0,0] + }, + "qut_dexterous_manpulation": { + "image_keys": [ + "image", + "wrist_image", + "wrist_image", + "wrist_image" + ], + "image_mask":[1,1,0,0] + }, + "agilex": { + "image_keys": [ + "cam_high", + "cam_right_wrist", + "cam_left_wrist", + "cam_right_wrist" + ], + "image_mask":[1,1,1,0] + }, + "rh20t": { + "image_keys": [ + "image", + "image", + "image", + "image" + ], + "image_mask":[1,0,0,0] + }, + "calvin": { + "image_keys": [ + "rgb_static", + "rgb_gripper", + "rgb_gripper", + "rgb_gripper" + ], + "image_mask":[1,1,0,0] + }, + "bridgev2": { + "image_keys": [ + "images0", + "images0", + "images0", + "images0" + ], + "image_mask":[1,0,0,0] + } +} \ No newline at end of file diff --git a/RDT-1B/configs/dataset_stat.json b/RDT-1B/configs/dataset_stat.json new file mode 100644 index 0000000..fd08dfd --- /dev/null +++ b/RDT-1B/configs/dataset_stat.json @@ -0,0 +1,525 @@ +{ + "agilex": { + "dataset_name": "agilex", + "state_mean": [ + -0.0036545392947090432, + -0.2773659935760079, + 0.3147616748061523, + 0.3813313179910183, + 0.04028575944090457, + 0.034888520819083294, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "state_std": [ + 0.05763674563578847, + 0.2580181064167735, + 0.19785840483767897, + 0.05020347749331385, + 0.054529239104671424, + 0.05020521339363586, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "state_min": [ + -0.17447535196940103, + -0.5522612677680121, + -0.3340397516886393, + 0.21861712137858072, + -0.09725829230414497, + 0.003396739231215583, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "state_max": [ + 0.21961932712131077, + 0.30613206227620443, + 0.5444545321994357, + 0.4866888682047526, + 0.31486290825737845, + 0.3355223337809245, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ] + } +} \ No newline at end of file diff --git a/RDT-1B/configs/finetune_datasets.json b/RDT-1B/configs/finetune_datasets.json new file mode 100644 index 0000000..851cd96 --- /dev/null +++ b/RDT-1B/configs/finetune_datasets.json @@ -0,0 +1,3 @@ +[ + "agilex" +] \ No newline at end of file diff --git a/RDT-1B/configs/finetune_sample_weights.json b/RDT-1B/configs/finetune_sample_weights.json new file mode 100644 index 0000000..d16b603 --- /dev/null +++ b/RDT-1B/configs/finetune_sample_weights.json @@ -0,0 +1,3 @@ +{ + "agilex": 100 +} \ No newline at end of file diff --git a/RDT-1B/configs/pretrain_datasets.json b/RDT-1B/configs/pretrain_datasets.json new file mode 100644 index 0000000..2b766f2 --- /dev/null +++ b/RDT-1B/configs/pretrain_datasets.json @@ -0,0 +1,48 @@ +[ + "fractal20220817_data", + "jaco_play", + "taco_play", + "berkeley_cable_routing", + "viola", + "berkeley_autolab_ur5", + "toto", + "nyu_door_opening_surprising_effectiveness", + "columbia_cairlab_pusht_real", + "stanford_kuka_multimodal_dataset_converted_externally_to_rlds", + "austin_buds_dataset_converted_externally_to_rlds", + "kuka", + "utokyo_xarm_bimanual_converted_externally_to_rlds", + "stanford_hydra_dataset_converted_externally_to_rlds", + "maniskill_dataset_converted_externally_to_rlds", + "ucsd_kitchen_dataset_converted_externally_to_rlds", + "ucsd_pick_and_place_dataset_converted_externally_to_rlds", + "austin_sailor_dataset_converted_externally_to_rlds", + "austin_sirius_dataset_converted_externally_to_rlds", + "bc_z", + "utokyo_pr2_opening_fridge_converted_externally_to_rlds", + "utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds", + "utokyo_xarm_pick_and_place_converted_externally_to_rlds", + "berkeley_mvp_converted_externally_to_rlds", + "berkeley_rpt_converted_externally_to_rlds", + "kaist_nonprehensile_converted_externally_to_rlds", + "tokyo_u_lsmo_converted_externally_to_rlds", + "dlr_sara_grid_clamp_converted_externally_to_rlds", + "stanford_robocook_converted_externally_to_rlds", + "imperialcollege_sawyer_wrist_cam", + "iamlab_cmu_pickup_insert_converted_externally_to_rlds", + "utaustin_mutex", + "berkeley_fanuc_manipulation", + "cmu_play_fusion", + "language_table", + "furniture_bench_dataset_converted_externally_to_rlds", + "droid", + "fmb", + "dobbe", + "qut_dexterous_manpulation", + "aloha_mobile", + "aloha_static", + "roboset", + "rh20t", + "calvin", + "bridgev2" +] \ No newline at end of file diff --git a/RDT-1B/configs/pretrain_sample_weights.json b/RDT-1B/configs/pretrain_sample_weights.json new file mode 100644 index 0000000..60f7777 --- /dev/null +++ b/RDT-1B/configs/pretrain_sample_weights.json @@ -0,0 +1,48 @@ +{ + "fractal20220817_data": 271, + "taco_play": 60, + "jaco_play": 33, + "berkeley_cable_routing": 8, + "nyu_door_opening_surprising_effectiveness": 10, + "viola": 12, + "berkeley_autolab_ur5": 32, + "toto": 32, + "kuka": 50, + "language_table": 100, + "columbia_cairlab_pusht_real": 12, + "stanford_kuka_multimodal_dataset_converted_externally_to_rlds": 55, + "stanford_hydra_dataset_converted_externally_to_rlds": 24, + "austin_buds_dataset_converted_externally_to_rlds": 7, + "maniskill_dataset_converted_externally_to_rlds": 174, + "furniture_bench_dataset_converted_externally_to_rlds": 71, + "ucsd_kitchen_dataset_converted_externally_to_rlds": 12, + "ucsd_pick_and_place_dataset_converted_externally_to_rlds": 37, + "austin_sailor_dataset_converted_externally_to_rlds": 15, + "austin_sirius_dataset_converted_externally_to_rlds": 24, + "bc_z": 208, + "utokyo_pr2_opening_fridge_converted_externally_to_rlds": 9, + "utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": 15, + "utokyo_xarm_pick_and_place_converted_externally_to_rlds": 10, + "utokyo_xarm_bimanual_converted_externally_to_rlds": 1, + "berkeley_mvp_converted_externally_to_rlds": 22, + "berkeley_rpt_converted_externally_to_rlds": 30, + "kaist_nonprehensile_converted_externally_to_rlds": 14, + "tokyo_u_lsmo_converted_externally_to_rlds": 7, + "dlr_sara_grid_clamp_converted_externally_to_rlds": 1, + "stanford_robocook_converted_externally_to_rlds": 50, + "imperialcollege_sawyer_wrist_cam": 13, + "iamlab_cmu_pickup_insert_converted_externally_to_rlds": 25, + "utaustin_mutex": 39, + "berkeley_fanuc_manipulation": 20, + "cmu_play_fusion": 24, + "droid": 303, + "fmb": 42, + "dobbe": 36, + "qut_dexterous_manpulation": 14, + "aloha_mobile": 150, + "aloha_static": 150, + "roboset": 135, + "rh20t": 331, + "calvin": 100, + "bridgev2": 224 +} \ No newline at end of file diff --git a/RDT-1B/configs/state_vec.py b/RDT-1B/configs/state_vec.py new file mode 100644 index 0000000..7d341a9 --- /dev/null +++ b/RDT-1B/configs/state_vec.py @@ -0,0 +1,126 @@ +STATE_VEC_IDX_MAPPING = { + # [0, 10): right arm joint positions + **{ + "arm_joint_{}_pos".format(i): i + for i in range(10) + }, + **{ + "right_arm_joint_{}_pos".format(i): i + for i in range(10) + }, + # [10, 15): right gripper joint positions + **{ + "gripper_joint_{}_pos".format(i): i + 10 + for i in range(5) + }, + **{ + "right_gripper_joint_{}_pos".format(i): i + 10 + for i in range(5) + }, + "gripper_open": 10, # alias of right_gripper_joint_0_pos + "right_gripper_open": 10, + # [15, 25): right arm joint velocities + **{ + "arm_joint_{}_vel".format(i): i + 15 + for i in range(10) + }, + **{ + "right_arm_joint_{}_vel".format(i): i + 15 + for i in range(10) + }, + # [25, 30): right gripper joint velocities + **{ + "gripper_joint_{}_vel".format(i): i + 25 + for i in range(5) + }, + **{ + "right_gripper_joint_{}_vel".format(i): i + 25 + for i in range(5) + }, + "gripper_open_vel": 25, # alias of right_gripper_joint_0_vel + "right_gripper_open_vel": 25, + # [30, 33): right end effector positions + "eef_pos_x": 30, + "right_eef_pos_x": 30, + "eef_pos_y": 31, + "right_eef_pos_y": 31, + "eef_pos_z": 32, + "right_eef_pos_z": 32, + # [33, 39): right end effector 6D pose + "eef_angle_0": 33, + "right_eef_angle_0": 33, + "eef_angle_1": 34, + "right_eef_angle_1": 34, + "eef_angle_2": 35, + "right_eef_angle_2": 35, + "eef_angle_3": 36, + "right_eef_angle_3": 36, + "eef_angle_4": 37, + "right_eef_angle_4": 37, + "eef_angle_5": 38, + "right_eef_angle_5": 38, + # [39, 42): right end effector velocities + "eef_vel_x": 39, + "right_eef_vel_x": 39, + "eef_vel_y": 40, + "right_eef_vel_y": 40, + "eef_vel_z": 41, + "right_eef_vel_z": 41, + # [42, 45): right end effector angular velocities + "eef_angular_vel_roll": 42, + "right_eef_angular_vel_roll": 42, + "eef_angular_vel_pitch": 43, + "right_eef_angular_vel_pitch": 43, + "eef_angular_vel_yaw": 44, + "right_eef_angular_vel_yaw": 44, + # [45, 50): reserved + # [50, 60): left arm joint positions + **{ + "left_arm_joint_{}_pos".format(i): i + 50 + for i in range(10) + }, + # [60, 65): left gripper joint positions + **{ + "left_gripper_joint_{}_pos".format(i): i + 60 + for i in range(5) + }, + "left_gripper_open": 60, # alias of left_gripper_joint_0_pos + # [65, 75): left arm joint velocities + **{ + "left_arm_joint_{}_vel".format(i): i + 65 + for i in range(10) + }, + # [75, 80): left gripper joint velocities + **{ + "left_gripper_joint_{}_vel".format(i): i + 75 + for i in range(5) + }, + "left_gripper_open_vel": 75, # alias of left_gripper_joint_0_vel + # [80, 83): left end effector positions + "left_eef_pos_x": 80, + "left_eef_pos_y": 81, + "left_eef_pos_z": 82, + # [83, 89): left end effector 6D pose + "left_eef_angle_0": 83, + "left_eef_angle_1": 84, + "left_eef_angle_2": 85, + "left_eef_angle_3": 86, + "left_eef_angle_4": 87, + "left_eef_angle_5": 88, + # [89, 92): left end effector velocities + "left_eef_vel_x": 89, + "left_eef_vel_y": 90, + "left_eef_vel_z": 91, + # [92, 95): left end effector angular velocities + "left_eef_angular_vel_roll": 92, + "left_eef_angular_vel_pitch": 93, + "left_eef_angular_vel_yaw": 94, + # [95, 100): reserved + # [100, 102): base linear velocities + "base_vel_x": 100, + "base_vel_y": 101, + # [102, 103): base angular velocities + "base_angular_vel": 102, + # [103, 128): reserved +} +STATE_VEC_LEN = 128 diff --git a/RDT-1B/configs/zero2.json b/RDT-1B/configs/zero2.json new file mode 100644 index 0000000..678e66b --- /dev/null +++ b/RDT-1B/configs/zero2.json @@ -0,0 +1,14 @@ +{ + "bf16": { + "enabled": "auto" + }, + "train_micro_batch_size_per_gpu": "auto", + "train_batch_size": "auto", + "gradient_accumulation_steps": "auto", + "zero_optimization": { + "stage": 2, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9 + } +} \ No newline at end of file diff --git a/RDT-1B/data/.gitignore b/RDT-1B/data/.gitignore new file mode 100644 index 0000000..aa87b46 --- /dev/null +++ b/RDT-1B/data/.gitignore @@ -0,0 +1,2 @@ +# Ignore data files +datasets diff --git a/RDT-1B/data/__pycache__/compute_dataset_stat_hdf5.cpython-310.pyc b/RDT-1B/data/__pycache__/compute_dataset_stat_hdf5.cpython-310.pyc new file mode 100644 index 0000000..e469ecb Binary files /dev/null and b/RDT-1B/data/__pycache__/compute_dataset_stat_hdf5.cpython-310.pyc differ diff --git a/RDT-1B/data/__pycache__/filelock.cpython-310.pyc b/RDT-1B/data/__pycache__/filelock.cpython-310.pyc new file mode 100644 index 0000000..198e67f Binary files /dev/null and b/RDT-1B/data/__pycache__/filelock.cpython-310.pyc differ diff --git a/RDT-1B/data/__pycache__/hdf5_vla_dataset.cpython-310.pyc b/RDT-1B/data/__pycache__/hdf5_vla_dataset.cpython-310.pyc new file mode 100644 index 0000000..c12bec3 Binary files /dev/null and b/RDT-1B/data/__pycache__/hdf5_vla_dataset.cpython-310.pyc differ diff --git a/RDT-1B/data/agilex/hdf5totfrecords.py b/RDT-1B/data/agilex/hdf5totfrecords.py new file mode 100644 index 0000000..bff05ac --- /dev/null +++ b/RDT-1B/data/agilex/hdf5totfrecords.py @@ -0,0 +1,154 @@ +import tensorflow as tf +import h5py +import os +import fnmatch +import shutil +from tqdm import tqdm +from multiprocessing import Pool +import numpy as np + + +def _bytes_feature(value): + """Returns a bytes_list from a string / byte.""" + if isinstance(value, type(tf.constant(0))): + value = value.numpy() # BytesList won't unpack a string from an EagerTensor. + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + + +def _bool_feature(value): + """Returns a bool_list from a boolean.""" + return tf.train.Feature(int64_list=tf.train.Int64List(value=[int(value)])) + + +def serialize_example( + action, + base_action, + qpos, + qvel, + cam_high, + cam_left_wrist, + cam_right_wrist, + instruction, + terminate_episode, +): + feature = { + "action": + _bytes_feature(tf.io.serialize_tensor(action)), + "base_action": + _bytes_feature(tf.io.serialize_tensor(base_action)), + "qpos": + _bytes_feature(tf.io.serialize_tensor(qpos)), + "qvel": + _bytes_feature(tf.io.serialize_tensor(qvel)), + "cam_high": + _bytes_feature(tf.io.serialize_tensor(tf.convert_to_tensor(cam_high.tobytes(), dtype=tf.string))), + "cam_left_wrist": + _bytes_feature(tf.io.serialize_tensor(tf.convert_to_tensor(cam_left_wrist.tobytes(), dtype=tf.string))), + "cam_right_wrist": + _bytes_feature(tf.io.serialize_tensor(tf.convert_to_tensor(cam_right_wrist.tobytes(), dtype=tf.string))), + "instruction": + _bytes_feature(instruction), + "terminate_episode": + _bool_feature(terminate_episode), + } + example_proto = tf.train.Example(features=tf.train.Features(feature=feature)) + return example_proto.SerializeToString() + + +def process_hdf5_file(args): + filepath, root_dir, out_dir = args + output_dir = os.path.join(out_dir, os.path.relpath(os.path.dirname(filepath), root_dir)) + os.makedirs(output_dir, exist_ok=True) + filename = os.path.basename(filepath) + tfrecord_path = os.path.join(output_dir, filename.replace(".hdf5", ".tfrecord")) + + if os.path.exists(tfrecord_path) and os.path.getsize(tfrecord_path) > 0: + return f"TFRecords already exist at {tfrecord_path}" + try: + with h5py.File(filepath, "r") as f, tf.io.TFRecordWriter(tfrecord_path) as writer: + num_episodes = f["action"].shape[0] + # Remove the first few still steps + EPS = 1e-2 + qpos = f["observations"]["qpos"][:] + # Get the idx of the first qpos whose delta exceeds the threshold + qpos_delta = np.abs(qpos - qpos[0:1]) + indices = np.where(np.any(qpos_delta > EPS, axis=1))[0] + if len(indices) > 0: + first_idx = indices[0] + else: + raise ValueError("Found no qpos that exceeds the threshold.") + + for i in range(first_idx - 1, num_episodes): + action = f["action"][i] + base_action = f["base_action"][i] + qpos = f["observations"]["qpos"][i] + qvel = f["observations"]["qvel"][i] + cam_high = f["observations"]["images"]["cam_high"][i] + cam_left_wrist = f["observations"]["images"]["cam_left_wrist"][i] + cam_right_wrist = f["observations"]["images"]["cam_right_wrist"][i] + instruction = f["instruction"][()] + terminate_episode = i == num_episodes - 1 + serialized_example = serialize_example( + action, + base_action, + qpos, + qvel, + cam_high, + cam_left_wrist, + cam_right_wrist, + instruction, + terminate_episode, + ) + writer.write(serialized_example) + except Exception as e: + with open("error_log.txt", "a") as f: + f.write(f"{filepath}\n") + print(f"error at {filepath}: {e}") + return f"TFRecords written to {tfrecord_path}" + + +def write_tfrecords(root_dir, out_dir): + if not os.path.exists(out_dir): + os.makedirs(out_dir) + + hdf5_files = [] + for root, dirs, files in os.walk(root_dir): + if os.path.exists(os.path.join(root, "expanded_instruction_gpt-4-turbo.json")): + # copy the instruction file + target_path = os.path.join(out_dir, os.path.relpath(root, root_dir)) + os.makedirs(target_path, exist_ok=True) + shutil.copy(os.path.join(root, "expanded_instruction_gpt-4-turbo.json"), target_path) + elif os.path.exists(os.path.join(root, "expanded_instruction.json")): + print(root) + target_path = os.path.join(out_dir, os.path.relpath(root, root_dir)) + os.makedirs(target_path, exist_ok=True) + shutil.copy(os.path.join(root, "expanded_instruction.json"), target_path) + # rename into expanded_instruction_gpt-4-turbo.json + os.rename( + os.path.join( + out_dir, + os.path.relpath(root, root_dir), + "expanded_instruction.json", + ), + os.path.join( + out_dir, + os.path.relpath(root, root_dir), + "expanded_instruction_gpt-4-turbo.json", + ), + ) + for filename in fnmatch.filter(files, "*.hdf5"): + filepath = os.path.join(root, filename) + hdf5_files.append((filepath, root_dir, out_dir)) + + with Pool(16) as pool: + max_count = len(hdf5_files) + with tqdm(total=max_count) as pbar: + for _ in pool.imap_unordered(process_hdf5_file, hdf5_files): + pbar.update(1) + + print(f"TFRecords written to {out_dir}") + + +root_dir = "../datasets/agilex/rdt_data/" +out_dir = "../datasets/agilex/tfrecords/" +write_tfrecords(root_dir, out_dir) diff --git a/RDT-1B/data/compute_dataset_stat.py b/RDT-1B/data/compute_dataset_stat.py new file mode 100644 index 0000000..9f79a40 --- /dev/null +++ b/RDT-1B/data/compute_dataset_stat.py @@ -0,0 +1,256 @@ +""" +This file will compute the min, max, mean, and standard deviation of each datasets +in `pretrain_datasets.json` or `pretrain_datasets.json`. +""" + +import json +import argparse +import os + +# from multiprocessing import Pool, Manager + +import tensorflow as tf +import numpy as np +from tqdm import tqdm + +from data.vla_dataset import VLADataset +from data.hdf5_vla_dataset import HDF5VLADataset +from data.preprocess import generate_json_state + + +# Process each dataset to get the statistics +@tf.autograph.experimental.do_not_convert +def process_dataset(name_dataset_pair): + # print(f"PID {os.getpid()} processing {name_dataset_pair[0]}") + dataset_iter = name_dataset_pair[1] + + MAX_EPISODES = 100000 + EPS = 1e-8 + # For debugging + # MAX_EPISODES = 10 + episode_cnt = 0 + state_sum = 0 + state_sum_sq = 0 + z_state_sum = 0 + z_state_sum_sq = 0 + state_cnt = 0 + nz_state_cnt = None + state_max = None + state_min = None + for episode in dataset_iter: + episode_cnt += 1 + if episode_cnt % 1000 == 0: + print(f"Processing episodes {episode_cnt}/{MAX_EPISODES}") + if episode_cnt > MAX_EPISODES: + break + episode_dict = episode["episode_dict"] + dataset_name = episode["dataset_name"] + + res_tup = generate_json_state(episode_dict, dataset_name) + states = res_tup[1] + + # Convert to numpy + states = states.numpy() + + # Zero the values that are close to zero + z_states = states.copy() + z_states[np.abs(states) <= EPS] = 0 + # Compute the non-zero count + if nz_state_cnt is None: + nz_state_cnt = np.zeros(states.shape[1]) + nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0) + + # Update statistics + state_sum += np.sum(states, axis=0) + state_sum_sq += np.sum(states**2, axis=0) + z_state_sum += np.sum(z_states, axis=0) + z_state_sum_sq += np.sum(z_states**2, axis=0) + state_cnt += states.shape[0] + if state_max is None: + state_max = np.max(states, axis=0) + state_min = np.min(states, axis=0) + else: + state_max = np.maximum(state_max, np.max(states, axis=0)) + state_min = np.minimum(state_min, np.min(states, axis=0)) + + # Add one to avoid division by zero + nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt)) + + result = { + "dataset_name": + name_dataset_pair[0], + "state_mean": (state_sum / state_cnt).tolist(), + "state_std": + np.sqrt( + np.maximum( + (z_state_sum_sq / nz_state_cnt) - (z_state_sum / state_cnt)**2 * (state_cnt / nz_state_cnt), + np.zeros_like(state_sum_sq), + )).tolist(), + "state_min": + state_min.tolist(), + "state_max": + state_max.tolist(), + } + + return result + + +def process_hdf5_dataset(vla_dataset): + EPS = 1e-8 + episode_cnt = 0 + state_sum = 0 + state_sum_sq = 0 + z_state_sum = 0 + z_state_sum_sq = 0 + state_cnt = 0 + nz_state_cnt = None + state_max = None + state_min = None + for i in tqdm(range(len(vla_dataset))): + episode = vla_dataset.get_item(i, state_only=True) + episode_cnt += 1 + + states = episode["state"] + + # Zero the values that are close to zero + z_states = states.copy() + z_states[np.abs(states) <= EPS] = 0 + # Compute the non-zero count + if nz_state_cnt is None: + nz_state_cnt = np.zeros(states.shape[1]) + nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0) + + # Update statistics + state_sum += np.sum(states, axis=0) + state_sum_sq += np.sum(states**2, axis=0) + z_state_sum += np.sum(z_states, axis=0) + z_state_sum_sq += np.sum(z_states**2, axis=0) + state_cnt += states.shape[0] + if state_max is None: + state_max = np.max(states, axis=0) + state_min = np.min(states, axis=0) + else: + state_max = np.maximum(state_max, np.max(states, axis=0)) + state_min = np.minimum(state_min, np.min(states, axis=0)) + + # Add one to avoid division by zero + nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt)) + + result = { + "dataset_name": + vla_dataset.get_dataset_name(), + "state_mean": (state_sum / state_cnt).tolist(), + "state_std": + np.sqrt( + np.maximum( + (z_state_sum_sq / nz_state_cnt) - (z_state_sum / state_cnt)**2 * (state_cnt / nz_state_cnt), + np.zeros_like(state_sum_sq), + )).tolist(), + "state_min": + state_min.tolist(), + "state_max": + state_max.tolist(), + } + + return result + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Multiprocessing currently with bugs + # parser.add_argument('--n_workers', type=int, default=1, + # help="Number of parallel workers.") + parser.add_argument( + "--dataset_type", + type=str, + default="pretrain", + help="Whether to load the pretrain dataset or finetune dataset.", + ) + parser.add_argument( + "--save_path", + type=str, + default="configs/dataset_stat.json", + help="JSON file path to save the dataset statistics.", + ) + parser.add_argument( + "--skip_exist", + action="store_true", + help="Whether to skip the existing dataset statistics.", + ) + parser.add_argument( + "--hdf5_dataset", + action="store_true", + help="Whether to load the dataset from the HDF5 files.", + ) + args = parser.parse_args() + + if args.hdf5_dataset: + vla_dataset = HDF5VLADataset() + dataset_name = vla_dataset.get_dataset_name() + + try: + with open(args.save_path, "r") as f: + results = json.load(f) + except FileNotFoundError: + results = {} + if args.skip_exist and dataset_name in results: + print(f"Skipping existed {dataset_name} dataset statistics") + else: + print(f"Processing {dataset_name} dataset") + result = process_hdf5_dataset(vla_dataset) + results[result["dataset_name"]] = result + with open(args.save_path, "w") as f: + json.dump(results, f, indent=4) + print("All datasets have been processed.") + os._exit(0) + + vla_dataset = VLADataset(seed=0, dataset_type=args.dataset_type, repeat=False) + name_dataset_pairs = vla_dataset.name2dataset.items() + # num_workers = args.n_workers + + for name_dataset_pair in tqdm(name_dataset_pairs): + try: + with open(args.save_path, "r") as f: + results = json.load(f) + except FileNotFoundError: + results = {} + + if args.skip_exist and name_dataset_pair[0] in results: + print(f"Skipping existed {name_dataset_pair[0]} dataset statistics") + continue + print(f"Processing {name_dataset_pair[0]} dataset") + + result = process_dataset(name_dataset_pair) + + results[result["dataset_name"]] = result + + # Save the results in the json file after each dataset (for resume) + with open(args.save_path, "w") as f: + json.dump(results, f, indent=4) + + print("All datasets have been processed.") + + # with Manager() as manager: + # # Create shared dictionary and lock through the manager, accessible by all processes + # progress = manager.dict(processed=0, results={}) + # progress_lock = manager.Lock() + + # # Callback function to update progress + # def update_progress(result): + # with progress_lock: + # progress['processed'] += 1 + # print(f"{result['dataset_name']} - {progress['processed']}/{len(name_dataset_pairs)} datasets have been processed") + # # Append the result to the shared dictionary + # progress['results'][result["dataset_name"]] = result + + # with Pool(num_workers) as p: + # for name_dataset_pair in name_dataset_pairs: + # p.apply_async(process_dataset, args=(name_dataset_pair,), callback=update_progress) + + # # Close the pool and wait for the work to finish + # p.close() + # p.join() + + # # Save the results in the json file + # with open(args.save_path, 'w') as f: + # json.dump(progress['results'], f, indent=4) diff --git a/RDT-1B/data/compute_dataset_stat_hdf5.py b/RDT-1B/data/compute_dataset_stat_hdf5.py new file mode 100644 index 0000000..ffec48d --- /dev/null +++ b/RDT-1B/data/compute_dataset_stat_hdf5.py @@ -0,0 +1,112 @@ +""" +This file will compute the min, max, mean, and standard deviation of each datasets +in `pretrain_datasets.json` or `pretrain_datasets.json`. +""" + +import json +import argparse + +import numpy as np +from tqdm import tqdm + +from data.hdf5_vla_dataset import HDF5VLADataset + + +def process_hdf5_dataset(vla_dataset): + EPS = 1e-8 + episode_cnt = 0 + state_sum = 0 + state_sum_sq = 0 + z_state_sum = 0 + z_state_sum_sq = 0 + state_cnt = 0 + nz_state_cnt = None + state_max = None + state_min = None + for i in tqdm(range(len(vla_dataset))): + episode = vla_dataset.get_item(i, state_only=True) + episode_cnt += 1 + + states = episode["state"] + + # Zero the values that are close to zero + z_states = states.copy() + z_states[np.abs(states) <= EPS] = 0 + # Compute the non-zero count + if nz_state_cnt is None: + nz_state_cnt = np.zeros(states.shape[1]) + nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0) + + # Update statistics + state_sum += np.sum(states, axis=0) + state_sum_sq += np.sum(states**2, axis=0) + z_state_sum += np.sum(z_states, axis=0) + z_state_sum_sq += np.sum(z_states**2, axis=0) + state_cnt += states.shape[0] + if state_max is None: + state_max = np.max(states, axis=0) + state_min = np.min(states, axis=0) + else: + state_max = np.maximum(state_max, np.max(states, axis=0)) + state_min = np.minimum(state_min, np.min(states, axis=0)) + + # Add one to avoid division by zero + nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt)) + + result = { + "dataset_name": + vla_dataset.get_dataset_name(), + "state_mean": (state_sum / state_cnt).tolist(), + "state_std": + np.sqrt( + np.maximum( + (z_state_sum_sq / nz_state_cnt) - (z_state_sum / state_cnt)**2 * (state_cnt / nz_state_cnt), + np.zeros_like(state_sum_sq), + )).tolist(), + "state_min": + state_min.tolist(), + "state_max": + state_max.tolist(), + } + + return result + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--task_name", + type=str, + default="configs/dataset_stat.json", + help="JSON file path to save the dataset statistics.", + ) + parser.add_argument( + "--save_path", + type=str, + default="configs/dataset_stat.json", + help="JSON file path to save the dataset statistics.", + ) + parser.add_argument( + "--skip_exist", + action="store_true", + help="Whether to skip the existing dataset statistics.", + ) + args = parser.parse_args() + + vla_dataset = HDF5VLADataset(f"model_config/{args.task_name}.yml") + dataset_name = vla_dataset.get_dataset_name() + + try: + with open(args.save_path, "r") as f: + results = json.load(f) + except FileNotFoundError: + results = {} + if args.skip_exist and dataset_name in results: + print(f"Skipping existed {dataset_name} dataset statistics") + else: + print(f"Processing {dataset_name} dataset") + result = process_hdf5_dataset(vla_dataset) + results[result["dataset_name"]] = result + with open(args.save_path, "w") as f: + json.dump(results, f, indent=4) + print("All datasets have been processed.") diff --git a/RDT-1B/data/empty_lang_embed.pt b/RDT-1B/data/empty_lang_embed.pt new file mode 100644 index 0000000..39a3f6a Binary files /dev/null and b/RDT-1B/data/empty_lang_embed.pt differ diff --git a/RDT-1B/data/episode_transform.py b/RDT-1B/data/episode_transform.py new file mode 100644 index 0000000..4d69ed7 --- /dev/null +++ b/RDT-1B/data/episode_transform.py @@ -0,0 +1,398 @@ +import numpy as np +import tensorflow as tf +import yaml + +from data.preprocess import generate_json_state +from configs.state_vec import STATE_VEC_IDX_MAPPING + +# Read the config +with open("configs/base.yaml", "r") as file: + config = yaml.safe_load(file) +# Load some constants from the config +IMG_HISTORY_SIZE = config["common"]["img_history_size"] +if IMG_HISTORY_SIZE < 1: + raise ValueError("Config `img_history_size` must be at least 1.") +ACTION_CHUNK_SIZE = config["common"]["action_chunk_size"] +if ACTION_CHUNK_SIZE < 1: + raise ValueError("Config `action_chunk_size` must be at least 1.") + + +@tf.function +def process_episode(epsd: dict, dataset_name: str, image_keys: list, image_mask: list) -> dict: + """ + Process an episode to extract the frames and the json content. + """ + # Frames of each camera + # Ugly code due to tf's poor compatibility + frames_0 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True) + frames_1 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True) + frames_2 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True) + frames_3 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True) + # Traverse the episode to collect... + for step in iter(epsd["steps"]): + # Parse the image + frames_0 = frames_0.write( + frames_0.size(), + tf.cond( + tf.equal(image_mask[0], 1), + lambda: step["observation"][image_keys[0]], + lambda: tf.zeros([0, 0, 0], dtype=tf.uint8), + ), + ) + # Very ugly code due to tf's poor compatibility + frames_1 = frames_1.write( + frames_1.size(), + tf.cond( + tf.equal(image_mask[1], 1), + lambda: step["observation"][image_keys[1]], + lambda: tf.zeros([0, 0, 0], dtype=tf.uint8), + ), + ) + frames_2 = frames_2.write( + frames_2.size(), + tf.cond( + tf.equal(image_mask[2], 1), + lambda: step["observation"][image_keys[2]], + lambda: tf.zeros([0, 0, 0], dtype=tf.uint8), + ), + ) + frames_3 = frames_3.write( + frames_3.size(), + tf.cond( + tf.equal(image_mask[3], 1), + lambda: step["observation"][image_keys[3]], + lambda: tf.zeros([0, 0, 0], dtype=tf.uint8), + ), + ) + + # Calculate the past_frames_0 for each step + # Each step has a window of previous frames with size IMG_HISTORY_SIZE + # Use the first state to pad the frames + # past_frames_0 will have shape (num_steps, IMG_HISTORY_SIZE, height, width, channels) + frames_0 = frames_0.stack() + first_frame = tf.expand_dims(frames_0[0], axis=0) + first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE - 1, axis=0) + padded_frames_0 = tf.concat([first_frame, frames_0], axis=0) + indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_0)[0] + IMG_HISTORY_SIZE) + past_frames_0 = tf.map_fn(lambda i: padded_frames_0[i - IMG_HISTORY_SIZE:i], indices, dtype=tf.uint8) + frames_0_time_mask = tf.ones([tf.shape(frames_0)[0]], dtype=tf.bool) + padded_frames_0_time_mask = tf.pad( + frames_0_time_mask, + [[IMG_HISTORY_SIZE - 1, 0]], + "CONSTANT", + constant_values=False, + ) + past_frames_0_time_mask = tf.map_fn( + lambda i: padded_frames_0_time_mask[i - IMG_HISTORY_SIZE:i], + indices, + dtype=tf.bool, + ) + + # For past_frames_1 + frames_1 = frames_1.stack() + first_frame = tf.expand_dims(frames_1[0], axis=0) + first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE - 1, axis=0) + padded_frames_1 = tf.concat([first_frame, frames_1], axis=0) + indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_1)[0] + IMG_HISTORY_SIZE) + past_frames_1 = tf.map_fn(lambda i: padded_frames_1[i - IMG_HISTORY_SIZE:i], indices, dtype=tf.uint8) + frames_1_time_mask = tf.ones([tf.shape(frames_1)[0]], dtype=tf.bool) + padded_frames_1_time_mask = tf.pad( + frames_1_time_mask, + [[IMG_HISTORY_SIZE - 1, 0]], + "CONSTANT", + constant_values=False, + ) + past_frames_1_time_mask = tf.map_fn( + lambda i: padded_frames_1_time_mask[i - IMG_HISTORY_SIZE:i], + indices, + dtype=tf.bool, + ) + + # For past_frames_2 + frames_2 = frames_2.stack() + first_frame = tf.expand_dims(frames_2[0], axis=0) + first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE - 1, axis=0) + padded_frames_2 = tf.concat([first_frame, frames_2], axis=0) + indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_2)[0] + IMG_HISTORY_SIZE) + past_frames_2 = tf.map_fn(lambda i: padded_frames_2[i - IMG_HISTORY_SIZE:i], indices, dtype=tf.uint8) + frames_2_time_mask = tf.ones([tf.shape(frames_2)[0]], dtype=tf.bool) + padded_frames_2_time_mask = tf.pad( + frames_2_time_mask, + [[IMG_HISTORY_SIZE - 1, 0]], + "CONSTANT", + constant_values=False, + ) + past_frames_2_time_mask = tf.map_fn( + lambda i: padded_frames_2_time_mask[i - IMG_HISTORY_SIZE:i], + indices, + dtype=tf.bool, + ) + + # For past_frames_3 + frames_3 = frames_3.stack() + first_frame = tf.expand_dims(frames_3[0], axis=0) + first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE - 1, axis=0) + padded_frames_3 = tf.concat([first_frame, frames_3], axis=0) + indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_3)[0] + IMG_HISTORY_SIZE) + past_frames_3 = tf.map_fn(lambda i: padded_frames_3[i - IMG_HISTORY_SIZE:i], indices, dtype=tf.uint8) + frames_3_time_mask = tf.ones([tf.shape(frames_3)[0]], dtype=tf.bool) + padded_frames_3_time_mask = tf.pad( + frames_3_time_mask, + [[IMG_HISTORY_SIZE - 1, 0]], + "CONSTANT", + constant_values=False, + ) + past_frames_3_time_mask = tf.map_fn( + lambda i: padded_frames_3_time_mask[i - IMG_HISTORY_SIZE:i], + indices, + dtype=tf.bool, + ) + + # Creat the ids for each step + step_id = tf.range(0, tf.shape(frames_0)[0]) + + return { + "dataset_name": dataset_name, + "episode_dict": epsd, + "step_id": step_id, + "past_frames_0": past_frames_0, + "past_frames_0_time_mask": past_frames_0_time_mask, + "past_frames_1": past_frames_1, + "past_frames_1_time_mask": past_frames_1_time_mask, + "past_frames_2": past_frames_2, + "past_frames_2_time_mask": past_frames_2_time_mask, + "past_frames_3": past_frames_3, + "past_frames_3_time_mask": past_frames_3_time_mask, + } + + +@tf.function +def bgr_to_rgb(epsd: dict): + """ + Convert BGR images to RGB images. + """ + past_frames_0 = epsd["past_frames_0"] + past_frames_0 = tf.cond( + tf.equal(tf.shape(past_frames_0)[-1], 3), + lambda: tf.stack( + [past_frames_0[..., 2], past_frames_0[..., 1], past_frames_0[..., 0]], + axis=-1, + ), + lambda: past_frames_0, + ) + + past_frames_1 = epsd["past_frames_1"] + past_frames_1 = tf.cond( + tf.equal(tf.shape(past_frames_1)[-1], 3), + lambda: tf.stack( + [past_frames_1[..., 2], past_frames_1[..., 1], past_frames_1[..., 0]], + axis=-1, + ), + lambda: past_frames_1, + ) + + past_frames_2 = epsd["past_frames_2"] + past_frames_2 = tf.cond( + tf.equal(tf.shape(past_frames_2)[-1], 3), + lambda: tf.stack( + [past_frames_2[..., 2], past_frames_2[..., 1], past_frames_2[..., 0]], + axis=-1, + ), + lambda: past_frames_2, + ) + + past_frames_3 = epsd["past_frames_3"] + past_frames_3 = tf.cond( + tf.equal(tf.shape(past_frames_3)[-1], 3), + lambda: tf.stack( + [past_frames_3[..., 2], past_frames_3[..., 1], past_frames_3[..., 0]], + axis=-1, + ), + lambda: past_frames_3, + ) + + return { + "dataset_name": epsd["dataset_name"], + "episode_dict": epsd["episode_dict"], + "step_id": epsd["step_id"], + "past_frames_0": past_frames_0, + "past_frames_0_time_mask": epsd["past_frames_0_time_mask"], + "past_frames_1": past_frames_1, + "past_frames_1_time_mask": epsd["past_frames_1_time_mask"], + "past_frames_2": past_frames_2, + "past_frames_2_time_mask": epsd["past_frames_2_time_mask"], + "past_frames_3": past_frames_3, + "past_frames_3_time_mask": epsd["past_frames_3_time_mask"], + } + + +def flatten_episode(episode: dict) -> tf.data.Dataset: + """ + Flatten the episode to a list of steps. + """ + episode_dict = episode["episode_dict"] + dataset_name = episode["dataset_name"] + + json_content, states, masks = generate_json_state(episode_dict, dataset_name) + + # Calculate the past_states for each step + # Each step has a window of previous states with size ACTION_CHUNK_SIZE + # Use the first state to pad the states + # past_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim) + first_state = tf.expand_dims(states[0], axis=0) + first_state = tf.repeat(first_state, ACTION_CHUNK_SIZE - 1, axis=0) + padded_states = tf.concat([first_state, states], axis=0) + indices = tf.range(ACTION_CHUNK_SIZE, tf.shape(states)[0] + ACTION_CHUNK_SIZE) + past_states = tf.map_fn(lambda i: padded_states[i - ACTION_CHUNK_SIZE:i], indices, dtype=tf.float32) + states_time_mask = tf.ones([tf.shape(states)[0]], dtype=tf.bool) + padded_states_time_mask = tf.pad( + states_time_mask, + [[ACTION_CHUNK_SIZE - 1, 0]], + "CONSTANT", + constant_values=False, + ) + past_states_time_mask = tf.map_fn( + lambda i: padded_states_time_mask[i - ACTION_CHUNK_SIZE:i], + indices, + dtype=tf.bool, + ) + + # Calculate the future_states for each step + # Each step has a window of future states with size ACTION_CHUNK_SIZE + # Use the last state to pad the states + # future_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim) + last_state = tf.expand_dims(states[-1], axis=0) + last_state = tf.repeat(last_state, ACTION_CHUNK_SIZE, axis=0) + padded_states = tf.concat([states, last_state], axis=0) + indices = tf.range(1, tf.shape(states)[0] + 1) + future_states = tf.map_fn(lambda i: padded_states[i:i + ACTION_CHUNK_SIZE], indices, dtype=tf.float32) + states_time_mask = tf.ones([tf.shape(states)[0]], dtype=tf.bool) + padded_states_time_mask = tf.pad(states_time_mask, [[0, ACTION_CHUNK_SIZE]], "CONSTANT", constant_values=False) + future_states_time_mask = tf.map_fn( + lambda i: padded_states_time_mask[i:i + ACTION_CHUNK_SIZE], + indices, + dtype=tf.bool, + ) + + # Calculate the mean and std for state + state_std = tf.math.reduce_std(states, axis=0, keepdims=True) + state_std = tf.repeat(state_std, tf.shape(states)[0], axis=0) + state_mean = tf.math.reduce_mean(states, axis=0, keepdims=True) + state_mean = tf.repeat(state_mean, tf.shape(states)[0], axis=0) + + state_norm = tf.math.reduce_mean(tf.math.square(states), axis=0, keepdims=True) + state_norm = tf.math.sqrt(state_norm) + state_norm = tf.repeat(state_norm, tf.shape(states)[0], axis=0) + + # Create a list of steps + step_data = [] + for i in range(tf.shape(states)[0]): + step_data.append({ + "step_id": episode["step_id"][i], + "json_content": json_content, + "state_chunk": past_states[i], + "state_chunk_time_mask": past_states_time_mask[i], + "action_chunk": future_states[i], + "action_chunk_time_mask": future_states_time_mask[i], + "state_vec_mask": masks[i], + "past_frames_0": episode["past_frames_0"][i], + "past_frames_0_time_mask": episode["past_frames_0_time_mask"][i], + "past_frames_1": episode["past_frames_1"][i], + "past_frames_1_time_mask": episode["past_frames_1_time_mask"][i], + "past_frames_2": episode["past_frames_2"][i], + "past_frames_2_time_mask": episode["past_frames_2_time_mask"][i], + "past_frames_3": episode["past_frames_3"][i], + "past_frames_3_time_mask": episode["past_frames_3_time_mask"][i], + "state_std": state_std[i], + "state_mean": state_mean[i], + "state_norm": state_norm[i], + }) + + return step_data + + +def flatten_episode_agilex(episode: dict) -> tf.data.Dataset: + """ + Flatten the episode to a list of steps. + """ + episode_dict = episode["episode_dict"] + dataset_name = episode["dataset_name"] + + json_content, states, masks, acts = generate_json_state(episode_dict, dataset_name) + + # Calculate the past_states for each step + # Each step has a window of previous states with size ACTION_CHUNK_SIZE + # Use the first state to pad the states + # past_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim) + first_state = tf.expand_dims(states[0], axis=0) + first_state = tf.repeat(first_state, ACTION_CHUNK_SIZE - 1, axis=0) + padded_states = tf.concat([first_state, states], axis=0) + indices = tf.range(ACTION_CHUNK_SIZE, tf.shape(states)[0] + ACTION_CHUNK_SIZE) + past_states = tf.map_fn(lambda i: padded_states[i - ACTION_CHUNK_SIZE:i], indices, dtype=tf.float32) + states_time_mask = tf.ones([tf.shape(states)[0]], dtype=tf.bool) + padded_states_time_mask = tf.pad( + states_time_mask, + [[ACTION_CHUNK_SIZE - 1, 0]], + "CONSTANT", + constant_values=False, + ) + past_states_time_mask = tf.map_fn( + lambda i: padded_states_time_mask[i - ACTION_CHUNK_SIZE:i], + indices, + dtype=tf.bool, + ) + + # NOTE bg the future states shall be actions + # Calculate the future_states for each step + # Each step has a window of future states with size ACTION_CHUNK_SIZE + # Use the last action to pad the states + # future_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim) + last_act = tf.expand_dims(acts[-1], axis=0) + last_act = tf.repeat(last_act, ACTION_CHUNK_SIZE, axis=0) + padded_states = tf.concat([acts, last_act], axis=0) + # indices = tf.range(1, tf.shape(states)[0] + 1) + indices = tf.range(0, tf.shape(acts)[0]) # NOTE time 0 action = time 1 state + future_states = tf.map_fn(lambda i: padded_states[i:i + ACTION_CHUNK_SIZE], indices, dtype=tf.float32) + states_time_mask = tf.ones([tf.shape(acts)[0]], dtype=tf.bool) + padded_states_time_mask = tf.pad(states_time_mask, [[0, ACTION_CHUNK_SIZE]], "CONSTANT", constant_values=False) + future_states_time_mask = tf.map_fn( + lambda i: padded_states_time_mask[i:i + ACTION_CHUNK_SIZE], + indices, + dtype=tf.bool, + ) + + # Calculate the std and mean for state + state_std = tf.math.reduce_std(states, axis=0, keepdims=True) + state_std = tf.repeat(state_std, tf.shape(states)[0], axis=0) + state_mean = tf.math.reduce_mean(states, axis=0, keepdims=True) + state_mean = tf.repeat(state_mean, tf.shape(states)[0], axis=0) + + state_norm = tf.math.reduce_mean(tf.math.square(acts), axis=0, keepdims=True) + state_norm = tf.math.sqrt(state_norm) + state_norm = tf.repeat(state_norm, tf.shape(states)[0], axis=0) + + # Create a list of steps + step_data = [] + for i in range(tf.shape(states)[0]): + step_data.append({ + "step_id": episode["step_id"][i], + "json_content": json_content, + "state_chunk": past_states[i], + "state_chunk_time_mask": past_states_time_mask[i], + "action_chunk": future_states[i], + "action_chunk_time_mask": future_states_time_mask[i], + "state_vec_mask": masks[i], + "past_frames_0": episode["past_frames_0"][i], + "past_frames_0_time_mask": episode["past_frames_0_time_mask"][i], + "past_frames_1": episode["past_frames_1"][i], + "past_frames_1_time_mask": episode["past_frames_1_time_mask"][i], + "past_frames_2": episode["past_frames_2"][i], + "past_frames_2_time_mask": episode["past_frames_2_time_mask"][i], + "past_frames_3": episode["past_frames_3"][i], + "past_frames_3_time_mask": episode["past_frames_3_time_mask"][i], + "state_std": state_std[i], + "state_mean": state_mean[i], + "state_norm": state_norm[i], + }) + + return step_data diff --git a/RDT-1B/data/filelock.py b/RDT-1B/data/filelock.py new file mode 100644 index 0000000..66b70d0 --- /dev/null +++ b/RDT-1B/data/filelock.py @@ -0,0 +1,25 @@ +import fcntl + + +class FileLock: + """ + A file lock class. + """ + + def __init__(self, filename): + self.filename = filename + self.handle = None + + def acquire_read_lock(self): + self.handle = open(self.filename + ".lock", "r") + fcntl.flock(self.handle, fcntl.LOCK_SH | fcntl.LOCK_NB) + + def acquire_write_lock(self): + self.handle = open(self.filename + ".lock", "w") + fcntl.flock(self.handle, fcntl.LOCK_EX | fcntl.LOCK_NB) + + def release_lock(self): + if self.handle is not None: + fcntl.flock(self.handle, fcntl.LOCK_UN) + self.handle.close() + self.handle = None diff --git a/RDT-1B/data/hdf5_vla_dataset.py b/RDT-1B/data/hdf5_vla_dataset.py new file mode 100644 index 0000000..f3ec570 --- /dev/null +++ b/RDT-1B/data/hdf5_vla_dataset.py @@ -0,0 +1,372 @@ +import os +import fnmatch +import json + +import h5py +import yaml +import cv2 +import numpy as np + +from configs.state_vec import STATE_VEC_IDX_MAPPING + +class HDF5VLADataset: + """ + This class is used to sample episodes from the embododiment dataset + stored in HDF5. + """ + + def __init__(self, model_config_path) -> None: + # [Modify] The path to the HDF5 dataset directory + # Each HDF5 file contains one episode + with open(model_config_path, "r") as f: + model_config = yaml.safe_load(f) + HDF5_DIR = model_config["data_path"] + self.DATASET_NAME = "agilex" + + self.file_paths = [] + for root, _, files in os.walk(HDF5_DIR): + for filename in fnmatch.filter(files, "*.hdf5"): + file_path = os.path.join(root, filename) + self.file_paths.append(file_path) + + # Load the config + with open("configs/base.yaml", "r") as file: + config = yaml.safe_load(file) + self.CHUNK_SIZE = config["common"]["action_chunk_size"] + self.IMG_HISORY_SIZE = config["common"]["img_history_size"] + self.STATE_DIM = config["common"]["state_dim"] + + # Get each episode's len (use original length, not standardized length) + episode_lens = [] + for file_path in self.file_paths: + try: + with h5py.File(file_path, "r") as f: + qpos = f["observations"]["qpos"][:] + num_steps = qpos.shape[0] + episode_lens.append(num_steps) + except Exception as e: + print(f"Warning: Could not read {file_path}: {e}") + episode_lens.append(0) + self.episode_sample_weights = np.array(episode_lens) / np.sum(episode_lens) + + def __len__(self): + return len(self.file_paths) + + def get_dataset_name(self): + return self.DATASET_NAME + + def get_item(self, index: int = None, state_only=False): + """Get a training sample at a random timestep. + + Args: + index (int, optional): the index of the episode. + If not provided, a random episode will be selected. + state_only (bool, optional): Whether to return only the state. + In this way, the sample will contain a complete trajectory rather + than a single timestep. Defaults to False. + + Returns: + sample (dict): a dictionary containing the training sample. + """ + while True: + if index is None: + file_path = np.random.choice(self.file_paths, p=self.episode_sample_weights) + else: + file_path = self.file_paths[index] + valid, sample = (self.parse_hdf5_file(file_path) + if not state_only else self.parse_hdf5_file_state_only(file_path)) + if valid: + return sample + else: + index = np.random.randint(0, len(self.file_paths)) + + def parse_hdf5_file(self, file_path): + """[Modify] Parse a hdf5 file to generate a training sample at + a random timestep. + + Args: + file_path (str): the path to the hdf5 file + + Returns: + valid (bool): whether the episode is valid, which is useful for filtering. + If False, this episode will be dropped. + dict: a dictionary containing the training sample, + { + "meta": { + "dataset_name": str, # the name of your dataset. + "#steps": int, # the number of steps in the episode, + # also the total timesteps. + "instruction": str # the language instruction for this episode. + }, + "step_id": int, # the index of the sampled step, + # also the timestep t. + "state": ndarray, # state[t], (1, STATE_DIM). + "state_std": ndarray, # std(state[:]), (STATE_DIM,). + "state_mean": ndarray, # mean(state[:]), (STATE_DIM,). + "state_norm": ndarray, # norm(state[:]), (STATE_DIM,). + "actions": ndarray, # action[t:t+CHUNK_SIZE], (CHUNK_SIZE, STATE_DIM). + "state_indicator", ndarray, # indicates the validness of each dim, (STATE_DIM,). + "cam_high": ndarray, # external camera image, (IMG_HISORY_SIZE, H, W, 3) + # or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable. + "cam_high_mask": ndarray, # indicates the validness of each timestep, (IMG_HISORY_SIZE,) boolean array. + # For the first IMAGE_HISTORY_SIZE-1 timesteps, the mask should be False. + "cam_left_wrist": ndarray, # left wrist camera image, (IMG_HISORY_SIZE, H, W, 3). + # or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable. + "cam_left_wrist_mask": ndarray, + "cam_right_wrist": ndarray, # right wrist camera image, (IMG_HISORY_SIZE, H, W, 3). + # or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable. + # If only one wrist, make it right wrist, plz. + "cam_right_wrist_mask": ndarray + } or None if the episode is invalid. + """ + with h5py.File(file_path, "r") as f: + qpos = f["observations"]["qpos"][:] + left_arm_dim = f["observations"]["left_arm_dim"][:] + right_arm_dim = f["observations"]["right_arm_dim"][:] + num_steps = qpos.shape[0] + action_dim = qpos + # [Optional] We drop too-short episode + # if num_steps < 128: + # return False, None + + # [Optional] We skip the first few still steps + EPS = 1e-2 + # Get the idx of the first qpos whose delta exceeds the threshold + qpos_delta = np.abs(qpos - qpos[0:1]) + indices = np.where(np.any(qpos_delta > EPS, axis=1))[0] + if len(indices) > 0: + first_idx = indices[0] + else: + raise ValueError("Found no qpos that exceeds the threshold.") + + # We randomly sample a timestep + step_id = np.random.randint(first_idx - 1, num_steps) + + # Load the instruction + dir_path = os.path.dirname(file_path) + + # with open(os.path.join(dir_path, 'instruction.json'), 'r') as f_instr: + # instruction_dict = json.load(f_instr) + # # We have 1/3 prob to use original instruction, + # # 1/3 to use simplified instruction, + # # and 1/3 to use expanded instruction. + # instruction_type = np.random.choice([ + # 'instruction', 'expanded_instruction']) + # instruction = instruction_dict[instruction_type] + # if isinstance(instruction, list): + # instruction = np.random.choice(instruction) + + # You can also use precomputed language embeddings (recommended) + # instruction = "path/to/lang_embed.pt" + instructions_path = os.path.join(dir_path, "instructions") + instructions_names = [] + + for filename in os.listdir(instructions_path): + # 检查文件名是否以.pt结尾 + if filename.endswith(".pt"): + instructions_names.append(os.path.join(instructions_path, filename)) + instruction = np.random.choice(instructions_names) + # print(f"choose {instruction} file as instruction.") + # Assemble the meta + meta = { + "dataset_name": self.DATASET_NAME, + "#steps": num_steps, + "step_id": step_id, + "instruction": instruction, + } + + # Rescale gripper to [0, 1] + # qpos = qpos / np.array([[1 for i in range(left_arm_dim[0] + 1 + right_arm_dim[0] + 1)]]) + # target_qpos = f["action"][step_id:step_id + self.CHUNK_SIZE] / np.array( + # [[1 for i in range(left_arm_dim[0] + 1 + right_arm_dim[0] + 1)]]) + + qpos = qpos / np.array( + # [[1, 1, 1, 1, 1, 1, 4.7908, 1, 1, 1, 1, 1, 1, 4.7888]] + [[180, 180, 180, 180, 180, 180]] + ) + target_qpos = f['action'][step_id:step_id + self.CHUNK_SIZE] / np.array( + # [[1, 1, 1, 1, 1, 1, 11.8997, 1, 1, 1, 1, 1, 1, 13.9231]] + [[180, 180, 180, 180, 180, 180]] + ) + + # Parse the state and action + state = qpos[step_id:step_id + 1] + state_std = np.std(qpos, axis=0) + state_mean = np.mean(qpos, axis=0) + state_norm = np.sqrt(np.mean(qpos**2, axis=0)) + actions = target_qpos + if actions.shape[0] < self.CHUNK_SIZE: + # Pad the actions using the last action + actions = np.concatenate( + [ + actions, + np.tile(actions[-1:], (self.CHUNK_SIZE - actions.shape[0], 1)), + ], + axis=0, + ) + + # Fill the state/action into the unified vector + + def fill_in_state(values): + # Target indices corresponding to your state space + # In this example: 6 joints + 1 gripper for each arm + UNI_STATE_INDICES = [ + STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(6) + # ] + [ + # STATE_VEC_IDX_MAPPING["right_gripper_open"] + ] + uni_vec = np.zeros(values.shape[:-1] + (self.STATE_DIM, )) + uni_vec[..., UNI_STATE_INDICES] = values + return uni_vec + + state = fill_in_state(state) + state_indicator = fill_in_state(np.ones_like(state_std)) + state_std = fill_in_state(state_std) + state_mean = fill_in_state(state_mean) + state_norm = fill_in_state(state_norm) + # If action's format is different from state's, + # you may implement fill_in_action() + actions = fill_in_state(actions) + + # Parse the images + def parse_img(key): + imgs = [] + for i in range(max(step_id - self.IMG_HISORY_SIZE + 1, 0), step_id + 1): + img_bits = f["observations"]["images"][key][i] + img = cv2.imdecode(np.frombuffer(img_bits, np.uint8), cv2.IMREAD_COLOR) + imgs.append(img) + imgs = np.stack(imgs) + if imgs.shape[0] < self.IMG_HISORY_SIZE: + # Pad the images using the first image + imgs = np.concatenate( + [ + np.tile( + imgs[:1], + (self.IMG_HISORY_SIZE - imgs.shape[0], 1, 1, 1), + ), + imgs, + ], + axis=0, + ) + return imgs + + # `cam_high` is the external camera image + cam_high = parse_img("cam_high") + # For step_id = first_idx - 1, the valid_len should be one + valid_len = min(step_id - (first_idx - 1) + 1, self.IMG_HISORY_SIZE) + cam_high_mask = np.array([False] * (self.IMG_HISORY_SIZE - valid_len) + [True] * valid_len) + # cam_left_wrist = parse_img("cam_left_wrist") + # cam_left_wrist_mask = cam_high_mask.copy() + cam_left_wrist = np.zeros((self.IMG_HISORY_SIZE, 0, 0, 0))#parse_img('cam_right_wrist') + cam_left_wrist_mask = np.array([False] * self.IMG_HISORY_SIZE)#cam_high_mask.copy() + cam_right_wrist = parse_img("cam_right_wrist") + cam_right_wrist_mask = cam_high_mask.copy() # 使用相同的掩码逻辑 + + # Return the resulting sample + # For unavailable images, return zero-shape arrays, i.e., (IMG_HISORY_SIZE, 0, 0, 0) + # E.g., return np.zeros((self.IMG_HISORY_SIZE, 0, 0, 0)) for the key "cam_left_wrist", + # if the left-wrist camera is unavailable on your robot + return True, { + "meta": meta, + "state": state, + "state_std": state_std, + "state_mean": state_mean, + "state_norm": state_norm, + "actions": actions, + "state_indicator": state_indicator, + "cam_high": cam_high, + "cam_high_mask": cam_high_mask, + "cam_left_wrist": cam_left_wrist, + "cam_left_wrist_mask": cam_left_wrist_mask, + "cam_right_wrist": cam_right_wrist, + "cam_right_wrist_mask": cam_right_wrist_mask, + } + + def parse_hdf5_file_state_only(self, file_path): + """[Modify] Parse a hdf5 file to generate a state trajectory. + + Args: + file_path (str): the path to the hdf5 file + + Returns: + valid (bool): whether the episode is valid, which is useful for filtering. + If False, this episode will be dropped. + dict: a dictionary containing the training sample, + { + "state": ndarray, # state[:], (T, STATE_DIM). + "action": ndarray, # action[:], (T, STATE_DIM). + } or None if the episode is invalid. + """ + with h5py.File(file_path, "r") as f: + qpos = f["observations"]["qpos"][:] + left_arm_dim = f["observations"]["left_arm_dim"][:] + right_arm_dim = f["observations"]["right_arm_dim"][:] + + num_steps = qpos.shape[0] + # [Optional] We drop too-short episode + # if num_steps < 128: + # return False, None + + # [Optional] We skip the first few still steps + EPS = 1e-2 + # Get the idx of the first qpos whose delta exceeds the threshold + qpos_delta = np.abs(qpos - qpos[0:1]) + indices = np.where(np.any(qpos_delta > EPS, axis=1))[0] + if len(indices) > 0: + first_idx = indices[0] + else: + raise ValueError("Found no qpos that exceeds the threshold.") + + # Rescale gripper to [0, 1] + # qpos = qpos / np.array([[1 for i in range(left_arm_dim[0] + right_arm_dim[0] + 2)]]) + # target_qpos = f["action"][:] / np.array([[1 for i in range(left_arm_dim[0] + right_arm_dim[0] + 2)]]) + + qpos = qpos / np.array( + # [[1, 1, 1, 1, 1, 1, 4.7908, 1, 1, 1, 1, 1, 1, 4.7888]] + [[180, 180, 180, 180, 180, 180]] + ) + target_qpos = f['action'][first_idx - 1:] / np.array( + # [[1, 1, 1, 1, 1, 1, 11.8997, 1, 1, 1, 1, 1, 1, 13.9231]] + [[180, 180, 180, 180, 180, 180]] + ) + # Parse the state and action + state = qpos[first_idx - 1:] + action = target_qpos[first_idx - 1:] + + # Standardize trajectory length to avoid batch size mismatch + # Use a fixed length (e.g., 128) or pad/truncate to match + target_length = 128 # You can adjust this value + if state.shape[0] > target_length: + # Truncate to target length + state = state[:target_length] + action = action[:target_length] + elif state.shape[0] < target_length: + # Pad with the last state/action + pad_length = target_length - state.shape[0] + state = np.concatenate([state, np.tile(state[-1:], (pad_length, 1))], axis=0) + action = np.concatenate([action, np.tile(action[-1:], (pad_length, 1))], axis=0) + + # Fill the state/action into the unified vector + def fill_in_state(values): + # Target indices corresponding to your state space + # In this example: 6 joints + 1 gripper for each arm + UNI_STATE_INDICES = [ + STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(6) + # ] + [ + # STATE_VEC_IDX_MAPPING["right_gripper_open"] + ] + uni_vec = np.zeros(values.shape[:-1] + (self.STATE_DIM, )) + uni_vec[..., UNI_STATE_INDICES] = values + return uni_vec + + state = fill_in_state(state) + action = fill_in_state(action) + + # Return the resulting sample + return True, {"state": state, "action": action} + +if __name__ == "__main__": + ds = HDF5VLADataset() + for i in range(len(ds)): + print(f"Processing episode {i}/{len(ds)}...") + ds.get_item(i) \ No newline at end of file diff --git a/RDT-1B/data/preprocess.py b/RDT-1B/data/preprocess.py new file mode 100644 index 0000000..05df31e --- /dev/null +++ b/RDT-1B/data/preprocess.py @@ -0,0 +1,299 @@ +import json + +import tensorflow as tf +import yaml + +from data.preprocess_scripts import * +from configs.state_vec import STATE_VEC_IDX_MAPPING, STATE_VEC_LEN +from data.utils import capitalize_and_period + +# The dataset without state +DATASET_NAMES_NO_STATE = [ + "nyu_door_opening_surprising_effectiveness", + "usc_cloth_sim_converted_externally_to_rlds", + "cmu_franka_exploration_dataset_converted_externally_to_rlds", + "imperialcollege_sawyer_wrist_cam", +] + +# Read the image keys of each dataset +with open("configs/dataset_img_keys.json", "r") as file: + IMAGE_KEYS = json.load(file) +# Read the config +with open("configs/base.yaml", "r") as file: + config = yaml.safe_load(file) + + +def assemble_state_vec(arm_concat: tf.Tensor, arm_format: str, base_concat=None, base_format=None) -> tf.Tensor: + """ + Assemble the state/action vector from the arm and base. + """ + state_vec = tf.zeros(STATE_VEC_LEN, dtype=tf.float32) + mask_vec = tf.zeros(STATE_VEC_LEN, dtype=tf.float32) + + # Assemble the arm state + arm_concat = tf.cast(arm_concat, tf.float32) + arm_format = arm_format.split(",") + # Use the scatter_nd to avoid the duplicate indices + state_vec = tf.tensor_scatter_nd_update(state_vec, [[STATE_VEC_IDX_MAPPING[name]] for name in arm_format], + arm_concat) + mask_vec = tf.tensor_scatter_nd_update( + mask_vec, + [[STATE_VEC_IDX_MAPPING[name]] for name in arm_format], + tf.ones(len(arm_format), dtype=tf.float32), + ) + + # Assemble the base state if exists + if base_concat is not None: + base_concat = tf.cast(base_concat, tf.float32) + base_format = base_format.split(",") + state_vec = tf.tensor_scatter_nd_update( + state_vec, + [[STATE_VEC_IDX_MAPPING[name]] for name in base_format], + base_concat, + ) + mask_vec = tf.tensor_scatter_nd_update( + mask_vec, + [[STATE_VEC_IDX_MAPPING[name]] for name in base_format], + tf.ones(len(base_format), dtype=tf.float32), + ) + return state_vec, mask_vec + + +@tf.autograph.experimental.do_not_convert +def _generate_json_state_agilex(episode: dict, dataset_name: str): + """ + Generate the json dict and state for a given episode. + """ + # Load some constants from the config + IMG_HISTORY_SIZE = config["common"]["img_history_size"] + if IMG_HISTORY_SIZE < 1: + raise ValueError("Config `img_history_size` must be at least 1.") + ACTION_CHUNK_SIZE = config["common"]["action_chunk_size"] + if ACTION_CHUNK_SIZE < 1: + raise ValueError("Config `action_chunk_size` must be at least 1.") + + # Initialize the episode_metadata + episode_metadata = {"dataset_name": dataset_name, "#steps": 0, "instruction": None} + + # Check whether this episode has an 'END' + base_act = None + last_base_act = None + episode_states = [] + episode_acts = [] + episode_masks = [] + has_base = None + for step_id, step in enumerate(iter(episode["steps"])): + # Parse the action + action = step["action"] + if has_base is None: + has_base = "base_concat" in action + if has_base: + base_act = action["base_concat"] + + # Parse the state + state = step["observation"] + + arm_format = state["format"].numpy().decode("utf-8") + base_format = None + if has_base: + act_format = action["format"].numpy().decode("utf-8") + base_formate_idx = act_format.find("base") + base_format = act_format[base_formate_idx:] + + arm_state = state["arm_concat"] + base_state = None + if has_base: + if last_base_act is None: + base_state = base_act * 0 + else: + base_state = last_base_act + last_base_act = base_act + + # Assemble the state vector + state_vec, mask_vec = assemble_state_vec(arm_state, arm_format, base_state, base_format) + + act_vec, mask_vec = assemble_state_vec(action["arm_concat"], arm_format, base_state, base_format) + + episode_states.append(state_vec) + episode_masks.append(mask_vec) + episode_acts.append(act_vec) + + # Parse the task instruction + instr = step["observation"]["natural_language_instruction"] + instr = instr.numpy().decode("utf-8") + instr = capitalize_and_period(instr) + + # Write to the episode_metadata + if episode_metadata["instruction"] is None: + episode_metadata["instruction"] = instr + + episode_metadata["#steps"] = step_id + + episode_states = tf.stack(episode_states) + episode_masks = tf.stack(episode_masks) + episode_acts = tf.stack(episode_acts) + + return episode_metadata, episode_states, episode_masks, episode_acts + + +@tf.autograph.experimental.do_not_convert +def _generate_json_state(episode: dict, dataset_name: str): + """ + Generate the json dict and state for a given episode. + """ + # Load some constants from the config + IMG_HISTORY_SIZE = config["common"]["img_history_size"] + if IMG_HISTORY_SIZE < 1: + raise ValueError("Config `img_history_size` must be at least 1.") + ACTION_CHUNK_SIZE = config["common"]["action_chunk_size"] + if ACTION_CHUNK_SIZE < 1: + raise ValueError("Config `action_chunk_size` must be at least 1.") + + # Initialize the episode_metadata + episode_metadata = {"dataset_name": dataset_name, "#steps": 0, "instruction": None} + + # Check whether this episode has an 'END' + base_act = None + last_base_act = None + episode_states = [] + episode_masks = [] + has_base = None + for step_id, step in enumerate(iter(episode["steps"])): + # Parse the action + action = step["action"] + if has_base is None: + has_base = "base_concat" in action + if has_base: + base_act = action["base_concat"] + + # Parse the state + state = step["observation"] + + arm_format = state["format"].numpy().decode("utf-8") + base_format = None + if has_base: + act_format = action["format"].numpy().decode("utf-8") + base_formate_idx = act_format.find("base") + base_format = act_format[base_formate_idx:] + + arm_state = state["arm_concat"] + base_state = None + if has_base: + if last_base_act is None: + base_state = base_act * 0 + else: + base_state = last_base_act + last_base_act = base_act + + # Assemble the state vector + state_vec, mask_vec = assemble_state_vec(arm_state, arm_format, base_state, base_format) + + episode_states.append(state_vec) + episode_masks.append(mask_vec) + + # Parse the task instruction + instr = step["observation"]["natural_language_instruction"] + instr = instr.numpy().decode("utf-8") + instr = capitalize_and_period(instr) + + # Write to the episode_metadata + if episode_metadata["instruction"] is None: + episode_metadata["instruction"] = instr + + episode_metadata["#steps"] = step_id + episode_states = tf.stack(episode_states) + episode_masks = tf.stack(episode_masks) + + return episode_metadata, episode_states, episode_masks + + +@tf.autograph.experimental.do_not_convert +def _generate_json_state_nostate_ds(episode: dict, dataset_name: str): + """ + Generate the json dict and state for an episode in the dataset without state. + If not state, we use the last action as current state. + """ + # Load some constants from the config + IMG_HISTORY_SIZE = config["common"]["img_history_size"] + if IMG_HISTORY_SIZE < 1: + raise ValueError("Config `img_history_size` must be at least 1.") + ACTION_CHUNK_SIZE = config["common"]["action_chunk_size"] + if ACTION_CHUNK_SIZE < 1: + raise ValueError("Config `action_chunk_size` must be at least 1.") + + # Initialize the episode_metadata + episode_metadata = {"dataset_name": dataset_name, "#steps": 0, "instruction": None} + + last_base_act = None + last_arm_act = None + episode_states = [] + episode_masks = [] + has_base = None + for step_id, step in enumerate(iter(episode["steps"])): + # Parse the action + action = step["action"] + if has_base is None: + has_base = "base_concat" in action + if has_base: + base_act = action["base_concat"] + if last_base_act is None: + last_base_act = base_act * 0 # Initialize + + # Parse the arm action + arm_act = action["arm_concat"] + if last_arm_act is None: + last_arm_act = arm_act * 0 # Initialize + + # Parse the act format + # Action format as the state format + act_format = action["format"].numpy().decode("utf-8") + + # Assemble the state vector + if has_base: + last_act_concat = tf.concat([last_arm_act, last_base_act], axis=0) + else: + last_act_concat = last_arm_act + state_vec, mask_vec = assemble_state_vec(last_act_concat, act_format) + + episode_states.append(state_vec) + episode_masks.append(mask_vec) + + # Parse the task instruction + instr = step["observation"]["natural_language_instruction"] + instr = instr.numpy().decode("utf-8") + instr = capitalize_and_period(instr) + + # Write to the episode_metadata + if episode_metadata["instruction"] is None: + episode_metadata["instruction"] = instr + + # Update the last_arm_act and last_base_act + last_arm_act = arm_act + if has_base: + last_base_act = base_act + + episode_metadata["#steps"] = step_id + episode_states = tf.stack(episode_states) + episode_masks = tf.stack(episode_masks) + + return episode_metadata, episode_states, episode_masks + + +@tf.autograph.experimental.do_not_convert +def generate_json_state(episode: dict, dataset_name: str): + """ + Generate the json dict and state for an episode. + """ + if isinstance(dataset_name, tf.Tensor): + dataset_name = dataset_name.numpy().decode("utf-8") + + # Process each step in the episode + episode["steps"] = episode["steps"].map(globals()[dataset_name].process_step, ) + + if dataset_name == "agilex": + return _generate_json_state_agilex(episode, dataset_name) + + if dataset_name in DATASET_NAMES_NO_STATE: + return _generate_json_state_nostate_ds(episode, dataset_name) + + return _generate_json_state(episode, dataset_name) diff --git a/RDT-1B/data/producer.py b/RDT-1B/data/producer.py new file mode 100644 index 0000000..2c761de --- /dev/null +++ b/RDT-1B/data/producer.py @@ -0,0 +1,313 @@ +import time +import json +import os +import time +import argparse +import sys +import signal +import random +from multiprocessing import Process + +import numpy as np +import tensorflow as tf +import yaml + +from data.vla_dataset import VLADataset +from data.filelock import FileLock + +# Producer does not need GPU +tf.config.set_visible_devices([], "GPU") + +# Read the config +with open("configs/base.yaml", "r") as file: + config = yaml.safe_load(file) +# Load some constants from the config +BUF_PATH = config["dataset"]["buf_path"] +BUF_NUM_CHUNKS = config["dataset"]["buf_num_chunks"] +if BUF_NUM_CHUNKS < 1: + raise ValueError("Config `buf_num_chunks` must be at least 1.") +BUF_CHUNK_SIZE = config["dataset"]["buf_chunk_size"] +if BUF_CHUNK_SIZE < 1: + raise ValueError("Config `buf_chunk_size` must be at least 1.") + + +def get_dirty_item(chunk_dir): + """ + Get indexes of dirty items in a chunk. + """ + dirty_bit = read_dirty_bit(chunk_dir) + return np.where(dirty_bit)[0].tolist() + + +def get_clean_item(chunk_dir): + """ + Get indexes of clean items in a chunk. + """ + dirty_bit = read_dirty_bit(chunk_dir) + return np.where(1 - dirty_bit)[0].tolist() + + +def save_dirty_bit(chunk_dir, dirty_bit): + """ + Save the dirty bit to the chunk directory. + """ + time_stmp = time.time() + while time.time() - time_stmp < 10.0: + try: + file_path = os.path.join(chunk_dir, "dirty_bit") + lock = FileLock(file_path) + lock.acquire_write_lock() + with open(file_path, "wb") as file: + file.write(dirty_bit.tobytes()) + lock.release_lock() + return + except KeyboardInterrupt: + lock.release_lock() + raise KeyboardInterrupt + except BaseException: + lock.release_lock() + continue + # raise RuntimeError("Failed to save dirty bit.") + print("Failed to save dirty bit.") + + +def read_dirty_bit(chunk_dir): + """ + Read the dirty bit from the chunk directory. + """ + # If error occurs, retry + time_stmp = time.time() + while time.time() - time_stmp < 10.0: + try: + file_path = os.path.join(chunk_dir, "dirty_bit") + lock = FileLock(file_path) + lock.acquire_read_lock() + with open(file_path, "rb") as file: + dirty_bit = np.frombuffer(file.read(), dtype=np.uint8).copy() + lock.release_lock() + assert len(dirty_bit) == BUF_CHUNK_SIZE + return dirty_bit + except KeyboardInterrupt: + lock.release_lock() + raise KeyboardInterrupt + except BaseException: + lock.release_lock() + continue + # If failed to read the dirty bit, return all ones for robustness + return np.ones(BUF_CHUNK_SIZE, dtype=np.uint8) + + +def save_sample(step_dict, chunk_dir, chunk_item_idx): + """ + Save a sample to the chunk directory. + """ + # Save the json content + time_stmp = time.time() + while time.time() - time_stmp < 10.0: + try: + locks = [] + json_content = step_dict["json_content"] + file_path = os.path.join(chunk_dir, f"json_content_{chunk_item_idx}.json") + lock = FileLock(file_path) + locks.append(lock) + lock.acquire_write_lock() + with open(file_path, "w") as file: + json.dump(json_content, file, indent=4) + lock.release_lock() + # Save all other tensors in a npz + file_path = os.path.join(chunk_dir, f"sample_{chunk_item_idx}.npz") + lock = FileLock(file_path) + locks.append(lock) + lock.acquire_write_lock() + with open(file_path, "wb") as file: + np.savez( + file, + step_id=step_dict["step_id"].numpy(), + state_chunk=step_dict["state_chunk"].numpy(), + state_chunk_time_mask=step_dict["state_chunk_time_mask"].numpy(), + action_chunk=step_dict["action_chunk"].numpy(), + action_chunk_time_mask=step_dict["action_chunk_time_mask"].numpy(), + state_vec_mask=step_dict["state_vec_mask"].numpy(), + past_frames_0=step_dict["past_frames_0"].numpy(), + past_frames_0_time_mask=step_dict["past_frames_0_time_mask"].numpy(), + past_frames_1=step_dict["past_frames_1"].numpy(), + past_frames_1_time_mask=step_dict["past_frames_1_time_mask"].numpy(), + past_frames_2=step_dict["past_frames_2"].numpy(), + past_frames_2_time_mask=step_dict["past_frames_2_time_mask"].numpy(), + past_frames_3=step_dict["past_frames_3"].numpy(), + past_frames_3_time_mask=step_dict["past_frames_3_time_mask"].numpy(), + state_std=step_dict["state_std"].numpy(), + state_mean=step_dict["state_mean"].numpy(), + state_norm=step_dict["state_norm"].numpy(), + ) + lock.release_lock() + return + except KeyboardInterrupt: + for lock in locks: + lock.release_lock() + raise KeyboardInterrupt + except BaseException: + for lock in locks: + lock.release_lock() + continue + # raise RuntimeError("Failed to save sample.") + print("Failed to save sample.") + + +def run_producer(seed, num_workers, worker_id, fill_up, clean_dirty, dataset_type): + """ + Run the producer. + The producer will first fill up the buffer with samples. + Then it will keep replacing dirty samples + (i.e., samples that have been read by the consumer) + with new samples. + """ + vla_dataset = VLADataset(seed=seed, dataset_type=dataset_type) + chunk_start_idx = worker_id * BUF_NUM_CHUNKS // num_workers + chunk_end_idx = (worker_id + 1) * BUF_NUM_CHUNKS // num_workers + if fill_up: + print(f"Worker {worker_id}: Start filling up the buffer...") + elif clean_dirty: + # Only refresh the dirty bits + print(f"Worker {worker_id}: Start refreshing the dirty bits...") + for chunk_idx in range(chunk_start_idx, chunk_end_idx): + chunk_dir = os.path.join(BUF_PATH, f"chunk_{chunk_idx}") + dirty_bit = np.zeros(BUF_CHUNK_SIZE, dtype=np.uint8) + save_dirty_bit(chunk_dir, dirty_bit) + print(f"Worker {worker_id}: Refreshed the dirty bits.") + + fill_chunk_idx = chunk_start_idx + fill_chunk_item_idx = 0 + dirty_chunk_idx = chunk_start_idx + dirty_chunk_item_idxs = [] + time_stmp = time.time() + for episode_steps in vla_dataset: + for step in episode_steps: + if fill_up and fill_chunk_idx < chunk_end_idx: + # Fill up the buffer + chunk_dir = os.path.join(BUF_PATH, f"chunk_{fill_chunk_idx}") + if fill_chunk_item_idx == 0: + # Create a new chunk + os.makedirs(chunk_dir, exist_ok=True) + # Write the dirty bit of size BUF_CHUNK_SIZE + dirty_bit = np.zeros(BUF_CHUNK_SIZE, dtype=np.uint8) + save_dirty_bit(chunk_dir, dirty_bit) + + # Save the sample + save_sample(step, chunk_dir, fill_chunk_item_idx) + + # print(f"Filled up chunk {fill_chunk_item_idx+1}/{BUF_CHUNK_SIZE} {fill_chunk_idx+1}/{BUF_NUM_CHUNKS}") + local_fill_chunk_idx = fill_chunk_idx - chunk_start_idx + local_num_chunks = chunk_end_idx - chunk_start_idx + if (local_fill_chunk_idx % 10 == 0 + or local_fill_chunk_idx == local_num_chunks - 1) and fill_chunk_item_idx == 0: + print(f"Worker {worker_id}: Filled up chunk {local_fill_chunk_idx+1}/{local_num_chunks}") + fill_chunk_item_idx += 1 + if fill_chunk_item_idx == BUF_CHUNK_SIZE: + fill_chunk_idx += 1 + fill_chunk_item_idx = 0 + if fill_chunk_idx == BUF_NUM_CHUNKS: + print(f"Worker {worker_id}: Buffer filled up. Start replacing dirty samples...") + + else: + # Search for the dirty chunk to replace + while len(dirty_chunk_item_idxs) == 0: + dirty_chunk_dir = os.path.join(BUF_PATH, f"chunk_{dirty_chunk_idx}") + dirty_chunk_item_idxs = get_dirty_item(dirty_chunk_dir) + # Print the dirty ratio + if time.time() - time_stmp > 2.0: + dirty_ratio = len(dirty_chunk_item_idxs) / BUF_CHUNK_SIZE + print(f"Worker {worker_id}: Dirty Ratio for Chunk {dirty_chunk_idx}: {dirty_ratio:.2f}") + time_stmp = time.time() + + if len(dirty_chunk_item_idxs) > 0: + # Lock the chunk + dirty_bit = np.ones(BUF_CHUNK_SIZE, dtype=np.uint8) + save_dirty_bit(dirty_chunk_dir, dirty_bit) + + # Iterate over the chunks + dirty_chunk_idx += 1 + if dirty_chunk_idx == chunk_end_idx: + dirty_chunk_idx = chunk_start_idx + + # Replace the dirty item + dirty_item_idx = dirty_chunk_item_idxs.pop() + chunk_dir = os.path.join(BUF_PATH, f"chunk_{dirty_chunk_idx}") + # Save the sample + save_sample(step, chunk_dir, dirty_item_idx) + + # If we have replaced all dirty items in the chunk + if len(dirty_chunk_item_idxs) == 0: + # Unlock the chunk + dirty_bit = np.zeros(BUF_CHUNK_SIZE, dtype=np.uint8) + save_dirty_bit(dirty_chunk_dir, dirty_bit) + print(f"Worker {worker_id}: Replaced dirty chunk {dirty_chunk_idx}.") + + +if __name__ == "__main__": + # Args: n_workers, fill_up + parser = argparse.ArgumentParser() + parser.add_argument( + "--n_workers", + type=int, + default=2, + help="Number of parallel workers. It should be less than or equal to the number of chunks.", + ) + parser.add_argument( + "--fill_up", + action="store_true", + help="Whether to fill up the buffer before replacing dirty samples.", + ) + parser.add_argument( + "--clean_dirty", + action="store_true", + help= + "Whether to clean the dirty bits before replacing dirty samples. This option is ignored when `fill_up` is set.", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="Random seed. If not set, the seed will be randomly generated.", + ) + parser.add_argument( + "--dataset_type", + type=str, + default="pretrain", + help="Whether to load the pretrain dataset or finetune dataset.", + ) + + # Run the producer + args = parser.parse_args() + if args.seed is not None: + print(f"Base seed: {args.seed}") + random.seed(args.seed) + + processes = [] + process_seeds = [random.randint(0, 2**32) for _ in range(args.n_workers)] + print(f"Process seeds: {process_seeds}") + + def signal_handler(sig, frame): + print("Ctrl+C received. Terminating child processes...") + for p in processes: + p.terminate() + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + for worker_id in range(args.n_workers): + p = Process( + target=run_producer, + args=( + process_seeds[worker_id], + args.n_workers, + worker_id, + args.fill_up, + args.clean_dirty, + args.dataset_type, + ), + ) + p.start() + processes.append(p) + + for p in processes: + p.join() diff --git a/RDT-1B/data/utils.py b/RDT-1B/data/utils.py new file mode 100644 index 0000000..773a2fa --- /dev/null +++ b/RDT-1B/data/utils.py @@ -0,0 +1,242 @@ +import tensorflow as tf +import tensorflow_graphics.geometry.transformation.euler as tf_euler +import tensorflow_graphics.geometry.transformation.quaternion as tf_quat +import tensorflow_graphics.geometry.transformation.rotation_matrix_3d as tf_rotmat + + +def dataset_to_path(dataset_name: str, dir_name: str) -> str: + """ + Return the path to the dataset. + """ + if (dataset_name == "robo_net" or dataset_name == "cmu_playing_with_food" or dataset_name == "droid"): + version = "1.0.0" + elif (dataset_name == "language_table" or dataset_name == "fmb" or dataset_name == "dobbe"): + version = "0.0.1" + elif dataset_name == "nyu_door_opening_surprising_effectiveness": + version = "" + elif dataset_name == "cmu_play_fusion": + version = "" + elif dataset_name == "berkeley_gnm_recon": + version = "" + else: + version = "0.1.0" + return f"{dir_name}/{dataset_name}/{version}" + + +def clean_task_instruction(task_instruction: tf.Tensor, replacements: dict) -> tf.Tensor: + """ + Clean up the natural language task instruction. + """ + + # Create a function that applies all replacements + def apply_replacements(tensor): + for old, new in replacements.items(): + tensor = tf.strings.regex_replace(tensor, old, new) + return tensor + + # Apply the replacements and strip leading and trailing spaces + cleaned_task_instruction = apply_replacements(task_instruction) + cleaned_task_instruction = tf.strings.strip(cleaned_task_instruction) + return cleaned_task_instruction + + +def quaternion_to_euler(quaternion: tf.Tensor) -> tf.Tensor: + """ + Convert a quaternion (x, y, z, w) to Euler angles (roll, pitch, yaw). + The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention. + """ + # Normalize the quaternion + quaternion = tf.nn.l2_normalize(quaternion, axis=-1) + return tf_euler.from_quaternion(quaternion) + + +def euler_to_quaternion(euler: tf.Tensor) -> tf.Tensor: + """ + Convert Euler angles (roll, pitch, yaw) to a quaternion (x, y, z, w). + The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention. + """ + quaternion = tf_quat.from_euler(euler) + return tf.nn.l2_normalize(quaternion, axis=-1) + + +def rotation_matrix_to_euler(matrix: tf.Tensor) -> tf.Tensor: + """ + Convert a 3x3 rotation matrix to Euler angles (roll, pitch, yaw). + The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention. + """ + return tf_euler.from_rotation_matrix(matrix) + + +def rotation_matrix_to_quaternion(matrix: tf.Tensor) -> tf.Tensor: + """ + Convert a 3x3 rotation matrix to a quaternion (x, y, z, w). + """ + quaternion = tf_quat.from_rotation_matrix(matrix) + return tf.nn.l2_normalize(quaternion, axis=-1) + + +def euler_to_rotation_matrix(euler: tf.Tensor) -> tf.Tensor: + """ + Convert Euler angles (roll, pitch, yaw) to a 3x3 rotation matrix. + The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention. + """ + return tf_rotmat.from_euler(euler) + + +def quaternion_to_rotation_matrix(quaternion: tf.Tensor) -> tf.Tensor: + """ + Convert a quaternion (x, y, z, w) to a 3x3 rotation matrix. + """ + # Normalize the quaternion + quaternion = tf.nn.l2_normalize(quaternion, axis=-1) + return tf_rotmat.from_quaternion(quaternion) + + +def quaternion_to_rotation_matrix_wo_static_check(quaternion: tf.Tensor) -> tf.Tensor: + """ + Convert a quaternion (x, y, z, w) to a 3x3 rotation matrix. + This function is used to make tensorflow happy. + """ + # Normalize the quaternion + quaternion = tf.nn.l2_normalize(quaternion, axis=-1) + + x = quaternion[..., 0] + y = quaternion[..., 1] + z = quaternion[..., 2] + w = quaternion[..., 3] + + tx = 2.0 * x + ty = 2.0 * y + tz = 2.0 * z + twx = tx * w + twy = ty * w + twz = tz * w + txx = tx * x + txy = ty * x + txz = tz * x + tyy = ty * y + tyz = tz * y + tzz = tz * z + matrix = tf.stack( + ( + 1.0 - (tyy + tzz), + txy - twz, + txz + twy, + txy + twz, + 1.0 - (txx + tzz), + tyz - twx, + txz - twy, + tyz + twx, + 1.0 - (txx + tyy), + ), + axis=-1, + ) # pyformat: disable + output_shape = tf.concat((tf.shape(input=quaternion)[:-1], (3, 3)), axis=-1) + return tf.reshape(matrix, shape=output_shape) + + +""" +Below is a continuous 6D rotation representation adapted from +On the Continuity of Rotation Representations in Neural Networks +https://arxiv.org/pdf/1812.07035.pdf +https://github.com/papagina/RotationContinuity/blob/master/sanity_test/code/tools.py +""" + + +def rotation_matrix_to_ortho6d(matrix: tf.Tensor) -> tf.Tensor: + """ + The orhto6d represents the first two column vectors a1 and a2 of the + rotation matrix: [ | , |, | ] + [ a1, a2, a3] + [ | , |, | ] + Input: (A1, ..., An, 3, 3) + Output: (A1, ..., An, 6) + """ + ortho6d = matrix[..., :, :2] + # Transpose the last two dimension + perm = list(range(len(ortho6d.shape))) + perm[-2], perm[-1] = perm[-1], perm[-2] + ortho6d = tf.transpose(ortho6d, perm) + # Flatten the last two dimension + ortho6d = tf.reshape(ortho6d, ortho6d.shape[:-2] + [6]) + return ortho6d + + +def rotation_matrix_to_ortho6d_1d(matrix: tf.Tensor) -> tf.Tensor: + """ + The orhto6d represents the first two column vectors a1 and a2 of the + rotation matrix: [ | , |, | ] + [ a1, a2, a3] + [ | , |, | ] + Input: (3, 3) + Output: (6,) + This function is used to make tensorflow happy. + """ + ortho6d = matrix[:, :2] + # Transpose the last two dimension + ortho6d = tf.transpose(ortho6d) + # Flatten the last two dimension + ortho6d = tf.reshape(ortho6d, [6]) + return ortho6d + + +def normalize_vector(v): + """ + v: (..., N) + """ + v_mag = tf.sqrt(tf.reduce_sum(tf.square(v), axis=-1, keepdims=True)) + v_mag = tf.maximum(v_mag, 1e-8) + v_normalized = v / v_mag + + return v_normalized + + +def cross_product(u, v): + """ + u: (..., 3) + v: (..., 3) + u x v: (..., 3) + """ + i = u[..., 1] * v[..., 2] - u[..., 2] * v[..., 1] + j = u[..., 2] * v[..., 0] - u[..., 0] * v[..., 2] + k = u[..., 0] * v[..., 1] - u[..., 1] * v[..., 0] + out = tf.stack([i, j, k], axis=-1) + return out + + +def ortho6d_to_rotation_matrix(ortho6d: tf.Tensor) -> tf.Tensor: + """ + The orhto6d represents the first two column vectors a1 and a2 of the + rotation matrix: [ | , |, | ] + [ a1, a2, a3] + [ | , |, | ] + Input: (A1, ..., An, 6) + Output: (A1, ..., An, 3, 3) + """ + x_raw = ortho6d[..., 0:3] + y_raw = ortho6d[..., 3:6] + + x = normalize_vector(x_raw) + z = cross_product(x, y_raw) + z = normalize_vector(z) + y = cross_product(z, x) + + # Stack x, y, z to form the matrix + matrix = tf.stack([x, y, z], axis=-1) + return matrix + + +def capitalize_and_period(instr: str) -> str: + """ + Capitalize the first letter of a string and add a period to the end if it's not there. + """ + if len(instr) > 0: + # if the first letter is not capital, make it so + if not instr[0].isupper(): + # if the first letter is not capital, make it so + instr = instr[0].upper() + instr[1:] + # add period to the end if it's not there + if instr[-1] != ".": + # add period to the end if it's not there + instr = instr + "." + return instr diff --git a/RDT-1B/data/vla_dataset.py b/RDT-1B/data/vla_dataset.py new file mode 100644 index 0000000..adad57c --- /dev/null +++ b/RDT-1B/data/vla_dataset.py @@ -0,0 +1,149 @@ +import json +import random + +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds +import yaml + +from data.episode_transform import ( + process_episode, + flatten_episode, + flatten_episode_agilex, + bgr_to_rgb, +) +from data.utils import dataset_to_path +from data.preprocess_scripts import * + +# Producer does not need GPU +tf.config.set_visible_devices([], "GPU") + +OPENX_EMBOD_DIR = "data/datasets/openx_embod" + +DATASET_NAMES_NOOPENX = [ + "aloha_mobile", + "aloha_static", + "roboset", + "agilex", + "rh20t", + "calvin", + "bridgev2", +] + +# Read the config +with open("configs/base.yaml", "r") as file: + config = yaml.safe_load(file) +# Load some constants from the config +EPSD_LEN_THRESH_LOW = config["dataset"]["epsd_len_thresh_low"] +EPSD_LEN_THRESH_HIGH = config["dataset"]["epsd_len_thresh_high"] +# Read the image keys of each dataset +with open("configs/dataset_img_keys.json", "r") as file: + IMAGE_KEYS = json.load(file) + + +class VLADataset: + """ + This class is used to sample episodes from the embododiment dataset. + """ + + def __init__(self, seed, dataset_type, repeat=True): + """ + seed: the random seed + dataset_type: 'pretrain' or 'finetune', which dataset to load + repeat: whether to repeat to infinite length + """ + dataset_names_cfg = ("configs/pretrain_datasets.json" + if dataset_type == "pretrain" else "configs/finetune_datasets.json") + with open(dataset_names_cfg, "r") as file: + DATASET_NAMES = json.load(file) + self.dataset_names = DATASET_NAMES + sample_weights_cfg = ("configs/pretrain_sample_weights.json" + if dataset_type == "pretrain" else "configs/finetune_sample_weights.json") + # Load the sample weights + with open(sample_weights_cfg, "r") as file: + SAMPLE_WEIGHTS = json.load(file) + self.openx_dir = OPENX_EMBOD_DIR + self.epsd_len_thresh_low = EPSD_LEN_THRESH_LOW + self.epsd_len_thresh_high = EPSD_LEN_THRESH_HIGH + self.repeat = repeat + + # Set the random seed + tf.random.set_seed(seed) + np.random.seed(seed) + + # Weights of the each dataset in the collection to sample from + sample_weights = [] + + self.name2dataset = {} + for dataset_name in self.dataset_names: + if dataset_name in DATASET_NAMES_NOOPENX: + dataset = globals()[dataset_name].load_dataset(seed) + else: + dataset_path = dataset_to_path(dataset_name, self.openx_dir) + dataset = tfds.builder_from_directory(builder_dir=dataset_path) + dataset = dataset.as_dataset(split="all", shuffle_files=True) + + # You can add filter for other datasets + if dataset_name == "kuka": + dataset = dataset.filter(lambda x: x["success"]) + elif dataset_name == "bc_z": + dataset = dataset.filter(lambda x: tf.math.greater( + next(iter(x["steps"]))["observation"]["episode_success"], + 0.5, + )) + elif (dataset_name == "ucsd_pick_and_place_dataset_converted_externally_to_rlds"): + dataset = dataset.filter(lambda x: x["episode_metadata"]["success"]) + elif (dataset_name == "utokyo_xarm_bimanual_converted_externally_to_rlds"): + # Only preserve the meaningful episodes + dataset = dataset.filter(lambda x: tf.math.equal( + next(iter(x["steps"]))["language_instruction"], + tf.constant("Unfold a wrinkled towel."), + )) + + # Note: use cache() will cause the unexpected crash + # dataset = dataset.map().cache().shuffle().repeat() + dataset = dataset.map(lambda x: process_episode( + x, + dataset_name, + IMAGE_KEYS[dataset_name]["image_keys"], + IMAGE_KEYS[dataset_name]["image_mask"], + )) + + # Change BGR to RGB if needed + if dataset_name == "fmb": + dataset = dataset.map(bgr_to_rgb) + + if self.repeat: + dataset = dataset.repeat() + self.name2dataset[dataset_name] = iter(dataset) + sample_weights.append(SAMPLE_WEIGHTS[dataset_name]) + # Normalize the sample weights + sample_weights = np.array(sample_weights) + self.sample_weights = sample_weights / np.sum(sample_weights) + + def __iter__(self): + """ + Sample batches of episodes for an epoch. + """ + while True: + dataset_name = np.random.choice(self.dataset_names, p=self.sample_weights) + episode = next(self.name2dataset[dataset_name]) + if dataset_name == "agilex": + episode_steps = flatten_episode_agilex(episode) + else: + episode_steps = flatten_episode(episode) + # Filter too short + if len(episode_steps) < self.epsd_len_thresh_low: + continue + # Randomly sample too long + if len(episode_steps) > self.epsd_len_thresh_high: + episode_steps = random.sample(episode_steps, self.epsd_len_thresh_high) + + yield episode_steps + + +if __name__ == "__main__": + dataset = VLADataset(0, "finetune") + for episode in dataset: + print(episode[0]) + break diff --git a/RDT-1B/finetune.sh b/RDT-1B/finetune.sh new file mode 100644 index 0000000..c289296 --- /dev/null +++ b/RDT-1B/finetune.sh @@ -0,0 +1,107 @@ +#!/bin/bash + +BEGIN_TIME=$(date +%s) + +CONFIG_NAME="Train_Config_Default" +CONFIG_FILE="model_config/$CONFIG_NAME.yml" +echo "CONFIG_FILE_PATH: $CONFIG_FILE" + +### ============Read Input Config and ReLoad Config YAML=================== +ln -s /home/qi.xiong/Temp/RDT-1B/input/weights ../weights + +TRAIN_CONFIG_FILE="input/config.json" +echo "TRAIN_CONFIG_FILE_PATH: $TRAIN_CONFIG_FILE" +python scripts/read_config.py "$TRAIN_CONFIG_FILE" "$CONFIG_FILE" + +### ============Read Input Config and ReLoad Config YAML=================== + +export NCCL_IB_HCA=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_7:1,mlx5_8:1,mlx5_9:1 +export NCCL_DEBUG=INFO +export NCCL_NVLS_ENABLE=0 +export NCCL_SOCKET_IFNAME=eth0 +# export TEXT_ENCODER_NAME="google/t5-v1_1-xxl" +export VISION_ENCODER_NAME="../weights/siglip-so400m-patch14-384" +export CFLAGS="-I/usr/include" +export LDFLAGS="-L/usr/lib/x86_64-linux-gnu" +export WANDB_PROJECT="RDT-1B" +export WANDB_DEFAULT_RUN_NAME=$CONFIG_NAME +export NCCL_P2P_DISABLE=1 +export NCCL_IB_DISABLE=1 + +# check if YAML exist +if [ ! -f "$CONFIG_FILE" ]; then + echo "Config file $CONFIG_FILE does not exist!" + exit 1 +fi + +PRETRAINED_MODEL_NAME=$(python scripts/read_yaml.py "$CONFIG_FILE" pretrained_model_name_or_path) +TRAIN_BATCH_SIZE=$(python scripts/read_yaml.py "$CONFIG_FILE" train_batch_size) +SAMPLE_BATCH_SIZE=$(python scripts/read_yaml.py "$CONFIG_FILE" sample_batch_size) +MAX_TRAIN_STEPS=$(python scripts/read_yaml.py "$CONFIG_FILE" max_train_steps) +CHECKPOINTING_PERIOD=$(python scripts/read_yaml.py "$CONFIG_FILE" checkpointing_period) +SAMPLE_PERIOD=$(python scripts/read_yaml.py "$CONFIG_FILE" sample_period) +CHECKPOINTS_TOTAL_LIMIT=$(python scripts/read_yaml.py "$CONFIG_FILE" checkpoints_total_limit) +LR_SCHEDULER=$(python scripts/read_yaml.py "$CONFIG_FILE" lr_scheduler) +LEARNING_RATE=$(python scripts/read_yaml.py "$CONFIG_FILE" learning_rate) +DATALOADER_NUM_WORKERS=$(python scripts/read_yaml.py "$CONFIG_FILE" dataloader_num_workers) +DATASET_TYPE=$(python scripts/read_yaml.py "$CONFIG_FILE" dataset_type) +STATE_NOISE_SNR=$(python scripts/read_yaml.py "$CONFIG_FILE" state_noise_snr) +GRAD_ACCUM_STEPS=$(python scripts/read_yaml.py "$CONFIG_FILE" gradient_accumulation_steps) +OUTPUT_DIR=$(python scripts/read_yaml.py "$CONFIG_FILE" checkpoint_path) +CUDA_USE=$(python scripts/read_yaml.py "$CONFIG_FILE" cuda_visible_device) + + +PRETRAINED_MODEL_NAME=$(echo "$PRETRAINED_MODEL_NAME" | tr -d '"') +CUDA_USE=$(echo "$CUDA_USE" | tr -d '"') +OUTPUT_DIR=$(echo "$OUTPUT_DIR" | tr -d '"') + +# create output path +if [ ! -d "$OUTPUT_DIR" ]; then + mkdir -p "$OUTPUT_DIR" + echo "Created output directory: $OUTPUT_DIR" +else + echo "Output directory already exists: $OUTPUT_DIR" +fi + +export CUDA_VISIBLE_DEVICES=$CUDA_USE + +python -m data.compute_dataset_stat_hdf5 --task_name $CONFIG_NAME + +accelerate launch --main_process_port=28499 main.py \ + --deepspeed="./configs/zero2.json" \ + --pretrained_model_name_or_path=$PRETRAINED_MODEL_NAME \ + --pretrained_text_encoder_name_or_path=$TEXT_ENCODER_NAME \ + --pretrained_vision_encoder_name_or_path=$VISION_ENCODER_NAME \ + --output_dir=$OUTPUT_DIR \ + --train_batch_size=$TRAIN_BATCH_SIZE \ + --sample_batch_size=$SAMPLE_BATCH_SIZE \ + --max_train_steps=$MAX_TRAIN_STEPS \ + --checkpointing_period=$CHECKPOINTING_PERIOD \ + --sample_period=$SAMPLE_PERIOD \ + --checkpoints_total_limit=$CHECKPOINTS_TOTAL_LIMIT \ + --lr_scheduler="constant" \ + --learning_rate=$LEARNING_RATE \ + --mixed_precision="bf16" \ + --dataloader_num_workers=$DATALOADER_NUM_WORKERS \ + --image_aug \ + --dataset_type="finetune" \ + --state_noise_snr=$STATE_NOISE_SNR \ + --load_from_hdf5 \ + --report_to=wandb \ + --precomp_lang_embed \ + --gradient_accumulation_steps=$GRAD_ACCUM_STEPS \ + --model_config_path=$CONFIG_FILE \ + --CONFIG_NAME=$CONFIG_NAME \ + --output_log_path=$OUTPUT_DIR/output.log + +END_TIME=$(date +%s) +RUNTIME=$((END_TIME - BEGIN_TIME)) +echo "Total runtime: $RUNTIME seconds" + +### ============Generate Output JSON=================== + +python scripts/generate_output_json.py "$TRAIN_CONFIG_FILE" "$OUTPUT_DIR" "$RUNTIME" + +### ============Generate Output JSON=================== + + diff --git a/RDT-1B/flash_attn-2.7.2.post1+cu12torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl b/RDT-1B/flash_attn-2.7.2.post1+cu12torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl new file mode 100644 index 0000000..566d96e Binary files /dev/null and b/RDT-1B/flash_attn-2.7.2.post1+cu12torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl differ diff --git a/RDT-1B/generate.sh b/RDT-1B/generate.sh new file mode 100644 index 0000000..608c34e --- /dev/null +++ b/RDT-1B/generate.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +model_name=${1} + +python ./model_config/_generate_model_config.py $model_name \ No newline at end of file diff --git a/RDT-1B/main.py b/RDT-1B/main.py new file mode 100644 index 0000000..0704ca7 --- /dev/null +++ b/RDT-1B/main.py @@ -0,0 +1,351 @@ +import argparse +import os +from train.train import train + +from accelerate.logging import get_logger + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Main script for training RDT.") + parser.add_argument( + "--model_config_path", + type=str, + default="model_config/sjoe_place_D435_100_finetune_config.yaml", + help= + "Path to the finetune data and model configuration file. Default is `model_config/sjoe_place_D435_100_finetune_config.yaml`.", + ) + parser.add_argument( + "--config_path", + type=str, + default="configs/base.yaml", + help="Path to the configuration file. Default is `configs/base.yaml`.", + ) + parser.add_argument( + "--deepspeed", + type=str, + default=None, + help= + "Enable DeepSpeed and pass the path to its config file or an already initialized DeepSpeed config dictionary", + ) + parser.add_argument( + "--pretrained_text_encoder_name_or_path", + type=str, + default=None, + help="Pretrained text encoder name or path if not the same as model_name", + ) + parser.add_argument( + "--pretrained_vision_encoder_name_or_path", + type=str, + default=None, + help="Pretrained vision encoder name or path if not the same as model_name", + ) + + parser.add_argument( + "--output_dir", + type=str, + default="checkpoints", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + + parser.add_argument( + "--load_from_hdf5", + action="store_true", + default=False, + help=("Whether to load the dataset directly from HDF5 files. " + "If False, the dataset will be loaded using producer-consumer pattern, " + "where the producer reads TFRecords and saves them to buffer, and the consumer reads from buffer."), + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=4, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--sample_batch_size", + type=int, + default=8, + help="Batch size (per device) for the sampling dataloader.", + ) + parser.add_argument( + "--num_sample_batches", + type=int, + default=2, + help="Number of batches to sample from the dataset.", + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_period", + type=int, + default=500, + help= + ("Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " + "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." + "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." + "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" + "instructions."), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help= + ("Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." + " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" + " for more details"), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=("Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_period`, or `"latest"` to automatically select the last available checkpoint.'), + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + help=( + "Path or name of a pretrained checkpoint to load the model from.\n", + " This can be either:\n" + " - a string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co, e.g., `robotics-diffusion-transformer/rdt-1b`,\n" + " - a path to a *directory* containing model weights saved using [`~RDTRunner.save_pretrained`] method, e.g., `./my_model_directory/`.\n" + " - a path to model checkpoint (*.pt), .e.g, `my_model_directory/checkpoint-10000/pytorch_model/mp_rank_00_model_states.pt`" + " - `None` if you are randomly initializing model using configuration at `config_path`.", + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--cond_mask_prob", + type=float, + default=0.1, + help=("The probability to randomly mask the conditions (except states) during training. " + "If set to 0, the conditions are not masked."), + ) + parser.add_argument( + "--cam_ext_mask_prob", + type=float, + default=-1.0, + help=("The probability to randomly mask the external camera image during training. " + "If set to < 0, the external camera image is masked with the probability of `cond_mask_prob`."), + ) + parser.add_argument( + "--state_noise_snr", + type=float, + default=None, + help=("The signal-to-noise ratio (SNR, unit: dB) for adding noise to the states. " + "Default is None, which means no noise is added."), + ) + parser.add_argument( + "--image_aug", + action="store_true", + default=False, + help="Whether or not to apply image augmentation (ColorJitter, blur, noise, etc) to the input images.", + ) + parser.add_argument( + "--precomp_lang_embed", + action="store_true", + default=False, + help="Whether or not to use precomputed language embeddings.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]'), + ) + parser.add_argument( + "--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.", + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument( + "--lr_power", + type=float, + default=1.0, + help="Power factor of the polynomial scheduler.", + ) + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--alpha", + type=float, + default=0.9, + help="The moving average coefficient for each dataset's loss.", + ) + parser.add_argument( + "--adam_beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam optimizer.", + ) + parser.add_argument( + "--adam_beta2", + type=float, + default=0.999, + help="The beta2 parameter for the Adam optimizer.", + ) + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the model to the Hub.", + ) + parser.add_argument( + "--hub_token", + type=str, + default=None, + help="The token to use to push to the Model Hub.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=("Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=('The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'), + ) + parser.add_argument( + "--sample_period", + type=int, + default=-1, + help=("Run sampling every X steps. During the sampling phase, the model will sample a trajectory" + " and report the error between the sampled trajectory and groud-truth trajectory" + " in the training batch."), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."), + ) + + parser.add_argument( + "--local_rank", + type=int, + default=-1, + help="For distributed training: local_rank", + ) + parser.add_argument( + "--set_grads_to_none", + action="store_true", + help=("Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" + " behaviors, so disable this argument if it causes any problems. More info:" + " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"), + ) + + parser.add_argument( + "--dataset_type", + type=str, + default="pretrain", + required=False, + help="Whether to load the pretrain dataset or finetune dataset.", + ) + + parser.add_argument( + "--CONFIG_NAME", + type=str, + default="Null", + required=True, + ) + + parser.add_argument( + "--output_log_path", + type=str, + default="output/output.log", + help="The path to the output log file.", + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + return args + + +if __name__ == "__main__": + logger = get_logger(__name__) + args = parse_args() + train(args, logger) diff --git a/RDT-1B/model.py b/RDT-1B/model.py new file mode 100644 index 0000000..e58d4b6 --- /dev/null +++ b/RDT-1B/model.py @@ -0,0 +1,269 @@ +#!/home/lin/software/miniconda3/envs/aloha/bin/python +# -- coding: UTF-8 +""" +#!/usr/bin/python3 +""" +from pathlib import Path + +# get current workspace +current_file = Path(__file__) + +import json +import sys + +parent_dir = current_file.parent +sys.path.append(str(parent_dir)) + +import os + +import argparse + +import threading +import time +import yaml +from collections import deque + +import numpy as np +import torch +from PIL import Image as PImage +import cv2 + +import sys, os + +# get current workspace +current_file = Path(__file__) +sys.path.append(os.path.join(current_file.parent, "models")) + +from scripts.agilex_model import create_model +from multimodal_encoder.t5_encoder import T5Embedder + +global_path = parent_dir.parent + + +class RDT: + + def __init__( + self, + pretrained_model_name_or_path, + task_name, + left_arm_dim, + right_arm_dim, + rdt_step, + ): + # set path + current_file = Path(__file__) + self.global_path = current_file.parent.parent + # load the config + self.config = { + "episode_len": 10000, # args.max_publish_step + "state_dim": left_arm_dim + 1 + right_arm_dim + + 1, # 14 dims action:[left joint angles,left gripper,right joint angles,right gripper] + "chunk_size": 64, # args.chunk_size + "camera_names": ["cam_high", "cam_right_wrist", "cam_left_wrist"], + } + # setup config + self.args = { + "max_publish_step": 10000, # Maximum number of action publishing steps + "seed": None, # Random seed + "ctrl_freq": 25, # The control frequency of the robot + "chunk_size": 64, # Action chunk size + # 'disable_puppet_arm': False, # Whether to disable the puppet arm + "config_path": os.path.join(self.global_path, "RDT/configs/base.yaml"), + "pretrained_model_name_or_path": pretrained_model_name_or_path, + } + + # Load rdt model + self.left_arm_dim, self.right_arm_dim = left_arm_dim, right_arm_dim + self.policy = self.make_policy(self.args) + self.max_publish_step = self.config["episode_len"] + self.chunk_size = self.config["chunk_size"] + self.task_name = task_name + self.observation_window = None + self.img_size = (640, 480) + self.set_language_embed() + self.rdt_step = rdt_step + + # set img_size + def set_img_size(self, img_size): + self.img_size = img_size + + def set_language_embed(self): + GPU = 0 + MODEL_PATH = os.path.join(self.global_path, "weights/RDT/t5-v1_1-xxl") + CONFIG_PATH = os.path.join(self.global_path, "RDT/configs/base.yaml") + with open(CONFIG_PATH, "r") as fp: + config = yaml.safe_load(fp) + device = torch.device(f"cuda:{GPU}") + text_embedder = T5Embedder( + from_pretrained=MODEL_PATH, + model_max_length=config["dataset"]["tokenizer_max_length"], + device=device, + use_offload_folder=None, + ) + self.tokenizer, self.text_encoder = text_embedder.tokenizer, text_embedder.model + self.text_encoder.eval() + + # set language randomly + def random_set_language(self, instruction=None): + assert instruction is not None, "Missing input instruction" + self.set_language_instruction(instruction) + + # encoding language + def set_language_instruction(self, language_instruction, save_dir=None, task_name=None): + assert ((save_dir is None) ^ (task_name is None)) == False, "input error" + + if os.path.isfile(language_instruction): + lang_dict = torch.load(language_instruction) + print(f"Running with instruction: \"{lang_dict['instruction']}\" from \"{lang_dict['name']}\"") + self.lang_embeddings = lang_dict["embeddings"] + print("loading instruction from pre-embed path") + else: + device = next(self.text_encoder.parameters()).device + with torch.no_grad(): + tokens = self.tokenizer( + language_instruction, + return_tensors="pt", + padding="longest", + truncation=True, + )["input_ids"].to(device) + tokens = tokens.view(1, -1) + output = self.text_encoder(tokens) + pred = output.last_hidden_state.detach().cpu() + + if save_dir is not None: + save_path = os.path.join(save_dir, f"{task_name}.pt") + torch.save({ + "name": task_name, + "instruction": language_instruction, + "embeddings": pred, + }, save_path) + + del tokens, output + torch.cuda.empty_cache() + self.lang_embeddings = pred + + print(f"successfully set instruction: {language_instruction}") + + # Update the observation window buffer + def update_observation_window(self, img_arr, state): + # JPEG transformation + # Align with training + def jpeg_mapping(img): + if img is None: + return None + img = cv2.imencode(".jpg", img)[1].tobytes() + img = cv2.imdecode(np.frombuffer(img, np.uint8), cv2.IMREAD_COLOR) + return img + + def resize_img(img, size): + return cv2.resize(img, size) + + if self.observation_window is None: + self.observation_window = deque(maxlen=2) + + # Append the first dummy image + self.observation_window.append({ + "qpos": None, + "images": { + self.config["camera_names"][0]: None, + self.config["camera_names"][1]: None, + self.config["camera_names"][2]: None, + }, + }) + + img_front, img_right, img_left, puppet_arm = ( + img_arr[0], + img_arr[1], + img_arr[2], + state, + ) + # img resize + img_front = resize_img(img_front, self.img_size) + img_left = resize_img(img_left, self.img_size) + img_right = resize_img(img_right, self.img_size) + # img jprg encoding + img_front = jpeg_mapping(img_front) + img_left = jpeg_mapping(img_left) + img_right = jpeg_mapping(img_right) + + qpos = np.array(puppet_arm) + qpos = torch.from_numpy(qpos).float().cuda() + self.observation_window.append({ + "qpos": qpos, + "images": { + self.config["camera_names"][0]: img_front, + self.config["camera_names"][1]: img_right, + self.config["camera_names"][2]: img_left, + }, + }) + + def get_action(self, img_arr=None, state=None): + assert (img_arr is None) ^ (state is None) == False, "input error" + if (img_arr is not None) and (state is not None): + self.update_observation_window(img_arr, state) + + with torch.inference_mode(): + action_buffer = inference_fn(self.config, self.policy, self.lang_embeddings, self.observation_window).copy() + + return action_buffer + + def reset_obsrvationwindows(self): + self.lang_embeddings = None + self.observation_window = None + print("successfully unset obs and language intruction") + + # Initialize the model + def make_policy(self, args): + with open(args["config_path"], "r") as fp: + config_base_yaml = yaml.safe_load(fp) + args["config"] = config_base_yaml + args["config"]["arm_dim"] = { + "left_arm_dim": self.left_arm_dim, + "right_arm_dim": self.right_arm_dim, + } + # pretrained_text_encoder_name_or_path = "weights/RDT/t5-v1_1-xxl" + pretrained_vision_encoder_name_or_path = os.path.join(self.global_path, "weights/RDT/siglip-so400m-patch14-384") + model = create_model( + args=args["config"], + dtype=torch.bfloat16, + pretrained=args["pretrained_model_name_or_path"], + # pretrained_text_encoder_name_or_path=pretrained_text_encoder_name_or_path, + pretrained_vision_encoder_name_or_path=pretrained_vision_encoder_name_or_path, + control_frequency=args["ctrl_freq"], + ) + + return model + + +# RDT inference +def inference_fn(config, policy, lang_embeddings, observation_window): + + # print(f"Start inference_thread_fn: t={t}") + while True: + time1 = time.time() + + # fetch images in sequence [front, right, left] + image_arrs = [ + observation_window[-2]["images"][config["camera_names"][0]], + observation_window[-2]["images"][config["camera_names"][1]], + observation_window[-2]["images"][config["camera_names"][2]], + observation_window[-1]["images"][config["camera_names"][0]], + observation_window[-1]["images"][config["camera_names"][1]], + observation_window[-1]["images"][config["camera_names"][2]], + ] + + images = [PImage.fromarray(arr) if arr is not None else None for arr in image_arrs] + + # get last qpos in shape [14, ] + proprio = observation_window[-1]["qpos"] + # unsqueeze to [1, 14] + proprio = proprio.unsqueeze(0) + + # actions shaped as [1, 64, 14] in format [left, right] + actions = (policy.step(proprio=proprio, images=images, text_embeds=lang_embeddings).squeeze(0).cpu().numpy()) + # print(f"inference_actions: {actions.squeeze()}") + + # print(f"Model inference time: {time.time() - time1} s") + + # print(f"Finish inference_thread_fn: t={t}") + return actions diff --git a/RDT-1B/model_config/_generate_model_config.py b/RDT-1B/model_config/_generate_model_config.py new file mode 100644 index 0000000..b8635ec --- /dev/null +++ b/RDT-1B/model_config/_generate_model_config.py @@ -0,0 +1,40 @@ +import os +import yaml +import argparse +from datetime import datetime + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Generate finetune config.") + parser.add_argument("model_name", type=str, help="The name of the task (e.g., beat_block_hammer)") + args = parser.parse_args() + model_name = args.model_name + fintune_data_path = os.path.join("training_data/", f"{model_name}") + checkpoint_path = os.path.join("checkpoints/", f"{model_name}") + data = { + "model": model_name, + "data_path": fintune_data_path, + "checkpoint_path": checkpoint_path, + "pretrained_model_name_or_path": "../weights/RDT/rdt-1b", + "cuda_visible_device": "...", # args.gpu_use, + "train_batch_size": 32, + "sample_batch_size": 64, + "max_train_steps": 20000, + "checkpointing_period": 2500, + "sample_period": 100, + "checkpoints_total_limit": 40, + "learning_rate": 1e-4, + "dataloader_num_workers": 8, + "state_noise_snr": 40, + "gradient_accumulation_steps": 1, + } + task_config_path = os.path.join("model_config/", f"{model_name}.yml") + + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + time_comment = f"# Generated on {current_time}\n" + + with open(task_config_path, "w") as f: + f.write(time_comment) + yaml.dump(data, f, default_flow_style=False, sort_keys=False) + + if not os.path.exists(fintune_data_path): + os.makedirs(fintune_data_path) diff --git a/RDT-1B/models/__init__.py b/RDT-1B/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/RDT-1B/models/__pycache__/__init__.cpython-310.pyc b/RDT-1B/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..56b0228 Binary files /dev/null and b/RDT-1B/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/RDT-1B/models/__pycache__/ema_model.cpython-310.pyc b/RDT-1B/models/__pycache__/ema_model.cpython-310.pyc new file mode 100644 index 0000000..c05697a Binary files /dev/null and b/RDT-1B/models/__pycache__/ema_model.cpython-310.pyc differ diff --git a/RDT-1B/models/__pycache__/hub_mixin.cpython-310.pyc b/RDT-1B/models/__pycache__/hub_mixin.cpython-310.pyc new file mode 100644 index 0000000..6201def Binary files /dev/null and b/RDT-1B/models/__pycache__/hub_mixin.cpython-310.pyc differ diff --git a/RDT-1B/models/__pycache__/rdt_runner.cpython-310.pyc b/RDT-1B/models/__pycache__/rdt_runner.cpython-310.pyc new file mode 100644 index 0000000..f497200 Binary files /dev/null and b/RDT-1B/models/__pycache__/rdt_runner.cpython-310.pyc differ diff --git a/RDT-1B/models/ema_model.py b/RDT-1B/models/ema_model.py new file mode 100644 index 0000000..39637ae --- /dev/null +++ b/RDT-1B/models/ema_model.py @@ -0,0 +1,82 @@ +# Reference: DiffusionPolicy [https://github.com/real-stanford/diffusion_policy] + +import torch +from torch.nn.modules.batchnorm import _BatchNorm + + +class EMAModel: + """ + Exponential Moving Average of models weights + """ + + def __init__(self, model, update_after_step=0, inv_gamma=1.0, power=2 / 3, min_value=0.0, max_value=0.9999): + """ + @crowsonkb's notes on EMA Warmup: + If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan + to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), + gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 + at 215.4k steps). + Args: + inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. + power (float): Exponential factor of EMA warmup. Default: 2/3. + min_value (float): The minimum EMA decay rate. Default: 0. + """ + + self.averaged_model = model + self.averaged_model.eval() + self.averaged_model.requires_grad_(False) + + self.update_after_step = update_after_step + self.inv_gamma = inv_gamma + self.power = power + self.min_value = min_value + self.max_value = max_value + + self.decay = 0.0 + self.optimization_step = 0 + + def get_decay(self, optimization_step): + """ + Compute the decay factor for the exponential moving average. + """ + step = max(0, optimization_step - self.update_after_step - 1) + value = 1 - (1 + step / self.inv_gamma)**-self.power + + if step <= 0: + return 0.0 + + return max(self.min_value, min(value, self.max_value)) + + @torch.no_grad() + def step(self, new_model): + self.decay = self.get_decay(self.optimization_step) + + # old_all_dataptrs = set() + # for param in new_model.parameters(): + # data_ptr = param.data_ptr() + # if data_ptr != 0: + # old_all_dataptrs.add(data_ptr) + + all_dataptrs = set() + for module, ema_module in zip(new_model.modules(), self.averaged_model.modules()): + for param, ema_param in zip(module.parameters(recurse=False), ema_module.parameters(recurse=False)): + # iterative over immediate parameters only. + if isinstance(param, dict): + raise RuntimeError('Dict parameter not supported') + + # data_ptr = param.data_ptr() + # if data_ptr != 0: + # all_dataptrs.add(data_ptr) + + if isinstance(module, _BatchNorm): + # skip batchnorms + ema_param.copy_(param.to(dtype=ema_param.dtype).data) + elif not param.requires_grad: + ema_param.copy_(param.to(dtype=ema_param.dtype).data) + else: + ema_param.mul_(self.decay) + ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay) + + # verify that iterating over module and then parameters is identical to parameters recursively. + # assert old_all_dataptrs == all_dataptrs + self.optimization_step += 1 diff --git a/RDT-1B/models/hub_mixin.py b/RDT-1B/models/hub_mixin.py new file mode 100644 index 0000000..d5ccbda --- /dev/null +++ b/RDT-1B/models/hub_mixin.py @@ -0,0 +1,75 @@ +import os +from pathlib import Path +from typing import Dict, Optional, Union + +from huggingface_hub import PyTorchModelHubMixin +from huggingface_hub.constants import (PYTORCH_WEIGHTS_NAME, SAFETENSORS_SINGLE_FILE) +from huggingface_hub.file_download import hf_hub_download +from huggingface_hub.utils import EntryNotFoundError, is_torch_available + +if is_torch_available(): + import torch # type: ignore + + +class CompatiblePyTorchModelHubMixin(PyTorchModelHubMixin): + """Mixin class to load Pytorch models from the Hub.""" + + def _save_pretrained(self, save_directory: Path) -> None: + """Save weights from a Pytorch model to a local directory.""" + # To bypass saving into safetensor by default + model_to_save = self.module if hasattr(self, "module") else self # type: ignore + torch.save(model_to_save.state_dict(), save_directory / PYTORCH_WEIGHTS_NAME) + + @classmethod + def _from_pretrained( + cls, + *, + model_id: str, + revision: Optional[str], + cache_dir: Optional[Union[str, Path]], + force_download: bool, + proxies: Optional[Dict], + resume_download: Optional[bool], + local_files_only: bool, + token: Union[str, bool, None], + map_location: str = "cpu", + strict: bool = False, + **model_kwargs, + ): + """Load Pytorch pretrained weights and return the loaded model.""" + model = cls(**model_kwargs) + if os.path.isdir(model_id): + print("Loading weights from local directory") + try: + model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE) + return cls._load_as_safetensor(model, model_file, map_location, strict) + except FileNotFoundError: + model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME) + return cls._load_as_pickle(model, model_file, map_location, strict) + else: + try: + model_file = hf_hub_download( + repo_id=model_id, + filename=SAFETENSORS_SINGLE_FILE, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + return cls._load_as_safetensor(model, model_file, map_location, strict) + except EntryNotFoundError: + model_file = hf_hub_download( + repo_id=model_id, + filename=PYTORCH_WEIGHTS_NAME, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + return cls._load_as_pickle(model, model_file, map_location, strict) diff --git a/RDT-1B/models/multimodal_encoder/__init__.py b/RDT-1B/models/multimodal_encoder/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/RDT-1B/models/multimodal_encoder/__pycache__/__init__.cpython-310.pyc b/RDT-1B/models/multimodal_encoder/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..55cf973 Binary files /dev/null and b/RDT-1B/models/multimodal_encoder/__pycache__/__init__.cpython-310.pyc differ diff --git a/RDT-1B/models/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc b/RDT-1B/models/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc new file mode 100644 index 0000000..3689e84 Binary files /dev/null and b/RDT-1B/models/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc differ diff --git a/RDT-1B/models/multimodal_encoder/__pycache__/t5_encoder.cpython-310.pyc b/RDT-1B/models/multimodal_encoder/__pycache__/t5_encoder.cpython-310.pyc new file mode 100644 index 0000000..4cf4ba2 Binary files /dev/null and b/RDT-1B/models/multimodal_encoder/__pycache__/t5_encoder.cpython-310.pyc differ diff --git a/RDT-1B/models/multimodal_encoder/clip_encoder.py b/RDT-1B/models/multimodal_encoder/clip_encoder.py new file mode 100644 index 0000000..460ff16 --- /dev/null +++ b/RDT-1B/models/multimodal_encoder/clip_encoder.py @@ -0,0 +1,159 @@ +import torch +import torch.nn as nn + +from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig + + +class CLIPVisionTower(nn.Module): + + def __init__(self, vision_tower, args, delay_load=False): + super().__init__() + + self.is_loaded = False + + self.vision_tower_name = vision_tower + self.select_layer = args.mm_vision_select_layer + self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') + + if not delay_load: + self.load_model() + elif getattr(args, 'unfreeze_mm_vision_tower', False): + self.load_model() + else: + self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) + + def load_model(self, device_map=None): + if self.is_loaded: + print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) + return + + self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) + self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) + self.vision_tower.requires_grad_(False) + + self.is_loaded = True + + def feature_select(self, image_forward_outs): + image_features = image_forward_outs.hidden_states[self.select_layer] + if self.select_feature == 'patch': + image_features = image_features[:, 1:] + elif self.select_feature == 'cls_patch': + image_features = image_features + else: + raise ValueError(f'Unexpected select feature: {self.select_feature}') + return image_features + + @torch.no_grad() + def forward(self, images): + if type(images) is list: + image_features = [] + for image in images: + image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), + output_hidden_states=True) + image_feature = self.feature_select(image_forward_out).to(image.dtype) + image_features.append(image_feature) + else: + image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), + output_hidden_states=True) + image_features = self.feature_select(image_forward_outs).to(images.dtype) + + return image_features + + @property + def dummy_feature(self): + return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) + + @property + def dtype(self): + return self.vision_tower.dtype + + @property + def device(self): + return self.vision_tower.device + + @property + def config(self): + if self.is_loaded: + return self.vision_tower.config + else: + return self.cfg_only + + @property + def hidden_size(self): + return self.config.hidden_size + + @property + def num_patches_per_side(self): + return self.config.image_size // self.config.patch_size + + @property + def num_patches(self): + return (self.config.image_size // self.config.patch_size)**2 + + +class CLIPVisionTowerS2(CLIPVisionTower): + + def __init__(self, vision_tower, args, delay_load=False): + super().__init__(vision_tower, args, delay_load) + + self.s2_scales = getattr(args, 's2_scales', '336,672,1008') + self.s2_scales = list(map(int, self.s2_scales.split(','))) + self.s2_scales.sort() + self.s2_split_size = self.s2_scales[0] + self.s2_image_size = self.s2_scales[-1] + + try: + from s2wrapper import forward as multiscale_forward + except ImportError: + raise ImportError( + 'Package s2wrapper not found! Please install by running: \npip install git+https://github.com/bfshi/scaling_on_scales.git' + ) + self.multiscale_forward = multiscale_forward + + # change resize/crop size in preprocessing to the largest image size in s2_scale + if not delay_load or getattr(args, 'unfreeze_mm_vision_tower', False): + self.image_processor.size['shortest_edge'] = self.s2_image_size + self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size + + def load_model(self, device_map=None): + if self.is_loaded: + print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) + return + + self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) + self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) + self.vision_tower.requires_grad_(False) + + self.image_processor.size['shortest_edge'] = self.s2_image_size + self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size + + self.is_loaded = True + + @torch.no_grad() + def forward_feature(self, images): + image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), + output_hidden_states=True) + image_features = self.feature_select(image_forward_outs).to(images.dtype) + return image_features + + @torch.no_grad() + def forward(self, images): + if type(images) is list: + image_features = [] + for image in images: + image_feature = self.multiscale_forward(self.forward_feature, + image.unsqueeze(0), + img_sizes=self.s2_scales, + max_split_size=self.s2_split_size) + image_features.append(image_feature) + else: + image_features = self.multiscale_forward(self.forward_feature, + images, + img_sizes=self.s2_scales, + max_split_size=self.s2_split_size) + + return image_features + + @property + def hidden_size(self): + return self.config.hidden_size * len(self.s2_scales) diff --git a/RDT-1B/models/multimodal_encoder/dinov2_encoder.py b/RDT-1B/models/multimodal_encoder/dinov2_encoder.py new file mode 100644 index 0000000..a809698 --- /dev/null +++ b/RDT-1B/models/multimodal_encoder/dinov2_encoder.py @@ -0,0 +1,87 @@ +import torch +import torch.nn as nn +from transformers import AutoConfig, AutoImageProcessor, AutoModel, Dinov2Model + + +class DinoV2VisionTower(nn.Module): + + def __init__(self, vision_tower, args, delay_load=False): + super().__init__() + + self.is_loaded = False + + self.vision_tower_name = vision_tower + self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') + + if not delay_load: + self.load_model() + elif getattr(args, 'unfreeze_mm_vision_tower', False): + self.load_model() + else: + self.cfg_only = AutoConfig.from_pretrained(self.vision_tower_name) + + def load_model(self, device_map=None): + if self.is_loaded: + print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) + return + + self.image_processor = AutoImageProcessor.from_pretrained(self.vision_tower_name) + self.vision_tower = AutoModel.from_pretrained(self.vision_tower_name, device_map=device_map) + self.vision_tower.requires_grad_(False) # FIXME: + + self.is_loaded = True + + def feature_select(self, image_forward_outs): + image_features = image_forward_outs.last_hidden_state + if self.select_feature == 'patch': + image_features = image_features[:, 1:] # (B, 1369, 1536) + elif self.select_feature == 'cls_patch': + image_features = image_features # (B, 1, 1536) + else: + raise ValueError(f'Unexpected select feature: {self.select_feature}') + return image_features + + @torch.no_grad() + def forward(self, images): + if type(images) is list: + image_features = [] + for image in images: + image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)) + image_feature = self.feature_select(image_forward_out).to(image.dtype) + image_features.append(image_feature) + else: + image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype)) + image_features = self.feature_select(image_forward_outs).to(images.dtype) + + return image_features + + @property + def dummy_feature(self): + return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) + + @property + def dtype(self): + return self.vision_tower.dtype + + @property + def device(self): + return self.vision_tower.device + + @property + def config(self): + if self.is_loaded: + return self.vision_tower.config + else: + return self.cfg_only + + @property + def hidden_size(self): + return self.config.hidden_size + + @property + def num_patches_per_side(self): + return self.config.image_size // self.config.patch_size + + @property + def num_patches(self): + return (self.config.image_size // self.config.patch_size)**2 diff --git a/RDT-1B/models/multimodal_encoder/siglip_encoder.py b/RDT-1B/models/multimodal_encoder/siglip_encoder.py new file mode 100644 index 0000000..b49b3ed --- /dev/null +++ b/RDT-1B/models/multimodal_encoder/siglip_encoder.py @@ -0,0 +1,86 @@ +import torch +import torch.nn as nn +from transformers import AutoConfig, SiglipImageProcessor, SiglipVisionModel + + +class SiglipVisionTower(nn.Module): + + def __init__(self, vision_tower, args, delay_load=False): + super().__init__() + + self.is_loaded = False + + self.vision_tower_name = vision_tower + self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') + + if not delay_load: + self.load_model() + elif getattr(args, 'unfreeze_mm_vision_tower', False): + self.load_model() + else: + self.cfg_only = AutoConfig.from_pretrained(self.vision_tower_name) + + def load_model(self, device_map=None): + if self.is_loaded: + print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) + return + + self.image_processor = SiglipImageProcessor.from_pretrained(self.vision_tower_name) + self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) + self.vision_tower.eval() + + self.is_loaded = True + + def feature_select(self, image_forward_outs): + if self.select_feature == 'patch': + image_features = image_forward_outs.last_hidden_state # (B, 729, 1536) + elif self.select_feature == 'cls_patch': + image_features = image_forward_outs.pooler_output # (B, 1, 1536) + else: + raise ValueError(f'Unexpected select feature: {self.select_feature}') + return image_features + + @torch.no_grad() + def forward(self, images): + if type(images) is list: + image_features = [] + for image in images: + image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)) + image_feature = self.feature_select(image_forward_out).to(image.dtype) + image_features.append(image_feature) + else: + image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype)) + image_features = self.feature_select(image_forward_outs).to(images.dtype) + + return image_features + + @property + def dummy_feature(self): + return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) + + @property + def dtype(self): + return self.vision_tower.dtype + + @property + def device(self): + return self.vision_tower.device + + @property + def config(self): + if self.is_loaded: + return self.vision_tower.config + else: + return self.cfg_only + + @property + def hidden_size(self): + return self.config.hidden_size + + @property + def num_patches_per_side(self): + return self.config.image_size // self.config.patch_size + + @property + def num_patches(self): + return (self.config.image_size // self.config.patch_size)**2 diff --git a/RDT-1B/models/multimodal_encoder/t5_encoder.py b/RDT-1B/models/multimodal_encoder/t5_encoder.py new file mode 100644 index 0000000..a0dfb7d --- /dev/null +++ b/RDT-1B/models/multimodal_encoder/t5_encoder.py @@ -0,0 +1,111 @@ +import torch +from transformers import AutoTokenizer, T5EncoderModel + + +class T5Embedder: + # available_models = ["google/t5-v1_1-xxl"] + + def __init__( + self, + device, + from_pretrained=None, + *, + cache_dir=None, + hf_token=None, + use_text_preprocessing=True, + t5_model_kwargs=None, + torch_dtype=None, + use_offload_folder=None, + model_max_length=120, + local_files_only=False, + ): + # from_pretrained="google/t5-v1_1-xxl" # zijian + self.device = torch.device(device) + self.torch_dtype = torch_dtype or torch.bfloat16 + self.cache_dir = cache_dir + + if t5_model_kwargs is None: + t5_model_kwargs = { + "low_cpu_mem_usage": True, + "torch_dtype": self.torch_dtype, + } + + if use_offload_folder is not None: + t5_model_kwargs["offload_folder"] = use_offload_folder + t5_model_kwargs["device_map"] = { + "shared": self.device, + "encoder.embed_tokens": self.device, + "encoder.block.0": self.device, + "encoder.block.1": self.device, + "encoder.block.2": self.device, + "encoder.block.3": self.device, + "encoder.block.4": self.device, + "encoder.block.5": self.device, + "encoder.block.6": self.device, + "encoder.block.7": self.device, + "encoder.block.8": self.device, + "encoder.block.9": self.device, + "encoder.block.10": self.device, + "encoder.block.11": self.device, + "encoder.block.12": "disk", + "encoder.block.13": "disk", + "encoder.block.14": "disk", + "encoder.block.15": "disk", + "encoder.block.16": "disk", + "encoder.block.17": "disk", + "encoder.block.18": "disk", + "encoder.block.19": "disk", + "encoder.block.20": "disk", + "encoder.block.21": "disk", + "encoder.block.22": "disk", + "encoder.block.23": "disk", + "encoder.final_layer_norm": "disk", + "encoder.dropout": "disk", + } + else: + t5_model_kwargs["device_map"] = { + "shared": self.device, + "encoder": self.device, + } + + self.use_text_preprocessing = use_text_preprocessing + self.hf_token = hf_token + + # assert from_pretrained in self.available_models + self.tokenizer = AutoTokenizer.from_pretrained( + from_pretrained, + model_max_length=model_max_length, + cache_dir=cache_dir, + local_files_only=local_files_only, + ) + self.model = T5EncoderModel.from_pretrained( + from_pretrained, + cache_dir=cache_dir, + local_files_only=local_files_only, + **t5_model_kwargs, + ).eval() + self.model_max_length = model_max_length + + def get_text_embeddings(self, texts): + text_tokens_and_mask = self.tokenizer( + texts, + max_length=self.model_max_length, + padding="longest", + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + + input_ids = text_tokens_and_mask["input_ids"].to(self.device) + attention_mask = text_tokens_and_mask["attention_mask"].to(self.device) + with torch.no_grad(): + text_encoder_embs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + )["last_hidden_state"].detach() + return text_encoder_embs, attention_mask + + +if __name__ == "__main__": + T5Embedder(from_pretrained="google/t5-v1_1-xxl", device='cuda:7') diff --git a/RDT-1B/models/rdt/__pycache__/blocks.cpython-310.pyc b/RDT-1B/models/rdt/__pycache__/blocks.cpython-310.pyc new file mode 100644 index 0000000..0137d01 Binary files /dev/null and b/RDT-1B/models/rdt/__pycache__/blocks.cpython-310.pyc differ diff --git a/RDT-1B/models/rdt/__pycache__/model.cpython-310.pyc b/RDT-1B/models/rdt/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000..69d1af3 Binary files /dev/null and b/RDT-1B/models/rdt/__pycache__/model.cpython-310.pyc differ diff --git a/RDT-1B/models/rdt/blocks.py b/RDT-1B/models/rdt/blocks.py new file mode 100644 index 0000000..cd48cc3 --- /dev/null +++ b/RDT-1B/models/rdt/blocks.py @@ -0,0 +1,304 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DiT: https://github.com/facebookresearch/DiT +# GLIDE: https://github.com/openai/glide-text2im +# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py +# -------------------------------------------------------- + +import math +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.jit import Final +from timm.models.vision_transformer import Attention, Mlp, RmsNorm, use_fused_attn + + +################################################################################# +# Embedding Layers for Timesteps and Condition Inptus # +################################################################################# +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256, dtype=torch.bfloat16): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + self.dtype = dtype + + def timestep_embedding(self, t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * + torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding.to(self.dtype) + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +################################################################################# +# Cross Attention Layers # +################################################################################# +class CrossAttention(nn.Module): + """ + A cross-attention layer with flash attention. + """ + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0, + proj_drop: float = 0, + norm_layer: nn.Module = nn.LayerNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.fused_attn = use_fused_attn() + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, c: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: + B, N, C = x.shape + _, L, _ = c.shape + q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + kv = self.kv(c).reshape(B, L, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + k, v = kv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + # Prepare attn mask (B, L) to mask the conditioion + if mask is not None: + mask = mask.reshape(B, 1, 1, L) + mask = mask.expand(-1, -1, N, -1) + + if self.fused_attn: + x = F.scaled_dot_product_attention(query=q, + key=k, + value=v, + dropout_p=self.attn_drop.p if self.training else 0., + attn_mask=mask) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + if mask is not None: + attn = attn.masked_fill_(mask.logical_not(), float('-inf')) + attn = attn.softmax(dim=-1) + if self.attn_drop.p > 0: + attn = self.attn_drop(attn) + x = attn @ v + + x = x.permute(0, 2, 1, 3).reshape(B, N, C) + x = self.proj(x) + if self.proj_drop.p > 0: + x = self.proj_drop(x) + return x + + +################################################################################# +# RDT Block # +################################################################################# +class RDTBlock(nn.Module): + """ + A RDT block with cross-attention conditioning. + """ + + def __init__(self, hidden_size, num_heads, **block_kwargs): + super().__init__() + self.norm1 = RmsNorm(hidden_size, eps=1e-6) + self.attn = Attention(dim=hidden_size, + num_heads=num_heads, + qkv_bias=True, + qk_norm=True, + norm_layer=RmsNorm, + **block_kwargs) + self.cross_attn = CrossAttention(hidden_size, + num_heads=num_heads, + qkv_bias=True, + qk_norm=True, + norm_layer=RmsNorm, + **block_kwargs) + + self.norm2 = RmsNorm(hidden_size, eps=1e-6) + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.ffn = Mlp(in_features=hidden_size, hidden_features=hidden_size, act_layer=approx_gelu, drop=0) + self.norm3 = RmsNorm(hidden_size, eps=1e-6) + + def forward(self, x, c, mask=None): + origin_x = x + x = self.norm1(x) + x = self.attn(x) + x = x + origin_x + + origin_x = x + x = self.norm2(x) + x = self.cross_attn(x, c, mask) + x = x + origin_x + + origin_x = x + x = self.norm3(x) + x = self.ffn(x) + x = x + origin_x + + return x + + +class FinalLayer(nn.Module): + """ + The final layer of RDT. + """ + + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = RmsNorm(hidden_size, eps=1e-6) + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.ffn_final = Mlp(in_features=hidden_size, + hidden_features=hidden_size, + out_features=out_channels, + act_layer=approx_gelu, + drop=0) + + def forward(self, x): + x = self.norm_final(x) + x = self.ffn_final(x) + return x + + +################################################################################# +# Sine/Cosine Positional Embedding Functions # +################################################################################# +# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + if not isinstance(pos, np.ndarray): + pos = np.array(pos, dtype=np.float64) + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def get_nd_sincos_pos_embed_from_grid(embed_dim, grid_sizes): + """ + embed_dim: output dimension for each position + grid_sizes: the grids sizes in each dimension (K,). + out: (grid_sizes[0], ..., grid_sizes[K-1], D) + """ + num_sizes = len(grid_sizes) + # For grid size of 1, we do not need to add any positional embedding + num_valid_sizes = len([x for x in grid_sizes if x > 1]) + emb = np.zeros(grid_sizes + (embed_dim, )) + # Uniformly divide the embedding dimension for each grid size + dim_for_each_grid = embed_dim // num_valid_sizes + # To make it even + if dim_for_each_grid % 2 != 0: + dim_for_each_grid -= 1 + valid_size_idx = 0 + for size_idx in range(num_sizes): + grid_size = grid_sizes[size_idx] + if grid_size <= 1: + continue + pos = np.arange(grid_size) + posemb_shape = [1] * len(grid_sizes) + [dim_for_each_grid] + posemb_shape[size_idx] = -1 + emb[..., valid_size_idx * dim_for_each_grid:(valid_size_idx + 1) * dim_for_each_grid] += \ + get_1d_sincos_pos_embed_from_grid(dim_for_each_grid, pos).reshape(posemb_shape) + valid_size_idx += 1 + return emb + + +def get_multimodal_cond_pos_embed(embed_dim, mm_cond_lens: OrderedDict, embed_modality=True): + """ + Generate position embeddings for multimodal conditions. + + mm_cond_lens: an OrderedDict containing + (modality name, modality token length) pairs. + For `"image"` modality, the value can be a multi-dimensional tuple. + If the length < 0, it means there is no position embedding for the modality or grid. + embed_modality: whether to embed the modality information. Default is True. + """ + num_modalities = len(mm_cond_lens) + modality_pos_embed = np.zeros((num_modalities, embed_dim)) + if embed_modality: + # Get embeddings for various modalites + # We put it in the first half + modality_sincos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, torch.arange(num_modalities)) + modality_pos_embed[:, :embed_dim // 2] = modality_sincos_embed + # The second half is for position embeddings + pos_embed_dim = embed_dim // 2 + else: + # The whole embedding is for position embeddings + pos_embed_dim = embed_dim + + # Get embeddings for positions inside each modality + c_pos_emb = np.zeros((0, embed_dim)) + for idx, (modality, cond_len) in enumerate(mm_cond_lens.items()): + if modality == "image" and \ + (isinstance(cond_len, tuple) or isinstance(cond_len, list)): + all_grid_sizes = tuple([abs(x) for x in cond_len]) + embed_grid_sizes = tuple([x if x > 0 else 1 for x in cond_len]) + cond_sincos_embed = get_nd_sincos_pos_embed_from_grid(pos_embed_dim, embed_grid_sizes) + cond_pos_embed = np.zeros(all_grid_sizes + (embed_dim, )) + cond_pos_embed[..., -pos_embed_dim:] += cond_sincos_embed + cond_pos_embed = cond_pos_embed.reshape((-1, embed_dim)) + else: + cond_sincos_embed = get_1d_sincos_pos_embed_from_grid(pos_embed_dim, + torch.arange(cond_len if cond_len > 0 else 1)) + cond_pos_embed = np.zeros((abs(cond_len), embed_dim)) + cond_pos_embed[:, -pos_embed_dim:] += cond_sincos_embed + cond_pos_embed += modality_pos_embed[idx] + c_pos_emb = np.concatenate([c_pos_emb, cond_pos_embed], axis=0) + + return c_pos_emb diff --git a/RDT-1B/models/rdt/model.py b/RDT-1B/models/rdt/model.py new file mode 100644 index 0000000..f3a3649 --- /dev/null +++ b/RDT-1B/models/rdt/model.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DiT: https://github.com/facebookresearch/DiT +# GLIDE: https://github.com/openai/glide-text2im +# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py +# -------------------------------------------------------- +from collections import OrderedDict + +import torch +import torch.nn as nn + +from pathlib import Path +import sys, os +# get current workspace +current_file = Path(__file__) +sys.path.append(str(current_file.parent.parent)) + +from rdt.blocks import (FinalLayer, RDTBlock, TimestepEmbedder, get_1d_sincos_pos_embed_from_grid, + get_multimodal_cond_pos_embed) + + +class RDT(nn.Module): + """ + Class for Robotics Diffusion Transformers. + """ + + def __init__(self, + output_dim=128, + horizon=32, + hidden_size=1152, + depth=28, + num_heads=16, + max_lang_cond_len=1024, + img_cond_len=4096, + lang_pos_embed_config=None, + img_pos_embed_config=None, + dtype=torch.bfloat16): + super().__init__() + self.horizon = horizon + self.hidden_size = hidden_size + self.max_lang_cond_len = max_lang_cond_len + self.img_cond_len = img_cond_len + self.dtype = dtype + self.lang_pos_embed_config = lang_pos_embed_config + self.img_pos_embed_config = img_pos_embed_config + + self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype) + self.freq_embedder = TimestepEmbedder(hidden_size, dtype=dtype) + + # We will use trainable sin-cos embeddings + # [timestep; state; action] + self.x_pos_embed = nn.Parameter(torch.zeros(1, horizon + 3, hidden_size)) + # Language conditions + self.lang_cond_pos_embed = nn.Parameter(torch.zeros(1, max_lang_cond_len, hidden_size)) + # Image conditions + self.img_cond_pos_embed = nn.Parameter(torch.zeros(1, img_cond_len, hidden_size)) + + self.blocks = nn.ModuleList([RDTBlock(hidden_size, num_heads) for _ in range(depth)]) + self.final_layer = FinalLayer(hidden_size, output_dim) + self.initialize_weights() + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize pos_embed by sin-cos embedding + x_pos_embed = get_multimodal_cond_pos_embed(embed_dim=self.hidden_size, + mm_cond_lens=OrderedDict([ + ('timestep', 1), + ('ctrl_freq', 1), + ('state', 1), + ('action', self.horizon), + ])) + self.x_pos_embed.data.copy_(torch.from_numpy(x_pos_embed).float().unsqueeze(0)) + + if self.lang_pos_embed_config is None: + lang_cond_pos_embed = get_1d_sincos_pos_embed_from_grid(self.hidden_size, + torch.arange(self.max_lang_cond_len)) + else: + lang_cond_pos_embed = get_multimodal_cond_pos_embed(embed_dim=self.hidden_size, + mm_cond_lens=OrderedDict(self.lang_pos_embed_config), + embed_modality=False) + self.lang_cond_pos_embed.data.copy_(torch.from_numpy(lang_cond_pos_embed).float().unsqueeze(0)) + + if self.img_pos_embed_config is None: + img_cond_pos_embed = get_1d_sincos_pos_embed_from_grid(self.hidden_size, torch.arange(self.img_cond_len)) + else: + img_cond_pos_embed = get_multimodal_cond_pos_embed(embed_dim=self.hidden_size, + mm_cond_lens=OrderedDict(self.img_pos_embed_config), + embed_modality=False) + self.img_cond_pos_embed.data.copy_(torch.from_numpy(img_cond_pos_embed).float().unsqueeze(0)) + + # Initialize timestep and control freq embedding MLP + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + nn.init.normal_(self.freq_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.freq_embedder.mlp[2].weight, std=0.02) + + # Initialize the final layer: zero-out the final linear layer + nn.init.constant_(self.final_layer.ffn_final.fc2.weight, 0) + nn.init.constant_(self.final_layer.ffn_final.fc2.bias, 0) + + # Move all the params to given data type: + self.to(self.dtype) + + def forward(self, x, freq, t, lang_c, img_c, lang_mask=None, img_mask=None): + """ + Forward pass of RDT. + + x: (B, T, D), state + action token sequence, T = horizon + 1, + dimension D is assumed to be the same as the hidden size. + freq: (B,), a scalar indicating control frequency. + t: (B,) or (1,), diffusion timesteps. + lang_c: (B, L_lang, D) or None, language condition tokens (variable length), + dimension D is assumed to be the same as the hidden size. + img_c: (B, L_img, D) or None, image condition tokens (fixed length), + dimension D is assumed to be the same as the hidden size. + lang_mask: (B, L_lang) or None, language condition mask (True for valid). + img_mask: (B, L_img) or None, image condition mask (True for valid). + """ + t = self.t_embedder(t).unsqueeze(1) # (B, 1, D) or (1, 1, D) + freq = self.freq_embedder(freq).unsqueeze(1) # (B, 1, D) + # Append timestep to the input tokens + if t.shape[0] == 1: + t = t.expand(x.shape[0], -1, -1) + x = torch.cat([t, freq, x], dim=1) # (B, T+1, D) + + # Add multimodal position embeddings + x = x + self.x_pos_embed + # Note the lang is of variable length + lang_c = lang_c + self.lang_cond_pos_embed[:, :lang_c.shape[1]] + img_c = img_c + self.img_cond_pos_embed + + # Forward pass + conds = [lang_c, img_c] + masks = [lang_mask, img_mask] + for i, block in enumerate(self.blocks): + c, mask = conds[i % 2], masks[i % 2] + x = block(x, c, mask) # (B, T+1, D) + # Inject the language condition at the final layer + x = self.final_layer(x) # (B, T+1, out_channels) + + # Only preserve the action tokens + x = x[:, -self.horizon:] + return x diff --git a/RDT-1B/models/rdt_runner.py b/RDT-1B/models/rdt_runner.py new file mode 100644 index 0000000..0156db7 --- /dev/null +++ b/RDT-1B/models/rdt_runner.py @@ -0,0 +1,246 @@ +import re, sys, os +from pathlib import Path + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +from diffusers.schedulers.scheduling_dpmsolver_multistep import \ + DPMSolverMultistepScheduler + +from pathlib import Path +# get current workspace +current_file = Path(__file__) +sys.path.append(os.path.join(current_file.parent)) +from hub_mixin import CompatiblePyTorchModelHubMixin +from rdt.model import RDT + + +class RDTRunner(nn.Module, + CompatiblePyTorchModelHubMixin, + repo_url="https://huggingface.co/robotics-diffusion-transformer/rdt-1b"): + + def __init__(self, + *, + action_dim, + pred_horizon, + config, + lang_token_dim, + img_token_dim, + state_token_dim, + max_lang_cond_len, + img_cond_len, + lang_pos_embed_config=None, + img_pos_embed_config=None, + dtype=torch.bfloat16): + super(RDTRunner, self).__init__() + # Create diffusion model + hidden_size = config['rdt']['hidden_size'] + self.model = RDT( + output_dim=action_dim, + horizon=pred_horizon, + hidden_size=hidden_size, + depth=config['rdt']['depth'], + num_heads=config['rdt']['num_heads'], + max_lang_cond_len=max_lang_cond_len, + img_cond_len=img_cond_len, + lang_pos_embed_config=lang_pos_embed_config, + img_pos_embed_config=img_pos_embed_config, + dtype=dtype, + ) + + # Create adpators for various conditional inputs + self.lang_adaptor = self.build_condition_adapter(config['lang_adaptor'], + in_features=lang_token_dim, + out_features=hidden_size) + self.img_adaptor = self.build_condition_adapter(config['img_adaptor'], + in_features=img_token_dim, + out_features=hidden_size) + # A `state` refers to an action or a proprioception vector + self.state_adaptor = self.build_condition_adapter( + config['state_adaptor'], + in_features=state_token_dim * 2, # state + state mask (indicator) + out_features=hidden_size) + + # Create the noise scheduler + noise_scheduler_config = config['noise_scheduler'] + self.noise_scheduler = DDPMScheduler( + num_train_timesteps=noise_scheduler_config['num_train_timesteps'], + beta_schedule=noise_scheduler_config['beta_schedule'], + prediction_type=noise_scheduler_config['prediction_type'], + clip_sample=noise_scheduler_config['clip_sample'], + ) + self.noise_scheduler_sample = DPMSolverMultistepScheduler( + num_train_timesteps=noise_scheduler_config['num_train_timesteps'], + beta_schedule=noise_scheduler_config['beta_schedule'], + prediction_type=noise_scheduler_config['prediction_type'], + ) + + self.num_train_timesteps = noise_scheduler_config['num_train_timesteps'] + self.num_inference_timesteps = noise_scheduler_config['num_inference_timesteps'] + self.prediction_type = noise_scheduler_config['prediction_type'] + + self.pred_horizon = pred_horizon + self.action_dim = action_dim + + print("Diffusion params: %e" % + sum([p.numel() for p in self.model.parameters()] + [p.numel() for p in self.lang_adaptor.parameters()] + + [p.numel() + for p in self.img_adaptor.parameters()] + [p.numel() for p in self.state_adaptor.parameters()])) + + def build_condition_adapter(self, projector_type, in_features, out_features): + projector = None + if projector_type == 'linear': + projector = nn.Linear(in_features, out_features) + else: + mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) + if mlp_gelu_match: + mlp_depth = int(mlp_gelu_match.group(1)) + modules = [nn.Linear(in_features, out_features)] + for _ in range(1, mlp_depth): + modules.append(nn.GELU(approximate="tanh")) + modules.append(nn.Linear(out_features, out_features)) + projector = nn.Sequential(*modules) + + if projector is None: + raise ValueError(f'Unknown projector type: {projector_type}') + + return projector + + def adapt_conditions(self, lang_tokens, img_tokens, state_tokens): + ''' + lang_tokens: (batch_size, lang_len, lang_token_dim) + img_tokens: (batch_size, img_len, img_token_dim) + state_tokens: (batch_size, state_len, state_token_dim) + + return: adpated (..., hidden_size) for all input tokens + ''' + adpated_lang = self.lang_adaptor(lang_tokens) + adpated_img = self.img_adaptor(img_tokens) + adpated_state = self.state_adaptor(state_tokens) + + return adpated_lang, adpated_img, adpated_state + + def conditional_sample(self, lang_cond, lang_attn_mask, img_cond, state_traj, action_mask, ctrl_freqs): + ''' + lang_cond: language conditional data, (batch_size, lang_len, hidden_size). + lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens, + which should be True-False bool tensor. + img_cond: image conditional data, (batch_size, img_len, hidden_size). + state_traj: (batch_size, 1, hidden_size), state trajectory. + action_mask: (batch_size, 1, action_dim), a 0-1 **float** tensor + indicating the valid action dimensions. + ctrl_freqs: (batch_size,), control frequency for each sample. + + return: (batch_size, horizon, action_dim) + ''' + device = state_traj.device + dtype = state_traj.dtype + noisy_action = torch.randn(size=(state_traj.shape[0], self.pred_horizon, self.action_dim), + dtype=dtype, + device=device) + action_mask = action_mask.expand(-1, self.pred_horizon, -1) + + # Set step values + self.noise_scheduler_sample.set_timesteps(self.num_inference_timesteps) + + for t in self.noise_scheduler_sample.timesteps: + # Prepare state-action trajectory + action_traj = torch.cat([noisy_action, action_mask], dim=2) + action_traj = self.state_adaptor(action_traj) + state_action_traj = torch.cat([state_traj, action_traj], dim=1) + + # Predict the model output + model_output = self.model(state_action_traj, + ctrl_freqs, + t.unsqueeze(-1).to(device), + lang_cond, + img_cond, + lang_mask=lang_attn_mask) + + # Compute previous actions: x_t -> x_t-1 + noisy_action = self.noise_scheduler_sample.step(model_output, t, noisy_action).prev_sample + noisy_action = noisy_action.to(state_traj.dtype) + + # Finally apply the action mask to mask invalid action dimensions + noisy_action = noisy_action * action_mask + + return noisy_action + + # ========= Train ============ + def compute_loss(self, lang_tokens, lang_attn_mask, img_tokens, state_tokens, action_gt, action_mask, + ctrl_freqs) -> torch.Tensor: + ''' + lang_tokens: (batch_size, lang_len, lang_token_dim) + lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens, + which should be True-False bool tensor. + img_tokens: (batch_size, img_len, img_token_dim) + state_tokens: (batch_size, 1, state_token_dim) + action_gt: (batch_size, horizon, state_token_dim), ground-truth actions for supervision + action_mask: (batch_size, 1, state_token_dim), a 0-1 **float** tensor. + ctrl_freqs: (batch_size,), control frequency for each sample. + + return: loss_value, a scalar tensor + ''' + batch_size = lang_tokens.shape[0] + device = lang_tokens.device + # Sample noise that we'll add to the actions + noise = torch.randn(action_gt.shape, dtype=action_gt.dtype, device=device) + # Sample random diffusion timesteps + timesteps = torch.randint(0, self.num_train_timesteps, (batch_size, ), device=device).long() + # Add noise to the clean actions according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_action = self.noise_scheduler.add_noise(action_gt, noise, timesteps) + + # Concatenate the state and action tokens to form the input sequence + state_action_traj = torch.cat([state_tokens, noisy_action], dim=1) + # Append the action mask to the input sequence + action_mask = action_mask.expand(-1, state_action_traj.shape[1], -1) + state_action_traj = torch.cat([state_action_traj, action_mask], dim=2) + # Align the dimension with the hidden size + lang_cond, img_cond, state_action_traj = self.adapt_conditions(lang_tokens, img_tokens, state_action_traj) + # Predict the denoised result + pred = self.model(state_action_traj, ctrl_freqs, timesteps, lang_cond, img_cond, lang_mask=lang_attn_mask) + + pred_type = self.prediction_type + if pred_type == 'epsilon': + target = noise + elif pred_type == 'sample': + target = action_gt + else: + raise ValueError(f"Unsupported prediction type {pred_type}") + loss = F.mse_loss(pred, target) + return loss + + # ========= Inference ============ + def predict_action(self, lang_tokens, lang_attn_mask, img_tokens, state_tokens, action_mask, ctrl_freqs): + ''' + lang_tokens: (batch_size, lang_len, lang_token_dim) + lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens, + which should be True-False bool tensor. + img_tokens: (batch_size, img_len, img_token_dim) + state_tokens: (batch_size, 1, state_token_dim) + action_mask: (batch_size, 1, action_dim), + which should be a 0-1 **float** tensor. + ctrl_freqs: (batch_size,), control frequency for each sample. + + return: (batch_size, horizon, action_dim), predicted action sequence + ''' + # Prepare the state and conditions + state_tokens = torch.cat([state_tokens, action_mask], dim=2) + lang_cond, img_cond, state_traj = self.adapt_conditions(lang_tokens, img_tokens, state_tokens) + + # Run sampling + action_pred = self.conditional_sample( + lang_cond, + lang_attn_mask, + img_cond, + state_traj, + action_mask, + ctrl_freqs, + ) + + return action_pred + + def forward(self, *args, **kwargs) -> torch.Tensor: + return self.compute_loss(*args, **kwargs) diff --git a/RDT-1B/pretrain.sh b/RDT-1B/pretrain.sh new file mode 100644 index 0000000..05a8066 --- /dev/null +++ b/RDT-1B/pretrain.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +export NCCL_IB_HCA=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_7:1,mlx5_8:1,mlx5_9:1 +export NCCL_IB_DISABLE=0 +export NCCL_SOCKET_IFNAME=bond0 +export NCCL_DEBUG=INFO +export NCCL_NVLS_ENABLE=0 + +export TEXT_ENCODER_NAME="google/t5-v1_1-xxl" +export VISION_ENCODER_NAME="google/siglip-so400m-patch14-384" +export OUTPUT_DIR="./checkpoints/rdt-pretrain-1b" +export CFLAGS="-I/usr/include" +export LDFLAGS="-L/usr/lib/x86_64-linux-gnu" +export CUTLASS_PATH="/path/to/cutlass" + +export WANDB_PROJECT="robotics_diffusion_transformer" + +if [ ! -d "$OUTPUT_DIR" ]; then + mkdir "$OUTPUT_DIR" + echo "Folder '$OUTPUT_DIR' created" +else + echo "Folder '$OUTPUT_DIR' already exists" +fi + +# For run in a single node/machine +# accelerate launch main.py \ +# --deepspeed="./configs/zero2.json" \ +# ... + +deepspeed --hostfile=hostfile.txt main.py \ + --deepspeed="./configs/zero2.json" \ + --pretrained_text_encoder_name_or_path=$TEXT_ENCODER_NAME \ + --pretrained_vision_encoder_name_or_path=$VISION_ENCODER_NAME \ + --output_dir=$OUTPUT_DIR \ + --train_batch_size=32 \ + --sample_batch_size=64 \ + --max_train_steps=1000000 \ + --checkpointing_period=1000 \ + --sample_period=500 \ + --checkpoints_total_limit=40 \ + --lr_scheduler="constant" \ + --learning_rate=1e-4 \ + --mixed_precision="bf16" \ + --dataloader_num_workers=8 \ + --dataset_type="pretrain" \ + --report_to=wandb + + # Use this to resume training from some previous checkpoint + # --resume_from_checkpoint="checkpoint-1000" \ diff --git a/RDT-1B/process_data_rdt.sh b/RDT-1B/process_data_rdt.sh new file mode 100644 index 0000000..c63fe0d --- /dev/null +++ b/RDT-1B/process_data_rdt.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +task_name=${1} +task_config=${2} +expert_data_num=${3} +gpu_id=${4} + +export CUDA_VISIBLE_DEVICES=${gpu_id} +python scripts/process_data.py $task_name $task_config $expert_data_num \ No newline at end of file diff --git a/RDT-1B/requirements.txt b/RDT-1B/requirements.txt new file mode 100644 index 0000000..63011cd --- /dev/null +++ b/RDT-1B/requirements.txt @@ -0,0 +1,23 @@ +numpy<2.0 +packaging==24.0 +wandb==0.17.0 +deepspeed==0.14.2 +accelerate==0.30.1 +diffusers==0.27.2 +timm==1.0.3 +transformers==4.41.0 +sentencepiece==0.2.0 +h5py==3.11.0 +opencv-python==4.9.0.80 +imgaug==0.4.0 +pytz>=2020.1 + +# requirements_data.txt +tfds-nightly==4.9.4.dev202402070044 +gsutil==5.27 +tensorflow==2.15.0.post1 +pillow==10.2.0 +pyyaml==6.0.1 +tensorflow-graphics==2021.12.3 +imageio==2.34.0 +imageio-ffmpeg==0.4.9 diff --git a/RDT-1B/scripts/agilex_inference.py b/RDT-1B/scripts/agilex_inference.py new file mode 100644 index 0000000..6bd21bd --- /dev/null +++ b/RDT-1B/scripts/agilex_inference.py @@ -0,0 +1,941 @@ +#!/home/lin/software/miniconda3/envs/aloha/bin/python +# -- coding: UTF-8 +""" +#!/usr/bin/python3 +""" + +import argparse +import sys +import threading +import time +import yaml +from collections import deque + +import numpy as np +import rospy +import torch +from cv_bridge import CvBridge +from geometry_msgs.msg import Twist +from nav_msgs.msg import Odometry +from PIL import Image as PImage +from sensor_msgs.msg import Image, JointState +from std_msgs.msg import Header +import cv2 + +from scripts.agilex_model import create_model + +# sys.path.append("./") + +CAMERA_NAMES = ["cam_high", "cam_right_wrist", "cam_left_wrist"] + +observation_window = None + +lang_embeddings = None + +# debug +preload_images = None + + +# Initialize the model +def make_policy(args): + with open(args.config_path, "r") as fp: + config = yaml.safe_load(fp) + args.config = config + + # pretrained_text_encoder_name_or_path = "google/t5-v1_1-xxl" + pretrained_vision_encoder_name_or_path = "google/siglip-so400m-patch14-384" + model = create_model( + args=args.config, + dtype=torch.bfloat16, + pretrained=args.pretrained_model_name_or_path, + # pretrained_text_encoder_name_or_path=pretrained_text_encoder_name_or_path, + pretrained_vision_encoder_name_or_path=pretrained_vision_encoder_name_or_path, + control_frequency=args.ctrl_freq, + ) + + return model + + +def set_seed(seed): + torch.manual_seed(seed) + np.random.seed(seed) + + +# Interpolate the actions to make the robot move smoothly +def interpolate_action(args, prev_action, cur_action): + steps = np.concatenate((np.array(args.arm_steps_length), np.array(args.arm_steps_length)), axis=0) + diff = np.abs(cur_action - prev_action) + step = np.ceil(diff / steps).astype(int) + step = np.max(step) + if step <= 1: + return cur_action[np.newaxis, :] + new_actions = np.linspace(prev_action, cur_action, step + 1) + return new_actions[1:] + + +def get_config(args): + config = { + "episode_len": args.max_publish_step, + "state_dim": 14, + "chunk_size": args.chunk_size, + "camera_names": CAMERA_NAMES, + } + return config + + +# Get the observation from the ROS topic +def get_ros_observation(args, ros_operator): + rate = rospy.Rate(args.publish_rate) + print_flag = True + + while True and not rospy.is_shutdown(): + result = ros_operator.get_frame() + if not result: + if print_flag: + print("syn fail when get_ros_observation") + print_flag = False + rate.sleep() + continue + print_flag = True + ( + img_front, + img_left, + img_right, + img_front_depth, + img_left_depth, + img_right_depth, + puppet_arm_left, + puppet_arm_right, + robot_base, + ) = result + # print(f"sync success when get_ros_observation") + return (img_front, img_left, img_right, puppet_arm_left, puppet_arm_right) + + +# Update the observation window buffer +def update_observation_window(args, config, ros_operator): + # JPEG transformation + # Align with training + def jpeg_mapping(img): + img = cv2.imencode(".jpg", img)[1].tobytes() + img = cv2.imdecode(np.frombuffer(img, np.uint8), cv2.IMREAD_COLOR) + return img + + global observation_window + if observation_window is None: + observation_window = deque(maxlen=2) + + # Append the first dummy image + observation_window.append({ + "qpos": None, + "images": { + config["camera_names"][0]: None, + config["camera_names"][1]: None, + config["camera_names"][2]: None, + }, + }) + + img_front, img_left, img_right, puppet_arm_left, puppet_arm_right = (get_ros_observation(args, ros_operator)) + img_front = jpeg_mapping(img_front) + img_left = jpeg_mapping(img_left) + img_right = jpeg_mapping(img_right) + + qpos = np.concatenate( + (np.array(puppet_arm_left.position), np.array(puppet_arm_right.position)), + axis=0, + ) + qpos = torch.from_numpy(qpos).float().cuda() + observation_window.append({ + "qpos": qpos, + "images": { + config["camera_names"][0]: img_front, + config["camera_names"][1]: img_right, + config["camera_names"][2]: img_left, + }, + }) + + +# RDT inference +def inference_fn(args, config, policy, t): + global observation_window + global lang_embeddings + + # print(f"Start inference_thread_fn: t={t}") + while True and not rospy.is_shutdown(): + time1 = time.time() + + # fetch images in sequence [front, right, left] + image_arrs = [ + observation_window[-2]["images"][config["camera_names"][0]], + observation_window[-2]["images"][config["camera_names"][1]], + observation_window[-2]["images"][config["camera_names"][2]], + observation_window[-1]["images"][config["camera_names"][0]], + observation_window[-1]["images"][config["camera_names"][1]], + observation_window[-1]["images"][config["camera_names"][2]], + ] + + # fetch debug images in sequence [front, right, left] + # image_arrs = [ + # preload_images[config['camera_names'][0]][max(t - 1, 0)], + # preload_images[config['camera_names'][2]][max(t - 1, 0)], + # preload_images[config['camera_names'][1]][max(t - 1, 0)], + # preload_images[config['camera_names'][0]][t], + # preload_images[config['camera_names'][2]][t], + # preload_images[config['camera_names'][1]][t] + # ] + # # encode the images + # for i in range(len(image_arrs)): + # image_arrs[i] = cv2.imdecode(np.frombuffer(image_arrs[i], np.uint8), cv2.IMREAD_COLOR) + # proprio = torch.from_numpy(preload_images['qpos'][t]).float().cuda() + + images = [PImage.fromarray(arr) if arr is not None else None for arr in image_arrs] + + # for i, pos in enumerate(['f', 'r', 'l'] * 2): + # images[i].save(f'{t}-{i}-{pos}.png') + + # get last qpos in shape [14, ] + proprio = observation_window[-1]["qpos"] + # unsqueeze to [1, 14] + proprio = proprio.unsqueeze(0) + + # actions shaped as [1, 64, 14] in format [left, right] + actions = (policy.step(proprio=proprio, images=images, text_embeds=lang_embeddings).squeeze(0).cpu().numpy()) + # print(f"inference_actions: {actions.squeeze()}") + + # print(f"Model inference time: {time.time() - time1} s") + + # print(f"Finish inference_thread_fn: t={t}") + return actions + + +# Main loop for the manipulation task +def model_inference(args, config, ros_operator): + global lang_embeddings + + # Load rdt model + policy = make_policy(args) + + lang_dict = torch.load(args.lang_embeddings_path) + print(f"Running with instruction: \"{lang_dict['instruction']}\" from \"{lang_dict['name']}\"") + lang_embeddings = lang_dict["embeddings"] + + max_publish_step = config["episode_len"] + chunk_size = config["chunk_size"] + + # Initialize position of the puppet arm + left0 = [ + -0.00133514404296875, + 0.00209808349609375, + 0.01583099365234375, + -0.032616615295410156, + -0.00286102294921875, + 0.00095367431640625, + 3.557830810546875, + ] + right0 = [ + -0.00133514404296875, + 0.00438690185546875, + 0.034523963928222656, + -0.053597450256347656, + -0.00476837158203125, + -0.00209808349609375, + 3.557830810546875, + ] + left1 = [ + -0.00133514404296875, + 0.00209808349609375, + 0.01583099365234375, + -0.032616615295410156, + -0.00286102294921875, + 0.00095367431640625, + -0.3393220901489258, + ] + right1 = [ + -0.00133514404296875, + 0.00247955322265625, + 0.01583099365234375, + -0.032616615295410156, + -0.00286102294921875, + 0.00095367431640625, + -0.3397035598754883, + ] + ros_operator.puppet_arm_publish_continuous(left0, right0) + input("Press enter to continue") + ros_operator.puppet_arm_publish_continuous(left1, right1) + # Initialize the previous action to be the initial robot state + pre_action = np.zeros(config["state_dim"]) + pre_action[:14] = np.array([ + -0.00133514404296875, + 0.00209808349609375, + 0.01583099365234375, + -0.032616615295410156, + -0.00286102294921875, + 0.00095367431640625, + -0.3393220901489258, + ] + [ + -0.00133514404296875, + 0.00247955322265625, + 0.01583099365234375, + -0.032616615295410156, + -0.00286102294921875, + 0.00095367431640625, + -0.3397035598754883, + ]) + action = None + # Inference loop + with torch.inference_mode(): + while True and not rospy.is_shutdown(): + # The current time step + t = 0 + rate = rospy.Rate(args.publish_rate) + + action_buffer = np.zeros([chunk_size, config["state_dim"]]) + + while t < max_publish_step and not rospy.is_shutdown(): + # Update observation window + update_observation_window(args, config, ros_operator) + + # When coming to the end of the action chunk + if t % chunk_size == 0: + # Start inference + action_buffer = inference_fn(args, config, policy, t).copy() + + raw_action = action_buffer[t % chunk_size] + action = raw_action + # Interpolate the original action sequence + if args.use_actions_interpolation: + # print(f"Time {t}, pre {pre_action}, act {action}") + interp_actions = interpolate_action(args, pre_action, action) + else: + interp_actions = action[np.newaxis, :] + # Execute the interpolated actions one by one + for act in interp_actions: + left_action = act[:7] + right_action = act[7:14] + + if not args.disable_puppet_arm: + ros_operator.puppet_arm_publish(left_action, + right_action) # puppet_arm_publish_continuous_thread + + if args.use_robot_base: + vel_action = act[14:16] + ros_operator.robot_base_publish(vel_action) + rate.sleep() + # print(f"doing action: {act}") + t += 1 + + print("Published Step", t) + pre_action = action.copy() + + +# ROS operator class +class RosOperator: + + def __init__(self, args): + self.robot_base_deque = None + self.puppet_arm_right_deque = None + self.puppet_arm_left_deque = None + self.img_front_deque = None + self.img_right_deque = None + self.img_left_deque = None + self.img_front_depth_deque = None + self.img_right_depth_deque = None + self.img_left_depth_deque = None + self.bridge = None + self.puppet_arm_left_publisher = None + self.puppet_arm_right_publisher = None + self.robot_base_publisher = None + self.puppet_arm_publish_thread = None + self.puppet_arm_publish_lock = None + self.args = args + self.init() + self.init_ros() + + def init(self): + self.bridge = CvBridge() + self.img_left_deque = deque() + self.img_right_deque = deque() + self.img_front_deque = deque() + self.img_left_depth_deque = deque() + self.img_right_depth_deque = deque() + self.img_front_depth_deque = deque() + self.puppet_arm_left_deque = deque() + self.puppet_arm_right_deque = deque() + self.robot_base_deque = deque() + self.puppet_arm_publish_lock = threading.Lock() + self.puppet_arm_publish_lock.acquire() + + def puppet_arm_publish(self, left, right): + joint_state_msg = JointState() + joint_state_msg.header = Header() + joint_state_msg.header.stamp = rospy.Time.now() # Set timestep + joint_state_msg.name = [ + "joint0", + "joint1", + "joint2", + "joint3", + "joint4", + "joint5", + "joint6", + ] # 设置关节名称 + joint_state_msg.position = left + self.puppet_arm_left_publisher.publish(joint_state_msg) + joint_state_msg.position = right + self.puppet_arm_right_publisher.publish(joint_state_msg) + + def robot_base_publish(self, vel): + vel_msg = Twist() + vel_msg.linear.x = vel[0] + vel_msg.linear.y = 0 + vel_msg.linear.z = 0 + vel_msg.angular.x = 0 + vel_msg.angular.y = 0 + vel_msg.angular.z = vel[1] + self.robot_base_publisher.publish(vel_msg) + + def puppet_arm_publish_continuous(self, left, right): + rate = rospy.Rate(self.args.publish_rate) + left_arm = None + right_arm = None + while True and not rospy.is_shutdown(): + if len(self.puppet_arm_left_deque) != 0: + left_arm = list(self.puppet_arm_left_deque[-1].position) + if len(self.puppet_arm_right_deque) != 0: + right_arm = list(self.puppet_arm_right_deque[-1].position) + if left_arm is None or right_arm is None: + rate.sleep() + continue + else: + break + left_symbol = [1 if left[i] - left_arm[i] > 0 else -1 for i in range(len(left))] + right_symbol = [1 if right[i] - right_arm[i] > 0 else -1 for i in range(len(right))] + flag = True + step = 0 + while flag and not rospy.is_shutdown(): + if self.puppet_arm_publish_lock.acquire(False): + return + left_diff = [abs(left[i] - left_arm[i]) for i in range(len(left))] + right_diff = [abs(right[i] - right_arm[i]) for i in range(len(right))] + flag = False + for i in range(len(left)): + if left_diff[i] < self.args.arm_steps_length[i]: + left_arm[i] = left[i] + else: + left_arm[i] += left_symbol[i] * self.args.arm_steps_length[i] + flag = True + for i in range(len(right)): + if right_diff[i] < self.args.arm_steps_length[i]: + right_arm[i] = right[i] + else: + right_arm[i] += right_symbol[i] * self.args.arm_steps_length[i] + flag = True + joint_state_msg = JointState() + joint_state_msg.header = Header() + joint_state_msg.header.stamp = rospy.Time.now() # Set the timestep + joint_state_msg.name = [ + "joint0", + "joint1", + "joint2", + "joint3", + "joint4", + "joint5", + "joint6", + ] # 设置关节名称 + joint_state_msg.position = left_arm + self.puppet_arm_left_publisher.publish(joint_state_msg) + joint_state_msg.position = right_arm + self.puppet_arm_right_publisher.publish(joint_state_msg) + step += 1 + print("puppet_arm_publish_continuous:", step) + rate.sleep() + + def puppet_arm_publish_linear(self, left, right): + num_step = 100 + rate = rospy.Rate(200) + + left_arm = None + right_arm = None + + while True and not rospy.is_shutdown(): + if len(self.puppet_arm_left_deque) != 0: + left_arm = list(self.puppet_arm_left_deque[-1].position) + if len(self.puppet_arm_right_deque) != 0: + right_arm = list(self.puppet_arm_right_deque[-1].position) + if left_arm is None or right_arm is None: + rate.sleep() + continue + else: + break + + traj_left_list = np.linspace(left_arm, left, num_step) + traj_right_list = np.linspace(right_arm, right, num_step) + + for i in range(len(traj_left_list)): + traj_left = traj_left_list[i] + traj_right = traj_right_list[i] + traj_left[-1] = left[-1] + traj_right[-1] = right[-1] + joint_state_msg = JointState() + joint_state_msg.header = Header() + joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳 + joint_state_msg.name = [ + "joint0", + "joint1", + "joint2", + "joint3", + "joint4", + "joint5", + "joint6", + ] # 设置关节名称 + joint_state_msg.position = traj_left + self.puppet_arm_left_publisher.publish(joint_state_msg) + joint_state_msg.position = traj_right + self.puppet_arm_right_publisher.publish(joint_state_msg) + rate.sleep() + + def puppet_arm_publish_continuous_thread(self, left, right): + if self.puppet_arm_publish_thread is not None: + self.puppet_arm_publish_lock.release() + self.puppet_arm_publish_thread.join() + self.puppet_arm_publish_lock.acquire(False) + self.puppet_arm_publish_thread = None + self.puppet_arm_publish_thread = threading.Thread(target=self.puppet_arm_publish_continuous, args=(left, right)) + self.puppet_arm_publish_thread.start() + + def get_frame(self): + if (len(self.img_left_deque) == 0 or len(self.img_right_deque) == 0 or len(self.img_front_deque) == 0 or + (self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or len(self.img_right_depth_deque) == 0 + or len(self.img_front_depth_deque) == 0))): + return False + if self.args.use_depth_image: + frame_time = min([ + self.img_left_deque[-1].header.stamp.to_sec(), + self.img_right_deque[-1].header.stamp.to_sec(), + self.img_front_deque[-1].header.stamp.to_sec(), + self.img_left_depth_deque[-1].header.stamp.to_sec(), + self.img_right_depth_deque[-1].header.stamp.to_sec(), + self.img_front_depth_deque[-1].header.stamp.to_sec(), + ]) + else: + frame_time = min([ + self.img_left_deque[-1].header.stamp.to_sec(), + self.img_right_deque[-1].header.stamp.to_sec(), + self.img_front_deque[-1].header.stamp.to_sec(), + ]) + + if (len(self.img_left_deque) == 0 or self.img_left_deque[-1].header.stamp.to_sec() < frame_time): + return False + if (len(self.img_right_deque) == 0 or self.img_right_deque[-1].header.stamp.to_sec() < frame_time): + return False + if (len(self.img_front_deque) == 0 or self.img_front_deque[-1].header.stamp.to_sec() < frame_time): + return False + if (len(self.puppet_arm_left_deque) == 0 or self.puppet_arm_left_deque[-1].header.stamp.to_sec() < frame_time): + return False + if (len(self.puppet_arm_right_deque) == 0 + or self.puppet_arm_right_deque[-1].header.stamp.to_sec() < frame_time): + return False + if self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 + or self.img_left_depth_deque[-1].header.stamp.to_sec() < frame_time): + return False + if self.args.use_depth_image and (len(self.img_right_depth_deque) == 0 + or self.img_right_depth_deque[-1].header.stamp.to_sec() < frame_time): + return False + if self.args.use_depth_image and (len(self.img_front_depth_deque) == 0 + or self.img_front_depth_deque[-1].header.stamp.to_sec() < frame_time): + return False + if self.args.use_robot_base and (len(self.robot_base_deque) == 0 + or self.robot_base_deque[-1].header.stamp.to_sec() < frame_time): + return False + + while self.img_left_deque[0].header.stamp.to_sec() < frame_time: + self.img_left_deque.popleft() + img_left = self.bridge.imgmsg_to_cv2(self.img_left_deque.popleft(), "passthrough") + + while self.img_right_deque[0].header.stamp.to_sec() < frame_time: + self.img_right_deque.popleft() + img_right = self.bridge.imgmsg_to_cv2(self.img_right_deque.popleft(), "passthrough") + + while self.img_front_deque[0].header.stamp.to_sec() < frame_time: + self.img_front_deque.popleft() + img_front = self.bridge.imgmsg_to_cv2(self.img_front_deque.popleft(), "passthrough") + + while self.puppet_arm_left_deque[0].header.stamp.to_sec() < frame_time: + self.puppet_arm_left_deque.popleft() + puppet_arm_left = self.puppet_arm_left_deque.popleft() + + while self.puppet_arm_right_deque[0].header.stamp.to_sec() < frame_time: + self.puppet_arm_right_deque.popleft() + puppet_arm_right = self.puppet_arm_right_deque.popleft() + + img_left_depth = None + if self.args.use_depth_image: + while self.img_left_depth_deque[0].header.stamp.to_sec() < frame_time: + self.img_left_depth_deque.popleft() + img_left_depth = self.bridge.imgmsg_to_cv2(self.img_left_depth_deque.popleft(), "passthrough") + + img_right_depth = None + if self.args.use_depth_image: + while self.img_right_depth_deque[0].header.stamp.to_sec() < frame_time: + self.img_right_depth_deque.popleft() + img_right_depth = self.bridge.imgmsg_to_cv2(self.img_right_depth_deque.popleft(), "passthrough") + + img_front_depth = None + if self.args.use_depth_image: + while self.img_front_depth_deque[0].header.stamp.to_sec() < frame_time: + self.img_front_depth_deque.popleft() + img_front_depth = self.bridge.imgmsg_to_cv2(self.img_front_depth_deque.popleft(), "passthrough") + + robot_base = None + if self.args.use_robot_base: + while self.robot_base_deque[0].header.stamp.to_sec() < frame_time: + self.robot_base_deque.popleft() + robot_base = self.robot_base_deque.popleft() + + return ( + img_front, + img_left, + img_right, + img_front_depth, + img_left_depth, + img_right_depth, + puppet_arm_left, + puppet_arm_right, + robot_base, + ) + + def img_left_callback(self, msg): + if len(self.img_left_deque) >= 2000: + self.img_left_deque.popleft() + self.img_left_deque.append(msg) + + def img_right_callback(self, msg): + if len(self.img_right_deque) >= 2000: + self.img_right_deque.popleft() + self.img_right_deque.append(msg) + + def img_front_callback(self, msg): + if len(self.img_front_deque) >= 2000: + self.img_front_deque.popleft() + self.img_front_deque.append(msg) + + def img_left_depth_callback(self, msg): + if len(self.img_left_depth_deque) >= 2000: + self.img_left_depth_deque.popleft() + self.img_left_depth_deque.append(msg) + + def img_right_depth_callback(self, msg): + if len(self.img_right_depth_deque) >= 2000: + self.img_right_depth_deque.popleft() + self.img_right_depth_deque.append(msg) + + def img_front_depth_callback(self, msg): + if len(self.img_front_depth_deque) >= 2000: + self.img_front_depth_deque.popleft() + self.img_front_depth_deque.append(msg) + + def puppet_arm_left_callback(self, msg): + if len(self.puppet_arm_left_deque) >= 2000: + self.puppet_arm_left_deque.popleft() + self.puppet_arm_left_deque.append(msg) + + def puppet_arm_right_callback(self, msg): + if len(self.puppet_arm_right_deque) >= 2000: + self.puppet_arm_right_deque.popleft() + self.puppet_arm_right_deque.append(msg) + + def robot_base_callback(self, msg): + if len(self.robot_base_deque) >= 2000: + self.robot_base_deque.popleft() + self.robot_base_deque.append(msg) + + def init_ros(self): + rospy.init_node("joint_state_publisher", anonymous=True) + rospy.Subscriber( + self.args.img_left_topic, + Image, + self.img_left_callback, + queue_size=1000, + tcp_nodelay=True, + ) + rospy.Subscriber( + self.args.img_right_topic, + Image, + self.img_right_callback, + queue_size=1000, + tcp_nodelay=True, + ) + rospy.Subscriber( + self.args.img_front_topic, + Image, + self.img_front_callback, + queue_size=1000, + tcp_nodelay=True, + ) + if self.args.use_depth_image: + rospy.Subscriber( + self.args.img_left_depth_topic, + Image, + self.img_left_depth_callback, + queue_size=1000, + tcp_nodelay=True, + ) + rospy.Subscriber( + self.args.img_right_depth_topic, + Image, + self.img_right_depth_callback, + queue_size=1000, + tcp_nodelay=True, + ) + rospy.Subscriber( + self.args.img_front_depth_topic, + Image, + self.img_front_depth_callback, + queue_size=1000, + tcp_nodelay=True, + ) + rospy.Subscriber( + self.args.puppet_arm_left_topic, + JointState, + self.puppet_arm_left_callback, + queue_size=1000, + tcp_nodelay=True, + ) + rospy.Subscriber( + self.args.puppet_arm_right_topic, + JointState, + self.puppet_arm_right_callback, + queue_size=1000, + tcp_nodelay=True, + ) + rospy.Subscriber( + self.args.robot_base_topic, + Odometry, + self.robot_base_callback, + queue_size=1000, + tcp_nodelay=True, + ) + self.puppet_arm_left_publisher = rospy.Publisher(self.args.puppet_arm_left_cmd_topic, JointState, queue_size=10) + self.puppet_arm_right_publisher = rospy.Publisher(self.args.puppet_arm_right_cmd_topic, + JointState, + queue_size=10) + self.robot_base_publisher = rospy.Publisher(self.args.robot_base_cmd_topic, Twist, queue_size=10) + + +def get_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--max_publish_step", + action="store", + type=int, + help="Maximum number of action publishing steps", + default=10000, + required=False, + ) + parser.add_argument( + "--seed", + action="store", + type=int, + help="Random seed", + default=None, + required=False, + ) + + parser.add_argument( + "--img_front_topic", + action="store", + type=str, + help="img_front_topic", + default="/camera_f/color/image_raw", + required=False, + ) + parser.add_argument( + "--img_left_topic", + action="store", + type=str, + help="img_left_topic", + default="/camera_l/color/image_raw", + required=False, + ) + parser.add_argument( + "--img_right_topic", + action="store", + type=str, + help="img_right_topic", + default="/camera_r/color/image_raw", + required=False, + ) + + parser.add_argument( + "--img_front_depth_topic", + action="store", + type=str, + help="img_front_depth_topic", + default="/camera_f/depth/image_raw", + required=False, + ) + parser.add_argument( + "--img_left_depth_topic", + action="store", + type=str, + help="img_left_depth_topic", + default="/camera_l/depth/image_raw", + required=False, + ) + parser.add_argument( + "--img_right_depth_topic", + action="store", + type=str, + help="img_right_depth_topic", + default="/camera_r/depth/image_raw", + required=False, + ) + + parser.add_argument( + "--puppet_arm_left_cmd_topic", + action="store", + type=str, + help="puppet_arm_left_cmd_topic", + default="/master/joint_left", + required=False, + ) + parser.add_argument( + "--puppet_arm_right_cmd_topic", + action="store", + type=str, + help="puppet_arm_right_cmd_topic", + default="/master/joint_right", + required=False, + ) + parser.add_argument( + "--puppet_arm_left_topic", + action="store", + type=str, + help="puppet_arm_left_topic", + default="/puppet/joint_left", + required=False, + ) + parser.add_argument( + "--puppet_arm_right_topic", + action="store", + type=str, + help="puppet_arm_right_topic", + default="/puppet/joint_right", + required=False, + ) + + parser.add_argument( + "--robot_base_topic", + action="store", + type=str, + help="robot_base_topic", + default="/odom_raw", + required=False, + ) + parser.add_argument( + "--robot_base_cmd_topic", + action="store", + type=str, + help="robot_base_topic", + default="/cmd_vel", + required=False, + ) + parser.add_argument( + "--use_robot_base", + action="store_true", + help="Whether to use the robot base to move around", + default=False, + required=False, + ) + parser.add_argument( + "--publish_rate", + action="store", + type=int, + help="The rate at which to publish the actions", + default=30, + required=False, + ) + parser.add_argument( + "--ctrl_freq", + action="store", + type=int, + help="The control frequency of the robot", + default=25, + required=False, + ) + + parser.add_argument( + "--chunk_size", + action="store", + type=int, + help="Action chunk size", + default=64, + required=False, + ) + parser.add_argument( + "--arm_steps_length", + action="store", + type=float, + help="The maximum change allowed for each joint per timestep", + default=[0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.2], + required=False, + ) + + parser.add_argument( + "--use_actions_interpolation", + action="store_true", + help="Whether to interpolate the actions if the difference is too large", + default=False, + required=False, + ) + parser.add_argument( + "--use_depth_image", + action="store_true", + help="Whether to use depth images", + default=False, + required=False, + ) + + parser.add_argument( + "--disable_puppet_arm", + action="store_true", + help="Whether to disable the puppet arm. This is useful for safely debugging", + default=False, + ) + + parser.add_argument( + "--config_path", + type=str, + default="configs/base.yaml", + help="Path to the config file", + ) + # parser.add_argument('--cfg_scale', type=float, default=2.0, + # help='the scaling factor used to modify the magnitude of the control features during denoising') + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + required=True, + help="Name or path to the pretrained model", + ) + + parser.add_argument( + "--lang_embeddings_path", + type=str, + required=True, + help="Path to the pre-encoded language instruction embeddings", + ) + + args = parser.parse_args() + return args + + +def main(): + args = get_arguments() + ros_operator = RosOperator(args) + if args.seed is not None: + set_seed(args.seed) + config = get_config(args) + model_inference(args, config, ros_operator) + + +if __name__ == "__main__": + main() diff --git a/RDT-1B/scripts/agilex_model.py b/RDT-1B/scripts/agilex_model.py new file mode 100644 index 0000000..753fcf8 --- /dev/null +++ b/RDT-1B/scripts/agilex_model.py @@ -0,0 +1,344 @@ +import os, sys + +import numpy as np +import torch +from PIL import Image +from torchvision import transforms + +from configs.state_vec import STATE_VEC_IDX_MAPPING + +from pathlib import Path + +# get current workspace +current_file = Path(__file__) +sys.path.append(os.path.join(current_file.parent.parent, "models")) +sys.path.append(os.path.join(current_file.parent.parent, "models")) + +from multimodal_encoder.siglip_encoder import SiglipVisionTower +from multimodal_encoder.t5_encoder import T5Embedder +from rdt_runner import RDTRunner + +# The indices that the raw vector should be mapped to in the unified action vector +# AGILEX_STATE_INDICES = [ +# STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"] for i in range(1) +# ] + [ +# STATE_VEC_IDX_MAPPING["left_gripper_open"] +# ] + [ +# STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(1) +# ] + [ +# STATE_VEC_IDX_MAPPING[f"right_gripper_open"] +# ] +# AGILEX_STATE_INDICES = None + + +# Create the RDT model +def create_model(args, **kwargs): + left_arm_dim, right_arm_dim = ( + args["arm_dim"]["left_arm_dim"], + args["arm_dim"]["right_arm_dim"], + ) + AGILEX_STATE_INDICES = ([STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"] + for i in range(left_arm_dim)] + [STATE_VEC_IDX_MAPPING["left_gripper_open"]] + + [STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] + for i in range(right_arm_dim)] + [STATE_VEC_IDX_MAPPING[f"right_gripper_open"]]) + model = RoboticDiffusionTransformerModel(args, **kwargs) + pretrained = kwargs.get("pretrained", None) + if pretrained is not None and os.path.isfile(pretrained): + model.load_pretrained_weights(pretrained) + + return model + + +class RoboticDiffusionTransformerModel(object): + """A wrapper for the RDT model, which handles + 1. Model initialization + 2. Encodings of instructions + 3. Model inference + """ + + def __init__( + self, + args, + device="cuda", + dtype=torch.bfloat16, + image_size=None, + control_frequency=25, + pretrained=None, + pretrained_vision_encoder_name_or_path=None, + ): + self.args = args + self.dtype = dtype + self.image_size = image_size + self.device = device + self.control_frequency = control_frequency + # We do not use the text encoder due to limited GPU memory + # self.text_tokenizer, self.text_model = self.get_text_encoder(pretrained_text_encoder_name_or_path) + self.image_processor, self.vision_model = self.get_vision_encoder(pretrained_vision_encoder_name_or_path) + self.policy = self.get_policy(pretrained) + self.left_arm_dim, self.right_arm_dim = ( + args["arm_dim"]["left_arm_dim"], + args["arm_dim"]["right_arm_dim"], + ) + + self.reset() + + def get_policy(self, pretrained): + """Initialize the model.""" + # Initialize model with arguments + if pretrained is None or os.path.isfile(pretrained): + img_cond_len = (self.args["common"]["img_history_size"] * self.args["common"]["num_cameras"] * + self.vision_model.num_patches) + + _model = RDTRunner( + action_dim=self.args["common"]["state_dim"], + pred_horizon=self.args["common"]["action_chunk_size"], + config=self.args["model"], + lang_token_dim=self.args["model"]["lang_token_dim"], + img_token_dim=self.args["model"]["img_token_dim"], + state_token_dim=self.args["model"]["state_token_dim"], + max_lang_cond_len=self.args["dataset"]["tokenizer_max_length"], + img_cond_len=img_cond_len, + img_pos_embed_config=[ + # No initial pos embed in the last grid size + # since we've already done in ViT + ( + "image", + ( + self.args["common"]["img_history_size"], + self.args["common"]["num_cameras"], + -self.vision_model.num_patches, + ), + ), + ], + lang_pos_embed_config=[ + # Similarly, no initial pos embed for language + ("lang", -self.args["dataset"]["tokenizer_max_length"]), + ], + dtype=self.dtype, + ) + else: + _model = RDTRunner.from_pretrained(pretrained) + + return _model + + def get_text_encoder(self, pretrained_text_encoder_name_or_path): + text_embedder = T5Embedder( + from_pretrained=pretrained_text_encoder_name_or_path, + model_max_length=self.args["dataset"]["tokenizer_max_length"], + device=self.device, + ) + tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model + return tokenizer, text_encoder + + def get_vision_encoder(self, pretrained_vision_encoder_name_or_path): + vision_encoder = SiglipVisionTower(vision_tower=pretrained_vision_encoder_name_or_path, args=None) + image_processor = vision_encoder.image_processor + return image_processor, vision_encoder + + def reset(self): + """Set model to evaluation mode.""" + device = self.device + weight_dtype = self.dtype + self.policy.eval() + # self.text_model.eval() + self.vision_model.eval() + + self.policy = self.policy.to(device, dtype=weight_dtype) + # self.text_model = self.text_model.to(device, dtype=weight_dtype) + self.vision_model = self.vision_model.to(device, dtype=weight_dtype) + + def load_pretrained_weights(self, pretrained=None): + if pretrained is None: + return + print(f"Loading weights from {pretrained}") + filename = os.path.basename(pretrained) + if filename.endswith(".pt"): + checkpoint = torch.load(pretrained) + self.policy.load_state_dict(checkpoint["module"]) + elif filename.endswith(".safetensors"): + from safetensors.torch import load_model + + load_model(self.policy, pretrained) + else: + raise NotImplementedError(f"Unknown checkpoint format: {pretrained}") + + def encode_instruction(self, instruction, device="cuda"): + """Encode string instruction to latent embeddings. + + Args: + instruction: a string of instruction + device: a string of device + + Returns: + pred: a tensor of latent embeddings of shape (text_max_length, 512) + """ + tokens = self.text_tokenizer(instruction, return_tensors="pt", padding="longest", + truncation=True)["input_ids"].to(device) + + tokens = tokens.view(1, -1) + with torch.no_grad(): + pred = self.text_model(tokens).last_hidden_state.detach() + + return pred + + def _format_joint_to_state(self, joints): + """ + Format the joint proprioception into the unified action vector. + + Args: + joints (torch.Tensor): The joint proprioception to be formatted. + qpos ([B, N, 14]). + + Returns: + state (torch.Tensor): The formatted vector for RDT ([B, N, 128]). + """ + AGILEX_STATE_INDICES = ([STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"] + for i in range(self.left_arm_dim)] + [STATE_VEC_IDX_MAPPING["left_gripper_open"]] + + [STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] + for i in range(self.right_arm_dim)] + [STATE_VEC_IDX_MAPPING[f"right_gripper_open"]]) + # Rescale the gripper to the range of [0, 1] + joints = joints / torch.tensor( + [[[1 for i in range(self.left_arm_dim + 1 + self.right_arm_dim + 1)]]], + device=joints.device, + dtype=joints.dtype, + ) + + B, N, _ = joints.shape + state = torch.zeros( + (B, N, self.args["model"]["state_token_dim"]), + device=joints.device, + dtype=joints.dtype, + ) + # Fill into the unified state vector + state[:, :, AGILEX_STATE_INDICES] = joints + # Assemble the mask indicating each dimension's availability + state_elem_mask = torch.zeros( + (B, self.args["model"]["state_token_dim"]), + device=joints.device, + dtype=joints.dtype, + ) + state_elem_mask[:, AGILEX_STATE_INDICES] = 1 + return state, state_elem_mask + + def _unformat_action_to_joint(self, action): + """ + Unformat the unified action vector into the joint action to be executed. + + Args: + action (torch.Tensor): The unified action vector to be unformatted. + ([B, N, 128]) + + Returns: + joints (torch.Tensor): The unformatted robot joint action. + qpos ([B, N, 14]). + """ + AGILEX_STATE_INDICES = ([STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"] + for i in range(self.left_arm_dim)] + [STATE_VEC_IDX_MAPPING["left_gripper_open"]] + + [STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] + for i in range(self.right_arm_dim)] + [STATE_VEC_IDX_MAPPING[f"right_gripper_open"]]) + action_indices = AGILEX_STATE_INDICES + joints = action[:, :, action_indices] + + # Rescale the gripper back to the action range + # Note that the action range and proprioception range are different + # for Mobile ALOHA robot + joints = joints * torch.tensor( + [[[1 for i in range(self.left_arm_dim + 1 + self.right_arm_dim + 1)]]], + device=joints.device, + dtype=joints.dtype, + ) + + return joints + + @torch.no_grad() + def step(self, proprio, images, text_embeds): + """ + Predict the next action chunk given the + proprioceptive states, images, and instruction embeddings. + + Args: + proprio: proprioceptive states + images: RGB images, the order should be + [ext_{t-1}, right_wrist_{t-1}, left_wrist_{t-1}, + ext_{t}, right_wrist_{t}, left_wrist_{t}] + text_embeds: instruction embeddings + + Returns: + action: predicted action + """ + device = self.device + dtype = self.dtype + + # The background image used for padding + background_color = np.array([int(x * 255) for x in self.image_processor.image_mean], + dtype=np.uint8).reshape(1, 1, 3) + background_image = (np.ones( + ( + self.image_processor.size["height"], + self.image_processor.size["width"], + 3, + ), + dtype=np.uint8, + ) * background_color) + + # Preprocess the images by order and encode them + image_tensor_list = [] + for image in images: + if image is None: + # Replace it with the background image + image = Image.fromarray(background_image) + + if self.image_size is not None: + image = transforms.Resize(self.data_args.image_size)(image) + + if self.args["dataset"].get("auto_adjust_image_brightness", False): + pixel_values = list(image.getdata()) + average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3) + if average_brightness <= 0.15: + image = transforms.ColorJitter(brightness=(1.75, 1.75))(image) + + if self.args["dataset"].get("image_aspect_ratio", "pad") == "pad": + + def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + image = expand2square(image, tuple(int(x * 255) for x in self.image_processor.image_mean)) + image = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0] + image_tensor_list.append(image) + + image_tensor = torch.stack(image_tensor_list, dim=0).to(device, dtype=dtype) + + image_embeds = self.vision_model(image_tensor).detach() + image_embeds = image_embeds.reshape(-1, self.vision_model.hidden_size).unsqueeze(0) + + # Prepare the proprioception states and the control frequency + joints = proprio.to(device).unsqueeze(0) # (1, 1, 14) + states, state_elem_mask = self._format_joint_to_state(joints) # (1, 1, 128), (1, 128) + states, state_elem_mask = states.to(device, dtype=dtype), state_elem_mask.to(device, dtype=dtype) + states = states[:, -1:, :] # (1, 1, 128) + ctrl_freqs = torch.tensor([self.control_frequency]).to(device) + + text_embeds = text_embeds.to(device, dtype=dtype) + + # Predict the next action chunk given the inputs + trajectory = self.policy.predict_action( + lang_tokens=text_embeds, + lang_attn_mask=torch.ones(text_embeds.shape[:2], dtype=torch.bool, device=text_embeds.device), + img_tokens=image_embeds, + state_tokens=states, + action_mask=state_elem_mask.unsqueeze(1), + ctrl_freqs=ctrl_freqs, + ) + trajectory = self._unformat_action_to_joint(trajectory).to(torch.float32) + + return trajectory diff --git a/RDT-1B/scripts/encode_lang.py b/RDT-1B/scripts/encode_lang.py new file mode 100644 index 0000000..e725a54 --- /dev/null +++ b/RDT-1B/scripts/encode_lang.py @@ -0,0 +1,53 @@ +import os + +import torch +import yaml + +from models.multimodal_encoder.t5_encoder import T5Embedder + +GPU = 0 +MODEL_PATH = "google/t5-v1_1-xxl" +CONFIG_PATH = "configs/base.yaml" +SAVE_DIR = "outs/" + +# Modify this to your task name and instruction +TASK_NAME = "handover_pan" +INSTRUCTION = "Pick up the black marker on the right and put it into the packaging box on the left." + +# Note: if your GPU VRAM is less than 24GB, +# it is recommended to enable offloading by specifying an offload directory. +OFFLOAD_DIR = ( + None # Specify your offload directory here, ensuring the directory exists. +) + + +def main(): + with open(CONFIG_PATH, "r") as fp: + config = yaml.safe_load(fp) + + device = torch.device(f"cuda:{GPU}") + text_embedder = T5Embedder( + from_pretrained=MODEL_PATH, + model_max_length=config["dataset"]["tokenizer_max_length"], + device=device, + use_offload_folder=OFFLOAD_DIR, + ) + tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model + + tokens = tokenizer(INSTRUCTION, return_tensors="pt", padding="longest", truncation=True)["input_ids"].to(device) + + tokens = tokens.view(1, -1) + with torch.no_grad(): + pred = text_encoder(tokens).last_hidden_state.detach().cpu() + + save_path = os.path.join(SAVE_DIR, f"{TASK_NAME}.pt") + # We save the embeddings in a dictionary format + torch.save({"name": TASK_NAME, "instruction": INSTRUCTION, "embeddings": pred}, save_path) + + print( + f'"{INSTRUCTION}" from "{TASK_NAME}" is encoded by "{MODEL_PATH}" into shape {pred.shape} and saved to "{save_path}"' + ) + + +if __name__ == "__main__": + main() diff --git a/RDT-1B/scripts/encode_lang_batch_once.py b/RDT-1B/scripts/encode_lang_batch_once.py new file mode 100644 index 0000000..1573b2d --- /dev/null +++ b/RDT-1B/scripts/encode_lang_batch_once.py @@ -0,0 +1,57 @@ +import os +import json +import argparse +import torch +import yaml +from tqdm import tqdm + +from models.multimodal_encoder.t5_encoder import T5Embedder + + +def encode_lang( + DATA_FILE_PATH, + TARGET_DIR, + GPU, + desc_type="seen", + tokenizer=None, + text_encoder=None, +): + current_dir = os.path.dirname(__file__) + + with open(os.path.join(current_dir, "../configs/base.yaml"), "r") as fp: + config = yaml.safe_load(fp) + + device = torch.device(f"cuda:{GPU}") + if tokenizer is None or text_encoder is None: + text_embedder = T5Embedder( + from_pretrained=os.path.join(current_dir, "../../weights/RDT/t5-v1_1-xxl"), + model_max_length=config["dataset"]["tokenizer_max_length"], + device=device, + use_offload_folder=None, + ) + tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model + + with open(DATA_FILE_PATH, "r") as f_instr: + instruction_dict = json.load(f_instr) + + instructions = instruction_dict[desc_type] + + # Encode the instructions + tokenized_res = tokenizer(instructions, return_tensors="pt", padding="longest", truncation=True) + tokens = tokenized_res["input_ids"].to(device) + attn_mask = tokenized_res["attention_mask"].to(device) + + with torch.no_grad(): + text_embeds = (text_encoder(input_ids=tokens, attention_mask=attn_mask)["last_hidden_state"].detach().cpu()) + + attn_mask = attn_mask.cpu().bool() + if not os.path.exists(f"{TARGET_DIR}/instructions"): + os.makedirs(f"{TARGET_DIR}/instructions") + # Save the embeddings for training use + for i in range(len(instructions)): + text_embed = text_embeds[i][attn_mask[i]] + save_path = os.path.join(TARGET_DIR, f"instructions/lang_embed_{i}.pt") + # print("encoded instructions save_path:",save_path) + torch.save(text_embed, save_path) + + return tokenizer, text_encoder diff --git a/RDT-1B/scripts/generate_output_json.py b/RDT-1B/scripts/generate_output_json.py new file mode 100644 index 0000000..dd06c01 --- /dev/null +++ b/RDT-1B/scripts/generate_output_json.py @@ -0,0 +1,84 @@ +import json +import os +import sys +import re + +def extract_metrics_from_log(log_file_path): + all_metrics = [] + pattern = re.compile( + r"\{'agilex_sample_mse':\s*([0-9.eE+-]+),\s*'agilex_sample_l2err':\s*([0-9.eE+-]+),\s*'overall_avg_sample_mse':\s*([0-9.eE+-]+),\s*'overall_avg_sample_l2err':\s*([0-9.eE+-]+)\}" + ) + try: + with open(log_file_path, 'r', encoding='utf-8') as f: + for line in f: + m = pattern.search(line) + if m: + metrics = ( + float(m.group(1)), + float(m.group(2)), + float(m.group(3)), + float(m.group(4)) + ) + all_metrics.append(metrics) + print(f"Find Metrics: agilex_sample_mse={metrics[0]}, agilex_sample_l2err={metrics[1]}, " + f"overall_avg_sample_mse={metrics[2]}, overall_avg_sample_l2err={metrics[3]}") + except Exception as e: + print(f"Failed to read log: {e}") + return (None, None, None, None) + + if not all_metrics: + print("No metrics found in the log file") + return (None, None, None, None) + + print(f"\nTotal {len(all_metrics)} metrics found in the log file") + + best_agilex_mse = min(m[0] for m in all_metrics) + best_agilex_l2err = min(m[1] for m in all_metrics) + best_overall_mse = min(m[2] for m in all_metrics) + best_overall_l2err = min(m[3] for m in all_metrics) + + print(f"\nBest metrics:") + print(f" agilex_sample_mse: {best_agilex_mse}") + print(f" agilex_sample_l2err: {best_agilex_l2err}") + print(f" overall_avg_sample_mse: {best_overall_mse}") + print(f" overall_avg_sample_l2err: {best_overall_l2err}") + + return (best_agilex_mse, best_agilex_l2err, best_overall_mse, best_overall_l2err) + +def generate_output_json(input_config_file, output_dir, runtime): + with open(input_config_file, 'r') as f: + config = json.load(f) + + log_file = os.path.join(output_dir, 'output.log') + agilex_sample_mse, agilex_sample_l2err, overall_avg_sample_mse, overall_avg_sample_l2err = extract_metrics_from_log(log_file) + + if None in [agilex_sample_mse, agilex_sample_l2err, overall_avg_sample_mse, overall_avg_sample_l2err]: + print("Warning: Some metrics are missing in the log file.") + + output_json = { + "task_id": config.get("task_id"), + "model_type": "RDT-1B", + "model_name": config.get("model_name") if "model_name" in config else config.get("train", {}).get("model"), + "gpu_id": config.get("gpu_id"), + "runtime": runtime, + "log_path": log_file, + "output_dir": output_dir, + "model_path": os.path.join(output_dir, 'pytorch_model.bin'), + "metrics": { + "agilex_sample_mse": agilex_sample_mse, + "agilex_sample_l2err": agilex_sample_l2err, + "overall_avg_sample_mse": overall_avg_sample_mse, + "overall_avg_sample_l2err": overall_avg_sample_l2err + } + } + + # 写入 output.json,格式化输出、确保null与规范json一致 + output_json_path = os.path.join(output_dir, 'output.json') + with open(output_json_path, 'w') as f: + json.dump(output_json, f, indent=4, ensure_ascii=False) + +if __name__ == "__main__": + if len(sys.argv) != 4: + print("Usage: python generate_output_json.py ") + sys.exit(1) + generate_output_json(sys.argv[1], sys.argv[2], sys.argv[3]) diff --git a/RDT-1B/scripts/maniskill_model.py b/RDT-1B/scripts/maniskill_model.py new file mode 100644 index 0000000..439d3dc --- /dev/null +++ b/RDT-1B/scripts/maniskill_model.py @@ -0,0 +1,325 @@ +import os + +import numpy as np +import torch +from PIL import Image +from torchvision import transforms + +from configs.state_vec import STATE_VEC_IDX_MAPPING +from models.multimodal_encoder.siglip_encoder import SiglipVisionTower +from models.multimodal_encoder.t5_encoder import T5Embedder +from models.rdt_runner import RDTRunner + +MANISKILL_INDICES = [STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] + for i in range(7)] + [STATE_VEC_IDX_MAPPING[f"right_gripper_open"]] + + +def create_model(args, pretrained, **kwargs): + model = RoboticDiffusionTransformerModel(args, **kwargs) + if pretrained is not None: + model.load_pretrained_weights(pretrained) + return model + + +DATA_STAT = { + "state_min": [ + -0.7463043928146362, + -0.0801204964518547, + -0.4976441562175751, + -2.657780647277832, + -0.5742632150650024, + 1.8309762477874756, + -2.2423808574676514, + 0.0, + ], + "state_max": [ + 0.7645499110221863, + 1.4967026710510254, + 0.4650936424732208, + -0.3866899907588959, + 0.5505855679512024, + 3.2900545597076416, + 2.5737812519073486, + 0.03999999910593033, + ], + "action_min": [ + -0.7472005486488342, + -0.08631071448326111, + -0.4995281398296356, + -2.658363103866577, + -0.5751323103904724, + 1.8290787935256958, + -2.245187997817993, + -1.0, + ], + "action_max": [ + 0.7654682397842407, + 1.4984270334243774, + 0.46786263585090637, + -0.38181185722351074, + 0.5517147779464722, + 3.291581630706787, + 2.575840711593628, + 1.0, + ], +} + + +class RoboticDiffusionTransformerModel(object): + """A wrapper for the RDT model, which handles + 1. Model initialization + 2. Encodings of instructions + 3. Model inference + """ + + def __init__( + self, + args, + device="cuda", + dtype=torch.bfloat16, + image_size=None, + control_frequency=25, + pretrained_text_encoder_name_or_path=None, + pretrained_vision_encoder_name_or_path=None, + ): + self.args = args + self.dtype = dtype + self.image_size = image_size + self.device = device + self.control_frequency = control_frequency + self.text_tokenizer, self.text_model = self.get_text_encoder(pretrained_text_encoder_name_or_path) + self.image_processor, self.vision_model = self.get_vision_encoder(pretrained_vision_encoder_name_or_path) + self.policy = self.get_policy() + + self.state_min = torch.tensor(DATA_STAT["state_min"]).to(device) + self.state_max = torch.tensor(DATA_STAT["state_max"]).to(device) + self.action_min = torch.tensor(DATA_STAT["action_min"]).to(device) + self.action_max = torch.tensor(DATA_STAT["action_max"]).to(device) + + self.reset() + + def get_policy(self): + """Initialize the model.""" + # Initialize model with arguments + img_cond_len = (self.args["common"]["img_history_size"] * self.args["common"]["num_cameras"] * + self.vision_model.num_patches) + + _model = RDTRunner( + action_dim=self.args["common"]["state_dim"], + pred_horizon=self.args["common"]["action_chunk_size"], + config=self.args["model"], + lang_token_dim=self.args["model"]["lang_token_dim"], + img_token_dim=self.args["model"]["img_token_dim"], + state_token_dim=self.args["model"]["state_token_dim"], + max_lang_cond_len=self.args["dataset"]["tokenizer_max_length"], + img_cond_len=img_cond_len, + img_pos_embed_config=[ + # No initial pos embed in the last grid size + # since we've already done in ViT + ( + "image", + ( + self.args["common"]["img_history_size"], + self.args["common"]["num_cameras"], + -self.vision_model.num_patches, + ), + ), + ], + lang_pos_embed_config=[ + # Similarly, no initial pos embed for language + ("lang", -self.args["dataset"]["tokenizer_max_length"]), + ], + dtype=self.dtype, + ) + + return _model + + def get_text_encoder(self, pretrained_text_encoder_name_or_path): + text_embedder = T5Embedder( + from_pretrained=pretrained_text_encoder_name_or_path, + model_max_length=self.args["dataset"]["tokenizer_max_length"], + device=self.device, + ) + tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model + return tokenizer, text_encoder + + def get_vision_encoder(self, pretrained_vision_encoder_name_or_path): + vision_encoder = SiglipVisionTower(vision_tower=pretrained_vision_encoder_name_or_path, args=None) + image_processor = vision_encoder.image_processor + return image_processor, vision_encoder + + def reset(self): + """Set model to evaluation mode.""" + device = self.device + weight_dtype = self.dtype + self.policy.eval() + self.text_model.eval() + self.vision_model.eval() + + self.policy = self.policy.to(device, dtype=weight_dtype) + self.text_model = self.text_model.to(device, dtype=weight_dtype) + self.vision_model = self.vision_model.to(device, dtype=weight_dtype) + + def load_pretrained_weights(self, pretrained=None): + if pretrained is None: + return + print(f"Loading weights from {pretrained}") + filename = os.path.basename(pretrained) + if filename.endswith(".pt"): + checkpoint = torch.load(pretrained) + self.policy.load_state_dict(checkpoint["module"]) + elif filename.endswith(".safetensors"): + from safetensors.torch import load_model + + load_model(self.policy, pretrained) + else: + raise NotImplementedError(f"Unknown checkpoint format: {pretrained}") + + def encode_instruction(self, instruction, device="cuda"): + """Encode string instruction to latent embeddings. + + Args: + instruction: a string of instruction + device: a string of device + + Returns: + pred: a tensor of latent embeddings of shape (text_max_length, 512) + """ + tokens = self.text_tokenizer(instruction, return_tensors="pt", padding="longest", + truncation=True)["input_ids"].to(device) + + tokens = tokens.view(1, -1) + with torch.no_grad(): + pred = self.text_model(tokens).last_hidden_state.detach() + + return pred + + def _format_joint_to_state(self, joints): + """ + Format the robot joint state into the unified state vector. + + Args: + joints (torch.Tensor): The joint state to be formatted. + qpos ([B, N, 14]). + + Returns: + state (torch.Tensor): The formatted state for RDT ([B, N, 128]). + """ + # Rescale the gripper + # joints = joints / torch.tensor( + # [[[1, 1, 1, 1, 1, 1, 4.7908, 1, 1, 1, 1, 1, 1, 4.7888]]], + # device=joints.device, dtype=joints.dtype + # ) + + # normalize to -1,1 + joints = (joints - self.state_min) / (self.state_max - self.state_min) * 2 - 1 + B, N, _ = joints.shape + state = torch.zeros( + (B, N, self.args["model"]["state_token_dim"]), + device=joints.device, + dtype=joints.dtype, + ) + # assemble the unifed state vector + state[:, :, MANISKILL_INDICES] = joints + state_elem_mask = torch.zeros( + (B, self.args["model"]["state_token_dim"]), + device=joints.device, + dtype=joints.dtype, + ) + state_elem_mask[:, MANISKILL_INDICES] = 1 + return state, state_elem_mask + + def _unformat_action_to_joint(self, action): + action_indices = MANISKILL_INDICES + joints = action[:, :, action_indices] + + # denormalize to action space + + joints = (joints + 1) / 2 * (self.action_max - self.action_min) + self.action_min + + return joints + + @torch.no_grad() + def step(self, proprio, images, text_embeds): + """ + Args: + proprio: proprioceptive states + images: RGB images + text_embeds: instruction embeddings + + Returns: + action: predicted action + """ + device = self.device + dtype = self.dtype + + background_color = np.array([int(x * 255) for x in self.image_processor.image_mean], + dtype=np.uint8).reshape(1, 1, 3) + background_image = (np.ones( + ( + self.image_processor.size["height"], + self.image_processor.size["width"], + 3, + ), + dtype=np.uint8, + ) * background_color) + + image_tensor_list = [] + for image in images: + if image is None: + # Replace it with the background image + image = Image.fromarray(background_image) + + if self.image_size is not None: + image = transforms.Resize(self.data_args.image_size)(image) + + if self.args["dataset"].get("auto_adjust_image_brightness", False): + pixel_values = list(image.getdata()) + average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3) + if average_brightness <= 0.15: + image = transforms.ColorJitter(brightness=(1.75, 1.75))(image) + + if self.args["dataset"].get("image_aspect_ratio", "pad") == "pad": + + def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + image = expand2square(image, tuple(int(x * 255) for x in self.image_processor.image_mean)) + image = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0] + image_tensor_list.append(image) + + image_tensor = torch.stack(image_tensor_list, dim=0).to(device, dtype=dtype) + + image_embeds = self.vision_model(image_tensor).detach() + image_embeds = image_embeds.reshape(-1, self.vision_model.hidden_size).unsqueeze(0) + + # history of actions + joints = proprio.to(device).unsqueeze(0) # (1, 1, 14) + states, state_elem_mask = self._format_joint_to_state(joints) # (1, 1, 128), (1, 128) + states, state_elem_mask = states.to(device, dtype=dtype), state_elem_mask.to(device, dtype=dtype) + states = states[:, -1:, :] # (1, 1, 128) + ctrl_freqs = torch.tensor([self.control_frequency]).to(device) + + text_embeds = text_embeds.to(device, dtype=dtype) + + trajectory = self.policy.predict_action( + lang_tokens=text_embeds, + lang_attn_mask=torch.ones(text_embeds.shape[:2], dtype=torch.bool, device=text_embeds.device), + img_tokens=image_embeds, + state_tokens=states, + action_mask=state_elem_mask.unsqueeze(1), + ctrl_freqs=ctrl_freqs, + ) + trajectory = self._unformat_action_to_joint(trajectory).to(torch.float32) + + return trajectory diff --git a/RDT-1B/scripts/process_data.py b/RDT-1B/scripts/process_data.py new file mode 100644 index 0000000..774d549 --- /dev/null +++ b/RDT-1B/scripts/process_data.py @@ -0,0 +1,169 @@ +import sys + +sys.path.append("./") + +import os +import h5py +import numpy as np +import pickle +import cv2 +import argparse +import yaml +from scripts.encode_lang_batch_once import encode_lang + + +def load_hdf5(dataset_path): + if not os.path.isfile(dataset_path): + print(f"Dataset does not exist at \n{dataset_path}\n") + exit() + + with h5py.File(dataset_path, "r") as root: + left_gripper, left_arm = ( + root["/joint_action/left_gripper"][()], + root["/joint_action/left_arm"][()], + ) + right_gripper, right_arm = ( + root["/joint_action/right_gripper"][()], + root["/joint_action/right_arm"][()], + ) + image_dict = dict() + for cam_name in root[f"/observation/"].keys(): + image_dict[cam_name] = root[f"/observation/{cam_name}/rgb"][()] + + return left_gripper, left_arm, right_gripper, right_arm, image_dict + + +def images_encoding(imgs): + encode_data = [] + padded_data = [] + max_len = 0 + for i in range(len(imgs)): + success, encoded_image = cv2.imencode(".jpg", imgs[i]) + jpeg_data = encoded_image.tobytes() + encode_data.append(jpeg_data) + max_len = max(max_len, len(jpeg_data)) + # padding + for i in range(len(imgs)): + padded_data.append(encode_data[i].ljust(max_len, b"\0")) + return encode_data, max_len + + +def get_task_config(task_name): + with open(f"./task_config/{task_name}.yml", "r", encoding="utf-8") as f: + args = yaml.load(f.read(), Loader=yaml.FullLoader) + return args + + +def data_transform(path, episode_num, save_path): + begin = 0 + floders = os.listdir(path) + assert episode_num <= len(floders), "data num not enough" + + if not os.path.exists(save_path): + os.makedirs(save_path) + + for i in range(episode_num): + left_gripper_all, left_arm_all, right_gripper_all, right_arm_all, image_dict = (load_hdf5( + os.path.join(path, f"episode{i}.hdf5"))) + qpos = [] + actions = [] + cam_high = [] + cam_right_wrist = [] + cam_left_wrist = [] + left_arm_dim = [] + right_arm_dim = [] + + last_state = None + for j in range(0, left_gripper_all.shape[0]): + + left_gripper, left_arm, right_gripper, right_arm = ( + left_gripper_all[j], + left_arm_all[j], + right_gripper_all[j], + right_arm_all[j], + ) + + state = np.concatenate((left_arm, [left_gripper], right_arm, [right_gripper]), axis=0) # joint + state = state.astype(np.float32) + + if j != left_gripper_all.shape[0] - 1: + + qpos.append(state) + + camera_high_bits = image_dict["head_camera"][j] + camera_high = cv2.imdecode(np.frombuffer(camera_high_bits, np.uint8), cv2.IMREAD_COLOR) + camera_high_resized = cv2.resize(camera_high, (640, 480)) + cam_high.append(camera_high_resized) + + camera_right_wrist_bits = image_dict["right_camera"][j] + camera_right_wrist = cv2.imdecode(np.frombuffer(camera_right_wrist_bits, np.uint8), cv2.IMREAD_COLOR) + camera_right_wrist_resized = cv2.resize(camera_right_wrist, (640, 480)) + cam_right_wrist.append(camera_right_wrist_resized) + + camera_left_wrist_bits = image_dict["left_camera"][j] + camera_left_wrist = cv2.imdecode(np.frombuffer(camera_left_wrist_bits, np.uint8), cv2.IMREAD_COLOR) + camera_left_wrist_resized = cv2.resize(camera_left_wrist, (640, 480)) + cam_left_wrist.append(camera_left_wrist_resized) + + if j != 0: + action = state + actions.append(action) + left_arm_dim.append(left_arm.shape[0]) + right_arm_dim.append(right_arm.shape[0]) + + if not os.path.exists(os.path.join(save_path, f"episode_{i}")): + os.makedirs(os.path.join(save_path, f"episode_{i}")) + hdf5path = os.path.join(save_path, f"episode_{i}/episode_{i}.hdf5") + + with h5py.File(hdf5path, "w") as f: + f.create_dataset("action", data=np.array(actions)) + obs = f.create_group("observations") + obs.create_dataset("qpos", data=np.array(qpos)) + obs.create_dataset("left_arm_dim", data=np.array(left_arm_dim)) + obs.create_dataset("right_arm_dim", data=np.array(right_arm_dim)) + image = obs.create_group("images") + cam_high_enc, len_high = images_encoding(cam_high) + cam_right_wrist_enc, len_right = images_encoding(cam_right_wrist) + cam_left_wrist_enc, len_left = images_encoding(cam_left_wrist) + image.create_dataset("cam_high", data=cam_high_enc, dtype=f"S{len_high}") + image.create_dataset("cam_right_wrist", data=cam_right_wrist_enc, dtype=f"S{len_right}") + image.create_dataset("cam_left_wrist", data=cam_left_wrist_enc, dtype=f"S{len_left}") + + begin += 1 + print(f"proccess {i} success!") + + return begin + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Process some episodes.") + parser.add_argument("task_name", type=str) + parser.add_argument("task_config", type=str) + parser.add_argument("expert_data_num", type=int) + args = parser.parse_args() + + task_name = args.task_name + task_config = args.task_config + expert_data_num = args.expert_data_num + + load_dir = os.path.join("../../data", str(task_name), str(task_config), "data") + + print(f"read data from path: {load_dir}") + begin = data_transform( + load_dir, + expert_data_num, + f"./processed_data/{task_name}-{task_config}-{expert_data_num}", + ) + tokenizer, text_encoder = None, None + for idx in range(expert_data_num): + print(f"Processing Language: {idx}", end="\r") + data_file_path = (f"../../data/{task_name}/{task_config}/instructions/episode{idx}.json") + target_dir = (f"processed_data/{task_name}-{task_config}-{expert_data_num}/episode_{idx}") + tokenizer, text_encoder = encode_lang( + DATA_FILE_PATH=data_file_path, + TARGET_DIR=target_dir, + GPU=0, + desc_type="seen", + tokenizer=tokenizer, + text_encoder=text_encoder, + ) diff --git a/RDT-1B/scripts/read_config.py b/RDT-1B/scripts/read_config.py new file mode 100644 index 0000000..c290f8e --- /dev/null +++ b/RDT-1B/scripts/read_config.py @@ -0,0 +1,31 @@ +import json +import yaml +import sys + +def read_config(config_file, yaml_file): + with open(config_file, 'r') as f: + json_config = json.load(f) + with open(yaml_file, 'r') as f: + yaml_config = yaml.load(f, Loader=yaml.FullLoader) + + yaml_config["model"] = json_config["train"]["model"] + json_config["task_id"] + yaml_config["data_path"] = json_config["train"]["input_data_path"] + "/data" + yaml_config["checkpoint_path"] = json_config["train"]["checkpoint_path"] + "/" + json_config["task_id"] + yaml_config["pretrained_model_name_or_path"] = json_config["train"]["input_data_path"] + "/weights/rdt-1b" + yaml_config["cuda_visible_device"] = str(json_config["gpu_id"]) + print(f"cuda_visible_device: {yaml_config['cuda_visible_device']}") + yaml_config["train_batch_size"] = int(json_config["train"]["batch_size"]) + yaml_config["sample_batch_size"] = int(json_config["train"]["batch_size"]) * 2 + yaml_config["max_train_steps"] = int(json_config["train"]["epochs"]) + yaml_config["checkpointing_period"] = int(int(json_config["train"]["epochs"]) / 10) + yaml_config["sample_period"] = 200 + yaml_config["checkpoints_total_limit"] = 50 + + + with open(yaml_file, 'w') as f: + yaml.dump(yaml_config, f, default_flow_style=False) + + print("Config YAML file updated successfully") + +if __name__ == "__main__": + read_config(sys.argv[1], sys.argv[2]) diff --git a/RDT-1B/scripts/read_yaml.py b/RDT-1B/scripts/read_yaml.py new file mode 100644 index 0000000..20b80e8 --- /dev/null +++ b/RDT-1B/scripts/read_yaml.py @@ -0,0 +1,22 @@ +import sys +import yaml + + +def read_yaml_value(file_path, key): + with open(file_path, "r") as file: + data = yaml.safe_load(file) + value = data.get(key) + if value is not None: + print(value) + else: + print(f"Key '{key}' not found in {file_path}") + + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage: python read_yaml.py ") + sys.exit(1) + + file_path = sys.argv[1] + key = sys.argv[2] + read_yaml_value(file_path, key) diff --git a/RDT-1B/train/__init__.py b/RDT-1B/train/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/RDT-1B/train/__pycache__/__init__.cpython-310.pyc b/RDT-1B/train/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..837215a Binary files /dev/null and b/RDT-1B/train/__pycache__/__init__.cpython-310.pyc differ diff --git a/RDT-1B/train/__pycache__/dataset.cpython-310.pyc b/RDT-1B/train/__pycache__/dataset.cpython-310.pyc new file mode 100644 index 0000000..9a26a17 Binary files /dev/null and b/RDT-1B/train/__pycache__/dataset.cpython-310.pyc differ diff --git a/RDT-1B/train/__pycache__/image_corrupt.cpython-310.pyc b/RDT-1B/train/__pycache__/image_corrupt.cpython-310.pyc new file mode 100644 index 0000000..1f8a570 Binary files /dev/null and b/RDT-1B/train/__pycache__/image_corrupt.cpython-310.pyc differ diff --git a/RDT-1B/train/__pycache__/sample.cpython-310.pyc b/RDT-1B/train/__pycache__/sample.cpython-310.pyc new file mode 100644 index 0000000..ab2414c Binary files /dev/null and b/RDT-1B/train/__pycache__/sample.cpython-310.pyc differ diff --git a/RDT-1B/train/__pycache__/train.cpython-310.pyc b/RDT-1B/train/__pycache__/train.cpython-310.pyc new file mode 100644 index 0000000..a51aef5 Binary files /dev/null and b/RDT-1B/train/__pycache__/train.cpython-310.pyc differ diff --git a/RDT-1B/train/dataset.py b/RDT-1B/train/dataset.py new file mode 100644 index 0000000..e72c50d --- /dev/null +++ b/RDT-1B/train/dataset.py @@ -0,0 +1,479 @@ +import traceback +import time +import os +import json +import math +import random +from typing import Dict, Sequence + +import numpy as np +import torch +from torch.utils.data import Dataset +from torchvision import transforms +from PIL import Image +import transformers + +from data.filelock import FileLock +from data.hdf5_vla_dataset import HDF5VLADataset +from train.image_corrupt import image_corrupt + + +def get_clean_item(chunk_dir): + """ + Get indexes of clean items in a chunk. + """ + dirty_bit = read_dirty_bit(chunk_dir) + return np.where(1 - dirty_bit)[0].tolist() + + +def save_dirty_bit(chunk_dir, dirty_bit): + """ + Save the dirty bit to the chunk directory. + """ + time_stmp = time.time() + while time.time() - time_stmp < 10.0: + try: + file_path = os.path.join(chunk_dir, "dirty_bit") + lock = FileLock(file_path) + lock.acquire_write_lock() + with open(file_path, "wb") as file: + file.write(dirty_bit.tobytes()) + lock.release_lock() + return + except KeyboardInterrupt: + lock.release_lock() + raise KeyboardInterrupt + except BaseException: + lock.release_lock() + continue + raise RuntimeError("Failed to save dirty bit.") + + +def read_dirty_bit(chunk_dir): + """ + Read the dirty bit from the chunk directory. + """ + # If error occurs, retry + time_stmp = time.time() + while time.time() - time_stmp < 10.0: + try: + file_path = os.path.join(chunk_dir, "dirty_bit") + lock = FileLock(file_path) + lock.acquire_read_lock() + with open(file_path, "rb") as file: + dirty_bit = np.frombuffer(file.read(), dtype=np.uint8).copy() + lock.release_lock() + assert len(dirty_bit) > 0 + return dirty_bit + except KeyboardInterrupt: + lock.release_lock() + raise KeyboardInterrupt + except BaseException: + lock.release_lock() + continue + raise RuntimeError("Failed to read dirty bit.") + + +class VLAConsumerDataset(Dataset): + """A vision-languange-action Dataset for supervised training. + This dataset will load data from the buffer directory. + """ + + def __init__( + self, + model_config_path, + config, + tokenizer, + image_processor, + num_cameras, + img_history_size, + image_size=None, + auto_adjust_image_brightness=False, + image_aug=False, + dataset_type="pretrain", + cond_mask_prob=0.1, + cam_ext_mask_prob=-1.0, + state_noise_snr=None, + use_hdf5=False, + use_precomp_lang_embed=False, + ): + super(VLAConsumerDataset, self).__init__() + + # Load the control frequency for each dataset + with open("configs/dataset_control_freq.json", "r") as fp: + self.control_freq = json.load(fp) + # Load the dataset names + dataset_names_cfg = ("configs/pretrain_datasets.json" + if dataset_type == "pretrain" else "configs/finetune_datasets.json") + with open(dataset_names_cfg, "r") as file: + DATASET_NAMES = json.load(file) + # Create the mapping between dataset name and id + self.dataset_name2id = {name: i for i, name in enumerate(DATASET_NAMES)} + self.dataset_id2name = {i: name for i, name in enumerate(DATASET_NAMES)} + + self.image_processor = image_processor + self.model_config_path = model_config_path + self.buffer_dir = config["buf_path"] + self.num_chunks = config["buf_num_chunks"] + self.chunk_size = config["buf_chunk_size"] + self.tokenizer_max_length = config["tokenizer_max_length"] + self.image_aspect_ratio = config["image_aspect_ratio"] + self.state_noise_snr = state_noise_snr + self.num_cameras = num_cameras + self.img_history_size = img_history_size + self.cond_mask_prob = cond_mask_prob + self.cam_ext_mask_prob = cam_ext_mask_prob + self.use_hdf5 = use_hdf5 + self.hdf5_dataset = None + if use_hdf5: + self.hdf5_dataset = HDF5VLADataset(self.model_config_path) + self.use_precomp_lang_embed = use_precomp_lang_embed + if use_precomp_lang_embed: + self.empty_lang_embed = torch.load("data/empty_lang_embed.pt") + + # Load dataset stat + with open("configs/dataset_stat.json", "r") as f: + dataset_stat = json.load(f) + self.dataset_stat = dataset_stat + + self.tokenizer = tokenizer + self.image_size = image_size + self.auto_adjust_image_brightness = auto_adjust_image_brightness + self.image_aug = image_aug + + self.last_content = None + self.last_meta = None + + def get_dataset_name2id(self): + return self.dataset_name2id + + def get_dataset_id2name(self): + return self.dataset_id2name + + @staticmethod + def pairwise(iterable): + a = iter(iterable) + return zip(a, a) + + @staticmethod + def _load_data_from_chunk(chunk_dir, chunk_item_idx): + # If error occurs, retry + time_stmp = time.time() + while time.time() - time_stmp < 10.0: + try: + locks = [] + file_path = os.path.join(chunk_dir, f"json_content_{chunk_item_idx}.json") + lock = FileLock(file_path) + locks.append(lock) + lock.acquire_read_lock() + with open(file_path, "r") as file: + json_content = json.load(file) + lock.release_lock() + file_path = os.path.join(chunk_dir, f"sample_{chunk_item_idx}.npz") + lock = FileLock(file_path) + locks.append(lock) + lock.acquire_read_lock() + with open(file_path, "rb") as file: + sample_dict = np.load(file) + meta = tuple(sample_dict.values()) + lock.release_lock() + return json_content, meta + except KeyboardInterrupt: + for lock in locks: + lock.release_lock() + raise KeyboardInterrupt + except BaseException: + for lock in locks: + lock.release_lock() + continue + raise RuntimeError("Failed to load sample.") + + def __len__(self) -> int: + if self.use_hdf5: + return len(self.hdf5_dataset) + else: + return self.num_chunks * self.chunk_size + + def _safe_load(self, index): + read_chunk_item_indices = [] + # Start searching from a random chunk + read_chunk_idx = index // self.chunk_size + while len(read_chunk_item_indices) == 0: + read_chunk_dir = os.path.join(self.buffer_dir, f"chunk_{read_chunk_idx}") + try: + read_chunk_item_indices = get_clean_item(read_chunk_dir) + except BaseException as e: + # Print the error info + print("Error catched when searching a clean chunk:", e) + traceback.print_exc() + read_chunk_item_indices = [] + read_chunk_idx = (read_chunk_idx + 1) % self.num_chunks + + # read_chunk_item_index = random.choice(read_chunk_item_indices) + # read_chunk_item_index = read_chunk_item_indices.pop() + random_item_index = index % len(read_chunk_item_indices) + read_chunk_item_index = read_chunk_item_indices[random_item_index] + + # Modify the dirty bit + try: + dirty_bit = read_dirty_bit(read_chunk_dir) + dirty_bit[read_chunk_item_index] = 1 + save_dirty_bit(read_chunk_dir, dirty_bit) + except BaseException as e: + # Print the error info + print("Error catched when modifying the dirty bit:", e) + traceback.print_exc() + + # load the sample + try: + content, meta = self._load_data_from_chunk(read_chunk_dir, read_chunk_item_index) + self.last_content, self.last_meta = content, meta + except BaseException as e: + # Print the error info + print("Error catched when loading sample:", e) + traceback.print_exc() + + # If failed to load the data, return the last loaded data for robustness + content, meta = self.last_content, self.last_meta + + return (content, *meta) + + def __getitem__(self, index): + # For robustness, we will try to load the data until we succeed + while True: + data_dict = None + try: + if self.use_hdf5: + res = self.hdf5_dataset.get_item() + content = res["meta"] + states = res["state"] + actions = res["actions"] + state_elem_mask = res["state_indicator"] + image_metas = [ + res["cam_high"], + res["cam_high_mask"], + res["cam_right_wrist"], + res["cam_right_wrist_mask"], + res["cam_left_wrist"], + res["cam_left_wrist_mask"], + ] + state_std = res["state_std"] + state_mean = res["state_mean"] + state_norm = res["state_norm"] + else: + ( + content, + _, + states, + _, + actions, + _, + state_elem_mask, + *image_metas, + state_std, + state_mean, + state_norm, + ) = self._safe_load(index) + + data_dict = {} + data_dict["dataset_name"] = content["dataset_name"] + data_dict["data_idx"] = self.dataset_name2id[data_dict["dataset_name"]] + data_dict["ctrl_freq"] = (self.control_freq[data_dict["dataset_name"]] + if random.random() > self.cond_mask_prob else 0) + + if self.state_noise_snr is not None: + states += np.random.normal( + 0.0, + state_std / np.sqrt(10**(self.state_noise_snr / 10)), + states.shape, + ) + ds_state_mean = np.array(self.dataset_stat[data_dict["dataset_name"]]["state_mean"]) + ds_state_mean = np.tile(ds_state_mean[None], (states.shape[0], 1)) + # Randomly mask the states by the mean state + data_dict["states"] = (states if random.random() > self.cond_mask_prob else ds_state_mean) + data_dict["actions"] = actions + data_dict["state_elem_mask"] = (state_elem_mask if random.random() > self.cond_mask_prob else + np.zeros_like(state_elem_mask)) + + # Stat for the episode that the step belongs to + data_dict["state_norm"] = state_norm + + # We replace the invalid images with the background image + # and also randomly mask images by the background image + background_color = np.array( + [int(x * 255) for x in self.image_processor.image_mean], + dtype=np.uint8, + ).reshape(1, 1, 3) + background_image = (np.ones( + ( + self.image_processor.size["height"], + self.image_processor.size["width"], + 3, + ), + dtype=np.uint8, + ) * background_color) + + image_metas = list(self.pairwise(image_metas)) + mask_probs = [self.cond_mask_prob] * self.num_cameras + if self.cam_ext_mask_prob >= 0.0: + mask_probs[0] = self.cam_ext_mask_prob + rearranged_images = [] + for i in range(self.img_history_size): + for j in range(self.num_cameras): + images, image_mask = image_metas[j] + image, valid = images[i], image_mask[i] + if (valid and (math.prod(image.shape) > 0) and (random.random() > mask_probs[j])): + rearranged_images.append((image, True)) + else: + rearranged_images.append((background_image.copy(), False)) + + preprocessed_images = [] + processor = self.image_processor + for image, valid in rearranged_images: + image = Image.fromarray(image) + if self.image_size is not None: + image = transforms.Resize(self.image_size)(image) # (1008, 336) + # assert image.height == 336, "We haven't prepare for training with images of different resolutions." + + if valid and self.auto_adjust_image_brightness: + pixel_values = list(image.getdata()) + average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3) + if average_brightness <= 0.15: + image = transforms.ColorJitter(brightness=(1.75, 1.75))(image) + + # Only apply image augmentation to 50% of the images + if valid and self.image_aug and (random.random() > 0.5): + aug_type = random.choice(["corrput_only", "color_only", "both"]) + if aug_type != "corrput_only": + image = transforms.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.5, + hue=0.03)(image) + if aug_type != "color_only": + image = image_corrupt(image) + + if self.image_aspect_ratio == "pad": + + def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean)) + image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0] + preprocessed_images.append(image) + data_dict["images"] = preprocessed_images + + if self.use_precomp_lang_embed: + if content["instruction"][-1] == ".": + content["instruction"] = content["instruction"][:-1] + data_dict["lang_embed"] = (torch.load(content["instruction"]) + if random.random() > self.cond_mask_prob else self.empty_lang_embed) + else: + instruction = (content["instruction"] if random.random() > self.cond_mask_prob else "") + data_dict["input_ids"] = self.tokenizer( + instruction, + return_tensors="pt", + padding="longest", + truncation=False, + ).input_ids[0] + + assert ( + len(data_dict["input_ids"]) <= self.tokenizer_max_length + ), f"Instruction length {len(data_dict['input_ids'])} exceeds the maximum length {self.tokenizer_max_length}." + + for k, v in data_dict.items(): + if isinstance(v, np.ndarray): + data_dict[k] = torch.from_numpy(v) + + for k, v in data_dict.items(): + assert not isinstance(v, np.ndarray), f"key: {k}, value: {v}" + # data_dict[k] = torch.from_numpy(v) + + return data_dict + except BaseException as e: + # Print the error info + if data_dict is not None: + print( + f"Error catched when processing sample from {data_dict.get('dataset_name')}:", + e, + ) + else: + print(f"Error catched when processing sample:", e) + traceback.print_exc() + # Try incresing the index + index = (index + 1) % len(self) + + +class DataCollatorForVLAConsumerDataset(object): + """Collate examples for supervised training.""" + + def __init__(self, tokenizer: transformers.PreTrainedTokenizer) -> None: + self.tokenizer = tokenizer + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + batch = { + "states": [], + "actions": [], + "state_elem_mask": [], + "state_norm": [], + "images": [], + "data_indices": [], + "ctrl_freqs": [], + } + input_ids = [] + lang_embeds = [] + lang_embed_lens = [] + + for instance in instances: + # Convert all the numpy arrays to tensor + keys_to_check = [ + "states", + "actions", + "state_elem_mask", + "state_norm", + ] + for key in keys_to_check: + if isinstance(instance[key], torch.Tensor): + item = instance[key] + else: + item = torch.from_numpy(instance[key]) + batch[key].append(item) + + if "input_ids" in instance: + input_ids.append(instance["input_ids"]) + else: + lang_embeds.append(instance["lang_embed"]) + lang_embed_lens.append(instance["lang_embed"].shape[0]) + + batch["images"].append(torch.stack(instance["images"], dim=0)) + batch["data_indices"].append(instance["data_idx"]) + batch["ctrl_freqs"].append(instance["ctrl_freq"]) + + keys_to_stack = ["states", "actions", "state_elem_mask", "state_norm", "images"] + for key in keys_to_stack: + batch[key] = torch.stack(batch[key], dim=0) + + batch["ctrl_freqs"] = torch.tensor(batch["ctrl_freqs"]) + + if len(input_ids) > 0: + input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id) + batch["input_ids"] = input_ids + batch["lang_attn_mask"] = input_ids.ne(self.tokenizer.pad_token_id) + else: + lang_embeds = torch.nn.utils.rnn.pad_sequence(lang_embeds, batch_first=True, padding_value=0) + input_lang_attn_mask = torch.zeros(lang_embeds.shape[0], lang_embeds.shape[1], dtype=torch.bool) + for i, l in enumerate(lang_embed_lens): + input_lang_attn_mask[i, :l] = True + batch["lang_embeds"] = lang_embeds + batch["lang_attn_mask"] = input_lang_attn_mask + + return batch diff --git a/RDT-1B/train/image_corrupt.py b/RDT-1B/train/image_corrupt.py new file mode 100644 index 0000000..583ef29 --- /dev/null +++ b/RDT-1B/train/image_corrupt.py @@ -0,0 +1,45 @@ +import warnings + +warnings.simplefilter(action="ignore", category=FutureWarning) + +import numpy as np + +np.bool = np.bool_ +import imgaug.augmenters as iaa +from PIL import Image + +# Define our sequence of augmentation steps that will be applied to every image. +seq = iaa.Sequential( + [ + # Execute one of the following noise augmentations + iaa.OneOf([ + iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05 * 255), per_channel=0.5), + iaa.AdditiveLaplaceNoise(scale=(0.0, 0.05 * 255), per_channel=0.5), + iaa.AdditivePoissonNoise(lam=(0.0, 0.05 * 255), per_channel=0.5), + ]), + # Execute one or none of the following blur augmentations + iaa.SomeOf( + (0, 1), + [ + iaa.OneOf([ + iaa.GaussianBlur((0, 3.0)), + iaa.AverageBlur(k=(2, 7)), + iaa.MedianBlur(k=(3, 11)), + ]), + iaa.MotionBlur(k=(3, 36)), + ], + ), + ], + # do all of the above augmentations in random order + random_order=True, +) + + +def image_corrupt(image: Image): + image_arr = np.array(image) + image_arr = image_arr[None, ...] + + image_arr = seq(images=image_arr) + + image = Image.fromarray(image_arr[0]) + return image diff --git a/RDT-1B/train/sample.py b/RDT-1B/train/sample.py new file mode 100644 index 0000000..d6d4e84 --- /dev/null +++ b/RDT-1B/train/sample.py @@ -0,0 +1,101 @@ +from collections import defaultdict + +import torch +import torch.nn.functional as F + + +@torch.no_grad() +def log_sample_res( + text_encoder, + vision_encoder, + rdt, + args, + accelerator, + weight_dtype, + dataset_id2name, + dataloader, + logger, +): + with torch.autocast(device_type="cuda", dtype=torch.float16): + logger.info(f"Running sampling for {args.num_sample_batches} batches...") + + rdt.eval() + + loss_for_log = defaultdict(float) + loss_counter = defaultdict(int) + for step, batch in enumerate(dataloader): + if step >= args.num_sample_batches: + break + + data_indices = batch["data_indices"] + ctrl_freqs = batch["ctrl_freqs"] + state_norm = batch["state_norm"].to(dtype=weight_dtype) + images = batch["images"].to(dtype=weight_dtype) + states = batch["states"].to(dtype=weight_dtype) + # We only use the last state as input + states = states[:, -1:, :] + actions = batch["actions"].to(dtype=weight_dtype) + state_elem_mask = batch["state_elem_mask"].to(dtype=weight_dtype) + + batch_size, _, C, H, W = images.shape + image_embeds = vision_encoder(images.reshape(-1, C, H, W)).detach() + image_embeds = image_embeds.reshape((batch_size, -1, vision_encoder.hidden_size)) + + lang_attn_mask = batch["lang_attn_mask"] + text_embeds = (batch["lang_embeds"].to(dtype=weight_dtype) if args.precomp_lang_embed else text_encoder( + input_ids=batch["input_ids"], attention_mask=lang_attn_mask)["last_hidden_state"].detach()) + + pred_actions = rdt.predict_action( + lang_tokens=text_embeds, + lang_attn_mask=lang_attn_mask, + img_tokens=image_embeds, + state_tokens=states, + action_mask=state_elem_mask.unsqueeze(1), + ctrl_freqs=ctrl_freqs, + ) + + num_steps = pred_actions.shape[1] + expanded_state_elem_mask = (state_elem_mask.unsqueeze(1).tile((1, num_steps, 1)).float()) + expanded_state_norm = (state_norm.unsqueeze(1).tile((1, num_steps, 1)).float()) + + loss = F.mse_loss(pred_actions, actions, reduction="none").float() + + mse_loss_per_entry = (loss * expanded_state_elem_mask).reshape( + (batch_size, -1)).sum(1) / expanded_state_elem_mask.reshape((batch_size, -1)).sum(1) + l2_loss_per_entry = loss.sqrt() / (expanded_state_norm + 1e-3) + l2_loss_per_entry = (l2_loss_per_entry * expanded_state_elem_mask).reshape( + (batch_size, -1)).sum(1) / expanded_state_elem_mask.reshape((batch_size, -1)).sum(1) + + dataset_indices, mse_losses, l2_losses = accelerator.gather_for_metrics(( + torch.LongTensor(data_indices).to(device=pred_actions.device), + mse_loss_per_entry, + l2_loss_per_entry, + ), ) + dataset_indices = dataset_indices.tolist() + if accelerator.is_main_process: + for loss_suffix, losses in zip(["_sample_mse", "_sample_l2err"], [mse_losses, l2_losses]): + for dataset_idx, loss_tensor in zip(dataset_indices, losses): + loss_name = dataset_id2name[dataset_idx] + loss_suffix + loss_for_log[loss_name] += loss_tensor.item() + loss_counter[loss_name] += 1 + + mse_loss = (loss * expanded_state_elem_mask).sum() / expanded_state_elem_mask.sum() + mse_loss_scaler = accelerator.gather(mse_loss).mean().item() + loss_for_log["overall_avg_sample_mse"] += mse_loss_scaler + + l2_loss = loss.sqrt() / (expanded_state_norm + 1e-3) + l2_loss = (l2_loss * expanded_state_elem_mask).sum() / expanded_state_elem_mask.sum() + l2_loss_scaler = accelerator.gather(l2_loss).mean().item() + loss_for_log["overall_avg_sample_l2err"] += l2_loss_scaler + + for name in loss_for_log: + if name in ["overall_avg_sample_mse", "overall_avg_sample_l2err"]: + loss_scaler = loss_for_log[name] + loss_for_log[name] = round(loss_scaler / (args.num_sample_batches), 4) + else: + loss_for_log[name] = round(loss_for_log[name] / loss_counter[name], 4) + + rdt.train() + torch.cuda.empty_cache() + + return dict(loss_for_log) diff --git a/RDT-1B/train/train.py b/RDT-1B/train/train.py new file mode 100644 index 0000000..9d1f9f1 --- /dev/null +++ b/RDT-1B/train/train.py @@ -0,0 +1,521 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import copy +import logging +import math +import os +from pathlib import Path + +import diffusers +import torch +import torch.utils.checkpoint +import transformers +import yaml +from accelerate import Accelerator +from accelerate.utils import DeepSpeedPlugin, ProjectConfiguration, set_seed +from diffusers.optimization import get_scheduler +from diffusers.utils import is_wandb_available +from huggingface_hub import create_repo, upload_folder +from tqdm.auto import tqdm +from safetensors.torch import load_model + +from models.ema_model import EMAModel +from models.multimodal_encoder.siglip_encoder import SiglipVisionTower +from models.multimodal_encoder.t5_encoder import T5Embedder +from models.rdt_runner import RDTRunner +from train.dataset import DataCollatorForVLAConsumerDataset, VLAConsumerDataset +from train.sample import log_sample_res + +if is_wandb_available(): + import wandb + + +def save_model_card(repo_id: str, base_model=str, repo_folder=None): + yaml = f""" +--- +license: mit +base_model: {base_model} +language: +- en +pipeline_tag: robotics +library_name: transformers +tags: +- robotics +- pytorch +- multimodal +- pretraining +- vla +- diffusion +- rdt +--- + """ + model_card = f""" +# RDT - {repo_id} + +This is a RDT model derived from {base_model}. The weights were trained using [RDT](https://rdt-robotics.github.io/rdt-robotics/). +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + +def train(args, logger): + # Read the config + with open(args.config_path, "r") as fp: + config = yaml.safe_load(fp) + + with open(args.model_config_path, "r") as f: + model_config = yaml.safe_load(f) + # print(model_config) + args.output_dir = model_config["checkpoint_path"] + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit) + accelerator = Accelerator( + deepspeed_plugin=(DeepSpeedPlugin(hf_ds_config=args.deepspeed) if args.deepspeed is not None else None), + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_dir=logging_dir, + project_config=accelerator_project_config, + ) + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + filename=args.output_log_path, + filemode='w', + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + token=args.hub_token, + ).repo_id + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if args.precomp_lang_embed: + tokenizer, text_encoder = None, None + else: + text_embedder = T5Embedder( + from_pretrained=args.pretrained_text_encoder_name_or_path, + model_max_length=config["dataset"]["tokenizer_max_length"], + device=accelerator.device, + ) + tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model + + vision_encoder = SiglipVisionTower(vision_tower=args.pretrained_vision_encoder_name_or_path, args=None) + image_processor = vision_encoder.image_processor + + # Load from a pretrained checkpoint + if args.pretrained_model_name_or_path is not None and not os.path.isfile(args.pretrained_model_name_or_path): + logger.info("Constructing model from pretrained checkpoint.") + rdt = RDTRunner.from_pretrained(args.pretrained_model_name_or_path) + else: + logger.info("Constructing model from provided config.") + # Calculate the image condition length + img_cond_len = (config["common"]["img_history_size"] * config["common"]["num_cameras"] * + vision_encoder.num_patches) + rdt = RDTRunner( + action_dim=config["common"]["state_dim"], + pred_horizon=config["common"]["action_chunk_size"], + config=config["model"], + lang_token_dim=config["model"]["lang_token_dim"], + img_token_dim=config["model"]["img_token_dim"], + state_token_dim=config["model"]["state_token_dim"], + max_lang_cond_len=config["dataset"]["tokenizer_max_length"], + img_cond_len=img_cond_len, + img_pos_embed_config=[ + # No initial pos embed in the last grid size + # since we've already done in ViT + ( + "image", + ( + config["common"]["img_history_size"], + config["common"]["num_cameras"], + -vision_encoder.num_patches, + ), + ), + ], + lang_pos_embed_config=[ + # Similarly, no initial pos embed for language + ("lang", -config["dataset"]["tokenizer_max_length"]), + ], + dtype=weight_dtype, + ) + + ema_rdt = copy.deepcopy(rdt) + ema_model = EMAModel( + ema_rdt, + update_after_step=config["model"]["ema"]["update_after_step"], + inv_gamma=config["model"]["ema"]["inv_gamma"], + power=config["model"]["ema"]["power"], + min_value=config["model"]["ema"]["min_value"], + max_value=config["model"]["ema"]["max_value"], + ) + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + # which ensure saving model in huggingface format (config.json + pytorch_model.bin) + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + for model in models: + model_to_save = model.module if hasattr(model, "module") else model # type: ignore + if isinstance(model_to_save, type(accelerator.unwrap_model(rdt))): + model_to_save.save_pretrained(output_dir) + + accelerator.register_save_state_pre_hook(save_model_hook) + + if args.gradient_checkpointing: + # TODO: + raise NotImplementedError("Gradient checkpointing is not yet implemented.") + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = (args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * + accelerator.num_processes) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = rdt.parameters() + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Dataset and DataLoaders creation: + train_dataset = VLAConsumerDataset( + model_config_path=args.model_config_path, # TODO + config=config["dataset"], + tokenizer=tokenizer, + image_processor=image_processor, + num_cameras=config["common"]["num_cameras"], + img_history_size=config["common"]["img_history_size"], + dataset_type=args.dataset_type, + image_aug=args.image_aug, + cond_mask_prob=args.cond_mask_prob, + cam_ext_mask_prob=args.cam_ext_mask_prob, + state_noise_snr=args.state_noise_snr, + use_hdf5=args.load_from_hdf5, + use_precomp_lang_embed=args.precomp_lang_embed, + ) + sample_dataset = VLAConsumerDataset( + model_config_path=args.model_config_path, # TODO + config=config["dataset"], + tokenizer=tokenizer, + image_processor=image_processor, + num_cameras=config["common"]["num_cameras"], + img_history_size=config["common"]["img_history_size"], + dataset_type=args.dataset_type, + image_aug=False, + cond_mask_prob=0, + cam_ext_mask_prob=-1, + state_noise_snr=None, + use_hdf5=args.load_from_hdf5, + use_precomp_lang_embed=args.precomp_lang_embed, + ) + + data_collator = DataCollatorForVLAConsumerDataset(tokenizer) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=data_collator, + num_workers=args.dataloader_num_workers, + pin_memory=True, + persistent_workers=True, + ) + sample_dataloader = torch.utils.data.DataLoader( + sample_dataset, + batch_size=args.sample_batch_size, + shuffle=True, + collate_fn=data_collator, + num_workers=args.dataloader_num_workers, + pin_memory=True, + persistent_workers=True, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + rdt, optimizer, train_dataloader, sample_dataloader, lr_scheduler = (accelerator.prepare( + rdt, optimizer, train_dataloader, sample_dataloader, lr_scheduler)) + + ema_rdt.to(accelerator.device, dtype=weight_dtype) + + if text_encoder is not None: + text_encoder.to(accelerator.device, dtype=weight_dtype) + + if vision_encoder is not None: + vision_encoder.vision_tower.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers( + "VLA", + config=vars(args), + init_kwargs={"wandb": { + "name": f"RoboTwin_RDT_{args.CONFIG_NAME}", + }}, + ) + + # Train! + total_batch_size = (args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps) + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Load from a pretrained checkpoint + if (args.resume_from_checkpoint is None and args.pretrained_model_name_or_path is not None + and os.path.isfile(args.pretrained_model_name_or_path)): + # Since EMA is deprecated, we do not load EMA from the pretrained checkpoint + logger.info("Loading from a pretrained checkpoint.") + checkpoint = torch.load(args.pretrained_model_name_or_path) + rdt.module.load_state_dict(checkpoint["module"]) + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.") + args.resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + try: + accelerator.load_state(os.path.join(args.output_dir, path)) # load_module_strict=False + except: + # load deepspeed's state_dict + logger.info("Resuming training state failed. Attempting to only load from model checkpoint.") + checkpoint = torch.load( + os.path.join( + args.output_dir, + path, + "pytorch_model", + "mp_rank_00_model_states.pt", + )) + rdt.module.load_state_dict(checkpoint["module"]) + + load_model(ema_rdt, os.path.join(args.output_dir, path, "ema", "model.safetensors")) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + + # Only show the progress bar once on each machine. + progress_bar = tqdm( + range(global_step, args.max_train_steps), + disable=not accelerator.is_local_main_process, + ) + progress_bar.set_description("Steps") + + loss_for_log = {} + for epoch in range(first_epoch, args.num_train_epochs): + + rdt.train() + + # Set the progress_bar to correct position + if args.resume_from_checkpoint and epoch == first_epoch: + progress_bar.update(resume_step // args.gradient_accumulation_steps) + + # Forward and backward... + for batch in train_dataloader: + with accelerator.accumulate(rdt): + images = batch["images"].to(dtype=weight_dtype) + states = batch["states"].to(dtype=weight_dtype) # (B, T, D_a) + # We only use the last state as input + states = states[:, -1:, :] + actions = batch["actions"].to(dtype=weight_dtype) + state_elem_mask = batch["state_elem_mask"].to(dtype=weight_dtype) + ctrl_freqs = batch["ctrl_freqs"] + + with torch.no_grad(): + batch_size, _, C, H, W = images.shape + image_embeds = vision_encoder(images.reshape(-1, C, H, W)).detach() + image_embeds = image_embeds.reshape((batch_size, -1, vision_encoder.hidden_size)) + + lang_attn_mask = batch["lang_attn_mask"] + text_embeds = (batch["lang_embeds"].to( + dtype=weight_dtype) if args.precomp_lang_embed else text_encoder( + input_ids=batch["input_ids"], attention_mask=lang_attn_mask)["last_hidden_state"].detach()) + + state_elem_mask = state_elem_mask.unsqueeze(1) + loss = rdt( + lang_tokens=text_embeds, + lang_attn_mask=lang_attn_mask, + img_tokens=image_embeds, + state_tokens=states, + action_gt=actions, + action_mask=state_elem_mask, + ctrl_freqs=ctrl_freqs, + ) + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = rdt.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=args.set_grads_to_none) + + ema_model.step(accelerator.unwrap_model(rdt)) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if global_step % args.checkpointing_period == 0: + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + ema_save_path = os.path.join(save_path, f"ema") + accelerator.save_model(ema_rdt, ema_save_path) + logger.info(f"Saved state to {save_path}") + + if args.sample_period > 0 and global_step % args.sample_period == 0: + sample_loss_for_log = log_sample_res( + text_encoder, + vision_encoder, + rdt, # We do not use EMA currently + args, + accelerator, + weight_dtype, + sample_dataset.get_dataset_id2name(), + sample_dataloader, + logger, + ) + logger.info(sample_loss_for_log) + accelerator.log(sample_loss_for_log, step=global_step) + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + logs.update(loss_for_log) + # logger.info(logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + accelerator.unwrap_model(rdt).save_pretrained(args.output_dir) + ema_save_path = os.path.join(args.output_dir, f"ema") + accelerator.save_model(ema_rdt, ema_save_path) + + logger.info(f"Saved Model to {args.output_dir}") + + if args.push_to_hub: + save_model_card( + repo_id, + base_model=args.pretrained_model_name_or_path, + repo_folder=args.output_dir, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + token=args.hub_token, + allow_patterns=["pytorch_model.bin", "*.json", "*.md"], + # ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() diff --git a/README.md b/README.md new file mode 100644 index 0000000..69ad584 --- /dev/null +++ b/README.md @@ -0,0 +1,2 @@ +# d-robotics-rdt + diff --git a/weights b/weights new file mode 120000 index 0000000..9698dbe --- /dev/null +++ b/weights @@ -0,0 +1 @@ +/home/qi.xiong/Temp/RDT-1B/input/weights \ No newline at end of file