Compare commits

...

2 Commits

Author SHA1 Message Date
skyxz
3168acad93 update 2025-10-24 19:27:07 +08:00
skyxz
52d79bbc5e update 2025-10-24 19:25:20 +08:00
82 changed files with 9244 additions and 0 deletions

7
.gitignore vendored
View File

@ -1,3 +1,9 @@
<<<<<<< HEAD
input/
output/
Temp/
weights/
=======
# ---> Python # ---> Python
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
@ -168,3 +174,4 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/ #.idea/
>>>>>>> e37bf13e1bc7484419d980e28c97c500332db13c

2
RDT-1B/.dockerignore Normal file
View File

@ -0,0 +1,2 @@
input/*
output/*

7
RDT-1B/.gitignore vendored Normal file
View File

@ -0,0 +1,7 @@
processed_data/
training_data/
checkpoints/
model_config/*.yml
wandb/*
!models/
!data/

45
RDT-1B/Dockerfile Normal file
View File

@ -0,0 +1,45 @@
FROM registry.d-robotics.cc/public/cuda:11.8.0-cudnn8-devel-ubuntu22.04
# ccr-29eug8s3-pub.cnc.bj.baidubce.com/public/cuda:11.8.0-cudnn8-devel-ubuntu22.04
WORKDIR /app
ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1
ENV TZ=Asia/Shanghai
RUN sed -i 's/archive.ubuntu.com/mirrors.tuna.tsinghua.edu.cn/g' /etc/apt/sources.list && \
sed -i 's/security.ubuntu.com/mirrors.tuna.tsinghua.edu.cn/g' /etc/apt/sources.list
RUN apt-get update && apt-get install -y \
software-properties-common \
&& add-apt-repository ppa:deadsnakes/ppa \
&& apt-get update \
&& apt-get install -y \
python3.10 \
python3.10-dev \
python3.10-distutils \
libgl1-mesa-glx \
libglib2.0-0 \
wget \
ffmpeg \
libsm6 \
libxext6 \
&& rm -rf /var/lib/apt/lists/*
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1
COPY . /app/
RUN python3 -m pip install --upgrade pip
RUN pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu121
RUN pip3 install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
RUN pip install packaging==24.0
RUN pip install flash_attn-2.7.2.post1+cu12torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
RUN mkdir -p /app/dataset/input /app/dataset/output
ENTRYPOINT ["bash", "finetune.sh"]

1
RDT-1B/__init__.py Normal file
View File

@ -0,0 +1 @@
from .deploy_policy import *

BIN
RDT-1B/assets/head.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 726 KiB

Binary file not shown.

71
RDT-1B/configs/base.yaml Normal file
View File

@ -0,0 +1,71 @@
common:
# The number of historical images
img_history_size: 2
# The number of future actions to predict
action_chunk_size: 64
# The number of cameras to be used in the model
num_cameras: 3
# Dimension for state/action, we use the same space for both state and action
# This MUST be equal to configs/state_vec.py
state_dim: 128
dataset:
# We will extract the data from raw dataset
# and store them in the disk buffer by producer
# When training, we will read the data
# randomly from the buffer by consumer
# The producer will replace the data which has been
# read by the consumer with new data
# The path to the buffer (at least 400GB)
buf_path: /path/to/buffer
# The number of chunks in the buffer
buf_num_chunks: 512
# The number of samples (step rather than episode) in each chunk
buf_chunk_size: 512
# We will filter the episodes with length less than `epsd_len_thresh_low`
epsd_len_thresh_low: 32
# For those more than `epsd_len_thresh_high`,
# we will randomly sample `epsd_len_thresh_high` steps each time we load the episode
# to better balance the training datasets
epsd_len_thresh_high: 2048
# How to fit the image size
image_aspect_ratio: pad
# Maximum number of language tokens
tokenizer_max_length: 1024
model:
# Config for condition adpators
lang_adaptor: mlp2x_gelu
img_adaptor: mlp2x_gelu
state_adaptor: mlp3x_gelu
lang_token_dim: 4096
img_token_dim: 1152
# Dim of action or proprioception vector
# A `state` refers to an action or a proprioception vector
state_token_dim: 128
# Config for RDT structure
rdt:
# 1B: num_head 32 hidden_size 2048
hidden_size: 2048
depth: 28
num_heads: 32
cond_pos_embed_type: multimodal
# For noise scheduler
noise_scheduler:
type: ddpm
num_train_timesteps: 1000
num_inference_timesteps: 5
beta_schedule: squaredcos_cap_v2 # Critical choice
prediction_type: sample
clip_sample: False
# For EMA (params averaging)
# We do not use EMA currently
ema:
update_after_step: 0
inv_gamma: 1.0
power: 0.75
min_value: 0.0
max_value: 0.9999

View File

@ -0,0 +1,50 @@
{
"A": [
[
-0.2691913843154907,
-0.21995729207992554,
-0.182277649641037
],
[
0.35127854347229004,
0.2769763469696045,
0.17159393429756165
]
],
"B": [
[
-0.2576896846294403,
-0.22244493663311005,
-0.20557966828346252
],
[
0.32854634523391724,
0.2922680974006653,
0.17373555898666382
]
],
"C": [
[
-0.29205888509750366,
-0.24688798189163208,
-0.17577645182609558
],
[
0.25053921341896057,
0.3277084231376648,
0.16431939601898193
]
],
"D": [
[
-0.25131964683532715,
-0.15233077108860016,
-0.13294968008995056
],
[
0.19209328293800354,
0.19344553351402283,
0.1370421051979065
]
]
}

View File

@ -0,0 +1,65 @@
{
"fractal20220817_data": 3,
"taco_play": 15,
"jaco_play": 10,
"berkeley_cable_routing": 10,
"nyu_door_opening_surprising_effectiveness": 3,
"viola": 20,
"berkeley_autolab_ur5": 5,
"toto": 30,
"kuka": 10,
"language_table": 10,
"columbia_cairlab_pusht_real": 10,
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": 20,
"nyu_rot_dataset_converted_externally_to_rlds":3,
"stanford_hydra_dataset_converted_externally_to_rlds": 10,
"austin_buds_dataset_converted_externally_to_rlds": 20,
"nyu_franka_play_dataset_converted_externally_to_rlds": 3,
"maniskill_dataset_converted_externally_to_rlds": 20,
"furniture_bench_dataset_converted_externally_to_rlds": 10,
"ucsd_kitchen_dataset_converted_externally_to_rlds": 2,
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": 3,
"austin_sailor_dataset_converted_externally_to_rlds": 20,
"austin_sirius_dataset_converted_externally_to_rlds": 20,
"bc_z": 10,
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": 10,
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": 10,
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": 10,
"utokyo_xarm_bimanual_converted_externally_to_rlds": 10,
"berkeley_mvp_converted_externally_to_rlds": 5,
"berkeley_rpt_converted_externally_to_rlds": 30,
"kaist_nonprehensile_converted_externally_to_rlds": 10,
"stanford_mask_vit_converted_externally_to_rlds": 0,
"tokyo_u_lsmo_converted_externally_to_rlds": 10,
"dlr_sara_pour_converted_externally_to_rlds": 10,
"dlr_sara_grid_clamp_converted_externally_to_rlds": 10,
"dlr_edan_shared_control_converted_externally_to_rlds": 5,
"asu_table_top_converted_externally_to_rlds": 12.5,
"stanford_robocook_converted_externally_to_rlds": 5,
"eth_agent_affordances": 66.6,
"imperialcollege_sawyer_wrist_cam": 10,
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": 20,
"uiuc_d3field": 1,
"utaustin_mutex": 20,
"berkeley_fanuc_manipulation": 10,
"cmu_play_fusion": 5,
"cmu_stretch": 10,
"berkeley_gnm_recon": 3,
"berkeley_gnm_cory_hall": 5,
"berkeley_gnm_sac_son": 10,
"robo_net": 1,
"roboturk_real_towercreation": 10,
"roboturk_real_laundrylayout": 10,
"roboturk_real_objectsearch": 10,
"aloha_mobile": 50,
"aloha_static": 50,
"roboset": 5,
"droid": 15,
"fmb": 10,
"dobbe": 30,
"qut_dexterous_manpulation": 30,
"agilex": 25,
"rh20t": 10,
"calvin": 30,
"bridgev2": 5
}

View File

@ -0,0 +1,575 @@
{
"fractal20220817_data": {
"image_keys": [
"image",
"image",
"image",
"image"
],
"image_mask":[
1,0,0,0
]
},
"taco_play": {
"image_keys": [
"rgb_static",
"rgb_gripper",
"rgb_static",
"rgb_static"
],
"image_mask":[
1,1,0,0
]
},
"jaco_play": {
"image_keys": [
"image",
"image_wrist",
"image_wrist",
"image_wrist"
],
"image_mask":[
1,1,0,0
]
},
"berkeley_cable_routing": {
"image_keys": [
"image",
"wrist45_image",
"wrist225_image",
"top_image"
],
"image_mask":[1,1,0,1]
},
"nyu_door_opening_surprising_effectiveness": {
"image_keys": [
"image",
"image",
"image",
"image"
],
"image_mask":[1,0,0,0]
},
"viola": {
"image_keys": [
"agentview_rgb",
"eye_in_hand_rgb",
"eye_in_hand_rgb",
"eye_in_hand_rgb"
],
"image_mask":[1,1,0,0]
},
"berkeley_autolab_ur5": {
"image_keys": [
"image",
"hand_image",
"hand_image",
"hand_image"
],
"image_mask":[1,1,0,0]
},
"toto": {
"image_keys": [
"image",
"image",
"image",
"image"
],
"image_mask":[1,0,0,0]
},
"kuka": {
"image_keys": [
"image",
"image",
"image",
"image"
],
"image_mask":[1,0,0,0]
},
"language_table": {
"image_keys": [
"rgb",
"rgb",
"rgb",
"rgb"
],
"image_mask":[1,0,0,0]
},
"columbia_cairlab_pusht_real": {
"image_keys": [
"image",
"wrist_image",
"wrist_image",
"wrist_image"
],
"image_mask":[1,1,0,0]
},
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": {
"image_keys": [
"image",
"image",
"image",
"image"
],
"image_mask":[1,0,0,0]
},
"nyu_rot_dataset_converted_externally_to_rlds": {
"image_keys": [
"image",
"image",
"image",
"image"
],
"image_mask":[1,0,0,0]
},
"stanford_hydra_dataset_converted_externally_to_rlds": {
"image_keys": [
"image",
"wrist_image",
"wrist_image",
"wrist_image"
],
"image_mask":[1,1,0,0]
},
"austin_buds_dataset_converted_externally_to_rlds": {
"image_keys": [
"image",
"wrist_image",
"wrist_image",
"wrist_image"
],
"image_mask":[1,1,0,0]
},
"nyu_franka_play_dataset_converted_externally_to_rlds": {
"image_keys": [
"image",
"image_additional_view",
"image_additional_view",
"image_additional_view"
],
"image_mask":[1,0,0,1]
},
"maniskill_dataset_converted_externally_to_rlds": {
"image_keys": [
"image",
"wrist_image",
"wrist_image",
"wrist_image"
],
"image_mask":[1,1,0,0]
},
"furniture_bench_dataset_converted_externally_to_rlds": {
"image_keys": [
"image",
"wrist_image",
"wrist_image",
"wrist_image"
],
"image_mask":[1,1,0,0]
},
"ucsd_kitchen_dataset_converted_externally_to_rlds": {
"image_keys": [
"image",
"image",
"image",
"image"
],
"image_mask":[1,0,0,0]
},
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": {
"image_keys": [
"image",
"image",
"image",
"image"
],
"image_mask":[1,0,0,0]
},
"austin_sailor_dataset_converted_externally_to_rlds": {
"image_keys": [
"image",
"wrist_image",
"wrist_image",
"wrist_image"
],
"image_mask":[1,1,0,0]
},
"austin_sirius_dataset_converted_externally_to_rlds": {
"image_keys": [
"image",
"wrist_image",
"wrist_image",
"wrist_image"
],
"image_mask":[1,1,0,0]
},
"bc_z": {
"image_keys": [
"image",
"image",
"image",
"image"
],
"image_mask":[1,0,0,0]
},
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": {
"image_keys": [
"image",
"image",
"image",
"image"
],
"image_mask":[1,0,0,0]
},
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": {
"image_keys": [
"image",
"image",
"image",
"image"
],
"image_mask":[1,0,0,0]
},
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": {
"image_keys": [
"image",
"hand_image",
"hand_image",
"image2"
],
"image_mask":[1,1,0,1]
},
"utokyo_xarm_bimanual_converted_externally_to_rlds": {
"image_keys": [
"image",
"image",
"image",
"image"
],
"image_mask":[1,0,0,0]
},
"berkeley_mvp_converted_externally_to_rlds": {
"image_keys": [
"hand_image",
"hand_image",
"hand_image",
"hand_image"
],
"image_mask":[0,1,0,0]
},
"berkeley_rpt_converted_externally_to_rlds": {
"image_keys": [
"hand_image",
"hand_image",
"hand_image",
"hand_image"
],
"image_mask":[0,1,0,0]
},
"kaist_nonprehensile_converted_externally_to_rlds": {
"image_keys": [
"image",
"image",
"image",
"image"
],
"image_mask":[1,0,0,0]
},
"stanford_mask_vit_converted_externally_to_rlds": {
"image_keys": [
"image",
"image",
"image",
"image"
],
"image_mask":[1,0,0,0]
},
"tokyo_u_lsmo_converted_externally_to_rlds": {
"image_keys": [
"image",
"image",
"image",
"image"
],
"image_mask":[1,0,0,0]
},
"dlr_sara_pour_converted_externally_to_rlds": {
"image_keys": [
"image",
"image",
"image",
"image"
],
"image_mask":[1,0,0,0]
},
"dlr_sara_grid_clamp_converted_externally_to_rlds": {
"image_keys": [
"image",
"image",
"image",
"image"
],
"image_mask":[1,0,0,0]
},
"dlr_edan_shared_control_converted_externally_to_rlds": {
"image_keys": [
"image",
"image",
"image",
"image"
],
"image_mask":[1,0,0,0]
},
"asu_table_top_converted_externally_to_rlds": {
"image_keys": [
"image",
"image",
"image",
"image"
],
"image_mask":[1,0,0,0]
},
"stanford_robocook_converted_externally_to_rlds": {
"image_keys": [
"image_2",
"image_1",
"image_3",
"image_4"
],
"image_mask":[1,0,0,1]
},
"eth_agent_affordances": {
"image_keys": [
"image",
"image",
"image",
"image"
],
"image_mask":[1,0,0,0]
},
"imperialcollege_sawyer_wrist_cam": {
"image_keys": [
"image",
"wrist_image",
"wrist_image",
"wrist_image"
],
"image_mask":[0,1,0,0]
},
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": {
"image_keys": [
"image",
"wrist_image",
"wrist_image",
"wrist_image"
],
"image_mask":[1,1,0,0]
},
"uiuc_d3field": {
"image_keys": [
"image_1",
"image_2",
"image_3",
"image_4"
],
"image_mask":[1,0,0,1]
},
"utaustin_mutex": {
"image_keys": [
"image",
"wrist_image",
"wrist_image",
"wrist_image"
],
"image_mask":[1,1,0,0]
},
"berkeley_fanuc_manipulation": {
"image_keys": [
"image",
"wrist_image",
"wrist_image",
"wrist_image"
],
"image_mask":[1,1,0,0]
},
"cmu_play_fusion": {
"image_keys": [
"image",
"image",
"image",
"image"
],
"image_mask":[1,0,0,0]
},
"cmu_stretch": {
"image_keys": [
"image",
"image",
"image",
"image"
],
"image_mask":[1,0,0,0]
},
"berkeley_gnm_recon": {
"image_keys": [
"image",
"image",
"image",
"image"
],
"image_mask":[1,0,0,0]
},
"berkeley_gnm_cory_hall": {
"image_keys": [
"image",
"image",
"image",
"image"
],
"image_mask":[1,0,0,0]
},
"berkeley_gnm_sac_son": {
"image_keys": [
"image",
"image",
"image",
"image"
],
"image_mask":[1,0,0,0]
},
"robo_net": {
"image_keys": [
"image",
"image1",
"image2",
"image2"
],
"image_mask":[1,0,0,1]
},
"roboturk_real_towercreation": {
"image_keys": [
"top_rgb_frame",
"front_rgb_frame",
"front_rgb_frame",
"front_rgb_frame"
],
"image_mask":[1,0,0,1]
},
"roboturk_real_laundrylayout": {
"image_keys": [
"top_rgb_frame",
"front_rgb_frame",
"front_rgb_frame",
"front_rgb_frame"
],
"image_mask":[1,0,0,1]
},
"roboturk_real_objectsearch": {
"image_keys": [
"top_rgb_frame",
"front_rgb_frame",
"front_rgb_frame",
"front_rgb_frame"
],
"image_mask":[1,0,0,1]
},
"aloha_mobile": {
"image_keys": [
"cam_high",
"cam_right_wrist",
"cam_left_wrist",
"cam_right_wrist"
],
"image_mask":[1,1,1,0]
},
"aloha_static": {
"image_keys": [
"cam_high",
"cam_right_wrist",
"cam_left_wrist",
"cam_low"
],
"image_mask":[1,1,1,1]
},
"roboset": {
"image_keys": [
"rgb_top",
"rgb_right",
"rgb_left",
"rgb_right"
],
"image_mask":[1,1,1,0]
},
"droid": {
"image_keys": [
"exterior_image_1_left",
"wrist_image_left",
"wrist_image_left",
"exterior_image_2_left"
],
"image_mask":[1,1,0,1]
},
"fmb": {
"image_keys": [
"image_side_1",
"image_wrist_1",
"image_wrist_1",
"image_side_2"
],
"image_mask":[1,1,0,1]
},
"dobbe": {
"image_keys": [
"wrist_image",
"wrist_image",
"wrist_image",
"wrist_image"
],
"image_mask":[0,1,0,0]
},
"qut_dexterous_manpulation": {
"image_keys": [
"image",
"wrist_image",
"wrist_image",
"wrist_image"
],
"image_mask":[1,1,0,0]
},
"agilex": {
"image_keys": [
"cam_high",
"cam_right_wrist",
"cam_left_wrist",
"cam_right_wrist"
],
"image_mask":[1,1,1,0]
},
"rh20t": {
"image_keys": [
"image",
"image",
"image",
"image"
],
"image_mask":[1,0,0,0]
},
"calvin": {
"image_keys": [
"rgb_static",
"rgb_gripper",
"rgb_gripper",
"rgb_gripper"
],
"image_mask":[1,1,0,0]
},
"bridgev2": {
"image_keys": [
"images0",
"images0",
"images0",
"images0"
],
"image_mask":[1,0,0,0]
}
}

View File

@ -0,0 +1,525 @@
{
"agilex": {
"dataset_name": "agilex",
"state_mean": [
-0.0036545392947090432,
-0.2773659935760079,
0.3147616748061523,
0.3813313179910183,
0.04028575944090457,
0.034888520819083294,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0
],
"state_std": [
0.05763674563578847,
0.2580181064167735,
0.19785840483767897,
0.05020347749331385,
0.054529239104671424,
0.05020521339363586,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0
],
"state_min": [
-0.17447535196940103,
-0.5522612677680121,
-0.3340397516886393,
0.21861712137858072,
-0.09725829230414497,
0.003396739231215583,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0
],
"state_max": [
0.21961932712131077,
0.30613206227620443,
0.5444545321994357,
0.4866888682047526,
0.31486290825737845,
0.3355223337809245,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0
]
}
}

View File

@ -0,0 +1,3 @@
[
"agilex"
]

View File

@ -0,0 +1,3 @@
{
"agilex": 100
}

View File

@ -0,0 +1,48 @@
[
"fractal20220817_data",
"jaco_play",
"taco_play",
"berkeley_cable_routing",
"viola",
"berkeley_autolab_ur5",
"toto",
"nyu_door_opening_surprising_effectiveness",
"columbia_cairlab_pusht_real",
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds",
"austin_buds_dataset_converted_externally_to_rlds",
"kuka",
"utokyo_xarm_bimanual_converted_externally_to_rlds",
"stanford_hydra_dataset_converted_externally_to_rlds",
"maniskill_dataset_converted_externally_to_rlds",
"ucsd_kitchen_dataset_converted_externally_to_rlds",
"ucsd_pick_and_place_dataset_converted_externally_to_rlds",
"austin_sailor_dataset_converted_externally_to_rlds",
"austin_sirius_dataset_converted_externally_to_rlds",
"bc_z",
"utokyo_pr2_opening_fridge_converted_externally_to_rlds",
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds",
"utokyo_xarm_pick_and_place_converted_externally_to_rlds",
"berkeley_mvp_converted_externally_to_rlds",
"berkeley_rpt_converted_externally_to_rlds",
"kaist_nonprehensile_converted_externally_to_rlds",
"tokyo_u_lsmo_converted_externally_to_rlds",
"dlr_sara_grid_clamp_converted_externally_to_rlds",
"stanford_robocook_converted_externally_to_rlds",
"imperialcollege_sawyer_wrist_cam",
"iamlab_cmu_pickup_insert_converted_externally_to_rlds",
"utaustin_mutex",
"berkeley_fanuc_manipulation",
"cmu_play_fusion",
"language_table",
"furniture_bench_dataset_converted_externally_to_rlds",
"droid",
"fmb",
"dobbe",
"qut_dexterous_manpulation",
"aloha_mobile",
"aloha_static",
"roboset",
"rh20t",
"calvin",
"bridgev2"
]

View File

@ -0,0 +1,48 @@
{
"fractal20220817_data": 271,
"taco_play": 60,
"jaco_play": 33,
"berkeley_cable_routing": 8,
"nyu_door_opening_surprising_effectiveness": 10,
"viola": 12,
"berkeley_autolab_ur5": 32,
"toto": 32,
"kuka": 50,
"language_table": 100,
"columbia_cairlab_pusht_real": 12,
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": 55,
"stanford_hydra_dataset_converted_externally_to_rlds": 24,
"austin_buds_dataset_converted_externally_to_rlds": 7,
"maniskill_dataset_converted_externally_to_rlds": 174,
"furniture_bench_dataset_converted_externally_to_rlds": 71,
"ucsd_kitchen_dataset_converted_externally_to_rlds": 12,
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": 37,
"austin_sailor_dataset_converted_externally_to_rlds": 15,
"austin_sirius_dataset_converted_externally_to_rlds": 24,
"bc_z": 208,
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": 9,
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": 15,
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": 10,
"utokyo_xarm_bimanual_converted_externally_to_rlds": 1,
"berkeley_mvp_converted_externally_to_rlds": 22,
"berkeley_rpt_converted_externally_to_rlds": 30,
"kaist_nonprehensile_converted_externally_to_rlds": 14,
"tokyo_u_lsmo_converted_externally_to_rlds": 7,
"dlr_sara_grid_clamp_converted_externally_to_rlds": 1,
"stanford_robocook_converted_externally_to_rlds": 50,
"imperialcollege_sawyer_wrist_cam": 13,
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": 25,
"utaustin_mutex": 39,
"berkeley_fanuc_manipulation": 20,
"cmu_play_fusion": 24,
"droid": 303,
"fmb": 42,
"dobbe": 36,
"qut_dexterous_manpulation": 14,
"aloha_mobile": 150,
"aloha_static": 150,
"roboset": 135,
"rh20t": 331,
"calvin": 100,
"bridgev2": 224
}

126
RDT-1B/configs/state_vec.py Normal file
View File

@ -0,0 +1,126 @@
STATE_VEC_IDX_MAPPING = {
# [0, 10): right arm joint positions
**{
"arm_joint_{}_pos".format(i): i
for i in range(10)
},
**{
"right_arm_joint_{}_pos".format(i): i
for i in range(10)
},
# [10, 15): right gripper joint positions
**{
"gripper_joint_{}_pos".format(i): i + 10
for i in range(5)
},
**{
"right_gripper_joint_{}_pos".format(i): i + 10
for i in range(5)
},
"gripper_open": 10, # alias of right_gripper_joint_0_pos
"right_gripper_open": 10,
# [15, 25): right arm joint velocities
**{
"arm_joint_{}_vel".format(i): i + 15
for i in range(10)
},
**{
"right_arm_joint_{}_vel".format(i): i + 15
for i in range(10)
},
# [25, 30): right gripper joint velocities
**{
"gripper_joint_{}_vel".format(i): i + 25
for i in range(5)
},
**{
"right_gripper_joint_{}_vel".format(i): i + 25
for i in range(5)
},
"gripper_open_vel": 25, # alias of right_gripper_joint_0_vel
"right_gripper_open_vel": 25,
# [30, 33): right end effector positions
"eef_pos_x": 30,
"right_eef_pos_x": 30,
"eef_pos_y": 31,
"right_eef_pos_y": 31,
"eef_pos_z": 32,
"right_eef_pos_z": 32,
# [33, 39): right end effector 6D pose
"eef_angle_0": 33,
"right_eef_angle_0": 33,
"eef_angle_1": 34,
"right_eef_angle_1": 34,
"eef_angle_2": 35,
"right_eef_angle_2": 35,
"eef_angle_3": 36,
"right_eef_angle_3": 36,
"eef_angle_4": 37,
"right_eef_angle_4": 37,
"eef_angle_5": 38,
"right_eef_angle_5": 38,
# [39, 42): right end effector velocities
"eef_vel_x": 39,
"right_eef_vel_x": 39,
"eef_vel_y": 40,
"right_eef_vel_y": 40,
"eef_vel_z": 41,
"right_eef_vel_z": 41,
# [42, 45): right end effector angular velocities
"eef_angular_vel_roll": 42,
"right_eef_angular_vel_roll": 42,
"eef_angular_vel_pitch": 43,
"right_eef_angular_vel_pitch": 43,
"eef_angular_vel_yaw": 44,
"right_eef_angular_vel_yaw": 44,
# [45, 50): reserved
# [50, 60): left arm joint positions
**{
"left_arm_joint_{}_pos".format(i): i + 50
for i in range(10)
},
# [60, 65): left gripper joint positions
**{
"left_gripper_joint_{}_pos".format(i): i + 60
for i in range(5)
},
"left_gripper_open": 60, # alias of left_gripper_joint_0_pos
# [65, 75): left arm joint velocities
**{
"left_arm_joint_{}_vel".format(i): i + 65
for i in range(10)
},
# [75, 80): left gripper joint velocities
**{
"left_gripper_joint_{}_vel".format(i): i + 75
for i in range(5)
},
"left_gripper_open_vel": 75, # alias of left_gripper_joint_0_vel
# [80, 83): left end effector positions
"left_eef_pos_x": 80,
"left_eef_pos_y": 81,
"left_eef_pos_z": 82,
# [83, 89): left end effector 6D pose
"left_eef_angle_0": 83,
"left_eef_angle_1": 84,
"left_eef_angle_2": 85,
"left_eef_angle_3": 86,
"left_eef_angle_4": 87,
"left_eef_angle_5": 88,
# [89, 92): left end effector velocities
"left_eef_vel_x": 89,
"left_eef_vel_y": 90,
"left_eef_vel_z": 91,
# [92, 95): left end effector angular velocities
"left_eef_angular_vel_roll": 92,
"left_eef_angular_vel_pitch": 93,
"left_eef_angular_vel_yaw": 94,
# [95, 100): reserved
# [100, 102): base linear velocities
"base_vel_x": 100,
"base_vel_y": 101,
# [102, 103): base angular velocities
"base_angular_vel": 102,
# [103, 128): reserved
}
STATE_VEC_LEN = 128

14
RDT-1B/configs/zero2.json Normal file
View File

@ -0,0 +1,14 @@
{
"bf16": {
"enabled": "auto"
},
"train_micro_batch_size_per_gpu": "auto",
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"zero_optimization": {
"stage": 2,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9
}
}

2
RDT-1B/data/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
# Ignore data files
datasets

Binary file not shown.

View File

@ -0,0 +1,154 @@
import tensorflow as tf
import h5py
import os
import fnmatch
import shutil
from tqdm import tqdm
from multiprocessing import Pool
import numpy as np
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _bool_feature(value):
"""Returns a bool_list from a boolean."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[int(value)]))
def serialize_example(
action,
base_action,
qpos,
qvel,
cam_high,
cam_left_wrist,
cam_right_wrist,
instruction,
terminate_episode,
):
feature = {
"action":
_bytes_feature(tf.io.serialize_tensor(action)),
"base_action":
_bytes_feature(tf.io.serialize_tensor(base_action)),
"qpos":
_bytes_feature(tf.io.serialize_tensor(qpos)),
"qvel":
_bytes_feature(tf.io.serialize_tensor(qvel)),
"cam_high":
_bytes_feature(tf.io.serialize_tensor(tf.convert_to_tensor(cam_high.tobytes(), dtype=tf.string))),
"cam_left_wrist":
_bytes_feature(tf.io.serialize_tensor(tf.convert_to_tensor(cam_left_wrist.tobytes(), dtype=tf.string))),
"cam_right_wrist":
_bytes_feature(tf.io.serialize_tensor(tf.convert_to_tensor(cam_right_wrist.tobytes(), dtype=tf.string))),
"instruction":
_bytes_feature(instruction),
"terminate_episode":
_bool_feature(terminate_episode),
}
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializeToString()
def process_hdf5_file(args):
filepath, root_dir, out_dir = args
output_dir = os.path.join(out_dir, os.path.relpath(os.path.dirname(filepath), root_dir))
os.makedirs(output_dir, exist_ok=True)
filename = os.path.basename(filepath)
tfrecord_path = os.path.join(output_dir, filename.replace(".hdf5", ".tfrecord"))
if os.path.exists(tfrecord_path) and os.path.getsize(tfrecord_path) > 0:
return f"TFRecords already exist at {tfrecord_path}"
try:
with h5py.File(filepath, "r") as f, tf.io.TFRecordWriter(tfrecord_path) as writer:
num_episodes = f["action"].shape[0]
# Remove the first few still steps
EPS = 1e-2
qpos = f["observations"]["qpos"][:]
# Get the idx of the first qpos whose delta exceeds the threshold
qpos_delta = np.abs(qpos - qpos[0:1])
indices = np.where(np.any(qpos_delta > EPS, axis=1))[0]
if len(indices) > 0:
first_idx = indices[0]
else:
raise ValueError("Found no qpos that exceeds the threshold.")
for i in range(first_idx - 1, num_episodes):
action = f["action"][i]
base_action = f["base_action"][i]
qpos = f["observations"]["qpos"][i]
qvel = f["observations"]["qvel"][i]
cam_high = f["observations"]["images"]["cam_high"][i]
cam_left_wrist = f["observations"]["images"]["cam_left_wrist"][i]
cam_right_wrist = f["observations"]["images"]["cam_right_wrist"][i]
instruction = f["instruction"][()]
terminate_episode = i == num_episodes - 1
serialized_example = serialize_example(
action,
base_action,
qpos,
qvel,
cam_high,
cam_left_wrist,
cam_right_wrist,
instruction,
terminate_episode,
)
writer.write(serialized_example)
except Exception as e:
with open("error_log.txt", "a") as f:
f.write(f"{filepath}\n")
print(f"error at {filepath}: {e}")
return f"TFRecords written to {tfrecord_path}"
def write_tfrecords(root_dir, out_dir):
if not os.path.exists(out_dir):
os.makedirs(out_dir)
hdf5_files = []
for root, dirs, files in os.walk(root_dir):
if os.path.exists(os.path.join(root, "expanded_instruction_gpt-4-turbo.json")):
# copy the instruction file
target_path = os.path.join(out_dir, os.path.relpath(root, root_dir))
os.makedirs(target_path, exist_ok=True)
shutil.copy(os.path.join(root, "expanded_instruction_gpt-4-turbo.json"), target_path)
elif os.path.exists(os.path.join(root, "expanded_instruction.json")):
print(root)
target_path = os.path.join(out_dir, os.path.relpath(root, root_dir))
os.makedirs(target_path, exist_ok=True)
shutil.copy(os.path.join(root, "expanded_instruction.json"), target_path)
# rename into expanded_instruction_gpt-4-turbo.json
os.rename(
os.path.join(
out_dir,
os.path.relpath(root, root_dir),
"expanded_instruction.json",
),
os.path.join(
out_dir,
os.path.relpath(root, root_dir),
"expanded_instruction_gpt-4-turbo.json",
),
)
for filename in fnmatch.filter(files, "*.hdf5"):
filepath = os.path.join(root, filename)
hdf5_files.append((filepath, root_dir, out_dir))
with Pool(16) as pool:
max_count = len(hdf5_files)
with tqdm(total=max_count) as pbar:
for _ in pool.imap_unordered(process_hdf5_file, hdf5_files):
pbar.update(1)
print(f"TFRecords written to {out_dir}")
root_dir = "../datasets/agilex/rdt_data/"
out_dir = "../datasets/agilex/tfrecords/"
write_tfrecords(root_dir, out_dir)

View File

@ -0,0 +1,256 @@
"""
This file will compute the min, max, mean, and standard deviation of each datasets
in `pretrain_datasets.json` or `pretrain_datasets.json`.
"""
import json
import argparse
import os
# from multiprocessing import Pool, Manager
import tensorflow as tf
import numpy as np
from tqdm import tqdm
from data.vla_dataset import VLADataset
from data.hdf5_vla_dataset import HDF5VLADataset
from data.preprocess import generate_json_state
# Process each dataset to get the statistics
@tf.autograph.experimental.do_not_convert
def process_dataset(name_dataset_pair):
# print(f"PID {os.getpid()} processing {name_dataset_pair[0]}")
dataset_iter = name_dataset_pair[1]
MAX_EPISODES = 100000
EPS = 1e-8
# For debugging
# MAX_EPISODES = 10
episode_cnt = 0
state_sum = 0
state_sum_sq = 0
z_state_sum = 0
z_state_sum_sq = 0
state_cnt = 0
nz_state_cnt = None
state_max = None
state_min = None
for episode in dataset_iter:
episode_cnt += 1
if episode_cnt % 1000 == 0:
print(f"Processing episodes {episode_cnt}/{MAX_EPISODES}")
if episode_cnt > MAX_EPISODES:
break
episode_dict = episode["episode_dict"]
dataset_name = episode["dataset_name"]
res_tup = generate_json_state(episode_dict, dataset_name)
states = res_tup[1]
# Convert to numpy
states = states.numpy()
# Zero the values that are close to zero
z_states = states.copy()
z_states[np.abs(states) <= EPS] = 0
# Compute the non-zero count
if nz_state_cnt is None:
nz_state_cnt = np.zeros(states.shape[1])
nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0)
# Update statistics
state_sum += np.sum(states, axis=0)
state_sum_sq += np.sum(states**2, axis=0)
z_state_sum += np.sum(z_states, axis=0)
z_state_sum_sq += np.sum(z_states**2, axis=0)
state_cnt += states.shape[0]
if state_max is None:
state_max = np.max(states, axis=0)
state_min = np.min(states, axis=0)
else:
state_max = np.maximum(state_max, np.max(states, axis=0))
state_min = np.minimum(state_min, np.min(states, axis=0))
# Add one to avoid division by zero
nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt))
result = {
"dataset_name":
name_dataset_pair[0],
"state_mean": (state_sum / state_cnt).tolist(),
"state_std":
np.sqrt(
np.maximum(
(z_state_sum_sq / nz_state_cnt) - (z_state_sum / state_cnt)**2 * (state_cnt / nz_state_cnt),
np.zeros_like(state_sum_sq),
)).tolist(),
"state_min":
state_min.tolist(),
"state_max":
state_max.tolist(),
}
return result
def process_hdf5_dataset(vla_dataset):
EPS = 1e-8
episode_cnt = 0
state_sum = 0
state_sum_sq = 0
z_state_sum = 0
z_state_sum_sq = 0
state_cnt = 0
nz_state_cnt = None
state_max = None
state_min = None
for i in tqdm(range(len(vla_dataset))):
episode = vla_dataset.get_item(i, state_only=True)
episode_cnt += 1
states = episode["state"]
# Zero the values that are close to zero
z_states = states.copy()
z_states[np.abs(states) <= EPS] = 0
# Compute the non-zero count
if nz_state_cnt is None:
nz_state_cnt = np.zeros(states.shape[1])
nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0)
# Update statistics
state_sum += np.sum(states, axis=0)
state_sum_sq += np.sum(states**2, axis=0)
z_state_sum += np.sum(z_states, axis=0)
z_state_sum_sq += np.sum(z_states**2, axis=0)
state_cnt += states.shape[0]
if state_max is None:
state_max = np.max(states, axis=0)
state_min = np.min(states, axis=0)
else:
state_max = np.maximum(state_max, np.max(states, axis=0))
state_min = np.minimum(state_min, np.min(states, axis=0))
# Add one to avoid division by zero
nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt))
result = {
"dataset_name":
vla_dataset.get_dataset_name(),
"state_mean": (state_sum / state_cnt).tolist(),
"state_std":
np.sqrt(
np.maximum(
(z_state_sum_sq / nz_state_cnt) - (z_state_sum / state_cnt)**2 * (state_cnt / nz_state_cnt),
np.zeros_like(state_sum_sq),
)).tolist(),
"state_min":
state_min.tolist(),
"state_max":
state_max.tolist(),
}
return result
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Multiprocessing currently with bugs
# parser.add_argument('--n_workers', type=int, default=1,
# help="Number of parallel workers.")
parser.add_argument(
"--dataset_type",
type=str,
default="pretrain",
help="Whether to load the pretrain dataset or finetune dataset.",
)
parser.add_argument(
"--save_path",
type=str,
default="configs/dataset_stat.json",
help="JSON file path to save the dataset statistics.",
)
parser.add_argument(
"--skip_exist",
action="store_true",
help="Whether to skip the existing dataset statistics.",
)
parser.add_argument(
"--hdf5_dataset",
action="store_true",
help="Whether to load the dataset from the HDF5 files.",
)
args = parser.parse_args()
if args.hdf5_dataset:
vla_dataset = HDF5VLADataset()
dataset_name = vla_dataset.get_dataset_name()
try:
with open(args.save_path, "r") as f:
results = json.load(f)
except FileNotFoundError:
results = {}
if args.skip_exist and dataset_name in results:
print(f"Skipping existed {dataset_name} dataset statistics")
else:
print(f"Processing {dataset_name} dataset")
result = process_hdf5_dataset(vla_dataset)
results[result["dataset_name"]] = result
with open(args.save_path, "w") as f:
json.dump(results, f, indent=4)
print("All datasets have been processed.")
os._exit(0)
vla_dataset = VLADataset(seed=0, dataset_type=args.dataset_type, repeat=False)
name_dataset_pairs = vla_dataset.name2dataset.items()
# num_workers = args.n_workers
for name_dataset_pair in tqdm(name_dataset_pairs):
try:
with open(args.save_path, "r") as f:
results = json.load(f)
except FileNotFoundError:
results = {}
if args.skip_exist and name_dataset_pair[0] in results:
print(f"Skipping existed {name_dataset_pair[0]} dataset statistics")
continue
print(f"Processing {name_dataset_pair[0]} dataset")
result = process_dataset(name_dataset_pair)
results[result["dataset_name"]] = result
# Save the results in the json file after each dataset (for resume)
with open(args.save_path, "w") as f:
json.dump(results, f, indent=4)
print("All datasets have been processed.")
# with Manager() as manager:
# # Create shared dictionary and lock through the manager, accessible by all processes
# progress = manager.dict(processed=0, results={})
# progress_lock = manager.Lock()
# # Callback function to update progress
# def update_progress(result):
# with progress_lock:
# progress['processed'] += 1
# print(f"{result['dataset_name']} - {progress['processed']}/{len(name_dataset_pairs)} datasets have been processed")
# # Append the result to the shared dictionary
# progress['results'][result["dataset_name"]] = result
# with Pool(num_workers) as p:
# for name_dataset_pair in name_dataset_pairs:
# p.apply_async(process_dataset, args=(name_dataset_pair,), callback=update_progress)
# # Close the pool and wait for the work to finish
# p.close()
# p.join()
# # Save the results in the json file
# with open(args.save_path, 'w') as f:
# json.dump(progress['results'], f, indent=4)

View File

@ -0,0 +1,112 @@
"""
This file will compute the min, max, mean, and standard deviation of each datasets
in `pretrain_datasets.json` or `pretrain_datasets.json`.
"""
import json
import argparse
import numpy as np
from tqdm import tqdm
from data.hdf5_vla_dataset import HDF5VLADataset
def process_hdf5_dataset(vla_dataset):
EPS = 1e-8
episode_cnt = 0
state_sum = 0
state_sum_sq = 0
z_state_sum = 0
z_state_sum_sq = 0
state_cnt = 0
nz_state_cnt = None
state_max = None
state_min = None
for i in tqdm(range(len(vla_dataset))):
episode = vla_dataset.get_item(i, state_only=True)
episode_cnt += 1
states = episode["state"]
# Zero the values that are close to zero
z_states = states.copy()
z_states[np.abs(states) <= EPS] = 0
# Compute the non-zero count
if nz_state_cnt is None:
nz_state_cnt = np.zeros(states.shape[1])
nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0)
# Update statistics
state_sum += np.sum(states, axis=0)
state_sum_sq += np.sum(states**2, axis=0)
z_state_sum += np.sum(z_states, axis=0)
z_state_sum_sq += np.sum(z_states**2, axis=0)
state_cnt += states.shape[0]
if state_max is None:
state_max = np.max(states, axis=0)
state_min = np.min(states, axis=0)
else:
state_max = np.maximum(state_max, np.max(states, axis=0))
state_min = np.minimum(state_min, np.min(states, axis=0))
# Add one to avoid division by zero
nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt))
result = {
"dataset_name":
vla_dataset.get_dataset_name(),
"state_mean": (state_sum / state_cnt).tolist(),
"state_std":
np.sqrt(
np.maximum(
(z_state_sum_sq / nz_state_cnt) - (z_state_sum / state_cnt)**2 * (state_cnt / nz_state_cnt),
np.zeros_like(state_sum_sq),
)).tolist(),
"state_min":
state_min.tolist(),
"state_max":
state_max.tolist(),
}
return result
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--task_name",
type=str,
default="configs/dataset_stat.json",
help="JSON file path to save the dataset statistics.",
)
parser.add_argument(
"--save_path",
type=str,
default="configs/dataset_stat.json",
help="JSON file path to save the dataset statistics.",
)
parser.add_argument(
"--skip_exist",
action="store_true",
help="Whether to skip the existing dataset statistics.",
)
args = parser.parse_args()
vla_dataset = HDF5VLADataset(f"model_config/{args.task_name}.yml")
dataset_name = vla_dataset.get_dataset_name()
try:
with open(args.save_path, "r") as f:
results = json.load(f)
except FileNotFoundError:
results = {}
if args.skip_exist and dataset_name in results:
print(f"Skipping existed {dataset_name} dataset statistics")
else:
print(f"Processing {dataset_name} dataset")
result = process_hdf5_dataset(vla_dataset)
results[result["dataset_name"]] = result
with open(args.save_path, "w") as f:
json.dump(results, f, indent=4)
print("All datasets have been processed.")

Binary file not shown.

View File

@ -0,0 +1,398 @@
import numpy as np
import tensorflow as tf
import yaml
from data.preprocess import generate_json_state
from configs.state_vec import STATE_VEC_IDX_MAPPING
# Read the config
with open("configs/base.yaml", "r") as file:
config = yaml.safe_load(file)
# Load some constants from the config
IMG_HISTORY_SIZE = config["common"]["img_history_size"]
if IMG_HISTORY_SIZE < 1:
raise ValueError("Config `img_history_size` must be at least 1.")
ACTION_CHUNK_SIZE = config["common"]["action_chunk_size"]
if ACTION_CHUNK_SIZE < 1:
raise ValueError("Config `action_chunk_size` must be at least 1.")
@tf.function
def process_episode(epsd: dict, dataset_name: str, image_keys: list, image_mask: list) -> dict:
"""
Process an episode to extract the frames and the json content.
"""
# Frames of each camera
# Ugly code due to tf's poor compatibility
frames_0 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True)
frames_1 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True)
frames_2 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True)
frames_3 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True)
# Traverse the episode to collect...
for step in iter(epsd["steps"]):
# Parse the image
frames_0 = frames_0.write(
frames_0.size(),
tf.cond(
tf.equal(image_mask[0], 1),
lambda: step["observation"][image_keys[0]],
lambda: tf.zeros([0, 0, 0], dtype=tf.uint8),
),
)
# Very ugly code due to tf's poor compatibility
frames_1 = frames_1.write(
frames_1.size(),
tf.cond(
tf.equal(image_mask[1], 1),
lambda: step["observation"][image_keys[1]],
lambda: tf.zeros([0, 0, 0], dtype=tf.uint8),
),
)
frames_2 = frames_2.write(
frames_2.size(),
tf.cond(
tf.equal(image_mask[2], 1),
lambda: step["observation"][image_keys[2]],
lambda: tf.zeros([0, 0, 0], dtype=tf.uint8),
),
)
frames_3 = frames_3.write(
frames_3.size(),
tf.cond(
tf.equal(image_mask[3], 1),
lambda: step["observation"][image_keys[3]],
lambda: tf.zeros([0, 0, 0], dtype=tf.uint8),
),
)
# Calculate the past_frames_0 for each step
# Each step has a window of previous frames with size IMG_HISTORY_SIZE
# Use the first state to pad the frames
# past_frames_0 will have shape (num_steps, IMG_HISTORY_SIZE, height, width, channels)
frames_0 = frames_0.stack()
first_frame = tf.expand_dims(frames_0[0], axis=0)
first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE - 1, axis=0)
padded_frames_0 = tf.concat([first_frame, frames_0], axis=0)
indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_0)[0] + IMG_HISTORY_SIZE)
past_frames_0 = tf.map_fn(lambda i: padded_frames_0[i - IMG_HISTORY_SIZE:i], indices, dtype=tf.uint8)
frames_0_time_mask = tf.ones([tf.shape(frames_0)[0]], dtype=tf.bool)
padded_frames_0_time_mask = tf.pad(
frames_0_time_mask,
[[IMG_HISTORY_SIZE - 1, 0]],
"CONSTANT",
constant_values=False,
)
past_frames_0_time_mask = tf.map_fn(
lambda i: padded_frames_0_time_mask[i - IMG_HISTORY_SIZE:i],
indices,
dtype=tf.bool,
)
# For past_frames_1
frames_1 = frames_1.stack()
first_frame = tf.expand_dims(frames_1[0], axis=0)
first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE - 1, axis=0)
padded_frames_1 = tf.concat([first_frame, frames_1], axis=0)
indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_1)[0] + IMG_HISTORY_SIZE)
past_frames_1 = tf.map_fn(lambda i: padded_frames_1[i - IMG_HISTORY_SIZE:i], indices, dtype=tf.uint8)
frames_1_time_mask = tf.ones([tf.shape(frames_1)[0]], dtype=tf.bool)
padded_frames_1_time_mask = tf.pad(
frames_1_time_mask,
[[IMG_HISTORY_SIZE - 1, 0]],
"CONSTANT",
constant_values=False,
)
past_frames_1_time_mask = tf.map_fn(
lambda i: padded_frames_1_time_mask[i - IMG_HISTORY_SIZE:i],
indices,
dtype=tf.bool,
)
# For past_frames_2
frames_2 = frames_2.stack()
first_frame = tf.expand_dims(frames_2[0], axis=0)
first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE - 1, axis=0)
padded_frames_2 = tf.concat([first_frame, frames_2], axis=0)
indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_2)[0] + IMG_HISTORY_SIZE)
past_frames_2 = tf.map_fn(lambda i: padded_frames_2[i - IMG_HISTORY_SIZE:i], indices, dtype=tf.uint8)
frames_2_time_mask = tf.ones([tf.shape(frames_2)[0]], dtype=tf.bool)
padded_frames_2_time_mask = tf.pad(
frames_2_time_mask,
[[IMG_HISTORY_SIZE - 1, 0]],
"CONSTANT",
constant_values=False,
)
past_frames_2_time_mask = tf.map_fn(
lambda i: padded_frames_2_time_mask[i - IMG_HISTORY_SIZE:i],
indices,
dtype=tf.bool,
)
# For past_frames_3
frames_3 = frames_3.stack()
first_frame = tf.expand_dims(frames_3[0], axis=0)
first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE - 1, axis=0)
padded_frames_3 = tf.concat([first_frame, frames_3], axis=0)
indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_3)[0] + IMG_HISTORY_SIZE)
past_frames_3 = tf.map_fn(lambda i: padded_frames_3[i - IMG_HISTORY_SIZE:i], indices, dtype=tf.uint8)
frames_3_time_mask = tf.ones([tf.shape(frames_3)[0]], dtype=tf.bool)
padded_frames_3_time_mask = tf.pad(
frames_3_time_mask,
[[IMG_HISTORY_SIZE - 1, 0]],
"CONSTANT",
constant_values=False,
)
past_frames_3_time_mask = tf.map_fn(
lambda i: padded_frames_3_time_mask[i - IMG_HISTORY_SIZE:i],
indices,
dtype=tf.bool,
)
# Creat the ids for each step
step_id = tf.range(0, tf.shape(frames_0)[0])
return {
"dataset_name": dataset_name,
"episode_dict": epsd,
"step_id": step_id,
"past_frames_0": past_frames_0,
"past_frames_0_time_mask": past_frames_0_time_mask,
"past_frames_1": past_frames_1,
"past_frames_1_time_mask": past_frames_1_time_mask,
"past_frames_2": past_frames_2,
"past_frames_2_time_mask": past_frames_2_time_mask,
"past_frames_3": past_frames_3,
"past_frames_3_time_mask": past_frames_3_time_mask,
}
@tf.function
def bgr_to_rgb(epsd: dict):
"""
Convert BGR images to RGB images.
"""
past_frames_0 = epsd["past_frames_0"]
past_frames_0 = tf.cond(
tf.equal(tf.shape(past_frames_0)[-1], 3),
lambda: tf.stack(
[past_frames_0[..., 2], past_frames_0[..., 1], past_frames_0[..., 0]],
axis=-1,
),
lambda: past_frames_0,
)
past_frames_1 = epsd["past_frames_1"]
past_frames_1 = tf.cond(
tf.equal(tf.shape(past_frames_1)[-1], 3),
lambda: tf.stack(
[past_frames_1[..., 2], past_frames_1[..., 1], past_frames_1[..., 0]],
axis=-1,
),
lambda: past_frames_1,
)
past_frames_2 = epsd["past_frames_2"]
past_frames_2 = tf.cond(
tf.equal(tf.shape(past_frames_2)[-1], 3),
lambda: tf.stack(
[past_frames_2[..., 2], past_frames_2[..., 1], past_frames_2[..., 0]],
axis=-1,
),
lambda: past_frames_2,
)
past_frames_3 = epsd["past_frames_3"]
past_frames_3 = tf.cond(
tf.equal(tf.shape(past_frames_3)[-1], 3),
lambda: tf.stack(
[past_frames_3[..., 2], past_frames_3[..., 1], past_frames_3[..., 0]],
axis=-1,
),
lambda: past_frames_3,
)
return {
"dataset_name": epsd["dataset_name"],
"episode_dict": epsd["episode_dict"],
"step_id": epsd["step_id"],
"past_frames_0": past_frames_0,
"past_frames_0_time_mask": epsd["past_frames_0_time_mask"],
"past_frames_1": past_frames_1,
"past_frames_1_time_mask": epsd["past_frames_1_time_mask"],
"past_frames_2": past_frames_2,
"past_frames_2_time_mask": epsd["past_frames_2_time_mask"],
"past_frames_3": past_frames_3,
"past_frames_3_time_mask": epsd["past_frames_3_time_mask"],
}
def flatten_episode(episode: dict) -> tf.data.Dataset:
"""
Flatten the episode to a list of steps.
"""
episode_dict = episode["episode_dict"]
dataset_name = episode["dataset_name"]
json_content, states, masks = generate_json_state(episode_dict, dataset_name)
# Calculate the past_states for each step
# Each step has a window of previous states with size ACTION_CHUNK_SIZE
# Use the first state to pad the states
# past_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim)
first_state = tf.expand_dims(states[0], axis=0)
first_state = tf.repeat(first_state, ACTION_CHUNK_SIZE - 1, axis=0)
padded_states = tf.concat([first_state, states], axis=0)
indices = tf.range(ACTION_CHUNK_SIZE, tf.shape(states)[0] + ACTION_CHUNK_SIZE)
past_states = tf.map_fn(lambda i: padded_states[i - ACTION_CHUNK_SIZE:i], indices, dtype=tf.float32)
states_time_mask = tf.ones([tf.shape(states)[0]], dtype=tf.bool)
padded_states_time_mask = tf.pad(
states_time_mask,
[[ACTION_CHUNK_SIZE - 1, 0]],
"CONSTANT",
constant_values=False,
)
past_states_time_mask = tf.map_fn(
lambda i: padded_states_time_mask[i - ACTION_CHUNK_SIZE:i],
indices,
dtype=tf.bool,
)
# Calculate the future_states for each step
# Each step has a window of future states with size ACTION_CHUNK_SIZE
# Use the last state to pad the states
# future_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim)
last_state = tf.expand_dims(states[-1], axis=0)
last_state = tf.repeat(last_state, ACTION_CHUNK_SIZE, axis=0)
padded_states = tf.concat([states, last_state], axis=0)
indices = tf.range(1, tf.shape(states)[0] + 1)
future_states = tf.map_fn(lambda i: padded_states[i:i + ACTION_CHUNK_SIZE], indices, dtype=tf.float32)
states_time_mask = tf.ones([tf.shape(states)[0]], dtype=tf.bool)
padded_states_time_mask = tf.pad(states_time_mask, [[0, ACTION_CHUNK_SIZE]], "CONSTANT", constant_values=False)
future_states_time_mask = tf.map_fn(
lambda i: padded_states_time_mask[i:i + ACTION_CHUNK_SIZE],
indices,
dtype=tf.bool,
)
# Calculate the mean and std for state
state_std = tf.math.reduce_std(states, axis=0, keepdims=True)
state_std = tf.repeat(state_std, tf.shape(states)[0], axis=0)
state_mean = tf.math.reduce_mean(states, axis=0, keepdims=True)
state_mean = tf.repeat(state_mean, tf.shape(states)[0], axis=0)
state_norm = tf.math.reduce_mean(tf.math.square(states), axis=0, keepdims=True)
state_norm = tf.math.sqrt(state_norm)
state_norm = tf.repeat(state_norm, tf.shape(states)[0], axis=0)
# Create a list of steps
step_data = []
for i in range(tf.shape(states)[0]):
step_data.append({
"step_id": episode["step_id"][i],
"json_content": json_content,
"state_chunk": past_states[i],
"state_chunk_time_mask": past_states_time_mask[i],
"action_chunk": future_states[i],
"action_chunk_time_mask": future_states_time_mask[i],
"state_vec_mask": masks[i],
"past_frames_0": episode["past_frames_0"][i],
"past_frames_0_time_mask": episode["past_frames_0_time_mask"][i],
"past_frames_1": episode["past_frames_1"][i],
"past_frames_1_time_mask": episode["past_frames_1_time_mask"][i],
"past_frames_2": episode["past_frames_2"][i],
"past_frames_2_time_mask": episode["past_frames_2_time_mask"][i],
"past_frames_3": episode["past_frames_3"][i],
"past_frames_3_time_mask": episode["past_frames_3_time_mask"][i],
"state_std": state_std[i],
"state_mean": state_mean[i],
"state_norm": state_norm[i],
})
return step_data
def flatten_episode_agilex(episode: dict) -> tf.data.Dataset:
"""
Flatten the episode to a list of steps.
"""
episode_dict = episode["episode_dict"]
dataset_name = episode["dataset_name"]
json_content, states, masks, acts = generate_json_state(episode_dict, dataset_name)
# Calculate the past_states for each step
# Each step has a window of previous states with size ACTION_CHUNK_SIZE
# Use the first state to pad the states
# past_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim)
first_state = tf.expand_dims(states[0], axis=0)
first_state = tf.repeat(first_state, ACTION_CHUNK_SIZE - 1, axis=0)
padded_states = tf.concat([first_state, states], axis=0)
indices = tf.range(ACTION_CHUNK_SIZE, tf.shape(states)[0] + ACTION_CHUNK_SIZE)
past_states = tf.map_fn(lambda i: padded_states[i - ACTION_CHUNK_SIZE:i], indices, dtype=tf.float32)
states_time_mask = tf.ones([tf.shape(states)[0]], dtype=tf.bool)
padded_states_time_mask = tf.pad(
states_time_mask,
[[ACTION_CHUNK_SIZE - 1, 0]],
"CONSTANT",
constant_values=False,
)
past_states_time_mask = tf.map_fn(
lambda i: padded_states_time_mask[i - ACTION_CHUNK_SIZE:i],
indices,
dtype=tf.bool,
)
# NOTE bg the future states shall be actions
# Calculate the future_states for each step
# Each step has a window of future states with size ACTION_CHUNK_SIZE
# Use the last action to pad the states
# future_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim)
last_act = tf.expand_dims(acts[-1], axis=0)
last_act = tf.repeat(last_act, ACTION_CHUNK_SIZE, axis=0)
padded_states = tf.concat([acts, last_act], axis=0)
# indices = tf.range(1, tf.shape(states)[0] + 1)
indices = tf.range(0, tf.shape(acts)[0]) # NOTE time 0 action = time 1 state
future_states = tf.map_fn(lambda i: padded_states[i:i + ACTION_CHUNK_SIZE], indices, dtype=tf.float32)
states_time_mask = tf.ones([tf.shape(acts)[0]], dtype=tf.bool)
padded_states_time_mask = tf.pad(states_time_mask, [[0, ACTION_CHUNK_SIZE]], "CONSTANT", constant_values=False)
future_states_time_mask = tf.map_fn(
lambda i: padded_states_time_mask[i:i + ACTION_CHUNK_SIZE],
indices,
dtype=tf.bool,
)
# Calculate the std and mean for state
state_std = tf.math.reduce_std(states, axis=0, keepdims=True)
state_std = tf.repeat(state_std, tf.shape(states)[0], axis=0)
state_mean = tf.math.reduce_mean(states, axis=0, keepdims=True)
state_mean = tf.repeat(state_mean, tf.shape(states)[0], axis=0)
state_norm = tf.math.reduce_mean(tf.math.square(acts), axis=0, keepdims=True)
state_norm = tf.math.sqrt(state_norm)
state_norm = tf.repeat(state_norm, tf.shape(states)[0], axis=0)
# Create a list of steps
step_data = []
for i in range(tf.shape(states)[0]):
step_data.append({
"step_id": episode["step_id"][i],
"json_content": json_content,
"state_chunk": past_states[i],
"state_chunk_time_mask": past_states_time_mask[i],
"action_chunk": future_states[i],
"action_chunk_time_mask": future_states_time_mask[i],
"state_vec_mask": masks[i],
"past_frames_0": episode["past_frames_0"][i],
"past_frames_0_time_mask": episode["past_frames_0_time_mask"][i],
"past_frames_1": episode["past_frames_1"][i],
"past_frames_1_time_mask": episode["past_frames_1_time_mask"][i],
"past_frames_2": episode["past_frames_2"][i],
"past_frames_2_time_mask": episode["past_frames_2_time_mask"][i],
"past_frames_3": episode["past_frames_3"][i],
"past_frames_3_time_mask": episode["past_frames_3_time_mask"][i],
"state_std": state_std[i],
"state_mean": state_mean[i],
"state_norm": state_norm[i],
})
return step_data

25
RDT-1B/data/filelock.py Normal file
View File

@ -0,0 +1,25 @@
import fcntl
class FileLock:
"""
A file lock class.
"""
def __init__(self, filename):
self.filename = filename
self.handle = None
def acquire_read_lock(self):
self.handle = open(self.filename + ".lock", "r")
fcntl.flock(self.handle, fcntl.LOCK_SH | fcntl.LOCK_NB)
def acquire_write_lock(self):
self.handle = open(self.filename + ".lock", "w")
fcntl.flock(self.handle, fcntl.LOCK_EX | fcntl.LOCK_NB)
def release_lock(self):
if self.handle is not None:
fcntl.flock(self.handle, fcntl.LOCK_UN)
self.handle.close()
self.handle = None

View File

@ -0,0 +1,372 @@
import os
import fnmatch
import json
import h5py
import yaml
import cv2
import numpy as np
from configs.state_vec import STATE_VEC_IDX_MAPPING
class HDF5VLADataset:
"""
This class is used to sample episodes from the embododiment dataset
stored in HDF5.
"""
def __init__(self, model_config_path) -> None:
# [Modify] The path to the HDF5 dataset directory
# Each HDF5 file contains one episode
with open(model_config_path, "r") as f:
model_config = yaml.safe_load(f)
HDF5_DIR = model_config["data_path"]
self.DATASET_NAME = "agilex"
self.file_paths = []
for root, _, files in os.walk(HDF5_DIR):
for filename in fnmatch.filter(files, "*.hdf5"):
file_path = os.path.join(root, filename)
self.file_paths.append(file_path)
# Load the config
with open("configs/base.yaml", "r") as file:
config = yaml.safe_load(file)
self.CHUNK_SIZE = config["common"]["action_chunk_size"]
self.IMG_HISORY_SIZE = config["common"]["img_history_size"]
self.STATE_DIM = config["common"]["state_dim"]
# Get each episode's len (use original length, not standardized length)
episode_lens = []
for file_path in self.file_paths:
try:
with h5py.File(file_path, "r") as f:
qpos = f["observations"]["qpos"][:]
num_steps = qpos.shape[0]
episode_lens.append(num_steps)
except Exception as e:
print(f"Warning: Could not read {file_path}: {e}")
episode_lens.append(0)
self.episode_sample_weights = np.array(episode_lens) / np.sum(episode_lens)
def __len__(self):
return len(self.file_paths)
def get_dataset_name(self):
return self.DATASET_NAME
def get_item(self, index: int = None, state_only=False):
"""Get a training sample at a random timestep.
Args:
index (int, optional): the index of the episode.
If not provided, a random episode will be selected.
state_only (bool, optional): Whether to return only the state.
In this way, the sample will contain a complete trajectory rather
than a single timestep. Defaults to False.
Returns:
sample (dict): a dictionary containing the training sample.
"""
while True:
if index is None:
file_path = np.random.choice(self.file_paths, p=self.episode_sample_weights)
else:
file_path = self.file_paths[index]
valid, sample = (self.parse_hdf5_file(file_path)
if not state_only else self.parse_hdf5_file_state_only(file_path))
if valid:
return sample
else:
index = np.random.randint(0, len(self.file_paths))
def parse_hdf5_file(self, file_path):
"""[Modify] Parse a hdf5 file to generate a training sample at
a random timestep.
Args:
file_path (str): the path to the hdf5 file
Returns:
valid (bool): whether the episode is valid, which is useful for filtering.
If False, this episode will be dropped.
dict: a dictionary containing the training sample,
{
"meta": {
"dataset_name": str, # the name of your dataset.
"#steps": int, # the number of steps in the episode,
# also the total timesteps.
"instruction": str # the language instruction for this episode.
},
"step_id": int, # the index of the sampled step,
# also the timestep t.
"state": ndarray, # state[t], (1, STATE_DIM).
"state_std": ndarray, # std(state[:]), (STATE_DIM,).
"state_mean": ndarray, # mean(state[:]), (STATE_DIM,).
"state_norm": ndarray, # norm(state[:]), (STATE_DIM,).
"actions": ndarray, # action[t:t+CHUNK_SIZE], (CHUNK_SIZE, STATE_DIM).
"state_indicator", ndarray, # indicates the validness of each dim, (STATE_DIM,).
"cam_high": ndarray, # external camera image, (IMG_HISORY_SIZE, H, W, 3)
# or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
"cam_high_mask": ndarray, # indicates the validness of each timestep, (IMG_HISORY_SIZE,) boolean array.
# For the first IMAGE_HISTORY_SIZE-1 timesteps, the mask should be False.
"cam_left_wrist": ndarray, # left wrist camera image, (IMG_HISORY_SIZE, H, W, 3).
# or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
"cam_left_wrist_mask": ndarray,
"cam_right_wrist": ndarray, # right wrist camera image, (IMG_HISORY_SIZE, H, W, 3).
# or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
# If only one wrist, make it right wrist, plz.
"cam_right_wrist_mask": ndarray
} or None if the episode is invalid.
"""
with h5py.File(file_path, "r") as f:
qpos = f["observations"]["qpos"][:]
left_arm_dim = f["observations"]["left_arm_dim"][:]
right_arm_dim = f["observations"]["right_arm_dim"][:]
num_steps = qpos.shape[0]
action_dim = qpos
# [Optional] We drop too-short episode
# if num_steps < 128:
# return False, None
# [Optional] We skip the first few still steps
EPS = 1e-2
# Get the idx of the first qpos whose delta exceeds the threshold
qpos_delta = np.abs(qpos - qpos[0:1])
indices = np.where(np.any(qpos_delta > EPS, axis=1))[0]
if len(indices) > 0:
first_idx = indices[0]
else:
raise ValueError("Found no qpos that exceeds the threshold.")
# We randomly sample a timestep
step_id = np.random.randint(first_idx - 1, num_steps)
# Load the instruction
dir_path = os.path.dirname(file_path)
# with open(os.path.join(dir_path, 'instruction.json'), 'r') as f_instr:
# instruction_dict = json.load(f_instr)
# # We have 1/3 prob to use original instruction,
# # 1/3 to use simplified instruction,
# # and 1/3 to use expanded instruction.
# instruction_type = np.random.choice([
# 'instruction', 'expanded_instruction'])
# instruction = instruction_dict[instruction_type]
# if isinstance(instruction, list):
# instruction = np.random.choice(instruction)
# You can also use precomputed language embeddings (recommended)
# instruction = "path/to/lang_embed.pt"
instructions_path = os.path.join(dir_path, "instructions")
instructions_names = []
for filename in os.listdir(instructions_path):
# 检查文件名是否以.pt结尾
if filename.endswith(".pt"):
instructions_names.append(os.path.join(instructions_path, filename))
instruction = np.random.choice(instructions_names)
# print(f"choose {instruction} file as instruction.")
# Assemble the meta
meta = {
"dataset_name": self.DATASET_NAME,
"#steps": num_steps,
"step_id": step_id,
"instruction": instruction,
}
# Rescale gripper to [0, 1]
# qpos = qpos / np.array([[1 for i in range(left_arm_dim[0] + 1 + right_arm_dim[0] + 1)]])
# target_qpos = f["action"][step_id:step_id + self.CHUNK_SIZE] / np.array(
# [[1 for i in range(left_arm_dim[0] + 1 + right_arm_dim[0] + 1)]])
qpos = qpos / np.array(
# [[1, 1, 1, 1, 1, 1, 4.7908, 1, 1, 1, 1, 1, 1, 4.7888]]
[[180, 180, 180, 180, 180, 180]]
)
target_qpos = f['action'][step_id:step_id + self.CHUNK_SIZE] / np.array(
# [[1, 1, 1, 1, 1, 1, 11.8997, 1, 1, 1, 1, 1, 1, 13.9231]]
[[180, 180, 180, 180, 180, 180]]
)
# Parse the state and action
state = qpos[step_id:step_id + 1]
state_std = np.std(qpos, axis=0)
state_mean = np.mean(qpos, axis=0)
state_norm = np.sqrt(np.mean(qpos**2, axis=0))
actions = target_qpos
if actions.shape[0] < self.CHUNK_SIZE:
# Pad the actions using the last action
actions = np.concatenate(
[
actions,
np.tile(actions[-1:], (self.CHUNK_SIZE - actions.shape[0], 1)),
],
axis=0,
)
# Fill the state/action into the unified vector
def fill_in_state(values):
# Target indices corresponding to your state space
# In this example: 6 joints + 1 gripper for each arm
UNI_STATE_INDICES = [
STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(6)
# ] + [
# STATE_VEC_IDX_MAPPING["right_gripper_open"]
]
uni_vec = np.zeros(values.shape[:-1] + (self.STATE_DIM, ))
uni_vec[..., UNI_STATE_INDICES] = values
return uni_vec
state = fill_in_state(state)
state_indicator = fill_in_state(np.ones_like(state_std))
state_std = fill_in_state(state_std)
state_mean = fill_in_state(state_mean)
state_norm = fill_in_state(state_norm)
# If action's format is different from state's,
# you may implement fill_in_action()
actions = fill_in_state(actions)
# Parse the images
def parse_img(key):
imgs = []
for i in range(max(step_id - self.IMG_HISORY_SIZE + 1, 0), step_id + 1):
img_bits = f["observations"]["images"][key][i]
img = cv2.imdecode(np.frombuffer(img_bits, np.uint8), cv2.IMREAD_COLOR)
imgs.append(img)
imgs = np.stack(imgs)
if imgs.shape[0] < self.IMG_HISORY_SIZE:
# Pad the images using the first image
imgs = np.concatenate(
[
np.tile(
imgs[:1],
(self.IMG_HISORY_SIZE - imgs.shape[0], 1, 1, 1),
),
imgs,
],
axis=0,
)
return imgs
# `cam_high` is the external camera image
cam_high = parse_img("cam_high")
# For step_id = first_idx - 1, the valid_len should be one
valid_len = min(step_id - (first_idx - 1) + 1, self.IMG_HISORY_SIZE)
cam_high_mask = np.array([False] * (self.IMG_HISORY_SIZE - valid_len) + [True] * valid_len)
# cam_left_wrist = parse_img("cam_left_wrist")
# cam_left_wrist_mask = cam_high_mask.copy()
cam_left_wrist = np.zeros((self.IMG_HISORY_SIZE, 0, 0, 0))#parse_img('cam_right_wrist')
cam_left_wrist_mask = np.array([False] * self.IMG_HISORY_SIZE)#cam_high_mask.copy()
cam_right_wrist = parse_img("cam_right_wrist")
cam_right_wrist_mask = cam_high_mask.copy() # 使用相同的掩码逻辑
# Return the resulting sample
# For unavailable images, return zero-shape arrays, i.e., (IMG_HISORY_SIZE, 0, 0, 0)
# E.g., return np.zeros((self.IMG_HISORY_SIZE, 0, 0, 0)) for the key "cam_left_wrist",
# if the left-wrist camera is unavailable on your robot
return True, {
"meta": meta,
"state": state,
"state_std": state_std,
"state_mean": state_mean,
"state_norm": state_norm,
"actions": actions,
"state_indicator": state_indicator,
"cam_high": cam_high,
"cam_high_mask": cam_high_mask,
"cam_left_wrist": cam_left_wrist,
"cam_left_wrist_mask": cam_left_wrist_mask,
"cam_right_wrist": cam_right_wrist,
"cam_right_wrist_mask": cam_right_wrist_mask,
}
def parse_hdf5_file_state_only(self, file_path):
"""[Modify] Parse a hdf5 file to generate a state trajectory.
Args:
file_path (str): the path to the hdf5 file
Returns:
valid (bool): whether the episode is valid, which is useful for filtering.
If False, this episode will be dropped.
dict: a dictionary containing the training sample,
{
"state": ndarray, # state[:], (T, STATE_DIM).
"action": ndarray, # action[:], (T, STATE_DIM).
} or None if the episode is invalid.
"""
with h5py.File(file_path, "r") as f:
qpos = f["observations"]["qpos"][:]
left_arm_dim = f["observations"]["left_arm_dim"][:]
right_arm_dim = f["observations"]["right_arm_dim"][:]
num_steps = qpos.shape[0]
# [Optional] We drop too-short episode
# if num_steps < 128:
# return False, None
# [Optional] We skip the first few still steps
EPS = 1e-2
# Get the idx of the first qpos whose delta exceeds the threshold
qpos_delta = np.abs(qpos - qpos[0:1])
indices = np.where(np.any(qpos_delta > EPS, axis=1))[0]
if len(indices) > 0:
first_idx = indices[0]
else:
raise ValueError("Found no qpos that exceeds the threshold.")
# Rescale gripper to [0, 1]
# qpos = qpos / np.array([[1 for i in range(left_arm_dim[0] + right_arm_dim[0] + 2)]])
# target_qpos = f["action"][:] / np.array([[1 for i in range(left_arm_dim[0] + right_arm_dim[0] + 2)]])
qpos = qpos / np.array(
# [[1, 1, 1, 1, 1, 1, 4.7908, 1, 1, 1, 1, 1, 1, 4.7888]]
[[180, 180, 180, 180, 180, 180]]
)
target_qpos = f['action'][first_idx - 1:] / np.array(
# [[1, 1, 1, 1, 1, 1, 11.8997, 1, 1, 1, 1, 1, 1, 13.9231]]
[[180, 180, 180, 180, 180, 180]]
)
# Parse the state and action
state = qpos[first_idx - 1:]
action = target_qpos[first_idx - 1:]
# Standardize trajectory length to avoid batch size mismatch
# Use a fixed length (e.g., 128) or pad/truncate to match
target_length = 128 # You can adjust this value
if state.shape[0] > target_length:
# Truncate to target length
state = state[:target_length]
action = action[:target_length]
elif state.shape[0] < target_length:
# Pad with the last state/action
pad_length = target_length - state.shape[0]
state = np.concatenate([state, np.tile(state[-1:], (pad_length, 1))], axis=0)
action = np.concatenate([action, np.tile(action[-1:], (pad_length, 1))], axis=0)
# Fill the state/action into the unified vector
def fill_in_state(values):
# Target indices corresponding to your state space
# In this example: 6 joints + 1 gripper for each arm
UNI_STATE_INDICES = [
STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(6)
# ] + [
# STATE_VEC_IDX_MAPPING["right_gripper_open"]
]
uni_vec = np.zeros(values.shape[:-1] + (self.STATE_DIM, ))
uni_vec[..., UNI_STATE_INDICES] = values
return uni_vec
state = fill_in_state(state)
action = fill_in_state(action)
# Return the resulting sample
return True, {"state": state, "action": action}
if __name__ == "__main__":
ds = HDF5VLADataset()
for i in range(len(ds)):
print(f"Processing episode {i}/{len(ds)}...")
ds.get_item(i)

299
RDT-1B/data/preprocess.py Normal file
View File

@ -0,0 +1,299 @@
import json
import tensorflow as tf
import yaml
from data.preprocess_scripts import *
from configs.state_vec import STATE_VEC_IDX_MAPPING, STATE_VEC_LEN
from data.utils import capitalize_and_period
# The dataset without state
DATASET_NAMES_NO_STATE = [
"nyu_door_opening_surprising_effectiveness",
"usc_cloth_sim_converted_externally_to_rlds",
"cmu_franka_exploration_dataset_converted_externally_to_rlds",
"imperialcollege_sawyer_wrist_cam",
]
# Read the image keys of each dataset
with open("configs/dataset_img_keys.json", "r") as file:
IMAGE_KEYS = json.load(file)
# Read the config
with open("configs/base.yaml", "r") as file:
config = yaml.safe_load(file)
def assemble_state_vec(arm_concat: tf.Tensor, arm_format: str, base_concat=None, base_format=None) -> tf.Tensor:
"""
Assemble the state/action vector from the arm and base.
"""
state_vec = tf.zeros(STATE_VEC_LEN, dtype=tf.float32)
mask_vec = tf.zeros(STATE_VEC_LEN, dtype=tf.float32)
# Assemble the arm state
arm_concat = tf.cast(arm_concat, tf.float32)
arm_format = arm_format.split(",")
# Use the scatter_nd to avoid the duplicate indices
state_vec = tf.tensor_scatter_nd_update(state_vec, [[STATE_VEC_IDX_MAPPING[name]] for name in arm_format],
arm_concat)
mask_vec = tf.tensor_scatter_nd_update(
mask_vec,
[[STATE_VEC_IDX_MAPPING[name]] for name in arm_format],
tf.ones(len(arm_format), dtype=tf.float32),
)
# Assemble the base state if exists
if base_concat is not None:
base_concat = tf.cast(base_concat, tf.float32)
base_format = base_format.split(",")
state_vec = tf.tensor_scatter_nd_update(
state_vec,
[[STATE_VEC_IDX_MAPPING[name]] for name in base_format],
base_concat,
)
mask_vec = tf.tensor_scatter_nd_update(
mask_vec,
[[STATE_VEC_IDX_MAPPING[name]] for name in base_format],
tf.ones(len(base_format), dtype=tf.float32),
)
return state_vec, mask_vec
@tf.autograph.experimental.do_not_convert
def _generate_json_state_agilex(episode: dict, dataset_name: str):
"""
Generate the json dict and state for a given episode.
"""
# Load some constants from the config
IMG_HISTORY_SIZE = config["common"]["img_history_size"]
if IMG_HISTORY_SIZE < 1:
raise ValueError("Config `img_history_size` must be at least 1.")
ACTION_CHUNK_SIZE = config["common"]["action_chunk_size"]
if ACTION_CHUNK_SIZE < 1:
raise ValueError("Config `action_chunk_size` must be at least 1.")
# Initialize the episode_metadata
episode_metadata = {"dataset_name": dataset_name, "#steps": 0, "instruction": None}
# Check whether this episode has an 'END'
base_act = None
last_base_act = None
episode_states = []
episode_acts = []
episode_masks = []
has_base = None
for step_id, step in enumerate(iter(episode["steps"])):
# Parse the action
action = step["action"]
if has_base is None:
has_base = "base_concat" in action
if has_base:
base_act = action["base_concat"]
# Parse the state
state = step["observation"]
arm_format = state["format"].numpy().decode("utf-8")
base_format = None
if has_base:
act_format = action["format"].numpy().decode("utf-8")
base_formate_idx = act_format.find("base")
base_format = act_format[base_formate_idx:]
arm_state = state["arm_concat"]
base_state = None
if has_base:
if last_base_act is None:
base_state = base_act * 0
else:
base_state = last_base_act
last_base_act = base_act
# Assemble the state vector
state_vec, mask_vec = assemble_state_vec(arm_state, arm_format, base_state, base_format)
act_vec, mask_vec = assemble_state_vec(action["arm_concat"], arm_format, base_state, base_format)
episode_states.append(state_vec)
episode_masks.append(mask_vec)
episode_acts.append(act_vec)
# Parse the task instruction
instr = step["observation"]["natural_language_instruction"]
instr = instr.numpy().decode("utf-8")
instr = capitalize_and_period(instr)
# Write to the episode_metadata
if episode_metadata["instruction"] is None:
episode_metadata["instruction"] = instr
episode_metadata["#steps"] = step_id
episode_states = tf.stack(episode_states)
episode_masks = tf.stack(episode_masks)
episode_acts = tf.stack(episode_acts)
return episode_metadata, episode_states, episode_masks, episode_acts
@tf.autograph.experimental.do_not_convert
def _generate_json_state(episode: dict, dataset_name: str):
"""
Generate the json dict and state for a given episode.
"""
# Load some constants from the config
IMG_HISTORY_SIZE = config["common"]["img_history_size"]
if IMG_HISTORY_SIZE < 1:
raise ValueError("Config `img_history_size` must be at least 1.")
ACTION_CHUNK_SIZE = config["common"]["action_chunk_size"]
if ACTION_CHUNK_SIZE < 1:
raise ValueError("Config `action_chunk_size` must be at least 1.")
# Initialize the episode_metadata
episode_metadata = {"dataset_name": dataset_name, "#steps": 0, "instruction": None}
# Check whether this episode has an 'END'
base_act = None
last_base_act = None
episode_states = []
episode_masks = []
has_base = None
for step_id, step in enumerate(iter(episode["steps"])):
# Parse the action
action = step["action"]
if has_base is None:
has_base = "base_concat" in action
if has_base:
base_act = action["base_concat"]
# Parse the state
state = step["observation"]
arm_format = state["format"].numpy().decode("utf-8")
base_format = None
if has_base:
act_format = action["format"].numpy().decode("utf-8")
base_formate_idx = act_format.find("base")
base_format = act_format[base_formate_idx:]
arm_state = state["arm_concat"]
base_state = None
if has_base:
if last_base_act is None:
base_state = base_act * 0
else:
base_state = last_base_act
last_base_act = base_act
# Assemble the state vector
state_vec, mask_vec = assemble_state_vec(arm_state, arm_format, base_state, base_format)
episode_states.append(state_vec)
episode_masks.append(mask_vec)
# Parse the task instruction
instr = step["observation"]["natural_language_instruction"]
instr = instr.numpy().decode("utf-8")
instr = capitalize_and_period(instr)
# Write to the episode_metadata
if episode_metadata["instruction"] is None:
episode_metadata["instruction"] = instr
episode_metadata["#steps"] = step_id
episode_states = tf.stack(episode_states)
episode_masks = tf.stack(episode_masks)
return episode_metadata, episode_states, episode_masks
@tf.autograph.experimental.do_not_convert
def _generate_json_state_nostate_ds(episode: dict, dataset_name: str):
"""
Generate the json dict and state for an episode in the dataset without state.
If not state, we use the last action as current state.
"""
# Load some constants from the config
IMG_HISTORY_SIZE = config["common"]["img_history_size"]
if IMG_HISTORY_SIZE < 1:
raise ValueError("Config `img_history_size` must be at least 1.")
ACTION_CHUNK_SIZE = config["common"]["action_chunk_size"]
if ACTION_CHUNK_SIZE < 1:
raise ValueError("Config `action_chunk_size` must be at least 1.")
# Initialize the episode_metadata
episode_metadata = {"dataset_name": dataset_name, "#steps": 0, "instruction": None}
last_base_act = None
last_arm_act = None
episode_states = []
episode_masks = []
has_base = None
for step_id, step in enumerate(iter(episode["steps"])):
# Parse the action
action = step["action"]
if has_base is None:
has_base = "base_concat" in action
if has_base:
base_act = action["base_concat"]
if last_base_act is None:
last_base_act = base_act * 0 # Initialize
# Parse the arm action
arm_act = action["arm_concat"]
if last_arm_act is None:
last_arm_act = arm_act * 0 # Initialize
# Parse the act format
# Action format as the state format
act_format = action["format"].numpy().decode("utf-8")
# Assemble the state vector
if has_base:
last_act_concat = tf.concat([last_arm_act, last_base_act], axis=0)
else:
last_act_concat = last_arm_act
state_vec, mask_vec = assemble_state_vec(last_act_concat, act_format)
episode_states.append(state_vec)
episode_masks.append(mask_vec)
# Parse the task instruction
instr = step["observation"]["natural_language_instruction"]
instr = instr.numpy().decode("utf-8")
instr = capitalize_and_period(instr)
# Write to the episode_metadata
if episode_metadata["instruction"] is None:
episode_metadata["instruction"] = instr
# Update the last_arm_act and last_base_act
last_arm_act = arm_act
if has_base:
last_base_act = base_act
episode_metadata["#steps"] = step_id
episode_states = tf.stack(episode_states)
episode_masks = tf.stack(episode_masks)
return episode_metadata, episode_states, episode_masks
@tf.autograph.experimental.do_not_convert
def generate_json_state(episode: dict, dataset_name: str):
"""
Generate the json dict and state for an episode.
"""
if isinstance(dataset_name, tf.Tensor):
dataset_name = dataset_name.numpy().decode("utf-8")
# Process each step in the episode
episode["steps"] = episode["steps"].map(globals()[dataset_name].process_step, )
if dataset_name == "agilex":
return _generate_json_state_agilex(episode, dataset_name)
if dataset_name in DATASET_NAMES_NO_STATE:
return _generate_json_state_nostate_ds(episode, dataset_name)
return _generate_json_state(episode, dataset_name)

313
RDT-1B/data/producer.py Normal file
View File

@ -0,0 +1,313 @@
import time
import json
import os
import time
import argparse
import sys
import signal
import random
from multiprocessing import Process
import numpy as np
import tensorflow as tf
import yaml
from data.vla_dataset import VLADataset
from data.filelock import FileLock
# Producer does not need GPU
tf.config.set_visible_devices([], "GPU")
# Read the config
with open("configs/base.yaml", "r") as file:
config = yaml.safe_load(file)
# Load some constants from the config
BUF_PATH = config["dataset"]["buf_path"]
BUF_NUM_CHUNKS = config["dataset"]["buf_num_chunks"]
if BUF_NUM_CHUNKS < 1:
raise ValueError("Config `buf_num_chunks` must be at least 1.")
BUF_CHUNK_SIZE = config["dataset"]["buf_chunk_size"]
if BUF_CHUNK_SIZE < 1:
raise ValueError("Config `buf_chunk_size` must be at least 1.")
def get_dirty_item(chunk_dir):
"""
Get indexes of dirty items in a chunk.
"""
dirty_bit = read_dirty_bit(chunk_dir)
return np.where(dirty_bit)[0].tolist()
def get_clean_item(chunk_dir):
"""
Get indexes of clean items in a chunk.
"""
dirty_bit = read_dirty_bit(chunk_dir)
return np.where(1 - dirty_bit)[0].tolist()
def save_dirty_bit(chunk_dir, dirty_bit):
"""
Save the dirty bit to the chunk directory.
"""
time_stmp = time.time()
while time.time() - time_stmp < 10.0:
try:
file_path = os.path.join(chunk_dir, "dirty_bit")
lock = FileLock(file_path)
lock.acquire_write_lock()
with open(file_path, "wb") as file:
file.write(dirty_bit.tobytes())
lock.release_lock()
return
except KeyboardInterrupt:
lock.release_lock()
raise KeyboardInterrupt
except BaseException:
lock.release_lock()
continue
# raise RuntimeError("Failed to save dirty bit.")
print("Failed to save dirty bit.")
def read_dirty_bit(chunk_dir):
"""
Read the dirty bit from the chunk directory.
"""
# If error occurs, retry
time_stmp = time.time()
while time.time() - time_stmp < 10.0:
try:
file_path = os.path.join(chunk_dir, "dirty_bit")
lock = FileLock(file_path)
lock.acquire_read_lock()
with open(file_path, "rb") as file:
dirty_bit = np.frombuffer(file.read(), dtype=np.uint8).copy()
lock.release_lock()
assert len(dirty_bit) == BUF_CHUNK_SIZE
return dirty_bit
except KeyboardInterrupt:
lock.release_lock()
raise KeyboardInterrupt
except BaseException:
lock.release_lock()
continue
# If failed to read the dirty bit, return all ones for robustness
return np.ones(BUF_CHUNK_SIZE, dtype=np.uint8)
def save_sample(step_dict, chunk_dir, chunk_item_idx):
"""
Save a sample to the chunk directory.
"""
# Save the json content
time_stmp = time.time()
while time.time() - time_stmp < 10.0:
try:
locks = []
json_content = step_dict["json_content"]
file_path = os.path.join(chunk_dir, f"json_content_{chunk_item_idx}.json")
lock = FileLock(file_path)
locks.append(lock)
lock.acquire_write_lock()
with open(file_path, "w") as file:
json.dump(json_content, file, indent=4)
lock.release_lock()
# Save all other tensors in a npz
file_path = os.path.join(chunk_dir, f"sample_{chunk_item_idx}.npz")
lock = FileLock(file_path)
locks.append(lock)
lock.acquire_write_lock()
with open(file_path, "wb") as file:
np.savez(
file,
step_id=step_dict["step_id"].numpy(),
state_chunk=step_dict["state_chunk"].numpy(),
state_chunk_time_mask=step_dict["state_chunk_time_mask"].numpy(),
action_chunk=step_dict["action_chunk"].numpy(),
action_chunk_time_mask=step_dict["action_chunk_time_mask"].numpy(),
state_vec_mask=step_dict["state_vec_mask"].numpy(),
past_frames_0=step_dict["past_frames_0"].numpy(),
past_frames_0_time_mask=step_dict["past_frames_0_time_mask"].numpy(),
past_frames_1=step_dict["past_frames_1"].numpy(),
past_frames_1_time_mask=step_dict["past_frames_1_time_mask"].numpy(),
past_frames_2=step_dict["past_frames_2"].numpy(),
past_frames_2_time_mask=step_dict["past_frames_2_time_mask"].numpy(),
past_frames_3=step_dict["past_frames_3"].numpy(),
past_frames_3_time_mask=step_dict["past_frames_3_time_mask"].numpy(),
state_std=step_dict["state_std"].numpy(),
state_mean=step_dict["state_mean"].numpy(),
state_norm=step_dict["state_norm"].numpy(),
)
lock.release_lock()
return
except KeyboardInterrupt:
for lock in locks:
lock.release_lock()
raise KeyboardInterrupt
except BaseException:
for lock in locks:
lock.release_lock()
continue
# raise RuntimeError("Failed to save sample.")
print("Failed to save sample.")
def run_producer(seed, num_workers, worker_id, fill_up, clean_dirty, dataset_type):
"""
Run the producer.
The producer will first fill up the buffer with samples.
Then it will keep replacing dirty samples
(i.e., samples that have been read by the consumer)
with new samples.
"""
vla_dataset = VLADataset(seed=seed, dataset_type=dataset_type)
chunk_start_idx = worker_id * BUF_NUM_CHUNKS // num_workers
chunk_end_idx = (worker_id + 1) * BUF_NUM_CHUNKS // num_workers
if fill_up:
print(f"Worker {worker_id}: Start filling up the buffer...")
elif clean_dirty:
# Only refresh the dirty bits
print(f"Worker {worker_id}: Start refreshing the dirty bits...")
for chunk_idx in range(chunk_start_idx, chunk_end_idx):
chunk_dir = os.path.join(BUF_PATH, f"chunk_{chunk_idx}")
dirty_bit = np.zeros(BUF_CHUNK_SIZE, dtype=np.uint8)
save_dirty_bit(chunk_dir, dirty_bit)
print(f"Worker {worker_id}: Refreshed the dirty bits.")
fill_chunk_idx = chunk_start_idx
fill_chunk_item_idx = 0
dirty_chunk_idx = chunk_start_idx
dirty_chunk_item_idxs = []
time_stmp = time.time()
for episode_steps in vla_dataset:
for step in episode_steps:
if fill_up and fill_chunk_idx < chunk_end_idx:
# Fill up the buffer
chunk_dir = os.path.join(BUF_PATH, f"chunk_{fill_chunk_idx}")
if fill_chunk_item_idx == 0:
# Create a new chunk
os.makedirs(chunk_dir, exist_ok=True)
# Write the dirty bit of size BUF_CHUNK_SIZE
dirty_bit = np.zeros(BUF_CHUNK_SIZE, dtype=np.uint8)
save_dirty_bit(chunk_dir, dirty_bit)
# Save the sample
save_sample(step, chunk_dir, fill_chunk_item_idx)
# print(f"Filled up chunk {fill_chunk_item_idx+1}/{BUF_CHUNK_SIZE} {fill_chunk_idx+1}/{BUF_NUM_CHUNKS}")
local_fill_chunk_idx = fill_chunk_idx - chunk_start_idx
local_num_chunks = chunk_end_idx - chunk_start_idx
if (local_fill_chunk_idx % 10 == 0
or local_fill_chunk_idx == local_num_chunks - 1) and fill_chunk_item_idx == 0:
print(f"Worker {worker_id}: Filled up chunk {local_fill_chunk_idx+1}/{local_num_chunks}")
fill_chunk_item_idx += 1
if fill_chunk_item_idx == BUF_CHUNK_SIZE:
fill_chunk_idx += 1
fill_chunk_item_idx = 0
if fill_chunk_idx == BUF_NUM_CHUNKS:
print(f"Worker {worker_id}: Buffer filled up. Start replacing dirty samples...")
else:
# Search for the dirty chunk to replace
while len(dirty_chunk_item_idxs) == 0:
dirty_chunk_dir = os.path.join(BUF_PATH, f"chunk_{dirty_chunk_idx}")
dirty_chunk_item_idxs = get_dirty_item(dirty_chunk_dir)
# Print the dirty ratio
if time.time() - time_stmp > 2.0:
dirty_ratio = len(dirty_chunk_item_idxs) / BUF_CHUNK_SIZE
print(f"Worker {worker_id}: Dirty Ratio for Chunk {dirty_chunk_idx}: {dirty_ratio:.2f}")
time_stmp = time.time()
if len(dirty_chunk_item_idxs) > 0:
# Lock the chunk
dirty_bit = np.ones(BUF_CHUNK_SIZE, dtype=np.uint8)
save_dirty_bit(dirty_chunk_dir, dirty_bit)
# Iterate over the chunks
dirty_chunk_idx += 1
if dirty_chunk_idx == chunk_end_idx:
dirty_chunk_idx = chunk_start_idx
# Replace the dirty item
dirty_item_idx = dirty_chunk_item_idxs.pop()
chunk_dir = os.path.join(BUF_PATH, f"chunk_{dirty_chunk_idx}")
# Save the sample
save_sample(step, chunk_dir, dirty_item_idx)
# If we have replaced all dirty items in the chunk
if len(dirty_chunk_item_idxs) == 0:
# Unlock the chunk
dirty_bit = np.zeros(BUF_CHUNK_SIZE, dtype=np.uint8)
save_dirty_bit(dirty_chunk_dir, dirty_bit)
print(f"Worker {worker_id}: Replaced dirty chunk {dirty_chunk_idx}.")
if __name__ == "__main__":
# Args: n_workers, fill_up
parser = argparse.ArgumentParser()
parser.add_argument(
"--n_workers",
type=int,
default=2,
help="Number of parallel workers. It should be less than or equal to the number of chunks.",
)
parser.add_argument(
"--fill_up",
action="store_true",
help="Whether to fill up the buffer before replacing dirty samples.",
)
parser.add_argument(
"--clean_dirty",
action="store_true",
help=
"Whether to clean the dirty bits before replacing dirty samples. This option is ignored when `fill_up` is set.",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Random seed. If not set, the seed will be randomly generated.",
)
parser.add_argument(
"--dataset_type",
type=str,
default="pretrain",
help="Whether to load the pretrain dataset or finetune dataset.",
)
# Run the producer
args = parser.parse_args()
if args.seed is not None:
print(f"Base seed: {args.seed}")
random.seed(args.seed)
processes = []
process_seeds = [random.randint(0, 2**32) for _ in range(args.n_workers)]
print(f"Process seeds: {process_seeds}")
def signal_handler(sig, frame):
print("Ctrl+C received. Terminating child processes...")
for p in processes:
p.terminate()
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
for worker_id in range(args.n_workers):
p = Process(
target=run_producer,
args=(
process_seeds[worker_id],
args.n_workers,
worker_id,
args.fill_up,
args.clean_dirty,
args.dataset_type,
),
)
p.start()
processes.append(p)
for p in processes:
p.join()

242
RDT-1B/data/utils.py Normal file
View File

@ -0,0 +1,242 @@
import tensorflow as tf
import tensorflow_graphics.geometry.transformation.euler as tf_euler
import tensorflow_graphics.geometry.transformation.quaternion as tf_quat
import tensorflow_graphics.geometry.transformation.rotation_matrix_3d as tf_rotmat
def dataset_to_path(dataset_name: str, dir_name: str) -> str:
"""
Return the path to the dataset.
"""
if (dataset_name == "robo_net" or dataset_name == "cmu_playing_with_food" or dataset_name == "droid"):
version = "1.0.0"
elif (dataset_name == "language_table" or dataset_name == "fmb" or dataset_name == "dobbe"):
version = "0.0.1"
elif dataset_name == "nyu_door_opening_surprising_effectiveness":
version = ""
elif dataset_name == "cmu_play_fusion":
version = ""
elif dataset_name == "berkeley_gnm_recon":
version = ""
else:
version = "0.1.0"
return f"{dir_name}/{dataset_name}/{version}"
def clean_task_instruction(task_instruction: tf.Tensor, replacements: dict) -> tf.Tensor:
"""
Clean up the natural language task instruction.
"""
# Create a function that applies all replacements
def apply_replacements(tensor):
for old, new in replacements.items():
tensor = tf.strings.regex_replace(tensor, old, new)
return tensor
# Apply the replacements and strip leading and trailing spaces
cleaned_task_instruction = apply_replacements(task_instruction)
cleaned_task_instruction = tf.strings.strip(cleaned_task_instruction)
return cleaned_task_instruction
def quaternion_to_euler(quaternion: tf.Tensor) -> tf.Tensor:
"""
Convert a quaternion (x, y, z, w) to Euler angles (roll, pitch, yaw).
The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention.
"""
# Normalize the quaternion
quaternion = tf.nn.l2_normalize(quaternion, axis=-1)
return tf_euler.from_quaternion(quaternion)
def euler_to_quaternion(euler: tf.Tensor) -> tf.Tensor:
"""
Convert Euler angles (roll, pitch, yaw) to a quaternion (x, y, z, w).
The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention.
"""
quaternion = tf_quat.from_euler(euler)
return tf.nn.l2_normalize(quaternion, axis=-1)
def rotation_matrix_to_euler(matrix: tf.Tensor) -> tf.Tensor:
"""
Convert a 3x3 rotation matrix to Euler angles (roll, pitch, yaw).
The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention.
"""
return tf_euler.from_rotation_matrix(matrix)
def rotation_matrix_to_quaternion(matrix: tf.Tensor) -> tf.Tensor:
"""
Convert a 3x3 rotation matrix to a quaternion (x, y, z, w).
"""
quaternion = tf_quat.from_rotation_matrix(matrix)
return tf.nn.l2_normalize(quaternion, axis=-1)
def euler_to_rotation_matrix(euler: tf.Tensor) -> tf.Tensor:
"""
Convert Euler angles (roll, pitch, yaw) to a 3x3 rotation matrix.
The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention.
"""
return tf_rotmat.from_euler(euler)
def quaternion_to_rotation_matrix(quaternion: tf.Tensor) -> tf.Tensor:
"""
Convert a quaternion (x, y, z, w) to a 3x3 rotation matrix.
"""
# Normalize the quaternion
quaternion = tf.nn.l2_normalize(quaternion, axis=-1)
return tf_rotmat.from_quaternion(quaternion)
def quaternion_to_rotation_matrix_wo_static_check(quaternion: tf.Tensor) -> tf.Tensor:
"""
Convert a quaternion (x, y, z, w) to a 3x3 rotation matrix.
This function is used to make tensorflow happy.
"""
# Normalize the quaternion
quaternion = tf.nn.l2_normalize(quaternion, axis=-1)
x = quaternion[..., 0]
y = quaternion[..., 1]
z = quaternion[..., 2]
w = quaternion[..., 3]
tx = 2.0 * x
ty = 2.0 * y
tz = 2.0 * z
twx = tx * w
twy = ty * w
twz = tz * w
txx = tx * x
txy = ty * x
txz = tz * x
tyy = ty * y
tyz = tz * y
tzz = tz * z
matrix = tf.stack(
(
1.0 - (tyy + tzz),
txy - twz,
txz + twy,
txy + twz,
1.0 - (txx + tzz),
tyz - twx,
txz - twy,
tyz + twx,
1.0 - (txx + tyy),
),
axis=-1,
) # pyformat: disable
output_shape = tf.concat((tf.shape(input=quaternion)[:-1], (3, 3)), axis=-1)
return tf.reshape(matrix, shape=output_shape)
"""
Below is a continuous 6D rotation representation adapted from
On the Continuity of Rotation Representations in Neural Networks
https://arxiv.org/pdf/1812.07035.pdf
https://github.com/papagina/RotationContinuity/blob/master/sanity_test/code/tools.py
"""
def rotation_matrix_to_ortho6d(matrix: tf.Tensor) -> tf.Tensor:
"""
The orhto6d represents the first two column vectors a1 and a2 of the
rotation matrix: [ | , |, | ]
[ a1, a2, a3]
[ | , |, | ]
Input: (A1, ..., An, 3, 3)
Output: (A1, ..., An, 6)
"""
ortho6d = matrix[..., :, :2]
# Transpose the last two dimension
perm = list(range(len(ortho6d.shape)))
perm[-2], perm[-1] = perm[-1], perm[-2]
ortho6d = tf.transpose(ortho6d, perm)
# Flatten the last two dimension
ortho6d = tf.reshape(ortho6d, ortho6d.shape[:-2] + [6])
return ortho6d
def rotation_matrix_to_ortho6d_1d(matrix: tf.Tensor) -> tf.Tensor:
"""
The orhto6d represents the first two column vectors a1 and a2 of the
rotation matrix: [ | , |, | ]
[ a1, a2, a3]
[ | , |, | ]
Input: (3, 3)
Output: (6,)
This function is used to make tensorflow happy.
"""
ortho6d = matrix[:, :2]
# Transpose the last two dimension
ortho6d = tf.transpose(ortho6d)
# Flatten the last two dimension
ortho6d = tf.reshape(ortho6d, [6])
return ortho6d
def normalize_vector(v):
"""
v: (..., N)
"""
v_mag = tf.sqrt(tf.reduce_sum(tf.square(v), axis=-1, keepdims=True))
v_mag = tf.maximum(v_mag, 1e-8)
v_normalized = v / v_mag
return v_normalized
def cross_product(u, v):
"""
u: (..., 3)
v: (..., 3)
u x v: (..., 3)
"""
i = u[..., 1] * v[..., 2] - u[..., 2] * v[..., 1]
j = u[..., 2] * v[..., 0] - u[..., 0] * v[..., 2]
k = u[..., 0] * v[..., 1] - u[..., 1] * v[..., 0]
out = tf.stack([i, j, k], axis=-1)
return out
def ortho6d_to_rotation_matrix(ortho6d: tf.Tensor) -> tf.Tensor:
"""
The orhto6d represents the first two column vectors a1 and a2 of the
rotation matrix: [ | , |, | ]
[ a1, a2, a3]
[ | , |, | ]
Input: (A1, ..., An, 6)
Output: (A1, ..., An, 3, 3)
"""
x_raw = ortho6d[..., 0:3]
y_raw = ortho6d[..., 3:6]
x = normalize_vector(x_raw)
z = cross_product(x, y_raw)
z = normalize_vector(z)
y = cross_product(z, x)
# Stack x, y, z to form the matrix
matrix = tf.stack([x, y, z], axis=-1)
return matrix
def capitalize_and_period(instr: str) -> str:
"""
Capitalize the first letter of a string and add a period to the end if it's not there.
"""
if len(instr) > 0:
# if the first letter is not capital, make it so
if not instr[0].isupper():
# if the first letter is not capital, make it so
instr = instr[0].upper() + instr[1:]
# add period to the end if it's not there
if instr[-1] != ".":
# add period to the end if it's not there
instr = instr + "."
return instr

149
RDT-1B/data/vla_dataset.py Normal file
View File

@ -0,0 +1,149 @@
import json
import random
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import yaml
from data.episode_transform import (
process_episode,
flatten_episode,
flatten_episode_agilex,
bgr_to_rgb,
)
from data.utils import dataset_to_path
from data.preprocess_scripts import *
# Producer does not need GPU
tf.config.set_visible_devices([], "GPU")
OPENX_EMBOD_DIR = "data/datasets/openx_embod"
DATASET_NAMES_NOOPENX = [
"aloha_mobile",
"aloha_static",
"roboset",
"agilex",
"rh20t",
"calvin",
"bridgev2",
]
# Read the config
with open("configs/base.yaml", "r") as file:
config = yaml.safe_load(file)
# Load some constants from the config
EPSD_LEN_THRESH_LOW = config["dataset"]["epsd_len_thresh_low"]
EPSD_LEN_THRESH_HIGH = config["dataset"]["epsd_len_thresh_high"]
# Read the image keys of each dataset
with open("configs/dataset_img_keys.json", "r") as file:
IMAGE_KEYS = json.load(file)
class VLADataset:
"""
This class is used to sample episodes from the embododiment dataset.
"""
def __init__(self, seed, dataset_type, repeat=True):
"""
seed: the random seed
dataset_type: 'pretrain' or 'finetune', which dataset to load
repeat: whether to repeat to infinite length
"""
dataset_names_cfg = ("configs/pretrain_datasets.json"
if dataset_type == "pretrain" else "configs/finetune_datasets.json")
with open(dataset_names_cfg, "r") as file:
DATASET_NAMES = json.load(file)
self.dataset_names = DATASET_NAMES
sample_weights_cfg = ("configs/pretrain_sample_weights.json"
if dataset_type == "pretrain" else "configs/finetune_sample_weights.json")
# Load the sample weights
with open(sample_weights_cfg, "r") as file:
SAMPLE_WEIGHTS = json.load(file)
self.openx_dir = OPENX_EMBOD_DIR
self.epsd_len_thresh_low = EPSD_LEN_THRESH_LOW
self.epsd_len_thresh_high = EPSD_LEN_THRESH_HIGH
self.repeat = repeat
# Set the random seed
tf.random.set_seed(seed)
np.random.seed(seed)
# Weights of the each dataset in the collection to sample from
sample_weights = []
self.name2dataset = {}
for dataset_name in self.dataset_names:
if dataset_name in DATASET_NAMES_NOOPENX:
dataset = globals()[dataset_name].load_dataset(seed)
else:
dataset_path = dataset_to_path(dataset_name, self.openx_dir)
dataset = tfds.builder_from_directory(builder_dir=dataset_path)
dataset = dataset.as_dataset(split="all", shuffle_files=True)
# You can add filter for other datasets
if dataset_name == "kuka":
dataset = dataset.filter(lambda x: x["success"])
elif dataset_name == "bc_z":
dataset = dataset.filter(lambda x: tf.math.greater(
next(iter(x["steps"]))["observation"]["episode_success"],
0.5,
))
elif (dataset_name == "ucsd_pick_and_place_dataset_converted_externally_to_rlds"):
dataset = dataset.filter(lambda x: x["episode_metadata"]["success"])
elif (dataset_name == "utokyo_xarm_bimanual_converted_externally_to_rlds"):
# Only preserve the meaningful episodes
dataset = dataset.filter(lambda x: tf.math.equal(
next(iter(x["steps"]))["language_instruction"],
tf.constant("Unfold a wrinkled towel."),
))
# Note: use cache() will cause the unexpected crash
# dataset = dataset.map().cache().shuffle().repeat()
dataset = dataset.map(lambda x: process_episode(
x,
dataset_name,
IMAGE_KEYS[dataset_name]["image_keys"],
IMAGE_KEYS[dataset_name]["image_mask"],
))
# Change BGR to RGB if needed
if dataset_name == "fmb":
dataset = dataset.map(bgr_to_rgb)
if self.repeat:
dataset = dataset.repeat()
self.name2dataset[dataset_name] = iter(dataset)
sample_weights.append(SAMPLE_WEIGHTS[dataset_name])
# Normalize the sample weights
sample_weights = np.array(sample_weights)
self.sample_weights = sample_weights / np.sum(sample_weights)
def __iter__(self):
"""
Sample batches of episodes for an epoch.
"""
while True:
dataset_name = np.random.choice(self.dataset_names, p=self.sample_weights)
episode = next(self.name2dataset[dataset_name])
if dataset_name == "agilex":
episode_steps = flatten_episode_agilex(episode)
else:
episode_steps = flatten_episode(episode)
# Filter too short
if len(episode_steps) < self.epsd_len_thresh_low:
continue
# Randomly sample too long
if len(episode_steps) > self.epsd_len_thresh_high:
episode_steps = random.sample(episode_steps, self.epsd_len_thresh_high)
yield episode_steps
if __name__ == "__main__":
dataset = VLADataset(0, "finetune")
for episode in dataset:
print(episode[0])
break

107
RDT-1B/finetune.sh Normal file
View File

@ -0,0 +1,107 @@
#!/bin/bash
BEGIN_TIME=$(date +%s)
CONFIG_NAME="Train_Config_Default"
CONFIG_FILE="model_config/$CONFIG_NAME.yml"
echo "CONFIG_FILE_PATH: $CONFIG_FILE"
### ============Read Input Config and ReLoad Config YAML===================
ln -s /home/qi.xiong/Temp/RDT-1B/input/weights ../weights
TRAIN_CONFIG_FILE="input/config.json"
echo "TRAIN_CONFIG_FILE_PATH: $TRAIN_CONFIG_FILE"
python scripts/read_config.py "$TRAIN_CONFIG_FILE" "$CONFIG_FILE"
### ============Read Input Config and ReLoad Config YAML===================
export NCCL_IB_HCA=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
export NCCL_DEBUG=INFO
export NCCL_NVLS_ENABLE=0
export NCCL_SOCKET_IFNAME=eth0
# export TEXT_ENCODER_NAME="google/t5-v1_1-xxl"
export VISION_ENCODER_NAME="../weights/siglip-so400m-patch14-384"
export CFLAGS="-I/usr/include"
export LDFLAGS="-L/usr/lib/x86_64-linux-gnu"
export WANDB_PROJECT="RDT-1B"
export WANDB_DEFAULT_RUN_NAME=$CONFIG_NAME
export NCCL_P2P_DISABLE=1
export NCCL_IB_DISABLE=1
# check if YAML exist
if [ ! -f "$CONFIG_FILE" ]; then
echo "Config file $CONFIG_FILE does not exist!"
exit 1
fi
PRETRAINED_MODEL_NAME=$(python scripts/read_yaml.py "$CONFIG_FILE" pretrained_model_name_or_path)
TRAIN_BATCH_SIZE=$(python scripts/read_yaml.py "$CONFIG_FILE" train_batch_size)
SAMPLE_BATCH_SIZE=$(python scripts/read_yaml.py "$CONFIG_FILE" sample_batch_size)
MAX_TRAIN_STEPS=$(python scripts/read_yaml.py "$CONFIG_FILE" max_train_steps)
CHECKPOINTING_PERIOD=$(python scripts/read_yaml.py "$CONFIG_FILE" checkpointing_period)
SAMPLE_PERIOD=$(python scripts/read_yaml.py "$CONFIG_FILE" sample_period)
CHECKPOINTS_TOTAL_LIMIT=$(python scripts/read_yaml.py "$CONFIG_FILE" checkpoints_total_limit)
LR_SCHEDULER=$(python scripts/read_yaml.py "$CONFIG_FILE" lr_scheduler)
LEARNING_RATE=$(python scripts/read_yaml.py "$CONFIG_FILE" learning_rate)
DATALOADER_NUM_WORKERS=$(python scripts/read_yaml.py "$CONFIG_FILE" dataloader_num_workers)
DATASET_TYPE=$(python scripts/read_yaml.py "$CONFIG_FILE" dataset_type)
STATE_NOISE_SNR=$(python scripts/read_yaml.py "$CONFIG_FILE" state_noise_snr)
GRAD_ACCUM_STEPS=$(python scripts/read_yaml.py "$CONFIG_FILE" gradient_accumulation_steps)
OUTPUT_DIR=$(python scripts/read_yaml.py "$CONFIG_FILE" checkpoint_path)
CUDA_USE=$(python scripts/read_yaml.py "$CONFIG_FILE" cuda_visible_device)
PRETRAINED_MODEL_NAME=$(echo "$PRETRAINED_MODEL_NAME" | tr -d '"')
CUDA_USE=$(echo "$CUDA_USE" | tr -d '"')
OUTPUT_DIR=$(echo "$OUTPUT_DIR" | tr -d '"')
# create output path
if [ ! -d "$OUTPUT_DIR" ]; then
mkdir -p "$OUTPUT_DIR"
echo "Created output directory: $OUTPUT_DIR"
else
echo "Output directory already exists: $OUTPUT_DIR"
fi
export CUDA_VISIBLE_DEVICES=$CUDA_USE
python -m data.compute_dataset_stat_hdf5 --task_name $CONFIG_NAME
accelerate launch --main_process_port=28499 main.py \
--deepspeed="./configs/zero2.json" \
--pretrained_model_name_or_path=$PRETRAINED_MODEL_NAME \
--pretrained_text_encoder_name_or_path=$TEXT_ENCODER_NAME \
--pretrained_vision_encoder_name_or_path=$VISION_ENCODER_NAME \
--output_dir=$OUTPUT_DIR \
--train_batch_size=$TRAIN_BATCH_SIZE \
--sample_batch_size=$SAMPLE_BATCH_SIZE \
--max_train_steps=$MAX_TRAIN_STEPS \
--checkpointing_period=$CHECKPOINTING_PERIOD \
--sample_period=$SAMPLE_PERIOD \
--checkpoints_total_limit=$CHECKPOINTS_TOTAL_LIMIT \
--lr_scheduler="constant" \
--learning_rate=$LEARNING_RATE \
--mixed_precision="bf16" \
--dataloader_num_workers=$DATALOADER_NUM_WORKERS \
--image_aug \
--dataset_type="finetune" \
--state_noise_snr=$STATE_NOISE_SNR \
--load_from_hdf5 \
--report_to=wandb \
--precomp_lang_embed \
--gradient_accumulation_steps=$GRAD_ACCUM_STEPS \
--model_config_path=$CONFIG_FILE \
--CONFIG_NAME=$CONFIG_NAME \
--output_log_path=$OUTPUT_DIR/output.log
END_TIME=$(date +%s)
RUNTIME=$((END_TIME - BEGIN_TIME))
echo "Total runtime: $RUNTIME seconds"
### ============Generate Output JSON===================
python scripts/generate_output_json.py "$TRAIN_CONFIG_FILE" "$OUTPUT_DIR" "$RUNTIME"
### ============Generate Output JSON===================

5
RDT-1B/generate.sh Normal file
View File

@ -0,0 +1,5 @@
#!/bin/bash
model_name=${1}
python ./model_config/_generate_model_config.py $model_name

351
RDT-1B/main.py Normal file
View File

@ -0,0 +1,351 @@
import argparse
import os
from train.train import train
from accelerate.logging import get_logger
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Main script for training RDT.")
parser.add_argument(
"--model_config_path",
type=str,
default="model_config/sjoe_place_D435_100_finetune_config.yaml",
help=
"Path to the finetune data and model configuration file. Default is `model_config/sjoe_place_D435_100_finetune_config.yaml`.",
)
parser.add_argument(
"--config_path",
type=str,
default="configs/base.yaml",
help="Path to the configuration file. Default is `configs/base.yaml`.",
)
parser.add_argument(
"--deepspeed",
type=str,
default=None,
help=
"Enable DeepSpeed and pass the path to its config file or an already initialized DeepSpeed config dictionary",
)
parser.add_argument(
"--pretrained_text_encoder_name_or_path",
type=str,
default=None,
help="Pretrained text encoder name or path if not the same as model_name",
)
parser.add_argument(
"--pretrained_vision_encoder_name_or_path",
type=str,
default=None,
help="Pretrained vision encoder name or path if not the same as model_name",
)
parser.add_argument(
"--output_dir",
type=str,
default="checkpoints",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--load_from_hdf5",
action="store_true",
default=False,
help=("Whether to load the dataset directly from HDF5 files. "
"If False, the dataset will be loaded using producer-consumer pattern, "
"where the producer reads TFRecords and saves them to buffer, and the consumer reads from buffer."),
)
parser.add_argument(
"--train_batch_size",
type=int,
default=4,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument(
"--sample_batch_size",
type=int,
default=8,
help="Batch size (per device) for the sampling dataloader.",
)
parser.add_argument(
"--num_sample_batches",
type=int,
default=2,
help="Number of batches to sample from the dataset.",
)
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--checkpointing_period",
type=int,
default=500,
help=
("Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
"instructions."),
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=
("Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more details"),
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=("Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_period`, or `"latest"` to automatically select the last available checkpoint.'),
)
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
help=(
"Path or name of a pretrained checkpoint to load the model from.\n",
" This can be either:\n"
" - a string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co, e.g., `robotics-diffusion-transformer/rdt-1b`,\n"
" - a path to a *directory* containing model weights saved using [`~RDTRunner.save_pretrained`] method, e.g., `./my_model_directory/`.\n"
" - a path to model checkpoint (*.pt), .e.g, `my_model_directory/checkpoint-10000/pytorch_model/mp_rank_00_model_states.pt`"
" - `None` if you are randomly initializing model using configuration at `config_path`.",
),
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-6,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--cond_mask_prob",
type=float,
default=0.1,
help=("The probability to randomly mask the conditions (except states) during training. "
"If set to 0, the conditions are not masked."),
)
parser.add_argument(
"--cam_ext_mask_prob",
type=float,
default=-1.0,
help=("The probability to randomly mask the external camera image during training. "
"If set to < 0, the external camera image is masked with the probability of `cond_mask_prob`."),
)
parser.add_argument(
"--state_noise_snr",
type=float,
default=None,
help=("The signal-to-noise ratio (SNR, unit: dB) for adding noise to the states. "
"Default is None, which means no noise is added."),
)
parser.add_argument(
"--image_aug",
action="store_true",
default=False,
help="Whether or not to apply image augmentation (ColorJitter, blur, noise, etc) to the input images.",
)
parser.add_argument(
"--precomp_lang_embed",
action="store_true",
default=False,
help="Whether or not to use precomputed language embeddings.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'),
)
parser.add_argument(
"--lr_warmup_steps",
type=int,
default=500,
help="Number of steps for the warmup in the lr scheduler.",
)
parser.add_argument(
"--lr_num_cycles",
type=int,
default=1,
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
)
parser.add_argument(
"--lr_power",
type=float,
default=1.0,
help="Power factor of the polynomial scheduler.",
)
parser.add_argument(
"--use_8bit_adam",
action="store_true",
help="Whether or not to use 8-bit Adam from bitsandbytes.",
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument(
"--alpha",
type=float,
default=0.9,
help="The moving average coefficient for each dataset's loss.",
)
parser.add_argument(
"--adam_beta1",
type=float,
default=0.9,
help="The beta1 parameter for the Adam optimizer.",
)
parser.add_argument(
"--adam_beta2",
type=float,
default=0.999,
help="The beta2 parameter for the Adam optimizer.",
)
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument(
"--adam_epsilon",
type=float,
default=1e-08,
help="Epsilon value for the Adam optimizer",
)
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument(
"--push_to_hub",
action="store_true",
help="Whether or not to push the model to the Hub.",
)
parser.add_argument(
"--hub_token",
type=str,
default=None,
help="The token to use to push to the Model Hub.",
)
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=("Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=('The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'),
)
parser.add_argument(
"--sample_period",
type=int,
default=-1,
help=("Run sampling every X steps. During the sampling phase, the model will sample a trajectory"
" and report the error between the sampled trajectory and groud-truth trajectory"
" in the training batch."),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."),
)
parser.add_argument(
"--local_rank",
type=int,
default=-1,
help="For distributed training: local_rank",
)
parser.add_argument(
"--set_grads_to_none",
action="store_true",
help=("Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
" behaviors, so disable this argument if it causes any problems. More info:"
" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"),
)
parser.add_argument(
"--dataset_type",
type=str,
default="pretrain",
required=False,
help="Whether to load the pretrain dataset or finetune dataset.",
)
parser.add_argument(
"--CONFIG_NAME",
type=str,
default="Null",
required=True,
)
parser.add_argument(
"--output_log_path",
type=str,
default="output/output.log",
help="The path to the output log file.",
)
if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
return args
if __name__ == "__main__":
logger = get_logger(__name__)
args = parse_args()
train(args, logger)

269
RDT-1B/model.py Normal file
View File

@ -0,0 +1,269 @@
#!/home/lin/software/miniconda3/envs/aloha/bin/python
# -- coding: UTF-8
"""
#!/usr/bin/python3
"""
from pathlib import Path
# get current workspace
current_file = Path(__file__)
import json
import sys
parent_dir = current_file.parent
sys.path.append(str(parent_dir))
import os
import argparse
import threading
import time
import yaml
from collections import deque
import numpy as np
import torch
from PIL import Image as PImage
import cv2
import sys, os
# get current workspace
current_file = Path(__file__)
sys.path.append(os.path.join(current_file.parent, "models"))
from scripts.agilex_model import create_model
from multimodal_encoder.t5_encoder import T5Embedder
global_path = parent_dir.parent
class RDT:
def __init__(
self,
pretrained_model_name_or_path,
task_name,
left_arm_dim,
right_arm_dim,
rdt_step,
):
# set path
current_file = Path(__file__)
self.global_path = current_file.parent.parent
# load the config
self.config = {
"episode_len": 10000, # args.max_publish_step
"state_dim": left_arm_dim + 1 + right_arm_dim +
1, # 14 dims action:[left joint angles,left gripper,right joint angles,right gripper]
"chunk_size": 64, # args.chunk_size
"camera_names": ["cam_high", "cam_right_wrist", "cam_left_wrist"],
}
# setup config
self.args = {
"max_publish_step": 10000, # Maximum number of action publishing steps
"seed": None, # Random seed
"ctrl_freq": 25, # The control frequency of the robot
"chunk_size": 64, # Action chunk size
# 'disable_puppet_arm': False, # Whether to disable the puppet arm
"config_path": os.path.join(self.global_path, "RDT/configs/base.yaml"),
"pretrained_model_name_or_path": pretrained_model_name_or_path,
}
# Load rdt model
self.left_arm_dim, self.right_arm_dim = left_arm_dim, right_arm_dim
self.policy = self.make_policy(self.args)
self.max_publish_step = self.config["episode_len"]
self.chunk_size = self.config["chunk_size"]
self.task_name = task_name
self.observation_window = None
self.img_size = (640, 480)
self.set_language_embed()
self.rdt_step = rdt_step
# set img_size
def set_img_size(self, img_size):
self.img_size = img_size
def set_language_embed(self):
GPU = 0
MODEL_PATH = os.path.join(self.global_path, "weights/RDT/t5-v1_1-xxl")
CONFIG_PATH = os.path.join(self.global_path, "RDT/configs/base.yaml")
with open(CONFIG_PATH, "r") as fp:
config = yaml.safe_load(fp)
device = torch.device(f"cuda:{GPU}")
text_embedder = T5Embedder(
from_pretrained=MODEL_PATH,
model_max_length=config["dataset"]["tokenizer_max_length"],
device=device,
use_offload_folder=None,
)
self.tokenizer, self.text_encoder = text_embedder.tokenizer, text_embedder.model
self.text_encoder.eval()
# set language randomly
def random_set_language(self, instruction=None):
assert instruction is not None, "Missing input instruction"
self.set_language_instruction(instruction)
# encoding language
def set_language_instruction(self, language_instruction, save_dir=None, task_name=None):
assert ((save_dir is None) ^ (task_name is None)) == False, "input error"
if os.path.isfile(language_instruction):
lang_dict = torch.load(language_instruction)
print(f"Running with instruction: \"{lang_dict['instruction']}\" from \"{lang_dict['name']}\"")
self.lang_embeddings = lang_dict["embeddings"]
print("loading instruction from pre-embed path")
else:
device = next(self.text_encoder.parameters()).device
with torch.no_grad():
tokens = self.tokenizer(
language_instruction,
return_tensors="pt",
padding="longest",
truncation=True,
)["input_ids"].to(device)
tokens = tokens.view(1, -1)
output = self.text_encoder(tokens)
pred = output.last_hidden_state.detach().cpu()
if save_dir is not None:
save_path = os.path.join(save_dir, f"{task_name}.pt")
torch.save({
"name": task_name,
"instruction": language_instruction,
"embeddings": pred,
}, save_path)
del tokens, output
torch.cuda.empty_cache()
self.lang_embeddings = pred
print(f"successfully set instruction: {language_instruction}")
# Update the observation window buffer
def update_observation_window(self, img_arr, state):
# JPEG transformation
# Align with training
def jpeg_mapping(img):
if img is None:
return None
img = cv2.imencode(".jpg", img)[1].tobytes()
img = cv2.imdecode(np.frombuffer(img, np.uint8), cv2.IMREAD_COLOR)
return img
def resize_img(img, size):
return cv2.resize(img, size)
if self.observation_window is None:
self.observation_window = deque(maxlen=2)
# Append the first dummy image
self.observation_window.append({
"qpos": None,
"images": {
self.config["camera_names"][0]: None,
self.config["camera_names"][1]: None,
self.config["camera_names"][2]: None,
},
})
img_front, img_right, img_left, puppet_arm = (
img_arr[0],
img_arr[1],
img_arr[2],
state,
)
# img resize
img_front = resize_img(img_front, self.img_size)
img_left = resize_img(img_left, self.img_size)
img_right = resize_img(img_right, self.img_size)
# img jprg encoding
img_front = jpeg_mapping(img_front)
img_left = jpeg_mapping(img_left)
img_right = jpeg_mapping(img_right)
qpos = np.array(puppet_arm)
qpos = torch.from_numpy(qpos).float().cuda()
self.observation_window.append({
"qpos": qpos,
"images": {
self.config["camera_names"][0]: img_front,
self.config["camera_names"][1]: img_right,
self.config["camera_names"][2]: img_left,
},
})
def get_action(self, img_arr=None, state=None):
assert (img_arr is None) ^ (state is None) == False, "input error"
if (img_arr is not None) and (state is not None):
self.update_observation_window(img_arr, state)
with torch.inference_mode():
action_buffer = inference_fn(self.config, self.policy, self.lang_embeddings, self.observation_window).copy()
return action_buffer
def reset_obsrvationwindows(self):
self.lang_embeddings = None
self.observation_window = None
print("successfully unset obs and language intruction")
# Initialize the model
def make_policy(self, args):
with open(args["config_path"], "r") as fp:
config_base_yaml = yaml.safe_load(fp)
args["config"] = config_base_yaml
args["config"]["arm_dim"] = {
"left_arm_dim": self.left_arm_dim,
"right_arm_dim": self.right_arm_dim,
}
# pretrained_text_encoder_name_or_path = "weights/RDT/t5-v1_1-xxl"
pretrained_vision_encoder_name_or_path = os.path.join(self.global_path, "weights/RDT/siglip-so400m-patch14-384")
model = create_model(
args=args["config"],
dtype=torch.bfloat16,
pretrained=args["pretrained_model_name_or_path"],
# pretrained_text_encoder_name_or_path=pretrained_text_encoder_name_or_path,
pretrained_vision_encoder_name_or_path=pretrained_vision_encoder_name_or_path,
control_frequency=args["ctrl_freq"],
)
return model
# RDT inference
def inference_fn(config, policy, lang_embeddings, observation_window):
# print(f"Start inference_thread_fn: t={t}")
while True:
time1 = time.time()
# fetch images in sequence [front, right, left]
image_arrs = [
observation_window[-2]["images"][config["camera_names"][0]],
observation_window[-2]["images"][config["camera_names"][1]],
observation_window[-2]["images"][config["camera_names"][2]],
observation_window[-1]["images"][config["camera_names"][0]],
observation_window[-1]["images"][config["camera_names"][1]],
observation_window[-1]["images"][config["camera_names"][2]],
]
images = [PImage.fromarray(arr) if arr is not None else None for arr in image_arrs]
# get last qpos in shape [14, ]
proprio = observation_window[-1]["qpos"]
# unsqueeze to [1, 14]
proprio = proprio.unsqueeze(0)
# actions shaped as [1, 64, 14] in format [left, right]
actions = (policy.step(proprio=proprio, images=images, text_embeds=lang_embeddings).squeeze(0).cpu().numpy())
# print(f"inference_actions: {actions.squeeze()}")
# print(f"Model inference time: {time.time() - time1} s")
# print(f"Finish inference_thread_fn: t={t}")
return actions

View File

@ -0,0 +1,40 @@
import os
import yaml
import argparse
from datetime import datetime
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate finetune config.")
parser.add_argument("model_name", type=str, help="The name of the task (e.g., beat_block_hammer)")
args = parser.parse_args()
model_name = args.model_name
fintune_data_path = os.path.join("training_data/", f"{model_name}")
checkpoint_path = os.path.join("checkpoints/", f"{model_name}")
data = {
"model": model_name,
"data_path": fintune_data_path,
"checkpoint_path": checkpoint_path,
"pretrained_model_name_or_path": "../weights/RDT/rdt-1b",
"cuda_visible_device": "...", # args.gpu_use,
"train_batch_size": 32,
"sample_batch_size": 64,
"max_train_steps": 20000,
"checkpointing_period": 2500,
"sample_period": 100,
"checkpoints_total_limit": 40,
"learning_rate": 1e-4,
"dataloader_num_workers": 8,
"state_noise_snr": 40,
"gradient_accumulation_steps": 1,
}
task_config_path = os.path.join("model_config/", f"{model_name}.yml")
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
time_comment = f"# Generated on {current_time}\n"
with open(task_config_path, "w") as f:
f.write(time_comment)
yaml.dump(data, f, default_flow_style=False, sort_keys=False)
if not os.path.exists(fintune_data_path):
os.makedirs(fintune_data_path)

View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,82 @@
# Reference: DiffusionPolicy [https://github.com/real-stanford/diffusion_policy]
import torch
from torch.nn.modules.batchnorm import _BatchNorm
class EMAModel:
"""
Exponential Moving Average of models weights
"""
def __init__(self, model, update_after_step=0, inv_gamma=1.0, power=2 / 3, min_value=0.0, max_value=0.9999):
"""
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
at 215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 2/3.
min_value (float): The minimum EMA decay rate. Default: 0.
"""
self.averaged_model = model
self.averaged_model.eval()
self.averaged_model.requires_grad_(False)
self.update_after_step = update_after_step
self.inv_gamma = inv_gamma
self.power = power
self.min_value = min_value
self.max_value = max_value
self.decay = 0.0
self.optimization_step = 0
def get_decay(self, optimization_step):
"""
Compute the decay factor for the exponential moving average.
"""
step = max(0, optimization_step - self.update_after_step - 1)
value = 1 - (1 + step / self.inv_gamma)**-self.power
if step <= 0:
return 0.0
return max(self.min_value, min(value, self.max_value))
@torch.no_grad()
def step(self, new_model):
self.decay = self.get_decay(self.optimization_step)
# old_all_dataptrs = set()
# for param in new_model.parameters():
# data_ptr = param.data_ptr()
# if data_ptr != 0:
# old_all_dataptrs.add(data_ptr)
all_dataptrs = set()
for module, ema_module in zip(new_model.modules(), self.averaged_model.modules()):
for param, ema_param in zip(module.parameters(recurse=False), ema_module.parameters(recurse=False)):
# iterative over immediate parameters only.
if isinstance(param, dict):
raise RuntimeError('Dict parameter not supported')
# data_ptr = param.data_ptr()
# if data_ptr != 0:
# all_dataptrs.add(data_ptr)
if isinstance(module, _BatchNorm):
# skip batchnorms
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
elif not param.requires_grad:
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
else:
ema_param.mul_(self.decay)
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)
# verify that iterating over module and then parameters is identical to parameters recursively.
# assert old_all_dataptrs == all_dataptrs
self.optimization_step += 1

View File

@ -0,0 +1,75 @@
import os
from pathlib import Path
from typing import Dict, Optional, Union
from huggingface_hub import PyTorchModelHubMixin
from huggingface_hub.constants import (PYTORCH_WEIGHTS_NAME, SAFETENSORS_SINGLE_FILE)
from huggingface_hub.file_download import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, is_torch_available
if is_torch_available():
import torch # type: ignore
class CompatiblePyTorchModelHubMixin(PyTorchModelHubMixin):
"""Mixin class to load Pytorch models from the Hub."""
def _save_pretrained(self, save_directory: Path) -> None:
"""Save weights from a Pytorch model to a local directory."""
# To bypass saving into safetensor by default
model_to_save = self.module if hasattr(self, "module") else self # type: ignore
torch.save(model_to_save.state_dict(), save_directory / PYTORCH_WEIGHTS_NAME)
@classmethod
def _from_pretrained(
cls,
*,
model_id: str,
revision: Optional[str],
cache_dir: Optional[Union[str, Path]],
force_download: bool,
proxies: Optional[Dict],
resume_download: Optional[bool],
local_files_only: bool,
token: Union[str, bool, None],
map_location: str = "cpu",
strict: bool = False,
**model_kwargs,
):
"""Load Pytorch pretrained weights and return the loaded model."""
model = cls(**model_kwargs)
if os.path.isdir(model_id):
print("Loading weights from local directory")
try:
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
return cls._load_as_safetensor(model, model_file, map_location, strict)
except FileNotFoundError:
model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
return cls._load_as_pickle(model, model_file, map_location, strict)
else:
try:
model_file = hf_hub_download(
repo_id=model_id,
filename=SAFETENSORS_SINGLE_FILE,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
return cls._load_as_safetensor(model, model_file, map_location, strict)
except EntryNotFoundError:
model_file = hf_hub_download(
repo_id=model_id,
filename=PYTORCH_WEIGHTS_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
return cls._load_as_pickle(model, model_file, map_location, strict)

View File

@ -0,0 +1,159 @@
import torch
import torch.nn as nn
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
class CLIPVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower
self.select_layer = args.mm_vision_select_layer
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
if not delay_load:
self.load_model()
elif getattr(args, 'unfreeze_mm_vision_tower', False):
self.load_model()
else:
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
def load_model(self, device_map=None):
if self.is_loaded:
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
return
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
self.vision_tower.requires_grad_(False)
self.is_loaded = True
def feature_select(self, image_forward_outs):
image_features = image_forward_outs.hidden_states[self.select_layer]
if self.select_feature == 'patch':
image_features = image_features[:, 1:]
elif self.select_feature == 'cls_patch':
image_features = image_features
else:
raise ValueError(f'Unexpected select feature: {self.select_feature}')
return image_features
@torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
output_hidden_states=True)
image_feature = self.feature_select(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
output_hidden_states=True)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches_per_side(self):
return self.config.image_size // self.config.patch_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size)**2
class CLIPVisionTowerS2(CLIPVisionTower):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__(vision_tower, args, delay_load)
self.s2_scales = getattr(args, 's2_scales', '336,672,1008')
self.s2_scales = list(map(int, self.s2_scales.split(',')))
self.s2_scales.sort()
self.s2_split_size = self.s2_scales[0]
self.s2_image_size = self.s2_scales[-1]
try:
from s2wrapper import forward as multiscale_forward
except ImportError:
raise ImportError(
'Package s2wrapper not found! Please install by running: \npip install git+https://github.com/bfshi/scaling_on_scales.git'
)
self.multiscale_forward = multiscale_forward
# change resize/crop size in preprocessing to the largest image size in s2_scale
if not delay_load or getattr(args, 'unfreeze_mm_vision_tower', False):
self.image_processor.size['shortest_edge'] = self.s2_image_size
self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
def load_model(self, device_map=None):
if self.is_loaded:
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
return
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
self.vision_tower.requires_grad_(False)
self.image_processor.size['shortest_edge'] = self.s2_image_size
self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
self.is_loaded = True
@torch.no_grad()
def forward_feature(self, images):
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
output_hidden_states=True)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
return image_features
@torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_feature = self.multiscale_forward(self.forward_feature,
image.unsqueeze(0),
img_sizes=self.s2_scales,
max_split_size=self.s2_split_size)
image_features.append(image_feature)
else:
image_features = self.multiscale_forward(self.forward_feature,
images,
img_sizes=self.s2_scales,
max_split_size=self.s2_split_size)
return image_features
@property
def hidden_size(self):
return self.config.hidden_size * len(self.s2_scales)

View File

@ -0,0 +1,87 @@
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoImageProcessor, AutoModel, Dinov2Model
class DinoV2VisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
if not delay_load:
self.load_model()
elif getattr(args, 'unfreeze_mm_vision_tower', False):
self.load_model()
else:
self.cfg_only = AutoConfig.from_pretrained(self.vision_tower_name)
def load_model(self, device_map=None):
if self.is_loaded:
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
return
self.image_processor = AutoImageProcessor.from_pretrained(self.vision_tower_name)
self.vision_tower = AutoModel.from_pretrained(self.vision_tower_name, device_map=device_map)
self.vision_tower.requires_grad_(False) # FIXME:
self.is_loaded = True
def feature_select(self, image_forward_outs):
image_features = image_forward_outs.last_hidden_state
if self.select_feature == 'patch':
image_features = image_features[:, 1:] # (B, 1369, 1536)
elif self.select_feature == 'cls_patch':
image_features = image_features # (B, 1, 1536)
else:
raise ValueError(f'Unexpected select feature: {self.select_feature}')
return image_features
@torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
image_feature = self.feature_select(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype))
image_features = self.feature_select(image_forward_outs).to(images.dtype)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches_per_side(self):
return self.config.image_size // self.config.patch_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size)**2

View File

@ -0,0 +1,86 @@
import torch
import torch.nn as nn
from transformers import AutoConfig, SiglipImageProcessor, SiglipVisionModel
class SiglipVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
if not delay_load:
self.load_model()
elif getattr(args, 'unfreeze_mm_vision_tower', False):
self.load_model()
else:
self.cfg_only = AutoConfig.from_pretrained(self.vision_tower_name)
def load_model(self, device_map=None):
if self.is_loaded:
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
return
self.image_processor = SiglipImageProcessor.from_pretrained(self.vision_tower_name)
self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
self.vision_tower.eval()
self.is_loaded = True
def feature_select(self, image_forward_outs):
if self.select_feature == 'patch':
image_features = image_forward_outs.last_hidden_state # (B, 729, 1536)
elif self.select_feature == 'cls_patch':
image_features = image_forward_outs.pooler_output # (B, 1, 1536)
else:
raise ValueError(f'Unexpected select feature: {self.select_feature}')
return image_features
@torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
image_feature = self.feature_select(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype))
image_features = self.feature_select(image_forward_outs).to(images.dtype)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches_per_side(self):
return self.config.image_size // self.config.patch_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size)**2

View File

@ -0,0 +1,111 @@
import torch
from transformers import AutoTokenizer, T5EncoderModel
class T5Embedder:
# available_models = ["google/t5-v1_1-xxl"]
def __init__(
self,
device,
from_pretrained=None,
*,
cache_dir=None,
hf_token=None,
use_text_preprocessing=True,
t5_model_kwargs=None,
torch_dtype=None,
use_offload_folder=None,
model_max_length=120,
local_files_only=False,
):
# from_pretrained="google/t5-v1_1-xxl" # zijian
self.device = torch.device(device)
self.torch_dtype = torch_dtype or torch.bfloat16
self.cache_dir = cache_dir
if t5_model_kwargs is None:
t5_model_kwargs = {
"low_cpu_mem_usage": True,
"torch_dtype": self.torch_dtype,
}
if use_offload_folder is not None:
t5_model_kwargs["offload_folder"] = use_offload_folder
t5_model_kwargs["device_map"] = {
"shared": self.device,
"encoder.embed_tokens": self.device,
"encoder.block.0": self.device,
"encoder.block.1": self.device,
"encoder.block.2": self.device,
"encoder.block.3": self.device,
"encoder.block.4": self.device,
"encoder.block.5": self.device,
"encoder.block.6": self.device,
"encoder.block.7": self.device,
"encoder.block.8": self.device,
"encoder.block.9": self.device,
"encoder.block.10": self.device,
"encoder.block.11": self.device,
"encoder.block.12": "disk",
"encoder.block.13": "disk",
"encoder.block.14": "disk",
"encoder.block.15": "disk",
"encoder.block.16": "disk",
"encoder.block.17": "disk",
"encoder.block.18": "disk",
"encoder.block.19": "disk",
"encoder.block.20": "disk",
"encoder.block.21": "disk",
"encoder.block.22": "disk",
"encoder.block.23": "disk",
"encoder.final_layer_norm": "disk",
"encoder.dropout": "disk",
}
else:
t5_model_kwargs["device_map"] = {
"shared": self.device,
"encoder": self.device,
}
self.use_text_preprocessing = use_text_preprocessing
self.hf_token = hf_token
# assert from_pretrained in self.available_models
self.tokenizer = AutoTokenizer.from_pretrained(
from_pretrained,
model_max_length=model_max_length,
cache_dir=cache_dir,
local_files_only=local_files_only,
)
self.model = T5EncoderModel.from_pretrained(
from_pretrained,
cache_dir=cache_dir,
local_files_only=local_files_only,
**t5_model_kwargs,
).eval()
self.model_max_length = model_max_length
def get_text_embeddings(self, texts):
text_tokens_and_mask = self.tokenizer(
texts,
max_length=self.model_max_length,
padding="longest",
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
input_ids = text_tokens_and_mask["input_ids"].to(self.device)
attention_mask = text_tokens_and_mask["attention_mask"].to(self.device)
with torch.no_grad():
text_encoder_embs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
)["last_hidden_state"].detach()
return text_encoder_embs, attention_mask
if __name__ == "__main__":
T5Embedder(from_pretrained="google/t5-v1_1-xxl", device='cuda:7')

Binary file not shown.

Binary file not shown.

304
RDT-1B/models/rdt/blocks.py Normal file
View File

@ -0,0 +1,304 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DiT: https://github.com/facebookresearch/DiT
# GLIDE: https://github.com/openai/glide-text2im
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
# --------------------------------------------------------
import math
from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.jit import Final
from timm.models.vision_transformer import Attention, Mlp, RmsNorm, use_fused_attn
#################################################################################
# Embedding Layers for Timesteps and Condition Inptus #
#################################################################################
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=torch.bfloat16):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
self.dtype = dtype
def timestep_embedding(self, t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(-math.log(max_period) *
torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding.to(self.dtype)
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
#################################################################################
# Cross Attention Layers #
#################################################################################
class CrossAttention(nn.Module):
"""
A cross-attention layer with flash attention.
"""
fused_attn: Final[bool]
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0,
proj_drop: float = 0,
norm_layer: nn.Module = nn.LayerNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.fused_attn = use_fused_attn()
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, c: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
B, N, C = x.shape
_, L, _ = c.shape
q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
kv = self.kv(c).reshape(B, L, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
k, v = kv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
# Prepare attn mask (B, L) to mask the conditioion
if mask is not None:
mask = mask.reshape(B, 1, 1, L)
mask = mask.expand(-1, -1, N, -1)
if self.fused_attn:
x = F.scaled_dot_product_attention(query=q,
key=k,
value=v,
dropout_p=self.attn_drop.p if self.training else 0.,
attn_mask=mask)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
if mask is not None:
attn = attn.masked_fill_(mask.logical_not(), float('-inf'))
attn = attn.softmax(dim=-1)
if self.attn_drop.p > 0:
attn = self.attn_drop(attn)
x = attn @ v
x = x.permute(0, 2, 1, 3).reshape(B, N, C)
x = self.proj(x)
if self.proj_drop.p > 0:
x = self.proj_drop(x)
return x
#################################################################################
# RDT Block #
#################################################################################
class RDTBlock(nn.Module):
"""
A RDT block with cross-attention conditioning.
"""
def __init__(self, hidden_size, num_heads, **block_kwargs):
super().__init__()
self.norm1 = RmsNorm(hidden_size, eps=1e-6)
self.attn = Attention(dim=hidden_size,
num_heads=num_heads,
qkv_bias=True,
qk_norm=True,
norm_layer=RmsNorm,
**block_kwargs)
self.cross_attn = CrossAttention(hidden_size,
num_heads=num_heads,
qkv_bias=True,
qk_norm=True,
norm_layer=RmsNorm,
**block_kwargs)
self.norm2 = RmsNorm(hidden_size, eps=1e-6)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.ffn = Mlp(in_features=hidden_size, hidden_features=hidden_size, act_layer=approx_gelu, drop=0)
self.norm3 = RmsNorm(hidden_size, eps=1e-6)
def forward(self, x, c, mask=None):
origin_x = x
x = self.norm1(x)
x = self.attn(x)
x = x + origin_x
origin_x = x
x = self.norm2(x)
x = self.cross_attn(x, c, mask)
x = x + origin_x
origin_x = x
x = self.norm3(x)
x = self.ffn(x)
x = x + origin_x
return x
class FinalLayer(nn.Module):
"""
The final layer of RDT.
"""
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = RmsNorm(hidden_size, eps=1e-6)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.ffn_final = Mlp(in_features=hidden_size,
hidden_features=hidden_size,
out_features=out_channels,
act_layer=approx_gelu,
drop=0)
def forward(self, x):
x = self.norm_final(x)
x = self.ffn_final(x)
return x
#################################################################################
# Sine/Cosine Positional Embedding Functions #
#################################################################################
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
if not isinstance(pos, np.ndarray):
pos = np.array(pos, dtype=np.float64)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
def get_nd_sincos_pos_embed_from_grid(embed_dim, grid_sizes):
"""
embed_dim: output dimension for each position
grid_sizes: the grids sizes in each dimension (K,).
out: (grid_sizes[0], ..., grid_sizes[K-1], D)
"""
num_sizes = len(grid_sizes)
# For grid size of 1, we do not need to add any positional embedding
num_valid_sizes = len([x for x in grid_sizes if x > 1])
emb = np.zeros(grid_sizes + (embed_dim, ))
# Uniformly divide the embedding dimension for each grid size
dim_for_each_grid = embed_dim // num_valid_sizes
# To make it even
if dim_for_each_grid % 2 != 0:
dim_for_each_grid -= 1
valid_size_idx = 0
for size_idx in range(num_sizes):
grid_size = grid_sizes[size_idx]
if grid_size <= 1:
continue
pos = np.arange(grid_size)
posemb_shape = [1] * len(grid_sizes) + [dim_for_each_grid]
posemb_shape[size_idx] = -1
emb[..., valid_size_idx * dim_for_each_grid:(valid_size_idx + 1) * dim_for_each_grid] += \
get_1d_sincos_pos_embed_from_grid(dim_for_each_grid, pos).reshape(posemb_shape)
valid_size_idx += 1
return emb
def get_multimodal_cond_pos_embed(embed_dim, mm_cond_lens: OrderedDict, embed_modality=True):
"""
Generate position embeddings for multimodal conditions.
mm_cond_lens: an OrderedDict containing
(modality name, modality token length) pairs.
For `"image"` modality, the value can be a multi-dimensional tuple.
If the length < 0, it means there is no position embedding for the modality or grid.
embed_modality: whether to embed the modality information. Default is True.
"""
num_modalities = len(mm_cond_lens)
modality_pos_embed = np.zeros((num_modalities, embed_dim))
if embed_modality:
# Get embeddings for various modalites
# We put it in the first half
modality_sincos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, torch.arange(num_modalities))
modality_pos_embed[:, :embed_dim // 2] = modality_sincos_embed
# The second half is for position embeddings
pos_embed_dim = embed_dim // 2
else:
# The whole embedding is for position embeddings
pos_embed_dim = embed_dim
# Get embeddings for positions inside each modality
c_pos_emb = np.zeros((0, embed_dim))
for idx, (modality, cond_len) in enumerate(mm_cond_lens.items()):
if modality == "image" and \
(isinstance(cond_len, tuple) or isinstance(cond_len, list)):
all_grid_sizes = tuple([abs(x) for x in cond_len])
embed_grid_sizes = tuple([x if x > 0 else 1 for x in cond_len])
cond_sincos_embed = get_nd_sincos_pos_embed_from_grid(pos_embed_dim, embed_grid_sizes)
cond_pos_embed = np.zeros(all_grid_sizes + (embed_dim, ))
cond_pos_embed[..., -pos_embed_dim:] += cond_sincos_embed
cond_pos_embed = cond_pos_embed.reshape((-1, embed_dim))
else:
cond_sincos_embed = get_1d_sincos_pos_embed_from_grid(pos_embed_dim,
torch.arange(cond_len if cond_len > 0 else 1))
cond_pos_embed = np.zeros((abs(cond_len), embed_dim))
cond_pos_embed[:, -pos_embed_dim:] += cond_sincos_embed
cond_pos_embed += modality_pos_embed[idx]
c_pos_emb = np.concatenate([c_pos_emb, cond_pos_embed], axis=0)
return c_pos_emb

156
RDT-1B/models/rdt/model.py Normal file
View File

@ -0,0 +1,156 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DiT: https://github.com/facebookresearch/DiT
# GLIDE: https://github.com/openai/glide-text2im
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
# --------------------------------------------------------
from collections import OrderedDict
import torch
import torch.nn as nn
from pathlib import Path
import sys, os
# get current workspace
current_file = Path(__file__)
sys.path.append(str(current_file.parent.parent))
from rdt.blocks import (FinalLayer, RDTBlock, TimestepEmbedder, get_1d_sincos_pos_embed_from_grid,
get_multimodal_cond_pos_embed)
class RDT(nn.Module):
"""
Class for Robotics Diffusion Transformers.
"""
def __init__(self,
output_dim=128,
horizon=32,
hidden_size=1152,
depth=28,
num_heads=16,
max_lang_cond_len=1024,
img_cond_len=4096,
lang_pos_embed_config=None,
img_pos_embed_config=None,
dtype=torch.bfloat16):
super().__init__()
self.horizon = horizon
self.hidden_size = hidden_size
self.max_lang_cond_len = max_lang_cond_len
self.img_cond_len = img_cond_len
self.dtype = dtype
self.lang_pos_embed_config = lang_pos_embed_config
self.img_pos_embed_config = img_pos_embed_config
self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype)
self.freq_embedder = TimestepEmbedder(hidden_size, dtype=dtype)
# We will use trainable sin-cos embeddings
# [timestep; state; action]
self.x_pos_embed = nn.Parameter(torch.zeros(1, horizon + 3, hidden_size))
# Language conditions
self.lang_cond_pos_embed = nn.Parameter(torch.zeros(1, max_lang_cond_len, hidden_size))
# Image conditions
self.img_cond_pos_embed = nn.Parameter(torch.zeros(1, img_cond_len, hidden_size))
self.blocks = nn.ModuleList([RDTBlock(hidden_size, num_heads) for _ in range(depth)])
self.final_layer = FinalLayer(hidden_size, output_dim)
self.initialize_weights()
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize pos_embed by sin-cos embedding
x_pos_embed = get_multimodal_cond_pos_embed(embed_dim=self.hidden_size,
mm_cond_lens=OrderedDict([
('timestep', 1),
('ctrl_freq', 1),
('state', 1),
('action', self.horizon),
]))
self.x_pos_embed.data.copy_(torch.from_numpy(x_pos_embed).float().unsqueeze(0))
if self.lang_pos_embed_config is None:
lang_cond_pos_embed = get_1d_sincos_pos_embed_from_grid(self.hidden_size,
torch.arange(self.max_lang_cond_len))
else:
lang_cond_pos_embed = get_multimodal_cond_pos_embed(embed_dim=self.hidden_size,
mm_cond_lens=OrderedDict(self.lang_pos_embed_config),
embed_modality=False)
self.lang_cond_pos_embed.data.copy_(torch.from_numpy(lang_cond_pos_embed).float().unsqueeze(0))
if self.img_pos_embed_config is None:
img_cond_pos_embed = get_1d_sincos_pos_embed_from_grid(self.hidden_size, torch.arange(self.img_cond_len))
else:
img_cond_pos_embed = get_multimodal_cond_pos_embed(embed_dim=self.hidden_size,
mm_cond_lens=OrderedDict(self.img_pos_embed_config),
embed_modality=False)
self.img_cond_pos_embed.data.copy_(torch.from_numpy(img_cond_pos_embed).float().unsqueeze(0))
# Initialize timestep and control freq embedding MLP
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.freq_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.freq_embedder.mlp[2].weight, std=0.02)
# Initialize the final layer: zero-out the final linear layer
nn.init.constant_(self.final_layer.ffn_final.fc2.weight, 0)
nn.init.constant_(self.final_layer.ffn_final.fc2.bias, 0)
# Move all the params to given data type:
self.to(self.dtype)
def forward(self, x, freq, t, lang_c, img_c, lang_mask=None, img_mask=None):
"""
Forward pass of RDT.
x: (B, T, D), state + action token sequence, T = horizon + 1,
dimension D is assumed to be the same as the hidden size.
freq: (B,), a scalar indicating control frequency.
t: (B,) or (1,), diffusion timesteps.
lang_c: (B, L_lang, D) or None, language condition tokens (variable length),
dimension D is assumed to be the same as the hidden size.
img_c: (B, L_img, D) or None, image condition tokens (fixed length),
dimension D is assumed to be the same as the hidden size.
lang_mask: (B, L_lang) or None, language condition mask (True for valid).
img_mask: (B, L_img) or None, image condition mask (True for valid).
"""
t = self.t_embedder(t).unsqueeze(1) # (B, 1, D) or (1, 1, D)
freq = self.freq_embedder(freq).unsqueeze(1) # (B, 1, D)
# Append timestep to the input tokens
if t.shape[0] == 1:
t = t.expand(x.shape[0], -1, -1)
x = torch.cat([t, freq, x], dim=1) # (B, T+1, D)
# Add multimodal position embeddings
x = x + self.x_pos_embed
# Note the lang is of variable length
lang_c = lang_c + self.lang_cond_pos_embed[:, :lang_c.shape[1]]
img_c = img_c + self.img_cond_pos_embed
# Forward pass
conds = [lang_c, img_c]
masks = [lang_mask, img_mask]
for i, block in enumerate(self.blocks):
c, mask = conds[i % 2], masks[i % 2]
x = block(x, c, mask) # (B, T+1, D)
# Inject the language condition at the final layer
x = self.final_layer(x) # (B, T+1, out_channels)
# Only preserve the action tokens
x = x[:, -self.horizon:]
return x

246
RDT-1B/models/rdt_runner.py Normal file
View File

@ -0,0 +1,246 @@
import re, sys, os
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.schedulers.scheduling_dpmsolver_multistep import \
DPMSolverMultistepScheduler
from pathlib import Path
# get current workspace
current_file = Path(__file__)
sys.path.append(os.path.join(current_file.parent))
from hub_mixin import CompatiblePyTorchModelHubMixin
from rdt.model import RDT
class RDTRunner(nn.Module,
CompatiblePyTorchModelHubMixin,
repo_url="https://huggingface.co/robotics-diffusion-transformer/rdt-1b"):
def __init__(self,
*,
action_dim,
pred_horizon,
config,
lang_token_dim,
img_token_dim,
state_token_dim,
max_lang_cond_len,
img_cond_len,
lang_pos_embed_config=None,
img_pos_embed_config=None,
dtype=torch.bfloat16):
super(RDTRunner, self).__init__()
# Create diffusion model
hidden_size = config['rdt']['hidden_size']
self.model = RDT(
output_dim=action_dim,
horizon=pred_horizon,
hidden_size=hidden_size,
depth=config['rdt']['depth'],
num_heads=config['rdt']['num_heads'],
max_lang_cond_len=max_lang_cond_len,
img_cond_len=img_cond_len,
lang_pos_embed_config=lang_pos_embed_config,
img_pos_embed_config=img_pos_embed_config,
dtype=dtype,
)
# Create adpators for various conditional inputs
self.lang_adaptor = self.build_condition_adapter(config['lang_adaptor'],
in_features=lang_token_dim,
out_features=hidden_size)
self.img_adaptor = self.build_condition_adapter(config['img_adaptor'],
in_features=img_token_dim,
out_features=hidden_size)
# A `state` refers to an action or a proprioception vector
self.state_adaptor = self.build_condition_adapter(
config['state_adaptor'],
in_features=state_token_dim * 2, # state + state mask (indicator)
out_features=hidden_size)
# Create the noise scheduler
noise_scheduler_config = config['noise_scheduler']
self.noise_scheduler = DDPMScheduler(
num_train_timesteps=noise_scheduler_config['num_train_timesteps'],
beta_schedule=noise_scheduler_config['beta_schedule'],
prediction_type=noise_scheduler_config['prediction_type'],
clip_sample=noise_scheduler_config['clip_sample'],
)
self.noise_scheduler_sample = DPMSolverMultistepScheduler(
num_train_timesteps=noise_scheduler_config['num_train_timesteps'],
beta_schedule=noise_scheduler_config['beta_schedule'],
prediction_type=noise_scheduler_config['prediction_type'],
)
self.num_train_timesteps = noise_scheduler_config['num_train_timesteps']
self.num_inference_timesteps = noise_scheduler_config['num_inference_timesteps']
self.prediction_type = noise_scheduler_config['prediction_type']
self.pred_horizon = pred_horizon
self.action_dim = action_dim
print("Diffusion params: %e" %
sum([p.numel() for p in self.model.parameters()] + [p.numel() for p in self.lang_adaptor.parameters()] +
[p.numel()
for p in self.img_adaptor.parameters()] + [p.numel() for p in self.state_adaptor.parameters()]))
def build_condition_adapter(self, projector_type, in_features, out_features):
projector = None
if projector_type == 'linear':
projector = nn.Linear(in_features, out_features)
else:
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(in_features, out_features)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU(approximate="tanh"))
modules.append(nn.Linear(out_features, out_features))
projector = nn.Sequential(*modules)
if projector is None:
raise ValueError(f'Unknown projector type: {projector_type}')
return projector
def adapt_conditions(self, lang_tokens, img_tokens, state_tokens):
'''
lang_tokens: (batch_size, lang_len, lang_token_dim)
img_tokens: (batch_size, img_len, img_token_dim)
state_tokens: (batch_size, state_len, state_token_dim)
return: adpated (..., hidden_size) for all input tokens
'''
adpated_lang = self.lang_adaptor(lang_tokens)
adpated_img = self.img_adaptor(img_tokens)
adpated_state = self.state_adaptor(state_tokens)
return adpated_lang, adpated_img, adpated_state
def conditional_sample(self, lang_cond, lang_attn_mask, img_cond, state_traj, action_mask, ctrl_freqs):
'''
lang_cond: language conditional data, (batch_size, lang_len, hidden_size).
lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens,
which should be True-False bool tensor.
img_cond: image conditional data, (batch_size, img_len, hidden_size).
state_traj: (batch_size, 1, hidden_size), state trajectory.
action_mask: (batch_size, 1, action_dim), a 0-1 **float** tensor
indicating the valid action dimensions.
ctrl_freqs: (batch_size,), control frequency for each sample.
return: (batch_size, horizon, action_dim)
'''
device = state_traj.device
dtype = state_traj.dtype
noisy_action = torch.randn(size=(state_traj.shape[0], self.pred_horizon, self.action_dim),
dtype=dtype,
device=device)
action_mask = action_mask.expand(-1, self.pred_horizon, -1)
# Set step values
self.noise_scheduler_sample.set_timesteps(self.num_inference_timesteps)
for t in self.noise_scheduler_sample.timesteps:
# Prepare state-action trajectory
action_traj = torch.cat([noisy_action, action_mask], dim=2)
action_traj = self.state_adaptor(action_traj)
state_action_traj = torch.cat([state_traj, action_traj], dim=1)
# Predict the model output
model_output = self.model(state_action_traj,
ctrl_freqs,
t.unsqueeze(-1).to(device),
lang_cond,
img_cond,
lang_mask=lang_attn_mask)
# Compute previous actions: x_t -> x_t-1
noisy_action = self.noise_scheduler_sample.step(model_output, t, noisy_action).prev_sample
noisy_action = noisy_action.to(state_traj.dtype)
# Finally apply the action mask to mask invalid action dimensions
noisy_action = noisy_action * action_mask
return noisy_action
# ========= Train ============
def compute_loss(self, lang_tokens, lang_attn_mask, img_tokens, state_tokens, action_gt, action_mask,
ctrl_freqs) -> torch.Tensor:
'''
lang_tokens: (batch_size, lang_len, lang_token_dim)
lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens,
which should be True-False bool tensor.
img_tokens: (batch_size, img_len, img_token_dim)
state_tokens: (batch_size, 1, state_token_dim)
action_gt: (batch_size, horizon, state_token_dim), ground-truth actions for supervision
action_mask: (batch_size, 1, state_token_dim), a 0-1 **float** tensor.
ctrl_freqs: (batch_size,), control frequency for each sample.
return: loss_value, a scalar tensor
'''
batch_size = lang_tokens.shape[0]
device = lang_tokens.device
# Sample noise that we'll add to the actions
noise = torch.randn(action_gt.shape, dtype=action_gt.dtype, device=device)
# Sample random diffusion timesteps
timesteps = torch.randint(0, self.num_train_timesteps, (batch_size, ), device=device).long()
# Add noise to the clean actions according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_action = self.noise_scheduler.add_noise(action_gt, noise, timesteps)
# Concatenate the state and action tokens to form the input sequence
state_action_traj = torch.cat([state_tokens, noisy_action], dim=1)
# Append the action mask to the input sequence
action_mask = action_mask.expand(-1, state_action_traj.shape[1], -1)
state_action_traj = torch.cat([state_action_traj, action_mask], dim=2)
# Align the dimension with the hidden size
lang_cond, img_cond, state_action_traj = self.adapt_conditions(lang_tokens, img_tokens, state_action_traj)
# Predict the denoised result
pred = self.model(state_action_traj, ctrl_freqs, timesteps, lang_cond, img_cond, lang_mask=lang_attn_mask)
pred_type = self.prediction_type
if pred_type == 'epsilon':
target = noise
elif pred_type == 'sample':
target = action_gt
else:
raise ValueError(f"Unsupported prediction type {pred_type}")
loss = F.mse_loss(pred, target)
return loss
# ========= Inference ============
def predict_action(self, lang_tokens, lang_attn_mask, img_tokens, state_tokens, action_mask, ctrl_freqs):
'''
lang_tokens: (batch_size, lang_len, lang_token_dim)
lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens,
which should be True-False bool tensor.
img_tokens: (batch_size, img_len, img_token_dim)
state_tokens: (batch_size, 1, state_token_dim)
action_mask: (batch_size, 1, action_dim),
which should be a 0-1 **float** tensor.
ctrl_freqs: (batch_size,), control frequency for each sample.
return: (batch_size, horizon, action_dim), predicted action sequence
'''
# Prepare the state and conditions
state_tokens = torch.cat([state_tokens, action_mask], dim=2)
lang_cond, img_cond, state_traj = self.adapt_conditions(lang_tokens, img_tokens, state_tokens)
# Run sampling
action_pred = self.conditional_sample(
lang_cond,
lang_attn_mask,
img_cond,
state_traj,
action_mask,
ctrl_freqs,
)
return action_pred
def forward(self, *args, **kwargs) -> torch.Tensor:
return self.compute_loss(*args, **kwargs)

49
RDT-1B/pretrain.sh Normal file
View File

@ -0,0 +1,49 @@
#!/bin/bash
export NCCL_IB_HCA=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
export NCCL_IB_DISABLE=0
export NCCL_SOCKET_IFNAME=bond0
export NCCL_DEBUG=INFO
export NCCL_NVLS_ENABLE=0
export TEXT_ENCODER_NAME="google/t5-v1_1-xxl"
export VISION_ENCODER_NAME="google/siglip-so400m-patch14-384"
export OUTPUT_DIR="./checkpoints/rdt-pretrain-1b"
export CFLAGS="-I/usr/include"
export LDFLAGS="-L/usr/lib/x86_64-linux-gnu"
export CUTLASS_PATH="/path/to/cutlass"
export WANDB_PROJECT="robotics_diffusion_transformer"
if [ ! -d "$OUTPUT_DIR" ]; then
mkdir "$OUTPUT_DIR"
echo "Folder '$OUTPUT_DIR' created"
else
echo "Folder '$OUTPUT_DIR' already exists"
fi
# For run in a single node/machine
# accelerate launch main.py \
# --deepspeed="./configs/zero2.json" \
# ...
deepspeed --hostfile=hostfile.txt main.py \
--deepspeed="./configs/zero2.json" \
--pretrained_text_encoder_name_or_path=$TEXT_ENCODER_NAME \
--pretrained_vision_encoder_name_or_path=$VISION_ENCODER_NAME \
--output_dir=$OUTPUT_DIR \
--train_batch_size=32 \
--sample_batch_size=64 \
--max_train_steps=1000000 \
--checkpointing_period=1000 \
--sample_period=500 \
--checkpoints_total_limit=40 \
--lr_scheduler="constant" \
--learning_rate=1e-4 \
--mixed_precision="bf16" \
--dataloader_num_workers=8 \
--dataset_type="pretrain" \
--report_to=wandb
# Use this to resume training from some previous checkpoint
# --resume_from_checkpoint="checkpoint-1000" \

View File

@ -0,0 +1,9 @@
#!/bin/bash
task_name=${1}
task_config=${2}
expert_data_num=${3}
gpu_id=${4}
export CUDA_VISIBLE_DEVICES=${gpu_id}
python scripts/process_data.py $task_name $task_config $expert_data_num

23
RDT-1B/requirements.txt Normal file
View File

@ -0,0 +1,23 @@
numpy<2.0
packaging==24.0
wandb==0.17.0
deepspeed==0.14.2
accelerate==0.30.1
diffusers==0.27.2
timm==1.0.3
transformers==4.41.0
sentencepiece==0.2.0
h5py==3.11.0
opencv-python==4.9.0.80
imgaug==0.4.0
pytz>=2020.1
# requirements_data.txt
tfds-nightly==4.9.4.dev202402070044
gsutil==5.27
tensorflow==2.15.0.post1
pillow==10.2.0
pyyaml==6.0.1
tensorflow-graphics==2021.12.3
imageio==2.34.0
imageio-ffmpeg==0.4.9

View File

@ -0,0 +1,941 @@
#!/home/lin/software/miniconda3/envs/aloha/bin/python
# -- coding: UTF-8
"""
#!/usr/bin/python3
"""
import argparse
import sys
import threading
import time
import yaml
from collections import deque
import numpy as np
import rospy
import torch
from cv_bridge import CvBridge
from geometry_msgs.msg import Twist
from nav_msgs.msg import Odometry
from PIL import Image as PImage
from sensor_msgs.msg import Image, JointState
from std_msgs.msg import Header
import cv2
from scripts.agilex_model import create_model
# sys.path.append("./")
CAMERA_NAMES = ["cam_high", "cam_right_wrist", "cam_left_wrist"]
observation_window = None
lang_embeddings = None
# debug
preload_images = None
# Initialize the model
def make_policy(args):
with open(args.config_path, "r") as fp:
config = yaml.safe_load(fp)
args.config = config
# pretrained_text_encoder_name_or_path = "google/t5-v1_1-xxl"
pretrained_vision_encoder_name_or_path = "google/siglip-so400m-patch14-384"
model = create_model(
args=args.config,
dtype=torch.bfloat16,
pretrained=args.pretrained_model_name_or_path,
# pretrained_text_encoder_name_or_path=pretrained_text_encoder_name_or_path,
pretrained_vision_encoder_name_or_path=pretrained_vision_encoder_name_or_path,
control_frequency=args.ctrl_freq,
)
return model
def set_seed(seed):
torch.manual_seed(seed)
np.random.seed(seed)
# Interpolate the actions to make the robot move smoothly
def interpolate_action(args, prev_action, cur_action):
steps = np.concatenate((np.array(args.arm_steps_length), np.array(args.arm_steps_length)), axis=0)
diff = np.abs(cur_action - prev_action)
step = np.ceil(diff / steps).astype(int)
step = np.max(step)
if step <= 1:
return cur_action[np.newaxis, :]
new_actions = np.linspace(prev_action, cur_action, step + 1)
return new_actions[1:]
def get_config(args):
config = {
"episode_len": args.max_publish_step,
"state_dim": 14,
"chunk_size": args.chunk_size,
"camera_names": CAMERA_NAMES,
}
return config
# Get the observation from the ROS topic
def get_ros_observation(args, ros_operator):
rate = rospy.Rate(args.publish_rate)
print_flag = True
while True and not rospy.is_shutdown():
result = ros_operator.get_frame()
if not result:
if print_flag:
print("syn fail when get_ros_observation")
print_flag = False
rate.sleep()
continue
print_flag = True
(
img_front,
img_left,
img_right,
img_front_depth,
img_left_depth,
img_right_depth,
puppet_arm_left,
puppet_arm_right,
robot_base,
) = result
# print(f"sync success when get_ros_observation")
return (img_front, img_left, img_right, puppet_arm_left, puppet_arm_right)
# Update the observation window buffer
def update_observation_window(args, config, ros_operator):
# JPEG transformation
# Align with training
def jpeg_mapping(img):
img = cv2.imencode(".jpg", img)[1].tobytes()
img = cv2.imdecode(np.frombuffer(img, np.uint8), cv2.IMREAD_COLOR)
return img
global observation_window
if observation_window is None:
observation_window = deque(maxlen=2)
# Append the first dummy image
observation_window.append({
"qpos": None,
"images": {
config["camera_names"][0]: None,
config["camera_names"][1]: None,
config["camera_names"][2]: None,
},
})
img_front, img_left, img_right, puppet_arm_left, puppet_arm_right = (get_ros_observation(args, ros_operator))
img_front = jpeg_mapping(img_front)
img_left = jpeg_mapping(img_left)
img_right = jpeg_mapping(img_right)
qpos = np.concatenate(
(np.array(puppet_arm_left.position), np.array(puppet_arm_right.position)),
axis=0,
)
qpos = torch.from_numpy(qpos).float().cuda()
observation_window.append({
"qpos": qpos,
"images": {
config["camera_names"][0]: img_front,
config["camera_names"][1]: img_right,
config["camera_names"][2]: img_left,
},
})
# RDT inference
def inference_fn(args, config, policy, t):
global observation_window
global lang_embeddings
# print(f"Start inference_thread_fn: t={t}")
while True and not rospy.is_shutdown():
time1 = time.time()
# fetch images in sequence [front, right, left]
image_arrs = [
observation_window[-2]["images"][config["camera_names"][0]],
observation_window[-2]["images"][config["camera_names"][1]],
observation_window[-2]["images"][config["camera_names"][2]],
observation_window[-1]["images"][config["camera_names"][0]],
observation_window[-1]["images"][config["camera_names"][1]],
observation_window[-1]["images"][config["camera_names"][2]],
]
# fetch debug images in sequence [front, right, left]
# image_arrs = [
# preload_images[config['camera_names'][0]][max(t - 1, 0)],
# preload_images[config['camera_names'][2]][max(t - 1, 0)],
# preload_images[config['camera_names'][1]][max(t - 1, 0)],
# preload_images[config['camera_names'][0]][t],
# preload_images[config['camera_names'][2]][t],
# preload_images[config['camera_names'][1]][t]
# ]
# # encode the images
# for i in range(len(image_arrs)):
# image_arrs[i] = cv2.imdecode(np.frombuffer(image_arrs[i], np.uint8), cv2.IMREAD_COLOR)
# proprio = torch.from_numpy(preload_images['qpos'][t]).float().cuda()
images = [PImage.fromarray(arr) if arr is not None else None for arr in image_arrs]
# for i, pos in enumerate(['f', 'r', 'l'] * 2):
# images[i].save(f'{t}-{i}-{pos}.png')
# get last qpos in shape [14, ]
proprio = observation_window[-1]["qpos"]
# unsqueeze to [1, 14]
proprio = proprio.unsqueeze(0)
# actions shaped as [1, 64, 14] in format [left, right]
actions = (policy.step(proprio=proprio, images=images, text_embeds=lang_embeddings).squeeze(0).cpu().numpy())
# print(f"inference_actions: {actions.squeeze()}")
# print(f"Model inference time: {time.time() - time1} s")
# print(f"Finish inference_thread_fn: t={t}")
return actions
# Main loop for the manipulation task
def model_inference(args, config, ros_operator):
global lang_embeddings
# Load rdt model
policy = make_policy(args)
lang_dict = torch.load(args.lang_embeddings_path)
print(f"Running with instruction: \"{lang_dict['instruction']}\" from \"{lang_dict['name']}\"")
lang_embeddings = lang_dict["embeddings"]
max_publish_step = config["episode_len"]
chunk_size = config["chunk_size"]
# Initialize position of the puppet arm
left0 = [
-0.00133514404296875,
0.00209808349609375,
0.01583099365234375,
-0.032616615295410156,
-0.00286102294921875,
0.00095367431640625,
3.557830810546875,
]
right0 = [
-0.00133514404296875,
0.00438690185546875,
0.034523963928222656,
-0.053597450256347656,
-0.00476837158203125,
-0.00209808349609375,
3.557830810546875,
]
left1 = [
-0.00133514404296875,
0.00209808349609375,
0.01583099365234375,
-0.032616615295410156,
-0.00286102294921875,
0.00095367431640625,
-0.3393220901489258,
]
right1 = [
-0.00133514404296875,
0.00247955322265625,
0.01583099365234375,
-0.032616615295410156,
-0.00286102294921875,
0.00095367431640625,
-0.3397035598754883,
]
ros_operator.puppet_arm_publish_continuous(left0, right0)
input("Press enter to continue")
ros_operator.puppet_arm_publish_continuous(left1, right1)
# Initialize the previous action to be the initial robot state
pre_action = np.zeros(config["state_dim"])
pre_action[:14] = np.array([
-0.00133514404296875,
0.00209808349609375,
0.01583099365234375,
-0.032616615295410156,
-0.00286102294921875,
0.00095367431640625,
-0.3393220901489258,
] + [
-0.00133514404296875,
0.00247955322265625,
0.01583099365234375,
-0.032616615295410156,
-0.00286102294921875,
0.00095367431640625,
-0.3397035598754883,
])
action = None
# Inference loop
with torch.inference_mode():
while True and not rospy.is_shutdown():
# The current time step
t = 0
rate = rospy.Rate(args.publish_rate)
action_buffer = np.zeros([chunk_size, config["state_dim"]])
while t < max_publish_step and not rospy.is_shutdown():
# Update observation window
update_observation_window(args, config, ros_operator)
# When coming to the end of the action chunk
if t % chunk_size == 0:
# Start inference
action_buffer = inference_fn(args, config, policy, t).copy()
raw_action = action_buffer[t % chunk_size]
action = raw_action
# Interpolate the original action sequence
if args.use_actions_interpolation:
# print(f"Time {t}, pre {pre_action}, act {action}")
interp_actions = interpolate_action(args, pre_action, action)
else:
interp_actions = action[np.newaxis, :]
# Execute the interpolated actions one by one
for act in interp_actions:
left_action = act[:7]
right_action = act[7:14]
if not args.disable_puppet_arm:
ros_operator.puppet_arm_publish(left_action,
right_action) # puppet_arm_publish_continuous_thread
if args.use_robot_base:
vel_action = act[14:16]
ros_operator.robot_base_publish(vel_action)
rate.sleep()
# print(f"doing action: {act}")
t += 1
print("Published Step", t)
pre_action = action.copy()
# ROS operator class
class RosOperator:
def __init__(self, args):
self.robot_base_deque = None
self.puppet_arm_right_deque = None
self.puppet_arm_left_deque = None
self.img_front_deque = None
self.img_right_deque = None
self.img_left_deque = None
self.img_front_depth_deque = None
self.img_right_depth_deque = None
self.img_left_depth_deque = None
self.bridge = None
self.puppet_arm_left_publisher = None
self.puppet_arm_right_publisher = None
self.robot_base_publisher = None
self.puppet_arm_publish_thread = None
self.puppet_arm_publish_lock = None
self.args = args
self.init()
self.init_ros()
def init(self):
self.bridge = CvBridge()
self.img_left_deque = deque()
self.img_right_deque = deque()
self.img_front_deque = deque()
self.img_left_depth_deque = deque()
self.img_right_depth_deque = deque()
self.img_front_depth_deque = deque()
self.puppet_arm_left_deque = deque()
self.puppet_arm_right_deque = deque()
self.robot_base_deque = deque()
self.puppet_arm_publish_lock = threading.Lock()
self.puppet_arm_publish_lock.acquire()
def puppet_arm_publish(self, left, right):
joint_state_msg = JointState()
joint_state_msg.header = Header()
joint_state_msg.header.stamp = rospy.Time.now() # Set timestep
joint_state_msg.name = [
"joint0",
"joint1",
"joint2",
"joint3",
"joint4",
"joint5",
"joint6",
] # 设置关节名称
joint_state_msg.position = left
self.puppet_arm_left_publisher.publish(joint_state_msg)
joint_state_msg.position = right
self.puppet_arm_right_publisher.publish(joint_state_msg)
def robot_base_publish(self, vel):
vel_msg = Twist()
vel_msg.linear.x = vel[0]
vel_msg.linear.y = 0
vel_msg.linear.z = 0
vel_msg.angular.x = 0
vel_msg.angular.y = 0
vel_msg.angular.z = vel[1]
self.robot_base_publisher.publish(vel_msg)
def puppet_arm_publish_continuous(self, left, right):
rate = rospy.Rate(self.args.publish_rate)
left_arm = None
right_arm = None
while True and not rospy.is_shutdown():
if len(self.puppet_arm_left_deque) != 0:
left_arm = list(self.puppet_arm_left_deque[-1].position)
if len(self.puppet_arm_right_deque) != 0:
right_arm = list(self.puppet_arm_right_deque[-1].position)
if left_arm is None or right_arm is None:
rate.sleep()
continue
else:
break
left_symbol = [1 if left[i] - left_arm[i] > 0 else -1 for i in range(len(left))]
right_symbol = [1 if right[i] - right_arm[i] > 0 else -1 for i in range(len(right))]
flag = True
step = 0
while flag and not rospy.is_shutdown():
if self.puppet_arm_publish_lock.acquire(False):
return
left_diff = [abs(left[i] - left_arm[i]) for i in range(len(left))]
right_diff = [abs(right[i] - right_arm[i]) for i in range(len(right))]
flag = False
for i in range(len(left)):
if left_diff[i] < self.args.arm_steps_length[i]:
left_arm[i] = left[i]
else:
left_arm[i] += left_symbol[i] * self.args.arm_steps_length[i]
flag = True
for i in range(len(right)):
if right_diff[i] < self.args.arm_steps_length[i]:
right_arm[i] = right[i]
else:
right_arm[i] += right_symbol[i] * self.args.arm_steps_length[i]
flag = True
joint_state_msg = JointState()
joint_state_msg.header = Header()
joint_state_msg.header.stamp = rospy.Time.now() # Set the timestep
joint_state_msg.name = [
"joint0",
"joint1",
"joint2",
"joint3",
"joint4",
"joint5",
"joint6",
] # 设置关节名称
joint_state_msg.position = left_arm
self.puppet_arm_left_publisher.publish(joint_state_msg)
joint_state_msg.position = right_arm
self.puppet_arm_right_publisher.publish(joint_state_msg)
step += 1
print("puppet_arm_publish_continuous:", step)
rate.sleep()
def puppet_arm_publish_linear(self, left, right):
num_step = 100
rate = rospy.Rate(200)
left_arm = None
right_arm = None
while True and not rospy.is_shutdown():
if len(self.puppet_arm_left_deque) != 0:
left_arm = list(self.puppet_arm_left_deque[-1].position)
if len(self.puppet_arm_right_deque) != 0:
right_arm = list(self.puppet_arm_right_deque[-1].position)
if left_arm is None or right_arm is None:
rate.sleep()
continue
else:
break
traj_left_list = np.linspace(left_arm, left, num_step)
traj_right_list = np.linspace(right_arm, right, num_step)
for i in range(len(traj_left_list)):
traj_left = traj_left_list[i]
traj_right = traj_right_list[i]
traj_left[-1] = left[-1]
traj_right[-1] = right[-1]
joint_state_msg = JointState()
joint_state_msg.header = Header()
joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳
joint_state_msg.name = [
"joint0",
"joint1",
"joint2",
"joint3",
"joint4",
"joint5",
"joint6",
] # 设置关节名称
joint_state_msg.position = traj_left
self.puppet_arm_left_publisher.publish(joint_state_msg)
joint_state_msg.position = traj_right
self.puppet_arm_right_publisher.publish(joint_state_msg)
rate.sleep()
def puppet_arm_publish_continuous_thread(self, left, right):
if self.puppet_arm_publish_thread is not None:
self.puppet_arm_publish_lock.release()
self.puppet_arm_publish_thread.join()
self.puppet_arm_publish_lock.acquire(False)
self.puppet_arm_publish_thread = None
self.puppet_arm_publish_thread = threading.Thread(target=self.puppet_arm_publish_continuous, args=(left, right))
self.puppet_arm_publish_thread.start()
def get_frame(self):
if (len(self.img_left_deque) == 0 or len(self.img_right_deque) == 0 or len(self.img_front_deque) == 0 or
(self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or len(self.img_right_depth_deque) == 0
or len(self.img_front_depth_deque) == 0))):
return False
if self.args.use_depth_image:
frame_time = min([
self.img_left_deque[-1].header.stamp.to_sec(),
self.img_right_deque[-1].header.stamp.to_sec(),
self.img_front_deque[-1].header.stamp.to_sec(),
self.img_left_depth_deque[-1].header.stamp.to_sec(),
self.img_right_depth_deque[-1].header.stamp.to_sec(),
self.img_front_depth_deque[-1].header.stamp.to_sec(),
])
else:
frame_time = min([
self.img_left_deque[-1].header.stamp.to_sec(),
self.img_right_deque[-1].header.stamp.to_sec(),
self.img_front_deque[-1].header.stamp.to_sec(),
])
if (len(self.img_left_deque) == 0 or self.img_left_deque[-1].header.stamp.to_sec() < frame_time):
return False
if (len(self.img_right_deque) == 0 or self.img_right_deque[-1].header.stamp.to_sec() < frame_time):
return False
if (len(self.img_front_deque) == 0 or self.img_front_deque[-1].header.stamp.to_sec() < frame_time):
return False
if (len(self.puppet_arm_left_deque) == 0 or self.puppet_arm_left_deque[-1].header.stamp.to_sec() < frame_time):
return False
if (len(self.puppet_arm_right_deque) == 0
or self.puppet_arm_right_deque[-1].header.stamp.to_sec() < frame_time):
return False
if self.args.use_depth_image and (len(self.img_left_depth_deque) == 0
or self.img_left_depth_deque[-1].header.stamp.to_sec() < frame_time):
return False
if self.args.use_depth_image and (len(self.img_right_depth_deque) == 0
or self.img_right_depth_deque[-1].header.stamp.to_sec() < frame_time):
return False
if self.args.use_depth_image and (len(self.img_front_depth_deque) == 0
or self.img_front_depth_deque[-1].header.stamp.to_sec() < frame_time):
return False
if self.args.use_robot_base and (len(self.robot_base_deque) == 0
or self.robot_base_deque[-1].header.stamp.to_sec() < frame_time):
return False
while self.img_left_deque[0].header.stamp.to_sec() < frame_time:
self.img_left_deque.popleft()
img_left = self.bridge.imgmsg_to_cv2(self.img_left_deque.popleft(), "passthrough")
while self.img_right_deque[0].header.stamp.to_sec() < frame_time:
self.img_right_deque.popleft()
img_right = self.bridge.imgmsg_to_cv2(self.img_right_deque.popleft(), "passthrough")
while self.img_front_deque[0].header.stamp.to_sec() < frame_time:
self.img_front_deque.popleft()
img_front = self.bridge.imgmsg_to_cv2(self.img_front_deque.popleft(), "passthrough")
while self.puppet_arm_left_deque[0].header.stamp.to_sec() < frame_time:
self.puppet_arm_left_deque.popleft()
puppet_arm_left = self.puppet_arm_left_deque.popleft()
while self.puppet_arm_right_deque[0].header.stamp.to_sec() < frame_time:
self.puppet_arm_right_deque.popleft()
puppet_arm_right = self.puppet_arm_right_deque.popleft()
img_left_depth = None
if self.args.use_depth_image:
while self.img_left_depth_deque[0].header.stamp.to_sec() < frame_time:
self.img_left_depth_deque.popleft()
img_left_depth = self.bridge.imgmsg_to_cv2(self.img_left_depth_deque.popleft(), "passthrough")
img_right_depth = None
if self.args.use_depth_image:
while self.img_right_depth_deque[0].header.stamp.to_sec() < frame_time:
self.img_right_depth_deque.popleft()
img_right_depth = self.bridge.imgmsg_to_cv2(self.img_right_depth_deque.popleft(), "passthrough")
img_front_depth = None
if self.args.use_depth_image:
while self.img_front_depth_deque[0].header.stamp.to_sec() < frame_time:
self.img_front_depth_deque.popleft()
img_front_depth = self.bridge.imgmsg_to_cv2(self.img_front_depth_deque.popleft(), "passthrough")
robot_base = None
if self.args.use_robot_base:
while self.robot_base_deque[0].header.stamp.to_sec() < frame_time:
self.robot_base_deque.popleft()
robot_base = self.robot_base_deque.popleft()
return (
img_front,
img_left,
img_right,
img_front_depth,
img_left_depth,
img_right_depth,
puppet_arm_left,
puppet_arm_right,
robot_base,
)
def img_left_callback(self, msg):
if len(self.img_left_deque) >= 2000:
self.img_left_deque.popleft()
self.img_left_deque.append(msg)
def img_right_callback(self, msg):
if len(self.img_right_deque) >= 2000:
self.img_right_deque.popleft()
self.img_right_deque.append(msg)
def img_front_callback(self, msg):
if len(self.img_front_deque) >= 2000:
self.img_front_deque.popleft()
self.img_front_deque.append(msg)
def img_left_depth_callback(self, msg):
if len(self.img_left_depth_deque) >= 2000:
self.img_left_depth_deque.popleft()
self.img_left_depth_deque.append(msg)
def img_right_depth_callback(self, msg):
if len(self.img_right_depth_deque) >= 2000:
self.img_right_depth_deque.popleft()
self.img_right_depth_deque.append(msg)
def img_front_depth_callback(self, msg):
if len(self.img_front_depth_deque) >= 2000:
self.img_front_depth_deque.popleft()
self.img_front_depth_deque.append(msg)
def puppet_arm_left_callback(self, msg):
if len(self.puppet_arm_left_deque) >= 2000:
self.puppet_arm_left_deque.popleft()
self.puppet_arm_left_deque.append(msg)
def puppet_arm_right_callback(self, msg):
if len(self.puppet_arm_right_deque) >= 2000:
self.puppet_arm_right_deque.popleft()
self.puppet_arm_right_deque.append(msg)
def robot_base_callback(self, msg):
if len(self.robot_base_deque) >= 2000:
self.robot_base_deque.popleft()
self.robot_base_deque.append(msg)
def init_ros(self):
rospy.init_node("joint_state_publisher", anonymous=True)
rospy.Subscriber(
self.args.img_left_topic,
Image,
self.img_left_callback,
queue_size=1000,
tcp_nodelay=True,
)
rospy.Subscriber(
self.args.img_right_topic,
Image,
self.img_right_callback,
queue_size=1000,
tcp_nodelay=True,
)
rospy.Subscriber(
self.args.img_front_topic,
Image,
self.img_front_callback,
queue_size=1000,
tcp_nodelay=True,
)
if self.args.use_depth_image:
rospy.Subscriber(
self.args.img_left_depth_topic,
Image,
self.img_left_depth_callback,
queue_size=1000,
tcp_nodelay=True,
)
rospy.Subscriber(
self.args.img_right_depth_topic,
Image,
self.img_right_depth_callback,
queue_size=1000,
tcp_nodelay=True,
)
rospy.Subscriber(
self.args.img_front_depth_topic,
Image,
self.img_front_depth_callback,
queue_size=1000,
tcp_nodelay=True,
)
rospy.Subscriber(
self.args.puppet_arm_left_topic,
JointState,
self.puppet_arm_left_callback,
queue_size=1000,
tcp_nodelay=True,
)
rospy.Subscriber(
self.args.puppet_arm_right_topic,
JointState,
self.puppet_arm_right_callback,
queue_size=1000,
tcp_nodelay=True,
)
rospy.Subscriber(
self.args.robot_base_topic,
Odometry,
self.robot_base_callback,
queue_size=1000,
tcp_nodelay=True,
)
self.puppet_arm_left_publisher = rospy.Publisher(self.args.puppet_arm_left_cmd_topic, JointState, queue_size=10)
self.puppet_arm_right_publisher = rospy.Publisher(self.args.puppet_arm_right_cmd_topic,
JointState,
queue_size=10)
self.robot_base_publisher = rospy.Publisher(self.args.robot_base_cmd_topic, Twist, queue_size=10)
def get_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
"--max_publish_step",
action="store",
type=int,
help="Maximum number of action publishing steps",
default=10000,
required=False,
)
parser.add_argument(
"--seed",
action="store",
type=int,
help="Random seed",
default=None,
required=False,
)
parser.add_argument(
"--img_front_topic",
action="store",
type=str,
help="img_front_topic",
default="/camera_f/color/image_raw",
required=False,
)
parser.add_argument(
"--img_left_topic",
action="store",
type=str,
help="img_left_topic",
default="/camera_l/color/image_raw",
required=False,
)
parser.add_argument(
"--img_right_topic",
action="store",
type=str,
help="img_right_topic",
default="/camera_r/color/image_raw",
required=False,
)
parser.add_argument(
"--img_front_depth_topic",
action="store",
type=str,
help="img_front_depth_topic",
default="/camera_f/depth/image_raw",
required=False,
)
parser.add_argument(
"--img_left_depth_topic",
action="store",
type=str,
help="img_left_depth_topic",
default="/camera_l/depth/image_raw",
required=False,
)
parser.add_argument(
"--img_right_depth_topic",
action="store",
type=str,
help="img_right_depth_topic",
default="/camera_r/depth/image_raw",
required=False,
)
parser.add_argument(
"--puppet_arm_left_cmd_topic",
action="store",
type=str,
help="puppet_arm_left_cmd_topic",
default="/master/joint_left",
required=False,
)
parser.add_argument(
"--puppet_arm_right_cmd_topic",
action="store",
type=str,
help="puppet_arm_right_cmd_topic",
default="/master/joint_right",
required=False,
)
parser.add_argument(
"--puppet_arm_left_topic",
action="store",
type=str,
help="puppet_arm_left_topic",
default="/puppet/joint_left",
required=False,
)
parser.add_argument(
"--puppet_arm_right_topic",
action="store",
type=str,
help="puppet_arm_right_topic",
default="/puppet/joint_right",
required=False,
)
parser.add_argument(
"--robot_base_topic",
action="store",
type=str,
help="robot_base_topic",
default="/odom_raw",
required=False,
)
parser.add_argument(
"--robot_base_cmd_topic",
action="store",
type=str,
help="robot_base_topic",
default="/cmd_vel",
required=False,
)
parser.add_argument(
"--use_robot_base",
action="store_true",
help="Whether to use the robot base to move around",
default=False,
required=False,
)
parser.add_argument(
"--publish_rate",
action="store",
type=int,
help="The rate at which to publish the actions",
default=30,
required=False,
)
parser.add_argument(
"--ctrl_freq",
action="store",
type=int,
help="The control frequency of the robot",
default=25,
required=False,
)
parser.add_argument(
"--chunk_size",
action="store",
type=int,
help="Action chunk size",
default=64,
required=False,
)
parser.add_argument(
"--arm_steps_length",
action="store",
type=float,
help="The maximum change allowed for each joint per timestep",
default=[0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.2],
required=False,
)
parser.add_argument(
"--use_actions_interpolation",
action="store_true",
help="Whether to interpolate the actions if the difference is too large",
default=False,
required=False,
)
parser.add_argument(
"--use_depth_image",
action="store_true",
help="Whether to use depth images",
default=False,
required=False,
)
parser.add_argument(
"--disable_puppet_arm",
action="store_true",
help="Whether to disable the puppet arm. This is useful for safely debugging",
default=False,
)
parser.add_argument(
"--config_path",
type=str,
default="configs/base.yaml",
help="Path to the config file",
)
# parser.add_argument('--cfg_scale', type=float, default=2.0,
# help='the scaling factor used to modify the magnitude of the control features during denoising')
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
required=True,
help="Name or path to the pretrained model",
)
parser.add_argument(
"--lang_embeddings_path",
type=str,
required=True,
help="Path to the pre-encoded language instruction embeddings",
)
args = parser.parse_args()
return args
def main():
args = get_arguments()
ros_operator = RosOperator(args)
if args.seed is not None:
set_seed(args.seed)
config = get_config(args)
model_inference(args, config, ros_operator)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,344 @@
import os, sys
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
from configs.state_vec import STATE_VEC_IDX_MAPPING
from pathlib import Path
# get current workspace
current_file = Path(__file__)
sys.path.append(os.path.join(current_file.parent.parent, "models"))
sys.path.append(os.path.join(current_file.parent.parent, "models"))
from multimodal_encoder.siglip_encoder import SiglipVisionTower
from multimodal_encoder.t5_encoder import T5Embedder
from rdt_runner import RDTRunner
# The indices that the raw vector should be mapped to in the unified action vector
# AGILEX_STATE_INDICES = [
# STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"] for i in range(1)
# ] + [
# STATE_VEC_IDX_MAPPING["left_gripper_open"]
# ] + [
# STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(1)
# ] + [
# STATE_VEC_IDX_MAPPING[f"right_gripper_open"]
# ]
# AGILEX_STATE_INDICES = None
# Create the RDT model
def create_model(args, **kwargs):
left_arm_dim, right_arm_dim = (
args["arm_dim"]["left_arm_dim"],
args["arm_dim"]["right_arm_dim"],
)
AGILEX_STATE_INDICES = ([STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"]
for i in range(left_arm_dim)] + [STATE_VEC_IDX_MAPPING["left_gripper_open"]] +
[STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"]
for i in range(right_arm_dim)] + [STATE_VEC_IDX_MAPPING[f"right_gripper_open"]])
model = RoboticDiffusionTransformerModel(args, **kwargs)
pretrained = kwargs.get("pretrained", None)
if pretrained is not None and os.path.isfile(pretrained):
model.load_pretrained_weights(pretrained)
return model
class RoboticDiffusionTransformerModel(object):
"""A wrapper for the RDT model, which handles
1. Model initialization
2. Encodings of instructions
3. Model inference
"""
def __init__(
self,
args,
device="cuda",
dtype=torch.bfloat16,
image_size=None,
control_frequency=25,
pretrained=None,
pretrained_vision_encoder_name_or_path=None,
):
self.args = args
self.dtype = dtype
self.image_size = image_size
self.device = device
self.control_frequency = control_frequency
# We do not use the text encoder due to limited GPU memory
# self.text_tokenizer, self.text_model = self.get_text_encoder(pretrained_text_encoder_name_or_path)
self.image_processor, self.vision_model = self.get_vision_encoder(pretrained_vision_encoder_name_or_path)
self.policy = self.get_policy(pretrained)
self.left_arm_dim, self.right_arm_dim = (
args["arm_dim"]["left_arm_dim"],
args["arm_dim"]["right_arm_dim"],
)
self.reset()
def get_policy(self, pretrained):
"""Initialize the model."""
# Initialize model with arguments
if pretrained is None or os.path.isfile(pretrained):
img_cond_len = (self.args["common"]["img_history_size"] * self.args["common"]["num_cameras"] *
self.vision_model.num_patches)
_model = RDTRunner(
action_dim=self.args["common"]["state_dim"],
pred_horizon=self.args["common"]["action_chunk_size"],
config=self.args["model"],
lang_token_dim=self.args["model"]["lang_token_dim"],
img_token_dim=self.args["model"]["img_token_dim"],
state_token_dim=self.args["model"]["state_token_dim"],
max_lang_cond_len=self.args["dataset"]["tokenizer_max_length"],
img_cond_len=img_cond_len,
img_pos_embed_config=[
# No initial pos embed in the last grid size
# since we've already done in ViT
(
"image",
(
self.args["common"]["img_history_size"],
self.args["common"]["num_cameras"],
-self.vision_model.num_patches,
),
),
],
lang_pos_embed_config=[
# Similarly, no initial pos embed for language
("lang", -self.args["dataset"]["tokenizer_max_length"]),
],
dtype=self.dtype,
)
else:
_model = RDTRunner.from_pretrained(pretrained)
return _model
def get_text_encoder(self, pretrained_text_encoder_name_or_path):
text_embedder = T5Embedder(
from_pretrained=pretrained_text_encoder_name_or_path,
model_max_length=self.args["dataset"]["tokenizer_max_length"],
device=self.device,
)
tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
return tokenizer, text_encoder
def get_vision_encoder(self, pretrained_vision_encoder_name_or_path):
vision_encoder = SiglipVisionTower(vision_tower=pretrained_vision_encoder_name_or_path, args=None)
image_processor = vision_encoder.image_processor
return image_processor, vision_encoder
def reset(self):
"""Set model to evaluation mode."""
device = self.device
weight_dtype = self.dtype
self.policy.eval()
# self.text_model.eval()
self.vision_model.eval()
self.policy = self.policy.to(device, dtype=weight_dtype)
# self.text_model = self.text_model.to(device, dtype=weight_dtype)
self.vision_model = self.vision_model.to(device, dtype=weight_dtype)
def load_pretrained_weights(self, pretrained=None):
if pretrained is None:
return
print(f"Loading weights from {pretrained}")
filename = os.path.basename(pretrained)
if filename.endswith(".pt"):
checkpoint = torch.load(pretrained)
self.policy.load_state_dict(checkpoint["module"])
elif filename.endswith(".safetensors"):
from safetensors.torch import load_model
load_model(self.policy, pretrained)
else:
raise NotImplementedError(f"Unknown checkpoint format: {pretrained}")
def encode_instruction(self, instruction, device="cuda"):
"""Encode string instruction to latent embeddings.
Args:
instruction: a string of instruction
device: a string of device
Returns:
pred: a tensor of latent embeddings of shape (text_max_length, 512)
"""
tokens = self.text_tokenizer(instruction, return_tensors="pt", padding="longest",
truncation=True)["input_ids"].to(device)
tokens = tokens.view(1, -1)
with torch.no_grad():
pred = self.text_model(tokens).last_hidden_state.detach()
return pred
def _format_joint_to_state(self, joints):
"""
Format the joint proprioception into the unified action vector.
Args:
joints (torch.Tensor): The joint proprioception to be formatted.
qpos ([B, N, 14]).
Returns:
state (torch.Tensor): The formatted vector for RDT ([B, N, 128]).
"""
AGILEX_STATE_INDICES = ([STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"]
for i in range(self.left_arm_dim)] + [STATE_VEC_IDX_MAPPING["left_gripper_open"]] +
[STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"]
for i in range(self.right_arm_dim)] + [STATE_VEC_IDX_MAPPING[f"right_gripper_open"]])
# Rescale the gripper to the range of [0, 1]
joints = joints / torch.tensor(
[[[1 for i in range(self.left_arm_dim + 1 + self.right_arm_dim + 1)]]],
device=joints.device,
dtype=joints.dtype,
)
B, N, _ = joints.shape
state = torch.zeros(
(B, N, self.args["model"]["state_token_dim"]),
device=joints.device,
dtype=joints.dtype,
)
# Fill into the unified state vector
state[:, :, AGILEX_STATE_INDICES] = joints
# Assemble the mask indicating each dimension's availability
state_elem_mask = torch.zeros(
(B, self.args["model"]["state_token_dim"]),
device=joints.device,
dtype=joints.dtype,
)
state_elem_mask[:, AGILEX_STATE_INDICES] = 1
return state, state_elem_mask
def _unformat_action_to_joint(self, action):
"""
Unformat the unified action vector into the joint action to be executed.
Args:
action (torch.Tensor): The unified action vector to be unformatted.
([B, N, 128])
Returns:
joints (torch.Tensor): The unformatted robot joint action.
qpos ([B, N, 14]).
"""
AGILEX_STATE_INDICES = ([STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"]
for i in range(self.left_arm_dim)] + [STATE_VEC_IDX_MAPPING["left_gripper_open"]] +
[STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"]
for i in range(self.right_arm_dim)] + [STATE_VEC_IDX_MAPPING[f"right_gripper_open"]])
action_indices = AGILEX_STATE_INDICES
joints = action[:, :, action_indices]
# Rescale the gripper back to the action range
# Note that the action range and proprioception range are different
# for Mobile ALOHA robot
joints = joints * torch.tensor(
[[[1 for i in range(self.left_arm_dim + 1 + self.right_arm_dim + 1)]]],
device=joints.device,
dtype=joints.dtype,
)
return joints
@torch.no_grad()
def step(self, proprio, images, text_embeds):
"""
Predict the next action chunk given the
proprioceptive states, images, and instruction embeddings.
Args:
proprio: proprioceptive states
images: RGB images, the order should be
[ext_{t-1}, right_wrist_{t-1}, left_wrist_{t-1},
ext_{t}, right_wrist_{t}, left_wrist_{t}]
text_embeds: instruction embeddings
Returns:
action: predicted action
"""
device = self.device
dtype = self.dtype
# The background image used for padding
background_color = np.array([int(x * 255) for x in self.image_processor.image_mean],
dtype=np.uint8).reshape(1, 1, 3)
background_image = (np.ones(
(
self.image_processor.size["height"],
self.image_processor.size["width"],
3,
),
dtype=np.uint8,
) * background_color)
# Preprocess the images by order and encode them
image_tensor_list = []
for image in images:
if image is None:
# Replace it with the background image
image = Image.fromarray(background_image)
if self.image_size is not None:
image = transforms.Resize(self.data_args.image_size)(image)
if self.args["dataset"].get("auto_adjust_image_brightness", False):
pixel_values = list(image.getdata())
average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)
if average_brightness <= 0.15:
image = transforms.ColorJitter(brightness=(1.75, 1.75))(image)
if self.args["dataset"].get("image_aspect_ratio", "pad") == "pad":
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image, tuple(int(x * 255) for x in self.image_processor.image_mean))
image = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
image_tensor_list.append(image)
image_tensor = torch.stack(image_tensor_list, dim=0).to(device, dtype=dtype)
image_embeds = self.vision_model(image_tensor).detach()
image_embeds = image_embeds.reshape(-1, self.vision_model.hidden_size).unsqueeze(0)
# Prepare the proprioception states and the control frequency
joints = proprio.to(device).unsqueeze(0) # (1, 1, 14)
states, state_elem_mask = self._format_joint_to_state(joints) # (1, 1, 128), (1, 128)
states, state_elem_mask = states.to(device, dtype=dtype), state_elem_mask.to(device, dtype=dtype)
states = states[:, -1:, :] # (1, 1, 128)
ctrl_freqs = torch.tensor([self.control_frequency]).to(device)
text_embeds = text_embeds.to(device, dtype=dtype)
# Predict the next action chunk given the inputs
trajectory = self.policy.predict_action(
lang_tokens=text_embeds,
lang_attn_mask=torch.ones(text_embeds.shape[:2], dtype=torch.bool, device=text_embeds.device),
img_tokens=image_embeds,
state_tokens=states,
action_mask=state_elem_mask.unsqueeze(1),
ctrl_freqs=ctrl_freqs,
)
trajectory = self._unformat_action_to_joint(trajectory).to(torch.float32)
return trajectory

View File

@ -0,0 +1,53 @@
import os
import torch
import yaml
from models.multimodal_encoder.t5_encoder import T5Embedder
GPU = 0
MODEL_PATH = "google/t5-v1_1-xxl"
CONFIG_PATH = "configs/base.yaml"
SAVE_DIR = "outs/"
# Modify this to your task name and instruction
TASK_NAME = "handover_pan"
INSTRUCTION = "Pick up the black marker on the right and put it into the packaging box on the left."
# Note: if your GPU VRAM is less than 24GB,
# it is recommended to enable offloading by specifying an offload directory.
OFFLOAD_DIR = (
None # Specify your offload directory here, ensuring the directory exists.
)
def main():
with open(CONFIG_PATH, "r") as fp:
config = yaml.safe_load(fp)
device = torch.device(f"cuda:{GPU}")
text_embedder = T5Embedder(
from_pretrained=MODEL_PATH,
model_max_length=config["dataset"]["tokenizer_max_length"],
device=device,
use_offload_folder=OFFLOAD_DIR,
)
tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
tokens = tokenizer(INSTRUCTION, return_tensors="pt", padding="longest", truncation=True)["input_ids"].to(device)
tokens = tokens.view(1, -1)
with torch.no_grad():
pred = text_encoder(tokens).last_hidden_state.detach().cpu()
save_path = os.path.join(SAVE_DIR, f"{TASK_NAME}.pt")
# We save the embeddings in a dictionary format
torch.save({"name": TASK_NAME, "instruction": INSTRUCTION, "embeddings": pred}, save_path)
print(
f'"{INSTRUCTION}" from "{TASK_NAME}" is encoded by "{MODEL_PATH}" into shape {pred.shape} and saved to "{save_path}"'
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,57 @@
import os
import json
import argparse
import torch
import yaml
from tqdm import tqdm
from models.multimodal_encoder.t5_encoder import T5Embedder
def encode_lang(
DATA_FILE_PATH,
TARGET_DIR,
GPU,
desc_type="seen",
tokenizer=None,
text_encoder=None,
):
current_dir = os.path.dirname(__file__)
with open(os.path.join(current_dir, "../configs/base.yaml"), "r") as fp:
config = yaml.safe_load(fp)
device = torch.device(f"cuda:{GPU}")
if tokenizer is None or text_encoder is None:
text_embedder = T5Embedder(
from_pretrained=os.path.join(current_dir, "../../weights/RDT/t5-v1_1-xxl"),
model_max_length=config["dataset"]["tokenizer_max_length"],
device=device,
use_offload_folder=None,
)
tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
with open(DATA_FILE_PATH, "r") as f_instr:
instruction_dict = json.load(f_instr)
instructions = instruction_dict[desc_type]
# Encode the instructions
tokenized_res = tokenizer(instructions, return_tensors="pt", padding="longest", truncation=True)
tokens = tokenized_res["input_ids"].to(device)
attn_mask = tokenized_res["attention_mask"].to(device)
with torch.no_grad():
text_embeds = (text_encoder(input_ids=tokens, attention_mask=attn_mask)["last_hidden_state"].detach().cpu())
attn_mask = attn_mask.cpu().bool()
if not os.path.exists(f"{TARGET_DIR}/instructions"):
os.makedirs(f"{TARGET_DIR}/instructions")
# Save the embeddings for training use
for i in range(len(instructions)):
text_embed = text_embeds[i][attn_mask[i]]
save_path = os.path.join(TARGET_DIR, f"instructions/lang_embed_{i}.pt")
# print("encoded instructions save_path:",save_path)
torch.save(text_embed, save_path)
return tokenizer, text_encoder

View File

@ -0,0 +1,84 @@
import json
import os
import sys
import re
def extract_metrics_from_log(log_file_path):
all_metrics = []
pattern = re.compile(
r"\{'agilex_sample_mse':\s*([0-9.eE+-]+),\s*'agilex_sample_l2err':\s*([0-9.eE+-]+),\s*'overall_avg_sample_mse':\s*([0-9.eE+-]+),\s*'overall_avg_sample_l2err':\s*([0-9.eE+-]+)\}"
)
try:
with open(log_file_path, 'r', encoding='utf-8') as f:
for line in f:
m = pattern.search(line)
if m:
metrics = (
float(m.group(1)),
float(m.group(2)),
float(m.group(3)),
float(m.group(4))
)
all_metrics.append(metrics)
print(f"Find Metrics: agilex_sample_mse={metrics[0]}, agilex_sample_l2err={metrics[1]}, "
f"overall_avg_sample_mse={metrics[2]}, overall_avg_sample_l2err={metrics[3]}")
except Exception as e:
print(f"Failed to read log: {e}")
return (None, None, None, None)
if not all_metrics:
print("No metrics found in the log file")
return (None, None, None, None)
print(f"\nTotal {len(all_metrics)} metrics found in the log file")
best_agilex_mse = min(m[0] for m in all_metrics)
best_agilex_l2err = min(m[1] for m in all_metrics)
best_overall_mse = min(m[2] for m in all_metrics)
best_overall_l2err = min(m[3] for m in all_metrics)
print(f"\nBest metrics:")
print(f" agilex_sample_mse: {best_agilex_mse}")
print(f" agilex_sample_l2err: {best_agilex_l2err}")
print(f" overall_avg_sample_mse: {best_overall_mse}")
print(f" overall_avg_sample_l2err: {best_overall_l2err}")
return (best_agilex_mse, best_agilex_l2err, best_overall_mse, best_overall_l2err)
def generate_output_json(input_config_file, output_dir, runtime):
with open(input_config_file, 'r') as f:
config = json.load(f)
log_file = os.path.join(output_dir, 'output.log')
agilex_sample_mse, agilex_sample_l2err, overall_avg_sample_mse, overall_avg_sample_l2err = extract_metrics_from_log(log_file)
if None in [agilex_sample_mse, agilex_sample_l2err, overall_avg_sample_mse, overall_avg_sample_l2err]:
print("Warning: Some metrics are missing in the log file.")
output_json = {
"task_id": config.get("task_id"),
"model_type": "RDT-1B",
"model_name": config.get("model_name") if "model_name" in config else config.get("train", {}).get("model"),
"gpu_id": config.get("gpu_id"),
"runtime": runtime,
"log_path": log_file,
"output_dir": output_dir,
"model_path": os.path.join(output_dir, 'pytorch_model.bin'),
"metrics": {
"agilex_sample_mse": agilex_sample_mse,
"agilex_sample_l2err": agilex_sample_l2err,
"overall_avg_sample_mse": overall_avg_sample_mse,
"overall_avg_sample_l2err": overall_avg_sample_l2err
}
}
# 写入 output.json格式化输出、确保null与规范json一致
output_json_path = os.path.join(output_dir, 'output.json')
with open(output_json_path, 'w') as f:
json.dump(output_json, f, indent=4, ensure_ascii=False)
if __name__ == "__main__":
if len(sys.argv) != 4:
print("Usage: python generate_output_json.py <input_config_file> <output_dir> <runtime>")
sys.exit(1)
generate_output_json(sys.argv[1], sys.argv[2], sys.argv[3])

View File

@ -0,0 +1,325 @@
import os
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
from configs.state_vec import STATE_VEC_IDX_MAPPING
from models.multimodal_encoder.siglip_encoder import SiglipVisionTower
from models.multimodal_encoder.t5_encoder import T5Embedder
from models.rdt_runner import RDTRunner
MANISKILL_INDICES = [STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"]
for i in range(7)] + [STATE_VEC_IDX_MAPPING[f"right_gripper_open"]]
def create_model(args, pretrained, **kwargs):
model = RoboticDiffusionTransformerModel(args, **kwargs)
if pretrained is not None:
model.load_pretrained_weights(pretrained)
return model
DATA_STAT = {
"state_min": [
-0.7463043928146362,
-0.0801204964518547,
-0.4976441562175751,
-2.657780647277832,
-0.5742632150650024,
1.8309762477874756,
-2.2423808574676514,
0.0,
],
"state_max": [
0.7645499110221863,
1.4967026710510254,
0.4650936424732208,
-0.3866899907588959,
0.5505855679512024,
3.2900545597076416,
2.5737812519073486,
0.03999999910593033,
],
"action_min": [
-0.7472005486488342,
-0.08631071448326111,
-0.4995281398296356,
-2.658363103866577,
-0.5751323103904724,
1.8290787935256958,
-2.245187997817993,
-1.0,
],
"action_max": [
0.7654682397842407,
1.4984270334243774,
0.46786263585090637,
-0.38181185722351074,
0.5517147779464722,
3.291581630706787,
2.575840711593628,
1.0,
],
}
class RoboticDiffusionTransformerModel(object):
"""A wrapper for the RDT model, which handles
1. Model initialization
2. Encodings of instructions
3. Model inference
"""
def __init__(
self,
args,
device="cuda",
dtype=torch.bfloat16,
image_size=None,
control_frequency=25,
pretrained_text_encoder_name_or_path=None,
pretrained_vision_encoder_name_or_path=None,
):
self.args = args
self.dtype = dtype
self.image_size = image_size
self.device = device
self.control_frequency = control_frequency
self.text_tokenizer, self.text_model = self.get_text_encoder(pretrained_text_encoder_name_or_path)
self.image_processor, self.vision_model = self.get_vision_encoder(pretrained_vision_encoder_name_or_path)
self.policy = self.get_policy()
self.state_min = torch.tensor(DATA_STAT["state_min"]).to(device)
self.state_max = torch.tensor(DATA_STAT["state_max"]).to(device)
self.action_min = torch.tensor(DATA_STAT["action_min"]).to(device)
self.action_max = torch.tensor(DATA_STAT["action_max"]).to(device)
self.reset()
def get_policy(self):
"""Initialize the model."""
# Initialize model with arguments
img_cond_len = (self.args["common"]["img_history_size"] * self.args["common"]["num_cameras"] *
self.vision_model.num_patches)
_model = RDTRunner(
action_dim=self.args["common"]["state_dim"],
pred_horizon=self.args["common"]["action_chunk_size"],
config=self.args["model"],
lang_token_dim=self.args["model"]["lang_token_dim"],
img_token_dim=self.args["model"]["img_token_dim"],
state_token_dim=self.args["model"]["state_token_dim"],
max_lang_cond_len=self.args["dataset"]["tokenizer_max_length"],
img_cond_len=img_cond_len,
img_pos_embed_config=[
# No initial pos embed in the last grid size
# since we've already done in ViT
(
"image",
(
self.args["common"]["img_history_size"],
self.args["common"]["num_cameras"],
-self.vision_model.num_patches,
),
),
],
lang_pos_embed_config=[
# Similarly, no initial pos embed for language
("lang", -self.args["dataset"]["tokenizer_max_length"]),
],
dtype=self.dtype,
)
return _model
def get_text_encoder(self, pretrained_text_encoder_name_or_path):
text_embedder = T5Embedder(
from_pretrained=pretrained_text_encoder_name_or_path,
model_max_length=self.args["dataset"]["tokenizer_max_length"],
device=self.device,
)
tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
return tokenizer, text_encoder
def get_vision_encoder(self, pretrained_vision_encoder_name_or_path):
vision_encoder = SiglipVisionTower(vision_tower=pretrained_vision_encoder_name_or_path, args=None)
image_processor = vision_encoder.image_processor
return image_processor, vision_encoder
def reset(self):
"""Set model to evaluation mode."""
device = self.device
weight_dtype = self.dtype
self.policy.eval()
self.text_model.eval()
self.vision_model.eval()
self.policy = self.policy.to(device, dtype=weight_dtype)
self.text_model = self.text_model.to(device, dtype=weight_dtype)
self.vision_model = self.vision_model.to(device, dtype=weight_dtype)
def load_pretrained_weights(self, pretrained=None):
if pretrained is None:
return
print(f"Loading weights from {pretrained}")
filename = os.path.basename(pretrained)
if filename.endswith(".pt"):
checkpoint = torch.load(pretrained)
self.policy.load_state_dict(checkpoint["module"])
elif filename.endswith(".safetensors"):
from safetensors.torch import load_model
load_model(self.policy, pretrained)
else:
raise NotImplementedError(f"Unknown checkpoint format: {pretrained}")
def encode_instruction(self, instruction, device="cuda"):
"""Encode string instruction to latent embeddings.
Args:
instruction: a string of instruction
device: a string of device
Returns:
pred: a tensor of latent embeddings of shape (text_max_length, 512)
"""
tokens = self.text_tokenizer(instruction, return_tensors="pt", padding="longest",
truncation=True)["input_ids"].to(device)
tokens = tokens.view(1, -1)
with torch.no_grad():
pred = self.text_model(tokens).last_hidden_state.detach()
return pred
def _format_joint_to_state(self, joints):
"""
Format the robot joint state into the unified state vector.
Args:
joints (torch.Tensor): The joint state to be formatted.
qpos ([B, N, 14]).
Returns:
state (torch.Tensor): The formatted state for RDT ([B, N, 128]).
"""
# Rescale the gripper
# joints = joints / torch.tensor(
# [[[1, 1, 1, 1, 1, 1, 4.7908, 1, 1, 1, 1, 1, 1, 4.7888]]],
# device=joints.device, dtype=joints.dtype
# )
# normalize to -1,1
joints = (joints - self.state_min) / (self.state_max - self.state_min) * 2 - 1
B, N, _ = joints.shape
state = torch.zeros(
(B, N, self.args["model"]["state_token_dim"]),
device=joints.device,
dtype=joints.dtype,
)
# assemble the unifed state vector
state[:, :, MANISKILL_INDICES] = joints
state_elem_mask = torch.zeros(
(B, self.args["model"]["state_token_dim"]),
device=joints.device,
dtype=joints.dtype,
)
state_elem_mask[:, MANISKILL_INDICES] = 1
return state, state_elem_mask
def _unformat_action_to_joint(self, action):
action_indices = MANISKILL_INDICES
joints = action[:, :, action_indices]
# denormalize to action space
joints = (joints + 1) / 2 * (self.action_max - self.action_min) + self.action_min
return joints
@torch.no_grad()
def step(self, proprio, images, text_embeds):
"""
Args:
proprio: proprioceptive states
images: RGB images
text_embeds: instruction embeddings
Returns:
action: predicted action
"""
device = self.device
dtype = self.dtype
background_color = np.array([int(x * 255) for x in self.image_processor.image_mean],
dtype=np.uint8).reshape(1, 1, 3)
background_image = (np.ones(
(
self.image_processor.size["height"],
self.image_processor.size["width"],
3,
),
dtype=np.uint8,
) * background_color)
image_tensor_list = []
for image in images:
if image is None:
# Replace it with the background image
image = Image.fromarray(background_image)
if self.image_size is not None:
image = transforms.Resize(self.data_args.image_size)(image)
if self.args["dataset"].get("auto_adjust_image_brightness", False):
pixel_values = list(image.getdata())
average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)
if average_brightness <= 0.15:
image = transforms.ColorJitter(brightness=(1.75, 1.75))(image)
if self.args["dataset"].get("image_aspect_ratio", "pad") == "pad":
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image, tuple(int(x * 255) for x in self.image_processor.image_mean))
image = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
image_tensor_list.append(image)
image_tensor = torch.stack(image_tensor_list, dim=0).to(device, dtype=dtype)
image_embeds = self.vision_model(image_tensor).detach()
image_embeds = image_embeds.reshape(-1, self.vision_model.hidden_size).unsqueeze(0)
# history of actions
joints = proprio.to(device).unsqueeze(0) # (1, 1, 14)
states, state_elem_mask = self._format_joint_to_state(joints) # (1, 1, 128), (1, 128)
states, state_elem_mask = states.to(device, dtype=dtype), state_elem_mask.to(device, dtype=dtype)
states = states[:, -1:, :] # (1, 1, 128)
ctrl_freqs = torch.tensor([self.control_frequency]).to(device)
text_embeds = text_embeds.to(device, dtype=dtype)
trajectory = self.policy.predict_action(
lang_tokens=text_embeds,
lang_attn_mask=torch.ones(text_embeds.shape[:2], dtype=torch.bool, device=text_embeds.device),
img_tokens=image_embeds,
state_tokens=states,
action_mask=state_elem_mask.unsqueeze(1),
ctrl_freqs=ctrl_freqs,
)
trajectory = self._unformat_action_to_joint(trajectory).to(torch.float32)
return trajectory

View File

@ -0,0 +1,169 @@
import sys
sys.path.append("./")
import os
import h5py
import numpy as np
import pickle
import cv2
import argparse
import yaml
from scripts.encode_lang_batch_once import encode_lang
def load_hdf5(dataset_path):
if not os.path.isfile(dataset_path):
print(f"Dataset does not exist at \n{dataset_path}\n")
exit()
with h5py.File(dataset_path, "r") as root:
left_gripper, left_arm = (
root["/joint_action/left_gripper"][()],
root["/joint_action/left_arm"][()],
)
right_gripper, right_arm = (
root["/joint_action/right_gripper"][()],
root["/joint_action/right_arm"][()],
)
image_dict = dict()
for cam_name in root[f"/observation/"].keys():
image_dict[cam_name] = root[f"/observation/{cam_name}/rgb"][()]
return left_gripper, left_arm, right_gripper, right_arm, image_dict
def images_encoding(imgs):
encode_data = []
padded_data = []
max_len = 0
for i in range(len(imgs)):
success, encoded_image = cv2.imencode(".jpg", imgs[i])
jpeg_data = encoded_image.tobytes()
encode_data.append(jpeg_data)
max_len = max(max_len, len(jpeg_data))
# padding
for i in range(len(imgs)):
padded_data.append(encode_data[i].ljust(max_len, b"\0"))
return encode_data, max_len
def get_task_config(task_name):
with open(f"./task_config/{task_name}.yml", "r", encoding="utf-8") as f:
args = yaml.load(f.read(), Loader=yaml.FullLoader)
return args
def data_transform(path, episode_num, save_path):
begin = 0
floders = os.listdir(path)
assert episode_num <= len(floders), "data num not enough"
if not os.path.exists(save_path):
os.makedirs(save_path)
for i in range(episode_num):
left_gripper_all, left_arm_all, right_gripper_all, right_arm_all, image_dict = (load_hdf5(
os.path.join(path, f"episode{i}.hdf5")))
qpos = []
actions = []
cam_high = []
cam_right_wrist = []
cam_left_wrist = []
left_arm_dim = []
right_arm_dim = []
last_state = None
for j in range(0, left_gripper_all.shape[0]):
left_gripper, left_arm, right_gripper, right_arm = (
left_gripper_all[j],
left_arm_all[j],
right_gripper_all[j],
right_arm_all[j],
)
state = np.concatenate((left_arm, [left_gripper], right_arm, [right_gripper]), axis=0) # joint
state = state.astype(np.float32)
if j != left_gripper_all.shape[0] - 1:
qpos.append(state)
camera_high_bits = image_dict["head_camera"][j]
camera_high = cv2.imdecode(np.frombuffer(camera_high_bits, np.uint8), cv2.IMREAD_COLOR)
camera_high_resized = cv2.resize(camera_high, (640, 480))
cam_high.append(camera_high_resized)
camera_right_wrist_bits = image_dict["right_camera"][j]
camera_right_wrist = cv2.imdecode(np.frombuffer(camera_right_wrist_bits, np.uint8), cv2.IMREAD_COLOR)
camera_right_wrist_resized = cv2.resize(camera_right_wrist, (640, 480))
cam_right_wrist.append(camera_right_wrist_resized)
camera_left_wrist_bits = image_dict["left_camera"][j]
camera_left_wrist = cv2.imdecode(np.frombuffer(camera_left_wrist_bits, np.uint8), cv2.IMREAD_COLOR)
camera_left_wrist_resized = cv2.resize(camera_left_wrist, (640, 480))
cam_left_wrist.append(camera_left_wrist_resized)
if j != 0:
action = state
actions.append(action)
left_arm_dim.append(left_arm.shape[0])
right_arm_dim.append(right_arm.shape[0])
if not os.path.exists(os.path.join(save_path, f"episode_{i}")):
os.makedirs(os.path.join(save_path, f"episode_{i}"))
hdf5path = os.path.join(save_path, f"episode_{i}/episode_{i}.hdf5")
with h5py.File(hdf5path, "w") as f:
f.create_dataset("action", data=np.array(actions))
obs = f.create_group("observations")
obs.create_dataset("qpos", data=np.array(qpos))
obs.create_dataset("left_arm_dim", data=np.array(left_arm_dim))
obs.create_dataset("right_arm_dim", data=np.array(right_arm_dim))
image = obs.create_group("images")
cam_high_enc, len_high = images_encoding(cam_high)
cam_right_wrist_enc, len_right = images_encoding(cam_right_wrist)
cam_left_wrist_enc, len_left = images_encoding(cam_left_wrist)
image.create_dataset("cam_high", data=cam_high_enc, dtype=f"S{len_high}")
image.create_dataset("cam_right_wrist", data=cam_right_wrist_enc, dtype=f"S{len_right}")
image.create_dataset("cam_left_wrist", data=cam_left_wrist_enc, dtype=f"S{len_left}")
begin += 1
print(f"proccess {i} success!")
return begin
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process some episodes.")
parser.add_argument("task_name", type=str)
parser.add_argument("task_config", type=str)
parser.add_argument("expert_data_num", type=int)
args = parser.parse_args()
task_name = args.task_name
task_config = args.task_config
expert_data_num = args.expert_data_num
load_dir = os.path.join("../../data", str(task_name), str(task_config), "data")
print(f"read data from path: {load_dir}")
begin = data_transform(
load_dir,
expert_data_num,
f"./processed_data/{task_name}-{task_config}-{expert_data_num}",
)
tokenizer, text_encoder = None, None
for idx in range(expert_data_num):
print(f"Processing Language: {idx}", end="\r")
data_file_path = (f"../../data/{task_name}/{task_config}/instructions/episode{idx}.json")
target_dir = (f"processed_data/{task_name}-{task_config}-{expert_data_num}/episode_{idx}")
tokenizer, text_encoder = encode_lang(
DATA_FILE_PATH=data_file_path,
TARGET_DIR=target_dir,
GPU=0,
desc_type="seen",
tokenizer=tokenizer,
text_encoder=text_encoder,
)

View File

@ -0,0 +1,31 @@
import json
import yaml
import sys
def read_config(config_file, yaml_file):
with open(config_file, 'r') as f:
json_config = json.load(f)
with open(yaml_file, 'r') as f:
yaml_config = yaml.load(f, Loader=yaml.FullLoader)
yaml_config["model"] = json_config["train"]["model"] + json_config["task_id"]
yaml_config["data_path"] = json_config["train"]["input_data_path"] + "/data"
yaml_config["checkpoint_path"] = json_config["train"]["checkpoint_path"] + "/" + json_config["task_id"]
yaml_config["pretrained_model_name_or_path"] = json_config["train"]["input_data_path"] + "/weights/rdt-1b"
yaml_config["cuda_visible_device"] = str(json_config["gpu_id"])
print(f"cuda_visible_device: {yaml_config['cuda_visible_device']}")
yaml_config["train_batch_size"] = int(json_config["train"]["batch_size"])
yaml_config["sample_batch_size"] = int(json_config["train"]["batch_size"]) * 2
yaml_config["max_train_steps"] = int(json_config["train"]["epochs"])
yaml_config["checkpointing_period"] = int(int(json_config["train"]["epochs"]) / 10)
yaml_config["sample_period"] = 200
yaml_config["checkpoints_total_limit"] = 50
with open(yaml_file, 'w') as f:
yaml.dump(yaml_config, f, default_flow_style=False)
print("Config YAML file updated successfully")
if __name__ == "__main__":
read_config(sys.argv[1], sys.argv[2])

View File

@ -0,0 +1,22 @@
import sys
import yaml
def read_yaml_value(file_path, key):
with open(file_path, "r") as file:
data = yaml.safe_load(file)
value = data.get(key)
if value is not None:
print(value)
else:
print(f"Key '{key}' not found in {file_path}")
if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage: python read_yaml.py <file_path> <key>")
sys.exit(1)
file_path = sys.argv[1]
key = sys.argv[2]
read_yaml_value(file_path, key)

0
RDT-1B/train/__init__.py Normal file
View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

479
RDT-1B/train/dataset.py Normal file
View File

@ -0,0 +1,479 @@
import traceback
import time
import os
import json
import math
import random
from typing import Dict, Sequence
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import transformers
from data.filelock import FileLock
from data.hdf5_vla_dataset import HDF5VLADataset
from train.image_corrupt import image_corrupt
def get_clean_item(chunk_dir):
"""
Get indexes of clean items in a chunk.
"""
dirty_bit = read_dirty_bit(chunk_dir)
return np.where(1 - dirty_bit)[0].tolist()
def save_dirty_bit(chunk_dir, dirty_bit):
"""
Save the dirty bit to the chunk directory.
"""
time_stmp = time.time()
while time.time() - time_stmp < 10.0:
try:
file_path = os.path.join(chunk_dir, "dirty_bit")
lock = FileLock(file_path)
lock.acquire_write_lock()
with open(file_path, "wb") as file:
file.write(dirty_bit.tobytes())
lock.release_lock()
return
except KeyboardInterrupt:
lock.release_lock()
raise KeyboardInterrupt
except BaseException:
lock.release_lock()
continue
raise RuntimeError("Failed to save dirty bit.")
def read_dirty_bit(chunk_dir):
"""
Read the dirty bit from the chunk directory.
"""
# If error occurs, retry
time_stmp = time.time()
while time.time() - time_stmp < 10.0:
try:
file_path = os.path.join(chunk_dir, "dirty_bit")
lock = FileLock(file_path)
lock.acquire_read_lock()
with open(file_path, "rb") as file:
dirty_bit = np.frombuffer(file.read(), dtype=np.uint8).copy()
lock.release_lock()
assert len(dirty_bit) > 0
return dirty_bit
except KeyboardInterrupt:
lock.release_lock()
raise KeyboardInterrupt
except BaseException:
lock.release_lock()
continue
raise RuntimeError("Failed to read dirty bit.")
class VLAConsumerDataset(Dataset):
"""A vision-languange-action Dataset for supervised training.
This dataset will load data from the buffer directory.
"""
def __init__(
self,
model_config_path,
config,
tokenizer,
image_processor,
num_cameras,
img_history_size,
image_size=None,
auto_adjust_image_brightness=False,
image_aug=False,
dataset_type="pretrain",
cond_mask_prob=0.1,
cam_ext_mask_prob=-1.0,
state_noise_snr=None,
use_hdf5=False,
use_precomp_lang_embed=False,
):
super(VLAConsumerDataset, self).__init__()
# Load the control frequency for each dataset
with open("configs/dataset_control_freq.json", "r") as fp:
self.control_freq = json.load(fp)
# Load the dataset names
dataset_names_cfg = ("configs/pretrain_datasets.json"
if dataset_type == "pretrain" else "configs/finetune_datasets.json")
with open(dataset_names_cfg, "r") as file:
DATASET_NAMES = json.load(file)
# Create the mapping between dataset name and id
self.dataset_name2id = {name: i for i, name in enumerate(DATASET_NAMES)}
self.dataset_id2name = {i: name for i, name in enumerate(DATASET_NAMES)}
self.image_processor = image_processor
self.model_config_path = model_config_path
self.buffer_dir = config["buf_path"]
self.num_chunks = config["buf_num_chunks"]
self.chunk_size = config["buf_chunk_size"]
self.tokenizer_max_length = config["tokenizer_max_length"]
self.image_aspect_ratio = config["image_aspect_ratio"]
self.state_noise_snr = state_noise_snr
self.num_cameras = num_cameras
self.img_history_size = img_history_size
self.cond_mask_prob = cond_mask_prob
self.cam_ext_mask_prob = cam_ext_mask_prob
self.use_hdf5 = use_hdf5
self.hdf5_dataset = None
if use_hdf5:
self.hdf5_dataset = HDF5VLADataset(self.model_config_path)
self.use_precomp_lang_embed = use_precomp_lang_embed
if use_precomp_lang_embed:
self.empty_lang_embed = torch.load("data/empty_lang_embed.pt")
# Load dataset stat
with open("configs/dataset_stat.json", "r") as f:
dataset_stat = json.load(f)
self.dataset_stat = dataset_stat
self.tokenizer = tokenizer
self.image_size = image_size
self.auto_adjust_image_brightness = auto_adjust_image_brightness
self.image_aug = image_aug
self.last_content = None
self.last_meta = None
def get_dataset_name2id(self):
return self.dataset_name2id
def get_dataset_id2name(self):
return self.dataset_id2name
@staticmethod
def pairwise(iterable):
a = iter(iterable)
return zip(a, a)
@staticmethod
def _load_data_from_chunk(chunk_dir, chunk_item_idx):
# If error occurs, retry
time_stmp = time.time()
while time.time() - time_stmp < 10.0:
try:
locks = []
file_path = os.path.join(chunk_dir, f"json_content_{chunk_item_idx}.json")
lock = FileLock(file_path)
locks.append(lock)
lock.acquire_read_lock()
with open(file_path, "r") as file:
json_content = json.load(file)
lock.release_lock()
file_path = os.path.join(chunk_dir, f"sample_{chunk_item_idx}.npz")
lock = FileLock(file_path)
locks.append(lock)
lock.acquire_read_lock()
with open(file_path, "rb") as file:
sample_dict = np.load(file)
meta = tuple(sample_dict.values())
lock.release_lock()
return json_content, meta
except KeyboardInterrupt:
for lock in locks:
lock.release_lock()
raise KeyboardInterrupt
except BaseException:
for lock in locks:
lock.release_lock()
continue
raise RuntimeError("Failed to load sample.")
def __len__(self) -> int:
if self.use_hdf5:
return len(self.hdf5_dataset)
else:
return self.num_chunks * self.chunk_size
def _safe_load(self, index):
read_chunk_item_indices = []
# Start searching from a random chunk
read_chunk_idx = index // self.chunk_size
while len(read_chunk_item_indices) == 0:
read_chunk_dir = os.path.join(self.buffer_dir, f"chunk_{read_chunk_idx}")
try:
read_chunk_item_indices = get_clean_item(read_chunk_dir)
except BaseException as e:
# Print the error info
print("Error catched when searching a clean chunk:", e)
traceback.print_exc()
read_chunk_item_indices = []
read_chunk_idx = (read_chunk_idx + 1) % self.num_chunks
# read_chunk_item_index = random.choice(read_chunk_item_indices)
# read_chunk_item_index = read_chunk_item_indices.pop()
random_item_index = index % len(read_chunk_item_indices)
read_chunk_item_index = read_chunk_item_indices[random_item_index]
# Modify the dirty bit
try:
dirty_bit = read_dirty_bit(read_chunk_dir)
dirty_bit[read_chunk_item_index] = 1
save_dirty_bit(read_chunk_dir, dirty_bit)
except BaseException as e:
# Print the error info
print("Error catched when modifying the dirty bit:", e)
traceback.print_exc()
# load the sample
try:
content, meta = self._load_data_from_chunk(read_chunk_dir, read_chunk_item_index)
self.last_content, self.last_meta = content, meta
except BaseException as e:
# Print the error info
print("Error catched when loading sample:", e)
traceback.print_exc()
# If failed to load the data, return the last loaded data for robustness
content, meta = self.last_content, self.last_meta
return (content, *meta)
def __getitem__(self, index):
# For robustness, we will try to load the data until we succeed
while True:
data_dict = None
try:
if self.use_hdf5:
res = self.hdf5_dataset.get_item()
content = res["meta"]
states = res["state"]
actions = res["actions"]
state_elem_mask = res["state_indicator"]
image_metas = [
res["cam_high"],
res["cam_high_mask"],
res["cam_right_wrist"],
res["cam_right_wrist_mask"],
res["cam_left_wrist"],
res["cam_left_wrist_mask"],
]
state_std = res["state_std"]
state_mean = res["state_mean"]
state_norm = res["state_norm"]
else:
(
content,
_,
states,
_,
actions,
_,
state_elem_mask,
*image_metas,
state_std,
state_mean,
state_norm,
) = self._safe_load(index)
data_dict = {}
data_dict["dataset_name"] = content["dataset_name"]
data_dict["data_idx"] = self.dataset_name2id[data_dict["dataset_name"]]
data_dict["ctrl_freq"] = (self.control_freq[data_dict["dataset_name"]]
if random.random() > self.cond_mask_prob else 0)
if self.state_noise_snr is not None:
states += np.random.normal(
0.0,
state_std / np.sqrt(10**(self.state_noise_snr / 10)),
states.shape,
)
ds_state_mean = np.array(self.dataset_stat[data_dict["dataset_name"]]["state_mean"])
ds_state_mean = np.tile(ds_state_mean[None], (states.shape[0], 1))
# Randomly mask the states by the mean state
data_dict["states"] = (states if random.random() > self.cond_mask_prob else ds_state_mean)
data_dict["actions"] = actions
data_dict["state_elem_mask"] = (state_elem_mask if random.random() > self.cond_mask_prob else
np.zeros_like(state_elem_mask))
# Stat for the episode that the step belongs to
data_dict["state_norm"] = state_norm
# We replace the invalid images with the background image
# and also randomly mask images by the background image
background_color = np.array(
[int(x * 255) for x in self.image_processor.image_mean],
dtype=np.uint8,
).reshape(1, 1, 3)
background_image = (np.ones(
(
self.image_processor.size["height"],
self.image_processor.size["width"],
3,
),
dtype=np.uint8,
) * background_color)
image_metas = list(self.pairwise(image_metas))
mask_probs = [self.cond_mask_prob] * self.num_cameras
if self.cam_ext_mask_prob >= 0.0:
mask_probs[0] = self.cam_ext_mask_prob
rearranged_images = []
for i in range(self.img_history_size):
for j in range(self.num_cameras):
images, image_mask = image_metas[j]
image, valid = images[i], image_mask[i]
if (valid and (math.prod(image.shape) > 0) and (random.random() > mask_probs[j])):
rearranged_images.append((image, True))
else:
rearranged_images.append((background_image.copy(), False))
preprocessed_images = []
processor = self.image_processor
for image, valid in rearranged_images:
image = Image.fromarray(image)
if self.image_size is not None:
image = transforms.Resize(self.image_size)(image) # (1008, 336)
# assert image.height == 336, "We haven't prepare for training with images of different resolutions."
if valid and self.auto_adjust_image_brightness:
pixel_values = list(image.getdata())
average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)
if average_brightness <= 0.15:
image = transforms.ColorJitter(brightness=(1.75, 1.75))(image)
# Only apply image augmentation to 50% of the images
if valid and self.image_aug and (random.random() > 0.5):
aug_type = random.choice(["corrput_only", "color_only", "both"])
if aug_type != "corrput_only":
image = transforms.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.5,
hue=0.03)(image)
if aug_type != "color_only":
image = image_corrupt(image)
if self.image_aspect_ratio == "pad":
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
preprocessed_images.append(image)
data_dict["images"] = preprocessed_images
if self.use_precomp_lang_embed:
if content["instruction"][-1] == ".":
content["instruction"] = content["instruction"][:-1]
data_dict["lang_embed"] = (torch.load(content["instruction"])
if random.random() > self.cond_mask_prob else self.empty_lang_embed)
else:
instruction = (content["instruction"] if random.random() > self.cond_mask_prob else "")
data_dict["input_ids"] = self.tokenizer(
instruction,
return_tensors="pt",
padding="longest",
truncation=False,
).input_ids[0]
assert (
len(data_dict["input_ids"]) <= self.tokenizer_max_length
), f"Instruction length {len(data_dict['input_ids'])} exceeds the maximum length {self.tokenizer_max_length}."
for k, v in data_dict.items():
if isinstance(v, np.ndarray):
data_dict[k] = torch.from_numpy(v)
for k, v in data_dict.items():
assert not isinstance(v, np.ndarray), f"key: {k}, value: {v}"
# data_dict[k] = torch.from_numpy(v)
return data_dict
except BaseException as e:
# Print the error info
if data_dict is not None:
print(
f"Error catched when processing sample from {data_dict.get('dataset_name')}:",
e,
)
else:
print(f"Error catched when processing sample:", e)
traceback.print_exc()
# Try incresing the index
index = (index + 1) % len(self)
class DataCollatorForVLAConsumerDataset(object):
"""Collate examples for supervised training."""
def __init__(self, tokenizer: transformers.PreTrainedTokenizer) -> None:
self.tokenizer = tokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
batch = {
"states": [],
"actions": [],
"state_elem_mask": [],
"state_norm": [],
"images": [],
"data_indices": [],
"ctrl_freqs": [],
}
input_ids = []
lang_embeds = []
lang_embed_lens = []
for instance in instances:
# Convert all the numpy arrays to tensor
keys_to_check = [
"states",
"actions",
"state_elem_mask",
"state_norm",
]
for key in keys_to_check:
if isinstance(instance[key], torch.Tensor):
item = instance[key]
else:
item = torch.from_numpy(instance[key])
batch[key].append(item)
if "input_ids" in instance:
input_ids.append(instance["input_ids"])
else:
lang_embeds.append(instance["lang_embed"])
lang_embed_lens.append(instance["lang_embed"].shape[0])
batch["images"].append(torch.stack(instance["images"], dim=0))
batch["data_indices"].append(instance["data_idx"])
batch["ctrl_freqs"].append(instance["ctrl_freq"])
keys_to_stack = ["states", "actions", "state_elem_mask", "state_norm", "images"]
for key in keys_to_stack:
batch[key] = torch.stack(batch[key], dim=0)
batch["ctrl_freqs"] = torch.tensor(batch["ctrl_freqs"])
if len(input_ids) > 0:
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id)
batch["input_ids"] = input_ids
batch["lang_attn_mask"] = input_ids.ne(self.tokenizer.pad_token_id)
else:
lang_embeds = torch.nn.utils.rnn.pad_sequence(lang_embeds, batch_first=True, padding_value=0)
input_lang_attn_mask = torch.zeros(lang_embeds.shape[0], lang_embeds.shape[1], dtype=torch.bool)
for i, l in enumerate(lang_embed_lens):
input_lang_attn_mask[i, :l] = True
batch["lang_embeds"] = lang_embeds
batch["lang_attn_mask"] = input_lang_attn_mask
return batch

View File

@ -0,0 +1,45 @@
import warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
import numpy as np
np.bool = np.bool_
import imgaug.augmenters as iaa
from PIL import Image
# Define our sequence of augmentation steps that will be applied to every image.
seq = iaa.Sequential(
[
# Execute one of the following noise augmentations
iaa.OneOf([
iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05 * 255), per_channel=0.5),
iaa.AdditiveLaplaceNoise(scale=(0.0, 0.05 * 255), per_channel=0.5),
iaa.AdditivePoissonNoise(lam=(0.0, 0.05 * 255), per_channel=0.5),
]),
# Execute one or none of the following blur augmentations
iaa.SomeOf(
(0, 1),
[
iaa.OneOf([
iaa.GaussianBlur((0, 3.0)),
iaa.AverageBlur(k=(2, 7)),
iaa.MedianBlur(k=(3, 11)),
]),
iaa.MotionBlur(k=(3, 36)),
],
),
],
# do all of the above augmentations in random order
random_order=True,
)
def image_corrupt(image: Image):
image_arr = np.array(image)
image_arr = image_arr[None, ...]
image_arr = seq(images=image_arr)
image = Image.fromarray(image_arr[0])
return image

101
RDT-1B/train/sample.py Normal file
View File

@ -0,0 +1,101 @@
from collections import defaultdict
import torch
import torch.nn.functional as F
@torch.no_grad()
def log_sample_res(
text_encoder,
vision_encoder,
rdt,
args,
accelerator,
weight_dtype,
dataset_id2name,
dataloader,
logger,
):
with torch.autocast(device_type="cuda", dtype=torch.float16):
logger.info(f"Running sampling for {args.num_sample_batches} batches...")
rdt.eval()
loss_for_log = defaultdict(float)
loss_counter = defaultdict(int)
for step, batch in enumerate(dataloader):
if step >= args.num_sample_batches:
break
data_indices = batch["data_indices"]
ctrl_freqs = batch["ctrl_freqs"]
state_norm = batch["state_norm"].to(dtype=weight_dtype)
images = batch["images"].to(dtype=weight_dtype)
states = batch["states"].to(dtype=weight_dtype)
# We only use the last state as input
states = states[:, -1:, :]
actions = batch["actions"].to(dtype=weight_dtype)
state_elem_mask = batch["state_elem_mask"].to(dtype=weight_dtype)
batch_size, _, C, H, W = images.shape
image_embeds = vision_encoder(images.reshape(-1, C, H, W)).detach()
image_embeds = image_embeds.reshape((batch_size, -1, vision_encoder.hidden_size))
lang_attn_mask = batch["lang_attn_mask"]
text_embeds = (batch["lang_embeds"].to(dtype=weight_dtype) if args.precomp_lang_embed else text_encoder(
input_ids=batch["input_ids"], attention_mask=lang_attn_mask)["last_hidden_state"].detach())
pred_actions = rdt.predict_action(
lang_tokens=text_embeds,
lang_attn_mask=lang_attn_mask,
img_tokens=image_embeds,
state_tokens=states,
action_mask=state_elem_mask.unsqueeze(1),
ctrl_freqs=ctrl_freqs,
)
num_steps = pred_actions.shape[1]
expanded_state_elem_mask = (state_elem_mask.unsqueeze(1).tile((1, num_steps, 1)).float())
expanded_state_norm = (state_norm.unsqueeze(1).tile((1, num_steps, 1)).float())
loss = F.mse_loss(pred_actions, actions, reduction="none").float()
mse_loss_per_entry = (loss * expanded_state_elem_mask).reshape(
(batch_size, -1)).sum(1) / expanded_state_elem_mask.reshape((batch_size, -1)).sum(1)
l2_loss_per_entry = loss.sqrt() / (expanded_state_norm + 1e-3)
l2_loss_per_entry = (l2_loss_per_entry * expanded_state_elem_mask).reshape(
(batch_size, -1)).sum(1) / expanded_state_elem_mask.reshape((batch_size, -1)).sum(1)
dataset_indices, mse_losses, l2_losses = accelerator.gather_for_metrics((
torch.LongTensor(data_indices).to(device=pred_actions.device),
mse_loss_per_entry,
l2_loss_per_entry,
), )
dataset_indices = dataset_indices.tolist()
if accelerator.is_main_process:
for loss_suffix, losses in zip(["_sample_mse", "_sample_l2err"], [mse_losses, l2_losses]):
for dataset_idx, loss_tensor in zip(dataset_indices, losses):
loss_name = dataset_id2name[dataset_idx] + loss_suffix
loss_for_log[loss_name] += loss_tensor.item()
loss_counter[loss_name] += 1
mse_loss = (loss * expanded_state_elem_mask).sum() / expanded_state_elem_mask.sum()
mse_loss_scaler = accelerator.gather(mse_loss).mean().item()
loss_for_log["overall_avg_sample_mse"] += mse_loss_scaler
l2_loss = loss.sqrt() / (expanded_state_norm + 1e-3)
l2_loss = (l2_loss * expanded_state_elem_mask).sum() / expanded_state_elem_mask.sum()
l2_loss_scaler = accelerator.gather(l2_loss).mean().item()
loss_for_log["overall_avg_sample_l2err"] += l2_loss_scaler
for name in loss_for_log:
if name in ["overall_avg_sample_mse", "overall_avg_sample_l2err"]:
loss_scaler = loss_for_log[name]
loss_for_log[name] = round(loss_scaler / (args.num_sample_batches), 4)
else:
loss_for_log[name] = round(loss_for_log[name] / loss_counter[name], 4)
rdt.train()
torch.cuda.empty_cache()
return dict(loss_for_log)

521
RDT-1B/train/train.py Normal file
View File

@ -0,0 +1,521 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import copy
import logging
import math
import os
from pathlib import Path
import diffusers
import torch
import torch.utils.checkpoint
import transformers
import yaml
from accelerate import Accelerator
from accelerate.utils import DeepSpeedPlugin, ProjectConfiguration, set_seed
from diffusers.optimization import get_scheduler
from diffusers.utils import is_wandb_available
from huggingface_hub import create_repo, upload_folder
from tqdm.auto import tqdm
from safetensors.torch import load_model
from models.ema_model import EMAModel
from models.multimodal_encoder.siglip_encoder import SiglipVisionTower
from models.multimodal_encoder.t5_encoder import T5Embedder
from models.rdt_runner import RDTRunner
from train.dataset import DataCollatorForVLAConsumerDataset, VLAConsumerDataset
from train.sample import log_sample_res
if is_wandb_available():
import wandb
def save_model_card(repo_id: str, base_model=str, repo_folder=None):
yaml = f"""
---
license: mit
base_model: {base_model}
language:
- en
pipeline_tag: robotics
library_name: transformers
tags:
- robotics
- pytorch
- multimodal
- pretraining
- vla
- diffusion
- rdt
---
"""
model_card = f"""
# RDT - {repo_id}
This is a RDT model derived from {base_model}. The weights were trained using [RDT](https://rdt-robotics.github.io/rdt-robotics/).
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
def train(args, logger):
# Read the config
with open(args.config_path, "r") as fp:
config = yaml.safe_load(fp)
with open(args.model_config_path, "r") as f:
model_config = yaml.safe_load(f)
# print(model_config)
args.output_dir = model_config["checkpoint_path"]
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
accelerator = Accelerator(
deepspeed_plugin=(DeepSpeedPlugin(hf_ds_config=args.deepspeed) if args.deepspeed is not None else None),
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_dir=logging_dir,
project_config=accelerator_project_config,
)
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
filename=args.output_log_path,
filemode='w',
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name,
exist_ok=True,
token=args.hub_token,
).repo_id
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
if args.precomp_lang_embed:
tokenizer, text_encoder = None, None
else:
text_embedder = T5Embedder(
from_pretrained=args.pretrained_text_encoder_name_or_path,
model_max_length=config["dataset"]["tokenizer_max_length"],
device=accelerator.device,
)
tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
vision_encoder = SiglipVisionTower(vision_tower=args.pretrained_vision_encoder_name_or_path, args=None)
image_processor = vision_encoder.image_processor
# Load from a pretrained checkpoint
if args.pretrained_model_name_or_path is not None and not os.path.isfile(args.pretrained_model_name_or_path):
logger.info("Constructing model from pretrained checkpoint.")
rdt = RDTRunner.from_pretrained(args.pretrained_model_name_or_path)
else:
logger.info("Constructing model from provided config.")
# Calculate the image condition length
img_cond_len = (config["common"]["img_history_size"] * config["common"]["num_cameras"] *
vision_encoder.num_patches)
rdt = RDTRunner(
action_dim=config["common"]["state_dim"],
pred_horizon=config["common"]["action_chunk_size"],
config=config["model"],
lang_token_dim=config["model"]["lang_token_dim"],
img_token_dim=config["model"]["img_token_dim"],
state_token_dim=config["model"]["state_token_dim"],
max_lang_cond_len=config["dataset"]["tokenizer_max_length"],
img_cond_len=img_cond_len,
img_pos_embed_config=[
# No initial pos embed in the last grid size
# since we've already done in ViT
(
"image",
(
config["common"]["img_history_size"],
config["common"]["num_cameras"],
-vision_encoder.num_patches,
),
),
],
lang_pos_embed_config=[
# Similarly, no initial pos embed for language
("lang", -config["dataset"]["tokenizer_max_length"]),
],
dtype=weight_dtype,
)
ema_rdt = copy.deepcopy(rdt)
ema_model = EMAModel(
ema_rdt,
update_after_step=config["model"]["ema"]["update_after_step"],
inv_gamma=config["model"]["ema"]["inv_gamma"],
power=config["model"]["ema"]["power"],
min_value=config["model"]["ema"]["min_value"],
max_value=config["model"]["ema"]["max_value"],
)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
# which ensure saving model in huggingface format (config.json + pytorch_model.bin)
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
for model in models:
model_to_save = model.module if hasattr(model, "module") else model # type: ignore
if isinstance(model_to_save, type(accelerator.unwrap_model(rdt))):
model_to_save.save_pretrained(output_dir)
accelerator.register_save_state_pre_hook(save_model_hook)
if args.gradient_checkpointing:
# TODO:
raise NotImplementedError("Gradient checkpointing is not yet implemented.")
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
if args.scale_lr:
args.learning_rate = (args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size *
accelerator.num_processes)
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
# Optimizer creation
params_to_optimize = rdt.parameters()
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
# Dataset and DataLoaders creation:
train_dataset = VLAConsumerDataset(
model_config_path=args.model_config_path, # TODO
config=config["dataset"],
tokenizer=tokenizer,
image_processor=image_processor,
num_cameras=config["common"]["num_cameras"],
img_history_size=config["common"]["img_history_size"],
dataset_type=args.dataset_type,
image_aug=args.image_aug,
cond_mask_prob=args.cond_mask_prob,
cam_ext_mask_prob=args.cam_ext_mask_prob,
state_noise_snr=args.state_noise_snr,
use_hdf5=args.load_from_hdf5,
use_precomp_lang_embed=args.precomp_lang_embed,
)
sample_dataset = VLAConsumerDataset(
model_config_path=args.model_config_path, # TODO
config=config["dataset"],
tokenizer=tokenizer,
image_processor=image_processor,
num_cameras=config["common"]["num_cameras"],
img_history_size=config["common"]["img_history_size"],
dataset_type=args.dataset_type,
image_aug=False,
cond_mask_prob=0,
cam_ext_mask_prob=-1,
state_noise_snr=None,
use_hdf5=args.load_from_hdf5,
use_precomp_lang_embed=args.precomp_lang_embed,
)
data_collator = DataCollatorForVLAConsumerDataset(tokenizer)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
collate_fn=data_collator,
num_workers=args.dataloader_num_workers,
pin_memory=True,
persistent_workers=True,
)
sample_dataloader = torch.utils.data.DataLoader(
sample_dataset,
batch_size=args.sample_batch_size,
shuffle=True,
collate_fn=data_collator,
num_workers=args.dataloader_num_workers,
pin_memory=True,
persistent_workers=True,
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
# Prepare everything with our `accelerator`.
rdt, optimizer, train_dataloader, sample_dataloader, lr_scheduler = (accelerator.prepare(
rdt, optimizer, train_dataloader, sample_dataloader, lr_scheduler))
ema_rdt.to(accelerator.device, dtype=weight_dtype)
if text_encoder is not None:
text_encoder.to(accelerator.device, dtype=weight_dtype)
if vision_encoder is not None:
vision_encoder.vision_tower.to(accelerator.device, dtype=weight_dtype)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
accelerator.init_trackers(
"VLA",
config=vars(args),
init_kwargs={"wandb": {
"name": f"RoboTwin_RDT_{args.CONFIG_NAME}",
}},
)
# Train!
total_batch_size = (args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps)
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
# Load from a pretrained checkpoint
if (args.resume_from_checkpoint is None and args.pretrained_model_name_or_path is not None
and os.path.isfile(args.pretrained_model_name_or_path)):
# Since EMA is deprecated, we do not load EMA from the pretrained checkpoint
logger.info("Loading from a pretrained checkpoint.")
checkpoint = torch.load(args.pretrained_model_name_or_path)
rdt.module.load_state_dict(checkpoint["module"])
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the mos recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.")
args.resume_from_checkpoint = None
else:
accelerator.print(f"Resuming from checkpoint {path}")
try:
accelerator.load_state(os.path.join(args.output_dir, path)) # load_module_strict=False
except:
# load deepspeed's state_dict
logger.info("Resuming training state failed. Attempting to only load from model checkpoint.")
checkpoint = torch.load(
os.path.join(
args.output_dir,
path,
"pytorch_model",
"mp_rank_00_model_states.pt",
))
rdt.module.load_state_dict(checkpoint["module"])
load_model(ema_rdt, os.path.join(args.output_dir, path, "ema", "model.safetensors"))
global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps
first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
# Only show the progress bar once on each machine.
progress_bar = tqdm(
range(global_step, args.max_train_steps),
disable=not accelerator.is_local_main_process,
)
progress_bar.set_description("Steps")
loss_for_log = {}
for epoch in range(first_epoch, args.num_train_epochs):
rdt.train()
# Set the progress_bar to correct position
if args.resume_from_checkpoint and epoch == first_epoch:
progress_bar.update(resume_step // args.gradient_accumulation_steps)
# Forward and backward...
for batch in train_dataloader:
with accelerator.accumulate(rdt):
images = batch["images"].to(dtype=weight_dtype)
states = batch["states"].to(dtype=weight_dtype) # (B, T, D_a)
# We only use the last state as input
states = states[:, -1:, :]
actions = batch["actions"].to(dtype=weight_dtype)
state_elem_mask = batch["state_elem_mask"].to(dtype=weight_dtype)
ctrl_freqs = batch["ctrl_freqs"]
with torch.no_grad():
batch_size, _, C, H, W = images.shape
image_embeds = vision_encoder(images.reshape(-1, C, H, W)).detach()
image_embeds = image_embeds.reshape((batch_size, -1, vision_encoder.hidden_size))
lang_attn_mask = batch["lang_attn_mask"]
text_embeds = (batch["lang_embeds"].to(
dtype=weight_dtype) if args.precomp_lang_embed else text_encoder(
input_ids=batch["input_ids"], attention_mask=lang_attn_mask)["last_hidden_state"].detach())
state_elem_mask = state_elem_mask.unsqueeze(1)
loss = rdt(
lang_tokens=text_embeds,
lang_attn_mask=lang_attn_mask,
img_tokens=image_embeds,
state_tokens=states,
action_gt=actions,
action_mask=state_elem_mask,
ctrl_freqs=ctrl_freqs,
)
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = rdt.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=args.set_grads_to_none)
ema_model.step(accelerator.unwrap_model(rdt))
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if global_step % args.checkpointing_period == 0:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
ema_save_path = os.path.join(save_path, f"ema")
accelerator.save_model(ema_rdt, ema_save_path)
logger.info(f"Saved state to {save_path}")
if args.sample_period > 0 and global_step % args.sample_period == 0:
sample_loss_for_log = log_sample_res(
text_encoder,
vision_encoder,
rdt, # We do not use EMA currently
args,
accelerator,
weight_dtype,
sample_dataset.get_dataset_id2name(),
sample_dataloader,
logger,
)
logger.info(sample_loss_for_log)
accelerator.log(sample_loss_for_log, step=global_step)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
logs.update(loss_for_log)
# logger.info(logs)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
# Create the pipeline using using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
accelerator.unwrap_model(rdt).save_pretrained(args.output_dir)
ema_save_path = os.path.join(args.output_dir, f"ema")
accelerator.save_model(ema_rdt, ema_save_path)
logger.info(f"Saved Model to {args.output_dir}")
if args.push_to_hub:
save_model_card(
repo_id,
base_model=args.pretrained_model_name_or_path,
repo_folder=args.output_dir,
)
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
token=args.hub_token,
allow_patterns=["pytorch_model.bin", "*.json", "*.md"],
# ignore_patterns=["step_*", "epoch_*"],
)
accelerator.end_training()

1
weights Symbolic link
View File

@ -0,0 +1 @@
/home/qi.xiong/Temp/RDT-1B/input/weights