update
This commit is contained in:
commit
52d79bbc5e
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
input/
|
||||
output/
|
||||
Temp/
|
||||
weights/
|
||||
2
RDT-1B/.dockerignore
Normal file
2
RDT-1B/.dockerignore
Normal file
@ -0,0 +1,2 @@
|
||||
input/*
|
||||
output/*
|
||||
7
RDT-1B/.gitignore
vendored
Normal file
7
RDT-1B/.gitignore
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
processed_data/
|
||||
training_data/
|
||||
checkpoints/
|
||||
model_config/*.yml
|
||||
wandb/*
|
||||
!models/
|
||||
!data/
|
||||
45
RDT-1B/Dockerfile
Normal file
45
RDT-1B/Dockerfile
Normal file
@ -0,0 +1,45 @@
|
||||
|
||||
FROM registry.d-robotics.cc/public/cuda:11.8.0-cudnn8-devel-ubuntu22.04
|
||||
# ccr-29eug8s3-pub.cnc.bj.baidubce.com/public/cuda:11.8.0-cudnn8-devel-ubuntu22.04
|
||||
WORKDIR /app
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV TZ=Asia/Shanghai
|
||||
|
||||
RUN sed -i 's/archive.ubuntu.com/mirrors.tuna.tsinghua.edu.cn/g' /etc/apt/sources.list && \
|
||||
sed -i 's/security.ubuntu.com/mirrors.tuna.tsinghua.edu.cn/g' /etc/apt/sources.list
|
||||
|
||||
RUN apt-get update && apt-get install -y \
|
||||
software-properties-common \
|
||||
&& add-apt-repository ppa:deadsnakes/ppa \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y \
|
||||
python3.10 \
|
||||
python3.10-dev \
|
||||
python3.10-distutils \
|
||||
libgl1-mesa-glx \
|
||||
libglib2.0-0 \
|
||||
wget \
|
||||
ffmpeg \
|
||||
libsm6 \
|
||||
libxext6 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1
|
||||
|
||||
COPY . /app/
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
|
||||
RUN pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
RUN pip3 install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
RUN pip install packaging==24.0
|
||||
|
||||
RUN pip install flash_attn-2.7.2.post1+cu12torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
||||
|
||||
RUN mkdir -p /app/dataset/input /app/dataset/output
|
||||
|
||||
ENTRYPOINT ["bash", "finetune.sh"]
|
||||
1
RDT-1B/__init__.py
Normal file
1
RDT-1B/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .deploy_policy import *
|
||||
BIN
RDT-1B/assets/head.png
Normal file
BIN
RDT-1B/assets/head.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 726 KiB |
BIN
RDT-1B/configs/__pycache__/state_vec.cpython-310.pyc
Normal file
BIN
RDT-1B/configs/__pycache__/state_vec.cpython-310.pyc
Normal file
Binary file not shown.
71
RDT-1B/configs/base.yaml
Normal file
71
RDT-1B/configs/base.yaml
Normal file
@ -0,0 +1,71 @@
|
||||
common:
|
||||
# The number of historical images
|
||||
img_history_size: 2
|
||||
# The number of future actions to predict
|
||||
action_chunk_size: 64
|
||||
# The number of cameras to be used in the model
|
||||
num_cameras: 3
|
||||
# Dimension for state/action, we use the same space for both state and action
|
||||
# This MUST be equal to configs/state_vec.py
|
||||
state_dim: 128
|
||||
|
||||
|
||||
dataset:
|
||||
# We will extract the data from raw dataset
|
||||
# and store them in the disk buffer by producer
|
||||
# When training, we will read the data
|
||||
# randomly from the buffer by consumer
|
||||
# The producer will replace the data which has been
|
||||
# read by the consumer with new data
|
||||
|
||||
# The path to the buffer (at least 400GB)
|
||||
buf_path: /path/to/buffer
|
||||
# The number of chunks in the buffer
|
||||
buf_num_chunks: 512
|
||||
# The number of samples (step rather than episode) in each chunk
|
||||
buf_chunk_size: 512
|
||||
|
||||
# We will filter the episodes with length less than `epsd_len_thresh_low`
|
||||
epsd_len_thresh_low: 32
|
||||
# For those more than `epsd_len_thresh_high`,
|
||||
# we will randomly sample `epsd_len_thresh_high` steps each time we load the episode
|
||||
# to better balance the training datasets
|
||||
epsd_len_thresh_high: 2048
|
||||
# How to fit the image size
|
||||
image_aspect_ratio: pad
|
||||
# Maximum number of language tokens
|
||||
tokenizer_max_length: 1024
|
||||
|
||||
model:
|
||||
# Config for condition adpators
|
||||
lang_adaptor: mlp2x_gelu
|
||||
img_adaptor: mlp2x_gelu
|
||||
state_adaptor: mlp3x_gelu
|
||||
lang_token_dim: 4096
|
||||
img_token_dim: 1152
|
||||
# Dim of action or proprioception vector
|
||||
# A `state` refers to an action or a proprioception vector
|
||||
state_token_dim: 128
|
||||
# Config for RDT structure
|
||||
rdt:
|
||||
# 1B: num_head 32 hidden_size 2048
|
||||
hidden_size: 2048
|
||||
depth: 28
|
||||
num_heads: 32
|
||||
cond_pos_embed_type: multimodal
|
||||
# For noise scheduler
|
||||
noise_scheduler:
|
||||
type: ddpm
|
||||
num_train_timesteps: 1000
|
||||
num_inference_timesteps: 5
|
||||
beta_schedule: squaredcos_cap_v2 # Critical choice
|
||||
prediction_type: sample
|
||||
clip_sample: False
|
||||
# For EMA (params averaging)
|
||||
# We do not use EMA currently
|
||||
ema:
|
||||
update_after_step: 0
|
||||
inv_gamma: 1.0
|
||||
power: 0.75
|
||||
min_value: 0.0
|
||||
max_value: 0.9999
|
||||
@ -0,0 +1,50 @@
|
||||
{
|
||||
"A": [
|
||||
[
|
||||
-0.2691913843154907,
|
||||
-0.21995729207992554,
|
||||
-0.182277649641037
|
||||
],
|
||||
[
|
||||
0.35127854347229004,
|
||||
0.2769763469696045,
|
||||
0.17159393429756165
|
||||
]
|
||||
],
|
||||
"B": [
|
||||
[
|
||||
-0.2576896846294403,
|
||||
-0.22244493663311005,
|
||||
-0.20557966828346252
|
||||
],
|
||||
[
|
||||
0.32854634523391724,
|
||||
0.2922680974006653,
|
||||
0.17373555898666382
|
||||
]
|
||||
],
|
||||
"C": [
|
||||
[
|
||||
-0.29205888509750366,
|
||||
-0.24688798189163208,
|
||||
-0.17577645182609558
|
||||
],
|
||||
[
|
||||
0.25053921341896057,
|
||||
0.3277084231376648,
|
||||
0.16431939601898193
|
||||
]
|
||||
],
|
||||
"D": [
|
||||
[
|
||||
-0.25131964683532715,
|
||||
-0.15233077108860016,
|
||||
-0.13294968008995056
|
||||
],
|
||||
[
|
||||
0.19209328293800354,
|
||||
0.19344553351402283,
|
||||
0.1370421051979065
|
||||
]
|
||||
]
|
||||
}
|
||||
65
RDT-1B/configs/dataset_control_freq.json
Normal file
65
RDT-1B/configs/dataset_control_freq.json
Normal file
@ -0,0 +1,65 @@
|
||||
{
|
||||
"fractal20220817_data": 3,
|
||||
"taco_play": 15,
|
||||
"jaco_play": 10,
|
||||
"berkeley_cable_routing": 10,
|
||||
"nyu_door_opening_surprising_effectiveness": 3,
|
||||
"viola": 20,
|
||||
"berkeley_autolab_ur5": 5,
|
||||
"toto": 30,
|
||||
"kuka": 10,
|
||||
"language_table": 10,
|
||||
"columbia_cairlab_pusht_real": 10,
|
||||
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": 20,
|
||||
"nyu_rot_dataset_converted_externally_to_rlds":3,
|
||||
"stanford_hydra_dataset_converted_externally_to_rlds": 10,
|
||||
"austin_buds_dataset_converted_externally_to_rlds": 20,
|
||||
"nyu_franka_play_dataset_converted_externally_to_rlds": 3,
|
||||
"maniskill_dataset_converted_externally_to_rlds": 20,
|
||||
"furniture_bench_dataset_converted_externally_to_rlds": 10,
|
||||
"ucsd_kitchen_dataset_converted_externally_to_rlds": 2,
|
||||
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": 3,
|
||||
"austin_sailor_dataset_converted_externally_to_rlds": 20,
|
||||
"austin_sirius_dataset_converted_externally_to_rlds": 20,
|
||||
"bc_z": 10,
|
||||
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": 10,
|
||||
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": 10,
|
||||
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": 10,
|
||||
"utokyo_xarm_bimanual_converted_externally_to_rlds": 10,
|
||||
"berkeley_mvp_converted_externally_to_rlds": 5,
|
||||
"berkeley_rpt_converted_externally_to_rlds": 30,
|
||||
"kaist_nonprehensile_converted_externally_to_rlds": 10,
|
||||
"stanford_mask_vit_converted_externally_to_rlds": 0,
|
||||
"tokyo_u_lsmo_converted_externally_to_rlds": 10,
|
||||
"dlr_sara_pour_converted_externally_to_rlds": 10,
|
||||
"dlr_sara_grid_clamp_converted_externally_to_rlds": 10,
|
||||
"dlr_edan_shared_control_converted_externally_to_rlds": 5,
|
||||
"asu_table_top_converted_externally_to_rlds": 12.5,
|
||||
"stanford_robocook_converted_externally_to_rlds": 5,
|
||||
"eth_agent_affordances": 66.6,
|
||||
"imperialcollege_sawyer_wrist_cam": 10,
|
||||
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": 20,
|
||||
"uiuc_d3field": 1,
|
||||
"utaustin_mutex": 20,
|
||||
"berkeley_fanuc_manipulation": 10,
|
||||
"cmu_play_fusion": 5,
|
||||
"cmu_stretch": 10,
|
||||
"berkeley_gnm_recon": 3,
|
||||
"berkeley_gnm_cory_hall": 5,
|
||||
"berkeley_gnm_sac_son": 10,
|
||||
"robo_net": 1,
|
||||
"roboturk_real_towercreation": 10,
|
||||
"roboturk_real_laundrylayout": 10,
|
||||
"roboturk_real_objectsearch": 10,
|
||||
"aloha_mobile": 50,
|
||||
"aloha_static": 50,
|
||||
"roboset": 5,
|
||||
"droid": 15,
|
||||
"fmb": 10,
|
||||
"dobbe": 30,
|
||||
"qut_dexterous_manpulation": 30,
|
||||
"agilex": 25,
|
||||
"rh20t": 10,
|
||||
"calvin": 30,
|
||||
"bridgev2": 5
|
||||
}
|
||||
575
RDT-1B/configs/dataset_img_keys.json
Normal file
575
RDT-1B/configs/dataset_img_keys.json
Normal file
@ -0,0 +1,575 @@
|
||||
{
|
||||
"fractal20220817_data": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"image",
|
||||
"image",
|
||||
"image"
|
||||
],
|
||||
"image_mask":[
|
||||
1,0,0,0
|
||||
]
|
||||
},
|
||||
"taco_play": {
|
||||
"image_keys": [
|
||||
"rgb_static",
|
||||
"rgb_gripper",
|
||||
"rgb_static",
|
||||
"rgb_static"
|
||||
],
|
||||
"image_mask":[
|
||||
1,1,0,0
|
||||
]
|
||||
},
|
||||
"jaco_play": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"image_wrist",
|
||||
"image_wrist",
|
||||
"image_wrist"
|
||||
],
|
||||
"image_mask":[
|
||||
1,1,0,0
|
||||
]
|
||||
},
|
||||
"berkeley_cable_routing": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"wrist45_image",
|
||||
"wrist225_image",
|
||||
"top_image"
|
||||
],
|
||||
"image_mask":[1,1,0,1]
|
||||
},
|
||||
"nyu_door_opening_surprising_effectiveness": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"image",
|
||||
"image",
|
||||
"image"
|
||||
],
|
||||
"image_mask":[1,0,0,0]
|
||||
},
|
||||
"viola": {
|
||||
"image_keys": [
|
||||
"agentview_rgb",
|
||||
"eye_in_hand_rgb",
|
||||
"eye_in_hand_rgb",
|
||||
"eye_in_hand_rgb"
|
||||
],
|
||||
"image_mask":[1,1,0,0]
|
||||
},
|
||||
"berkeley_autolab_ur5": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"hand_image",
|
||||
"hand_image",
|
||||
"hand_image"
|
||||
],
|
||||
"image_mask":[1,1,0,0]
|
||||
},
|
||||
"toto": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"image",
|
||||
"image",
|
||||
"image"
|
||||
],
|
||||
"image_mask":[1,0,0,0]
|
||||
},
|
||||
"kuka": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"image",
|
||||
"image",
|
||||
"image"
|
||||
],
|
||||
"image_mask":[1,0,0,0]
|
||||
},
|
||||
"language_table": {
|
||||
"image_keys": [
|
||||
"rgb",
|
||||
"rgb",
|
||||
"rgb",
|
||||
"rgb"
|
||||
],
|
||||
"image_mask":[1,0,0,0]
|
||||
},
|
||||
"columbia_cairlab_pusht_real": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"wrist_image",
|
||||
"wrist_image",
|
||||
"wrist_image"
|
||||
],
|
||||
"image_mask":[1,1,0,0]
|
||||
},
|
||||
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"image",
|
||||
"image",
|
||||
"image"
|
||||
],
|
||||
"image_mask":[1,0,0,0]
|
||||
},
|
||||
"nyu_rot_dataset_converted_externally_to_rlds": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"image",
|
||||
"image",
|
||||
"image"
|
||||
],
|
||||
"image_mask":[1,0,0,0]
|
||||
},
|
||||
"stanford_hydra_dataset_converted_externally_to_rlds": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"wrist_image",
|
||||
"wrist_image",
|
||||
"wrist_image"
|
||||
],
|
||||
"image_mask":[1,1,0,0]
|
||||
},
|
||||
"austin_buds_dataset_converted_externally_to_rlds": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"wrist_image",
|
||||
"wrist_image",
|
||||
"wrist_image"
|
||||
],
|
||||
"image_mask":[1,1,0,0]
|
||||
},
|
||||
"nyu_franka_play_dataset_converted_externally_to_rlds": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"image_additional_view",
|
||||
"image_additional_view",
|
||||
"image_additional_view"
|
||||
],
|
||||
"image_mask":[1,0,0,1]
|
||||
},
|
||||
"maniskill_dataset_converted_externally_to_rlds": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"wrist_image",
|
||||
"wrist_image",
|
||||
"wrist_image"
|
||||
],
|
||||
"image_mask":[1,1,0,0]
|
||||
},
|
||||
"furniture_bench_dataset_converted_externally_to_rlds": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"wrist_image",
|
||||
"wrist_image",
|
||||
"wrist_image"
|
||||
],
|
||||
"image_mask":[1,1,0,0]
|
||||
},
|
||||
"ucsd_kitchen_dataset_converted_externally_to_rlds": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"image",
|
||||
"image",
|
||||
"image"
|
||||
],
|
||||
"image_mask":[1,0,0,0]
|
||||
},
|
||||
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"image",
|
||||
"image",
|
||||
"image"
|
||||
],
|
||||
"image_mask":[1,0,0,0]
|
||||
},
|
||||
"austin_sailor_dataset_converted_externally_to_rlds": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"wrist_image",
|
||||
"wrist_image",
|
||||
"wrist_image"
|
||||
],
|
||||
"image_mask":[1,1,0,0]
|
||||
},
|
||||
"austin_sirius_dataset_converted_externally_to_rlds": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"wrist_image",
|
||||
"wrist_image",
|
||||
"wrist_image"
|
||||
],
|
||||
"image_mask":[1,1,0,0]
|
||||
},
|
||||
"bc_z": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"image",
|
||||
"image",
|
||||
"image"
|
||||
],
|
||||
"image_mask":[1,0,0,0]
|
||||
},
|
||||
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"image",
|
||||
"image",
|
||||
"image"
|
||||
],
|
||||
"image_mask":[1,0,0,0]
|
||||
},
|
||||
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"image",
|
||||
"image",
|
||||
"image"
|
||||
],
|
||||
"image_mask":[1,0,0,0]
|
||||
},
|
||||
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"hand_image",
|
||||
"hand_image",
|
||||
"image2"
|
||||
],
|
||||
"image_mask":[1,1,0,1]
|
||||
},
|
||||
"utokyo_xarm_bimanual_converted_externally_to_rlds": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"image",
|
||||
"image",
|
||||
"image"
|
||||
],
|
||||
"image_mask":[1,0,0,0]
|
||||
},
|
||||
"berkeley_mvp_converted_externally_to_rlds": {
|
||||
"image_keys": [
|
||||
"hand_image",
|
||||
"hand_image",
|
||||
"hand_image",
|
||||
"hand_image"
|
||||
],
|
||||
"image_mask":[0,1,0,0]
|
||||
},
|
||||
"berkeley_rpt_converted_externally_to_rlds": {
|
||||
"image_keys": [
|
||||
"hand_image",
|
||||
"hand_image",
|
||||
"hand_image",
|
||||
"hand_image"
|
||||
],
|
||||
"image_mask":[0,1,0,0]
|
||||
},
|
||||
"kaist_nonprehensile_converted_externally_to_rlds": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"image",
|
||||
"image",
|
||||
"image"
|
||||
],
|
||||
"image_mask":[1,0,0,0]
|
||||
},
|
||||
"stanford_mask_vit_converted_externally_to_rlds": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"image",
|
||||
"image",
|
||||
"image"
|
||||
],
|
||||
"image_mask":[1,0,0,0]
|
||||
},
|
||||
"tokyo_u_lsmo_converted_externally_to_rlds": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"image",
|
||||
"image",
|
||||
"image"
|
||||
],
|
||||
"image_mask":[1,0,0,0]
|
||||
},
|
||||
"dlr_sara_pour_converted_externally_to_rlds": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"image",
|
||||
"image",
|
||||
"image"
|
||||
],
|
||||
"image_mask":[1,0,0,0]
|
||||
},
|
||||
"dlr_sara_grid_clamp_converted_externally_to_rlds": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"image",
|
||||
"image",
|
||||
"image"
|
||||
],
|
||||
"image_mask":[1,0,0,0]
|
||||
},
|
||||
"dlr_edan_shared_control_converted_externally_to_rlds": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"image",
|
||||
"image",
|
||||
"image"
|
||||
],
|
||||
"image_mask":[1,0,0,0]
|
||||
},
|
||||
"asu_table_top_converted_externally_to_rlds": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"image",
|
||||
"image",
|
||||
"image"
|
||||
],
|
||||
"image_mask":[1,0,0,0]
|
||||
},
|
||||
"stanford_robocook_converted_externally_to_rlds": {
|
||||
"image_keys": [
|
||||
"image_2",
|
||||
"image_1",
|
||||
"image_3",
|
||||
"image_4"
|
||||
],
|
||||
"image_mask":[1,0,0,1]
|
||||
},
|
||||
"eth_agent_affordances": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"image",
|
||||
"image",
|
||||
"image"
|
||||
],
|
||||
"image_mask":[1,0,0,0]
|
||||
},
|
||||
"imperialcollege_sawyer_wrist_cam": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"wrist_image",
|
||||
"wrist_image",
|
||||
"wrist_image"
|
||||
],
|
||||
"image_mask":[0,1,0,0]
|
||||
},
|
||||
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"wrist_image",
|
||||
"wrist_image",
|
||||
"wrist_image"
|
||||
],
|
||||
"image_mask":[1,1,0,0]
|
||||
},
|
||||
"uiuc_d3field": {
|
||||
"image_keys": [
|
||||
"image_1",
|
||||
"image_2",
|
||||
"image_3",
|
||||
"image_4"
|
||||
],
|
||||
"image_mask":[1,0,0,1]
|
||||
},
|
||||
"utaustin_mutex": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"wrist_image",
|
||||
"wrist_image",
|
||||
"wrist_image"
|
||||
],
|
||||
"image_mask":[1,1,0,0]
|
||||
},
|
||||
"berkeley_fanuc_manipulation": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"wrist_image",
|
||||
"wrist_image",
|
||||
"wrist_image"
|
||||
],
|
||||
"image_mask":[1,1,0,0]
|
||||
},
|
||||
"cmu_play_fusion": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"image",
|
||||
"image",
|
||||
"image"
|
||||
],
|
||||
"image_mask":[1,0,0,0]
|
||||
},
|
||||
"cmu_stretch": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"image",
|
||||
"image",
|
||||
"image"
|
||||
],
|
||||
"image_mask":[1,0,0,0]
|
||||
},
|
||||
"berkeley_gnm_recon": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"image",
|
||||
"image",
|
||||
"image"
|
||||
],
|
||||
"image_mask":[1,0,0,0]
|
||||
},
|
||||
"berkeley_gnm_cory_hall": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"image",
|
||||
"image",
|
||||
"image"
|
||||
],
|
||||
"image_mask":[1,0,0,0]
|
||||
},
|
||||
"berkeley_gnm_sac_son": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"image",
|
||||
"image",
|
||||
"image"
|
||||
],
|
||||
"image_mask":[1,0,0,0]
|
||||
},
|
||||
"robo_net": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"image1",
|
||||
"image2",
|
||||
"image2"
|
||||
],
|
||||
"image_mask":[1,0,0,1]
|
||||
},
|
||||
"roboturk_real_towercreation": {
|
||||
"image_keys": [
|
||||
"top_rgb_frame",
|
||||
"front_rgb_frame",
|
||||
"front_rgb_frame",
|
||||
"front_rgb_frame"
|
||||
],
|
||||
"image_mask":[1,0,0,1]
|
||||
},
|
||||
"roboturk_real_laundrylayout": {
|
||||
"image_keys": [
|
||||
"top_rgb_frame",
|
||||
"front_rgb_frame",
|
||||
"front_rgb_frame",
|
||||
"front_rgb_frame"
|
||||
],
|
||||
"image_mask":[1,0,0,1]
|
||||
},
|
||||
"roboturk_real_objectsearch": {
|
||||
"image_keys": [
|
||||
"top_rgb_frame",
|
||||
"front_rgb_frame",
|
||||
"front_rgb_frame",
|
||||
"front_rgb_frame"
|
||||
],
|
||||
"image_mask":[1,0,0,1]
|
||||
},
|
||||
"aloha_mobile": {
|
||||
"image_keys": [
|
||||
"cam_high",
|
||||
"cam_right_wrist",
|
||||
"cam_left_wrist",
|
||||
"cam_right_wrist"
|
||||
],
|
||||
"image_mask":[1,1,1,0]
|
||||
},
|
||||
"aloha_static": {
|
||||
"image_keys": [
|
||||
"cam_high",
|
||||
"cam_right_wrist",
|
||||
"cam_left_wrist",
|
||||
"cam_low"
|
||||
],
|
||||
"image_mask":[1,1,1,1]
|
||||
},
|
||||
"roboset": {
|
||||
"image_keys": [
|
||||
"rgb_top",
|
||||
"rgb_right",
|
||||
"rgb_left",
|
||||
"rgb_right"
|
||||
],
|
||||
"image_mask":[1,1,1,0]
|
||||
},
|
||||
"droid": {
|
||||
"image_keys": [
|
||||
"exterior_image_1_left",
|
||||
"wrist_image_left",
|
||||
"wrist_image_left",
|
||||
"exterior_image_2_left"
|
||||
],
|
||||
"image_mask":[1,1,0,1]
|
||||
},
|
||||
"fmb": {
|
||||
"image_keys": [
|
||||
"image_side_1",
|
||||
"image_wrist_1",
|
||||
"image_wrist_1",
|
||||
"image_side_2"
|
||||
],
|
||||
"image_mask":[1,1,0,1]
|
||||
},
|
||||
"dobbe": {
|
||||
"image_keys": [
|
||||
"wrist_image",
|
||||
"wrist_image",
|
||||
"wrist_image",
|
||||
"wrist_image"
|
||||
],
|
||||
"image_mask":[0,1,0,0]
|
||||
},
|
||||
"qut_dexterous_manpulation": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"wrist_image",
|
||||
"wrist_image",
|
||||
"wrist_image"
|
||||
],
|
||||
"image_mask":[1,1,0,0]
|
||||
},
|
||||
"agilex": {
|
||||
"image_keys": [
|
||||
"cam_high",
|
||||
"cam_right_wrist",
|
||||
"cam_left_wrist",
|
||||
"cam_right_wrist"
|
||||
],
|
||||
"image_mask":[1,1,1,0]
|
||||
},
|
||||
"rh20t": {
|
||||
"image_keys": [
|
||||
"image",
|
||||
"image",
|
||||
"image",
|
||||
"image"
|
||||
],
|
||||
"image_mask":[1,0,0,0]
|
||||
},
|
||||
"calvin": {
|
||||
"image_keys": [
|
||||
"rgb_static",
|
||||
"rgb_gripper",
|
||||
"rgb_gripper",
|
||||
"rgb_gripper"
|
||||
],
|
||||
"image_mask":[1,1,0,0]
|
||||
},
|
||||
"bridgev2": {
|
||||
"image_keys": [
|
||||
"images0",
|
||||
"images0",
|
||||
"images0",
|
||||
"images0"
|
||||
],
|
||||
"image_mask":[1,0,0,0]
|
||||
}
|
||||
}
|
||||
525
RDT-1B/configs/dataset_stat.json
Normal file
525
RDT-1B/configs/dataset_stat.json
Normal file
@ -0,0 +1,525 @@
|
||||
{
|
||||
"agilex": {
|
||||
"dataset_name": "agilex",
|
||||
"state_mean": [
|
||||
-0.0036545392947090432,
|
||||
-0.2773659935760079,
|
||||
0.3147616748061523,
|
||||
0.3813313179910183,
|
||||
0.04028575944090457,
|
||||
0.034888520819083294,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0
|
||||
],
|
||||
"state_std": [
|
||||
0.05763674563578847,
|
||||
0.2580181064167735,
|
||||
0.19785840483767897,
|
||||
0.05020347749331385,
|
||||
0.054529239104671424,
|
||||
0.05020521339363586,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0
|
||||
],
|
||||
"state_min": [
|
||||
-0.17447535196940103,
|
||||
-0.5522612677680121,
|
||||
-0.3340397516886393,
|
||||
0.21861712137858072,
|
||||
-0.09725829230414497,
|
||||
0.003396739231215583,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0
|
||||
],
|
||||
"state_max": [
|
||||
0.21961932712131077,
|
||||
0.30613206227620443,
|
||||
0.5444545321994357,
|
||||
0.4866888682047526,
|
||||
0.31486290825737845,
|
||||
0.3355223337809245,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0
|
||||
]
|
||||
}
|
||||
}
|
||||
3
RDT-1B/configs/finetune_datasets.json
Normal file
3
RDT-1B/configs/finetune_datasets.json
Normal file
@ -0,0 +1,3 @@
|
||||
[
|
||||
"agilex"
|
||||
]
|
||||
3
RDT-1B/configs/finetune_sample_weights.json
Normal file
3
RDT-1B/configs/finetune_sample_weights.json
Normal file
@ -0,0 +1,3 @@
|
||||
{
|
||||
"agilex": 100
|
||||
}
|
||||
48
RDT-1B/configs/pretrain_datasets.json
Normal file
48
RDT-1B/configs/pretrain_datasets.json
Normal file
@ -0,0 +1,48 @@
|
||||
[
|
||||
"fractal20220817_data",
|
||||
"jaco_play",
|
||||
"taco_play",
|
||||
"berkeley_cable_routing",
|
||||
"viola",
|
||||
"berkeley_autolab_ur5",
|
||||
"toto",
|
||||
"nyu_door_opening_surprising_effectiveness",
|
||||
"columbia_cairlab_pusht_real",
|
||||
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds",
|
||||
"austin_buds_dataset_converted_externally_to_rlds",
|
||||
"kuka",
|
||||
"utokyo_xarm_bimanual_converted_externally_to_rlds",
|
||||
"stanford_hydra_dataset_converted_externally_to_rlds",
|
||||
"maniskill_dataset_converted_externally_to_rlds",
|
||||
"ucsd_kitchen_dataset_converted_externally_to_rlds",
|
||||
"ucsd_pick_and_place_dataset_converted_externally_to_rlds",
|
||||
"austin_sailor_dataset_converted_externally_to_rlds",
|
||||
"austin_sirius_dataset_converted_externally_to_rlds",
|
||||
"bc_z",
|
||||
"utokyo_pr2_opening_fridge_converted_externally_to_rlds",
|
||||
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds",
|
||||
"utokyo_xarm_pick_and_place_converted_externally_to_rlds",
|
||||
"berkeley_mvp_converted_externally_to_rlds",
|
||||
"berkeley_rpt_converted_externally_to_rlds",
|
||||
"kaist_nonprehensile_converted_externally_to_rlds",
|
||||
"tokyo_u_lsmo_converted_externally_to_rlds",
|
||||
"dlr_sara_grid_clamp_converted_externally_to_rlds",
|
||||
"stanford_robocook_converted_externally_to_rlds",
|
||||
"imperialcollege_sawyer_wrist_cam",
|
||||
"iamlab_cmu_pickup_insert_converted_externally_to_rlds",
|
||||
"utaustin_mutex",
|
||||
"berkeley_fanuc_manipulation",
|
||||
"cmu_play_fusion",
|
||||
"language_table",
|
||||
"furniture_bench_dataset_converted_externally_to_rlds",
|
||||
"droid",
|
||||
"fmb",
|
||||
"dobbe",
|
||||
"qut_dexterous_manpulation",
|
||||
"aloha_mobile",
|
||||
"aloha_static",
|
||||
"roboset",
|
||||
"rh20t",
|
||||
"calvin",
|
||||
"bridgev2"
|
||||
]
|
||||
48
RDT-1B/configs/pretrain_sample_weights.json
Normal file
48
RDT-1B/configs/pretrain_sample_weights.json
Normal file
@ -0,0 +1,48 @@
|
||||
{
|
||||
"fractal20220817_data": 271,
|
||||
"taco_play": 60,
|
||||
"jaco_play": 33,
|
||||
"berkeley_cable_routing": 8,
|
||||
"nyu_door_opening_surprising_effectiveness": 10,
|
||||
"viola": 12,
|
||||
"berkeley_autolab_ur5": 32,
|
||||
"toto": 32,
|
||||
"kuka": 50,
|
||||
"language_table": 100,
|
||||
"columbia_cairlab_pusht_real": 12,
|
||||
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": 55,
|
||||
"stanford_hydra_dataset_converted_externally_to_rlds": 24,
|
||||
"austin_buds_dataset_converted_externally_to_rlds": 7,
|
||||
"maniskill_dataset_converted_externally_to_rlds": 174,
|
||||
"furniture_bench_dataset_converted_externally_to_rlds": 71,
|
||||
"ucsd_kitchen_dataset_converted_externally_to_rlds": 12,
|
||||
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": 37,
|
||||
"austin_sailor_dataset_converted_externally_to_rlds": 15,
|
||||
"austin_sirius_dataset_converted_externally_to_rlds": 24,
|
||||
"bc_z": 208,
|
||||
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": 9,
|
||||
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": 15,
|
||||
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": 10,
|
||||
"utokyo_xarm_bimanual_converted_externally_to_rlds": 1,
|
||||
"berkeley_mvp_converted_externally_to_rlds": 22,
|
||||
"berkeley_rpt_converted_externally_to_rlds": 30,
|
||||
"kaist_nonprehensile_converted_externally_to_rlds": 14,
|
||||
"tokyo_u_lsmo_converted_externally_to_rlds": 7,
|
||||
"dlr_sara_grid_clamp_converted_externally_to_rlds": 1,
|
||||
"stanford_robocook_converted_externally_to_rlds": 50,
|
||||
"imperialcollege_sawyer_wrist_cam": 13,
|
||||
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": 25,
|
||||
"utaustin_mutex": 39,
|
||||
"berkeley_fanuc_manipulation": 20,
|
||||
"cmu_play_fusion": 24,
|
||||
"droid": 303,
|
||||
"fmb": 42,
|
||||
"dobbe": 36,
|
||||
"qut_dexterous_manpulation": 14,
|
||||
"aloha_mobile": 150,
|
||||
"aloha_static": 150,
|
||||
"roboset": 135,
|
||||
"rh20t": 331,
|
||||
"calvin": 100,
|
||||
"bridgev2": 224
|
||||
}
|
||||
126
RDT-1B/configs/state_vec.py
Normal file
126
RDT-1B/configs/state_vec.py
Normal file
@ -0,0 +1,126 @@
|
||||
STATE_VEC_IDX_MAPPING = {
|
||||
# [0, 10): right arm joint positions
|
||||
**{
|
||||
"arm_joint_{}_pos".format(i): i
|
||||
for i in range(10)
|
||||
},
|
||||
**{
|
||||
"right_arm_joint_{}_pos".format(i): i
|
||||
for i in range(10)
|
||||
},
|
||||
# [10, 15): right gripper joint positions
|
||||
**{
|
||||
"gripper_joint_{}_pos".format(i): i + 10
|
||||
for i in range(5)
|
||||
},
|
||||
**{
|
||||
"right_gripper_joint_{}_pos".format(i): i + 10
|
||||
for i in range(5)
|
||||
},
|
||||
"gripper_open": 10, # alias of right_gripper_joint_0_pos
|
||||
"right_gripper_open": 10,
|
||||
# [15, 25): right arm joint velocities
|
||||
**{
|
||||
"arm_joint_{}_vel".format(i): i + 15
|
||||
for i in range(10)
|
||||
},
|
||||
**{
|
||||
"right_arm_joint_{}_vel".format(i): i + 15
|
||||
for i in range(10)
|
||||
},
|
||||
# [25, 30): right gripper joint velocities
|
||||
**{
|
||||
"gripper_joint_{}_vel".format(i): i + 25
|
||||
for i in range(5)
|
||||
},
|
||||
**{
|
||||
"right_gripper_joint_{}_vel".format(i): i + 25
|
||||
for i in range(5)
|
||||
},
|
||||
"gripper_open_vel": 25, # alias of right_gripper_joint_0_vel
|
||||
"right_gripper_open_vel": 25,
|
||||
# [30, 33): right end effector positions
|
||||
"eef_pos_x": 30,
|
||||
"right_eef_pos_x": 30,
|
||||
"eef_pos_y": 31,
|
||||
"right_eef_pos_y": 31,
|
||||
"eef_pos_z": 32,
|
||||
"right_eef_pos_z": 32,
|
||||
# [33, 39): right end effector 6D pose
|
||||
"eef_angle_0": 33,
|
||||
"right_eef_angle_0": 33,
|
||||
"eef_angle_1": 34,
|
||||
"right_eef_angle_1": 34,
|
||||
"eef_angle_2": 35,
|
||||
"right_eef_angle_2": 35,
|
||||
"eef_angle_3": 36,
|
||||
"right_eef_angle_3": 36,
|
||||
"eef_angle_4": 37,
|
||||
"right_eef_angle_4": 37,
|
||||
"eef_angle_5": 38,
|
||||
"right_eef_angle_5": 38,
|
||||
# [39, 42): right end effector velocities
|
||||
"eef_vel_x": 39,
|
||||
"right_eef_vel_x": 39,
|
||||
"eef_vel_y": 40,
|
||||
"right_eef_vel_y": 40,
|
||||
"eef_vel_z": 41,
|
||||
"right_eef_vel_z": 41,
|
||||
# [42, 45): right end effector angular velocities
|
||||
"eef_angular_vel_roll": 42,
|
||||
"right_eef_angular_vel_roll": 42,
|
||||
"eef_angular_vel_pitch": 43,
|
||||
"right_eef_angular_vel_pitch": 43,
|
||||
"eef_angular_vel_yaw": 44,
|
||||
"right_eef_angular_vel_yaw": 44,
|
||||
# [45, 50): reserved
|
||||
# [50, 60): left arm joint positions
|
||||
**{
|
||||
"left_arm_joint_{}_pos".format(i): i + 50
|
||||
for i in range(10)
|
||||
},
|
||||
# [60, 65): left gripper joint positions
|
||||
**{
|
||||
"left_gripper_joint_{}_pos".format(i): i + 60
|
||||
for i in range(5)
|
||||
},
|
||||
"left_gripper_open": 60, # alias of left_gripper_joint_0_pos
|
||||
# [65, 75): left arm joint velocities
|
||||
**{
|
||||
"left_arm_joint_{}_vel".format(i): i + 65
|
||||
for i in range(10)
|
||||
},
|
||||
# [75, 80): left gripper joint velocities
|
||||
**{
|
||||
"left_gripper_joint_{}_vel".format(i): i + 75
|
||||
for i in range(5)
|
||||
},
|
||||
"left_gripper_open_vel": 75, # alias of left_gripper_joint_0_vel
|
||||
# [80, 83): left end effector positions
|
||||
"left_eef_pos_x": 80,
|
||||
"left_eef_pos_y": 81,
|
||||
"left_eef_pos_z": 82,
|
||||
# [83, 89): left end effector 6D pose
|
||||
"left_eef_angle_0": 83,
|
||||
"left_eef_angle_1": 84,
|
||||
"left_eef_angle_2": 85,
|
||||
"left_eef_angle_3": 86,
|
||||
"left_eef_angle_4": 87,
|
||||
"left_eef_angle_5": 88,
|
||||
# [89, 92): left end effector velocities
|
||||
"left_eef_vel_x": 89,
|
||||
"left_eef_vel_y": 90,
|
||||
"left_eef_vel_z": 91,
|
||||
# [92, 95): left end effector angular velocities
|
||||
"left_eef_angular_vel_roll": 92,
|
||||
"left_eef_angular_vel_pitch": 93,
|
||||
"left_eef_angular_vel_yaw": 94,
|
||||
# [95, 100): reserved
|
||||
# [100, 102): base linear velocities
|
||||
"base_vel_x": 100,
|
||||
"base_vel_y": 101,
|
||||
# [102, 103): base angular velocities
|
||||
"base_angular_vel": 102,
|
||||
# [103, 128): reserved
|
||||
}
|
||||
STATE_VEC_LEN = 128
|
||||
14
RDT-1B/configs/zero2.json
Normal file
14
RDT-1B/configs/zero2.json
Normal file
@ -0,0 +1,14 @@
|
||||
{
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 1e9
|
||||
}
|
||||
}
|
||||
2
RDT-1B/data/.gitignore
vendored
Normal file
2
RDT-1B/data/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
# Ignore data files
|
||||
datasets
|
||||
Binary file not shown.
BIN
RDT-1B/data/__pycache__/filelock.cpython-310.pyc
Normal file
BIN
RDT-1B/data/__pycache__/filelock.cpython-310.pyc
Normal file
Binary file not shown.
BIN
RDT-1B/data/__pycache__/hdf5_vla_dataset.cpython-310.pyc
Normal file
BIN
RDT-1B/data/__pycache__/hdf5_vla_dataset.cpython-310.pyc
Normal file
Binary file not shown.
154
RDT-1B/data/agilex/hdf5totfrecords.py
Normal file
154
RDT-1B/data/agilex/hdf5totfrecords.py
Normal file
@ -0,0 +1,154 @@
|
||||
import tensorflow as tf
|
||||
import h5py
|
||||
import os
|
||||
import fnmatch
|
||||
import shutil
|
||||
from tqdm import tqdm
|
||||
from multiprocessing import Pool
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _bytes_feature(value):
|
||||
"""Returns a bytes_list from a string / byte."""
|
||||
if isinstance(value, type(tf.constant(0))):
|
||||
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
|
||||
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
|
||||
|
||||
|
||||
def _bool_feature(value):
|
||||
"""Returns a bool_list from a boolean."""
|
||||
return tf.train.Feature(int64_list=tf.train.Int64List(value=[int(value)]))
|
||||
|
||||
|
||||
def serialize_example(
|
||||
action,
|
||||
base_action,
|
||||
qpos,
|
||||
qvel,
|
||||
cam_high,
|
||||
cam_left_wrist,
|
||||
cam_right_wrist,
|
||||
instruction,
|
||||
terminate_episode,
|
||||
):
|
||||
feature = {
|
||||
"action":
|
||||
_bytes_feature(tf.io.serialize_tensor(action)),
|
||||
"base_action":
|
||||
_bytes_feature(tf.io.serialize_tensor(base_action)),
|
||||
"qpos":
|
||||
_bytes_feature(tf.io.serialize_tensor(qpos)),
|
||||
"qvel":
|
||||
_bytes_feature(tf.io.serialize_tensor(qvel)),
|
||||
"cam_high":
|
||||
_bytes_feature(tf.io.serialize_tensor(tf.convert_to_tensor(cam_high.tobytes(), dtype=tf.string))),
|
||||
"cam_left_wrist":
|
||||
_bytes_feature(tf.io.serialize_tensor(tf.convert_to_tensor(cam_left_wrist.tobytes(), dtype=tf.string))),
|
||||
"cam_right_wrist":
|
||||
_bytes_feature(tf.io.serialize_tensor(tf.convert_to_tensor(cam_right_wrist.tobytes(), dtype=tf.string))),
|
||||
"instruction":
|
||||
_bytes_feature(instruction),
|
||||
"terminate_episode":
|
||||
_bool_feature(terminate_episode),
|
||||
}
|
||||
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
|
||||
return example_proto.SerializeToString()
|
||||
|
||||
|
||||
def process_hdf5_file(args):
|
||||
filepath, root_dir, out_dir = args
|
||||
output_dir = os.path.join(out_dir, os.path.relpath(os.path.dirname(filepath), root_dir))
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
filename = os.path.basename(filepath)
|
||||
tfrecord_path = os.path.join(output_dir, filename.replace(".hdf5", ".tfrecord"))
|
||||
|
||||
if os.path.exists(tfrecord_path) and os.path.getsize(tfrecord_path) > 0:
|
||||
return f"TFRecords already exist at {tfrecord_path}"
|
||||
try:
|
||||
with h5py.File(filepath, "r") as f, tf.io.TFRecordWriter(tfrecord_path) as writer:
|
||||
num_episodes = f["action"].shape[0]
|
||||
# Remove the first few still steps
|
||||
EPS = 1e-2
|
||||
qpos = f["observations"]["qpos"][:]
|
||||
# Get the idx of the first qpos whose delta exceeds the threshold
|
||||
qpos_delta = np.abs(qpos - qpos[0:1])
|
||||
indices = np.where(np.any(qpos_delta > EPS, axis=1))[0]
|
||||
if len(indices) > 0:
|
||||
first_idx = indices[0]
|
||||
else:
|
||||
raise ValueError("Found no qpos that exceeds the threshold.")
|
||||
|
||||
for i in range(first_idx - 1, num_episodes):
|
||||
action = f["action"][i]
|
||||
base_action = f["base_action"][i]
|
||||
qpos = f["observations"]["qpos"][i]
|
||||
qvel = f["observations"]["qvel"][i]
|
||||
cam_high = f["observations"]["images"]["cam_high"][i]
|
||||
cam_left_wrist = f["observations"]["images"]["cam_left_wrist"][i]
|
||||
cam_right_wrist = f["observations"]["images"]["cam_right_wrist"][i]
|
||||
instruction = f["instruction"][()]
|
||||
terminate_episode = i == num_episodes - 1
|
||||
serialized_example = serialize_example(
|
||||
action,
|
||||
base_action,
|
||||
qpos,
|
||||
qvel,
|
||||
cam_high,
|
||||
cam_left_wrist,
|
||||
cam_right_wrist,
|
||||
instruction,
|
||||
terminate_episode,
|
||||
)
|
||||
writer.write(serialized_example)
|
||||
except Exception as e:
|
||||
with open("error_log.txt", "a") as f:
|
||||
f.write(f"{filepath}\n")
|
||||
print(f"error at {filepath}: {e}")
|
||||
return f"TFRecords written to {tfrecord_path}"
|
||||
|
||||
|
||||
def write_tfrecords(root_dir, out_dir):
|
||||
if not os.path.exists(out_dir):
|
||||
os.makedirs(out_dir)
|
||||
|
||||
hdf5_files = []
|
||||
for root, dirs, files in os.walk(root_dir):
|
||||
if os.path.exists(os.path.join(root, "expanded_instruction_gpt-4-turbo.json")):
|
||||
# copy the instruction file
|
||||
target_path = os.path.join(out_dir, os.path.relpath(root, root_dir))
|
||||
os.makedirs(target_path, exist_ok=True)
|
||||
shutil.copy(os.path.join(root, "expanded_instruction_gpt-4-turbo.json"), target_path)
|
||||
elif os.path.exists(os.path.join(root, "expanded_instruction.json")):
|
||||
print(root)
|
||||
target_path = os.path.join(out_dir, os.path.relpath(root, root_dir))
|
||||
os.makedirs(target_path, exist_ok=True)
|
||||
shutil.copy(os.path.join(root, "expanded_instruction.json"), target_path)
|
||||
# rename into expanded_instruction_gpt-4-turbo.json
|
||||
os.rename(
|
||||
os.path.join(
|
||||
out_dir,
|
||||
os.path.relpath(root, root_dir),
|
||||
"expanded_instruction.json",
|
||||
),
|
||||
os.path.join(
|
||||
out_dir,
|
||||
os.path.relpath(root, root_dir),
|
||||
"expanded_instruction_gpt-4-turbo.json",
|
||||
),
|
||||
)
|
||||
for filename in fnmatch.filter(files, "*.hdf5"):
|
||||
filepath = os.path.join(root, filename)
|
||||
hdf5_files.append((filepath, root_dir, out_dir))
|
||||
|
||||
with Pool(16) as pool:
|
||||
max_count = len(hdf5_files)
|
||||
with tqdm(total=max_count) as pbar:
|
||||
for _ in pool.imap_unordered(process_hdf5_file, hdf5_files):
|
||||
pbar.update(1)
|
||||
|
||||
print(f"TFRecords written to {out_dir}")
|
||||
|
||||
|
||||
root_dir = "../datasets/agilex/rdt_data/"
|
||||
out_dir = "../datasets/agilex/tfrecords/"
|
||||
write_tfrecords(root_dir, out_dir)
|
||||
256
RDT-1B/data/compute_dataset_stat.py
Normal file
256
RDT-1B/data/compute_dataset_stat.py
Normal file
@ -0,0 +1,256 @@
|
||||
"""
|
||||
This file will compute the min, max, mean, and standard deviation of each datasets
|
||||
in `pretrain_datasets.json` or `pretrain_datasets.json`.
|
||||
"""
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import os
|
||||
|
||||
# from multiprocessing import Pool, Manager
|
||||
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from data.vla_dataset import VLADataset
|
||||
from data.hdf5_vla_dataset import HDF5VLADataset
|
||||
from data.preprocess import generate_json_state
|
||||
|
||||
|
||||
# Process each dataset to get the statistics
|
||||
@tf.autograph.experimental.do_not_convert
|
||||
def process_dataset(name_dataset_pair):
|
||||
# print(f"PID {os.getpid()} processing {name_dataset_pair[0]}")
|
||||
dataset_iter = name_dataset_pair[1]
|
||||
|
||||
MAX_EPISODES = 100000
|
||||
EPS = 1e-8
|
||||
# For debugging
|
||||
# MAX_EPISODES = 10
|
||||
episode_cnt = 0
|
||||
state_sum = 0
|
||||
state_sum_sq = 0
|
||||
z_state_sum = 0
|
||||
z_state_sum_sq = 0
|
||||
state_cnt = 0
|
||||
nz_state_cnt = None
|
||||
state_max = None
|
||||
state_min = None
|
||||
for episode in dataset_iter:
|
||||
episode_cnt += 1
|
||||
if episode_cnt % 1000 == 0:
|
||||
print(f"Processing episodes {episode_cnt}/{MAX_EPISODES}")
|
||||
if episode_cnt > MAX_EPISODES:
|
||||
break
|
||||
episode_dict = episode["episode_dict"]
|
||||
dataset_name = episode["dataset_name"]
|
||||
|
||||
res_tup = generate_json_state(episode_dict, dataset_name)
|
||||
states = res_tup[1]
|
||||
|
||||
# Convert to numpy
|
||||
states = states.numpy()
|
||||
|
||||
# Zero the values that are close to zero
|
||||
z_states = states.copy()
|
||||
z_states[np.abs(states) <= EPS] = 0
|
||||
# Compute the non-zero count
|
||||
if nz_state_cnt is None:
|
||||
nz_state_cnt = np.zeros(states.shape[1])
|
||||
nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0)
|
||||
|
||||
# Update statistics
|
||||
state_sum += np.sum(states, axis=0)
|
||||
state_sum_sq += np.sum(states**2, axis=0)
|
||||
z_state_sum += np.sum(z_states, axis=0)
|
||||
z_state_sum_sq += np.sum(z_states**2, axis=0)
|
||||
state_cnt += states.shape[0]
|
||||
if state_max is None:
|
||||
state_max = np.max(states, axis=0)
|
||||
state_min = np.min(states, axis=0)
|
||||
else:
|
||||
state_max = np.maximum(state_max, np.max(states, axis=0))
|
||||
state_min = np.minimum(state_min, np.min(states, axis=0))
|
||||
|
||||
# Add one to avoid division by zero
|
||||
nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt))
|
||||
|
||||
result = {
|
||||
"dataset_name":
|
||||
name_dataset_pair[0],
|
||||
"state_mean": (state_sum / state_cnt).tolist(),
|
||||
"state_std":
|
||||
np.sqrt(
|
||||
np.maximum(
|
||||
(z_state_sum_sq / nz_state_cnt) - (z_state_sum / state_cnt)**2 * (state_cnt / nz_state_cnt),
|
||||
np.zeros_like(state_sum_sq),
|
||||
)).tolist(),
|
||||
"state_min":
|
||||
state_min.tolist(),
|
||||
"state_max":
|
||||
state_max.tolist(),
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def process_hdf5_dataset(vla_dataset):
|
||||
EPS = 1e-8
|
||||
episode_cnt = 0
|
||||
state_sum = 0
|
||||
state_sum_sq = 0
|
||||
z_state_sum = 0
|
||||
z_state_sum_sq = 0
|
||||
state_cnt = 0
|
||||
nz_state_cnt = None
|
||||
state_max = None
|
||||
state_min = None
|
||||
for i in tqdm(range(len(vla_dataset))):
|
||||
episode = vla_dataset.get_item(i, state_only=True)
|
||||
episode_cnt += 1
|
||||
|
||||
states = episode["state"]
|
||||
|
||||
# Zero the values that are close to zero
|
||||
z_states = states.copy()
|
||||
z_states[np.abs(states) <= EPS] = 0
|
||||
# Compute the non-zero count
|
||||
if nz_state_cnt is None:
|
||||
nz_state_cnt = np.zeros(states.shape[1])
|
||||
nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0)
|
||||
|
||||
# Update statistics
|
||||
state_sum += np.sum(states, axis=0)
|
||||
state_sum_sq += np.sum(states**2, axis=0)
|
||||
z_state_sum += np.sum(z_states, axis=0)
|
||||
z_state_sum_sq += np.sum(z_states**2, axis=0)
|
||||
state_cnt += states.shape[0]
|
||||
if state_max is None:
|
||||
state_max = np.max(states, axis=0)
|
||||
state_min = np.min(states, axis=0)
|
||||
else:
|
||||
state_max = np.maximum(state_max, np.max(states, axis=0))
|
||||
state_min = np.minimum(state_min, np.min(states, axis=0))
|
||||
|
||||
# Add one to avoid division by zero
|
||||
nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt))
|
||||
|
||||
result = {
|
||||
"dataset_name":
|
||||
vla_dataset.get_dataset_name(),
|
||||
"state_mean": (state_sum / state_cnt).tolist(),
|
||||
"state_std":
|
||||
np.sqrt(
|
||||
np.maximum(
|
||||
(z_state_sum_sq / nz_state_cnt) - (z_state_sum / state_cnt)**2 * (state_cnt / nz_state_cnt),
|
||||
np.zeros_like(state_sum_sq),
|
||||
)).tolist(),
|
||||
"state_min":
|
||||
state_min.tolist(),
|
||||
"state_max":
|
||||
state_max.tolist(),
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Multiprocessing currently with bugs
|
||||
# parser.add_argument('--n_workers', type=int, default=1,
|
||||
# help="Number of parallel workers.")
|
||||
parser.add_argument(
|
||||
"--dataset_type",
|
||||
type=str,
|
||||
default="pretrain",
|
||||
help="Whether to load the pretrain dataset or finetune dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
default="configs/dataset_stat.json",
|
||||
help="JSON file path to save the dataset statistics.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_exist",
|
||||
action="store_true",
|
||||
help="Whether to skip the existing dataset statistics.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hdf5_dataset",
|
||||
action="store_true",
|
||||
help="Whether to load the dataset from the HDF5 files.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.hdf5_dataset:
|
||||
vla_dataset = HDF5VLADataset()
|
||||
dataset_name = vla_dataset.get_dataset_name()
|
||||
|
||||
try:
|
||||
with open(args.save_path, "r") as f:
|
||||
results = json.load(f)
|
||||
except FileNotFoundError:
|
||||
results = {}
|
||||
if args.skip_exist and dataset_name in results:
|
||||
print(f"Skipping existed {dataset_name} dataset statistics")
|
||||
else:
|
||||
print(f"Processing {dataset_name} dataset")
|
||||
result = process_hdf5_dataset(vla_dataset)
|
||||
results[result["dataset_name"]] = result
|
||||
with open(args.save_path, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
print("All datasets have been processed.")
|
||||
os._exit(0)
|
||||
|
||||
vla_dataset = VLADataset(seed=0, dataset_type=args.dataset_type, repeat=False)
|
||||
name_dataset_pairs = vla_dataset.name2dataset.items()
|
||||
# num_workers = args.n_workers
|
||||
|
||||
for name_dataset_pair in tqdm(name_dataset_pairs):
|
||||
try:
|
||||
with open(args.save_path, "r") as f:
|
||||
results = json.load(f)
|
||||
except FileNotFoundError:
|
||||
results = {}
|
||||
|
||||
if args.skip_exist and name_dataset_pair[0] in results:
|
||||
print(f"Skipping existed {name_dataset_pair[0]} dataset statistics")
|
||||
continue
|
||||
print(f"Processing {name_dataset_pair[0]} dataset")
|
||||
|
||||
result = process_dataset(name_dataset_pair)
|
||||
|
||||
results[result["dataset_name"]] = result
|
||||
|
||||
# Save the results in the json file after each dataset (for resume)
|
||||
with open(args.save_path, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
|
||||
print("All datasets have been processed.")
|
||||
|
||||
# with Manager() as manager:
|
||||
# # Create shared dictionary and lock through the manager, accessible by all processes
|
||||
# progress = manager.dict(processed=0, results={})
|
||||
# progress_lock = manager.Lock()
|
||||
|
||||
# # Callback function to update progress
|
||||
# def update_progress(result):
|
||||
# with progress_lock:
|
||||
# progress['processed'] += 1
|
||||
# print(f"{result['dataset_name']} - {progress['processed']}/{len(name_dataset_pairs)} datasets have been processed")
|
||||
# # Append the result to the shared dictionary
|
||||
# progress['results'][result["dataset_name"]] = result
|
||||
|
||||
# with Pool(num_workers) as p:
|
||||
# for name_dataset_pair in name_dataset_pairs:
|
||||
# p.apply_async(process_dataset, args=(name_dataset_pair,), callback=update_progress)
|
||||
|
||||
# # Close the pool and wait for the work to finish
|
||||
# p.close()
|
||||
# p.join()
|
||||
|
||||
# # Save the results in the json file
|
||||
# with open(args.save_path, 'w') as f:
|
||||
# json.dump(progress['results'], f, indent=4)
|
||||
112
RDT-1B/data/compute_dataset_stat_hdf5.py
Normal file
112
RDT-1B/data/compute_dataset_stat_hdf5.py
Normal file
@ -0,0 +1,112 @@
|
||||
"""
|
||||
This file will compute the min, max, mean, and standard deviation of each datasets
|
||||
in `pretrain_datasets.json` or `pretrain_datasets.json`.
|
||||
"""
|
||||
|
||||
import json
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from data.hdf5_vla_dataset import HDF5VLADataset
|
||||
|
||||
|
||||
def process_hdf5_dataset(vla_dataset):
|
||||
EPS = 1e-8
|
||||
episode_cnt = 0
|
||||
state_sum = 0
|
||||
state_sum_sq = 0
|
||||
z_state_sum = 0
|
||||
z_state_sum_sq = 0
|
||||
state_cnt = 0
|
||||
nz_state_cnt = None
|
||||
state_max = None
|
||||
state_min = None
|
||||
for i in tqdm(range(len(vla_dataset))):
|
||||
episode = vla_dataset.get_item(i, state_only=True)
|
||||
episode_cnt += 1
|
||||
|
||||
states = episode["state"]
|
||||
|
||||
# Zero the values that are close to zero
|
||||
z_states = states.copy()
|
||||
z_states[np.abs(states) <= EPS] = 0
|
||||
# Compute the non-zero count
|
||||
if nz_state_cnt is None:
|
||||
nz_state_cnt = np.zeros(states.shape[1])
|
||||
nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0)
|
||||
|
||||
# Update statistics
|
||||
state_sum += np.sum(states, axis=0)
|
||||
state_sum_sq += np.sum(states**2, axis=0)
|
||||
z_state_sum += np.sum(z_states, axis=0)
|
||||
z_state_sum_sq += np.sum(z_states**2, axis=0)
|
||||
state_cnt += states.shape[0]
|
||||
if state_max is None:
|
||||
state_max = np.max(states, axis=0)
|
||||
state_min = np.min(states, axis=0)
|
||||
else:
|
||||
state_max = np.maximum(state_max, np.max(states, axis=0))
|
||||
state_min = np.minimum(state_min, np.min(states, axis=0))
|
||||
|
||||
# Add one to avoid division by zero
|
||||
nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt))
|
||||
|
||||
result = {
|
||||
"dataset_name":
|
||||
vla_dataset.get_dataset_name(),
|
||||
"state_mean": (state_sum / state_cnt).tolist(),
|
||||
"state_std":
|
||||
np.sqrt(
|
||||
np.maximum(
|
||||
(z_state_sum_sq / nz_state_cnt) - (z_state_sum / state_cnt)**2 * (state_cnt / nz_state_cnt),
|
||||
np.zeros_like(state_sum_sq),
|
||||
)).tolist(),
|
||||
"state_min":
|
||||
state_min.tolist(),
|
||||
"state_max":
|
||||
state_max.tolist(),
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--task_name",
|
||||
type=str,
|
||||
default="configs/dataset_stat.json",
|
||||
help="JSON file path to save the dataset statistics.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
default="configs/dataset_stat.json",
|
||||
help="JSON file path to save the dataset statistics.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_exist",
|
||||
action="store_true",
|
||||
help="Whether to skip the existing dataset statistics.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
vla_dataset = HDF5VLADataset(f"model_config/{args.task_name}.yml")
|
||||
dataset_name = vla_dataset.get_dataset_name()
|
||||
|
||||
try:
|
||||
with open(args.save_path, "r") as f:
|
||||
results = json.load(f)
|
||||
except FileNotFoundError:
|
||||
results = {}
|
||||
if args.skip_exist and dataset_name in results:
|
||||
print(f"Skipping existed {dataset_name} dataset statistics")
|
||||
else:
|
||||
print(f"Processing {dataset_name} dataset")
|
||||
result = process_hdf5_dataset(vla_dataset)
|
||||
results[result["dataset_name"]] = result
|
||||
with open(args.save_path, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
print("All datasets have been processed.")
|
||||
BIN
RDT-1B/data/empty_lang_embed.pt
Normal file
BIN
RDT-1B/data/empty_lang_embed.pt
Normal file
Binary file not shown.
398
RDT-1B/data/episode_transform.py
Normal file
398
RDT-1B/data/episode_transform.py
Normal file
@ -0,0 +1,398 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import yaml
|
||||
|
||||
from data.preprocess import generate_json_state
|
||||
from configs.state_vec import STATE_VEC_IDX_MAPPING
|
||||
|
||||
# Read the config
|
||||
with open("configs/base.yaml", "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
# Load some constants from the config
|
||||
IMG_HISTORY_SIZE = config["common"]["img_history_size"]
|
||||
if IMG_HISTORY_SIZE < 1:
|
||||
raise ValueError("Config `img_history_size` must be at least 1.")
|
||||
ACTION_CHUNK_SIZE = config["common"]["action_chunk_size"]
|
||||
if ACTION_CHUNK_SIZE < 1:
|
||||
raise ValueError("Config `action_chunk_size` must be at least 1.")
|
||||
|
||||
|
||||
@tf.function
|
||||
def process_episode(epsd: dict, dataset_name: str, image_keys: list, image_mask: list) -> dict:
|
||||
"""
|
||||
Process an episode to extract the frames and the json content.
|
||||
"""
|
||||
# Frames of each camera
|
||||
# Ugly code due to tf's poor compatibility
|
||||
frames_0 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True)
|
||||
frames_1 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True)
|
||||
frames_2 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True)
|
||||
frames_3 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True)
|
||||
# Traverse the episode to collect...
|
||||
for step in iter(epsd["steps"]):
|
||||
# Parse the image
|
||||
frames_0 = frames_0.write(
|
||||
frames_0.size(),
|
||||
tf.cond(
|
||||
tf.equal(image_mask[0], 1),
|
||||
lambda: step["observation"][image_keys[0]],
|
||||
lambda: tf.zeros([0, 0, 0], dtype=tf.uint8),
|
||||
),
|
||||
)
|
||||
# Very ugly code due to tf's poor compatibility
|
||||
frames_1 = frames_1.write(
|
||||
frames_1.size(),
|
||||
tf.cond(
|
||||
tf.equal(image_mask[1], 1),
|
||||
lambda: step["observation"][image_keys[1]],
|
||||
lambda: tf.zeros([0, 0, 0], dtype=tf.uint8),
|
||||
),
|
||||
)
|
||||
frames_2 = frames_2.write(
|
||||
frames_2.size(),
|
||||
tf.cond(
|
||||
tf.equal(image_mask[2], 1),
|
||||
lambda: step["observation"][image_keys[2]],
|
||||
lambda: tf.zeros([0, 0, 0], dtype=tf.uint8),
|
||||
),
|
||||
)
|
||||
frames_3 = frames_3.write(
|
||||
frames_3.size(),
|
||||
tf.cond(
|
||||
tf.equal(image_mask[3], 1),
|
||||
lambda: step["observation"][image_keys[3]],
|
||||
lambda: tf.zeros([0, 0, 0], dtype=tf.uint8),
|
||||
),
|
||||
)
|
||||
|
||||
# Calculate the past_frames_0 for each step
|
||||
# Each step has a window of previous frames with size IMG_HISTORY_SIZE
|
||||
# Use the first state to pad the frames
|
||||
# past_frames_0 will have shape (num_steps, IMG_HISTORY_SIZE, height, width, channels)
|
||||
frames_0 = frames_0.stack()
|
||||
first_frame = tf.expand_dims(frames_0[0], axis=0)
|
||||
first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE - 1, axis=0)
|
||||
padded_frames_0 = tf.concat([first_frame, frames_0], axis=0)
|
||||
indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_0)[0] + IMG_HISTORY_SIZE)
|
||||
past_frames_0 = tf.map_fn(lambda i: padded_frames_0[i - IMG_HISTORY_SIZE:i], indices, dtype=tf.uint8)
|
||||
frames_0_time_mask = tf.ones([tf.shape(frames_0)[0]], dtype=tf.bool)
|
||||
padded_frames_0_time_mask = tf.pad(
|
||||
frames_0_time_mask,
|
||||
[[IMG_HISTORY_SIZE - 1, 0]],
|
||||
"CONSTANT",
|
||||
constant_values=False,
|
||||
)
|
||||
past_frames_0_time_mask = tf.map_fn(
|
||||
lambda i: padded_frames_0_time_mask[i - IMG_HISTORY_SIZE:i],
|
||||
indices,
|
||||
dtype=tf.bool,
|
||||
)
|
||||
|
||||
# For past_frames_1
|
||||
frames_1 = frames_1.stack()
|
||||
first_frame = tf.expand_dims(frames_1[0], axis=0)
|
||||
first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE - 1, axis=0)
|
||||
padded_frames_1 = tf.concat([first_frame, frames_1], axis=0)
|
||||
indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_1)[0] + IMG_HISTORY_SIZE)
|
||||
past_frames_1 = tf.map_fn(lambda i: padded_frames_1[i - IMG_HISTORY_SIZE:i], indices, dtype=tf.uint8)
|
||||
frames_1_time_mask = tf.ones([tf.shape(frames_1)[0]], dtype=tf.bool)
|
||||
padded_frames_1_time_mask = tf.pad(
|
||||
frames_1_time_mask,
|
||||
[[IMG_HISTORY_SIZE - 1, 0]],
|
||||
"CONSTANT",
|
||||
constant_values=False,
|
||||
)
|
||||
past_frames_1_time_mask = tf.map_fn(
|
||||
lambda i: padded_frames_1_time_mask[i - IMG_HISTORY_SIZE:i],
|
||||
indices,
|
||||
dtype=tf.bool,
|
||||
)
|
||||
|
||||
# For past_frames_2
|
||||
frames_2 = frames_2.stack()
|
||||
first_frame = tf.expand_dims(frames_2[0], axis=0)
|
||||
first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE - 1, axis=0)
|
||||
padded_frames_2 = tf.concat([first_frame, frames_2], axis=0)
|
||||
indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_2)[0] + IMG_HISTORY_SIZE)
|
||||
past_frames_2 = tf.map_fn(lambda i: padded_frames_2[i - IMG_HISTORY_SIZE:i], indices, dtype=tf.uint8)
|
||||
frames_2_time_mask = tf.ones([tf.shape(frames_2)[0]], dtype=tf.bool)
|
||||
padded_frames_2_time_mask = tf.pad(
|
||||
frames_2_time_mask,
|
||||
[[IMG_HISTORY_SIZE - 1, 0]],
|
||||
"CONSTANT",
|
||||
constant_values=False,
|
||||
)
|
||||
past_frames_2_time_mask = tf.map_fn(
|
||||
lambda i: padded_frames_2_time_mask[i - IMG_HISTORY_SIZE:i],
|
||||
indices,
|
||||
dtype=tf.bool,
|
||||
)
|
||||
|
||||
# For past_frames_3
|
||||
frames_3 = frames_3.stack()
|
||||
first_frame = tf.expand_dims(frames_3[0], axis=0)
|
||||
first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE - 1, axis=0)
|
||||
padded_frames_3 = tf.concat([first_frame, frames_3], axis=0)
|
||||
indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_3)[0] + IMG_HISTORY_SIZE)
|
||||
past_frames_3 = tf.map_fn(lambda i: padded_frames_3[i - IMG_HISTORY_SIZE:i], indices, dtype=tf.uint8)
|
||||
frames_3_time_mask = tf.ones([tf.shape(frames_3)[0]], dtype=tf.bool)
|
||||
padded_frames_3_time_mask = tf.pad(
|
||||
frames_3_time_mask,
|
||||
[[IMG_HISTORY_SIZE - 1, 0]],
|
||||
"CONSTANT",
|
||||
constant_values=False,
|
||||
)
|
||||
past_frames_3_time_mask = tf.map_fn(
|
||||
lambda i: padded_frames_3_time_mask[i - IMG_HISTORY_SIZE:i],
|
||||
indices,
|
||||
dtype=tf.bool,
|
||||
)
|
||||
|
||||
# Creat the ids for each step
|
||||
step_id = tf.range(0, tf.shape(frames_0)[0])
|
||||
|
||||
return {
|
||||
"dataset_name": dataset_name,
|
||||
"episode_dict": epsd,
|
||||
"step_id": step_id,
|
||||
"past_frames_0": past_frames_0,
|
||||
"past_frames_0_time_mask": past_frames_0_time_mask,
|
||||
"past_frames_1": past_frames_1,
|
||||
"past_frames_1_time_mask": past_frames_1_time_mask,
|
||||
"past_frames_2": past_frames_2,
|
||||
"past_frames_2_time_mask": past_frames_2_time_mask,
|
||||
"past_frames_3": past_frames_3,
|
||||
"past_frames_3_time_mask": past_frames_3_time_mask,
|
||||
}
|
||||
|
||||
|
||||
@tf.function
|
||||
def bgr_to_rgb(epsd: dict):
|
||||
"""
|
||||
Convert BGR images to RGB images.
|
||||
"""
|
||||
past_frames_0 = epsd["past_frames_0"]
|
||||
past_frames_0 = tf.cond(
|
||||
tf.equal(tf.shape(past_frames_0)[-1], 3),
|
||||
lambda: tf.stack(
|
||||
[past_frames_0[..., 2], past_frames_0[..., 1], past_frames_0[..., 0]],
|
||||
axis=-1,
|
||||
),
|
||||
lambda: past_frames_0,
|
||||
)
|
||||
|
||||
past_frames_1 = epsd["past_frames_1"]
|
||||
past_frames_1 = tf.cond(
|
||||
tf.equal(tf.shape(past_frames_1)[-1], 3),
|
||||
lambda: tf.stack(
|
||||
[past_frames_1[..., 2], past_frames_1[..., 1], past_frames_1[..., 0]],
|
||||
axis=-1,
|
||||
),
|
||||
lambda: past_frames_1,
|
||||
)
|
||||
|
||||
past_frames_2 = epsd["past_frames_2"]
|
||||
past_frames_2 = tf.cond(
|
||||
tf.equal(tf.shape(past_frames_2)[-1], 3),
|
||||
lambda: tf.stack(
|
||||
[past_frames_2[..., 2], past_frames_2[..., 1], past_frames_2[..., 0]],
|
||||
axis=-1,
|
||||
),
|
||||
lambda: past_frames_2,
|
||||
)
|
||||
|
||||
past_frames_3 = epsd["past_frames_3"]
|
||||
past_frames_3 = tf.cond(
|
||||
tf.equal(tf.shape(past_frames_3)[-1], 3),
|
||||
lambda: tf.stack(
|
||||
[past_frames_3[..., 2], past_frames_3[..., 1], past_frames_3[..., 0]],
|
||||
axis=-1,
|
||||
),
|
||||
lambda: past_frames_3,
|
||||
)
|
||||
|
||||
return {
|
||||
"dataset_name": epsd["dataset_name"],
|
||||
"episode_dict": epsd["episode_dict"],
|
||||
"step_id": epsd["step_id"],
|
||||
"past_frames_0": past_frames_0,
|
||||
"past_frames_0_time_mask": epsd["past_frames_0_time_mask"],
|
||||
"past_frames_1": past_frames_1,
|
||||
"past_frames_1_time_mask": epsd["past_frames_1_time_mask"],
|
||||
"past_frames_2": past_frames_2,
|
||||
"past_frames_2_time_mask": epsd["past_frames_2_time_mask"],
|
||||
"past_frames_3": past_frames_3,
|
||||
"past_frames_3_time_mask": epsd["past_frames_3_time_mask"],
|
||||
}
|
||||
|
||||
|
||||
def flatten_episode(episode: dict) -> tf.data.Dataset:
|
||||
"""
|
||||
Flatten the episode to a list of steps.
|
||||
"""
|
||||
episode_dict = episode["episode_dict"]
|
||||
dataset_name = episode["dataset_name"]
|
||||
|
||||
json_content, states, masks = generate_json_state(episode_dict, dataset_name)
|
||||
|
||||
# Calculate the past_states for each step
|
||||
# Each step has a window of previous states with size ACTION_CHUNK_SIZE
|
||||
# Use the first state to pad the states
|
||||
# past_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim)
|
||||
first_state = tf.expand_dims(states[0], axis=0)
|
||||
first_state = tf.repeat(first_state, ACTION_CHUNK_SIZE - 1, axis=0)
|
||||
padded_states = tf.concat([first_state, states], axis=0)
|
||||
indices = tf.range(ACTION_CHUNK_SIZE, tf.shape(states)[0] + ACTION_CHUNK_SIZE)
|
||||
past_states = tf.map_fn(lambda i: padded_states[i - ACTION_CHUNK_SIZE:i], indices, dtype=tf.float32)
|
||||
states_time_mask = tf.ones([tf.shape(states)[0]], dtype=tf.bool)
|
||||
padded_states_time_mask = tf.pad(
|
||||
states_time_mask,
|
||||
[[ACTION_CHUNK_SIZE - 1, 0]],
|
||||
"CONSTANT",
|
||||
constant_values=False,
|
||||
)
|
||||
past_states_time_mask = tf.map_fn(
|
||||
lambda i: padded_states_time_mask[i - ACTION_CHUNK_SIZE:i],
|
||||
indices,
|
||||
dtype=tf.bool,
|
||||
)
|
||||
|
||||
# Calculate the future_states for each step
|
||||
# Each step has a window of future states with size ACTION_CHUNK_SIZE
|
||||
# Use the last state to pad the states
|
||||
# future_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim)
|
||||
last_state = tf.expand_dims(states[-1], axis=0)
|
||||
last_state = tf.repeat(last_state, ACTION_CHUNK_SIZE, axis=0)
|
||||
padded_states = tf.concat([states, last_state], axis=0)
|
||||
indices = tf.range(1, tf.shape(states)[0] + 1)
|
||||
future_states = tf.map_fn(lambda i: padded_states[i:i + ACTION_CHUNK_SIZE], indices, dtype=tf.float32)
|
||||
states_time_mask = tf.ones([tf.shape(states)[0]], dtype=tf.bool)
|
||||
padded_states_time_mask = tf.pad(states_time_mask, [[0, ACTION_CHUNK_SIZE]], "CONSTANT", constant_values=False)
|
||||
future_states_time_mask = tf.map_fn(
|
||||
lambda i: padded_states_time_mask[i:i + ACTION_CHUNK_SIZE],
|
||||
indices,
|
||||
dtype=tf.bool,
|
||||
)
|
||||
|
||||
# Calculate the mean and std for state
|
||||
state_std = tf.math.reduce_std(states, axis=0, keepdims=True)
|
||||
state_std = tf.repeat(state_std, tf.shape(states)[0], axis=0)
|
||||
state_mean = tf.math.reduce_mean(states, axis=0, keepdims=True)
|
||||
state_mean = tf.repeat(state_mean, tf.shape(states)[0], axis=0)
|
||||
|
||||
state_norm = tf.math.reduce_mean(tf.math.square(states), axis=0, keepdims=True)
|
||||
state_norm = tf.math.sqrt(state_norm)
|
||||
state_norm = tf.repeat(state_norm, tf.shape(states)[0], axis=0)
|
||||
|
||||
# Create a list of steps
|
||||
step_data = []
|
||||
for i in range(tf.shape(states)[0]):
|
||||
step_data.append({
|
||||
"step_id": episode["step_id"][i],
|
||||
"json_content": json_content,
|
||||
"state_chunk": past_states[i],
|
||||
"state_chunk_time_mask": past_states_time_mask[i],
|
||||
"action_chunk": future_states[i],
|
||||
"action_chunk_time_mask": future_states_time_mask[i],
|
||||
"state_vec_mask": masks[i],
|
||||
"past_frames_0": episode["past_frames_0"][i],
|
||||
"past_frames_0_time_mask": episode["past_frames_0_time_mask"][i],
|
||||
"past_frames_1": episode["past_frames_1"][i],
|
||||
"past_frames_1_time_mask": episode["past_frames_1_time_mask"][i],
|
||||
"past_frames_2": episode["past_frames_2"][i],
|
||||
"past_frames_2_time_mask": episode["past_frames_2_time_mask"][i],
|
||||
"past_frames_3": episode["past_frames_3"][i],
|
||||
"past_frames_3_time_mask": episode["past_frames_3_time_mask"][i],
|
||||
"state_std": state_std[i],
|
||||
"state_mean": state_mean[i],
|
||||
"state_norm": state_norm[i],
|
||||
})
|
||||
|
||||
return step_data
|
||||
|
||||
|
||||
def flatten_episode_agilex(episode: dict) -> tf.data.Dataset:
|
||||
"""
|
||||
Flatten the episode to a list of steps.
|
||||
"""
|
||||
episode_dict = episode["episode_dict"]
|
||||
dataset_name = episode["dataset_name"]
|
||||
|
||||
json_content, states, masks, acts = generate_json_state(episode_dict, dataset_name)
|
||||
|
||||
# Calculate the past_states for each step
|
||||
# Each step has a window of previous states with size ACTION_CHUNK_SIZE
|
||||
# Use the first state to pad the states
|
||||
# past_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim)
|
||||
first_state = tf.expand_dims(states[0], axis=0)
|
||||
first_state = tf.repeat(first_state, ACTION_CHUNK_SIZE - 1, axis=0)
|
||||
padded_states = tf.concat([first_state, states], axis=0)
|
||||
indices = tf.range(ACTION_CHUNK_SIZE, tf.shape(states)[0] + ACTION_CHUNK_SIZE)
|
||||
past_states = tf.map_fn(lambda i: padded_states[i - ACTION_CHUNK_SIZE:i], indices, dtype=tf.float32)
|
||||
states_time_mask = tf.ones([tf.shape(states)[0]], dtype=tf.bool)
|
||||
padded_states_time_mask = tf.pad(
|
||||
states_time_mask,
|
||||
[[ACTION_CHUNK_SIZE - 1, 0]],
|
||||
"CONSTANT",
|
||||
constant_values=False,
|
||||
)
|
||||
past_states_time_mask = tf.map_fn(
|
||||
lambda i: padded_states_time_mask[i - ACTION_CHUNK_SIZE:i],
|
||||
indices,
|
||||
dtype=tf.bool,
|
||||
)
|
||||
|
||||
# NOTE bg the future states shall be actions
|
||||
# Calculate the future_states for each step
|
||||
# Each step has a window of future states with size ACTION_CHUNK_SIZE
|
||||
# Use the last action to pad the states
|
||||
# future_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim)
|
||||
last_act = tf.expand_dims(acts[-1], axis=0)
|
||||
last_act = tf.repeat(last_act, ACTION_CHUNK_SIZE, axis=0)
|
||||
padded_states = tf.concat([acts, last_act], axis=0)
|
||||
# indices = tf.range(1, tf.shape(states)[0] + 1)
|
||||
indices = tf.range(0, tf.shape(acts)[0]) # NOTE time 0 action = time 1 state
|
||||
future_states = tf.map_fn(lambda i: padded_states[i:i + ACTION_CHUNK_SIZE], indices, dtype=tf.float32)
|
||||
states_time_mask = tf.ones([tf.shape(acts)[0]], dtype=tf.bool)
|
||||
padded_states_time_mask = tf.pad(states_time_mask, [[0, ACTION_CHUNK_SIZE]], "CONSTANT", constant_values=False)
|
||||
future_states_time_mask = tf.map_fn(
|
||||
lambda i: padded_states_time_mask[i:i + ACTION_CHUNK_SIZE],
|
||||
indices,
|
||||
dtype=tf.bool,
|
||||
)
|
||||
|
||||
# Calculate the std and mean for state
|
||||
state_std = tf.math.reduce_std(states, axis=0, keepdims=True)
|
||||
state_std = tf.repeat(state_std, tf.shape(states)[0], axis=0)
|
||||
state_mean = tf.math.reduce_mean(states, axis=0, keepdims=True)
|
||||
state_mean = tf.repeat(state_mean, tf.shape(states)[0], axis=0)
|
||||
|
||||
state_norm = tf.math.reduce_mean(tf.math.square(acts), axis=0, keepdims=True)
|
||||
state_norm = tf.math.sqrt(state_norm)
|
||||
state_norm = tf.repeat(state_norm, tf.shape(states)[0], axis=0)
|
||||
|
||||
# Create a list of steps
|
||||
step_data = []
|
||||
for i in range(tf.shape(states)[0]):
|
||||
step_data.append({
|
||||
"step_id": episode["step_id"][i],
|
||||
"json_content": json_content,
|
||||
"state_chunk": past_states[i],
|
||||
"state_chunk_time_mask": past_states_time_mask[i],
|
||||
"action_chunk": future_states[i],
|
||||
"action_chunk_time_mask": future_states_time_mask[i],
|
||||
"state_vec_mask": masks[i],
|
||||
"past_frames_0": episode["past_frames_0"][i],
|
||||
"past_frames_0_time_mask": episode["past_frames_0_time_mask"][i],
|
||||
"past_frames_1": episode["past_frames_1"][i],
|
||||
"past_frames_1_time_mask": episode["past_frames_1_time_mask"][i],
|
||||
"past_frames_2": episode["past_frames_2"][i],
|
||||
"past_frames_2_time_mask": episode["past_frames_2_time_mask"][i],
|
||||
"past_frames_3": episode["past_frames_3"][i],
|
||||
"past_frames_3_time_mask": episode["past_frames_3_time_mask"][i],
|
||||
"state_std": state_std[i],
|
||||
"state_mean": state_mean[i],
|
||||
"state_norm": state_norm[i],
|
||||
})
|
||||
|
||||
return step_data
|
||||
25
RDT-1B/data/filelock.py
Normal file
25
RDT-1B/data/filelock.py
Normal file
@ -0,0 +1,25 @@
|
||||
import fcntl
|
||||
|
||||
|
||||
class FileLock:
|
||||
"""
|
||||
A file lock class.
|
||||
"""
|
||||
|
||||
def __init__(self, filename):
|
||||
self.filename = filename
|
||||
self.handle = None
|
||||
|
||||
def acquire_read_lock(self):
|
||||
self.handle = open(self.filename + ".lock", "r")
|
||||
fcntl.flock(self.handle, fcntl.LOCK_SH | fcntl.LOCK_NB)
|
||||
|
||||
def acquire_write_lock(self):
|
||||
self.handle = open(self.filename + ".lock", "w")
|
||||
fcntl.flock(self.handle, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||
|
||||
def release_lock(self):
|
||||
if self.handle is not None:
|
||||
fcntl.flock(self.handle, fcntl.LOCK_UN)
|
||||
self.handle.close()
|
||||
self.handle = None
|
||||
372
RDT-1B/data/hdf5_vla_dataset.py
Normal file
372
RDT-1B/data/hdf5_vla_dataset.py
Normal file
@ -0,0 +1,372 @@
|
||||
import os
|
||||
import fnmatch
|
||||
import json
|
||||
|
||||
import h5py
|
||||
import yaml
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from configs.state_vec import STATE_VEC_IDX_MAPPING
|
||||
|
||||
class HDF5VLADataset:
|
||||
"""
|
||||
This class is used to sample episodes from the embododiment dataset
|
||||
stored in HDF5.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config_path) -> None:
|
||||
# [Modify] The path to the HDF5 dataset directory
|
||||
# Each HDF5 file contains one episode
|
||||
with open(model_config_path, "r") as f:
|
||||
model_config = yaml.safe_load(f)
|
||||
HDF5_DIR = model_config["data_path"]
|
||||
self.DATASET_NAME = "agilex"
|
||||
|
||||
self.file_paths = []
|
||||
for root, _, files in os.walk(HDF5_DIR):
|
||||
for filename in fnmatch.filter(files, "*.hdf5"):
|
||||
file_path = os.path.join(root, filename)
|
||||
self.file_paths.append(file_path)
|
||||
|
||||
# Load the config
|
||||
with open("configs/base.yaml", "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
self.CHUNK_SIZE = config["common"]["action_chunk_size"]
|
||||
self.IMG_HISORY_SIZE = config["common"]["img_history_size"]
|
||||
self.STATE_DIM = config["common"]["state_dim"]
|
||||
|
||||
# Get each episode's len (use original length, not standardized length)
|
||||
episode_lens = []
|
||||
for file_path in self.file_paths:
|
||||
try:
|
||||
with h5py.File(file_path, "r") as f:
|
||||
qpos = f["observations"]["qpos"][:]
|
||||
num_steps = qpos.shape[0]
|
||||
episode_lens.append(num_steps)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not read {file_path}: {e}")
|
||||
episode_lens.append(0)
|
||||
self.episode_sample_weights = np.array(episode_lens) / np.sum(episode_lens)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.file_paths)
|
||||
|
||||
def get_dataset_name(self):
|
||||
return self.DATASET_NAME
|
||||
|
||||
def get_item(self, index: int = None, state_only=False):
|
||||
"""Get a training sample at a random timestep.
|
||||
|
||||
Args:
|
||||
index (int, optional): the index of the episode.
|
||||
If not provided, a random episode will be selected.
|
||||
state_only (bool, optional): Whether to return only the state.
|
||||
In this way, the sample will contain a complete trajectory rather
|
||||
than a single timestep. Defaults to False.
|
||||
|
||||
Returns:
|
||||
sample (dict): a dictionary containing the training sample.
|
||||
"""
|
||||
while True:
|
||||
if index is None:
|
||||
file_path = np.random.choice(self.file_paths, p=self.episode_sample_weights)
|
||||
else:
|
||||
file_path = self.file_paths[index]
|
||||
valid, sample = (self.parse_hdf5_file(file_path)
|
||||
if not state_only else self.parse_hdf5_file_state_only(file_path))
|
||||
if valid:
|
||||
return sample
|
||||
else:
|
||||
index = np.random.randint(0, len(self.file_paths))
|
||||
|
||||
def parse_hdf5_file(self, file_path):
|
||||
"""[Modify] Parse a hdf5 file to generate a training sample at
|
||||
a random timestep.
|
||||
|
||||
Args:
|
||||
file_path (str): the path to the hdf5 file
|
||||
|
||||
Returns:
|
||||
valid (bool): whether the episode is valid, which is useful for filtering.
|
||||
If False, this episode will be dropped.
|
||||
dict: a dictionary containing the training sample,
|
||||
{
|
||||
"meta": {
|
||||
"dataset_name": str, # the name of your dataset.
|
||||
"#steps": int, # the number of steps in the episode,
|
||||
# also the total timesteps.
|
||||
"instruction": str # the language instruction for this episode.
|
||||
},
|
||||
"step_id": int, # the index of the sampled step,
|
||||
# also the timestep t.
|
||||
"state": ndarray, # state[t], (1, STATE_DIM).
|
||||
"state_std": ndarray, # std(state[:]), (STATE_DIM,).
|
||||
"state_mean": ndarray, # mean(state[:]), (STATE_DIM,).
|
||||
"state_norm": ndarray, # norm(state[:]), (STATE_DIM,).
|
||||
"actions": ndarray, # action[t:t+CHUNK_SIZE], (CHUNK_SIZE, STATE_DIM).
|
||||
"state_indicator", ndarray, # indicates the validness of each dim, (STATE_DIM,).
|
||||
"cam_high": ndarray, # external camera image, (IMG_HISORY_SIZE, H, W, 3)
|
||||
# or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
|
||||
"cam_high_mask": ndarray, # indicates the validness of each timestep, (IMG_HISORY_SIZE,) boolean array.
|
||||
# For the first IMAGE_HISTORY_SIZE-1 timesteps, the mask should be False.
|
||||
"cam_left_wrist": ndarray, # left wrist camera image, (IMG_HISORY_SIZE, H, W, 3).
|
||||
# or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
|
||||
"cam_left_wrist_mask": ndarray,
|
||||
"cam_right_wrist": ndarray, # right wrist camera image, (IMG_HISORY_SIZE, H, W, 3).
|
||||
# or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
|
||||
# If only one wrist, make it right wrist, plz.
|
||||
"cam_right_wrist_mask": ndarray
|
||||
} or None if the episode is invalid.
|
||||
"""
|
||||
with h5py.File(file_path, "r") as f:
|
||||
qpos = f["observations"]["qpos"][:]
|
||||
left_arm_dim = f["observations"]["left_arm_dim"][:]
|
||||
right_arm_dim = f["observations"]["right_arm_dim"][:]
|
||||
num_steps = qpos.shape[0]
|
||||
action_dim = qpos
|
||||
# [Optional] We drop too-short episode
|
||||
# if num_steps < 128:
|
||||
# return False, None
|
||||
|
||||
# [Optional] We skip the first few still steps
|
||||
EPS = 1e-2
|
||||
# Get the idx of the first qpos whose delta exceeds the threshold
|
||||
qpos_delta = np.abs(qpos - qpos[0:1])
|
||||
indices = np.where(np.any(qpos_delta > EPS, axis=1))[0]
|
||||
if len(indices) > 0:
|
||||
first_idx = indices[0]
|
||||
else:
|
||||
raise ValueError("Found no qpos that exceeds the threshold.")
|
||||
|
||||
# We randomly sample a timestep
|
||||
step_id = np.random.randint(first_idx - 1, num_steps)
|
||||
|
||||
# Load the instruction
|
||||
dir_path = os.path.dirname(file_path)
|
||||
|
||||
# with open(os.path.join(dir_path, 'instruction.json'), 'r') as f_instr:
|
||||
# instruction_dict = json.load(f_instr)
|
||||
# # We have 1/3 prob to use original instruction,
|
||||
# # 1/3 to use simplified instruction,
|
||||
# # and 1/3 to use expanded instruction.
|
||||
# instruction_type = np.random.choice([
|
||||
# 'instruction', 'expanded_instruction'])
|
||||
# instruction = instruction_dict[instruction_type]
|
||||
# if isinstance(instruction, list):
|
||||
# instruction = np.random.choice(instruction)
|
||||
|
||||
# You can also use precomputed language embeddings (recommended)
|
||||
# instruction = "path/to/lang_embed.pt"
|
||||
instructions_path = os.path.join(dir_path, "instructions")
|
||||
instructions_names = []
|
||||
|
||||
for filename in os.listdir(instructions_path):
|
||||
# 检查文件名是否以.pt结尾
|
||||
if filename.endswith(".pt"):
|
||||
instructions_names.append(os.path.join(instructions_path, filename))
|
||||
instruction = np.random.choice(instructions_names)
|
||||
# print(f"choose {instruction} file as instruction.")
|
||||
# Assemble the meta
|
||||
meta = {
|
||||
"dataset_name": self.DATASET_NAME,
|
||||
"#steps": num_steps,
|
||||
"step_id": step_id,
|
||||
"instruction": instruction,
|
||||
}
|
||||
|
||||
# Rescale gripper to [0, 1]
|
||||
# qpos = qpos / np.array([[1 for i in range(left_arm_dim[0] + 1 + right_arm_dim[0] + 1)]])
|
||||
# target_qpos = f["action"][step_id:step_id + self.CHUNK_SIZE] / np.array(
|
||||
# [[1 for i in range(left_arm_dim[0] + 1 + right_arm_dim[0] + 1)]])
|
||||
|
||||
qpos = qpos / np.array(
|
||||
# [[1, 1, 1, 1, 1, 1, 4.7908, 1, 1, 1, 1, 1, 1, 4.7888]]
|
||||
[[180, 180, 180, 180, 180, 180]]
|
||||
)
|
||||
target_qpos = f['action'][step_id:step_id + self.CHUNK_SIZE] / np.array(
|
||||
# [[1, 1, 1, 1, 1, 1, 11.8997, 1, 1, 1, 1, 1, 1, 13.9231]]
|
||||
[[180, 180, 180, 180, 180, 180]]
|
||||
)
|
||||
|
||||
# Parse the state and action
|
||||
state = qpos[step_id:step_id + 1]
|
||||
state_std = np.std(qpos, axis=0)
|
||||
state_mean = np.mean(qpos, axis=0)
|
||||
state_norm = np.sqrt(np.mean(qpos**2, axis=0))
|
||||
actions = target_qpos
|
||||
if actions.shape[0] < self.CHUNK_SIZE:
|
||||
# Pad the actions using the last action
|
||||
actions = np.concatenate(
|
||||
[
|
||||
actions,
|
||||
np.tile(actions[-1:], (self.CHUNK_SIZE - actions.shape[0], 1)),
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
|
||||
# Fill the state/action into the unified vector
|
||||
|
||||
def fill_in_state(values):
|
||||
# Target indices corresponding to your state space
|
||||
# In this example: 6 joints + 1 gripper for each arm
|
||||
UNI_STATE_INDICES = [
|
||||
STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(6)
|
||||
# ] + [
|
||||
# STATE_VEC_IDX_MAPPING["right_gripper_open"]
|
||||
]
|
||||
uni_vec = np.zeros(values.shape[:-1] + (self.STATE_DIM, ))
|
||||
uni_vec[..., UNI_STATE_INDICES] = values
|
||||
return uni_vec
|
||||
|
||||
state = fill_in_state(state)
|
||||
state_indicator = fill_in_state(np.ones_like(state_std))
|
||||
state_std = fill_in_state(state_std)
|
||||
state_mean = fill_in_state(state_mean)
|
||||
state_norm = fill_in_state(state_norm)
|
||||
# If action's format is different from state's,
|
||||
# you may implement fill_in_action()
|
||||
actions = fill_in_state(actions)
|
||||
|
||||
# Parse the images
|
||||
def parse_img(key):
|
||||
imgs = []
|
||||
for i in range(max(step_id - self.IMG_HISORY_SIZE + 1, 0), step_id + 1):
|
||||
img_bits = f["observations"]["images"][key][i]
|
||||
img = cv2.imdecode(np.frombuffer(img_bits, np.uint8), cv2.IMREAD_COLOR)
|
||||
imgs.append(img)
|
||||
imgs = np.stack(imgs)
|
||||
if imgs.shape[0] < self.IMG_HISORY_SIZE:
|
||||
# Pad the images using the first image
|
||||
imgs = np.concatenate(
|
||||
[
|
||||
np.tile(
|
||||
imgs[:1],
|
||||
(self.IMG_HISORY_SIZE - imgs.shape[0], 1, 1, 1),
|
||||
),
|
||||
imgs,
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
return imgs
|
||||
|
||||
# `cam_high` is the external camera image
|
||||
cam_high = parse_img("cam_high")
|
||||
# For step_id = first_idx - 1, the valid_len should be one
|
||||
valid_len = min(step_id - (first_idx - 1) + 1, self.IMG_HISORY_SIZE)
|
||||
cam_high_mask = np.array([False] * (self.IMG_HISORY_SIZE - valid_len) + [True] * valid_len)
|
||||
# cam_left_wrist = parse_img("cam_left_wrist")
|
||||
# cam_left_wrist_mask = cam_high_mask.copy()
|
||||
cam_left_wrist = np.zeros((self.IMG_HISORY_SIZE, 0, 0, 0))#parse_img('cam_right_wrist')
|
||||
cam_left_wrist_mask = np.array([False] * self.IMG_HISORY_SIZE)#cam_high_mask.copy()
|
||||
cam_right_wrist = parse_img("cam_right_wrist")
|
||||
cam_right_wrist_mask = cam_high_mask.copy() # 使用相同的掩码逻辑
|
||||
|
||||
# Return the resulting sample
|
||||
# For unavailable images, return zero-shape arrays, i.e., (IMG_HISORY_SIZE, 0, 0, 0)
|
||||
# E.g., return np.zeros((self.IMG_HISORY_SIZE, 0, 0, 0)) for the key "cam_left_wrist",
|
||||
# if the left-wrist camera is unavailable on your robot
|
||||
return True, {
|
||||
"meta": meta,
|
||||
"state": state,
|
||||
"state_std": state_std,
|
||||
"state_mean": state_mean,
|
||||
"state_norm": state_norm,
|
||||
"actions": actions,
|
||||
"state_indicator": state_indicator,
|
||||
"cam_high": cam_high,
|
||||
"cam_high_mask": cam_high_mask,
|
||||
"cam_left_wrist": cam_left_wrist,
|
||||
"cam_left_wrist_mask": cam_left_wrist_mask,
|
||||
"cam_right_wrist": cam_right_wrist,
|
||||
"cam_right_wrist_mask": cam_right_wrist_mask,
|
||||
}
|
||||
|
||||
def parse_hdf5_file_state_only(self, file_path):
|
||||
"""[Modify] Parse a hdf5 file to generate a state trajectory.
|
||||
|
||||
Args:
|
||||
file_path (str): the path to the hdf5 file
|
||||
|
||||
Returns:
|
||||
valid (bool): whether the episode is valid, which is useful for filtering.
|
||||
If False, this episode will be dropped.
|
||||
dict: a dictionary containing the training sample,
|
||||
{
|
||||
"state": ndarray, # state[:], (T, STATE_DIM).
|
||||
"action": ndarray, # action[:], (T, STATE_DIM).
|
||||
} or None if the episode is invalid.
|
||||
"""
|
||||
with h5py.File(file_path, "r") as f:
|
||||
qpos = f["observations"]["qpos"][:]
|
||||
left_arm_dim = f["observations"]["left_arm_dim"][:]
|
||||
right_arm_dim = f["observations"]["right_arm_dim"][:]
|
||||
|
||||
num_steps = qpos.shape[0]
|
||||
# [Optional] We drop too-short episode
|
||||
# if num_steps < 128:
|
||||
# return False, None
|
||||
|
||||
# [Optional] We skip the first few still steps
|
||||
EPS = 1e-2
|
||||
# Get the idx of the first qpos whose delta exceeds the threshold
|
||||
qpos_delta = np.abs(qpos - qpos[0:1])
|
||||
indices = np.where(np.any(qpos_delta > EPS, axis=1))[0]
|
||||
if len(indices) > 0:
|
||||
first_idx = indices[0]
|
||||
else:
|
||||
raise ValueError("Found no qpos that exceeds the threshold.")
|
||||
|
||||
# Rescale gripper to [0, 1]
|
||||
# qpos = qpos / np.array([[1 for i in range(left_arm_dim[0] + right_arm_dim[0] + 2)]])
|
||||
# target_qpos = f["action"][:] / np.array([[1 for i in range(left_arm_dim[0] + right_arm_dim[0] + 2)]])
|
||||
|
||||
qpos = qpos / np.array(
|
||||
# [[1, 1, 1, 1, 1, 1, 4.7908, 1, 1, 1, 1, 1, 1, 4.7888]]
|
||||
[[180, 180, 180, 180, 180, 180]]
|
||||
)
|
||||
target_qpos = f['action'][first_idx - 1:] / np.array(
|
||||
# [[1, 1, 1, 1, 1, 1, 11.8997, 1, 1, 1, 1, 1, 1, 13.9231]]
|
||||
[[180, 180, 180, 180, 180, 180]]
|
||||
)
|
||||
# Parse the state and action
|
||||
state = qpos[first_idx - 1:]
|
||||
action = target_qpos[first_idx - 1:]
|
||||
|
||||
# Standardize trajectory length to avoid batch size mismatch
|
||||
# Use a fixed length (e.g., 128) or pad/truncate to match
|
||||
target_length = 128 # You can adjust this value
|
||||
if state.shape[0] > target_length:
|
||||
# Truncate to target length
|
||||
state = state[:target_length]
|
||||
action = action[:target_length]
|
||||
elif state.shape[0] < target_length:
|
||||
# Pad with the last state/action
|
||||
pad_length = target_length - state.shape[0]
|
||||
state = np.concatenate([state, np.tile(state[-1:], (pad_length, 1))], axis=0)
|
||||
action = np.concatenate([action, np.tile(action[-1:], (pad_length, 1))], axis=0)
|
||||
|
||||
# Fill the state/action into the unified vector
|
||||
def fill_in_state(values):
|
||||
# Target indices corresponding to your state space
|
||||
# In this example: 6 joints + 1 gripper for each arm
|
||||
UNI_STATE_INDICES = [
|
||||
STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(6)
|
||||
# ] + [
|
||||
# STATE_VEC_IDX_MAPPING["right_gripper_open"]
|
||||
]
|
||||
uni_vec = np.zeros(values.shape[:-1] + (self.STATE_DIM, ))
|
||||
uni_vec[..., UNI_STATE_INDICES] = values
|
||||
return uni_vec
|
||||
|
||||
state = fill_in_state(state)
|
||||
action = fill_in_state(action)
|
||||
|
||||
# Return the resulting sample
|
||||
return True, {"state": state, "action": action}
|
||||
|
||||
if __name__ == "__main__":
|
||||
ds = HDF5VLADataset()
|
||||
for i in range(len(ds)):
|
||||
print(f"Processing episode {i}/{len(ds)}...")
|
||||
ds.get_item(i)
|
||||
299
RDT-1B/data/preprocess.py
Normal file
299
RDT-1B/data/preprocess.py
Normal file
@ -0,0 +1,299 @@
|
||||
import json
|
||||
|
||||
import tensorflow as tf
|
||||
import yaml
|
||||
|
||||
from data.preprocess_scripts import *
|
||||
from configs.state_vec import STATE_VEC_IDX_MAPPING, STATE_VEC_LEN
|
||||
from data.utils import capitalize_and_period
|
||||
|
||||
# The dataset without state
|
||||
DATASET_NAMES_NO_STATE = [
|
||||
"nyu_door_opening_surprising_effectiveness",
|
||||
"usc_cloth_sim_converted_externally_to_rlds",
|
||||
"cmu_franka_exploration_dataset_converted_externally_to_rlds",
|
||||
"imperialcollege_sawyer_wrist_cam",
|
||||
]
|
||||
|
||||
# Read the image keys of each dataset
|
||||
with open("configs/dataset_img_keys.json", "r") as file:
|
||||
IMAGE_KEYS = json.load(file)
|
||||
# Read the config
|
||||
with open("configs/base.yaml", "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
|
||||
|
||||
def assemble_state_vec(arm_concat: tf.Tensor, arm_format: str, base_concat=None, base_format=None) -> tf.Tensor:
|
||||
"""
|
||||
Assemble the state/action vector from the arm and base.
|
||||
"""
|
||||
state_vec = tf.zeros(STATE_VEC_LEN, dtype=tf.float32)
|
||||
mask_vec = tf.zeros(STATE_VEC_LEN, dtype=tf.float32)
|
||||
|
||||
# Assemble the arm state
|
||||
arm_concat = tf.cast(arm_concat, tf.float32)
|
||||
arm_format = arm_format.split(",")
|
||||
# Use the scatter_nd to avoid the duplicate indices
|
||||
state_vec = tf.tensor_scatter_nd_update(state_vec, [[STATE_VEC_IDX_MAPPING[name]] for name in arm_format],
|
||||
arm_concat)
|
||||
mask_vec = tf.tensor_scatter_nd_update(
|
||||
mask_vec,
|
||||
[[STATE_VEC_IDX_MAPPING[name]] for name in arm_format],
|
||||
tf.ones(len(arm_format), dtype=tf.float32),
|
||||
)
|
||||
|
||||
# Assemble the base state if exists
|
||||
if base_concat is not None:
|
||||
base_concat = tf.cast(base_concat, tf.float32)
|
||||
base_format = base_format.split(",")
|
||||
state_vec = tf.tensor_scatter_nd_update(
|
||||
state_vec,
|
||||
[[STATE_VEC_IDX_MAPPING[name]] for name in base_format],
|
||||
base_concat,
|
||||
)
|
||||
mask_vec = tf.tensor_scatter_nd_update(
|
||||
mask_vec,
|
||||
[[STATE_VEC_IDX_MAPPING[name]] for name in base_format],
|
||||
tf.ones(len(base_format), dtype=tf.float32),
|
||||
)
|
||||
return state_vec, mask_vec
|
||||
|
||||
|
||||
@tf.autograph.experimental.do_not_convert
|
||||
def _generate_json_state_agilex(episode: dict, dataset_name: str):
|
||||
"""
|
||||
Generate the json dict and state for a given episode.
|
||||
"""
|
||||
# Load some constants from the config
|
||||
IMG_HISTORY_SIZE = config["common"]["img_history_size"]
|
||||
if IMG_HISTORY_SIZE < 1:
|
||||
raise ValueError("Config `img_history_size` must be at least 1.")
|
||||
ACTION_CHUNK_SIZE = config["common"]["action_chunk_size"]
|
||||
if ACTION_CHUNK_SIZE < 1:
|
||||
raise ValueError("Config `action_chunk_size` must be at least 1.")
|
||||
|
||||
# Initialize the episode_metadata
|
||||
episode_metadata = {"dataset_name": dataset_name, "#steps": 0, "instruction": None}
|
||||
|
||||
# Check whether this episode has an 'END'
|
||||
base_act = None
|
||||
last_base_act = None
|
||||
episode_states = []
|
||||
episode_acts = []
|
||||
episode_masks = []
|
||||
has_base = None
|
||||
for step_id, step in enumerate(iter(episode["steps"])):
|
||||
# Parse the action
|
||||
action = step["action"]
|
||||
if has_base is None:
|
||||
has_base = "base_concat" in action
|
||||
if has_base:
|
||||
base_act = action["base_concat"]
|
||||
|
||||
# Parse the state
|
||||
state = step["observation"]
|
||||
|
||||
arm_format = state["format"].numpy().decode("utf-8")
|
||||
base_format = None
|
||||
if has_base:
|
||||
act_format = action["format"].numpy().decode("utf-8")
|
||||
base_formate_idx = act_format.find("base")
|
||||
base_format = act_format[base_formate_idx:]
|
||||
|
||||
arm_state = state["arm_concat"]
|
||||
base_state = None
|
||||
if has_base:
|
||||
if last_base_act is None:
|
||||
base_state = base_act * 0
|
||||
else:
|
||||
base_state = last_base_act
|
||||
last_base_act = base_act
|
||||
|
||||
# Assemble the state vector
|
||||
state_vec, mask_vec = assemble_state_vec(arm_state, arm_format, base_state, base_format)
|
||||
|
||||
act_vec, mask_vec = assemble_state_vec(action["arm_concat"], arm_format, base_state, base_format)
|
||||
|
||||
episode_states.append(state_vec)
|
||||
episode_masks.append(mask_vec)
|
||||
episode_acts.append(act_vec)
|
||||
|
||||
# Parse the task instruction
|
||||
instr = step["observation"]["natural_language_instruction"]
|
||||
instr = instr.numpy().decode("utf-8")
|
||||
instr = capitalize_and_period(instr)
|
||||
|
||||
# Write to the episode_metadata
|
||||
if episode_metadata["instruction"] is None:
|
||||
episode_metadata["instruction"] = instr
|
||||
|
||||
episode_metadata["#steps"] = step_id
|
||||
|
||||
episode_states = tf.stack(episode_states)
|
||||
episode_masks = tf.stack(episode_masks)
|
||||
episode_acts = tf.stack(episode_acts)
|
||||
|
||||
return episode_metadata, episode_states, episode_masks, episode_acts
|
||||
|
||||
|
||||
@tf.autograph.experimental.do_not_convert
|
||||
def _generate_json_state(episode: dict, dataset_name: str):
|
||||
"""
|
||||
Generate the json dict and state for a given episode.
|
||||
"""
|
||||
# Load some constants from the config
|
||||
IMG_HISTORY_SIZE = config["common"]["img_history_size"]
|
||||
if IMG_HISTORY_SIZE < 1:
|
||||
raise ValueError("Config `img_history_size` must be at least 1.")
|
||||
ACTION_CHUNK_SIZE = config["common"]["action_chunk_size"]
|
||||
if ACTION_CHUNK_SIZE < 1:
|
||||
raise ValueError("Config `action_chunk_size` must be at least 1.")
|
||||
|
||||
# Initialize the episode_metadata
|
||||
episode_metadata = {"dataset_name": dataset_name, "#steps": 0, "instruction": None}
|
||||
|
||||
# Check whether this episode has an 'END'
|
||||
base_act = None
|
||||
last_base_act = None
|
||||
episode_states = []
|
||||
episode_masks = []
|
||||
has_base = None
|
||||
for step_id, step in enumerate(iter(episode["steps"])):
|
||||
# Parse the action
|
||||
action = step["action"]
|
||||
if has_base is None:
|
||||
has_base = "base_concat" in action
|
||||
if has_base:
|
||||
base_act = action["base_concat"]
|
||||
|
||||
# Parse the state
|
||||
state = step["observation"]
|
||||
|
||||
arm_format = state["format"].numpy().decode("utf-8")
|
||||
base_format = None
|
||||
if has_base:
|
||||
act_format = action["format"].numpy().decode("utf-8")
|
||||
base_formate_idx = act_format.find("base")
|
||||
base_format = act_format[base_formate_idx:]
|
||||
|
||||
arm_state = state["arm_concat"]
|
||||
base_state = None
|
||||
if has_base:
|
||||
if last_base_act is None:
|
||||
base_state = base_act * 0
|
||||
else:
|
||||
base_state = last_base_act
|
||||
last_base_act = base_act
|
||||
|
||||
# Assemble the state vector
|
||||
state_vec, mask_vec = assemble_state_vec(arm_state, arm_format, base_state, base_format)
|
||||
|
||||
episode_states.append(state_vec)
|
||||
episode_masks.append(mask_vec)
|
||||
|
||||
# Parse the task instruction
|
||||
instr = step["observation"]["natural_language_instruction"]
|
||||
instr = instr.numpy().decode("utf-8")
|
||||
instr = capitalize_and_period(instr)
|
||||
|
||||
# Write to the episode_metadata
|
||||
if episode_metadata["instruction"] is None:
|
||||
episode_metadata["instruction"] = instr
|
||||
|
||||
episode_metadata["#steps"] = step_id
|
||||
episode_states = tf.stack(episode_states)
|
||||
episode_masks = tf.stack(episode_masks)
|
||||
|
||||
return episode_metadata, episode_states, episode_masks
|
||||
|
||||
|
||||
@tf.autograph.experimental.do_not_convert
|
||||
def _generate_json_state_nostate_ds(episode: dict, dataset_name: str):
|
||||
"""
|
||||
Generate the json dict and state for an episode in the dataset without state.
|
||||
If not state, we use the last action as current state.
|
||||
"""
|
||||
# Load some constants from the config
|
||||
IMG_HISTORY_SIZE = config["common"]["img_history_size"]
|
||||
if IMG_HISTORY_SIZE < 1:
|
||||
raise ValueError("Config `img_history_size` must be at least 1.")
|
||||
ACTION_CHUNK_SIZE = config["common"]["action_chunk_size"]
|
||||
if ACTION_CHUNK_SIZE < 1:
|
||||
raise ValueError("Config `action_chunk_size` must be at least 1.")
|
||||
|
||||
# Initialize the episode_metadata
|
||||
episode_metadata = {"dataset_name": dataset_name, "#steps": 0, "instruction": None}
|
||||
|
||||
last_base_act = None
|
||||
last_arm_act = None
|
||||
episode_states = []
|
||||
episode_masks = []
|
||||
has_base = None
|
||||
for step_id, step in enumerate(iter(episode["steps"])):
|
||||
# Parse the action
|
||||
action = step["action"]
|
||||
if has_base is None:
|
||||
has_base = "base_concat" in action
|
||||
if has_base:
|
||||
base_act = action["base_concat"]
|
||||
if last_base_act is None:
|
||||
last_base_act = base_act * 0 # Initialize
|
||||
|
||||
# Parse the arm action
|
||||
arm_act = action["arm_concat"]
|
||||
if last_arm_act is None:
|
||||
last_arm_act = arm_act * 0 # Initialize
|
||||
|
||||
# Parse the act format
|
||||
# Action format as the state format
|
||||
act_format = action["format"].numpy().decode("utf-8")
|
||||
|
||||
# Assemble the state vector
|
||||
if has_base:
|
||||
last_act_concat = tf.concat([last_arm_act, last_base_act], axis=0)
|
||||
else:
|
||||
last_act_concat = last_arm_act
|
||||
state_vec, mask_vec = assemble_state_vec(last_act_concat, act_format)
|
||||
|
||||
episode_states.append(state_vec)
|
||||
episode_masks.append(mask_vec)
|
||||
|
||||
# Parse the task instruction
|
||||
instr = step["observation"]["natural_language_instruction"]
|
||||
instr = instr.numpy().decode("utf-8")
|
||||
instr = capitalize_and_period(instr)
|
||||
|
||||
# Write to the episode_metadata
|
||||
if episode_metadata["instruction"] is None:
|
||||
episode_metadata["instruction"] = instr
|
||||
|
||||
# Update the last_arm_act and last_base_act
|
||||
last_arm_act = arm_act
|
||||
if has_base:
|
||||
last_base_act = base_act
|
||||
|
||||
episode_metadata["#steps"] = step_id
|
||||
episode_states = tf.stack(episode_states)
|
||||
episode_masks = tf.stack(episode_masks)
|
||||
|
||||
return episode_metadata, episode_states, episode_masks
|
||||
|
||||
|
||||
@tf.autograph.experimental.do_not_convert
|
||||
def generate_json_state(episode: dict, dataset_name: str):
|
||||
"""
|
||||
Generate the json dict and state for an episode.
|
||||
"""
|
||||
if isinstance(dataset_name, tf.Tensor):
|
||||
dataset_name = dataset_name.numpy().decode("utf-8")
|
||||
|
||||
# Process each step in the episode
|
||||
episode["steps"] = episode["steps"].map(globals()[dataset_name].process_step, )
|
||||
|
||||
if dataset_name == "agilex":
|
||||
return _generate_json_state_agilex(episode, dataset_name)
|
||||
|
||||
if dataset_name in DATASET_NAMES_NO_STATE:
|
||||
return _generate_json_state_nostate_ds(episode, dataset_name)
|
||||
|
||||
return _generate_json_state(episode, dataset_name)
|
||||
313
RDT-1B/data/producer.py
Normal file
313
RDT-1B/data/producer.py
Normal file
@ -0,0 +1,313 @@
|
||||
import time
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import sys
|
||||
import signal
|
||||
import random
|
||||
from multiprocessing import Process
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import yaml
|
||||
|
||||
from data.vla_dataset import VLADataset
|
||||
from data.filelock import FileLock
|
||||
|
||||
# Producer does not need GPU
|
||||
tf.config.set_visible_devices([], "GPU")
|
||||
|
||||
# Read the config
|
||||
with open("configs/base.yaml", "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
# Load some constants from the config
|
||||
BUF_PATH = config["dataset"]["buf_path"]
|
||||
BUF_NUM_CHUNKS = config["dataset"]["buf_num_chunks"]
|
||||
if BUF_NUM_CHUNKS < 1:
|
||||
raise ValueError("Config `buf_num_chunks` must be at least 1.")
|
||||
BUF_CHUNK_SIZE = config["dataset"]["buf_chunk_size"]
|
||||
if BUF_CHUNK_SIZE < 1:
|
||||
raise ValueError("Config `buf_chunk_size` must be at least 1.")
|
||||
|
||||
|
||||
def get_dirty_item(chunk_dir):
|
||||
"""
|
||||
Get indexes of dirty items in a chunk.
|
||||
"""
|
||||
dirty_bit = read_dirty_bit(chunk_dir)
|
||||
return np.where(dirty_bit)[0].tolist()
|
||||
|
||||
|
||||
def get_clean_item(chunk_dir):
|
||||
"""
|
||||
Get indexes of clean items in a chunk.
|
||||
"""
|
||||
dirty_bit = read_dirty_bit(chunk_dir)
|
||||
return np.where(1 - dirty_bit)[0].tolist()
|
||||
|
||||
|
||||
def save_dirty_bit(chunk_dir, dirty_bit):
|
||||
"""
|
||||
Save the dirty bit to the chunk directory.
|
||||
"""
|
||||
time_stmp = time.time()
|
||||
while time.time() - time_stmp < 10.0:
|
||||
try:
|
||||
file_path = os.path.join(chunk_dir, "dirty_bit")
|
||||
lock = FileLock(file_path)
|
||||
lock.acquire_write_lock()
|
||||
with open(file_path, "wb") as file:
|
||||
file.write(dirty_bit.tobytes())
|
||||
lock.release_lock()
|
||||
return
|
||||
except KeyboardInterrupt:
|
||||
lock.release_lock()
|
||||
raise KeyboardInterrupt
|
||||
except BaseException:
|
||||
lock.release_lock()
|
||||
continue
|
||||
# raise RuntimeError("Failed to save dirty bit.")
|
||||
print("Failed to save dirty bit.")
|
||||
|
||||
|
||||
def read_dirty_bit(chunk_dir):
|
||||
"""
|
||||
Read the dirty bit from the chunk directory.
|
||||
"""
|
||||
# If error occurs, retry
|
||||
time_stmp = time.time()
|
||||
while time.time() - time_stmp < 10.0:
|
||||
try:
|
||||
file_path = os.path.join(chunk_dir, "dirty_bit")
|
||||
lock = FileLock(file_path)
|
||||
lock.acquire_read_lock()
|
||||
with open(file_path, "rb") as file:
|
||||
dirty_bit = np.frombuffer(file.read(), dtype=np.uint8).copy()
|
||||
lock.release_lock()
|
||||
assert len(dirty_bit) == BUF_CHUNK_SIZE
|
||||
return dirty_bit
|
||||
except KeyboardInterrupt:
|
||||
lock.release_lock()
|
||||
raise KeyboardInterrupt
|
||||
except BaseException:
|
||||
lock.release_lock()
|
||||
continue
|
||||
# If failed to read the dirty bit, return all ones for robustness
|
||||
return np.ones(BUF_CHUNK_SIZE, dtype=np.uint8)
|
||||
|
||||
|
||||
def save_sample(step_dict, chunk_dir, chunk_item_idx):
|
||||
"""
|
||||
Save a sample to the chunk directory.
|
||||
"""
|
||||
# Save the json content
|
||||
time_stmp = time.time()
|
||||
while time.time() - time_stmp < 10.0:
|
||||
try:
|
||||
locks = []
|
||||
json_content = step_dict["json_content"]
|
||||
file_path = os.path.join(chunk_dir, f"json_content_{chunk_item_idx}.json")
|
||||
lock = FileLock(file_path)
|
||||
locks.append(lock)
|
||||
lock.acquire_write_lock()
|
||||
with open(file_path, "w") as file:
|
||||
json.dump(json_content, file, indent=4)
|
||||
lock.release_lock()
|
||||
# Save all other tensors in a npz
|
||||
file_path = os.path.join(chunk_dir, f"sample_{chunk_item_idx}.npz")
|
||||
lock = FileLock(file_path)
|
||||
locks.append(lock)
|
||||
lock.acquire_write_lock()
|
||||
with open(file_path, "wb") as file:
|
||||
np.savez(
|
||||
file,
|
||||
step_id=step_dict["step_id"].numpy(),
|
||||
state_chunk=step_dict["state_chunk"].numpy(),
|
||||
state_chunk_time_mask=step_dict["state_chunk_time_mask"].numpy(),
|
||||
action_chunk=step_dict["action_chunk"].numpy(),
|
||||
action_chunk_time_mask=step_dict["action_chunk_time_mask"].numpy(),
|
||||
state_vec_mask=step_dict["state_vec_mask"].numpy(),
|
||||
past_frames_0=step_dict["past_frames_0"].numpy(),
|
||||
past_frames_0_time_mask=step_dict["past_frames_0_time_mask"].numpy(),
|
||||
past_frames_1=step_dict["past_frames_1"].numpy(),
|
||||
past_frames_1_time_mask=step_dict["past_frames_1_time_mask"].numpy(),
|
||||
past_frames_2=step_dict["past_frames_2"].numpy(),
|
||||
past_frames_2_time_mask=step_dict["past_frames_2_time_mask"].numpy(),
|
||||
past_frames_3=step_dict["past_frames_3"].numpy(),
|
||||
past_frames_3_time_mask=step_dict["past_frames_3_time_mask"].numpy(),
|
||||
state_std=step_dict["state_std"].numpy(),
|
||||
state_mean=step_dict["state_mean"].numpy(),
|
||||
state_norm=step_dict["state_norm"].numpy(),
|
||||
)
|
||||
lock.release_lock()
|
||||
return
|
||||
except KeyboardInterrupt:
|
||||
for lock in locks:
|
||||
lock.release_lock()
|
||||
raise KeyboardInterrupt
|
||||
except BaseException:
|
||||
for lock in locks:
|
||||
lock.release_lock()
|
||||
continue
|
||||
# raise RuntimeError("Failed to save sample.")
|
||||
print("Failed to save sample.")
|
||||
|
||||
|
||||
def run_producer(seed, num_workers, worker_id, fill_up, clean_dirty, dataset_type):
|
||||
"""
|
||||
Run the producer.
|
||||
The producer will first fill up the buffer with samples.
|
||||
Then it will keep replacing dirty samples
|
||||
(i.e., samples that have been read by the consumer)
|
||||
with new samples.
|
||||
"""
|
||||
vla_dataset = VLADataset(seed=seed, dataset_type=dataset_type)
|
||||
chunk_start_idx = worker_id * BUF_NUM_CHUNKS // num_workers
|
||||
chunk_end_idx = (worker_id + 1) * BUF_NUM_CHUNKS // num_workers
|
||||
if fill_up:
|
||||
print(f"Worker {worker_id}: Start filling up the buffer...")
|
||||
elif clean_dirty:
|
||||
# Only refresh the dirty bits
|
||||
print(f"Worker {worker_id}: Start refreshing the dirty bits...")
|
||||
for chunk_idx in range(chunk_start_idx, chunk_end_idx):
|
||||
chunk_dir = os.path.join(BUF_PATH, f"chunk_{chunk_idx}")
|
||||
dirty_bit = np.zeros(BUF_CHUNK_SIZE, dtype=np.uint8)
|
||||
save_dirty_bit(chunk_dir, dirty_bit)
|
||||
print(f"Worker {worker_id}: Refreshed the dirty bits.")
|
||||
|
||||
fill_chunk_idx = chunk_start_idx
|
||||
fill_chunk_item_idx = 0
|
||||
dirty_chunk_idx = chunk_start_idx
|
||||
dirty_chunk_item_idxs = []
|
||||
time_stmp = time.time()
|
||||
for episode_steps in vla_dataset:
|
||||
for step in episode_steps:
|
||||
if fill_up and fill_chunk_idx < chunk_end_idx:
|
||||
# Fill up the buffer
|
||||
chunk_dir = os.path.join(BUF_PATH, f"chunk_{fill_chunk_idx}")
|
||||
if fill_chunk_item_idx == 0:
|
||||
# Create a new chunk
|
||||
os.makedirs(chunk_dir, exist_ok=True)
|
||||
# Write the dirty bit of size BUF_CHUNK_SIZE
|
||||
dirty_bit = np.zeros(BUF_CHUNK_SIZE, dtype=np.uint8)
|
||||
save_dirty_bit(chunk_dir, dirty_bit)
|
||||
|
||||
# Save the sample
|
||||
save_sample(step, chunk_dir, fill_chunk_item_idx)
|
||||
|
||||
# print(f"Filled up chunk {fill_chunk_item_idx+1}/{BUF_CHUNK_SIZE} {fill_chunk_idx+1}/{BUF_NUM_CHUNKS}")
|
||||
local_fill_chunk_idx = fill_chunk_idx - chunk_start_idx
|
||||
local_num_chunks = chunk_end_idx - chunk_start_idx
|
||||
if (local_fill_chunk_idx % 10 == 0
|
||||
or local_fill_chunk_idx == local_num_chunks - 1) and fill_chunk_item_idx == 0:
|
||||
print(f"Worker {worker_id}: Filled up chunk {local_fill_chunk_idx+1}/{local_num_chunks}")
|
||||
fill_chunk_item_idx += 1
|
||||
if fill_chunk_item_idx == BUF_CHUNK_SIZE:
|
||||
fill_chunk_idx += 1
|
||||
fill_chunk_item_idx = 0
|
||||
if fill_chunk_idx == BUF_NUM_CHUNKS:
|
||||
print(f"Worker {worker_id}: Buffer filled up. Start replacing dirty samples...")
|
||||
|
||||
else:
|
||||
# Search for the dirty chunk to replace
|
||||
while len(dirty_chunk_item_idxs) == 0:
|
||||
dirty_chunk_dir = os.path.join(BUF_PATH, f"chunk_{dirty_chunk_idx}")
|
||||
dirty_chunk_item_idxs = get_dirty_item(dirty_chunk_dir)
|
||||
# Print the dirty ratio
|
||||
if time.time() - time_stmp > 2.0:
|
||||
dirty_ratio = len(dirty_chunk_item_idxs) / BUF_CHUNK_SIZE
|
||||
print(f"Worker {worker_id}: Dirty Ratio for Chunk {dirty_chunk_idx}: {dirty_ratio:.2f}")
|
||||
time_stmp = time.time()
|
||||
|
||||
if len(dirty_chunk_item_idxs) > 0:
|
||||
# Lock the chunk
|
||||
dirty_bit = np.ones(BUF_CHUNK_SIZE, dtype=np.uint8)
|
||||
save_dirty_bit(dirty_chunk_dir, dirty_bit)
|
||||
|
||||
# Iterate over the chunks
|
||||
dirty_chunk_idx += 1
|
||||
if dirty_chunk_idx == chunk_end_idx:
|
||||
dirty_chunk_idx = chunk_start_idx
|
||||
|
||||
# Replace the dirty item
|
||||
dirty_item_idx = dirty_chunk_item_idxs.pop()
|
||||
chunk_dir = os.path.join(BUF_PATH, f"chunk_{dirty_chunk_idx}")
|
||||
# Save the sample
|
||||
save_sample(step, chunk_dir, dirty_item_idx)
|
||||
|
||||
# If we have replaced all dirty items in the chunk
|
||||
if len(dirty_chunk_item_idxs) == 0:
|
||||
# Unlock the chunk
|
||||
dirty_bit = np.zeros(BUF_CHUNK_SIZE, dtype=np.uint8)
|
||||
save_dirty_bit(dirty_chunk_dir, dirty_bit)
|
||||
print(f"Worker {worker_id}: Replaced dirty chunk {dirty_chunk_idx}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Args: n_workers, fill_up
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--n_workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of parallel workers. It should be less than or equal to the number of chunks.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fill_up",
|
||||
action="store_true",
|
||||
help="Whether to fill up the buffer before replacing dirty samples.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clean_dirty",
|
||||
action="store_true",
|
||||
help=
|
||||
"Whether to clean the dirty bits before replacing dirty samples. This option is ignored when `fill_up` is set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Random seed. If not set, the seed will be randomly generated.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_type",
|
||||
type=str,
|
||||
default="pretrain",
|
||||
help="Whether to load the pretrain dataset or finetune dataset.",
|
||||
)
|
||||
|
||||
# Run the producer
|
||||
args = parser.parse_args()
|
||||
if args.seed is not None:
|
||||
print(f"Base seed: {args.seed}")
|
||||
random.seed(args.seed)
|
||||
|
||||
processes = []
|
||||
process_seeds = [random.randint(0, 2**32) for _ in range(args.n_workers)]
|
||||
print(f"Process seeds: {process_seeds}")
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
print("Ctrl+C received. Terminating child processes...")
|
||||
for p in processes:
|
||||
p.terminate()
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
for worker_id in range(args.n_workers):
|
||||
p = Process(
|
||||
target=run_producer,
|
||||
args=(
|
||||
process_seeds[worker_id],
|
||||
args.n_workers,
|
||||
worker_id,
|
||||
args.fill_up,
|
||||
args.clean_dirty,
|
||||
args.dataset_type,
|
||||
),
|
||||
)
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
242
RDT-1B/data/utils.py
Normal file
242
RDT-1B/data/utils.py
Normal file
@ -0,0 +1,242 @@
|
||||
import tensorflow as tf
|
||||
import tensorflow_graphics.geometry.transformation.euler as tf_euler
|
||||
import tensorflow_graphics.geometry.transformation.quaternion as tf_quat
|
||||
import tensorflow_graphics.geometry.transformation.rotation_matrix_3d as tf_rotmat
|
||||
|
||||
|
||||
def dataset_to_path(dataset_name: str, dir_name: str) -> str:
|
||||
"""
|
||||
Return the path to the dataset.
|
||||
"""
|
||||
if (dataset_name == "robo_net" or dataset_name == "cmu_playing_with_food" or dataset_name == "droid"):
|
||||
version = "1.0.0"
|
||||
elif (dataset_name == "language_table" or dataset_name == "fmb" or dataset_name == "dobbe"):
|
||||
version = "0.0.1"
|
||||
elif dataset_name == "nyu_door_opening_surprising_effectiveness":
|
||||
version = ""
|
||||
elif dataset_name == "cmu_play_fusion":
|
||||
version = ""
|
||||
elif dataset_name == "berkeley_gnm_recon":
|
||||
version = ""
|
||||
else:
|
||||
version = "0.1.0"
|
||||
return f"{dir_name}/{dataset_name}/{version}"
|
||||
|
||||
|
||||
def clean_task_instruction(task_instruction: tf.Tensor, replacements: dict) -> tf.Tensor:
|
||||
"""
|
||||
Clean up the natural language task instruction.
|
||||
"""
|
||||
|
||||
# Create a function that applies all replacements
|
||||
def apply_replacements(tensor):
|
||||
for old, new in replacements.items():
|
||||
tensor = tf.strings.regex_replace(tensor, old, new)
|
||||
return tensor
|
||||
|
||||
# Apply the replacements and strip leading and trailing spaces
|
||||
cleaned_task_instruction = apply_replacements(task_instruction)
|
||||
cleaned_task_instruction = tf.strings.strip(cleaned_task_instruction)
|
||||
return cleaned_task_instruction
|
||||
|
||||
|
||||
def quaternion_to_euler(quaternion: tf.Tensor) -> tf.Tensor:
|
||||
"""
|
||||
Convert a quaternion (x, y, z, w) to Euler angles (roll, pitch, yaw).
|
||||
The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention.
|
||||
"""
|
||||
# Normalize the quaternion
|
||||
quaternion = tf.nn.l2_normalize(quaternion, axis=-1)
|
||||
return tf_euler.from_quaternion(quaternion)
|
||||
|
||||
|
||||
def euler_to_quaternion(euler: tf.Tensor) -> tf.Tensor:
|
||||
"""
|
||||
Convert Euler angles (roll, pitch, yaw) to a quaternion (x, y, z, w).
|
||||
The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention.
|
||||
"""
|
||||
quaternion = tf_quat.from_euler(euler)
|
||||
return tf.nn.l2_normalize(quaternion, axis=-1)
|
||||
|
||||
|
||||
def rotation_matrix_to_euler(matrix: tf.Tensor) -> tf.Tensor:
|
||||
"""
|
||||
Convert a 3x3 rotation matrix to Euler angles (roll, pitch, yaw).
|
||||
The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention.
|
||||
"""
|
||||
return tf_euler.from_rotation_matrix(matrix)
|
||||
|
||||
|
||||
def rotation_matrix_to_quaternion(matrix: tf.Tensor) -> tf.Tensor:
|
||||
"""
|
||||
Convert a 3x3 rotation matrix to a quaternion (x, y, z, w).
|
||||
"""
|
||||
quaternion = tf_quat.from_rotation_matrix(matrix)
|
||||
return tf.nn.l2_normalize(quaternion, axis=-1)
|
||||
|
||||
|
||||
def euler_to_rotation_matrix(euler: tf.Tensor) -> tf.Tensor:
|
||||
"""
|
||||
Convert Euler angles (roll, pitch, yaw) to a 3x3 rotation matrix.
|
||||
The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention.
|
||||
"""
|
||||
return tf_rotmat.from_euler(euler)
|
||||
|
||||
|
||||
def quaternion_to_rotation_matrix(quaternion: tf.Tensor) -> tf.Tensor:
|
||||
"""
|
||||
Convert a quaternion (x, y, z, w) to a 3x3 rotation matrix.
|
||||
"""
|
||||
# Normalize the quaternion
|
||||
quaternion = tf.nn.l2_normalize(quaternion, axis=-1)
|
||||
return tf_rotmat.from_quaternion(quaternion)
|
||||
|
||||
|
||||
def quaternion_to_rotation_matrix_wo_static_check(quaternion: tf.Tensor) -> tf.Tensor:
|
||||
"""
|
||||
Convert a quaternion (x, y, z, w) to a 3x3 rotation matrix.
|
||||
This function is used to make tensorflow happy.
|
||||
"""
|
||||
# Normalize the quaternion
|
||||
quaternion = tf.nn.l2_normalize(quaternion, axis=-1)
|
||||
|
||||
x = quaternion[..., 0]
|
||||
y = quaternion[..., 1]
|
||||
z = quaternion[..., 2]
|
||||
w = quaternion[..., 3]
|
||||
|
||||
tx = 2.0 * x
|
||||
ty = 2.0 * y
|
||||
tz = 2.0 * z
|
||||
twx = tx * w
|
||||
twy = ty * w
|
||||
twz = tz * w
|
||||
txx = tx * x
|
||||
txy = ty * x
|
||||
txz = tz * x
|
||||
tyy = ty * y
|
||||
tyz = tz * y
|
||||
tzz = tz * z
|
||||
matrix = tf.stack(
|
||||
(
|
||||
1.0 - (tyy + tzz),
|
||||
txy - twz,
|
||||
txz + twy,
|
||||
txy + twz,
|
||||
1.0 - (txx + tzz),
|
||||
tyz - twx,
|
||||
txz - twy,
|
||||
tyz + twx,
|
||||
1.0 - (txx + tyy),
|
||||
),
|
||||
axis=-1,
|
||||
) # pyformat: disable
|
||||
output_shape = tf.concat((tf.shape(input=quaternion)[:-1], (3, 3)), axis=-1)
|
||||
return tf.reshape(matrix, shape=output_shape)
|
||||
|
||||
|
||||
"""
|
||||
Below is a continuous 6D rotation representation adapted from
|
||||
On the Continuity of Rotation Representations in Neural Networks
|
||||
https://arxiv.org/pdf/1812.07035.pdf
|
||||
https://github.com/papagina/RotationContinuity/blob/master/sanity_test/code/tools.py
|
||||
"""
|
||||
|
||||
|
||||
def rotation_matrix_to_ortho6d(matrix: tf.Tensor) -> tf.Tensor:
|
||||
"""
|
||||
The orhto6d represents the first two column vectors a1 and a2 of the
|
||||
rotation matrix: [ | , |, | ]
|
||||
[ a1, a2, a3]
|
||||
[ | , |, | ]
|
||||
Input: (A1, ..., An, 3, 3)
|
||||
Output: (A1, ..., An, 6)
|
||||
"""
|
||||
ortho6d = matrix[..., :, :2]
|
||||
# Transpose the last two dimension
|
||||
perm = list(range(len(ortho6d.shape)))
|
||||
perm[-2], perm[-1] = perm[-1], perm[-2]
|
||||
ortho6d = tf.transpose(ortho6d, perm)
|
||||
# Flatten the last two dimension
|
||||
ortho6d = tf.reshape(ortho6d, ortho6d.shape[:-2] + [6])
|
||||
return ortho6d
|
||||
|
||||
|
||||
def rotation_matrix_to_ortho6d_1d(matrix: tf.Tensor) -> tf.Tensor:
|
||||
"""
|
||||
The orhto6d represents the first two column vectors a1 and a2 of the
|
||||
rotation matrix: [ | , |, | ]
|
||||
[ a1, a2, a3]
|
||||
[ | , |, | ]
|
||||
Input: (3, 3)
|
||||
Output: (6,)
|
||||
This function is used to make tensorflow happy.
|
||||
"""
|
||||
ortho6d = matrix[:, :2]
|
||||
# Transpose the last two dimension
|
||||
ortho6d = tf.transpose(ortho6d)
|
||||
# Flatten the last two dimension
|
||||
ortho6d = tf.reshape(ortho6d, [6])
|
||||
return ortho6d
|
||||
|
||||
|
||||
def normalize_vector(v):
|
||||
"""
|
||||
v: (..., N)
|
||||
"""
|
||||
v_mag = tf.sqrt(tf.reduce_sum(tf.square(v), axis=-1, keepdims=True))
|
||||
v_mag = tf.maximum(v_mag, 1e-8)
|
||||
v_normalized = v / v_mag
|
||||
|
||||
return v_normalized
|
||||
|
||||
|
||||
def cross_product(u, v):
|
||||
"""
|
||||
u: (..., 3)
|
||||
v: (..., 3)
|
||||
u x v: (..., 3)
|
||||
"""
|
||||
i = u[..., 1] * v[..., 2] - u[..., 2] * v[..., 1]
|
||||
j = u[..., 2] * v[..., 0] - u[..., 0] * v[..., 2]
|
||||
k = u[..., 0] * v[..., 1] - u[..., 1] * v[..., 0]
|
||||
out = tf.stack([i, j, k], axis=-1)
|
||||
return out
|
||||
|
||||
|
||||
def ortho6d_to_rotation_matrix(ortho6d: tf.Tensor) -> tf.Tensor:
|
||||
"""
|
||||
The orhto6d represents the first two column vectors a1 and a2 of the
|
||||
rotation matrix: [ | , |, | ]
|
||||
[ a1, a2, a3]
|
||||
[ | , |, | ]
|
||||
Input: (A1, ..., An, 6)
|
||||
Output: (A1, ..., An, 3, 3)
|
||||
"""
|
||||
x_raw = ortho6d[..., 0:3]
|
||||
y_raw = ortho6d[..., 3:6]
|
||||
|
||||
x = normalize_vector(x_raw)
|
||||
z = cross_product(x, y_raw)
|
||||
z = normalize_vector(z)
|
||||
y = cross_product(z, x)
|
||||
|
||||
# Stack x, y, z to form the matrix
|
||||
matrix = tf.stack([x, y, z], axis=-1)
|
||||
return matrix
|
||||
|
||||
|
||||
def capitalize_and_period(instr: str) -> str:
|
||||
"""
|
||||
Capitalize the first letter of a string and add a period to the end if it's not there.
|
||||
"""
|
||||
if len(instr) > 0:
|
||||
# if the first letter is not capital, make it so
|
||||
if not instr[0].isupper():
|
||||
# if the first letter is not capital, make it so
|
||||
instr = instr[0].upper() + instr[1:]
|
||||
# add period to the end if it's not there
|
||||
if instr[-1] != ".":
|
||||
# add period to the end if it's not there
|
||||
instr = instr + "."
|
||||
return instr
|
||||
149
RDT-1B/data/vla_dataset.py
Normal file
149
RDT-1B/data/vla_dataset.py
Normal file
@ -0,0 +1,149 @@
|
||||
import json
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
import yaml
|
||||
|
||||
from data.episode_transform import (
|
||||
process_episode,
|
||||
flatten_episode,
|
||||
flatten_episode_agilex,
|
||||
bgr_to_rgb,
|
||||
)
|
||||
from data.utils import dataset_to_path
|
||||
from data.preprocess_scripts import *
|
||||
|
||||
# Producer does not need GPU
|
||||
tf.config.set_visible_devices([], "GPU")
|
||||
|
||||
OPENX_EMBOD_DIR = "data/datasets/openx_embod"
|
||||
|
||||
DATASET_NAMES_NOOPENX = [
|
||||
"aloha_mobile",
|
||||
"aloha_static",
|
||||
"roboset",
|
||||
"agilex",
|
||||
"rh20t",
|
||||
"calvin",
|
||||
"bridgev2",
|
||||
]
|
||||
|
||||
# Read the config
|
||||
with open("configs/base.yaml", "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
# Load some constants from the config
|
||||
EPSD_LEN_THRESH_LOW = config["dataset"]["epsd_len_thresh_low"]
|
||||
EPSD_LEN_THRESH_HIGH = config["dataset"]["epsd_len_thresh_high"]
|
||||
# Read the image keys of each dataset
|
||||
with open("configs/dataset_img_keys.json", "r") as file:
|
||||
IMAGE_KEYS = json.load(file)
|
||||
|
||||
|
||||
class VLADataset:
|
||||
"""
|
||||
This class is used to sample episodes from the embododiment dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, seed, dataset_type, repeat=True):
|
||||
"""
|
||||
seed: the random seed
|
||||
dataset_type: 'pretrain' or 'finetune', which dataset to load
|
||||
repeat: whether to repeat to infinite length
|
||||
"""
|
||||
dataset_names_cfg = ("configs/pretrain_datasets.json"
|
||||
if dataset_type == "pretrain" else "configs/finetune_datasets.json")
|
||||
with open(dataset_names_cfg, "r") as file:
|
||||
DATASET_NAMES = json.load(file)
|
||||
self.dataset_names = DATASET_NAMES
|
||||
sample_weights_cfg = ("configs/pretrain_sample_weights.json"
|
||||
if dataset_type == "pretrain" else "configs/finetune_sample_weights.json")
|
||||
# Load the sample weights
|
||||
with open(sample_weights_cfg, "r") as file:
|
||||
SAMPLE_WEIGHTS = json.load(file)
|
||||
self.openx_dir = OPENX_EMBOD_DIR
|
||||
self.epsd_len_thresh_low = EPSD_LEN_THRESH_LOW
|
||||
self.epsd_len_thresh_high = EPSD_LEN_THRESH_HIGH
|
||||
self.repeat = repeat
|
||||
|
||||
# Set the random seed
|
||||
tf.random.set_seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
# Weights of the each dataset in the collection to sample from
|
||||
sample_weights = []
|
||||
|
||||
self.name2dataset = {}
|
||||
for dataset_name in self.dataset_names:
|
||||
if dataset_name in DATASET_NAMES_NOOPENX:
|
||||
dataset = globals()[dataset_name].load_dataset(seed)
|
||||
else:
|
||||
dataset_path = dataset_to_path(dataset_name, self.openx_dir)
|
||||
dataset = tfds.builder_from_directory(builder_dir=dataset_path)
|
||||
dataset = dataset.as_dataset(split="all", shuffle_files=True)
|
||||
|
||||
# You can add filter for other datasets
|
||||
if dataset_name == "kuka":
|
||||
dataset = dataset.filter(lambda x: x["success"])
|
||||
elif dataset_name == "bc_z":
|
||||
dataset = dataset.filter(lambda x: tf.math.greater(
|
||||
next(iter(x["steps"]))["observation"]["episode_success"],
|
||||
0.5,
|
||||
))
|
||||
elif (dataset_name == "ucsd_pick_and_place_dataset_converted_externally_to_rlds"):
|
||||
dataset = dataset.filter(lambda x: x["episode_metadata"]["success"])
|
||||
elif (dataset_name == "utokyo_xarm_bimanual_converted_externally_to_rlds"):
|
||||
# Only preserve the meaningful episodes
|
||||
dataset = dataset.filter(lambda x: tf.math.equal(
|
||||
next(iter(x["steps"]))["language_instruction"],
|
||||
tf.constant("Unfold a wrinkled towel."),
|
||||
))
|
||||
|
||||
# Note: use cache() will cause the unexpected crash
|
||||
# dataset = dataset.map().cache().shuffle().repeat()
|
||||
dataset = dataset.map(lambda x: process_episode(
|
||||
x,
|
||||
dataset_name,
|
||||
IMAGE_KEYS[dataset_name]["image_keys"],
|
||||
IMAGE_KEYS[dataset_name]["image_mask"],
|
||||
))
|
||||
|
||||
# Change BGR to RGB if needed
|
||||
if dataset_name == "fmb":
|
||||
dataset = dataset.map(bgr_to_rgb)
|
||||
|
||||
if self.repeat:
|
||||
dataset = dataset.repeat()
|
||||
self.name2dataset[dataset_name] = iter(dataset)
|
||||
sample_weights.append(SAMPLE_WEIGHTS[dataset_name])
|
||||
# Normalize the sample weights
|
||||
sample_weights = np.array(sample_weights)
|
||||
self.sample_weights = sample_weights / np.sum(sample_weights)
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
Sample batches of episodes for an epoch.
|
||||
"""
|
||||
while True:
|
||||
dataset_name = np.random.choice(self.dataset_names, p=self.sample_weights)
|
||||
episode = next(self.name2dataset[dataset_name])
|
||||
if dataset_name == "agilex":
|
||||
episode_steps = flatten_episode_agilex(episode)
|
||||
else:
|
||||
episode_steps = flatten_episode(episode)
|
||||
# Filter too short
|
||||
if len(episode_steps) < self.epsd_len_thresh_low:
|
||||
continue
|
||||
# Randomly sample too long
|
||||
if len(episode_steps) > self.epsd_len_thresh_high:
|
||||
episode_steps = random.sample(episode_steps, self.epsd_len_thresh_high)
|
||||
|
||||
yield episode_steps
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataset = VLADataset(0, "finetune")
|
||||
for episode in dataset:
|
||||
print(episode[0])
|
||||
break
|
||||
107
RDT-1B/finetune.sh
Normal file
107
RDT-1B/finetune.sh
Normal file
@ -0,0 +1,107 @@
|
||||
#!/bin/bash
|
||||
|
||||
BEGIN_TIME=$(date +%s)
|
||||
|
||||
CONFIG_NAME="Train_Config_Default"
|
||||
CONFIG_FILE="model_config/$CONFIG_NAME.yml"
|
||||
echo "CONFIG_FILE_PATH: $CONFIG_FILE"
|
||||
|
||||
### ============Read Input Config and ReLoad Config YAML===================
|
||||
ln -s /home/qi.xiong/Temp/RDT-1B/input/weights ../weights
|
||||
|
||||
TRAIN_CONFIG_FILE="input/config.json"
|
||||
echo "TRAIN_CONFIG_FILE_PATH: $TRAIN_CONFIG_FILE"
|
||||
python scripts/read_config.py "$TRAIN_CONFIG_FILE" "$CONFIG_FILE"
|
||||
|
||||
### ============Read Input Config and ReLoad Config YAML===================
|
||||
|
||||
export NCCL_IB_HCA=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
|
||||
export NCCL_DEBUG=INFO
|
||||
export NCCL_NVLS_ENABLE=0
|
||||
export NCCL_SOCKET_IFNAME=eth0
|
||||
# export TEXT_ENCODER_NAME="google/t5-v1_1-xxl"
|
||||
export VISION_ENCODER_NAME="../weights/siglip-so400m-patch14-384"
|
||||
export CFLAGS="-I/usr/include"
|
||||
export LDFLAGS="-L/usr/lib/x86_64-linux-gnu"
|
||||
export WANDB_PROJECT="RDT-1B"
|
||||
export WANDB_DEFAULT_RUN_NAME=$CONFIG_NAME
|
||||
export NCCL_P2P_DISABLE=1
|
||||
export NCCL_IB_DISABLE=1
|
||||
|
||||
# check if YAML exist
|
||||
if [ ! -f "$CONFIG_FILE" ]; then
|
||||
echo "Config file $CONFIG_FILE does not exist!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PRETRAINED_MODEL_NAME=$(python scripts/read_yaml.py "$CONFIG_FILE" pretrained_model_name_or_path)
|
||||
TRAIN_BATCH_SIZE=$(python scripts/read_yaml.py "$CONFIG_FILE" train_batch_size)
|
||||
SAMPLE_BATCH_SIZE=$(python scripts/read_yaml.py "$CONFIG_FILE" sample_batch_size)
|
||||
MAX_TRAIN_STEPS=$(python scripts/read_yaml.py "$CONFIG_FILE" max_train_steps)
|
||||
CHECKPOINTING_PERIOD=$(python scripts/read_yaml.py "$CONFIG_FILE" checkpointing_period)
|
||||
SAMPLE_PERIOD=$(python scripts/read_yaml.py "$CONFIG_FILE" sample_period)
|
||||
CHECKPOINTS_TOTAL_LIMIT=$(python scripts/read_yaml.py "$CONFIG_FILE" checkpoints_total_limit)
|
||||
LR_SCHEDULER=$(python scripts/read_yaml.py "$CONFIG_FILE" lr_scheduler)
|
||||
LEARNING_RATE=$(python scripts/read_yaml.py "$CONFIG_FILE" learning_rate)
|
||||
DATALOADER_NUM_WORKERS=$(python scripts/read_yaml.py "$CONFIG_FILE" dataloader_num_workers)
|
||||
DATASET_TYPE=$(python scripts/read_yaml.py "$CONFIG_FILE" dataset_type)
|
||||
STATE_NOISE_SNR=$(python scripts/read_yaml.py "$CONFIG_FILE" state_noise_snr)
|
||||
GRAD_ACCUM_STEPS=$(python scripts/read_yaml.py "$CONFIG_FILE" gradient_accumulation_steps)
|
||||
OUTPUT_DIR=$(python scripts/read_yaml.py "$CONFIG_FILE" checkpoint_path)
|
||||
CUDA_USE=$(python scripts/read_yaml.py "$CONFIG_FILE" cuda_visible_device)
|
||||
|
||||
|
||||
PRETRAINED_MODEL_NAME=$(echo "$PRETRAINED_MODEL_NAME" | tr -d '"')
|
||||
CUDA_USE=$(echo "$CUDA_USE" | tr -d '"')
|
||||
OUTPUT_DIR=$(echo "$OUTPUT_DIR" | tr -d '"')
|
||||
|
||||
# create output path
|
||||
if [ ! -d "$OUTPUT_DIR" ]; then
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
echo "Created output directory: $OUTPUT_DIR"
|
||||
else
|
||||
echo "Output directory already exists: $OUTPUT_DIR"
|
||||
fi
|
||||
|
||||
export CUDA_VISIBLE_DEVICES=$CUDA_USE
|
||||
|
||||
python -m data.compute_dataset_stat_hdf5 --task_name $CONFIG_NAME
|
||||
|
||||
accelerate launch --main_process_port=28499 main.py \
|
||||
--deepspeed="./configs/zero2.json" \
|
||||
--pretrained_model_name_or_path=$PRETRAINED_MODEL_NAME \
|
||||
--pretrained_text_encoder_name_or_path=$TEXT_ENCODER_NAME \
|
||||
--pretrained_vision_encoder_name_or_path=$VISION_ENCODER_NAME \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--train_batch_size=$TRAIN_BATCH_SIZE \
|
||||
--sample_batch_size=$SAMPLE_BATCH_SIZE \
|
||||
--max_train_steps=$MAX_TRAIN_STEPS \
|
||||
--checkpointing_period=$CHECKPOINTING_PERIOD \
|
||||
--sample_period=$SAMPLE_PERIOD \
|
||||
--checkpoints_total_limit=$CHECKPOINTS_TOTAL_LIMIT \
|
||||
--lr_scheduler="constant" \
|
||||
--learning_rate=$LEARNING_RATE \
|
||||
--mixed_precision="bf16" \
|
||||
--dataloader_num_workers=$DATALOADER_NUM_WORKERS \
|
||||
--image_aug \
|
||||
--dataset_type="finetune" \
|
||||
--state_noise_snr=$STATE_NOISE_SNR \
|
||||
--load_from_hdf5 \
|
||||
--report_to=wandb \
|
||||
--precomp_lang_embed \
|
||||
--gradient_accumulation_steps=$GRAD_ACCUM_STEPS \
|
||||
--model_config_path=$CONFIG_FILE \
|
||||
--CONFIG_NAME=$CONFIG_NAME \
|
||||
--output_log_path=$OUTPUT_DIR/output.log
|
||||
|
||||
END_TIME=$(date +%s)
|
||||
RUNTIME=$((END_TIME - BEGIN_TIME))
|
||||
echo "Total runtime: $RUNTIME seconds"
|
||||
|
||||
### ============Generate Output JSON===================
|
||||
|
||||
python scripts/generate_output_json.py "$TRAIN_CONFIG_FILE" "$OUTPUT_DIR" "$RUNTIME"
|
||||
|
||||
### ============Generate Output JSON===================
|
||||
|
||||
|
||||
Binary file not shown.
5
RDT-1B/generate.sh
Normal file
5
RDT-1B/generate.sh
Normal file
@ -0,0 +1,5 @@
|
||||
#!/bin/bash
|
||||
|
||||
model_name=${1}
|
||||
|
||||
python ./model_config/_generate_model_config.py $model_name
|
||||
351
RDT-1B/main.py
Normal file
351
RDT-1B/main.py
Normal file
@ -0,0 +1,351 @@
|
||||
import argparse
|
||||
import os
|
||||
from train.train import train
|
||||
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
|
||||
def parse_args(input_args=None):
|
||||
parser = argparse.ArgumentParser(description="Main script for training RDT.")
|
||||
parser.add_argument(
|
||||
"--model_config_path",
|
||||
type=str,
|
||||
default="model_config/sjoe_place_D435_100_finetune_config.yaml",
|
||||
help=
|
||||
"Path to the finetune data and model configuration file. Default is `model_config/sjoe_place_D435_100_finetune_config.yaml`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_path",
|
||||
type=str,
|
||||
default="configs/base.yaml",
|
||||
help="Path to the configuration file. Default is `configs/base.yaml`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--deepspeed",
|
||||
type=str,
|
||||
default=None,
|
||||
help=
|
||||
"Enable DeepSpeed and pass the path to its config file or an already initialized DeepSpeed config dictionary",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_text_encoder_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained text encoder name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_vision_encoder_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained vision encoder name or path if not the same as model_name",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="checkpoints",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
|
||||
parser.add_argument(
|
||||
"--load_from_hdf5",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=("Whether to load the dataset directly from HDF5 files. "
|
||||
"If False, the dataset will be loaded using producer-consumer pattern, "
|
||||
"where the producer reads TFRecords and saves them to buffer, and the consumer reads from buffer."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_batch_size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Batch size (per device) for the training dataloader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample_batch_size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Batch size (per device) for the sampling dataloader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_sample_batches",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of batches to sample from the dataset.",
|
||||
)
|
||||
parser.add_argument("--num_train_epochs", type=int, default=1)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpointing_period",
|
||||
type=int,
|
||||
default=500,
|
||||
help=
|
||||
("Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
|
||||
"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
|
||||
"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
|
||||
"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
|
||||
"instructions."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoints_total_limit",
|
||||
type=int,
|
||||
default=None,
|
||||
help=
|
||||
("Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
|
||||
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
|
||||
" for more details"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help=("Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
||||
' `--checkpointing_period`, or `"latest"` to automatically select the last available checkpoint.'),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Path or name of a pretrained checkpoint to load the model from.\n",
|
||||
" This can be either:\n"
|
||||
" - a string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co, e.g., `robotics-diffusion-transformer/rdt-1b`,\n"
|
||||
" - a path to a *directory* containing model weights saved using [`~RDTRunner.save_pretrained`] method, e.g., `./my_model_directory/`.\n"
|
||||
" - a path to model checkpoint (*.pt), .e.g, `my_model_directory/checkpoint-10000/pytorch_model/mp_rank_00_model_states.pt`"
|
||||
" - `None` if you are randomly initializing model using configuration at `config_path`.",
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_checkpointing",
|
||||
action="store_true",
|
||||
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=5e-6,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cond_mask_prob",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help=("The probability to randomly mask the conditions (except states) during training. "
|
||||
"If set to 0, the conditions are not masked."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cam_ext_mask_prob",
|
||||
type=float,
|
||||
default=-1.0,
|
||||
help=("The probability to randomly mask the external camera image during training. "
|
||||
"If set to < 0, the external camera image is masked with the probability of `cond_mask_prob`."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--state_noise_snr",
|
||||
type=float,
|
||||
default=None,
|
||||
help=("The signal-to-noise ratio (SNR, unit: dB) for adding noise to the states. "
|
||||
"Default is None, which means no noise is added."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_aug",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether or not to apply image augmentation (ColorJitter, blur, noise, etc) to the input images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precomp_lang_embed",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether or not to use precomputed language embeddings.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_lr",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
default="constant",
|
||||
help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Number of steps for the warmup in the lr scheduler.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_num_cycles",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_power",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Power factor of the polynomial scheduler.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_8bit_adam",
|
||||
action="store_true",
|
||||
help="Whether or not to use 8-bit Adam from bitsandbytes.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataloader_num_workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha",
|
||||
type=float,
|
||||
default=0.9,
|
||||
help="The moving average coefficient for each dataset's loss.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adam_beta1",
|
||||
type=float,
|
||||
default=0.9,
|
||||
help="The beta1 parameter for the Adam optimizer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adam_beta2",
|
||||
type=float,
|
||||
default=0.999,
|
||||
help="The beta2 parameter for the Adam optimizer.",
|
||||
)
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
parser.add_argument(
|
||||
"--adam_epsilon",
|
||||
type=float,
|
||||
default=1e-08,
|
||||
help="Epsilon value for the Adam optimizer",
|
||||
)
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
help="Whether or not to push the model to the Hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hub_token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The token to use to push to the Model Hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hub_model_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allow_tf32",
|
||||
action="store_true",
|
||||
help=("Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
||||
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
help=('The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
||||
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample_period",
|
||||
type=int,
|
||||
default=-1,
|
||||
help=("Run sampling every X steps. During the sampling phase, the model will sample a trajectory"
|
||||
" and report the error between the sampled trajectory and groud-truth trajectory"
|
||||
" in the training batch."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--local_rank",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="For distributed training: local_rank",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--set_grads_to_none",
|
||||
action="store_true",
|
||||
help=("Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
|
||||
" behaviors, so disable this argument if it causes any problems. More info:"
|
||||
" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset_type",
|
||||
type=str,
|
||||
default="pretrain",
|
||||
required=False,
|
||||
help="Whether to load the pretrain dataset or finetune dataset.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--CONFIG_NAME",
|
||||
type=str,
|
||||
default="Null",
|
||||
required=True,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_log_path",
|
||||
type=str,
|
||||
default="output/output.log",
|
||||
help="The path to the output log file.",
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
else:
|
||||
args = parser.parse_args()
|
||||
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger = get_logger(__name__)
|
||||
args = parse_args()
|
||||
train(args, logger)
|
||||
269
RDT-1B/model.py
Normal file
269
RDT-1B/model.py
Normal file
@ -0,0 +1,269 @@
|
||||
#!/home/lin/software/miniconda3/envs/aloha/bin/python
|
||||
# -- coding: UTF-8
|
||||
"""
|
||||
#!/usr/bin/python3
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
# get current workspace
|
||||
current_file = Path(__file__)
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
parent_dir = current_file.parent
|
||||
sys.path.append(str(parent_dir))
|
||||
|
||||
import os
|
||||
|
||||
import argparse
|
||||
|
||||
import threading
|
||||
import time
|
||||
import yaml
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image as PImage
|
||||
import cv2
|
||||
|
||||
import sys, os
|
||||
|
||||
# get current workspace
|
||||
current_file = Path(__file__)
|
||||
sys.path.append(os.path.join(current_file.parent, "models"))
|
||||
|
||||
from scripts.agilex_model import create_model
|
||||
from multimodal_encoder.t5_encoder import T5Embedder
|
||||
|
||||
global_path = parent_dir.parent
|
||||
|
||||
|
||||
class RDT:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pretrained_model_name_or_path,
|
||||
task_name,
|
||||
left_arm_dim,
|
||||
right_arm_dim,
|
||||
rdt_step,
|
||||
):
|
||||
# set path
|
||||
current_file = Path(__file__)
|
||||
self.global_path = current_file.parent.parent
|
||||
# load the config
|
||||
self.config = {
|
||||
"episode_len": 10000, # args.max_publish_step
|
||||
"state_dim": left_arm_dim + 1 + right_arm_dim +
|
||||
1, # 14 dims action:[left joint angles,left gripper,right joint angles,right gripper]
|
||||
"chunk_size": 64, # args.chunk_size
|
||||
"camera_names": ["cam_high", "cam_right_wrist", "cam_left_wrist"],
|
||||
}
|
||||
# setup config
|
||||
self.args = {
|
||||
"max_publish_step": 10000, # Maximum number of action publishing steps
|
||||
"seed": None, # Random seed
|
||||
"ctrl_freq": 25, # The control frequency of the robot
|
||||
"chunk_size": 64, # Action chunk size
|
||||
# 'disable_puppet_arm': False, # Whether to disable the puppet arm
|
||||
"config_path": os.path.join(self.global_path, "RDT/configs/base.yaml"),
|
||||
"pretrained_model_name_or_path": pretrained_model_name_or_path,
|
||||
}
|
||||
|
||||
# Load rdt model
|
||||
self.left_arm_dim, self.right_arm_dim = left_arm_dim, right_arm_dim
|
||||
self.policy = self.make_policy(self.args)
|
||||
self.max_publish_step = self.config["episode_len"]
|
||||
self.chunk_size = self.config["chunk_size"]
|
||||
self.task_name = task_name
|
||||
self.observation_window = None
|
||||
self.img_size = (640, 480)
|
||||
self.set_language_embed()
|
||||
self.rdt_step = rdt_step
|
||||
|
||||
# set img_size
|
||||
def set_img_size(self, img_size):
|
||||
self.img_size = img_size
|
||||
|
||||
def set_language_embed(self):
|
||||
GPU = 0
|
||||
MODEL_PATH = os.path.join(self.global_path, "weights/RDT/t5-v1_1-xxl")
|
||||
CONFIG_PATH = os.path.join(self.global_path, "RDT/configs/base.yaml")
|
||||
with open(CONFIG_PATH, "r") as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
device = torch.device(f"cuda:{GPU}")
|
||||
text_embedder = T5Embedder(
|
||||
from_pretrained=MODEL_PATH,
|
||||
model_max_length=config["dataset"]["tokenizer_max_length"],
|
||||
device=device,
|
||||
use_offload_folder=None,
|
||||
)
|
||||
self.tokenizer, self.text_encoder = text_embedder.tokenizer, text_embedder.model
|
||||
self.text_encoder.eval()
|
||||
|
||||
# set language randomly
|
||||
def random_set_language(self, instruction=None):
|
||||
assert instruction is not None, "Missing input instruction"
|
||||
self.set_language_instruction(instruction)
|
||||
|
||||
# encoding language
|
||||
def set_language_instruction(self, language_instruction, save_dir=None, task_name=None):
|
||||
assert ((save_dir is None) ^ (task_name is None)) == False, "input error"
|
||||
|
||||
if os.path.isfile(language_instruction):
|
||||
lang_dict = torch.load(language_instruction)
|
||||
print(f"Running with instruction: \"{lang_dict['instruction']}\" from \"{lang_dict['name']}\"")
|
||||
self.lang_embeddings = lang_dict["embeddings"]
|
||||
print("loading instruction from pre-embed path")
|
||||
else:
|
||||
device = next(self.text_encoder.parameters()).device
|
||||
with torch.no_grad():
|
||||
tokens = self.tokenizer(
|
||||
language_instruction,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
truncation=True,
|
||||
)["input_ids"].to(device)
|
||||
tokens = tokens.view(1, -1)
|
||||
output = self.text_encoder(tokens)
|
||||
pred = output.last_hidden_state.detach().cpu()
|
||||
|
||||
if save_dir is not None:
|
||||
save_path = os.path.join(save_dir, f"{task_name}.pt")
|
||||
torch.save({
|
||||
"name": task_name,
|
||||
"instruction": language_instruction,
|
||||
"embeddings": pred,
|
||||
}, save_path)
|
||||
|
||||
del tokens, output
|
||||
torch.cuda.empty_cache()
|
||||
self.lang_embeddings = pred
|
||||
|
||||
print(f"successfully set instruction: {language_instruction}")
|
||||
|
||||
# Update the observation window buffer
|
||||
def update_observation_window(self, img_arr, state):
|
||||
# JPEG transformation
|
||||
# Align with training
|
||||
def jpeg_mapping(img):
|
||||
if img is None:
|
||||
return None
|
||||
img = cv2.imencode(".jpg", img)[1].tobytes()
|
||||
img = cv2.imdecode(np.frombuffer(img, np.uint8), cv2.IMREAD_COLOR)
|
||||
return img
|
||||
|
||||
def resize_img(img, size):
|
||||
return cv2.resize(img, size)
|
||||
|
||||
if self.observation_window is None:
|
||||
self.observation_window = deque(maxlen=2)
|
||||
|
||||
# Append the first dummy image
|
||||
self.observation_window.append({
|
||||
"qpos": None,
|
||||
"images": {
|
||||
self.config["camera_names"][0]: None,
|
||||
self.config["camera_names"][1]: None,
|
||||
self.config["camera_names"][2]: None,
|
||||
},
|
||||
})
|
||||
|
||||
img_front, img_right, img_left, puppet_arm = (
|
||||
img_arr[0],
|
||||
img_arr[1],
|
||||
img_arr[2],
|
||||
state,
|
||||
)
|
||||
# img resize
|
||||
img_front = resize_img(img_front, self.img_size)
|
||||
img_left = resize_img(img_left, self.img_size)
|
||||
img_right = resize_img(img_right, self.img_size)
|
||||
# img jprg encoding
|
||||
img_front = jpeg_mapping(img_front)
|
||||
img_left = jpeg_mapping(img_left)
|
||||
img_right = jpeg_mapping(img_right)
|
||||
|
||||
qpos = np.array(puppet_arm)
|
||||
qpos = torch.from_numpy(qpos).float().cuda()
|
||||
self.observation_window.append({
|
||||
"qpos": qpos,
|
||||
"images": {
|
||||
self.config["camera_names"][0]: img_front,
|
||||
self.config["camera_names"][1]: img_right,
|
||||
self.config["camera_names"][2]: img_left,
|
||||
},
|
||||
})
|
||||
|
||||
def get_action(self, img_arr=None, state=None):
|
||||
assert (img_arr is None) ^ (state is None) == False, "input error"
|
||||
if (img_arr is not None) and (state is not None):
|
||||
self.update_observation_window(img_arr, state)
|
||||
|
||||
with torch.inference_mode():
|
||||
action_buffer = inference_fn(self.config, self.policy, self.lang_embeddings, self.observation_window).copy()
|
||||
|
||||
return action_buffer
|
||||
|
||||
def reset_obsrvationwindows(self):
|
||||
self.lang_embeddings = None
|
||||
self.observation_window = None
|
||||
print("successfully unset obs and language intruction")
|
||||
|
||||
# Initialize the model
|
||||
def make_policy(self, args):
|
||||
with open(args["config_path"], "r") as fp:
|
||||
config_base_yaml = yaml.safe_load(fp)
|
||||
args["config"] = config_base_yaml
|
||||
args["config"]["arm_dim"] = {
|
||||
"left_arm_dim": self.left_arm_dim,
|
||||
"right_arm_dim": self.right_arm_dim,
|
||||
}
|
||||
# pretrained_text_encoder_name_or_path = "weights/RDT/t5-v1_1-xxl"
|
||||
pretrained_vision_encoder_name_or_path = os.path.join(self.global_path, "weights/RDT/siglip-so400m-patch14-384")
|
||||
model = create_model(
|
||||
args=args["config"],
|
||||
dtype=torch.bfloat16,
|
||||
pretrained=args["pretrained_model_name_or_path"],
|
||||
# pretrained_text_encoder_name_or_path=pretrained_text_encoder_name_or_path,
|
||||
pretrained_vision_encoder_name_or_path=pretrained_vision_encoder_name_or_path,
|
||||
control_frequency=args["ctrl_freq"],
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# RDT inference
|
||||
def inference_fn(config, policy, lang_embeddings, observation_window):
|
||||
|
||||
# print(f"Start inference_thread_fn: t={t}")
|
||||
while True:
|
||||
time1 = time.time()
|
||||
|
||||
# fetch images in sequence [front, right, left]
|
||||
image_arrs = [
|
||||
observation_window[-2]["images"][config["camera_names"][0]],
|
||||
observation_window[-2]["images"][config["camera_names"][1]],
|
||||
observation_window[-2]["images"][config["camera_names"][2]],
|
||||
observation_window[-1]["images"][config["camera_names"][0]],
|
||||
observation_window[-1]["images"][config["camera_names"][1]],
|
||||
observation_window[-1]["images"][config["camera_names"][2]],
|
||||
]
|
||||
|
||||
images = [PImage.fromarray(arr) if arr is not None else None for arr in image_arrs]
|
||||
|
||||
# get last qpos in shape [14, ]
|
||||
proprio = observation_window[-1]["qpos"]
|
||||
# unsqueeze to [1, 14]
|
||||
proprio = proprio.unsqueeze(0)
|
||||
|
||||
# actions shaped as [1, 64, 14] in format [left, right]
|
||||
actions = (policy.step(proprio=proprio, images=images, text_embeds=lang_embeddings).squeeze(0).cpu().numpy())
|
||||
# print(f"inference_actions: {actions.squeeze()}")
|
||||
|
||||
# print(f"Model inference time: {time.time() - time1} s")
|
||||
|
||||
# print(f"Finish inference_thread_fn: t={t}")
|
||||
return actions
|
||||
40
RDT-1B/model_config/_generate_model_config.py
Normal file
40
RDT-1B/model_config/_generate_model_config.py
Normal file
@ -0,0 +1,40 @@
|
||||
import os
|
||||
import yaml
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Generate finetune config.")
|
||||
parser.add_argument("model_name", type=str, help="The name of the task (e.g., beat_block_hammer)")
|
||||
args = parser.parse_args()
|
||||
model_name = args.model_name
|
||||
fintune_data_path = os.path.join("training_data/", f"{model_name}")
|
||||
checkpoint_path = os.path.join("checkpoints/", f"{model_name}")
|
||||
data = {
|
||||
"model": model_name,
|
||||
"data_path": fintune_data_path,
|
||||
"checkpoint_path": checkpoint_path,
|
||||
"pretrained_model_name_or_path": "../weights/RDT/rdt-1b",
|
||||
"cuda_visible_device": "...", # args.gpu_use,
|
||||
"train_batch_size": 32,
|
||||
"sample_batch_size": 64,
|
||||
"max_train_steps": 20000,
|
||||
"checkpointing_period": 2500,
|
||||
"sample_period": 100,
|
||||
"checkpoints_total_limit": 40,
|
||||
"learning_rate": 1e-4,
|
||||
"dataloader_num_workers": 8,
|
||||
"state_noise_snr": 40,
|
||||
"gradient_accumulation_steps": 1,
|
||||
}
|
||||
task_config_path = os.path.join("model_config/", f"{model_name}.yml")
|
||||
|
||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
time_comment = f"# Generated on {current_time}\n"
|
||||
|
||||
with open(task_config_path, "w") as f:
|
||||
f.write(time_comment)
|
||||
yaml.dump(data, f, default_flow_style=False, sort_keys=False)
|
||||
|
||||
if not os.path.exists(fintune_data_path):
|
||||
os.makedirs(fintune_data_path)
|
||||
0
RDT-1B/models/__init__.py
Normal file
0
RDT-1B/models/__init__.py
Normal file
BIN
RDT-1B/models/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
RDT-1B/models/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
RDT-1B/models/__pycache__/ema_model.cpython-310.pyc
Normal file
BIN
RDT-1B/models/__pycache__/ema_model.cpython-310.pyc
Normal file
Binary file not shown.
BIN
RDT-1B/models/__pycache__/hub_mixin.cpython-310.pyc
Normal file
BIN
RDT-1B/models/__pycache__/hub_mixin.cpython-310.pyc
Normal file
Binary file not shown.
BIN
RDT-1B/models/__pycache__/rdt_runner.cpython-310.pyc
Normal file
BIN
RDT-1B/models/__pycache__/rdt_runner.cpython-310.pyc
Normal file
Binary file not shown.
82
RDT-1B/models/ema_model.py
Normal file
82
RDT-1B/models/ema_model.py
Normal file
@ -0,0 +1,82 @@
|
||||
# Reference: DiffusionPolicy [https://github.com/real-stanford/diffusion_policy]
|
||||
|
||||
import torch
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
|
||||
class EMAModel:
|
||||
"""
|
||||
Exponential Moving Average of models weights
|
||||
"""
|
||||
|
||||
def __init__(self, model, update_after_step=0, inv_gamma=1.0, power=2 / 3, min_value=0.0, max_value=0.9999):
|
||||
"""
|
||||
@crowsonkb's notes on EMA Warmup:
|
||||
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
|
||||
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
|
||||
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
|
||||
at 215.4k steps).
|
||||
Args:
|
||||
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
||||
power (float): Exponential factor of EMA warmup. Default: 2/3.
|
||||
min_value (float): The minimum EMA decay rate. Default: 0.
|
||||
"""
|
||||
|
||||
self.averaged_model = model
|
||||
self.averaged_model.eval()
|
||||
self.averaged_model.requires_grad_(False)
|
||||
|
||||
self.update_after_step = update_after_step
|
||||
self.inv_gamma = inv_gamma
|
||||
self.power = power
|
||||
self.min_value = min_value
|
||||
self.max_value = max_value
|
||||
|
||||
self.decay = 0.0
|
||||
self.optimization_step = 0
|
||||
|
||||
def get_decay(self, optimization_step):
|
||||
"""
|
||||
Compute the decay factor for the exponential moving average.
|
||||
"""
|
||||
step = max(0, optimization_step - self.update_after_step - 1)
|
||||
value = 1 - (1 + step / self.inv_gamma)**-self.power
|
||||
|
||||
if step <= 0:
|
||||
return 0.0
|
||||
|
||||
return max(self.min_value, min(value, self.max_value))
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, new_model):
|
||||
self.decay = self.get_decay(self.optimization_step)
|
||||
|
||||
# old_all_dataptrs = set()
|
||||
# for param in new_model.parameters():
|
||||
# data_ptr = param.data_ptr()
|
||||
# if data_ptr != 0:
|
||||
# old_all_dataptrs.add(data_ptr)
|
||||
|
||||
all_dataptrs = set()
|
||||
for module, ema_module in zip(new_model.modules(), self.averaged_model.modules()):
|
||||
for param, ema_param in zip(module.parameters(recurse=False), ema_module.parameters(recurse=False)):
|
||||
# iterative over immediate parameters only.
|
||||
if isinstance(param, dict):
|
||||
raise RuntimeError('Dict parameter not supported')
|
||||
|
||||
# data_ptr = param.data_ptr()
|
||||
# if data_ptr != 0:
|
||||
# all_dataptrs.add(data_ptr)
|
||||
|
||||
if isinstance(module, _BatchNorm):
|
||||
# skip batchnorms
|
||||
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
|
||||
elif not param.requires_grad:
|
||||
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
|
||||
else:
|
||||
ema_param.mul_(self.decay)
|
||||
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)
|
||||
|
||||
# verify that iterating over module and then parameters is identical to parameters recursively.
|
||||
# assert old_all_dataptrs == all_dataptrs
|
||||
self.optimization_step += 1
|
||||
75
RDT-1B/models/hub_mixin.py
Normal file
75
RDT-1B/models/hub_mixin.py
Normal file
@ -0,0 +1,75 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from huggingface_hub.constants import (PYTORCH_WEIGHTS_NAME, SAFETENSORS_SINGLE_FILE)
|
||||
from huggingface_hub.file_download import hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError, is_torch_available
|
||||
|
||||
if is_torch_available():
|
||||
import torch # type: ignore
|
||||
|
||||
|
||||
class CompatiblePyTorchModelHubMixin(PyTorchModelHubMixin):
|
||||
"""Mixin class to load Pytorch models from the Hub."""
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
"""Save weights from a Pytorch model to a local directory."""
|
||||
# To bypass saving into safetensor by default
|
||||
model_to_save = self.module if hasattr(self, "module") else self # type: ignore
|
||||
torch.save(model_to_save.state_dict(), save_directory / PYTORCH_WEIGHTS_NAME)
|
||||
|
||||
@classmethod
|
||||
def _from_pretrained(
|
||||
cls,
|
||||
*,
|
||||
model_id: str,
|
||||
revision: Optional[str],
|
||||
cache_dir: Optional[Union[str, Path]],
|
||||
force_download: bool,
|
||||
proxies: Optional[Dict],
|
||||
resume_download: Optional[bool],
|
||||
local_files_only: bool,
|
||||
token: Union[str, bool, None],
|
||||
map_location: str = "cpu",
|
||||
strict: bool = False,
|
||||
**model_kwargs,
|
||||
):
|
||||
"""Load Pytorch pretrained weights and return the loaded model."""
|
||||
model = cls(**model_kwargs)
|
||||
if os.path.isdir(model_id):
|
||||
print("Loading weights from local directory")
|
||||
try:
|
||||
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
||||
return cls._load_as_safetensor(model, model_file, map_location, strict)
|
||||
except FileNotFoundError:
|
||||
model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
|
||||
return cls._load_as_pickle(model, model_file, map_location, strict)
|
||||
else:
|
||||
try:
|
||||
model_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=SAFETENSORS_SINGLE_FILE,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
return cls._load_as_safetensor(model, model_file, map_location, strict)
|
||||
except EntryNotFoundError:
|
||||
model_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=PYTORCH_WEIGHTS_NAME,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
return cls._load_as_pickle(model, model_file, map_location, strict)
|
||||
0
RDT-1B/models/multimodal_encoder/__init__.py
Normal file
0
RDT-1B/models/multimodal_encoder/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
159
RDT-1B/models/multimodal_encoder/clip_encoder.py
Normal file
159
RDT-1B/models/multimodal_encoder/clip_encoder.py
Normal file
@ -0,0 +1,159 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
|
||||
|
||||
|
||||
class CLIPVisionTower(nn.Module):
|
||||
|
||||
def __init__(self, vision_tower, args, delay_load=False):
|
||||
super().__init__()
|
||||
|
||||
self.is_loaded = False
|
||||
|
||||
self.vision_tower_name = vision_tower
|
||||
self.select_layer = args.mm_vision_select_layer
|
||||
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
||||
|
||||
if not delay_load:
|
||||
self.load_model()
|
||||
elif getattr(args, 'unfreeze_mm_vision_tower', False):
|
||||
self.load_model()
|
||||
else:
|
||||
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
|
||||
|
||||
def load_model(self, device_map=None):
|
||||
if self.is_loaded:
|
||||
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
|
||||
return
|
||||
|
||||
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
|
||||
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
|
||||
self.vision_tower.requires_grad_(False)
|
||||
|
||||
self.is_loaded = True
|
||||
|
||||
def feature_select(self, image_forward_outs):
|
||||
image_features = image_forward_outs.hidden_states[self.select_layer]
|
||||
if self.select_feature == 'patch':
|
||||
image_features = image_features[:, 1:]
|
||||
elif self.select_feature == 'cls_patch':
|
||||
image_features = image_features
|
||||
else:
|
||||
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
||||
return image_features
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, images):
|
||||
if type(images) is list:
|
||||
image_features = []
|
||||
for image in images:
|
||||
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
|
||||
output_hidden_states=True)
|
||||
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
||||
image_features.append(image_feature)
|
||||
else:
|
||||
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
|
||||
output_hidden_states=True)
|
||||
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
||||
|
||||
return image_features
|
||||
|
||||
@property
|
||||
def dummy_feature(self):
|
||||
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.vision_tower.dtype
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.vision_tower.device
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
if self.is_loaded:
|
||||
return self.vision_tower.config
|
||||
else:
|
||||
return self.cfg_only
|
||||
|
||||
@property
|
||||
def hidden_size(self):
|
||||
return self.config.hidden_size
|
||||
|
||||
@property
|
||||
def num_patches_per_side(self):
|
||||
return self.config.image_size // self.config.patch_size
|
||||
|
||||
@property
|
||||
def num_patches(self):
|
||||
return (self.config.image_size // self.config.patch_size)**2
|
||||
|
||||
|
||||
class CLIPVisionTowerS2(CLIPVisionTower):
|
||||
|
||||
def __init__(self, vision_tower, args, delay_load=False):
|
||||
super().__init__(vision_tower, args, delay_load)
|
||||
|
||||
self.s2_scales = getattr(args, 's2_scales', '336,672,1008')
|
||||
self.s2_scales = list(map(int, self.s2_scales.split(',')))
|
||||
self.s2_scales.sort()
|
||||
self.s2_split_size = self.s2_scales[0]
|
||||
self.s2_image_size = self.s2_scales[-1]
|
||||
|
||||
try:
|
||||
from s2wrapper import forward as multiscale_forward
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Package s2wrapper not found! Please install by running: \npip install git+https://github.com/bfshi/scaling_on_scales.git'
|
||||
)
|
||||
self.multiscale_forward = multiscale_forward
|
||||
|
||||
# change resize/crop size in preprocessing to the largest image size in s2_scale
|
||||
if not delay_load or getattr(args, 'unfreeze_mm_vision_tower', False):
|
||||
self.image_processor.size['shortest_edge'] = self.s2_image_size
|
||||
self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
|
||||
|
||||
def load_model(self, device_map=None):
|
||||
if self.is_loaded:
|
||||
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
|
||||
return
|
||||
|
||||
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
|
||||
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
|
||||
self.vision_tower.requires_grad_(False)
|
||||
|
||||
self.image_processor.size['shortest_edge'] = self.s2_image_size
|
||||
self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
|
||||
|
||||
self.is_loaded = True
|
||||
|
||||
@torch.no_grad()
|
||||
def forward_feature(self, images):
|
||||
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
|
||||
output_hidden_states=True)
|
||||
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
||||
return image_features
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, images):
|
||||
if type(images) is list:
|
||||
image_features = []
|
||||
for image in images:
|
||||
image_feature = self.multiscale_forward(self.forward_feature,
|
||||
image.unsqueeze(0),
|
||||
img_sizes=self.s2_scales,
|
||||
max_split_size=self.s2_split_size)
|
||||
image_features.append(image_feature)
|
||||
else:
|
||||
image_features = self.multiscale_forward(self.forward_feature,
|
||||
images,
|
||||
img_sizes=self.s2_scales,
|
||||
max_split_size=self.s2_split_size)
|
||||
|
||||
return image_features
|
||||
|
||||
@property
|
||||
def hidden_size(self):
|
||||
return self.config.hidden_size * len(self.s2_scales)
|
||||
87
RDT-1B/models/multimodal_encoder/dinov2_encoder.py
Normal file
87
RDT-1B/models/multimodal_encoder/dinov2_encoder.py
Normal file
@ -0,0 +1,87 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoConfig, AutoImageProcessor, AutoModel, Dinov2Model
|
||||
|
||||
|
||||
class DinoV2VisionTower(nn.Module):
|
||||
|
||||
def __init__(self, vision_tower, args, delay_load=False):
|
||||
super().__init__()
|
||||
|
||||
self.is_loaded = False
|
||||
|
||||
self.vision_tower_name = vision_tower
|
||||
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
||||
|
||||
if not delay_load:
|
||||
self.load_model()
|
||||
elif getattr(args, 'unfreeze_mm_vision_tower', False):
|
||||
self.load_model()
|
||||
else:
|
||||
self.cfg_only = AutoConfig.from_pretrained(self.vision_tower_name)
|
||||
|
||||
def load_model(self, device_map=None):
|
||||
if self.is_loaded:
|
||||
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
|
||||
return
|
||||
|
||||
self.image_processor = AutoImageProcessor.from_pretrained(self.vision_tower_name)
|
||||
self.vision_tower = AutoModel.from_pretrained(self.vision_tower_name, device_map=device_map)
|
||||
self.vision_tower.requires_grad_(False) # FIXME:
|
||||
|
||||
self.is_loaded = True
|
||||
|
||||
def feature_select(self, image_forward_outs):
|
||||
image_features = image_forward_outs.last_hidden_state
|
||||
if self.select_feature == 'patch':
|
||||
image_features = image_features[:, 1:] # (B, 1369, 1536)
|
||||
elif self.select_feature == 'cls_patch':
|
||||
image_features = image_features # (B, 1, 1536)
|
||||
else:
|
||||
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
||||
return image_features
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, images):
|
||||
if type(images) is list:
|
||||
image_features = []
|
||||
for image in images:
|
||||
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
|
||||
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
||||
image_features.append(image_feature)
|
||||
else:
|
||||
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype))
|
||||
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
||||
|
||||
return image_features
|
||||
|
||||
@property
|
||||
def dummy_feature(self):
|
||||
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.vision_tower.dtype
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.vision_tower.device
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
if self.is_loaded:
|
||||
return self.vision_tower.config
|
||||
else:
|
||||
return self.cfg_only
|
||||
|
||||
@property
|
||||
def hidden_size(self):
|
||||
return self.config.hidden_size
|
||||
|
||||
@property
|
||||
def num_patches_per_side(self):
|
||||
return self.config.image_size // self.config.patch_size
|
||||
|
||||
@property
|
||||
def num_patches(self):
|
||||
return (self.config.image_size // self.config.patch_size)**2
|
||||
86
RDT-1B/models/multimodal_encoder/siglip_encoder.py
Normal file
86
RDT-1B/models/multimodal_encoder/siglip_encoder.py
Normal file
@ -0,0 +1,86 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoConfig, SiglipImageProcessor, SiglipVisionModel
|
||||
|
||||
|
||||
class SiglipVisionTower(nn.Module):
|
||||
|
||||
def __init__(self, vision_tower, args, delay_load=False):
|
||||
super().__init__()
|
||||
|
||||
self.is_loaded = False
|
||||
|
||||
self.vision_tower_name = vision_tower
|
||||
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
||||
|
||||
if not delay_load:
|
||||
self.load_model()
|
||||
elif getattr(args, 'unfreeze_mm_vision_tower', False):
|
||||
self.load_model()
|
||||
else:
|
||||
self.cfg_only = AutoConfig.from_pretrained(self.vision_tower_name)
|
||||
|
||||
def load_model(self, device_map=None):
|
||||
if self.is_loaded:
|
||||
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
|
||||
return
|
||||
|
||||
self.image_processor = SiglipImageProcessor.from_pretrained(self.vision_tower_name)
|
||||
self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
|
||||
self.vision_tower.eval()
|
||||
|
||||
self.is_loaded = True
|
||||
|
||||
def feature_select(self, image_forward_outs):
|
||||
if self.select_feature == 'patch':
|
||||
image_features = image_forward_outs.last_hidden_state # (B, 729, 1536)
|
||||
elif self.select_feature == 'cls_patch':
|
||||
image_features = image_forward_outs.pooler_output # (B, 1, 1536)
|
||||
else:
|
||||
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
||||
return image_features
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, images):
|
||||
if type(images) is list:
|
||||
image_features = []
|
||||
for image in images:
|
||||
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
|
||||
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
||||
image_features.append(image_feature)
|
||||
else:
|
||||
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype))
|
||||
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
||||
|
||||
return image_features
|
||||
|
||||
@property
|
||||
def dummy_feature(self):
|
||||
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.vision_tower.dtype
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.vision_tower.device
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
if self.is_loaded:
|
||||
return self.vision_tower.config
|
||||
else:
|
||||
return self.cfg_only
|
||||
|
||||
@property
|
||||
def hidden_size(self):
|
||||
return self.config.hidden_size
|
||||
|
||||
@property
|
||||
def num_patches_per_side(self):
|
||||
return self.config.image_size // self.config.patch_size
|
||||
|
||||
@property
|
||||
def num_patches(self):
|
||||
return (self.config.image_size // self.config.patch_size)**2
|
||||
111
RDT-1B/models/multimodal_encoder/t5_encoder.py
Normal file
111
RDT-1B/models/multimodal_encoder/t5_encoder.py
Normal file
@ -0,0 +1,111 @@
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
|
||||
class T5Embedder:
|
||||
# available_models = ["google/t5-v1_1-xxl"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device,
|
||||
from_pretrained=None,
|
||||
*,
|
||||
cache_dir=None,
|
||||
hf_token=None,
|
||||
use_text_preprocessing=True,
|
||||
t5_model_kwargs=None,
|
||||
torch_dtype=None,
|
||||
use_offload_folder=None,
|
||||
model_max_length=120,
|
||||
local_files_only=False,
|
||||
):
|
||||
# from_pretrained="google/t5-v1_1-xxl" # zijian
|
||||
self.device = torch.device(device)
|
||||
self.torch_dtype = torch_dtype or torch.bfloat16
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
if t5_model_kwargs is None:
|
||||
t5_model_kwargs = {
|
||||
"low_cpu_mem_usage": True,
|
||||
"torch_dtype": self.torch_dtype,
|
||||
}
|
||||
|
||||
if use_offload_folder is not None:
|
||||
t5_model_kwargs["offload_folder"] = use_offload_folder
|
||||
t5_model_kwargs["device_map"] = {
|
||||
"shared": self.device,
|
||||
"encoder.embed_tokens": self.device,
|
||||
"encoder.block.0": self.device,
|
||||
"encoder.block.1": self.device,
|
||||
"encoder.block.2": self.device,
|
||||
"encoder.block.3": self.device,
|
||||
"encoder.block.4": self.device,
|
||||
"encoder.block.5": self.device,
|
||||
"encoder.block.6": self.device,
|
||||
"encoder.block.7": self.device,
|
||||
"encoder.block.8": self.device,
|
||||
"encoder.block.9": self.device,
|
||||
"encoder.block.10": self.device,
|
||||
"encoder.block.11": self.device,
|
||||
"encoder.block.12": "disk",
|
||||
"encoder.block.13": "disk",
|
||||
"encoder.block.14": "disk",
|
||||
"encoder.block.15": "disk",
|
||||
"encoder.block.16": "disk",
|
||||
"encoder.block.17": "disk",
|
||||
"encoder.block.18": "disk",
|
||||
"encoder.block.19": "disk",
|
||||
"encoder.block.20": "disk",
|
||||
"encoder.block.21": "disk",
|
||||
"encoder.block.22": "disk",
|
||||
"encoder.block.23": "disk",
|
||||
"encoder.final_layer_norm": "disk",
|
||||
"encoder.dropout": "disk",
|
||||
}
|
||||
else:
|
||||
t5_model_kwargs["device_map"] = {
|
||||
"shared": self.device,
|
||||
"encoder": self.device,
|
||||
}
|
||||
|
||||
self.use_text_preprocessing = use_text_preprocessing
|
||||
self.hf_token = hf_token
|
||||
|
||||
# assert from_pretrained in self.available_models
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
from_pretrained,
|
||||
model_max_length=model_max_length,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
self.model = T5EncoderModel.from_pretrained(
|
||||
from_pretrained,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
**t5_model_kwargs,
|
||||
).eval()
|
||||
self.model_max_length = model_max_length
|
||||
|
||||
def get_text_embeddings(self, texts):
|
||||
text_tokens_and_mask = self.tokenizer(
|
||||
texts,
|
||||
max_length=self.model_max_length,
|
||||
padding="longest",
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
input_ids = text_tokens_and_mask["input_ids"].to(self.device)
|
||||
attention_mask = text_tokens_and_mask["attention_mask"].to(self.device)
|
||||
with torch.no_grad():
|
||||
text_encoder_embs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)["last_hidden_state"].detach()
|
||||
return text_encoder_embs, attention_mask
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
T5Embedder(from_pretrained="google/t5-v1_1-xxl", device='cuda:7')
|
||||
BIN
RDT-1B/models/rdt/__pycache__/blocks.cpython-310.pyc
Normal file
BIN
RDT-1B/models/rdt/__pycache__/blocks.cpython-310.pyc
Normal file
Binary file not shown.
BIN
RDT-1B/models/rdt/__pycache__/model.cpython-310.pyc
Normal file
BIN
RDT-1B/models/rdt/__pycache__/model.cpython-310.pyc
Normal file
Binary file not shown.
304
RDT-1B/models/rdt/blocks.py
Normal file
304
RDT-1B/models/rdt/blocks.py
Normal file
@ -0,0 +1,304 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# DiT: https://github.com/facebookresearch/DiT
|
||||
# GLIDE: https://github.com/openai/glide-text2im
|
||||
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
||||
# --------------------------------------------------------
|
||||
|
||||
import math
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.jit import Final
|
||||
from timm.models.vision_transformer import Attention, Mlp, RmsNorm, use_fused_attn
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Embedding Layers for Timesteps and Condition Inptus #
|
||||
#################################################################################
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=torch.bfloat16):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
self.dtype = dtype
|
||||
|
||||
def timestep_embedding(self, t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) *
|
||||
torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding.to(self.dtype)
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Cross Attention Layers #
|
||||
#################################################################################
|
||||
class CrossAttention(nn.Module):
|
||||
"""
|
||||
A cross-attention layer with flash attention.
|
||||
"""
|
||||
fused_attn: Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = False,
|
||||
attn_drop: float = 0,
|
||||
proj_drop: float = 0,
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
_, L, _ = c.shape
|
||||
q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
||||
kv = self.kv(c).reshape(B, L, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
k, v = kv.unbind(0)
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
# Prepare attn mask (B, L) to mask the conditioion
|
||||
if mask is not None:
|
||||
mask = mask.reshape(B, 1, 1, L)
|
||||
mask = mask.expand(-1, -1, N, -1)
|
||||
|
||||
if self.fused_attn:
|
||||
x = F.scaled_dot_product_attention(query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||
attn_mask=mask)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
if mask is not None:
|
||||
attn = attn.masked_fill_(mask.logical_not(), float('-inf'))
|
||||
attn = attn.softmax(dim=-1)
|
||||
if self.attn_drop.p > 0:
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
|
||||
x = x.permute(0, 2, 1, 3).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
if self.proj_drop.p > 0:
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
#################################################################################
|
||||
# RDT Block #
|
||||
#################################################################################
|
||||
class RDTBlock(nn.Module):
|
||||
"""
|
||||
A RDT block with cross-attention conditioning.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, num_heads, **block_kwargs):
|
||||
super().__init__()
|
||||
self.norm1 = RmsNorm(hidden_size, eps=1e-6)
|
||||
self.attn = Attention(dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=True,
|
||||
qk_norm=True,
|
||||
norm_layer=RmsNorm,
|
||||
**block_kwargs)
|
||||
self.cross_attn = CrossAttention(hidden_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=True,
|
||||
qk_norm=True,
|
||||
norm_layer=RmsNorm,
|
||||
**block_kwargs)
|
||||
|
||||
self.norm2 = RmsNorm(hidden_size, eps=1e-6)
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.ffn = Mlp(in_features=hidden_size, hidden_features=hidden_size, act_layer=approx_gelu, drop=0)
|
||||
self.norm3 = RmsNorm(hidden_size, eps=1e-6)
|
||||
|
||||
def forward(self, x, c, mask=None):
|
||||
origin_x = x
|
||||
x = self.norm1(x)
|
||||
x = self.attn(x)
|
||||
x = x + origin_x
|
||||
|
||||
origin_x = x
|
||||
x = self.norm2(x)
|
||||
x = self.cross_attn(x, c, mask)
|
||||
x = x + origin_x
|
||||
|
||||
origin_x = x
|
||||
x = self.norm3(x)
|
||||
x = self.ffn(x)
|
||||
x = x + origin_x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of RDT.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = RmsNorm(hidden_size, eps=1e-6)
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.ffn_final = Mlp(in_features=hidden_size,
|
||||
hidden_features=hidden_size,
|
||||
out_features=out_channels,
|
||||
act_layer=approx_gelu,
|
||||
drop=0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm_final(x)
|
||||
x = self.ffn_final(x)
|
||||
return x
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Sine/Cosine Positional Embedding Functions #
|
||||
#################################################################################
|
||||
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||
omega /= embed_dim / 2.
|
||||
omega = 1. / 10000**omega # (D/2,)
|
||||
|
||||
if not isinstance(pos, np.ndarray):
|
||||
pos = np.array(pos, dtype=np.float64)
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_nd_sincos_pos_embed_from_grid(embed_dim, grid_sizes):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
grid_sizes: the grids sizes in each dimension (K,).
|
||||
out: (grid_sizes[0], ..., grid_sizes[K-1], D)
|
||||
"""
|
||||
num_sizes = len(grid_sizes)
|
||||
# For grid size of 1, we do not need to add any positional embedding
|
||||
num_valid_sizes = len([x for x in grid_sizes if x > 1])
|
||||
emb = np.zeros(grid_sizes + (embed_dim, ))
|
||||
# Uniformly divide the embedding dimension for each grid size
|
||||
dim_for_each_grid = embed_dim // num_valid_sizes
|
||||
# To make it even
|
||||
if dim_for_each_grid % 2 != 0:
|
||||
dim_for_each_grid -= 1
|
||||
valid_size_idx = 0
|
||||
for size_idx in range(num_sizes):
|
||||
grid_size = grid_sizes[size_idx]
|
||||
if grid_size <= 1:
|
||||
continue
|
||||
pos = np.arange(grid_size)
|
||||
posemb_shape = [1] * len(grid_sizes) + [dim_for_each_grid]
|
||||
posemb_shape[size_idx] = -1
|
||||
emb[..., valid_size_idx * dim_for_each_grid:(valid_size_idx + 1) * dim_for_each_grid] += \
|
||||
get_1d_sincos_pos_embed_from_grid(dim_for_each_grid, pos).reshape(posemb_shape)
|
||||
valid_size_idx += 1
|
||||
return emb
|
||||
|
||||
|
||||
def get_multimodal_cond_pos_embed(embed_dim, mm_cond_lens: OrderedDict, embed_modality=True):
|
||||
"""
|
||||
Generate position embeddings for multimodal conditions.
|
||||
|
||||
mm_cond_lens: an OrderedDict containing
|
||||
(modality name, modality token length) pairs.
|
||||
For `"image"` modality, the value can be a multi-dimensional tuple.
|
||||
If the length < 0, it means there is no position embedding for the modality or grid.
|
||||
embed_modality: whether to embed the modality information. Default is True.
|
||||
"""
|
||||
num_modalities = len(mm_cond_lens)
|
||||
modality_pos_embed = np.zeros((num_modalities, embed_dim))
|
||||
if embed_modality:
|
||||
# Get embeddings for various modalites
|
||||
# We put it in the first half
|
||||
modality_sincos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, torch.arange(num_modalities))
|
||||
modality_pos_embed[:, :embed_dim // 2] = modality_sincos_embed
|
||||
# The second half is for position embeddings
|
||||
pos_embed_dim = embed_dim // 2
|
||||
else:
|
||||
# The whole embedding is for position embeddings
|
||||
pos_embed_dim = embed_dim
|
||||
|
||||
# Get embeddings for positions inside each modality
|
||||
c_pos_emb = np.zeros((0, embed_dim))
|
||||
for idx, (modality, cond_len) in enumerate(mm_cond_lens.items()):
|
||||
if modality == "image" and \
|
||||
(isinstance(cond_len, tuple) or isinstance(cond_len, list)):
|
||||
all_grid_sizes = tuple([abs(x) for x in cond_len])
|
||||
embed_grid_sizes = tuple([x if x > 0 else 1 for x in cond_len])
|
||||
cond_sincos_embed = get_nd_sincos_pos_embed_from_grid(pos_embed_dim, embed_grid_sizes)
|
||||
cond_pos_embed = np.zeros(all_grid_sizes + (embed_dim, ))
|
||||
cond_pos_embed[..., -pos_embed_dim:] += cond_sincos_embed
|
||||
cond_pos_embed = cond_pos_embed.reshape((-1, embed_dim))
|
||||
else:
|
||||
cond_sincos_embed = get_1d_sincos_pos_embed_from_grid(pos_embed_dim,
|
||||
torch.arange(cond_len if cond_len > 0 else 1))
|
||||
cond_pos_embed = np.zeros((abs(cond_len), embed_dim))
|
||||
cond_pos_embed[:, -pos_embed_dim:] += cond_sincos_embed
|
||||
cond_pos_embed += modality_pos_embed[idx]
|
||||
c_pos_emb = np.concatenate([c_pos_emb, cond_pos_embed], axis=0)
|
||||
|
||||
return c_pos_emb
|
||||
156
RDT-1B/models/rdt/model.py
Normal file
156
RDT-1B/models/rdt/model.py
Normal file
@ -0,0 +1,156 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# DiT: https://github.com/facebookresearch/DiT
|
||||
# GLIDE: https://github.com/openai/glide-text2im
|
||||
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
||||
# --------------------------------------------------------
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from pathlib import Path
|
||||
import sys, os
|
||||
# get current workspace
|
||||
current_file = Path(__file__)
|
||||
sys.path.append(str(current_file.parent.parent))
|
||||
|
||||
from rdt.blocks import (FinalLayer, RDTBlock, TimestepEmbedder, get_1d_sincos_pos_embed_from_grid,
|
||||
get_multimodal_cond_pos_embed)
|
||||
|
||||
|
||||
class RDT(nn.Module):
|
||||
"""
|
||||
Class for Robotics Diffusion Transformers.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
output_dim=128,
|
||||
horizon=32,
|
||||
hidden_size=1152,
|
||||
depth=28,
|
||||
num_heads=16,
|
||||
max_lang_cond_len=1024,
|
||||
img_cond_len=4096,
|
||||
lang_pos_embed_config=None,
|
||||
img_pos_embed_config=None,
|
||||
dtype=torch.bfloat16):
|
||||
super().__init__()
|
||||
self.horizon = horizon
|
||||
self.hidden_size = hidden_size
|
||||
self.max_lang_cond_len = max_lang_cond_len
|
||||
self.img_cond_len = img_cond_len
|
||||
self.dtype = dtype
|
||||
self.lang_pos_embed_config = lang_pos_embed_config
|
||||
self.img_pos_embed_config = img_pos_embed_config
|
||||
|
||||
self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype)
|
||||
self.freq_embedder = TimestepEmbedder(hidden_size, dtype=dtype)
|
||||
|
||||
# We will use trainable sin-cos embeddings
|
||||
# [timestep; state; action]
|
||||
self.x_pos_embed = nn.Parameter(torch.zeros(1, horizon + 3, hidden_size))
|
||||
# Language conditions
|
||||
self.lang_cond_pos_embed = nn.Parameter(torch.zeros(1, max_lang_cond_len, hidden_size))
|
||||
# Image conditions
|
||||
self.img_cond_pos_embed = nn.Parameter(torch.zeros(1, img_cond_len, hidden_size))
|
||||
|
||||
self.blocks = nn.ModuleList([RDTBlock(hidden_size, num_heads) for _ in range(depth)])
|
||||
self.final_layer = FinalLayer(hidden_size, output_dim)
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize transformer layers:
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
self.apply(_basic_init)
|
||||
|
||||
# Initialize pos_embed by sin-cos embedding
|
||||
x_pos_embed = get_multimodal_cond_pos_embed(embed_dim=self.hidden_size,
|
||||
mm_cond_lens=OrderedDict([
|
||||
('timestep', 1),
|
||||
('ctrl_freq', 1),
|
||||
('state', 1),
|
||||
('action', self.horizon),
|
||||
]))
|
||||
self.x_pos_embed.data.copy_(torch.from_numpy(x_pos_embed).float().unsqueeze(0))
|
||||
|
||||
if self.lang_pos_embed_config is None:
|
||||
lang_cond_pos_embed = get_1d_sincos_pos_embed_from_grid(self.hidden_size,
|
||||
torch.arange(self.max_lang_cond_len))
|
||||
else:
|
||||
lang_cond_pos_embed = get_multimodal_cond_pos_embed(embed_dim=self.hidden_size,
|
||||
mm_cond_lens=OrderedDict(self.lang_pos_embed_config),
|
||||
embed_modality=False)
|
||||
self.lang_cond_pos_embed.data.copy_(torch.from_numpy(lang_cond_pos_embed).float().unsqueeze(0))
|
||||
|
||||
if self.img_pos_embed_config is None:
|
||||
img_cond_pos_embed = get_1d_sincos_pos_embed_from_grid(self.hidden_size, torch.arange(self.img_cond_len))
|
||||
else:
|
||||
img_cond_pos_embed = get_multimodal_cond_pos_embed(embed_dim=self.hidden_size,
|
||||
mm_cond_lens=OrderedDict(self.img_pos_embed_config),
|
||||
embed_modality=False)
|
||||
self.img_cond_pos_embed.data.copy_(torch.from_numpy(img_cond_pos_embed).float().unsqueeze(0))
|
||||
|
||||
# Initialize timestep and control freq embedding MLP
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
nn.init.normal_(self.freq_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.freq_embedder.mlp[2].weight, std=0.02)
|
||||
|
||||
# Initialize the final layer: zero-out the final linear layer
|
||||
nn.init.constant_(self.final_layer.ffn_final.fc2.weight, 0)
|
||||
nn.init.constant_(self.final_layer.ffn_final.fc2.bias, 0)
|
||||
|
||||
# Move all the params to given data type:
|
||||
self.to(self.dtype)
|
||||
|
||||
def forward(self, x, freq, t, lang_c, img_c, lang_mask=None, img_mask=None):
|
||||
"""
|
||||
Forward pass of RDT.
|
||||
|
||||
x: (B, T, D), state + action token sequence, T = horizon + 1,
|
||||
dimension D is assumed to be the same as the hidden size.
|
||||
freq: (B,), a scalar indicating control frequency.
|
||||
t: (B,) or (1,), diffusion timesteps.
|
||||
lang_c: (B, L_lang, D) or None, language condition tokens (variable length),
|
||||
dimension D is assumed to be the same as the hidden size.
|
||||
img_c: (B, L_img, D) or None, image condition tokens (fixed length),
|
||||
dimension D is assumed to be the same as the hidden size.
|
||||
lang_mask: (B, L_lang) or None, language condition mask (True for valid).
|
||||
img_mask: (B, L_img) or None, image condition mask (True for valid).
|
||||
"""
|
||||
t = self.t_embedder(t).unsqueeze(1) # (B, 1, D) or (1, 1, D)
|
||||
freq = self.freq_embedder(freq).unsqueeze(1) # (B, 1, D)
|
||||
# Append timestep to the input tokens
|
||||
if t.shape[0] == 1:
|
||||
t = t.expand(x.shape[0], -1, -1)
|
||||
x = torch.cat([t, freq, x], dim=1) # (B, T+1, D)
|
||||
|
||||
# Add multimodal position embeddings
|
||||
x = x + self.x_pos_embed
|
||||
# Note the lang is of variable length
|
||||
lang_c = lang_c + self.lang_cond_pos_embed[:, :lang_c.shape[1]]
|
||||
img_c = img_c + self.img_cond_pos_embed
|
||||
|
||||
# Forward pass
|
||||
conds = [lang_c, img_c]
|
||||
masks = [lang_mask, img_mask]
|
||||
for i, block in enumerate(self.blocks):
|
||||
c, mask = conds[i % 2], masks[i % 2]
|
||||
x = block(x, c, mask) # (B, T+1, D)
|
||||
# Inject the language condition at the final layer
|
||||
x = self.final_layer(x) # (B, T+1, out_channels)
|
||||
|
||||
# Only preserve the action tokens
|
||||
x = x[:, -self.horizon:]
|
||||
return x
|
||||
246
RDT-1B/models/rdt_runner.py
Normal file
246
RDT-1B/models/rdt_runner.py
Normal file
@ -0,0 +1,246 @@
|
||||
import re, sys, os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from diffusers.schedulers.scheduling_dpmsolver_multistep import \
|
||||
DPMSolverMultistepScheduler
|
||||
|
||||
from pathlib import Path
|
||||
# get current workspace
|
||||
current_file = Path(__file__)
|
||||
sys.path.append(os.path.join(current_file.parent))
|
||||
from hub_mixin import CompatiblePyTorchModelHubMixin
|
||||
from rdt.model import RDT
|
||||
|
||||
|
||||
class RDTRunner(nn.Module,
|
||||
CompatiblePyTorchModelHubMixin,
|
||||
repo_url="https://huggingface.co/robotics-diffusion-transformer/rdt-1b"):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
action_dim,
|
||||
pred_horizon,
|
||||
config,
|
||||
lang_token_dim,
|
||||
img_token_dim,
|
||||
state_token_dim,
|
||||
max_lang_cond_len,
|
||||
img_cond_len,
|
||||
lang_pos_embed_config=None,
|
||||
img_pos_embed_config=None,
|
||||
dtype=torch.bfloat16):
|
||||
super(RDTRunner, self).__init__()
|
||||
# Create diffusion model
|
||||
hidden_size = config['rdt']['hidden_size']
|
||||
self.model = RDT(
|
||||
output_dim=action_dim,
|
||||
horizon=pred_horizon,
|
||||
hidden_size=hidden_size,
|
||||
depth=config['rdt']['depth'],
|
||||
num_heads=config['rdt']['num_heads'],
|
||||
max_lang_cond_len=max_lang_cond_len,
|
||||
img_cond_len=img_cond_len,
|
||||
lang_pos_embed_config=lang_pos_embed_config,
|
||||
img_pos_embed_config=img_pos_embed_config,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Create adpators for various conditional inputs
|
||||
self.lang_adaptor = self.build_condition_adapter(config['lang_adaptor'],
|
||||
in_features=lang_token_dim,
|
||||
out_features=hidden_size)
|
||||
self.img_adaptor = self.build_condition_adapter(config['img_adaptor'],
|
||||
in_features=img_token_dim,
|
||||
out_features=hidden_size)
|
||||
# A `state` refers to an action or a proprioception vector
|
||||
self.state_adaptor = self.build_condition_adapter(
|
||||
config['state_adaptor'],
|
||||
in_features=state_token_dim * 2, # state + state mask (indicator)
|
||||
out_features=hidden_size)
|
||||
|
||||
# Create the noise scheduler
|
||||
noise_scheduler_config = config['noise_scheduler']
|
||||
self.noise_scheduler = DDPMScheduler(
|
||||
num_train_timesteps=noise_scheduler_config['num_train_timesteps'],
|
||||
beta_schedule=noise_scheduler_config['beta_schedule'],
|
||||
prediction_type=noise_scheduler_config['prediction_type'],
|
||||
clip_sample=noise_scheduler_config['clip_sample'],
|
||||
)
|
||||
self.noise_scheduler_sample = DPMSolverMultistepScheduler(
|
||||
num_train_timesteps=noise_scheduler_config['num_train_timesteps'],
|
||||
beta_schedule=noise_scheduler_config['beta_schedule'],
|
||||
prediction_type=noise_scheduler_config['prediction_type'],
|
||||
)
|
||||
|
||||
self.num_train_timesteps = noise_scheduler_config['num_train_timesteps']
|
||||
self.num_inference_timesteps = noise_scheduler_config['num_inference_timesteps']
|
||||
self.prediction_type = noise_scheduler_config['prediction_type']
|
||||
|
||||
self.pred_horizon = pred_horizon
|
||||
self.action_dim = action_dim
|
||||
|
||||
print("Diffusion params: %e" %
|
||||
sum([p.numel() for p in self.model.parameters()] + [p.numel() for p in self.lang_adaptor.parameters()] +
|
||||
[p.numel()
|
||||
for p in self.img_adaptor.parameters()] + [p.numel() for p in self.state_adaptor.parameters()]))
|
||||
|
||||
def build_condition_adapter(self, projector_type, in_features, out_features):
|
||||
projector = None
|
||||
if projector_type == 'linear':
|
||||
projector = nn.Linear(in_features, out_features)
|
||||
else:
|
||||
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
|
||||
if mlp_gelu_match:
|
||||
mlp_depth = int(mlp_gelu_match.group(1))
|
||||
modules = [nn.Linear(in_features, out_features)]
|
||||
for _ in range(1, mlp_depth):
|
||||
modules.append(nn.GELU(approximate="tanh"))
|
||||
modules.append(nn.Linear(out_features, out_features))
|
||||
projector = nn.Sequential(*modules)
|
||||
|
||||
if projector is None:
|
||||
raise ValueError(f'Unknown projector type: {projector_type}')
|
||||
|
||||
return projector
|
||||
|
||||
def adapt_conditions(self, lang_tokens, img_tokens, state_tokens):
|
||||
'''
|
||||
lang_tokens: (batch_size, lang_len, lang_token_dim)
|
||||
img_tokens: (batch_size, img_len, img_token_dim)
|
||||
state_tokens: (batch_size, state_len, state_token_dim)
|
||||
|
||||
return: adpated (..., hidden_size) for all input tokens
|
||||
'''
|
||||
adpated_lang = self.lang_adaptor(lang_tokens)
|
||||
adpated_img = self.img_adaptor(img_tokens)
|
||||
adpated_state = self.state_adaptor(state_tokens)
|
||||
|
||||
return adpated_lang, adpated_img, adpated_state
|
||||
|
||||
def conditional_sample(self, lang_cond, lang_attn_mask, img_cond, state_traj, action_mask, ctrl_freqs):
|
||||
'''
|
||||
lang_cond: language conditional data, (batch_size, lang_len, hidden_size).
|
||||
lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens,
|
||||
which should be True-False bool tensor.
|
||||
img_cond: image conditional data, (batch_size, img_len, hidden_size).
|
||||
state_traj: (batch_size, 1, hidden_size), state trajectory.
|
||||
action_mask: (batch_size, 1, action_dim), a 0-1 **float** tensor
|
||||
indicating the valid action dimensions.
|
||||
ctrl_freqs: (batch_size,), control frequency for each sample.
|
||||
|
||||
return: (batch_size, horizon, action_dim)
|
||||
'''
|
||||
device = state_traj.device
|
||||
dtype = state_traj.dtype
|
||||
noisy_action = torch.randn(size=(state_traj.shape[0], self.pred_horizon, self.action_dim),
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
action_mask = action_mask.expand(-1, self.pred_horizon, -1)
|
||||
|
||||
# Set step values
|
||||
self.noise_scheduler_sample.set_timesteps(self.num_inference_timesteps)
|
||||
|
||||
for t in self.noise_scheduler_sample.timesteps:
|
||||
# Prepare state-action trajectory
|
||||
action_traj = torch.cat([noisy_action, action_mask], dim=2)
|
||||
action_traj = self.state_adaptor(action_traj)
|
||||
state_action_traj = torch.cat([state_traj, action_traj], dim=1)
|
||||
|
||||
# Predict the model output
|
||||
model_output = self.model(state_action_traj,
|
||||
ctrl_freqs,
|
||||
t.unsqueeze(-1).to(device),
|
||||
lang_cond,
|
||||
img_cond,
|
||||
lang_mask=lang_attn_mask)
|
||||
|
||||
# Compute previous actions: x_t -> x_t-1
|
||||
noisy_action = self.noise_scheduler_sample.step(model_output, t, noisy_action).prev_sample
|
||||
noisy_action = noisy_action.to(state_traj.dtype)
|
||||
|
||||
# Finally apply the action mask to mask invalid action dimensions
|
||||
noisy_action = noisy_action * action_mask
|
||||
|
||||
return noisy_action
|
||||
|
||||
# ========= Train ============
|
||||
def compute_loss(self, lang_tokens, lang_attn_mask, img_tokens, state_tokens, action_gt, action_mask,
|
||||
ctrl_freqs) -> torch.Tensor:
|
||||
'''
|
||||
lang_tokens: (batch_size, lang_len, lang_token_dim)
|
||||
lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens,
|
||||
which should be True-False bool tensor.
|
||||
img_tokens: (batch_size, img_len, img_token_dim)
|
||||
state_tokens: (batch_size, 1, state_token_dim)
|
||||
action_gt: (batch_size, horizon, state_token_dim), ground-truth actions for supervision
|
||||
action_mask: (batch_size, 1, state_token_dim), a 0-1 **float** tensor.
|
||||
ctrl_freqs: (batch_size,), control frequency for each sample.
|
||||
|
||||
return: loss_value, a scalar tensor
|
||||
'''
|
||||
batch_size = lang_tokens.shape[0]
|
||||
device = lang_tokens.device
|
||||
# Sample noise that we'll add to the actions
|
||||
noise = torch.randn(action_gt.shape, dtype=action_gt.dtype, device=device)
|
||||
# Sample random diffusion timesteps
|
||||
timesteps = torch.randint(0, self.num_train_timesteps, (batch_size, ), device=device).long()
|
||||
# Add noise to the clean actions according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_action = self.noise_scheduler.add_noise(action_gt, noise, timesteps)
|
||||
|
||||
# Concatenate the state and action tokens to form the input sequence
|
||||
state_action_traj = torch.cat([state_tokens, noisy_action], dim=1)
|
||||
# Append the action mask to the input sequence
|
||||
action_mask = action_mask.expand(-1, state_action_traj.shape[1], -1)
|
||||
state_action_traj = torch.cat([state_action_traj, action_mask], dim=2)
|
||||
# Align the dimension with the hidden size
|
||||
lang_cond, img_cond, state_action_traj = self.adapt_conditions(lang_tokens, img_tokens, state_action_traj)
|
||||
# Predict the denoised result
|
||||
pred = self.model(state_action_traj, ctrl_freqs, timesteps, lang_cond, img_cond, lang_mask=lang_attn_mask)
|
||||
|
||||
pred_type = self.prediction_type
|
||||
if pred_type == 'epsilon':
|
||||
target = noise
|
||||
elif pred_type == 'sample':
|
||||
target = action_gt
|
||||
else:
|
||||
raise ValueError(f"Unsupported prediction type {pred_type}")
|
||||
loss = F.mse_loss(pred, target)
|
||||
return loss
|
||||
|
||||
# ========= Inference ============
|
||||
def predict_action(self, lang_tokens, lang_attn_mask, img_tokens, state_tokens, action_mask, ctrl_freqs):
|
||||
'''
|
||||
lang_tokens: (batch_size, lang_len, lang_token_dim)
|
||||
lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens,
|
||||
which should be True-False bool tensor.
|
||||
img_tokens: (batch_size, img_len, img_token_dim)
|
||||
state_tokens: (batch_size, 1, state_token_dim)
|
||||
action_mask: (batch_size, 1, action_dim),
|
||||
which should be a 0-1 **float** tensor.
|
||||
ctrl_freqs: (batch_size,), control frequency for each sample.
|
||||
|
||||
return: (batch_size, horizon, action_dim), predicted action sequence
|
||||
'''
|
||||
# Prepare the state and conditions
|
||||
state_tokens = torch.cat([state_tokens, action_mask], dim=2)
|
||||
lang_cond, img_cond, state_traj = self.adapt_conditions(lang_tokens, img_tokens, state_tokens)
|
||||
|
||||
# Run sampling
|
||||
action_pred = self.conditional_sample(
|
||||
lang_cond,
|
||||
lang_attn_mask,
|
||||
img_cond,
|
||||
state_traj,
|
||||
action_mask,
|
||||
ctrl_freqs,
|
||||
)
|
||||
|
||||
return action_pred
|
||||
|
||||
def forward(self, *args, **kwargs) -> torch.Tensor:
|
||||
return self.compute_loss(*args, **kwargs)
|
||||
49
RDT-1B/pretrain.sh
Normal file
49
RDT-1B/pretrain.sh
Normal file
@ -0,0 +1,49 @@
|
||||
#!/bin/bash
|
||||
|
||||
export NCCL_IB_HCA=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
|
||||
export NCCL_IB_DISABLE=0
|
||||
export NCCL_SOCKET_IFNAME=bond0
|
||||
export NCCL_DEBUG=INFO
|
||||
export NCCL_NVLS_ENABLE=0
|
||||
|
||||
export TEXT_ENCODER_NAME="google/t5-v1_1-xxl"
|
||||
export VISION_ENCODER_NAME="google/siglip-so400m-patch14-384"
|
||||
export OUTPUT_DIR="./checkpoints/rdt-pretrain-1b"
|
||||
export CFLAGS="-I/usr/include"
|
||||
export LDFLAGS="-L/usr/lib/x86_64-linux-gnu"
|
||||
export CUTLASS_PATH="/path/to/cutlass"
|
||||
|
||||
export WANDB_PROJECT="robotics_diffusion_transformer"
|
||||
|
||||
if [ ! -d "$OUTPUT_DIR" ]; then
|
||||
mkdir "$OUTPUT_DIR"
|
||||
echo "Folder '$OUTPUT_DIR' created"
|
||||
else
|
||||
echo "Folder '$OUTPUT_DIR' already exists"
|
||||
fi
|
||||
|
||||
# For run in a single node/machine
|
||||
# accelerate launch main.py \
|
||||
# --deepspeed="./configs/zero2.json" \
|
||||
# ...
|
||||
|
||||
deepspeed --hostfile=hostfile.txt main.py \
|
||||
--deepspeed="./configs/zero2.json" \
|
||||
--pretrained_text_encoder_name_or_path=$TEXT_ENCODER_NAME \
|
||||
--pretrained_vision_encoder_name_or_path=$VISION_ENCODER_NAME \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--train_batch_size=32 \
|
||||
--sample_batch_size=64 \
|
||||
--max_train_steps=1000000 \
|
||||
--checkpointing_period=1000 \
|
||||
--sample_period=500 \
|
||||
--checkpoints_total_limit=40 \
|
||||
--lr_scheduler="constant" \
|
||||
--learning_rate=1e-4 \
|
||||
--mixed_precision="bf16" \
|
||||
--dataloader_num_workers=8 \
|
||||
--dataset_type="pretrain" \
|
||||
--report_to=wandb
|
||||
|
||||
# Use this to resume training from some previous checkpoint
|
||||
# --resume_from_checkpoint="checkpoint-1000" \
|
||||
9
RDT-1B/process_data_rdt.sh
Normal file
9
RDT-1B/process_data_rdt.sh
Normal file
@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
|
||||
task_name=${1}
|
||||
task_config=${2}
|
||||
expert_data_num=${3}
|
||||
gpu_id=${4}
|
||||
|
||||
export CUDA_VISIBLE_DEVICES=${gpu_id}
|
||||
python scripts/process_data.py $task_name $task_config $expert_data_num
|
||||
23
RDT-1B/requirements.txt
Normal file
23
RDT-1B/requirements.txt
Normal file
@ -0,0 +1,23 @@
|
||||
numpy<2.0
|
||||
packaging==24.0
|
||||
wandb==0.17.0
|
||||
deepspeed==0.14.2
|
||||
accelerate==0.30.1
|
||||
diffusers==0.27.2
|
||||
timm==1.0.3
|
||||
transformers==4.41.0
|
||||
sentencepiece==0.2.0
|
||||
h5py==3.11.0
|
||||
opencv-python==4.9.0.80
|
||||
imgaug==0.4.0
|
||||
pytz>=2020.1
|
||||
|
||||
# requirements_data.txt
|
||||
tfds-nightly==4.9.4.dev202402070044
|
||||
gsutil==5.27
|
||||
tensorflow==2.15.0.post1
|
||||
pillow==10.2.0
|
||||
pyyaml==6.0.1
|
||||
tensorflow-graphics==2021.12.3
|
||||
imageio==2.34.0
|
||||
imageio-ffmpeg==0.4.9
|
||||
941
RDT-1B/scripts/agilex_inference.py
Normal file
941
RDT-1B/scripts/agilex_inference.py
Normal file
@ -0,0 +1,941 @@
|
||||
#!/home/lin/software/miniconda3/envs/aloha/bin/python
|
||||
# -- coding: UTF-8
|
||||
"""
|
||||
#!/usr/bin/python3
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import yaml
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
import rospy
|
||||
import torch
|
||||
from cv_bridge import CvBridge
|
||||
from geometry_msgs.msg import Twist
|
||||
from nav_msgs.msg import Odometry
|
||||
from PIL import Image as PImage
|
||||
from sensor_msgs.msg import Image, JointState
|
||||
from std_msgs.msg import Header
|
||||
import cv2
|
||||
|
||||
from scripts.agilex_model import create_model
|
||||
|
||||
# sys.path.append("./")
|
||||
|
||||
CAMERA_NAMES = ["cam_high", "cam_right_wrist", "cam_left_wrist"]
|
||||
|
||||
observation_window = None
|
||||
|
||||
lang_embeddings = None
|
||||
|
||||
# debug
|
||||
preload_images = None
|
||||
|
||||
|
||||
# Initialize the model
|
||||
def make_policy(args):
|
||||
with open(args.config_path, "r") as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
args.config = config
|
||||
|
||||
# pretrained_text_encoder_name_or_path = "google/t5-v1_1-xxl"
|
||||
pretrained_vision_encoder_name_or_path = "google/siglip-so400m-patch14-384"
|
||||
model = create_model(
|
||||
args=args.config,
|
||||
dtype=torch.bfloat16,
|
||||
pretrained=args.pretrained_model_name_or_path,
|
||||
# pretrained_text_encoder_name_or_path=pretrained_text_encoder_name_or_path,
|
||||
pretrained_vision_encoder_name_or_path=pretrained_vision_encoder_name_or_path,
|
||||
control_frequency=args.ctrl_freq,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
|
||||
# Interpolate the actions to make the robot move smoothly
|
||||
def interpolate_action(args, prev_action, cur_action):
|
||||
steps = np.concatenate((np.array(args.arm_steps_length), np.array(args.arm_steps_length)), axis=0)
|
||||
diff = np.abs(cur_action - prev_action)
|
||||
step = np.ceil(diff / steps).astype(int)
|
||||
step = np.max(step)
|
||||
if step <= 1:
|
||||
return cur_action[np.newaxis, :]
|
||||
new_actions = np.linspace(prev_action, cur_action, step + 1)
|
||||
return new_actions[1:]
|
||||
|
||||
|
||||
def get_config(args):
|
||||
config = {
|
||||
"episode_len": args.max_publish_step,
|
||||
"state_dim": 14,
|
||||
"chunk_size": args.chunk_size,
|
||||
"camera_names": CAMERA_NAMES,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
# Get the observation from the ROS topic
|
||||
def get_ros_observation(args, ros_operator):
|
||||
rate = rospy.Rate(args.publish_rate)
|
||||
print_flag = True
|
||||
|
||||
while True and not rospy.is_shutdown():
|
||||
result = ros_operator.get_frame()
|
||||
if not result:
|
||||
if print_flag:
|
||||
print("syn fail when get_ros_observation")
|
||||
print_flag = False
|
||||
rate.sleep()
|
||||
continue
|
||||
print_flag = True
|
||||
(
|
||||
img_front,
|
||||
img_left,
|
||||
img_right,
|
||||
img_front_depth,
|
||||
img_left_depth,
|
||||
img_right_depth,
|
||||
puppet_arm_left,
|
||||
puppet_arm_right,
|
||||
robot_base,
|
||||
) = result
|
||||
# print(f"sync success when get_ros_observation")
|
||||
return (img_front, img_left, img_right, puppet_arm_left, puppet_arm_right)
|
||||
|
||||
|
||||
# Update the observation window buffer
|
||||
def update_observation_window(args, config, ros_operator):
|
||||
# JPEG transformation
|
||||
# Align with training
|
||||
def jpeg_mapping(img):
|
||||
img = cv2.imencode(".jpg", img)[1].tobytes()
|
||||
img = cv2.imdecode(np.frombuffer(img, np.uint8), cv2.IMREAD_COLOR)
|
||||
return img
|
||||
|
||||
global observation_window
|
||||
if observation_window is None:
|
||||
observation_window = deque(maxlen=2)
|
||||
|
||||
# Append the first dummy image
|
||||
observation_window.append({
|
||||
"qpos": None,
|
||||
"images": {
|
||||
config["camera_names"][0]: None,
|
||||
config["camera_names"][1]: None,
|
||||
config["camera_names"][2]: None,
|
||||
},
|
||||
})
|
||||
|
||||
img_front, img_left, img_right, puppet_arm_left, puppet_arm_right = (get_ros_observation(args, ros_operator))
|
||||
img_front = jpeg_mapping(img_front)
|
||||
img_left = jpeg_mapping(img_left)
|
||||
img_right = jpeg_mapping(img_right)
|
||||
|
||||
qpos = np.concatenate(
|
||||
(np.array(puppet_arm_left.position), np.array(puppet_arm_right.position)),
|
||||
axis=0,
|
||||
)
|
||||
qpos = torch.from_numpy(qpos).float().cuda()
|
||||
observation_window.append({
|
||||
"qpos": qpos,
|
||||
"images": {
|
||||
config["camera_names"][0]: img_front,
|
||||
config["camera_names"][1]: img_right,
|
||||
config["camera_names"][2]: img_left,
|
||||
},
|
||||
})
|
||||
|
||||
|
||||
# RDT inference
|
||||
def inference_fn(args, config, policy, t):
|
||||
global observation_window
|
||||
global lang_embeddings
|
||||
|
||||
# print(f"Start inference_thread_fn: t={t}")
|
||||
while True and not rospy.is_shutdown():
|
||||
time1 = time.time()
|
||||
|
||||
# fetch images in sequence [front, right, left]
|
||||
image_arrs = [
|
||||
observation_window[-2]["images"][config["camera_names"][0]],
|
||||
observation_window[-2]["images"][config["camera_names"][1]],
|
||||
observation_window[-2]["images"][config["camera_names"][2]],
|
||||
observation_window[-1]["images"][config["camera_names"][0]],
|
||||
observation_window[-1]["images"][config["camera_names"][1]],
|
||||
observation_window[-1]["images"][config["camera_names"][2]],
|
||||
]
|
||||
|
||||
# fetch debug images in sequence [front, right, left]
|
||||
# image_arrs = [
|
||||
# preload_images[config['camera_names'][0]][max(t - 1, 0)],
|
||||
# preload_images[config['camera_names'][2]][max(t - 1, 0)],
|
||||
# preload_images[config['camera_names'][1]][max(t - 1, 0)],
|
||||
# preload_images[config['camera_names'][0]][t],
|
||||
# preload_images[config['camera_names'][2]][t],
|
||||
# preload_images[config['camera_names'][1]][t]
|
||||
# ]
|
||||
# # encode the images
|
||||
# for i in range(len(image_arrs)):
|
||||
# image_arrs[i] = cv2.imdecode(np.frombuffer(image_arrs[i], np.uint8), cv2.IMREAD_COLOR)
|
||||
# proprio = torch.from_numpy(preload_images['qpos'][t]).float().cuda()
|
||||
|
||||
images = [PImage.fromarray(arr) if arr is not None else None for arr in image_arrs]
|
||||
|
||||
# for i, pos in enumerate(['f', 'r', 'l'] * 2):
|
||||
# images[i].save(f'{t}-{i}-{pos}.png')
|
||||
|
||||
# get last qpos in shape [14, ]
|
||||
proprio = observation_window[-1]["qpos"]
|
||||
# unsqueeze to [1, 14]
|
||||
proprio = proprio.unsqueeze(0)
|
||||
|
||||
# actions shaped as [1, 64, 14] in format [left, right]
|
||||
actions = (policy.step(proprio=proprio, images=images, text_embeds=lang_embeddings).squeeze(0).cpu().numpy())
|
||||
# print(f"inference_actions: {actions.squeeze()}")
|
||||
|
||||
# print(f"Model inference time: {time.time() - time1} s")
|
||||
|
||||
# print(f"Finish inference_thread_fn: t={t}")
|
||||
return actions
|
||||
|
||||
|
||||
# Main loop for the manipulation task
|
||||
def model_inference(args, config, ros_operator):
|
||||
global lang_embeddings
|
||||
|
||||
# Load rdt model
|
||||
policy = make_policy(args)
|
||||
|
||||
lang_dict = torch.load(args.lang_embeddings_path)
|
||||
print(f"Running with instruction: \"{lang_dict['instruction']}\" from \"{lang_dict['name']}\"")
|
||||
lang_embeddings = lang_dict["embeddings"]
|
||||
|
||||
max_publish_step = config["episode_len"]
|
||||
chunk_size = config["chunk_size"]
|
||||
|
||||
# Initialize position of the puppet arm
|
||||
left0 = [
|
||||
-0.00133514404296875,
|
||||
0.00209808349609375,
|
||||
0.01583099365234375,
|
||||
-0.032616615295410156,
|
||||
-0.00286102294921875,
|
||||
0.00095367431640625,
|
||||
3.557830810546875,
|
||||
]
|
||||
right0 = [
|
||||
-0.00133514404296875,
|
||||
0.00438690185546875,
|
||||
0.034523963928222656,
|
||||
-0.053597450256347656,
|
||||
-0.00476837158203125,
|
||||
-0.00209808349609375,
|
||||
3.557830810546875,
|
||||
]
|
||||
left1 = [
|
||||
-0.00133514404296875,
|
||||
0.00209808349609375,
|
||||
0.01583099365234375,
|
||||
-0.032616615295410156,
|
||||
-0.00286102294921875,
|
||||
0.00095367431640625,
|
||||
-0.3393220901489258,
|
||||
]
|
||||
right1 = [
|
||||
-0.00133514404296875,
|
||||
0.00247955322265625,
|
||||
0.01583099365234375,
|
||||
-0.032616615295410156,
|
||||
-0.00286102294921875,
|
||||
0.00095367431640625,
|
||||
-0.3397035598754883,
|
||||
]
|
||||
ros_operator.puppet_arm_publish_continuous(left0, right0)
|
||||
input("Press enter to continue")
|
||||
ros_operator.puppet_arm_publish_continuous(left1, right1)
|
||||
# Initialize the previous action to be the initial robot state
|
||||
pre_action = np.zeros(config["state_dim"])
|
||||
pre_action[:14] = np.array([
|
||||
-0.00133514404296875,
|
||||
0.00209808349609375,
|
||||
0.01583099365234375,
|
||||
-0.032616615295410156,
|
||||
-0.00286102294921875,
|
||||
0.00095367431640625,
|
||||
-0.3393220901489258,
|
||||
] + [
|
||||
-0.00133514404296875,
|
||||
0.00247955322265625,
|
||||
0.01583099365234375,
|
||||
-0.032616615295410156,
|
||||
-0.00286102294921875,
|
||||
0.00095367431640625,
|
||||
-0.3397035598754883,
|
||||
])
|
||||
action = None
|
||||
# Inference loop
|
||||
with torch.inference_mode():
|
||||
while True and not rospy.is_shutdown():
|
||||
# The current time step
|
||||
t = 0
|
||||
rate = rospy.Rate(args.publish_rate)
|
||||
|
||||
action_buffer = np.zeros([chunk_size, config["state_dim"]])
|
||||
|
||||
while t < max_publish_step and not rospy.is_shutdown():
|
||||
# Update observation window
|
||||
update_observation_window(args, config, ros_operator)
|
||||
|
||||
# When coming to the end of the action chunk
|
||||
if t % chunk_size == 0:
|
||||
# Start inference
|
||||
action_buffer = inference_fn(args, config, policy, t).copy()
|
||||
|
||||
raw_action = action_buffer[t % chunk_size]
|
||||
action = raw_action
|
||||
# Interpolate the original action sequence
|
||||
if args.use_actions_interpolation:
|
||||
# print(f"Time {t}, pre {pre_action}, act {action}")
|
||||
interp_actions = interpolate_action(args, pre_action, action)
|
||||
else:
|
||||
interp_actions = action[np.newaxis, :]
|
||||
# Execute the interpolated actions one by one
|
||||
for act in interp_actions:
|
||||
left_action = act[:7]
|
||||
right_action = act[7:14]
|
||||
|
||||
if not args.disable_puppet_arm:
|
||||
ros_operator.puppet_arm_publish(left_action,
|
||||
right_action) # puppet_arm_publish_continuous_thread
|
||||
|
||||
if args.use_robot_base:
|
||||
vel_action = act[14:16]
|
||||
ros_operator.robot_base_publish(vel_action)
|
||||
rate.sleep()
|
||||
# print(f"doing action: {act}")
|
||||
t += 1
|
||||
|
||||
print("Published Step", t)
|
||||
pre_action = action.copy()
|
||||
|
||||
|
||||
# ROS operator class
|
||||
class RosOperator:
|
||||
|
||||
def __init__(self, args):
|
||||
self.robot_base_deque = None
|
||||
self.puppet_arm_right_deque = None
|
||||
self.puppet_arm_left_deque = None
|
||||
self.img_front_deque = None
|
||||
self.img_right_deque = None
|
||||
self.img_left_deque = None
|
||||
self.img_front_depth_deque = None
|
||||
self.img_right_depth_deque = None
|
||||
self.img_left_depth_deque = None
|
||||
self.bridge = None
|
||||
self.puppet_arm_left_publisher = None
|
||||
self.puppet_arm_right_publisher = None
|
||||
self.robot_base_publisher = None
|
||||
self.puppet_arm_publish_thread = None
|
||||
self.puppet_arm_publish_lock = None
|
||||
self.args = args
|
||||
self.init()
|
||||
self.init_ros()
|
||||
|
||||
def init(self):
|
||||
self.bridge = CvBridge()
|
||||
self.img_left_deque = deque()
|
||||
self.img_right_deque = deque()
|
||||
self.img_front_deque = deque()
|
||||
self.img_left_depth_deque = deque()
|
||||
self.img_right_depth_deque = deque()
|
||||
self.img_front_depth_deque = deque()
|
||||
self.puppet_arm_left_deque = deque()
|
||||
self.puppet_arm_right_deque = deque()
|
||||
self.robot_base_deque = deque()
|
||||
self.puppet_arm_publish_lock = threading.Lock()
|
||||
self.puppet_arm_publish_lock.acquire()
|
||||
|
||||
def puppet_arm_publish(self, left, right):
|
||||
joint_state_msg = JointState()
|
||||
joint_state_msg.header = Header()
|
||||
joint_state_msg.header.stamp = rospy.Time.now() # Set timestep
|
||||
joint_state_msg.name = [
|
||||
"joint0",
|
||||
"joint1",
|
||||
"joint2",
|
||||
"joint3",
|
||||
"joint4",
|
||||
"joint5",
|
||||
"joint6",
|
||||
] # 设置关节名称
|
||||
joint_state_msg.position = left
|
||||
self.puppet_arm_left_publisher.publish(joint_state_msg)
|
||||
joint_state_msg.position = right
|
||||
self.puppet_arm_right_publisher.publish(joint_state_msg)
|
||||
|
||||
def robot_base_publish(self, vel):
|
||||
vel_msg = Twist()
|
||||
vel_msg.linear.x = vel[0]
|
||||
vel_msg.linear.y = 0
|
||||
vel_msg.linear.z = 0
|
||||
vel_msg.angular.x = 0
|
||||
vel_msg.angular.y = 0
|
||||
vel_msg.angular.z = vel[1]
|
||||
self.robot_base_publisher.publish(vel_msg)
|
||||
|
||||
def puppet_arm_publish_continuous(self, left, right):
|
||||
rate = rospy.Rate(self.args.publish_rate)
|
||||
left_arm = None
|
||||
right_arm = None
|
||||
while True and not rospy.is_shutdown():
|
||||
if len(self.puppet_arm_left_deque) != 0:
|
||||
left_arm = list(self.puppet_arm_left_deque[-1].position)
|
||||
if len(self.puppet_arm_right_deque) != 0:
|
||||
right_arm = list(self.puppet_arm_right_deque[-1].position)
|
||||
if left_arm is None or right_arm is None:
|
||||
rate.sleep()
|
||||
continue
|
||||
else:
|
||||
break
|
||||
left_symbol = [1 if left[i] - left_arm[i] > 0 else -1 for i in range(len(left))]
|
||||
right_symbol = [1 if right[i] - right_arm[i] > 0 else -1 for i in range(len(right))]
|
||||
flag = True
|
||||
step = 0
|
||||
while flag and not rospy.is_shutdown():
|
||||
if self.puppet_arm_publish_lock.acquire(False):
|
||||
return
|
||||
left_diff = [abs(left[i] - left_arm[i]) for i in range(len(left))]
|
||||
right_diff = [abs(right[i] - right_arm[i]) for i in range(len(right))]
|
||||
flag = False
|
||||
for i in range(len(left)):
|
||||
if left_diff[i] < self.args.arm_steps_length[i]:
|
||||
left_arm[i] = left[i]
|
||||
else:
|
||||
left_arm[i] += left_symbol[i] * self.args.arm_steps_length[i]
|
||||
flag = True
|
||||
for i in range(len(right)):
|
||||
if right_diff[i] < self.args.arm_steps_length[i]:
|
||||
right_arm[i] = right[i]
|
||||
else:
|
||||
right_arm[i] += right_symbol[i] * self.args.arm_steps_length[i]
|
||||
flag = True
|
||||
joint_state_msg = JointState()
|
||||
joint_state_msg.header = Header()
|
||||
joint_state_msg.header.stamp = rospy.Time.now() # Set the timestep
|
||||
joint_state_msg.name = [
|
||||
"joint0",
|
||||
"joint1",
|
||||
"joint2",
|
||||
"joint3",
|
||||
"joint4",
|
||||
"joint5",
|
||||
"joint6",
|
||||
] # 设置关节名称
|
||||
joint_state_msg.position = left_arm
|
||||
self.puppet_arm_left_publisher.publish(joint_state_msg)
|
||||
joint_state_msg.position = right_arm
|
||||
self.puppet_arm_right_publisher.publish(joint_state_msg)
|
||||
step += 1
|
||||
print("puppet_arm_publish_continuous:", step)
|
||||
rate.sleep()
|
||||
|
||||
def puppet_arm_publish_linear(self, left, right):
|
||||
num_step = 100
|
||||
rate = rospy.Rate(200)
|
||||
|
||||
left_arm = None
|
||||
right_arm = None
|
||||
|
||||
while True and not rospy.is_shutdown():
|
||||
if len(self.puppet_arm_left_deque) != 0:
|
||||
left_arm = list(self.puppet_arm_left_deque[-1].position)
|
||||
if len(self.puppet_arm_right_deque) != 0:
|
||||
right_arm = list(self.puppet_arm_right_deque[-1].position)
|
||||
if left_arm is None or right_arm is None:
|
||||
rate.sleep()
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
traj_left_list = np.linspace(left_arm, left, num_step)
|
||||
traj_right_list = np.linspace(right_arm, right, num_step)
|
||||
|
||||
for i in range(len(traj_left_list)):
|
||||
traj_left = traj_left_list[i]
|
||||
traj_right = traj_right_list[i]
|
||||
traj_left[-1] = left[-1]
|
||||
traj_right[-1] = right[-1]
|
||||
joint_state_msg = JointState()
|
||||
joint_state_msg.header = Header()
|
||||
joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳
|
||||
joint_state_msg.name = [
|
||||
"joint0",
|
||||
"joint1",
|
||||
"joint2",
|
||||
"joint3",
|
||||
"joint4",
|
||||
"joint5",
|
||||
"joint6",
|
||||
] # 设置关节名称
|
||||
joint_state_msg.position = traj_left
|
||||
self.puppet_arm_left_publisher.publish(joint_state_msg)
|
||||
joint_state_msg.position = traj_right
|
||||
self.puppet_arm_right_publisher.publish(joint_state_msg)
|
||||
rate.sleep()
|
||||
|
||||
def puppet_arm_publish_continuous_thread(self, left, right):
|
||||
if self.puppet_arm_publish_thread is not None:
|
||||
self.puppet_arm_publish_lock.release()
|
||||
self.puppet_arm_publish_thread.join()
|
||||
self.puppet_arm_publish_lock.acquire(False)
|
||||
self.puppet_arm_publish_thread = None
|
||||
self.puppet_arm_publish_thread = threading.Thread(target=self.puppet_arm_publish_continuous, args=(left, right))
|
||||
self.puppet_arm_publish_thread.start()
|
||||
|
||||
def get_frame(self):
|
||||
if (len(self.img_left_deque) == 0 or len(self.img_right_deque) == 0 or len(self.img_front_deque) == 0 or
|
||||
(self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or len(self.img_right_depth_deque) == 0
|
||||
or len(self.img_front_depth_deque) == 0))):
|
||||
return False
|
||||
if self.args.use_depth_image:
|
||||
frame_time = min([
|
||||
self.img_left_deque[-1].header.stamp.to_sec(),
|
||||
self.img_right_deque[-1].header.stamp.to_sec(),
|
||||
self.img_front_deque[-1].header.stamp.to_sec(),
|
||||
self.img_left_depth_deque[-1].header.stamp.to_sec(),
|
||||
self.img_right_depth_deque[-1].header.stamp.to_sec(),
|
||||
self.img_front_depth_deque[-1].header.stamp.to_sec(),
|
||||
])
|
||||
else:
|
||||
frame_time = min([
|
||||
self.img_left_deque[-1].header.stamp.to_sec(),
|
||||
self.img_right_deque[-1].header.stamp.to_sec(),
|
||||
self.img_front_deque[-1].header.stamp.to_sec(),
|
||||
])
|
||||
|
||||
if (len(self.img_left_deque) == 0 or self.img_left_deque[-1].header.stamp.to_sec() < frame_time):
|
||||
return False
|
||||
if (len(self.img_right_deque) == 0 or self.img_right_deque[-1].header.stamp.to_sec() < frame_time):
|
||||
return False
|
||||
if (len(self.img_front_deque) == 0 or self.img_front_deque[-1].header.stamp.to_sec() < frame_time):
|
||||
return False
|
||||
if (len(self.puppet_arm_left_deque) == 0 or self.puppet_arm_left_deque[-1].header.stamp.to_sec() < frame_time):
|
||||
return False
|
||||
if (len(self.puppet_arm_right_deque) == 0
|
||||
or self.puppet_arm_right_deque[-1].header.stamp.to_sec() < frame_time):
|
||||
return False
|
||||
if self.args.use_depth_image and (len(self.img_left_depth_deque) == 0
|
||||
or self.img_left_depth_deque[-1].header.stamp.to_sec() < frame_time):
|
||||
return False
|
||||
if self.args.use_depth_image and (len(self.img_right_depth_deque) == 0
|
||||
or self.img_right_depth_deque[-1].header.stamp.to_sec() < frame_time):
|
||||
return False
|
||||
if self.args.use_depth_image and (len(self.img_front_depth_deque) == 0
|
||||
or self.img_front_depth_deque[-1].header.stamp.to_sec() < frame_time):
|
||||
return False
|
||||
if self.args.use_robot_base and (len(self.robot_base_deque) == 0
|
||||
or self.robot_base_deque[-1].header.stamp.to_sec() < frame_time):
|
||||
return False
|
||||
|
||||
while self.img_left_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.img_left_deque.popleft()
|
||||
img_left = self.bridge.imgmsg_to_cv2(self.img_left_deque.popleft(), "passthrough")
|
||||
|
||||
while self.img_right_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.img_right_deque.popleft()
|
||||
img_right = self.bridge.imgmsg_to_cv2(self.img_right_deque.popleft(), "passthrough")
|
||||
|
||||
while self.img_front_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.img_front_deque.popleft()
|
||||
img_front = self.bridge.imgmsg_to_cv2(self.img_front_deque.popleft(), "passthrough")
|
||||
|
||||
while self.puppet_arm_left_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.puppet_arm_left_deque.popleft()
|
||||
puppet_arm_left = self.puppet_arm_left_deque.popleft()
|
||||
|
||||
while self.puppet_arm_right_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.puppet_arm_right_deque.popleft()
|
||||
puppet_arm_right = self.puppet_arm_right_deque.popleft()
|
||||
|
||||
img_left_depth = None
|
||||
if self.args.use_depth_image:
|
||||
while self.img_left_depth_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.img_left_depth_deque.popleft()
|
||||
img_left_depth = self.bridge.imgmsg_to_cv2(self.img_left_depth_deque.popleft(), "passthrough")
|
||||
|
||||
img_right_depth = None
|
||||
if self.args.use_depth_image:
|
||||
while self.img_right_depth_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.img_right_depth_deque.popleft()
|
||||
img_right_depth = self.bridge.imgmsg_to_cv2(self.img_right_depth_deque.popleft(), "passthrough")
|
||||
|
||||
img_front_depth = None
|
||||
if self.args.use_depth_image:
|
||||
while self.img_front_depth_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.img_front_depth_deque.popleft()
|
||||
img_front_depth = self.bridge.imgmsg_to_cv2(self.img_front_depth_deque.popleft(), "passthrough")
|
||||
|
||||
robot_base = None
|
||||
if self.args.use_robot_base:
|
||||
while self.robot_base_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.robot_base_deque.popleft()
|
||||
robot_base = self.robot_base_deque.popleft()
|
||||
|
||||
return (
|
||||
img_front,
|
||||
img_left,
|
||||
img_right,
|
||||
img_front_depth,
|
||||
img_left_depth,
|
||||
img_right_depth,
|
||||
puppet_arm_left,
|
||||
puppet_arm_right,
|
||||
robot_base,
|
||||
)
|
||||
|
||||
def img_left_callback(self, msg):
|
||||
if len(self.img_left_deque) >= 2000:
|
||||
self.img_left_deque.popleft()
|
||||
self.img_left_deque.append(msg)
|
||||
|
||||
def img_right_callback(self, msg):
|
||||
if len(self.img_right_deque) >= 2000:
|
||||
self.img_right_deque.popleft()
|
||||
self.img_right_deque.append(msg)
|
||||
|
||||
def img_front_callback(self, msg):
|
||||
if len(self.img_front_deque) >= 2000:
|
||||
self.img_front_deque.popleft()
|
||||
self.img_front_deque.append(msg)
|
||||
|
||||
def img_left_depth_callback(self, msg):
|
||||
if len(self.img_left_depth_deque) >= 2000:
|
||||
self.img_left_depth_deque.popleft()
|
||||
self.img_left_depth_deque.append(msg)
|
||||
|
||||
def img_right_depth_callback(self, msg):
|
||||
if len(self.img_right_depth_deque) >= 2000:
|
||||
self.img_right_depth_deque.popleft()
|
||||
self.img_right_depth_deque.append(msg)
|
||||
|
||||
def img_front_depth_callback(self, msg):
|
||||
if len(self.img_front_depth_deque) >= 2000:
|
||||
self.img_front_depth_deque.popleft()
|
||||
self.img_front_depth_deque.append(msg)
|
||||
|
||||
def puppet_arm_left_callback(self, msg):
|
||||
if len(self.puppet_arm_left_deque) >= 2000:
|
||||
self.puppet_arm_left_deque.popleft()
|
||||
self.puppet_arm_left_deque.append(msg)
|
||||
|
||||
def puppet_arm_right_callback(self, msg):
|
||||
if len(self.puppet_arm_right_deque) >= 2000:
|
||||
self.puppet_arm_right_deque.popleft()
|
||||
self.puppet_arm_right_deque.append(msg)
|
||||
|
||||
def robot_base_callback(self, msg):
|
||||
if len(self.robot_base_deque) >= 2000:
|
||||
self.robot_base_deque.popleft()
|
||||
self.robot_base_deque.append(msg)
|
||||
|
||||
def init_ros(self):
|
||||
rospy.init_node("joint_state_publisher", anonymous=True)
|
||||
rospy.Subscriber(
|
||||
self.args.img_left_topic,
|
||||
Image,
|
||||
self.img_left_callback,
|
||||
queue_size=1000,
|
||||
tcp_nodelay=True,
|
||||
)
|
||||
rospy.Subscriber(
|
||||
self.args.img_right_topic,
|
||||
Image,
|
||||
self.img_right_callback,
|
||||
queue_size=1000,
|
||||
tcp_nodelay=True,
|
||||
)
|
||||
rospy.Subscriber(
|
||||
self.args.img_front_topic,
|
||||
Image,
|
||||
self.img_front_callback,
|
||||
queue_size=1000,
|
||||
tcp_nodelay=True,
|
||||
)
|
||||
if self.args.use_depth_image:
|
||||
rospy.Subscriber(
|
||||
self.args.img_left_depth_topic,
|
||||
Image,
|
||||
self.img_left_depth_callback,
|
||||
queue_size=1000,
|
||||
tcp_nodelay=True,
|
||||
)
|
||||
rospy.Subscriber(
|
||||
self.args.img_right_depth_topic,
|
||||
Image,
|
||||
self.img_right_depth_callback,
|
||||
queue_size=1000,
|
||||
tcp_nodelay=True,
|
||||
)
|
||||
rospy.Subscriber(
|
||||
self.args.img_front_depth_topic,
|
||||
Image,
|
||||
self.img_front_depth_callback,
|
||||
queue_size=1000,
|
||||
tcp_nodelay=True,
|
||||
)
|
||||
rospy.Subscriber(
|
||||
self.args.puppet_arm_left_topic,
|
||||
JointState,
|
||||
self.puppet_arm_left_callback,
|
||||
queue_size=1000,
|
||||
tcp_nodelay=True,
|
||||
)
|
||||
rospy.Subscriber(
|
||||
self.args.puppet_arm_right_topic,
|
||||
JointState,
|
||||
self.puppet_arm_right_callback,
|
||||
queue_size=1000,
|
||||
tcp_nodelay=True,
|
||||
)
|
||||
rospy.Subscriber(
|
||||
self.args.robot_base_topic,
|
||||
Odometry,
|
||||
self.robot_base_callback,
|
||||
queue_size=1000,
|
||||
tcp_nodelay=True,
|
||||
)
|
||||
self.puppet_arm_left_publisher = rospy.Publisher(self.args.puppet_arm_left_cmd_topic, JointState, queue_size=10)
|
||||
self.puppet_arm_right_publisher = rospy.Publisher(self.args.puppet_arm_right_cmd_topic,
|
||||
JointState,
|
||||
queue_size=10)
|
||||
self.robot_base_publisher = rospy.Publisher(self.args.robot_base_cmd_topic, Twist, queue_size=10)
|
||||
|
||||
|
||||
def get_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--max_publish_step",
|
||||
action="store",
|
||||
type=int,
|
||||
help="Maximum number of action publishing steps",
|
||||
default=10000,
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
action="store",
|
||||
type=int,
|
||||
help="Random seed",
|
||||
default=None,
|
||||
required=False,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--img_front_topic",
|
||||
action="store",
|
||||
type=str,
|
||||
help="img_front_topic",
|
||||
default="/camera_f/color/image_raw",
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--img_left_topic",
|
||||
action="store",
|
||||
type=str,
|
||||
help="img_left_topic",
|
||||
default="/camera_l/color/image_raw",
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--img_right_topic",
|
||||
action="store",
|
||||
type=str,
|
||||
help="img_right_topic",
|
||||
default="/camera_r/color/image_raw",
|
||||
required=False,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--img_front_depth_topic",
|
||||
action="store",
|
||||
type=str,
|
||||
help="img_front_depth_topic",
|
||||
default="/camera_f/depth/image_raw",
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--img_left_depth_topic",
|
||||
action="store",
|
||||
type=str,
|
||||
help="img_left_depth_topic",
|
||||
default="/camera_l/depth/image_raw",
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--img_right_depth_topic",
|
||||
action="store",
|
||||
type=str,
|
||||
help="img_right_depth_topic",
|
||||
default="/camera_r/depth/image_raw",
|
||||
required=False,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--puppet_arm_left_cmd_topic",
|
||||
action="store",
|
||||
type=str,
|
||||
help="puppet_arm_left_cmd_topic",
|
||||
default="/master/joint_left",
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--puppet_arm_right_cmd_topic",
|
||||
action="store",
|
||||
type=str,
|
||||
help="puppet_arm_right_cmd_topic",
|
||||
default="/master/joint_right",
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--puppet_arm_left_topic",
|
||||
action="store",
|
||||
type=str,
|
||||
help="puppet_arm_left_topic",
|
||||
default="/puppet/joint_left",
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--puppet_arm_right_topic",
|
||||
action="store",
|
||||
type=str,
|
||||
help="puppet_arm_right_topic",
|
||||
default="/puppet/joint_right",
|
||||
required=False,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--robot_base_topic",
|
||||
action="store",
|
||||
type=str,
|
||||
help="robot_base_topic",
|
||||
default="/odom_raw",
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot_base_cmd_topic",
|
||||
action="store",
|
||||
type=str,
|
||||
help="robot_base_topic",
|
||||
default="/cmd_vel",
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_robot_base",
|
||||
action="store_true",
|
||||
help="Whether to use the robot base to move around",
|
||||
default=False,
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--publish_rate",
|
||||
action="store",
|
||||
type=int,
|
||||
help="The rate at which to publish the actions",
|
||||
default=30,
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ctrl_freq",
|
||||
action="store",
|
||||
type=int,
|
||||
help="The control frequency of the robot",
|
||||
default=25,
|
||||
required=False,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--chunk_size",
|
||||
action="store",
|
||||
type=int,
|
||||
help="Action chunk size",
|
||||
default=64,
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arm_steps_length",
|
||||
action="store",
|
||||
type=float,
|
||||
help="The maximum change allowed for each joint per timestep",
|
||||
default=[0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.2],
|
||||
required=False,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_actions_interpolation",
|
||||
action="store_true",
|
||||
help="Whether to interpolate the actions if the difference is too large",
|
||||
default=False,
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_depth_image",
|
||||
action="store_true",
|
||||
help="Whether to use depth images",
|
||||
default=False,
|
||||
required=False,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--disable_puppet_arm",
|
||||
action="store_true",
|
||||
help="Whether to disable the puppet arm. This is useful for safely debugging",
|
||||
default=False,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--config_path",
|
||||
type=str,
|
||||
default="configs/base.yaml",
|
||||
help="Path to the config file",
|
||||
)
|
||||
# parser.add_argument('--cfg_scale', type=float, default=2.0,
|
||||
# help='the scaling factor used to modify the magnitude of the control features during denoising')
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Name or path to the pretrained model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang_embeddings_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the pre-encoded language instruction embeddings",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = get_arguments()
|
||||
ros_operator = RosOperator(args)
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
config = get_config(args)
|
||||
model_inference(args, config, ros_operator)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
344
RDT-1B/scripts/agilex_model.py
Normal file
344
RDT-1B/scripts/agilex_model.py
Normal file
@ -0,0 +1,344 @@
|
||||
import os, sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
from configs.state_vec import STATE_VEC_IDX_MAPPING
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
# get current workspace
|
||||
current_file = Path(__file__)
|
||||
sys.path.append(os.path.join(current_file.parent.parent, "models"))
|
||||
sys.path.append(os.path.join(current_file.parent.parent, "models"))
|
||||
|
||||
from multimodal_encoder.siglip_encoder import SiglipVisionTower
|
||||
from multimodal_encoder.t5_encoder import T5Embedder
|
||||
from rdt_runner import RDTRunner
|
||||
|
||||
# The indices that the raw vector should be mapped to in the unified action vector
|
||||
# AGILEX_STATE_INDICES = [
|
||||
# STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"] for i in range(1)
|
||||
# ] + [
|
||||
# STATE_VEC_IDX_MAPPING["left_gripper_open"]
|
||||
# ] + [
|
||||
# STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(1)
|
||||
# ] + [
|
||||
# STATE_VEC_IDX_MAPPING[f"right_gripper_open"]
|
||||
# ]
|
||||
# AGILEX_STATE_INDICES = None
|
||||
|
||||
|
||||
# Create the RDT model
|
||||
def create_model(args, **kwargs):
|
||||
left_arm_dim, right_arm_dim = (
|
||||
args["arm_dim"]["left_arm_dim"],
|
||||
args["arm_dim"]["right_arm_dim"],
|
||||
)
|
||||
AGILEX_STATE_INDICES = ([STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"]
|
||||
for i in range(left_arm_dim)] + [STATE_VEC_IDX_MAPPING["left_gripper_open"]] +
|
||||
[STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"]
|
||||
for i in range(right_arm_dim)] + [STATE_VEC_IDX_MAPPING[f"right_gripper_open"]])
|
||||
model = RoboticDiffusionTransformerModel(args, **kwargs)
|
||||
pretrained = kwargs.get("pretrained", None)
|
||||
if pretrained is not None and os.path.isfile(pretrained):
|
||||
model.load_pretrained_weights(pretrained)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class RoboticDiffusionTransformerModel(object):
|
||||
"""A wrapper for the RDT model, which handles
|
||||
1. Model initialization
|
||||
2. Encodings of instructions
|
||||
3. Model inference
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
device="cuda",
|
||||
dtype=torch.bfloat16,
|
||||
image_size=None,
|
||||
control_frequency=25,
|
||||
pretrained=None,
|
||||
pretrained_vision_encoder_name_or_path=None,
|
||||
):
|
||||
self.args = args
|
||||
self.dtype = dtype
|
||||
self.image_size = image_size
|
||||
self.device = device
|
||||
self.control_frequency = control_frequency
|
||||
# We do not use the text encoder due to limited GPU memory
|
||||
# self.text_tokenizer, self.text_model = self.get_text_encoder(pretrained_text_encoder_name_or_path)
|
||||
self.image_processor, self.vision_model = self.get_vision_encoder(pretrained_vision_encoder_name_or_path)
|
||||
self.policy = self.get_policy(pretrained)
|
||||
self.left_arm_dim, self.right_arm_dim = (
|
||||
args["arm_dim"]["left_arm_dim"],
|
||||
args["arm_dim"]["right_arm_dim"],
|
||||
)
|
||||
|
||||
self.reset()
|
||||
|
||||
def get_policy(self, pretrained):
|
||||
"""Initialize the model."""
|
||||
# Initialize model with arguments
|
||||
if pretrained is None or os.path.isfile(pretrained):
|
||||
img_cond_len = (self.args["common"]["img_history_size"] * self.args["common"]["num_cameras"] *
|
||||
self.vision_model.num_patches)
|
||||
|
||||
_model = RDTRunner(
|
||||
action_dim=self.args["common"]["state_dim"],
|
||||
pred_horizon=self.args["common"]["action_chunk_size"],
|
||||
config=self.args["model"],
|
||||
lang_token_dim=self.args["model"]["lang_token_dim"],
|
||||
img_token_dim=self.args["model"]["img_token_dim"],
|
||||
state_token_dim=self.args["model"]["state_token_dim"],
|
||||
max_lang_cond_len=self.args["dataset"]["tokenizer_max_length"],
|
||||
img_cond_len=img_cond_len,
|
||||
img_pos_embed_config=[
|
||||
# No initial pos embed in the last grid size
|
||||
# since we've already done in ViT
|
||||
(
|
||||
"image",
|
||||
(
|
||||
self.args["common"]["img_history_size"],
|
||||
self.args["common"]["num_cameras"],
|
||||
-self.vision_model.num_patches,
|
||||
),
|
||||
),
|
||||
],
|
||||
lang_pos_embed_config=[
|
||||
# Similarly, no initial pos embed for language
|
||||
("lang", -self.args["dataset"]["tokenizer_max_length"]),
|
||||
],
|
||||
dtype=self.dtype,
|
||||
)
|
||||
else:
|
||||
_model = RDTRunner.from_pretrained(pretrained)
|
||||
|
||||
return _model
|
||||
|
||||
def get_text_encoder(self, pretrained_text_encoder_name_or_path):
|
||||
text_embedder = T5Embedder(
|
||||
from_pretrained=pretrained_text_encoder_name_or_path,
|
||||
model_max_length=self.args["dataset"]["tokenizer_max_length"],
|
||||
device=self.device,
|
||||
)
|
||||
tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
|
||||
return tokenizer, text_encoder
|
||||
|
||||
def get_vision_encoder(self, pretrained_vision_encoder_name_or_path):
|
||||
vision_encoder = SiglipVisionTower(vision_tower=pretrained_vision_encoder_name_or_path, args=None)
|
||||
image_processor = vision_encoder.image_processor
|
||||
return image_processor, vision_encoder
|
||||
|
||||
def reset(self):
|
||||
"""Set model to evaluation mode."""
|
||||
device = self.device
|
||||
weight_dtype = self.dtype
|
||||
self.policy.eval()
|
||||
# self.text_model.eval()
|
||||
self.vision_model.eval()
|
||||
|
||||
self.policy = self.policy.to(device, dtype=weight_dtype)
|
||||
# self.text_model = self.text_model.to(device, dtype=weight_dtype)
|
||||
self.vision_model = self.vision_model.to(device, dtype=weight_dtype)
|
||||
|
||||
def load_pretrained_weights(self, pretrained=None):
|
||||
if pretrained is None:
|
||||
return
|
||||
print(f"Loading weights from {pretrained}")
|
||||
filename = os.path.basename(pretrained)
|
||||
if filename.endswith(".pt"):
|
||||
checkpoint = torch.load(pretrained)
|
||||
self.policy.load_state_dict(checkpoint["module"])
|
||||
elif filename.endswith(".safetensors"):
|
||||
from safetensors.torch import load_model
|
||||
|
||||
load_model(self.policy, pretrained)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown checkpoint format: {pretrained}")
|
||||
|
||||
def encode_instruction(self, instruction, device="cuda"):
|
||||
"""Encode string instruction to latent embeddings.
|
||||
|
||||
Args:
|
||||
instruction: a string of instruction
|
||||
device: a string of device
|
||||
|
||||
Returns:
|
||||
pred: a tensor of latent embeddings of shape (text_max_length, 512)
|
||||
"""
|
||||
tokens = self.text_tokenizer(instruction, return_tensors="pt", padding="longest",
|
||||
truncation=True)["input_ids"].to(device)
|
||||
|
||||
tokens = tokens.view(1, -1)
|
||||
with torch.no_grad():
|
||||
pred = self.text_model(tokens).last_hidden_state.detach()
|
||||
|
||||
return pred
|
||||
|
||||
def _format_joint_to_state(self, joints):
|
||||
"""
|
||||
Format the joint proprioception into the unified action vector.
|
||||
|
||||
Args:
|
||||
joints (torch.Tensor): The joint proprioception to be formatted.
|
||||
qpos ([B, N, 14]).
|
||||
|
||||
Returns:
|
||||
state (torch.Tensor): The formatted vector for RDT ([B, N, 128]).
|
||||
"""
|
||||
AGILEX_STATE_INDICES = ([STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"]
|
||||
for i in range(self.left_arm_dim)] + [STATE_VEC_IDX_MAPPING["left_gripper_open"]] +
|
||||
[STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"]
|
||||
for i in range(self.right_arm_dim)] + [STATE_VEC_IDX_MAPPING[f"right_gripper_open"]])
|
||||
# Rescale the gripper to the range of [0, 1]
|
||||
joints = joints / torch.tensor(
|
||||
[[[1 for i in range(self.left_arm_dim + 1 + self.right_arm_dim + 1)]]],
|
||||
device=joints.device,
|
||||
dtype=joints.dtype,
|
||||
)
|
||||
|
||||
B, N, _ = joints.shape
|
||||
state = torch.zeros(
|
||||
(B, N, self.args["model"]["state_token_dim"]),
|
||||
device=joints.device,
|
||||
dtype=joints.dtype,
|
||||
)
|
||||
# Fill into the unified state vector
|
||||
state[:, :, AGILEX_STATE_INDICES] = joints
|
||||
# Assemble the mask indicating each dimension's availability
|
||||
state_elem_mask = torch.zeros(
|
||||
(B, self.args["model"]["state_token_dim"]),
|
||||
device=joints.device,
|
||||
dtype=joints.dtype,
|
||||
)
|
||||
state_elem_mask[:, AGILEX_STATE_INDICES] = 1
|
||||
return state, state_elem_mask
|
||||
|
||||
def _unformat_action_to_joint(self, action):
|
||||
"""
|
||||
Unformat the unified action vector into the joint action to be executed.
|
||||
|
||||
Args:
|
||||
action (torch.Tensor): The unified action vector to be unformatted.
|
||||
([B, N, 128])
|
||||
|
||||
Returns:
|
||||
joints (torch.Tensor): The unformatted robot joint action.
|
||||
qpos ([B, N, 14]).
|
||||
"""
|
||||
AGILEX_STATE_INDICES = ([STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"]
|
||||
for i in range(self.left_arm_dim)] + [STATE_VEC_IDX_MAPPING["left_gripper_open"]] +
|
||||
[STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"]
|
||||
for i in range(self.right_arm_dim)] + [STATE_VEC_IDX_MAPPING[f"right_gripper_open"]])
|
||||
action_indices = AGILEX_STATE_INDICES
|
||||
joints = action[:, :, action_indices]
|
||||
|
||||
# Rescale the gripper back to the action range
|
||||
# Note that the action range and proprioception range are different
|
||||
# for Mobile ALOHA robot
|
||||
joints = joints * torch.tensor(
|
||||
[[[1 for i in range(self.left_arm_dim + 1 + self.right_arm_dim + 1)]]],
|
||||
device=joints.device,
|
||||
dtype=joints.dtype,
|
||||
)
|
||||
|
||||
return joints
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, proprio, images, text_embeds):
|
||||
"""
|
||||
Predict the next action chunk given the
|
||||
proprioceptive states, images, and instruction embeddings.
|
||||
|
||||
Args:
|
||||
proprio: proprioceptive states
|
||||
images: RGB images, the order should be
|
||||
[ext_{t-1}, right_wrist_{t-1}, left_wrist_{t-1},
|
||||
ext_{t}, right_wrist_{t}, left_wrist_{t}]
|
||||
text_embeds: instruction embeddings
|
||||
|
||||
Returns:
|
||||
action: predicted action
|
||||
"""
|
||||
device = self.device
|
||||
dtype = self.dtype
|
||||
|
||||
# The background image used for padding
|
||||
background_color = np.array([int(x * 255) for x in self.image_processor.image_mean],
|
||||
dtype=np.uint8).reshape(1, 1, 3)
|
||||
background_image = (np.ones(
|
||||
(
|
||||
self.image_processor.size["height"],
|
||||
self.image_processor.size["width"],
|
||||
3,
|
||||
),
|
||||
dtype=np.uint8,
|
||||
) * background_color)
|
||||
|
||||
# Preprocess the images by order and encode them
|
||||
image_tensor_list = []
|
||||
for image in images:
|
||||
if image is None:
|
||||
# Replace it with the background image
|
||||
image = Image.fromarray(background_image)
|
||||
|
||||
if self.image_size is not None:
|
||||
image = transforms.Resize(self.data_args.image_size)(image)
|
||||
|
||||
if self.args["dataset"].get("auto_adjust_image_brightness", False):
|
||||
pixel_values = list(image.getdata())
|
||||
average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)
|
||||
if average_brightness <= 0.15:
|
||||
image = transforms.ColorJitter(brightness=(1.75, 1.75))(image)
|
||||
|
||||
if self.args["dataset"].get("image_aspect_ratio", "pad") == "pad":
|
||||
|
||||
def expand2square(pil_img, background_color):
|
||||
width, height = pil_img.size
|
||||
if width == height:
|
||||
return pil_img
|
||||
elif width > height:
|
||||
result = Image.new(pil_img.mode, (width, width), background_color)
|
||||
result.paste(pil_img, (0, (width - height) // 2))
|
||||
return result
|
||||
else:
|
||||
result = Image.new(pil_img.mode, (height, height), background_color)
|
||||
result.paste(pil_img, ((height - width) // 2, 0))
|
||||
return result
|
||||
|
||||
image = expand2square(image, tuple(int(x * 255) for x in self.image_processor.image_mean))
|
||||
image = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
|
||||
image_tensor_list.append(image)
|
||||
|
||||
image_tensor = torch.stack(image_tensor_list, dim=0).to(device, dtype=dtype)
|
||||
|
||||
image_embeds = self.vision_model(image_tensor).detach()
|
||||
image_embeds = image_embeds.reshape(-1, self.vision_model.hidden_size).unsqueeze(0)
|
||||
|
||||
# Prepare the proprioception states and the control frequency
|
||||
joints = proprio.to(device).unsqueeze(0) # (1, 1, 14)
|
||||
states, state_elem_mask = self._format_joint_to_state(joints) # (1, 1, 128), (1, 128)
|
||||
states, state_elem_mask = states.to(device, dtype=dtype), state_elem_mask.to(device, dtype=dtype)
|
||||
states = states[:, -1:, :] # (1, 1, 128)
|
||||
ctrl_freqs = torch.tensor([self.control_frequency]).to(device)
|
||||
|
||||
text_embeds = text_embeds.to(device, dtype=dtype)
|
||||
|
||||
# Predict the next action chunk given the inputs
|
||||
trajectory = self.policy.predict_action(
|
||||
lang_tokens=text_embeds,
|
||||
lang_attn_mask=torch.ones(text_embeds.shape[:2], dtype=torch.bool, device=text_embeds.device),
|
||||
img_tokens=image_embeds,
|
||||
state_tokens=states,
|
||||
action_mask=state_elem_mask.unsqueeze(1),
|
||||
ctrl_freqs=ctrl_freqs,
|
||||
)
|
||||
trajectory = self._unformat_action_to_joint(trajectory).to(torch.float32)
|
||||
|
||||
return trajectory
|
||||
53
RDT-1B/scripts/encode_lang.py
Normal file
53
RDT-1B/scripts/encode_lang.py
Normal file
@ -0,0 +1,53 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
from models.multimodal_encoder.t5_encoder import T5Embedder
|
||||
|
||||
GPU = 0
|
||||
MODEL_PATH = "google/t5-v1_1-xxl"
|
||||
CONFIG_PATH = "configs/base.yaml"
|
||||
SAVE_DIR = "outs/"
|
||||
|
||||
# Modify this to your task name and instruction
|
||||
TASK_NAME = "handover_pan"
|
||||
INSTRUCTION = "Pick up the black marker on the right and put it into the packaging box on the left."
|
||||
|
||||
# Note: if your GPU VRAM is less than 24GB,
|
||||
# it is recommended to enable offloading by specifying an offload directory.
|
||||
OFFLOAD_DIR = (
|
||||
None # Specify your offload directory here, ensuring the directory exists.
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
with open(CONFIG_PATH, "r") as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
|
||||
device = torch.device(f"cuda:{GPU}")
|
||||
text_embedder = T5Embedder(
|
||||
from_pretrained=MODEL_PATH,
|
||||
model_max_length=config["dataset"]["tokenizer_max_length"],
|
||||
device=device,
|
||||
use_offload_folder=OFFLOAD_DIR,
|
||||
)
|
||||
tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
|
||||
|
||||
tokens = tokenizer(INSTRUCTION, return_tensors="pt", padding="longest", truncation=True)["input_ids"].to(device)
|
||||
|
||||
tokens = tokens.view(1, -1)
|
||||
with torch.no_grad():
|
||||
pred = text_encoder(tokens).last_hidden_state.detach().cpu()
|
||||
|
||||
save_path = os.path.join(SAVE_DIR, f"{TASK_NAME}.pt")
|
||||
# We save the embeddings in a dictionary format
|
||||
torch.save({"name": TASK_NAME, "instruction": INSTRUCTION, "embeddings": pred}, save_path)
|
||||
|
||||
print(
|
||||
f'"{INSTRUCTION}" from "{TASK_NAME}" is encoded by "{MODEL_PATH}" into shape {pred.shape} and saved to "{save_path}"'
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
57
RDT-1B/scripts/encode_lang_batch_once.py
Normal file
57
RDT-1B/scripts/encode_lang_batch_once.py
Normal file
@ -0,0 +1,57 @@
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
import torch
|
||||
import yaml
|
||||
from tqdm import tqdm
|
||||
|
||||
from models.multimodal_encoder.t5_encoder import T5Embedder
|
||||
|
||||
|
||||
def encode_lang(
|
||||
DATA_FILE_PATH,
|
||||
TARGET_DIR,
|
||||
GPU,
|
||||
desc_type="seen",
|
||||
tokenizer=None,
|
||||
text_encoder=None,
|
||||
):
|
||||
current_dir = os.path.dirname(__file__)
|
||||
|
||||
with open(os.path.join(current_dir, "../configs/base.yaml"), "r") as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
|
||||
device = torch.device(f"cuda:{GPU}")
|
||||
if tokenizer is None or text_encoder is None:
|
||||
text_embedder = T5Embedder(
|
||||
from_pretrained=os.path.join(current_dir, "../../weights/RDT/t5-v1_1-xxl"),
|
||||
model_max_length=config["dataset"]["tokenizer_max_length"],
|
||||
device=device,
|
||||
use_offload_folder=None,
|
||||
)
|
||||
tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
|
||||
|
||||
with open(DATA_FILE_PATH, "r") as f_instr:
|
||||
instruction_dict = json.load(f_instr)
|
||||
|
||||
instructions = instruction_dict[desc_type]
|
||||
|
||||
# Encode the instructions
|
||||
tokenized_res = tokenizer(instructions, return_tensors="pt", padding="longest", truncation=True)
|
||||
tokens = tokenized_res["input_ids"].to(device)
|
||||
attn_mask = tokenized_res["attention_mask"].to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
text_embeds = (text_encoder(input_ids=tokens, attention_mask=attn_mask)["last_hidden_state"].detach().cpu())
|
||||
|
||||
attn_mask = attn_mask.cpu().bool()
|
||||
if not os.path.exists(f"{TARGET_DIR}/instructions"):
|
||||
os.makedirs(f"{TARGET_DIR}/instructions")
|
||||
# Save the embeddings for training use
|
||||
for i in range(len(instructions)):
|
||||
text_embed = text_embeds[i][attn_mask[i]]
|
||||
save_path = os.path.join(TARGET_DIR, f"instructions/lang_embed_{i}.pt")
|
||||
# print("encoded instructions save_path:",save_path)
|
||||
torch.save(text_embed, save_path)
|
||||
|
||||
return tokenizer, text_encoder
|
||||
84
RDT-1B/scripts/generate_output_json.py
Normal file
84
RDT-1B/scripts/generate_output_json.py
Normal file
@ -0,0 +1,84 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
|
||||
def extract_metrics_from_log(log_file_path):
|
||||
all_metrics = []
|
||||
pattern = re.compile(
|
||||
r"\{'agilex_sample_mse':\s*([0-9.eE+-]+),\s*'agilex_sample_l2err':\s*([0-9.eE+-]+),\s*'overall_avg_sample_mse':\s*([0-9.eE+-]+),\s*'overall_avg_sample_l2err':\s*([0-9.eE+-]+)\}"
|
||||
)
|
||||
try:
|
||||
with open(log_file_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
m = pattern.search(line)
|
||||
if m:
|
||||
metrics = (
|
||||
float(m.group(1)),
|
||||
float(m.group(2)),
|
||||
float(m.group(3)),
|
||||
float(m.group(4))
|
||||
)
|
||||
all_metrics.append(metrics)
|
||||
print(f"Find Metrics: agilex_sample_mse={metrics[0]}, agilex_sample_l2err={metrics[1]}, "
|
||||
f"overall_avg_sample_mse={metrics[2]}, overall_avg_sample_l2err={metrics[3]}")
|
||||
except Exception as e:
|
||||
print(f"Failed to read log: {e}")
|
||||
return (None, None, None, None)
|
||||
|
||||
if not all_metrics:
|
||||
print("No metrics found in the log file")
|
||||
return (None, None, None, None)
|
||||
|
||||
print(f"\nTotal {len(all_metrics)} metrics found in the log file")
|
||||
|
||||
best_agilex_mse = min(m[0] for m in all_metrics)
|
||||
best_agilex_l2err = min(m[1] for m in all_metrics)
|
||||
best_overall_mse = min(m[2] for m in all_metrics)
|
||||
best_overall_l2err = min(m[3] for m in all_metrics)
|
||||
|
||||
print(f"\nBest metrics:")
|
||||
print(f" agilex_sample_mse: {best_agilex_mse}")
|
||||
print(f" agilex_sample_l2err: {best_agilex_l2err}")
|
||||
print(f" overall_avg_sample_mse: {best_overall_mse}")
|
||||
print(f" overall_avg_sample_l2err: {best_overall_l2err}")
|
||||
|
||||
return (best_agilex_mse, best_agilex_l2err, best_overall_mse, best_overall_l2err)
|
||||
|
||||
def generate_output_json(input_config_file, output_dir, runtime):
|
||||
with open(input_config_file, 'r') as f:
|
||||
config = json.load(f)
|
||||
|
||||
log_file = os.path.join(output_dir, 'output.log')
|
||||
agilex_sample_mse, agilex_sample_l2err, overall_avg_sample_mse, overall_avg_sample_l2err = extract_metrics_from_log(log_file)
|
||||
|
||||
if None in [agilex_sample_mse, agilex_sample_l2err, overall_avg_sample_mse, overall_avg_sample_l2err]:
|
||||
print("Warning: Some metrics are missing in the log file.")
|
||||
|
||||
output_json = {
|
||||
"task_id": config.get("task_id"),
|
||||
"model_type": "RDT-1B",
|
||||
"model_name": config.get("model_name") if "model_name" in config else config.get("train", {}).get("model"),
|
||||
"gpu_id": config.get("gpu_id"),
|
||||
"runtime": runtime,
|
||||
"log_path": log_file,
|
||||
"output_dir": output_dir,
|
||||
"model_path": os.path.join(output_dir, 'pytorch_model.bin'),
|
||||
"metrics": {
|
||||
"agilex_sample_mse": agilex_sample_mse,
|
||||
"agilex_sample_l2err": agilex_sample_l2err,
|
||||
"overall_avg_sample_mse": overall_avg_sample_mse,
|
||||
"overall_avg_sample_l2err": overall_avg_sample_l2err
|
||||
}
|
||||
}
|
||||
|
||||
# 写入 output.json,格式化输出、确保null与规范json一致
|
||||
output_json_path = os.path.join(output_dir, 'output.json')
|
||||
with open(output_json_path, 'w') as f:
|
||||
json.dump(output_json, f, indent=4, ensure_ascii=False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 4:
|
||||
print("Usage: python generate_output_json.py <input_config_file> <output_dir> <runtime>")
|
||||
sys.exit(1)
|
||||
generate_output_json(sys.argv[1], sys.argv[2], sys.argv[3])
|
||||
325
RDT-1B/scripts/maniskill_model.py
Normal file
325
RDT-1B/scripts/maniskill_model.py
Normal file
@ -0,0 +1,325 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
from configs.state_vec import STATE_VEC_IDX_MAPPING
|
||||
from models.multimodal_encoder.siglip_encoder import SiglipVisionTower
|
||||
from models.multimodal_encoder.t5_encoder import T5Embedder
|
||||
from models.rdt_runner import RDTRunner
|
||||
|
||||
MANISKILL_INDICES = [STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"]
|
||||
for i in range(7)] + [STATE_VEC_IDX_MAPPING[f"right_gripper_open"]]
|
||||
|
||||
|
||||
def create_model(args, pretrained, **kwargs):
|
||||
model = RoboticDiffusionTransformerModel(args, **kwargs)
|
||||
if pretrained is not None:
|
||||
model.load_pretrained_weights(pretrained)
|
||||
return model
|
||||
|
||||
|
||||
DATA_STAT = {
|
||||
"state_min": [
|
||||
-0.7463043928146362,
|
||||
-0.0801204964518547,
|
||||
-0.4976441562175751,
|
||||
-2.657780647277832,
|
||||
-0.5742632150650024,
|
||||
1.8309762477874756,
|
||||
-2.2423808574676514,
|
||||
0.0,
|
||||
],
|
||||
"state_max": [
|
||||
0.7645499110221863,
|
||||
1.4967026710510254,
|
||||
0.4650936424732208,
|
||||
-0.3866899907588959,
|
||||
0.5505855679512024,
|
||||
3.2900545597076416,
|
||||
2.5737812519073486,
|
||||
0.03999999910593033,
|
||||
],
|
||||
"action_min": [
|
||||
-0.7472005486488342,
|
||||
-0.08631071448326111,
|
||||
-0.4995281398296356,
|
||||
-2.658363103866577,
|
||||
-0.5751323103904724,
|
||||
1.8290787935256958,
|
||||
-2.245187997817993,
|
||||
-1.0,
|
||||
],
|
||||
"action_max": [
|
||||
0.7654682397842407,
|
||||
1.4984270334243774,
|
||||
0.46786263585090637,
|
||||
-0.38181185722351074,
|
||||
0.5517147779464722,
|
||||
3.291581630706787,
|
||||
2.575840711593628,
|
||||
1.0,
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class RoboticDiffusionTransformerModel(object):
|
||||
"""A wrapper for the RDT model, which handles
|
||||
1. Model initialization
|
||||
2. Encodings of instructions
|
||||
3. Model inference
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
device="cuda",
|
||||
dtype=torch.bfloat16,
|
||||
image_size=None,
|
||||
control_frequency=25,
|
||||
pretrained_text_encoder_name_or_path=None,
|
||||
pretrained_vision_encoder_name_or_path=None,
|
||||
):
|
||||
self.args = args
|
||||
self.dtype = dtype
|
||||
self.image_size = image_size
|
||||
self.device = device
|
||||
self.control_frequency = control_frequency
|
||||
self.text_tokenizer, self.text_model = self.get_text_encoder(pretrained_text_encoder_name_or_path)
|
||||
self.image_processor, self.vision_model = self.get_vision_encoder(pretrained_vision_encoder_name_or_path)
|
||||
self.policy = self.get_policy()
|
||||
|
||||
self.state_min = torch.tensor(DATA_STAT["state_min"]).to(device)
|
||||
self.state_max = torch.tensor(DATA_STAT["state_max"]).to(device)
|
||||
self.action_min = torch.tensor(DATA_STAT["action_min"]).to(device)
|
||||
self.action_max = torch.tensor(DATA_STAT["action_max"]).to(device)
|
||||
|
||||
self.reset()
|
||||
|
||||
def get_policy(self):
|
||||
"""Initialize the model."""
|
||||
# Initialize model with arguments
|
||||
img_cond_len = (self.args["common"]["img_history_size"] * self.args["common"]["num_cameras"] *
|
||||
self.vision_model.num_patches)
|
||||
|
||||
_model = RDTRunner(
|
||||
action_dim=self.args["common"]["state_dim"],
|
||||
pred_horizon=self.args["common"]["action_chunk_size"],
|
||||
config=self.args["model"],
|
||||
lang_token_dim=self.args["model"]["lang_token_dim"],
|
||||
img_token_dim=self.args["model"]["img_token_dim"],
|
||||
state_token_dim=self.args["model"]["state_token_dim"],
|
||||
max_lang_cond_len=self.args["dataset"]["tokenizer_max_length"],
|
||||
img_cond_len=img_cond_len,
|
||||
img_pos_embed_config=[
|
||||
# No initial pos embed in the last grid size
|
||||
# since we've already done in ViT
|
||||
(
|
||||
"image",
|
||||
(
|
||||
self.args["common"]["img_history_size"],
|
||||
self.args["common"]["num_cameras"],
|
||||
-self.vision_model.num_patches,
|
||||
),
|
||||
),
|
||||
],
|
||||
lang_pos_embed_config=[
|
||||
# Similarly, no initial pos embed for language
|
||||
("lang", -self.args["dataset"]["tokenizer_max_length"]),
|
||||
],
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
return _model
|
||||
|
||||
def get_text_encoder(self, pretrained_text_encoder_name_or_path):
|
||||
text_embedder = T5Embedder(
|
||||
from_pretrained=pretrained_text_encoder_name_or_path,
|
||||
model_max_length=self.args["dataset"]["tokenizer_max_length"],
|
||||
device=self.device,
|
||||
)
|
||||
tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
|
||||
return tokenizer, text_encoder
|
||||
|
||||
def get_vision_encoder(self, pretrained_vision_encoder_name_or_path):
|
||||
vision_encoder = SiglipVisionTower(vision_tower=pretrained_vision_encoder_name_or_path, args=None)
|
||||
image_processor = vision_encoder.image_processor
|
||||
return image_processor, vision_encoder
|
||||
|
||||
def reset(self):
|
||||
"""Set model to evaluation mode."""
|
||||
device = self.device
|
||||
weight_dtype = self.dtype
|
||||
self.policy.eval()
|
||||
self.text_model.eval()
|
||||
self.vision_model.eval()
|
||||
|
||||
self.policy = self.policy.to(device, dtype=weight_dtype)
|
||||
self.text_model = self.text_model.to(device, dtype=weight_dtype)
|
||||
self.vision_model = self.vision_model.to(device, dtype=weight_dtype)
|
||||
|
||||
def load_pretrained_weights(self, pretrained=None):
|
||||
if pretrained is None:
|
||||
return
|
||||
print(f"Loading weights from {pretrained}")
|
||||
filename = os.path.basename(pretrained)
|
||||
if filename.endswith(".pt"):
|
||||
checkpoint = torch.load(pretrained)
|
||||
self.policy.load_state_dict(checkpoint["module"])
|
||||
elif filename.endswith(".safetensors"):
|
||||
from safetensors.torch import load_model
|
||||
|
||||
load_model(self.policy, pretrained)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown checkpoint format: {pretrained}")
|
||||
|
||||
def encode_instruction(self, instruction, device="cuda"):
|
||||
"""Encode string instruction to latent embeddings.
|
||||
|
||||
Args:
|
||||
instruction: a string of instruction
|
||||
device: a string of device
|
||||
|
||||
Returns:
|
||||
pred: a tensor of latent embeddings of shape (text_max_length, 512)
|
||||
"""
|
||||
tokens = self.text_tokenizer(instruction, return_tensors="pt", padding="longest",
|
||||
truncation=True)["input_ids"].to(device)
|
||||
|
||||
tokens = tokens.view(1, -1)
|
||||
with torch.no_grad():
|
||||
pred = self.text_model(tokens).last_hidden_state.detach()
|
||||
|
||||
return pred
|
||||
|
||||
def _format_joint_to_state(self, joints):
|
||||
"""
|
||||
Format the robot joint state into the unified state vector.
|
||||
|
||||
Args:
|
||||
joints (torch.Tensor): The joint state to be formatted.
|
||||
qpos ([B, N, 14]).
|
||||
|
||||
Returns:
|
||||
state (torch.Tensor): The formatted state for RDT ([B, N, 128]).
|
||||
"""
|
||||
# Rescale the gripper
|
||||
# joints = joints / torch.tensor(
|
||||
# [[[1, 1, 1, 1, 1, 1, 4.7908, 1, 1, 1, 1, 1, 1, 4.7888]]],
|
||||
# device=joints.device, dtype=joints.dtype
|
||||
# )
|
||||
|
||||
# normalize to -1,1
|
||||
joints = (joints - self.state_min) / (self.state_max - self.state_min) * 2 - 1
|
||||
B, N, _ = joints.shape
|
||||
state = torch.zeros(
|
||||
(B, N, self.args["model"]["state_token_dim"]),
|
||||
device=joints.device,
|
||||
dtype=joints.dtype,
|
||||
)
|
||||
# assemble the unifed state vector
|
||||
state[:, :, MANISKILL_INDICES] = joints
|
||||
state_elem_mask = torch.zeros(
|
||||
(B, self.args["model"]["state_token_dim"]),
|
||||
device=joints.device,
|
||||
dtype=joints.dtype,
|
||||
)
|
||||
state_elem_mask[:, MANISKILL_INDICES] = 1
|
||||
return state, state_elem_mask
|
||||
|
||||
def _unformat_action_to_joint(self, action):
|
||||
action_indices = MANISKILL_INDICES
|
||||
joints = action[:, :, action_indices]
|
||||
|
||||
# denormalize to action space
|
||||
|
||||
joints = (joints + 1) / 2 * (self.action_max - self.action_min) + self.action_min
|
||||
|
||||
return joints
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, proprio, images, text_embeds):
|
||||
"""
|
||||
Args:
|
||||
proprio: proprioceptive states
|
||||
images: RGB images
|
||||
text_embeds: instruction embeddings
|
||||
|
||||
Returns:
|
||||
action: predicted action
|
||||
"""
|
||||
device = self.device
|
||||
dtype = self.dtype
|
||||
|
||||
background_color = np.array([int(x * 255) for x in self.image_processor.image_mean],
|
||||
dtype=np.uint8).reshape(1, 1, 3)
|
||||
background_image = (np.ones(
|
||||
(
|
||||
self.image_processor.size["height"],
|
||||
self.image_processor.size["width"],
|
||||
3,
|
||||
),
|
||||
dtype=np.uint8,
|
||||
) * background_color)
|
||||
|
||||
image_tensor_list = []
|
||||
for image in images:
|
||||
if image is None:
|
||||
# Replace it with the background image
|
||||
image = Image.fromarray(background_image)
|
||||
|
||||
if self.image_size is not None:
|
||||
image = transforms.Resize(self.data_args.image_size)(image)
|
||||
|
||||
if self.args["dataset"].get("auto_adjust_image_brightness", False):
|
||||
pixel_values = list(image.getdata())
|
||||
average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)
|
||||
if average_brightness <= 0.15:
|
||||
image = transforms.ColorJitter(brightness=(1.75, 1.75))(image)
|
||||
|
||||
if self.args["dataset"].get("image_aspect_ratio", "pad") == "pad":
|
||||
|
||||
def expand2square(pil_img, background_color):
|
||||
width, height = pil_img.size
|
||||
if width == height:
|
||||
return pil_img
|
||||
elif width > height:
|
||||
result = Image.new(pil_img.mode, (width, width), background_color)
|
||||
result.paste(pil_img, (0, (width - height) // 2))
|
||||
return result
|
||||
else:
|
||||
result = Image.new(pil_img.mode, (height, height), background_color)
|
||||
result.paste(pil_img, ((height - width) // 2, 0))
|
||||
return result
|
||||
|
||||
image = expand2square(image, tuple(int(x * 255) for x in self.image_processor.image_mean))
|
||||
image = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
|
||||
image_tensor_list.append(image)
|
||||
|
||||
image_tensor = torch.stack(image_tensor_list, dim=0).to(device, dtype=dtype)
|
||||
|
||||
image_embeds = self.vision_model(image_tensor).detach()
|
||||
image_embeds = image_embeds.reshape(-1, self.vision_model.hidden_size).unsqueeze(0)
|
||||
|
||||
# history of actions
|
||||
joints = proprio.to(device).unsqueeze(0) # (1, 1, 14)
|
||||
states, state_elem_mask = self._format_joint_to_state(joints) # (1, 1, 128), (1, 128)
|
||||
states, state_elem_mask = states.to(device, dtype=dtype), state_elem_mask.to(device, dtype=dtype)
|
||||
states = states[:, -1:, :] # (1, 1, 128)
|
||||
ctrl_freqs = torch.tensor([self.control_frequency]).to(device)
|
||||
|
||||
text_embeds = text_embeds.to(device, dtype=dtype)
|
||||
|
||||
trajectory = self.policy.predict_action(
|
||||
lang_tokens=text_embeds,
|
||||
lang_attn_mask=torch.ones(text_embeds.shape[:2], dtype=torch.bool, device=text_embeds.device),
|
||||
img_tokens=image_embeds,
|
||||
state_tokens=states,
|
||||
action_mask=state_elem_mask.unsqueeze(1),
|
||||
ctrl_freqs=ctrl_freqs,
|
||||
)
|
||||
trajectory = self._unformat_action_to_joint(trajectory).to(torch.float32)
|
||||
|
||||
return trajectory
|
||||
169
RDT-1B/scripts/process_data.py
Normal file
169
RDT-1B/scripts/process_data.py
Normal file
@ -0,0 +1,169 @@
|
||||
import sys
|
||||
|
||||
sys.path.append("./")
|
||||
|
||||
import os
|
||||
import h5py
|
||||
import numpy as np
|
||||
import pickle
|
||||
import cv2
|
||||
import argparse
|
||||
import yaml
|
||||
from scripts.encode_lang_batch_once import encode_lang
|
||||
|
||||
|
||||
def load_hdf5(dataset_path):
|
||||
if not os.path.isfile(dataset_path):
|
||||
print(f"Dataset does not exist at \n{dataset_path}\n")
|
||||
exit()
|
||||
|
||||
with h5py.File(dataset_path, "r") as root:
|
||||
left_gripper, left_arm = (
|
||||
root["/joint_action/left_gripper"][()],
|
||||
root["/joint_action/left_arm"][()],
|
||||
)
|
||||
right_gripper, right_arm = (
|
||||
root["/joint_action/right_gripper"][()],
|
||||
root["/joint_action/right_arm"][()],
|
||||
)
|
||||
image_dict = dict()
|
||||
for cam_name in root[f"/observation/"].keys():
|
||||
image_dict[cam_name] = root[f"/observation/{cam_name}/rgb"][()]
|
||||
|
||||
return left_gripper, left_arm, right_gripper, right_arm, image_dict
|
||||
|
||||
|
||||
def images_encoding(imgs):
|
||||
encode_data = []
|
||||
padded_data = []
|
||||
max_len = 0
|
||||
for i in range(len(imgs)):
|
||||
success, encoded_image = cv2.imencode(".jpg", imgs[i])
|
||||
jpeg_data = encoded_image.tobytes()
|
||||
encode_data.append(jpeg_data)
|
||||
max_len = max(max_len, len(jpeg_data))
|
||||
# padding
|
||||
for i in range(len(imgs)):
|
||||
padded_data.append(encode_data[i].ljust(max_len, b"\0"))
|
||||
return encode_data, max_len
|
||||
|
||||
|
||||
def get_task_config(task_name):
|
||||
with open(f"./task_config/{task_name}.yml", "r", encoding="utf-8") as f:
|
||||
args = yaml.load(f.read(), Loader=yaml.FullLoader)
|
||||
return args
|
||||
|
||||
|
||||
def data_transform(path, episode_num, save_path):
|
||||
begin = 0
|
||||
floders = os.listdir(path)
|
||||
assert episode_num <= len(floders), "data num not enough"
|
||||
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
|
||||
for i in range(episode_num):
|
||||
left_gripper_all, left_arm_all, right_gripper_all, right_arm_all, image_dict = (load_hdf5(
|
||||
os.path.join(path, f"episode{i}.hdf5")))
|
||||
qpos = []
|
||||
actions = []
|
||||
cam_high = []
|
||||
cam_right_wrist = []
|
||||
cam_left_wrist = []
|
||||
left_arm_dim = []
|
||||
right_arm_dim = []
|
||||
|
||||
last_state = None
|
||||
for j in range(0, left_gripper_all.shape[0]):
|
||||
|
||||
left_gripper, left_arm, right_gripper, right_arm = (
|
||||
left_gripper_all[j],
|
||||
left_arm_all[j],
|
||||
right_gripper_all[j],
|
||||
right_arm_all[j],
|
||||
)
|
||||
|
||||
state = np.concatenate((left_arm, [left_gripper], right_arm, [right_gripper]), axis=0) # joint
|
||||
state = state.astype(np.float32)
|
||||
|
||||
if j != left_gripper_all.shape[0] - 1:
|
||||
|
||||
qpos.append(state)
|
||||
|
||||
camera_high_bits = image_dict["head_camera"][j]
|
||||
camera_high = cv2.imdecode(np.frombuffer(camera_high_bits, np.uint8), cv2.IMREAD_COLOR)
|
||||
camera_high_resized = cv2.resize(camera_high, (640, 480))
|
||||
cam_high.append(camera_high_resized)
|
||||
|
||||
camera_right_wrist_bits = image_dict["right_camera"][j]
|
||||
camera_right_wrist = cv2.imdecode(np.frombuffer(camera_right_wrist_bits, np.uint8), cv2.IMREAD_COLOR)
|
||||
camera_right_wrist_resized = cv2.resize(camera_right_wrist, (640, 480))
|
||||
cam_right_wrist.append(camera_right_wrist_resized)
|
||||
|
||||
camera_left_wrist_bits = image_dict["left_camera"][j]
|
||||
camera_left_wrist = cv2.imdecode(np.frombuffer(camera_left_wrist_bits, np.uint8), cv2.IMREAD_COLOR)
|
||||
camera_left_wrist_resized = cv2.resize(camera_left_wrist, (640, 480))
|
||||
cam_left_wrist.append(camera_left_wrist_resized)
|
||||
|
||||
if j != 0:
|
||||
action = state
|
||||
actions.append(action)
|
||||
left_arm_dim.append(left_arm.shape[0])
|
||||
right_arm_dim.append(right_arm.shape[0])
|
||||
|
||||
if not os.path.exists(os.path.join(save_path, f"episode_{i}")):
|
||||
os.makedirs(os.path.join(save_path, f"episode_{i}"))
|
||||
hdf5path = os.path.join(save_path, f"episode_{i}/episode_{i}.hdf5")
|
||||
|
||||
with h5py.File(hdf5path, "w") as f:
|
||||
f.create_dataset("action", data=np.array(actions))
|
||||
obs = f.create_group("observations")
|
||||
obs.create_dataset("qpos", data=np.array(qpos))
|
||||
obs.create_dataset("left_arm_dim", data=np.array(left_arm_dim))
|
||||
obs.create_dataset("right_arm_dim", data=np.array(right_arm_dim))
|
||||
image = obs.create_group("images")
|
||||
cam_high_enc, len_high = images_encoding(cam_high)
|
||||
cam_right_wrist_enc, len_right = images_encoding(cam_right_wrist)
|
||||
cam_left_wrist_enc, len_left = images_encoding(cam_left_wrist)
|
||||
image.create_dataset("cam_high", data=cam_high_enc, dtype=f"S{len_high}")
|
||||
image.create_dataset("cam_right_wrist", data=cam_right_wrist_enc, dtype=f"S{len_right}")
|
||||
image.create_dataset("cam_left_wrist", data=cam_left_wrist_enc, dtype=f"S{len_left}")
|
||||
|
||||
begin += 1
|
||||
print(f"proccess {i} success!")
|
||||
|
||||
return begin
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Process some episodes.")
|
||||
parser.add_argument("task_name", type=str)
|
||||
parser.add_argument("task_config", type=str)
|
||||
parser.add_argument("expert_data_num", type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
task_name = args.task_name
|
||||
task_config = args.task_config
|
||||
expert_data_num = args.expert_data_num
|
||||
|
||||
load_dir = os.path.join("../../data", str(task_name), str(task_config), "data")
|
||||
|
||||
print(f"read data from path: {load_dir}")
|
||||
begin = data_transform(
|
||||
load_dir,
|
||||
expert_data_num,
|
||||
f"./processed_data/{task_name}-{task_config}-{expert_data_num}",
|
||||
)
|
||||
tokenizer, text_encoder = None, None
|
||||
for idx in range(expert_data_num):
|
||||
print(f"Processing Language: {idx}", end="\r")
|
||||
data_file_path = (f"../../data/{task_name}/{task_config}/instructions/episode{idx}.json")
|
||||
target_dir = (f"processed_data/{task_name}-{task_config}-{expert_data_num}/episode_{idx}")
|
||||
tokenizer, text_encoder = encode_lang(
|
||||
DATA_FILE_PATH=data_file_path,
|
||||
TARGET_DIR=target_dir,
|
||||
GPU=0,
|
||||
desc_type="seen",
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
)
|
||||
31
RDT-1B/scripts/read_config.py
Normal file
31
RDT-1B/scripts/read_config.py
Normal file
@ -0,0 +1,31 @@
|
||||
import json
|
||||
import yaml
|
||||
import sys
|
||||
|
||||
def read_config(config_file, yaml_file):
|
||||
with open(config_file, 'r') as f:
|
||||
json_config = json.load(f)
|
||||
with open(yaml_file, 'r') as f:
|
||||
yaml_config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
yaml_config["model"] = json_config["train"]["model"] + json_config["task_id"]
|
||||
yaml_config["data_path"] = json_config["train"]["input_data_path"] + "/data"
|
||||
yaml_config["checkpoint_path"] = json_config["train"]["checkpoint_path"] + "/" + json_config["task_id"]
|
||||
yaml_config["pretrained_model_name_or_path"] = json_config["train"]["input_data_path"] + "/weights/rdt-1b"
|
||||
yaml_config["cuda_visible_device"] = str(json_config["gpu_id"])
|
||||
print(f"cuda_visible_device: {yaml_config['cuda_visible_device']}")
|
||||
yaml_config["train_batch_size"] = int(json_config["train"]["batch_size"])
|
||||
yaml_config["sample_batch_size"] = int(json_config["train"]["batch_size"]) * 2
|
||||
yaml_config["max_train_steps"] = int(json_config["train"]["epochs"])
|
||||
yaml_config["checkpointing_period"] = int(int(json_config["train"]["epochs"]) / 10)
|
||||
yaml_config["sample_period"] = 200
|
||||
yaml_config["checkpoints_total_limit"] = 50
|
||||
|
||||
|
||||
with open(yaml_file, 'w') as f:
|
||||
yaml.dump(yaml_config, f, default_flow_style=False)
|
||||
|
||||
print("Config YAML file updated successfully")
|
||||
|
||||
if __name__ == "__main__":
|
||||
read_config(sys.argv[1], sys.argv[2])
|
||||
22
RDT-1B/scripts/read_yaml.py
Normal file
22
RDT-1B/scripts/read_yaml.py
Normal file
@ -0,0 +1,22 @@
|
||||
import sys
|
||||
import yaml
|
||||
|
||||
|
||||
def read_yaml_value(file_path, key):
|
||||
with open(file_path, "r") as file:
|
||||
data = yaml.safe_load(file)
|
||||
value = data.get(key)
|
||||
if value is not None:
|
||||
print(value)
|
||||
else:
|
||||
print(f"Key '{key}' not found in {file_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 3:
|
||||
print("Usage: python read_yaml.py <file_path> <key>")
|
||||
sys.exit(1)
|
||||
|
||||
file_path = sys.argv[1]
|
||||
key = sys.argv[2]
|
||||
read_yaml_value(file_path, key)
|
||||
0
RDT-1B/train/__init__.py
Normal file
0
RDT-1B/train/__init__.py
Normal file
BIN
RDT-1B/train/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
RDT-1B/train/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
RDT-1B/train/__pycache__/dataset.cpython-310.pyc
Normal file
BIN
RDT-1B/train/__pycache__/dataset.cpython-310.pyc
Normal file
Binary file not shown.
BIN
RDT-1B/train/__pycache__/image_corrupt.cpython-310.pyc
Normal file
BIN
RDT-1B/train/__pycache__/image_corrupt.cpython-310.pyc
Normal file
Binary file not shown.
BIN
RDT-1B/train/__pycache__/sample.cpython-310.pyc
Normal file
BIN
RDT-1B/train/__pycache__/sample.cpython-310.pyc
Normal file
Binary file not shown.
BIN
RDT-1B/train/__pycache__/train.cpython-310.pyc
Normal file
BIN
RDT-1B/train/__pycache__/train.cpython-310.pyc
Normal file
Binary file not shown.
479
RDT-1B/train/dataset.py
Normal file
479
RDT-1B/train/dataset.py
Normal file
@ -0,0 +1,479 @@
|
||||
import traceback
|
||||
import time
|
||||
import os
|
||||
import json
|
||||
import math
|
||||
import random
|
||||
from typing import Dict, Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
import transformers
|
||||
|
||||
from data.filelock import FileLock
|
||||
from data.hdf5_vla_dataset import HDF5VLADataset
|
||||
from train.image_corrupt import image_corrupt
|
||||
|
||||
|
||||
def get_clean_item(chunk_dir):
|
||||
"""
|
||||
Get indexes of clean items in a chunk.
|
||||
"""
|
||||
dirty_bit = read_dirty_bit(chunk_dir)
|
||||
return np.where(1 - dirty_bit)[0].tolist()
|
||||
|
||||
|
||||
def save_dirty_bit(chunk_dir, dirty_bit):
|
||||
"""
|
||||
Save the dirty bit to the chunk directory.
|
||||
"""
|
||||
time_stmp = time.time()
|
||||
while time.time() - time_stmp < 10.0:
|
||||
try:
|
||||
file_path = os.path.join(chunk_dir, "dirty_bit")
|
||||
lock = FileLock(file_path)
|
||||
lock.acquire_write_lock()
|
||||
with open(file_path, "wb") as file:
|
||||
file.write(dirty_bit.tobytes())
|
||||
lock.release_lock()
|
||||
return
|
||||
except KeyboardInterrupt:
|
||||
lock.release_lock()
|
||||
raise KeyboardInterrupt
|
||||
except BaseException:
|
||||
lock.release_lock()
|
||||
continue
|
||||
raise RuntimeError("Failed to save dirty bit.")
|
||||
|
||||
|
||||
def read_dirty_bit(chunk_dir):
|
||||
"""
|
||||
Read the dirty bit from the chunk directory.
|
||||
"""
|
||||
# If error occurs, retry
|
||||
time_stmp = time.time()
|
||||
while time.time() - time_stmp < 10.0:
|
||||
try:
|
||||
file_path = os.path.join(chunk_dir, "dirty_bit")
|
||||
lock = FileLock(file_path)
|
||||
lock.acquire_read_lock()
|
||||
with open(file_path, "rb") as file:
|
||||
dirty_bit = np.frombuffer(file.read(), dtype=np.uint8).copy()
|
||||
lock.release_lock()
|
||||
assert len(dirty_bit) > 0
|
||||
return dirty_bit
|
||||
except KeyboardInterrupt:
|
||||
lock.release_lock()
|
||||
raise KeyboardInterrupt
|
||||
except BaseException:
|
||||
lock.release_lock()
|
||||
continue
|
||||
raise RuntimeError("Failed to read dirty bit.")
|
||||
|
||||
|
||||
class VLAConsumerDataset(Dataset):
|
||||
"""A vision-languange-action Dataset for supervised training.
|
||||
This dataset will load data from the buffer directory.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config_path,
|
||||
config,
|
||||
tokenizer,
|
||||
image_processor,
|
||||
num_cameras,
|
||||
img_history_size,
|
||||
image_size=None,
|
||||
auto_adjust_image_brightness=False,
|
||||
image_aug=False,
|
||||
dataset_type="pretrain",
|
||||
cond_mask_prob=0.1,
|
||||
cam_ext_mask_prob=-1.0,
|
||||
state_noise_snr=None,
|
||||
use_hdf5=False,
|
||||
use_precomp_lang_embed=False,
|
||||
):
|
||||
super(VLAConsumerDataset, self).__init__()
|
||||
|
||||
# Load the control frequency for each dataset
|
||||
with open("configs/dataset_control_freq.json", "r") as fp:
|
||||
self.control_freq = json.load(fp)
|
||||
# Load the dataset names
|
||||
dataset_names_cfg = ("configs/pretrain_datasets.json"
|
||||
if dataset_type == "pretrain" else "configs/finetune_datasets.json")
|
||||
with open(dataset_names_cfg, "r") as file:
|
||||
DATASET_NAMES = json.load(file)
|
||||
# Create the mapping between dataset name and id
|
||||
self.dataset_name2id = {name: i for i, name in enumerate(DATASET_NAMES)}
|
||||
self.dataset_id2name = {i: name for i, name in enumerate(DATASET_NAMES)}
|
||||
|
||||
self.image_processor = image_processor
|
||||
self.model_config_path = model_config_path
|
||||
self.buffer_dir = config["buf_path"]
|
||||
self.num_chunks = config["buf_num_chunks"]
|
||||
self.chunk_size = config["buf_chunk_size"]
|
||||
self.tokenizer_max_length = config["tokenizer_max_length"]
|
||||
self.image_aspect_ratio = config["image_aspect_ratio"]
|
||||
self.state_noise_snr = state_noise_snr
|
||||
self.num_cameras = num_cameras
|
||||
self.img_history_size = img_history_size
|
||||
self.cond_mask_prob = cond_mask_prob
|
||||
self.cam_ext_mask_prob = cam_ext_mask_prob
|
||||
self.use_hdf5 = use_hdf5
|
||||
self.hdf5_dataset = None
|
||||
if use_hdf5:
|
||||
self.hdf5_dataset = HDF5VLADataset(self.model_config_path)
|
||||
self.use_precomp_lang_embed = use_precomp_lang_embed
|
||||
if use_precomp_lang_embed:
|
||||
self.empty_lang_embed = torch.load("data/empty_lang_embed.pt")
|
||||
|
||||
# Load dataset stat
|
||||
with open("configs/dataset_stat.json", "r") as f:
|
||||
dataset_stat = json.load(f)
|
||||
self.dataset_stat = dataset_stat
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.image_size = image_size
|
||||
self.auto_adjust_image_brightness = auto_adjust_image_brightness
|
||||
self.image_aug = image_aug
|
||||
|
||||
self.last_content = None
|
||||
self.last_meta = None
|
||||
|
||||
def get_dataset_name2id(self):
|
||||
return self.dataset_name2id
|
||||
|
||||
def get_dataset_id2name(self):
|
||||
return self.dataset_id2name
|
||||
|
||||
@staticmethod
|
||||
def pairwise(iterable):
|
||||
a = iter(iterable)
|
||||
return zip(a, a)
|
||||
|
||||
@staticmethod
|
||||
def _load_data_from_chunk(chunk_dir, chunk_item_idx):
|
||||
# If error occurs, retry
|
||||
time_stmp = time.time()
|
||||
while time.time() - time_stmp < 10.0:
|
||||
try:
|
||||
locks = []
|
||||
file_path = os.path.join(chunk_dir, f"json_content_{chunk_item_idx}.json")
|
||||
lock = FileLock(file_path)
|
||||
locks.append(lock)
|
||||
lock.acquire_read_lock()
|
||||
with open(file_path, "r") as file:
|
||||
json_content = json.load(file)
|
||||
lock.release_lock()
|
||||
file_path = os.path.join(chunk_dir, f"sample_{chunk_item_idx}.npz")
|
||||
lock = FileLock(file_path)
|
||||
locks.append(lock)
|
||||
lock.acquire_read_lock()
|
||||
with open(file_path, "rb") as file:
|
||||
sample_dict = np.load(file)
|
||||
meta = tuple(sample_dict.values())
|
||||
lock.release_lock()
|
||||
return json_content, meta
|
||||
except KeyboardInterrupt:
|
||||
for lock in locks:
|
||||
lock.release_lock()
|
||||
raise KeyboardInterrupt
|
||||
except BaseException:
|
||||
for lock in locks:
|
||||
lock.release_lock()
|
||||
continue
|
||||
raise RuntimeError("Failed to load sample.")
|
||||
|
||||
def __len__(self) -> int:
|
||||
if self.use_hdf5:
|
||||
return len(self.hdf5_dataset)
|
||||
else:
|
||||
return self.num_chunks * self.chunk_size
|
||||
|
||||
def _safe_load(self, index):
|
||||
read_chunk_item_indices = []
|
||||
# Start searching from a random chunk
|
||||
read_chunk_idx = index // self.chunk_size
|
||||
while len(read_chunk_item_indices) == 0:
|
||||
read_chunk_dir = os.path.join(self.buffer_dir, f"chunk_{read_chunk_idx}")
|
||||
try:
|
||||
read_chunk_item_indices = get_clean_item(read_chunk_dir)
|
||||
except BaseException as e:
|
||||
# Print the error info
|
||||
print("Error catched when searching a clean chunk:", e)
|
||||
traceback.print_exc()
|
||||
read_chunk_item_indices = []
|
||||
read_chunk_idx = (read_chunk_idx + 1) % self.num_chunks
|
||||
|
||||
# read_chunk_item_index = random.choice(read_chunk_item_indices)
|
||||
# read_chunk_item_index = read_chunk_item_indices.pop()
|
||||
random_item_index = index % len(read_chunk_item_indices)
|
||||
read_chunk_item_index = read_chunk_item_indices[random_item_index]
|
||||
|
||||
# Modify the dirty bit
|
||||
try:
|
||||
dirty_bit = read_dirty_bit(read_chunk_dir)
|
||||
dirty_bit[read_chunk_item_index] = 1
|
||||
save_dirty_bit(read_chunk_dir, dirty_bit)
|
||||
except BaseException as e:
|
||||
# Print the error info
|
||||
print("Error catched when modifying the dirty bit:", e)
|
||||
traceback.print_exc()
|
||||
|
||||
# load the sample
|
||||
try:
|
||||
content, meta = self._load_data_from_chunk(read_chunk_dir, read_chunk_item_index)
|
||||
self.last_content, self.last_meta = content, meta
|
||||
except BaseException as e:
|
||||
# Print the error info
|
||||
print("Error catched when loading sample:", e)
|
||||
traceback.print_exc()
|
||||
|
||||
# If failed to load the data, return the last loaded data for robustness
|
||||
content, meta = self.last_content, self.last_meta
|
||||
|
||||
return (content, *meta)
|
||||
|
||||
def __getitem__(self, index):
|
||||
# For robustness, we will try to load the data until we succeed
|
||||
while True:
|
||||
data_dict = None
|
||||
try:
|
||||
if self.use_hdf5:
|
||||
res = self.hdf5_dataset.get_item()
|
||||
content = res["meta"]
|
||||
states = res["state"]
|
||||
actions = res["actions"]
|
||||
state_elem_mask = res["state_indicator"]
|
||||
image_metas = [
|
||||
res["cam_high"],
|
||||
res["cam_high_mask"],
|
||||
res["cam_right_wrist"],
|
||||
res["cam_right_wrist_mask"],
|
||||
res["cam_left_wrist"],
|
||||
res["cam_left_wrist_mask"],
|
||||
]
|
||||
state_std = res["state_std"]
|
||||
state_mean = res["state_mean"]
|
||||
state_norm = res["state_norm"]
|
||||
else:
|
||||
(
|
||||
content,
|
||||
_,
|
||||
states,
|
||||
_,
|
||||
actions,
|
||||
_,
|
||||
state_elem_mask,
|
||||
*image_metas,
|
||||
state_std,
|
||||
state_mean,
|
||||
state_norm,
|
||||
) = self._safe_load(index)
|
||||
|
||||
data_dict = {}
|
||||
data_dict["dataset_name"] = content["dataset_name"]
|
||||
data_dict["data_idx"] = self.dataset_name2id[data_dict["dataset_name"]]
|
||||
data_dict["ctrl_freq"] = (self.control_freq[data_dict["dataset_name"]]
|
||||
if random.random() > self.cond_mask_prob else 0)
|
||||
|
||||
if self.state_noise_snr is not None:
|
||||
states += np.random.normal(
|
||||
0.0,
|
||||
state_std / np.sqrt(10**(self.state_noise_snr / 10)),
|
||||
states.shape,
|
||||
)
|
||||
ds_state_mean = np.array(self.dataset_stat[data_dict["dataset_name"]]["state_mean"])
|
||||
ds_state_mean = np.tile(ds_state_mean[None], (states.shape[0], 1))
|
||||
# Randomly mask the states by the mean state
|
||||
data_dict["states"] = (states if random.random() > self.cond_mask_prob else ds_state_mean)
|
||||
data_dict["actions"] = actions
|
||||
data_dict["state_elem_mask"] = (state_elem_mask if random.random() > self.cond_mask_prob else
|
||||
np.zeros_like(state_elem_mask))
|
||||
|
||||
# Stat for the episode that the step belongs to
|
||||
data_dict["state_norm"] = state_norm
|
||||
|
||||
# We replace the invalid images with the background image
|
||||
# and also randomly mask images by the background image
|
||||
background_color = np.array(
|
||||
[int(x * 255) for x in self.image_processor.image_mean],
|
||||
dtype=np.uint8,
|
||||
).reshape(1, 1, 3)
|
||||
background_image = (np.ones(
|
||||
(
|
||||
self.image_processor.size["height"],
|
||||
self.image_processor.size["width"],
|
||||
3,
|
||||
),
|
||||
dtype=np.uint8,
|
||||
) * background_color)
|
||||
|
||||
image_metas = list(self.pairwise(image_metas))
|
||||
mask_probs = [self.cond_mask_prob] * self.num_cameras
|
||||
if self.cam_ext_mask_prob >= 0.0:
|
||||
mask_probs[0] = self.cam_ext_mask_prob
|
||||
rearranged_images = []
|
||||
for i in range(self.img_history_size):
|
||||
for j in range(self.num_cameras):
|
||||
images, image_mask = image_metas[j]
|
||||
image, valid = images[i], image_mask[i]
|
||||
if (valid and (math.prod(image.shape) > 0) and (random.random() > mask_probs[j])):
|
||||
rearranged_images.append((image, True))
|
||||
else:
|
||||
rearranged_images.append((background_image.copy(), False))
|
||||
|
||||
preprocessed_images = []
|
||||
processor = self.image_processor
|
||||
for image, valid in rearranged_images:
|
||||
image = Image.fromarray(image)
|
||||
if self.image_size is not None:
|
||||
image = transforms.Resize(self.image_size)(image) # (1008, 336)
|
||||
# assert image.height == 336, "We haven't prepare for training with images of different resolutions."
|
||||
|
||||
if valid and self.auto_adjust_image_brightness:
|
||||
pixel_values = list(image.getdata())
|
||||
average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)
|
||||
if average_brightness <= 0.15:
|
||||
image = transforms.ColorJitter(brightness=(1.75, 1.75))(image)
|
||||
|
||||
# Only apply image augmentation to 50% of the images
|
||||
if valid and self.image_aug and (random.random() > 0.5):
|
||||
aug_type = random.choice(["corrput_only", "color_only", "both"])
|
||||
if aug_type != "corrput_only":
|
||||
image = transforms.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.5,
|
||||
hue=0.03)(image)
|
||||
if aug_type != "color_only":
|
||||
image = image_corrupt(image)
|
||||
|
||||
if self.image_aspect_ratio == "pad":
|
||||
|
||||
def expand2square(pil_img, background_color):
|
||||
width, height = pil_img.size
|
||||
if width == height:
|
||||
return pil_img
|
||||
elif width > height:
|
||||
result = Image.new(pil_img.mode, (width, width), background_color)
|
||||
result.paste(pil_img, (0, (width - height) // 2))
|
||||
return result
|
||||
else:
|
||||
result = Image.new(pil_img.mode, (height, height), background_color)
|
||||
result.paste(pil_img, ((height - width) // 2, 0))
|
||||
return result
|
||||
|
||||
image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
|
||||
image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
|
||||
preprocessed_images.append(image)
|
||||
data_dict["images"] = preprocessed_images
|
||||
|
||||
if self.use_precomp_lang_embed:
|
||||
if content["instruction"][-1] == ".":
|
||||
content["instruction"] = content["instruction"][:-1]
|
||||
data_dict["lang_embed"] = (torch.load(content["instruction"])
|
||||
if random.random() > self.cond_mask_prob else self.empty_lang_embed)
|
||||
else:
|
||||
instruction = (content["instruction"] if random.random() > self.cond_mask_prob else "")
|
||||
data_dict["input_ids"] = self.tokenizer(
|
||||
instruction,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
truncation=False,
|
||||
).input_ids[0]
|
||||
|
||||
assert (
|
||||
len(data_dict["input_ids"]) <= self.tokenizer_max_length
|
||||
), f"Instruction length {len(data_dict['input_ids'])} exceeds the maximum length {self.tokenizer_max_length}."
|
||||
|
||||
for k, v in data_dict.items():
|
||||
if isinstance(v, np.ndarray):
|
||||
data_dict[k] = torch.from_numpy(v)
|
||||
|
||||
for k, v in data_dict.items():
|
||||
assert not isinstance(v, np.ndarray), f"key: {k}, value: {v}"
|
||||
# data_dict[k] = torch.from_numpy(v)
|
||||
|
||||
return data_dict
|
||||
except BaseException as e:
|
||||
# Print the error info
|
||||
if data_dict is not None:
|
||||
print(
|
||||
f"Error catched when processing sample from {data_dict.get('dataset_name')}:",
|
||||
e,
|
||||
)
|
||||
else:
|
||||
print(f"Error catched when processing sample:", e)
|
||||
traceback.print_exc()
|
||||
# Try incresing the index
|
||||
index = (index + 1) % len(self)
|
||||
|
||||
|
||||
class DataCollatorForVLAConsumerDataset(object):
|
||||
"""Collate examples for supervised training."""
|
||||
|
||||
def __init__(self, tokenizer: transformers.PreTrainedTokenizer) -> None:
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
||||
batch = {
|
||||
"states": [],
|
||||
"actions": [],
|
||||
"state_elem_mask": [],
|
||||
"state_norm": [],
|
||||
"images": [],
|
||||
"data_indices": [],
|
||||
"ctrl_freqs": [],
|
||||
}
|
||||
input_ids = []
|
||||
lang_embeds = []
|
||||
lang_embed_lens = []
|
||||
|
||||
for instance in instances:
|
||||
# Convert all the numpy arrays to tensor
|
||||
keys_to_check = [
|
||||
"states",
|
||||
"actions",
|
||||
"state_elem_mask",
|
||||
"state_norm",
|
||||
]
|
||||
for key in keys_to_check:
|
||||
if isinstance(instance[key], torch.Tensor):
|
||||
item = instance[key]
|
||||
else:
|
||||
item = torch.from_numpy(instance[key])
|
||||
batch[key].append(item)
|
||||
|
||||
if "input_ids" in instance:
|
||||
input_ids.append(instance["input_ids"])
|
||||
else:
|
||||
lang_embeds.append(instance["lang_embed"])
|
||||
lang_embed_lens.append(instance["lang_embed"].shape[0])
|
||||
|
||||
batch["images"].append(torch.stack(instance["images"], dim=0))
|
||||
batch["data_indices"].append(instance["data_idx"])
|
||||
batch["ctrl_freqs"].append(instance["ctrl_freq"])
|
||||
|
||||
keys_to_stack = ["states", "actions", "state_elem_mask", "state_norm", "images"]
|
||||
for key in keys_to_stack:
|
||||
batch[key] = torch.stack(batch[key], dim=0)
|
||||
|
||||
batch["ctrl_freqs"] = torch.tensor(batch["ctrl_freqs"])
|
||||
|
||||
if len(input_ids) > 0:
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids,
|
||||
batch_first=True,
|
||||
padding_value=self.tokenizer.pad_token_id)
|
||||
batch["input_ids"] = input_ids
|
||||
batch["lang_attn_mask"] = input_ids.ne(self.tokenizer.pad_token_id)
|
||||
else:
|
||||
lang_embeds = torch.nn.utils.rnn.pad_sequence(lang_embeds, batch_first=True, padding_value=0)
|
||||
input_lang_attn_mask = torch.zeros(lang_embeds.shape[0], lang_embeds.shape[1], dtype=torch.bool)
|
||||
for i, l in enumerate(lang_embed_lens):
|
||||
input_lang_attn_mask[i, :l] = True
|
||||
batch["lang_embeds"] = lang_embeds
|
||||
batch["lang_attn_mask"] = input_lang_attn_mask
|
||||
|
||||
return batch
|
||||
45
RDT-1B/train/image_corrupt.py
Normal file
45
RDT-1B/train/image_corrupt.py
Normal file
@ -0,0 +1,45 @@
|
||||
import warnings
|
||||
|
||||
warnings.simplefilter(action="ignore", category=FutureWarning)
|
||||
|
||||
import numpy as np
|
||||
|
||||
np.bool = np.bool_
|
||||
import imgaug.augmenters as iaa
|
||||
from PIL import Image
|
||||
|
||||
# Define our sequence of augmentation steps that will be applied to every image.
|
||||
seq = iaa.Sequential(
|
||||
[
|
||||
# Execute one of the following noise augmentations
|
||||
iaa.OneOf([
|
||||
iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05 * 255), per_channel=0.5),
|
||||
iaa.AdditiveLaplaceNoise(scale=(0.0, 0.05 * 255), per_channel=0.5),
|
||||
iaa.AdditivePoissonNoise(lam=(0.0, 0.05 * 255), per_channel=0.5),
|
||||
]),
|
||||
# Execute one or none of the following blur augmentations
|
||||
iaa.SomeOf(
|
||||
(0, 1),
|
||||
[
|
||||
iaa.OneOf([
|
||||
iaa.GaussianBlur((0, 3.0)),
|
||||
iaa.AverageBlur(k=(2, 7)),
|
||||
iaa.MedianBlur(k=(3, 11)),
|
||||
]),
|
||||
iaa.MotionBlur(k=(3, 36)),
|
||||
],
|
||||
),
|
||||
],
|
||||
# do all of the above augmentations in random order
|
||||
random_order=True,
|
||||
)
|
||||
|
||||
|
||||
def image_corrupt(image: Image):
|
||||
image_arr = np.array(image)
|
||||
image_arr = image_arr[None, ...]
|
||||
|
||||
image_arr = seq(images=image_arr)
|
||||
|
||||
image = Image.fromarray(image_arr[0])
|
||||
return image
|
||||
101
RDT-1B/train/sample.py
Normal file
101
RDT-1B/train/sample.py
Normal file
@ -0,0 +1,101 @@
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def log_sample_res(
|
||||
text_encoder,
|
||||
vision_encoder,
|
||||
rdt,
|
||||
args,
|
||||
accelerator,
|
||||
weight_dtype,
|
||||
dataset_id2name,
|
||||
dataloader,
|
||||
logger,
|
||||
):
|
||||
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
||||
logger.info(f"Running sampling for {args.num_sample_batches} batches...")
|
||||
|
||||
rdt.eval()
|
||||
|
||||
loss_for_log = defaultdict(float)
|
||||
loss_counter = defaultdict(int)
|
||||
for step, batch in enumerate(dataloader):
|
||||
if step >= args.num_sample_batches:
|
||||
break
|
||||
|
||||
data_indices = batch["data_indices"]
|
||||
ctrl_freqs = batch["ctrl_freqs"]
|
||||
state_norm = batch["state_norm"].to(dtype=weight_dtype)
|
||||
images = batch["images"].to(dtype=weight_dtype)
|
||||
states = batch["states"].to(dtype=weight_dtype)
|
||||
# We only use the last state as input
|
||||
states = states[:, -1:, :]
|
||||
actions = batch["actions"].to(dtype=weight_dtype)
|
||||
state_elem_mask = batch["state_elem_mask"].to(dtype=weight_dtype)
|
||||
|
||||
batch_size, _, C, H, W = images.shape
|
||||
image_embeds = vision_encoder(images.reshape(-1, C, H, W)).detach()
|
||||
image_embeds = image_embeds.reshape((batch_size, -1, vision_encoder.hidden_size))
|
||||
|
||||
lang_attn_mask = batch["lang_attn_mask"]
|
||||
text_embeds = (batch["lang_embeds"].to(dtype=weight_dtype) if args.precomp_lang_embed else text_encoder(
|
||||
input_ids=batch["input_ids"], attention_mask=lang_attn_mask)["last_hidden_state"].detach())
|
||||
|
||||
pred_actions = rdt.predict_action(
|
||||
lang_tokens=text_embeds,
|
||||
lang_attn_mask=lang_attn_mask,
|
||||
img_tokens=image_embeds,
|
||||
state_tokens=states,
|
||||
action_mask=state_elem_mask.unsqueeze(1),
|
||||
ctrl_freqs=ctrl_freqs,
|
||||
)
|
||||
|
||||
num_steps = pred_actions.shape[1]
|
||||
expanded_state_elem_mask = (state_elem_mask.unsqueeze(1).tile((1, num_steps, 1)).float())
|
||||
expanded_state_norm = (state_norm.unsqueeze(1).tile((1, num_steps, 1)).float())
|
||||
|
||||
loss = F.mse_loss(pred_actions, actions, reduction="none").float()
|
||||
|
||||
mse_loss_per_entry = (loss * expanded_state_elem_mask).reshape(
|
||||
(batch_size, -1)).sum(1) / expanded_state_elem_mask.reshape((batch_size, -1)).sum(1)
|
||||
l2_loss_per_entry = loss.sqrt() / (expanded_state_norm + 1e-3)
|
||||
l2_loss_per_entry = (l2_loss_per_entry * expanded_state_elem_mask).reshape(
|
||||
(batch_size, -1)).sum(1) / expanded_state_elem_mask.reshape((batch_size, -1)).sum(1)
|
||||
|
||||
dataset_indices, mse_losses, l2_losses = accelerator.gather_for_metrics((
|
||||
torch.LongTensor(data_indices).to(device=pred_actions.device),
|
||||
mse_loss_per_entry,
|
||||
l2_loss_per_entry,
|
||||
), )
|
||||
dataset_indices = dataset_indices.tolist()
|
||||
if accelerator.is_main_process:
|
||||
for loss_suffix, losses in zip(["_sample_mse", "_sample_l2err"], [mse_losses, l2_losses]):
|
||||
for dataset_idx, loss_tensor in zip(dataset_indices, losses):
|
||||
loss_name = dataset_id2name[dataset_idx] + loss_suffix
|
||||
loss_for_log[loss_name] += loss_tensor.item()
|
||||
loss_counter[loss_name] += 1
|
||||
|
||||
mse_loss = (loss * expanded_state_elem_mask).sum() / expanded_state_elem_mask.sum()
|
||||
mse_loss_scaler = accelerator.gather(mse_loss).mean().item()
|
||||
loss_for_log["overall_avg_sample_mse"] += mse_loss_scaler
|
||||
|
||||
l2_loss = loss.sqrt() / (expanded_state_norm + 1e-3)
|
||||
l2_loss = (l2_loss * expanded_state_elem_mask).sum() / expanded_state_elem_mask.sum()
|
||||
l2_loss_scaler = accelerator.gather(l2_loss).mean().item()
|
||||
loss_for_log["overall_avg_sample_l2err"] += l2_loss_scaler
|
||||
|
||||
for name in loss_for_log:
|
||||
if name in ["overall_avg_sample_mse", "overall_avg_sample_l2err"]:
|
||||
loss_scaler = loss_for_log[name]
|
||||
loss_for_log[name] = round(loss_scaler / (args.num_sample_batches), 4)
|
||||
else:
|
||||
loss_for_log[name] = round(loss_for_log[name] / loss_counter[name], 4)
|
||||
|
||||
rdt.train()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return dict(loss_for_log)
|
||||
521
RDT-1B/train/train.py
Normal file
521
RDT-1B/train/train.py
Normal file
@ -0,0 +1,521 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import diffusers
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
import transformers
|
||||
import yaml
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import DeepSpeedPlugin, ProjectConfiguration, set_seed
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import is_wandb_available
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from tqdm.auto import tqdm
|
||||
from safetensors.torch import load_model
|
||||
|
||||
from models.ema_model import EMAModel
|
||||
from models.multimodal_encoder.siglip_encoder import SiglipVisionTower
|
||||
from models.multimodal_encoder.t5_encoder import T5Embedder
|
||||
from models.rdt_runner import RDTRunner
|
||||
from train.dataset import DataCollatorForVLAConsumerDataset, VLAConsumerDataset
|
||||
from train.sample import log_sample_res
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
|
||||
def save_model_card(repo_id: str, base_model=str, repo_folder=None):
|
||||
yaml = f"""
|
||||
---
|
||||
license: mit
|
||||
base_model: {base_model}
|
||||
language:
|
||||
- en
|
||||
pipeline_tag: robotics
|
||||
library_name: transformers
|
||||
tags:
|
||||
- robotics
|
||||
- pytorch
|
||||
- multimodal
|
||||
- pretraining
|
||||
- vla
|
||||
- diffusion
|
||||
- rdt
|
||||
---
|
||||
"""
|
||||
model_card = f"""
|
||||
# RDT - {repo_id}
|
||||
|
||||
This is a RDT model derived from {base_model}. The weights were trained using [RDT](https://rdt-robotics.github.io/rdt-robotics/).
|
||||
"""
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
|
||||
|
||||
def train(args, logger):
|
||||
# Read the config
|
||||
with open(args.config_path, "r") as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
|
||||
with open(args.model_config_path, "r") as f:
|
||||
model_config = yaml.safe_load(f)
|
||||
# print(model_config)
|
||||
args.output_dir = model_config["checkpoint_path"]
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
|
||||
accelerator = Accelerator(
|
||||
deepspeed_plugin=(DeepSpeedPlugin(hf_ds_config=args.deepspeed) if args.deepspeed is not None else None),
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=args.report_to,
|
||||
project_dir=logging_dir,
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
filename=args.output_log_path,
|
||||
filemode='w',
|
||||
)
|
||||
logger.info(accelerator.state, main_process_only=False)
|
||||
if accelerator.is_local_main_process:
|
||||
transformers.utils.logging.set_verbosity_warning()
|
||||
diffusers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
diffusers.utils.logging.set_verbosity_error()
|
||||
|
||||
# If passed along, set the training seed now.
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo_id = create_repo(
|
||||
repo_id=args.hub_model_id or Path(args.output_dir).name,
|
||||
exist_ok=True,
|
||||
token=args.hub_token,
|
||||
).repo_id
|
||||
|
||||
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
||||
# as these models are only used for inference, keeping weights in full precision is not required.
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
if args.precomp_lang_embed:
|
||||
tokenizer, text_encoder = None, None
|
||||
else:
|
||||
text_embedder = T5Embedder(
|
||||
from_pretrained=args.pretrained_text_encoder_name_or_path,
|
||||
model_max_length=config["dataset"]["tokenizer_max_length"],
|
||||
device=accelerator.device,
|
||||
)
|
||||
tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
|
||||
|
||||
vision_encoder = SiglipVisionTower(vision_tower=args.pretrained_vision_encoder_name_or_path, args=None)
|
||||
image_processor = vision_encoder.image_processor
|
||||
|
||||
# Load from a pretrained checkpoint
|
||||
if args.pretrained_model_name_or_path is not None and not os.path.isfile(args.pretrained_model_name_or_path):
|
||||
logger.info("Constructing model from pretrained checkpoint.")
|
||||
rdt = RDTRunner.from_pretrained(args.pretrained_model_name_or_path)
|
||||
else:
|
||||
logger.info("Constructing model from provided config.")
|
||||
# Calculate the image condition length
|
||||
img_cond_len = (config["common"]["img_history_size"] * config["common"]["num_cameras"] *
|
||||
vision_encoder.num_patches)
|
||||
rdt = RDTRunner(
|
||||
action_dim=config["common"]["state_dim"],
|
||||
pred_horizon=config["common"]["action_chunk_size"],
|
||||
config=config["model"],
|
||||
lang_token_dim=config["model"]["lang_token_dim"],
|
||||
img_token_dim=config["model"]["img_token_dim"],
|
||||
state_token_dim=config["model"]["state_token_dim"],
|
||||
max_lang_cond_len=config["dataset"]["tokenizer_max_length"],
|
||||
img_cond_len=img_cond_len,
|
||||
img_pos_embed_config=[
|
||||
# No initial pos embed in the last grid size
|
||||
# since we've already done in ViT
|
||||
(
|
||||
"image",
|
||||
(
|
||||
config["common"]["img_history_size"],
|
||||
config["common"]["num_cameras"],
|
||||
-vision_encoder.num_patches,
|
||||
),
|
||||
),
|
||||
],
|
||||
lang_pos_embed_config=[
|
||||
# Similarly, no initial pos embed for language
|
||||
("lang", -config["dataset"]["tokenizer_max_length"]),
|
||||
],
|
||||
dtype=weight_dtype,
|
||||
)
|
||||
|
||||
ema_rdt = copy.deepcopy(rdt)
|
||||
ema_model = EMAModel(
|
||||
ema_rdt,
|
||||
update_after_step=config["model"]["ema"]["update_after_step"],
|
||||
inv_gamma=config["model"]["ema"]["inv_gamma"],
|
||||
power=config["model"]["ema"]["power"],
|
||||
min_value=config["model"]["ema"]["min_value"],
|
||||
max_value=config["model"]["ema"]["max_value"],
|
||||
)
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
# which ensure saving model in huggingface format (config.json + pytorch_model.bin)
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
for model in models:
|
||||
model_to_save = model.module if hasattr(model, "module") else model # type: ignore
|
||||
if isinstance(model_to_save, type(accelerator.unwrap_model(rdt))):
|
||||
model_to_save.save_pretrained(output_dir)
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
# TODO:
|
||||
raise NotImplementedError("Gradient checkpointing is not yet implemented.")
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
if args.allow_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
if args.scale_lr:
|
||||
args.learning_rate = (args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size *
|
||||
accelerator.num_processes)
|
||||
|
||||
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
|
||||
|
||||
optimizer_class = bnb.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
# Optimizer creation
|
||||
params_to_optimize = rdt.parameters()
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
# Dataset and DataLoaders creation:
|
||||
train_dataset = VLAConsumerDataset(
|
||||
model_config_path=args.model_config_path, # TODO
|
||||
config=config["dataset"],
|
||||
tokenizer=tokenizer,
|
||||
image_processor=image_processor,
|
||||
num_cameras=config["common"]["num_cameras"],
|
||||
img_history_size=config["common"]["img_history_size"],
|
||||
dataset_type=args.dataset_type,
|
||||
image_aug=args.image_aug,
|
||||
cond_mask_prob=args.cond_mask_prob,
|
||||
cam_ext_mask_prob=args.cam_ext_mask_prob,
|
||||
state_noise_snr=args.state_noise_snr,
|
||||
use_hdf5=args.load_from_hdf5,
|
||||
use_precomp_lang_embed=args.precomp_lang_embed,
|
||||
)
|
||||
sample_dataset = VLAConsumerDataset(
|
||||
model_config_path=args.model_config_path, # TODO
|
||||
config=config["dataset"],
|
||||
tokenizer=tokenizer,
|
||||
image_processor=image_processor,
|
||||
num_cameras=config["common"]["num_cameras"],
|
||||
img_history_size=config["common"]["img_history_size"],
|
||||
dataset_type=args.dataset_type,
|
||||
image_aug=False,
|
||||
cond_mask_prob=0,
|
||||
cam_ext_mask_prob=-1,
|
||||
state_noise_snr=None,
|
||||
use_hdf5=args.load_from_hdf5,
|
||||
use_precomp_lang_embed=args.precomp_lang_embed,
|
||||
)
|
||||
|
||||
data_collator = DataCollatorForVLAConsumerDataset(tokenizer)
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=data_collator,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
pin_memory=True,
|
||||
persistent_workers=True,
|
||||
)
|
||||
sample_dataloader = torch.utils.data.DataLoader(
|
||||
sample_dataset,
|
||||
batch_size=args.sample_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=data_collator,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
pin_memory=True,
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
num_cycles=args.lr_num_cycles,
|
||||
power=args.lr_power,
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
rdt, optimizer, train_dataloader, sample_dataloader, lr_scheduler = (accelerator.prepare(
|
||||
rdt, optimizer, train_dataloader, sample_dataloader, lr_scheduler))
|
||||
|
||||
ema_rdt.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
if text_encoder is not None:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
if vision_encoder is not None:
|
||||
vision_encoder.vision_tower.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers initializes automatically on the main process.
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers(
|
||||
"VLA",
|
||||
config=vars(args),
|
||||
init_kwargs={"wandb": {
|
||||
"name": f"RoboTwin_RDT_{args.CONFIG_NAME}",
|
||||
}},
|
||||
)
|
||||
|
||||
# Train!
|
||||
total_batch_size = (args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps)
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
||||
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
|
||||
# Load from a pretrained checkpoint
|
||||
if (args.resume_from_checkpoint is None and args.pretrained_model_name_or_path is not None
|
||||
and os.path.isfile(args.pretrained_model_name_or_path)):
|
||||
# Since EMA is deprecated, we do not load EMA from the pretrained checkpoint
|
||||
logger.info("Loading from a pretrained checkpoint.")
|
||||
checkpoint = torch.load(args.pretrained_model_name_or_path)
|
||||
rdt.module.load_state_dict(checkpoint["module"])
|
||||
|
||||
# Potentially load in the weights and states from a previous save
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint != "latest":
|
||||
path = os.path.basename(args.resume_from_checkpoint)
|
||||
else:
|
||||
# Get the mos recent checkpoint
|
||||
dirs = os.listdir(args.output_dir)
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
path = dirs[-1] if len(dirs) > 0 else None
|
||||
|
||||
if path is None:
|
||||
accelerator.print(
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.")
|
||||
args.resume_from_checkpoint = None
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
try:
|
||||
accelerator.load_state(os.path.join(args.output_dir, path)) # load_module_strict=False
|
||||
except:
|
||||
# load deepspeed's state_dict
|
||||
logger.info("Resuming training state failed. Attempting to only load from model checkpoint.")
|
||||
checkpoint = torch.load(
|
||||
os.path.join(
|
||||
args.output_dir,
|
||||
path,
|
||||
"pytorch_model",
|
||||
"mp_rank_00_model_states.pt",
|
||||
))
|
||||
rdt.module.load_state_dict(checkpoint["module"])
|
||||
|
||||
load_model(ema_rdt, os.path.join(args.output_dir, path, "ema", "model.safetensors"))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(
|
||||
range(global_step, args.max_train_steps),
|
||||
disable=not accelerator.is_local_main_process,
|
||||
)
|
||||
progress_bar.set_description("Steps")
|
||||
|
||||
loss_for_log = {}
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
|
||||
rdt.train()
|
||||
|
||||
# Set the progress_bar to correct position
|
||||
if args.resume_from_checkpoint and epoch == first_epoch:
|
||||
progress_bar.update(resume_step // args.gradient_accumulation_steps)
|
||||
|
||||
# Forward and backward...
|
||||
for batch in train_dataloader:
|
||||
with accelerator.accumulate(rdt):
|
||||
images = batch["images"].to(dtype=weight_dtype)
|
||||
states = batch["states"].to(dtype=weight_dtype) # (B, T, D_a)
|
||||
# We only use the last state as input
|
||||
states = states[:, -1:, :]
|
||||
actions = batch["actions"].to(dtype=weight_dtype)
|
||||
state_elem_mask = batch["state_elem_mask"].to(dtype=weight_dtype)
|
||||
ctrl_freqs = batch["ctrl_freqs"]
|
||||
|
||||
with torch.no_grad():
|
||||
batch_size, _, C, H, W = images.shape
|
||||
image_embeds = vision_encoder(images.reshape(-1, C, H, W)).detach()
|
||||
image_embeds = image_embeds.reshape((batch_size, -1, vision_encoder.hidden_size))
|
||||
|
||||
lang_attn_mask = batch["lang_attn_mask"]
|
||||
text_embeds = (batch["lang_embeds"].to(
|
||||
dtype=weight_dtype) if args.precomp_lang_embed else text_encoder(
|
||||
input_ids=batch["input_ids"], attention_mask=lang_attn_mask)["last_hidden_state"].detach())
|
||||
|
||||
state_elem_mask = state_elem_mask.unsqueeze(1)
|
||||
loss = rdt(
|
||||
lang_tokens=text_embeds,
|
||||
lang_attn_mask=lang_attn_mask,
|
||||
img_tokens=image_embeds,
|
||||
state_tokens=states,
|
||||
action_gt=actions,
|
||||
action_mask=state_elem_mask,
|
||||
ctrl_freqs=ctrl_freqs,
|
||||
)
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
params_to_clip = rdt.parameters()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=args.set_grads_to_none)
|
||||
|
||||
ema_model.step(accelerator.unwrap_model(rdt))
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if global_step % args.checkpointing_period == 0:
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
ema_save_path = os.path.join(save_path, f"ema")
|
||||
accelerator.save_model(ema_rdt, ema_save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
if args.sample_period > 0 and global_step % args.sample_period == 0:
|
||||
sample_loss_for_log = log_sample_res(
|
||||
text_encoder,
|
||||
vision_encoder,
|
||||
rdt, # We do not use EMA currently
|
||||
args,
|
||||
accelerator,
|
||||
weight_dtype,
|
||||
sample_dataset.get_dataset_id2name(),
|
||||
sample_dataloader,
|
||||
logger,
|
||||
)
|
||||
logger.info(sample_loss_for_log)
|
||||
accelerator.log(sample_loss_for_log, step=global_step)
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
logs.update(loss_for_log)
|
||||
# logger.info(logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
accelerator.unwrap_model(rdt).save_pretrained(args.output_dir)
|
||||
ema_save_path = os.path.join(args.output_dir, f"ema")
|
||||
accelerator.save_model(ema_rdt, ema_save_path)
|
||||
|
||||
logger.info(f"Saved Model to {args.output_dir}")
|
||||
|
||||
if args.push_to_hub:
|
||||
save_model_card(
|
||||
repo_id,
|
||||
base_model=args.pretrained_model_name_or_path,
|
||||
repo_folder=args.output_dir,
|
||||
)
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=args.output_dir,
|
||||
commit_message="End of training",
|
||||
token=args.hub_token,
|
||||
allow_patterns=["pytorch_model.bin", "*.json", "*.md"],
|
||||
# ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
Loading…
x
Reference in New Issue
Block a user