finish train
This commit is contained in:
parent
3168acad93
commit
78702c7f47
2
RDT-170M/.dockerignore
Normal file
2
RDT-170M/.dockerignore
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
input/*
|
||||||
|
output/*
|
||||||
7
RDT-170M/.gitignore
vendored
Normal file
7
RDT-170M/.gitignore
vendored
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
processed_data/
|
||||||
|
training_data/
|
||||||
|
checkpoints/
|
||||||
|
model_config/*.yml
|
||||||
|
wandb/*
|
||||||
|
!models/
|
||||||
|
!data/
|
||||||
48
RDT-170M/Dockerfile
Normal file
48
RDT-170M/Dockerfile
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
|
||||||
|
FROM registry.d-robotics.cc/public/cuda:11.8.0-cudnn8-devel-ubuntu22.04
|
||||||
|
# ccr-29eug8s3-pub.cnc.bj.baidubce.com/public/cuda:11.8.0-cudnn8-devel-ubuntu22.04
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
ENV TZ=Asia/Shanghai
|
||||||
|
|
||||||
|
RUN sed -i 's/archive.ubuntu.com/mirrors.tuna.tsinghua.edu.cn/g' /etc/apt/sources.list && \
|
||||||
|
sed -i 's/security.ubuntu.com/mirrors.tuna.tsinghua.edu.cn/g' /etc/apt/sources.list
|
||||||
|
|
||||||
|
RUN apt-get update --allow-unauthenticated && apt-get install -y \
|
||||||
|
software-properties-common \
|
||||||
|
&& add-apt-repository ppa:deadsnakes/ppa \
|
||||||
|
&& apt-get update \
|
||||||
|
&& apt-get install -y \
|
||||||
|
python3.10 \
|
||||||
|
python3.10-dev \
|
||||||
|
python3-pip \
|
||||||
|
python3.10-distutils \
|
||||||
|
libgl1-mesa-glx \
|
||||||
|
libglib2.0-0 \
|
||||||
|
wget \
|
||||||
|
ffmpeg \
|
||||||
|
libsm6 \
|
||||||
|
libxext6 \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1
|
||||||
|
|
||||||
|
COPY . /app/
|
||||||
|
|
||||||
|
RUN python3 -m pip install --upgrade pip -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
|
|
||||||
|
# RUN pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu121
|
||||||
|
RUN pip install torch==2.1.0 torchvision==0.16.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
|
RUN pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
|
|
||||||
|
RUN pip install packaging==24.0
|
||||||
|
|
||||||
|
RUN pip install tfds-nightly==4.9.4.dev202402070044
|
||||||
|
|
||||||
|
RUN pip install flash_attn-2.7.2.post1+cu12torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
||||||
|
|
||||||
|
RUN mkdir -p /app/dataset/input /app/dataset/output
|
||||||
|
|
||||||
|
ENTRYPOINT ["bash", "finetune.sh"]
|
||||||
1
RDT-170M/__init__.py
Normal file
1
RDT-170M/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .deploy_policy import *
|
||||||
BIN
RDT-170M/assets/head.png
Normal file
BIN
RDT-170M/assets/head.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 726 KiB |
72
RDT-170M/configs/base.yaml
Normal file
72
RDT-170M/configs/base.yaml
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
common:
|
||||||
|
# The number of historical images
|
||||||
|
img_history_size: 2
|
||||||
|
# The number of future actions to predict
|
||||||
|
action_chunk_size: 64
|
||||||
|
# The number of cameras to be used in the model
|
||||||
|
num_cameras: 3
|
||||||
|
# Dimension for state/action, we use the same space for both state and action
|
||||||
|
# This MUST be equal to configs/state_vec.py
|
||||||
|
state_dim: 128
|
||||||
|
|
||||||
|
|
||||||
|
dataset:
|
||||||
|
# We will extract the data from raw dataset
|
||||||
|
# and store them in the disk buffer by producer
|
||||||
|
# When training, we will read the data
|
||||||
|
# randomly from the buffer by consumer
|
||||||
|
# The producer will replace the data which has been
|
||||||
|
# read by the consumer with new data
|
||||||
|
|
||||||
|
# The path to the buffer (at least 400GB)
|
||||||
|
buf_path: /path/to/buffer
|
||||||
|
# The number of chunks in the buffer
|
||||||
|
buf_num_chunks: 512
|
||||||
|
# The number of samples (step rather than episode) in each chunk
|
||||||
|
buf_chunk_size: 512
|
||||||
|
|
||||||
|
# We will filter the episodes with length less than `epsd_len_thresh_low`
|
||||||
|
epsd_len_thresh_low: 32
|
||||||
|
# For those more than `epsd_len_thresh_high`,
|
||||||
|
# we will randomly sample `epsd_len_thresh_high` steps each time we load the episode
|
||||||
|
# to better balance the training datasets
|
||||||
|
epsd_len_thresh_high: 2048
|
||||||
|
# How to fit the image size
|
||||||
|
image_aspect_ratio: pad
|
||||||
|
# Maximum number of language tokens
|
||||||
|
tokenizer_max_length: 1024
|
||||||
|
|
||||||
|
model:
|
||||||
|
# Config for condition adpators
|
||||||
|
lang_adaptor: mlp2x_gelu
|
||||||
|
img_adaptor: mlp2x_gelu
|
||||||
|
state_adaptor: mlp3x_gelu
|
||||||
|
lang_token_dim: 4096
|
||||||
|
img_token_dim: 1152
|
||||||
|
# Dim of action or proprioception vector
|
||||||
|
# A `state` refers to an action or a proprioception vector
|
||||||
|
state_token_dim: 128
|
||||||
|
# Config for RDT structure
|
||||||
|
rdt:
|
||||||
|
# 1B: num_head 32 hidden_size 2048 depth: 28
|
||||||
|
# 170M: num_head 32 hidden_size 1024 depth: 14
|
||||||
|
hidden_size: 1024
|
||||||
|
depth: 14
|
||||||
|
num_heads: 32
|
||||||
|
cond_pos_embed_type: multimodal
|
||||||
|
# For noise scheduler
|
||||||
|
noise_scheduler:
|
||||||
|
type: ddpm
|
||||||
|
num_train_timesteps: 1000
|
||||||
|
num_inference_timesteps: 5
|
||||||
|
beta_schedule: squaredcos_cap_v2 # Critical choice
|
||||||
|
prediction_type: sample
|
||||||
|
clip_sample: False
|
||||||
|
# For EMA (params averaging)
|
||||||
|
# We do not use EMA currently
|
||||||
|
ema:
|
||||||
|
update_after_step: 0
|
||||||
|
inv_gamma: 1.0
|
||||||
|
power: 0.75
|
||||||
|
min_value: 0.0
|
||||||
|
max_value: 0.9999
|
||||||
@ -0,0 +1,50 @@
|
|||||||
|
{
|
||||||
|
"A": [
|
||||||
|
[
|
||||||
|
-0.2691913843154907,
|
||||||
|
-0.21995729207992554,
|
||||||
|
-0.182277649641037
|
||||||
|
],
|
||||||
|
[
|
||||||
|
0.35127854347229004,
|
||||||
|
0.2769763469696045,
|
||||||
|
0.17159393429756165
|
||||||
|
]
|
||||||
|
],
|
||||||
|
"B": [
|
||||||
|
[
|
||||||
|
-0.2576896846294403,
|
||||||
|
-0.22244493663311005,
|
||||||
|
-0.20557966828346252
|
||||||
|
],
|
||||||
|
[
|
||||||
|
0.32854634523391724,
|
||||||
|
0.2922680974006653,
|
||||||
|
0.17373555898666382
|
||||||
|
]
|
||||||
|
],
|
||||||
|
"C": [
|
||||||
|
[
|
||||||
|
-0.29205888509750366,
|
||||||
|
-0.24688798189163208,
|
||||||
|
-0.17577645182609558
|
||||||
|
],
|
||||||
|
[
|
||||||
|
0.25053921341896057,
|
||||||
|
0.3277084231376648,
|
||||||
|
0.16431939601898193
|
||||||
|
]
|
||||||
|
],
|
||||||
|
"D": [
|
||||||
|
[
|
||||||
|
-0.25131964683532715,
|
||||||
|
-0.15233077108860016,
|
||||||
|
-0.13294968008995056
|
||||||
|
],
|
||||||
|
[
|
||||||
|
0.19209328293800354,
|
||||||
|
0.19344553351402283,
|
||||||
|
0.1370421051979065
|
||||||
|
]
|
||||||
|
]
|
||||||
|
}
|
||||||
65
RDT-170M/configs/dataset_control_freq.json
Normal file
65
RDT-170M/configs/dataset_control_freq.json
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
{
|
||||||
|
"fractal20220817_data": 3,
|
||||||
|
"taco_play": 15,
|
||||||
|
"jaco_play": 10,
|
||||||
|
"berkeley_cable_routing": 10,
|
||||||
|
"nyu_door_opening_surprising_effectiveness": 3,
|
||||||
|
"viola": 20,
|
||||||
|
"berkeley_autolab_ur5": 5,
|
||||||
|
"toto": 30,
|
||||||
|
"kuka": 10,
|
||||||
|
"language_table": 10,
|
||||||
|
"columbia_cairlab_pusht_real": 10,
|
||||||
|
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": 20,
|
||||||
|
"nyu_rot_dataset_converted_externally_to_rlds":3,
|
||||||
|
"stanford_hydra_dataset_converted_externally_to_rlds": 10,
|
||||||
|
"austin_buds_dataset_converted_externally_to_rlds": 20,
|
||||||
|
"nyu_franka_play_dataset_converted_externally_to_rlds": 3,
|
||||||
|
"maniskill_dataset_converted_externally_to_rlds": 20,
|
||||||
|
"furniture_bench_dataset_converted_externally_to_rlds": 10,
|
||||||
|
"ucsd_kitchen_dataset_converted_externally_to_rlds": 2,
|
||||||
|
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": 3,
|
||||||
|
"austin_sailor_dataset_converted_externally_to_rlds": 20,
|
||||||
|
"austin_sirius_dataset_converted_externally_to_rlds": 20,
|
||||||
|
"bc_z": 10,
|
||||||
|
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": 10,
|
||||||
|
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": 10,
|
||||||
|
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": 10,
|
||||||
|
"utokyo_xarm_bimanual_converted_externally_to_rlds": 10,
|
||||||
|
"berkeley_mvp_converted_externally_to_rlds": 5,
|
||||||
|
"berkeley_rpt_converted_externally_to_rlds": 30,
|
||||||
|
"kaist_nonprehensile_converted_externally_to_rlds": 10,
|
||||||
|
"stanford_mask_vit_converted_externally_to_rlds": 0,
|
||||||
|
"tokyo_u_lsmo_converted_externally_to_rlds": 10,
|
||||||
|
"dlr_sara_pour_converted_externally_to_rlds": 10,
|
||||||
|
"dlr_sara_grid_clamp_converted_externally_to_rlds": 10,
|
||||||
|
"dlr_edan_shared_control_converted_externally_to_rlds": 5,
|
||||||
|
"asu_table_top_converted_externally_to_rlds": 12.5,
|
||||||
|
"stanford_robocook_converted_externally_to_rlds": 5,
|
||||||
|
"eth_agent_affordances": 66.6,
|
||||||
|
"imperialcollege_sawyer_wrist_cam": 10,
|
||||||
|
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": 20,
|
||||||
|
"uiuc_d3field": 1,
|
||||||
|
"utaustin_mutex": 20,
|
||||||
|
"berkeley_fanuc_manipulation": 10,
|
||||||
|
"cmu_play_fusion": 5,
|
||||||
|
"cmu_stretch": 10,
|
||||||
|
"berkeley_gnm_recon": 3,
|
||||||
|
"berkeley_gnm_cory_hall": 5,
|
||||||
|
"berkeley_gnm_sac_son": 10,
|
||||||
|
"robo_net": 1,
|
||||||
|
"roboturk_real_towercreation": 10,
|
||||||
|
"roboturk_real_laundrylayout": 10,
|
||||||
|
"roboturk_real_objectsearch": 10,
|
||||||
|
"aloha_mobile": 50,
|
||||||
|
"aloha_static": 50,
|
||||||
|
"roboset": 5,
|
||||||
|
"droid": 15,
|
||||||
|
"fmb": 10,
|
||||||
|
"dobbe": 30,
|
||||||
|
"qut_dexterous_manpulation": 30,
|
||||||
|
"agilex": 25,
|
||||||
|
"rh20t": 10,
|
||||||
|
"calvin": 30,
|
||||||
|
"bridgev2": 5
|
||||||
|
}
|
||||||
575
RDT-170M/configs/dataset_img_keys.json
Normal file
575
RDT-170M/configs/dataset_img_keys.json
Normal file
@ -0,0 +1,575 @@
|
|||||||
|
{
|
||||||
|
"fractal20220817_data": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image"
|
||||||
|
],
|
||||||
|
"image_mask":[
|
||||||
|
1,0,0,0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"taco_play": {
|
||||||
|
"image_keys": [
|
||||||
|
"rgb_static",
|
||||||
|
"rgb_gripper",
|
||||||
|
"rgb_static",
|
||||||
|
"rgb_static"
|
||||||
|
],
|
||||||
|
"image_mask":[
|
||||||
|
1,1,0,0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"jaco_play": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"image_wrist",
|
||||||
|
"image_wrist",
|
||||||
|
"image_wrist"
|
||||||
|
],
|
||||||
|
"image_mask":[
|
||||||
|
1,1,0,0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"berkeley_cable_routing": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"wrist45_image",
|
||||||
|
"wrist225_image",
|
||||||
|
"top_image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,1,0,1]
|
||||||
|
},
|
||||||
|
"nyu_door_opening_surprising_effectiveness": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,0]
|
||||||
|
},
|
||||||
|
"viola": {
|
||||||
|
"image_keys": [
|
||||||
|
"agentview_rgb",
|
||||||
|
"eye_in_hand_rgb",
|
||||||
|
"eye_in_hand_rgb",
|
||||||
|
"eye_in_hand_rgb"
|
||||||
|
],
|
||||||
|
"image_mask":[1,1,0,0]
|
||||||
|
},
|
||||||
|
"berkeley_autolab_ur5": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"hand_image",
|
||||||
|
"hand_image",
|
||||||
|
"hand_image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,1,0,0]
|
||||||
|
},
|
||||||
|
"toto": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,0]
|
||||||
|
},
|
||||||
|
"kuka": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,0]
|
||||||
|
},
|
||||||
|
"language_table": {
|
||||||
|
"image_keys": [
|
||||||
|
"rgb",
|
||||||
|
"rgb",
|
||||||
|
"rgb",
|
||||||
|
"rgb"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,0]
|
||||||
|
},
|
||||||
|
"columbia_cairlab_pusht_real": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"wrist_image",
|
||||||
|
"wrist_image",
|
||||||
|
"wrist_image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,1,0,0]
|
||||||
|
},
|
||||||
|
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,0]
|
||||||
|
},
|
||||||
|
"nyu_rot_dataset_converted_externally_to_rlds": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,0]
|
||||||
|
},
|
||||||
|
"stanford_hydra_dataset_converted_externally_to_rlds": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"wrist_image",
|
||||||
|
"wrist_image",
|
||||||
|
"wrist_image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,1,0,0]
|
||||||
|
},
|
||||||
|
"austin_buds_dataset_converted_externally_to_rlds": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"wrist_image",
|
||||||
|
"wrist_image",
|
||||||
|
"wrist_image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,1,0,0]
|
||||||
|
},
|
||||||
|
"nyu_franka_play_dataset_converted_externally_to_rlds": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"image_additional_view",
|
||||||
|
"image_additional_view",
|
||||||
|
"image_additional_view"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,1]
|
||||||
|
},
|
||||||
|
"maniskill_dataset_converted_externally_to_rlds": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"wrist_image",
|
||||||
|
"wrist_image",
|
||||||
|
"wrist_image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,1,0,0]
|
||||||
|
},
|
||||||
|
"furniture_bench_dataset_converted_externally_to_rlds": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"wrist_image",
|
||||||
|
"wrist_image",
|
||||||
|
"wrist_image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,1,0,0]
|
||||||
|
},
|
||||||
|
"ucsd_kitchen_dataset_converted_externally_to_rlds": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,0]
|
||||||
|
},
|
||||||
|
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,0]
|
||||||
|
},
|
||||||
|
"austin_sailor_dataset_converted_externally_to_rlds": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"wrist_image",
|
||||||
|
"wrist_image",
|
||||||
|
"wrist_image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,1,0,0]
|
||||||
|
},
|
||||||
|
"austin_sirius_dataset_converted_externally_to_rlds": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"wrist_image",
|
||||||
|
"wrist_image",
|
||||||
|
"wrist_image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,1,0,0]
|
||||||
|
},
|
||||||
|
"bc_z": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,0]
|
||||||
|
},
|
||||||
|
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,0]
|
||||||
|
},
|
||||||
|
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,0]
|
||||||
|
},
|
||||||
|
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"hand_image",
|
||||||
|
"hand_image",
|
||||||
|
"image2"
|
||||||
|
],
|
||||||
|
"image_mask":[1,1,0,1]
|
||||||
|
},
|
||||||
|
"utokyo_xarm_bimanual_converted_externally_to_rlds": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,0]
|
||||||
|
},
|
||||||
|
"berkeley_mvp_converted_externally_to_rlds": {
|
||||||
|
"image_keys": [
|
||||||
|
"hand_image",
|
||||||
|
"hand_image",
|
||||||
|
"hand_image",
|
||||||
|
"hand_image"
|
||||||
|
],
|
||||||
|
"image_mask":[0,1,0,0]
|
||||||
|
},
|
||||||
|
"berkeley_rpt_converted_externally_to_rlds": {
|
||||||
|
"image_keys": [
|
||||||
|
"hand_image",
|
||||||
|
"hand_image",
|
||||||
|
"hand_image",
|
||||||
|
"hand_image"
|
||||||
|
],
|
||||||
|
"image_mask":[0,1,0,0]
|
||||||
|
},
|
||||||
|
"kaist_nonprehensile_converted_externally_to_rlds": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,0]
|
||||||
|
},
|
||||||
|
"stanford_mask_vit_converted_externally_to_rlds": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,0]
|
||||||
|
},
|
||||||
|
"tokyo_u_lsmo_converted_externally_to_rlds": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,0]
|
||||||
|
},
|
||||||
|
"dlr_sara_pour_converted_externally_to_rlds": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,0]
|
||||||
|
},
|
||||||
|
"dlr_sara_grid_clamp_converted_externally_to_rlds": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,0]
|
||||||
|
},
|
||||||
|
"dlr_edan_shared_control_converted_externally_to_rlds": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,0]
|
||||||
|
},
|
||||||
|
"asu_table_top_converted_externally_to_rlds": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,0]
|
||||||
|
},
|
||||||
|
"stanford_robocook_converted_externally_to_rlds": {
|
||||||
|
"image_keys": [
|
||||||
|
"image_2",
|
||||||
|
"image_1",
|
||||||
|
"image_3",
|
||||||
|
"image_4"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,1]
|
||||||
|
},
|
||||||
|
"eth_agent_affordances": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,0]
|
||||||
|
},
|
||||||
|
"imperialcollege_sawyer_wrist_cam": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"wrist_image",
|
||||||
|
"wrist_image",
|
||||||
|
"wrist_image"
|
||||||
|
],
|
||||||
|
"image_mask":[0,1,0,0]
|
||||||
|
},
|
||||||
|
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"wrist_image",
|
||||||
|
"wrist_image",
|
||||||
|
"wrist_image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,1,0,0]
|
||||||
|
},
|
||||||
|
"uiuc_d3field": {
|
||||||
|
"image_keys": [
|
||||||
|
"image_1",
|
||||||
|
"image_2",
|
||||||
|
"image_3",
|
||||||
|
"image_4"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,1]
|
||||||
|
},
|
||||||
|
"utaustin_mutex": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"wrist_image",
|
||||||
|
"wrist_image",
|
||||||
|
"wrist_image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,1,0,0]
|
||||||
|
},
|
||||||
|
"berkeley_fanuc_manipulation": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"wrist_image",
|
||||||
|
"wrist_image",
|
||||||
|
"wrist_image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,1,0,0]
|
||||||
|
},
|
||||||
|
"cmu_play_fusion": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,0]
|
||||||
|
},
|
||||||
|
"cmu_stretch": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,0]
|
||||||
|
},
|
||||||
|
"berkeley_gnm_recon": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,0]
|
||||||
|
},
|
||||||
|
"berkeley_gnm_cory_hall": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,0]
|
||||||
|
},
|
||||||
|
"berkeley_gnm_sac_son": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,0]
|
||||||
|
},
|
||||||
|
"robo_net": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"image1",
|
||||||
|
"image2",
|
||||||
|
"image2"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,1]
|
||||||
|
},
|
||||||
|
"roboturk_real_towercreation": {
|
||||||
|
"image_keys": [
|
||||||
|
"top_rgb_frame",
|
||||||
|
"front_rgb_frame",
|
||||||
|
"front_rgb_frame",
|
||||||
|
"front_rgb_frame"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,1]
|
||||||
|
},
|
||||||
|
"roboturk_real_laundrylayout": {
|
||||||
|
"image_keys": [
|
||||||
|
"top_rgb_frame",
|
||||||
|
"front_rgb_frame",
|
||||||
|
"front_rgb_frame",
|
||||||
|
"front_rgb_frame"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,1]
|
||||||
|
},
|
||||||
|
"roboturk_real_objectsearch": {
|
||||||
|
"image_keys": [
|
||||||
|
"top_rgb_frame",
|
||||||
|
"front_rgb_frame",
|
||||||
|
"front_rgb_frame",
|
||||||
|
"front_rgb_frame"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,1]
|
||||||
|
},
|
||||||
|
"aloha_mobile": {
|
||||||
|
"image_keys": [
|
||||||
|
"cam_high",
|
||||||
|
"cam_right_wrist",
|
||||||
|
"cam_left_wrist",
|
||||||
|
"cam_right_wrist"
|
||||||
|
],
|
||||||
|
"image_mask":[1,1,1,0]
|
||||||
|
},
|
||||||
|
"aloha_static": {
|
||||||
|
"image_keys": [
|
||||||
|
"cam_high",
|
||||||
|
"cam_right_wrist",
|
||||||
|
"cam_left_wrist",
|
||||||
|
"cam_low"
|
||||||
|
],
|
||||||
|
"image_mask":[1,1,1,1]
|
||||||
|
},
|
||||||
|
"roboset": {
|
||||||
|
"image_keys": [
|
||||||
|
"rgb_top",
|
||||||
|
"rgb_right",
|
||||||
|
"rgb_left",
|
||||||
|
"rgb_right"
|
||||||
|
],
|
||||||
|
"image_mask":[1,1,1,0]
|
||||||
|
},
|
||||||
|
"droid": {
|
||||||
|
"image_keys": [
|
||||||
|
"exterior_image_1_left",
|
||||||
|
"wrist_image_left",
|
||||||
|
"wrist_image_left",
|
||||||
|
"exterior_image_2_left"
|
||||||
|
],
|
||||||
|
"image_mask":[1,1,0,1]
|
||||||
|
},
|
||||||
|
"fmb": {
|
||||||
|
"image_keys": [
|
||||||
|
"image_side_1",
|
||||||
|
"image_wrist_1",
|
||||||
|
"image_wrist_1",
|
||||||
|
"image_side_2"
|
||||||
|
],
|
||||||
|
"image_mask":[1,1,0,1]
|
||||||
|
},
|
||||||
|
"dobbe": {
|
||||||
|
"image_keys": [
|
||||||
|
"wrist_image",
|
||||||
|
"wrist_image",
|
||||||
|
"wrist_image",
|
||||||
|
"wrist_image"
|
||||||
|
],
|
||||||
|
"image_mask":[0,1,0,0]
|
||||||
|
},
|
||||||
|
"qut_dexterous_manpulation": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"wrist_image",
|
||||||
|
"wrist_image",
|
||||||
|
"wrist_image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,1,0,0]
|
||||||
|
},
|
||||||
|
"agilex": {
|
||||||
|
"image_keys": [
|
||||||
|
"cam_high",
|
||||||
|
"cam_right_wrist",
|
||||||
|
"cam_left_wrist",
|
||||||
|
"cam_right_wrist"
|
||||||
|
],
|
||||||
|
"image_mask":[1,1,1,0]
|
||||||
|
},
|
||||||
|
"rh20t": {
|
||||||
|
"image_keys": [
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image",
|
||||||
|
"image"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,0]
|
||||||
|
},
|
||||||
|
"calvin": {
|
||||||
|
"image_keys": [
|
||||||
|
"rgb_static",
|
||||||
|
"rgb_gripper",
|
||||||
|
"rgb_gripper",
|
||||||
|
"rgb_gripper"
|
||||||
|
],
|
||||||
|
"image_mask":[1,1,0,0]
|
||||||
|
},
|
||||||
|
"bridgev2": {
|
||||||
|
"image_keys": [
|
||||||
|
"images0",
|
||||||
|
"images0",
|
||||||
|
"images0",
|
||||||
|
"images0"
|
||||||
|
],
|
||||||
|
"image_mask":[1,0,0,0]
|
||||||
|
}
|
||||||
|
}
|
||||||
525
RDT-170M/configs/dataset_stat.json
Normal file
525
RDT-170M/configs/dataset_stat.json
Normal file
@ -0,0 +1,525 @@
|
|||||||
|
{
|
||||||
|
"agilex": {
|
||||||
|
"dataset_name": "agilex",
|
||||||
|
"state_mean": [
|
||||||
|
-0.0036545392947090432,
|
||||||
|
-0.2773659935760079,
|
||||||
|
0.3147616748061523,
|
||||||
|
0.3813313179910183,
|
||||||
|
0.04028575944090457,
|
||||||
|
0.034888520819083294,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0
|
||||||
|
],
|
||||||
|
"state_std": [
|
||||||
|
0.05763674563578847,
|
||||||
|
0.2580181064167735,
|
||||||
|
0.19785840483767897,
|
||||||
|
0.05020347749331385,
|
||||||
|
0.054529239104671424,
|
||||||
|
0.05020521339363586,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0
|
||||||
|
],
|
||||||
|
"state_min": [
|
||||||
|
-0.17447535196940103,
|
||||||
|
-0.5522612677680121,
|
||||||
|
-0.3340397516886393,
|
||||||
|
0.21861712137858072,
|
||||||
|
-0.09725829230414497,
|
||||||
|
0.003396739231215583,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0
|
||||||
|
],
|
||||||
|
"state_max": [
|
||||||
|
0.21961932712131077,
|
||||||
|
0.30613206227620443,
|
||||||
|
0.5444545321994357,
|
||||||
|
0.4866888682047526,
|
||||||
|
0.31486290825737845,
|
||||||
|
0.3355223337809245,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
3
RDT-170M/configs/finetune_datasets.json
Normal file
3
RDT-170M/configs/finetune_datasets.json
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
[
|
||||||
|
"agilex"
|
||||||
|
]
|
||||||
3
RDT-170M/configs/finetune_sample_weights.json
Normal file
3
RDT-170M/configs/finetune_sample_weights.json
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
{
|
||||||
|
"agilex": 100
|
||||||
|
}
|
||||||
48
RDT-170M/configs/pretrain_datasets.json
Normal file
48
RDT-170M/configs/pretrain_datasets.json
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
[
|
||||||
|
"fractal20220817_data",
|
||||||
|
"jaco_play",
|
||||||
|
"taco_play",
|
||||||
|
"berkeley_cable_routing",
|
||||||
|
"viola",
|
||||||
|
"berkeley_autolab_ur5",
|
||||||
|
"toto",
|
||||||
|
"nyu_door_opening_surprising_effectiveness",
|
||||||
|
"columbia_cairlab_pusht_real",
|
||||||
|
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds",
|
||||||
|
"austin_buds_dataset_converted_externally_to_rlds",
|
||||||
|
"kuka",
|
||||||
|
"utokyo_xarm_bimanual_converted_externally_to_rlds",
|
||||||
|
"stanford_hydra_dataset_converted_externally_to_rlds",
|
||||||
|
"maniskill_dataset_converted_externally_to_rlds",
|
||||||
|
"ucsd_kitchen_dataset_converted_externally_to_rlds",
|
||||||
|
"ucsd_pick_and_place_dataset_converted_externally_to_rlds",
|
||||||
|
"austin_sailor_dataset_converted_externally_to_rlds",
|
||||||
|
"austin_sirius_dataset_converted_externally_to_rlds",
|
||||||
|
"bc_z",
|
||||||
|
"utokyo_pr2_opening_fridge_converted_externally_to_rlds",
|
||||||
|
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds",
|
||||||
|
"utokyo_xarm_pick_and_place_converted_externally_to_rlds",
|
||||||
|
"berkeley_mvp_converted_externally_to_rlds",
|
||||||
|
"berkeley_rpt_converted_externally_to_rlds",
|
||||||
|
"kaist_nonprehensile_converted_externally_to_rlds",
|
||||||
|
"tokyo_u_lsmo_converted_externally_to_rlds",
|
||||||
|
"dlr_sara_grid_clamp_converted_externally_to_rlds",
|
||||||
|
"stanford_robocook_converted_externally_to_rlds",
|
||||||
|
"imperialcollege_sawyer_wrist_cam",
|
||||||
|
"iamlab_cmu_pickup_insert_converted_externally_to_rlds",
|
||||||
|
"utaustin_mutex",
|
||||||
|
"berkeley_fanuc_manipulation",
|
||||||
|
"cmu_play_fusion",
|
||||||
|
"language_table",
|
||||||
|
"furniture_bench_dataset_converted_externally_to_rlds",
|
||||||
|
"droid",
|
||||||
|
"fmb",
|
||||||
|
"dobbe",
|
||||||
|
"qut_dexterous_manpulation",
|
||||||
|
"aloha_mobile",
|
||||||
|
"aloha_static",
|
||||||
|
"roboset",
|
||||||
|
"rh20t",
|
||||||
|
"calvin",
|
||||||
|
"bridgev2"
|
||||||
|
]
|
||||||
48
RDT-170M/configs/pretrain_sample_weights.json
Normal file
48
RDT-170M/configs/pretrain_sample_weights.json
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
{
|
||||||
|
"fractal20220817_data": 271,
|
||||||
|
"taco_play": 60,
|
||||||
|
"jaco_play": 33,
|
||||||
|
"berkeley_cable_routing": 8,
|
||||||
|
"nyu_door_opening_surprising_effectiveness": 10,
|
||||||
|
"viola": 12,
|
||||||
|
"berkeley_autolab_ur5": 32,
|
||||||
|
"toto": 32,
|
||||||
|
"kuka": 50,
|
||||||
|
"language_table": 100,
|
||||||
|
"columbia_cairlab_pusht_real": 12,
|
||||||
|
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": 55,
|
||||||
|
"stanford_hydra_dataset_converted_externally_to_rlds": 24,
|
||||||
|
"austin_buds_dataset_converted_externally_to_rlds": 7,
|
||||||
|
"maniskill_dataset_converted_externally_to_rlds": 174,
|
||||||
|
"furniture_bench_dataset_converted_externally_to_rlds": 71,
|
||||||
|
"ucsd_kitchen_dataset_converted_externally_to_rlds": 12,
|
||||||
|
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": 37,
|
||||||
|
"austin_sailor_dataset_converted_externally_to_rlds": 15,
|
||||||
|
"austin_sirius_dataset_converted_externally_to_rlds": 24,
|
||||||
|
"bc_z": 208,
|
||||||
|
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": 9,
|
||||||
|
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": 15,
|
||||||
|
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": 10,
|
||||||
|
"utokyo_xarm_bimanual_converted_externally_to_rlds": 1,
|
||||||
|
"berkeley_mvp_converted_externally_to_rlds": 22,
|
||||||
|
"berkeley_rpt_converted_externally_to_rlds": 30,
|
||||||
|
"kaist_nonprehensile_converted_externally_to_rlds": 14,
|
||||||
|
"tokyo_u_lsmo_converted_externally_to_rlds": 7,
|
||||||
|
"dlr_sara_grid_clamp_converted_externally_to_rlds": 1,
|
||||||
|
"stanford_robocook_converted_externally_to_rlds": 50,
|
||||||
|
"imperialcollege_sawyer_wrist_cam": 13,
|
||||||
|
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": 25,
|
||||||
|
"utaustin_mutex": 39,
|
||||||
|
"berkeley_fanuc_manipulation": 20,
|
||||||
|
"cmu_play_fusion": 24,
|
||||||
|
"droid": 303,
|
||||||
|
"fmb": 42,
|
||||||
|
"dobbe": 36,
|
||||||
|
"qut_dexterous_manpulation": 14,
|
||||||
|
"aloha_mobile": 150,
|
||||||
|
"aloha_static": 150,
|
||||||
|
"roboset": 135,
|
||||||
|
"rh20t": 331,
|
||||||
|
"calvin": 100,
|
||||||
|
"bridgev2": 224
|
||||||
|
}
|
||||||
126
RDT-170M/configs/state_vec.py
Normal file
126
RDT-170M/configs/state_vec.py
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
STATE_VEC_IDX_MAPPING = {
|
||||||
|
# [0, 10): right arm joint positions
|
||||||
|
**{
|
||||||
|
"arm_joint_{}_pos".format(i): i
|
||||||
|
for i in range(10)
|
||||||
|
},
|
||||||
|
**{
|
||||||
|
"right_arm_joint_{}_pos".format(i): i
|
||||||
|
for i in range(10)
|
||||||
|
},
|
||||||
|
# [10, 15): right gripper joint positions
|
||||||
|
**{
|
||||||
|
"gripper_joint_{}_pos".format(i): i + 10
|
||||||
|
for i in range(5)
|
||||||
|
},
|
||||||
|
**{
|
||||||
|
"right_gripper_joint_{}_pos".format(i): i + 10
|
||||||
|
for i in range(5)
|
||||||
|
},
|
||||||
|
"gripper_open": 10, # alias of right_gripper_joint_0_pos
|
||||||
|
"right_gripper_open": 10,
|
||||||
|
# [15, 25): right arm joint velocities
|
||||||
|
**{
|
||||||
|
"arm_joint_{}_vel".format(i): i + 15
|
||||||
|
for i in range(10)
|
||||||
|
},
|
||||||
|
**{
|
||||||
|
"right_arm_joint_{}_vel".format(i): i + 15
|
||||||
|
for i in range(10)
|
||||||
|
},
|
||||||
|
# [25, 30): right gripper joint velocities
|
||||||
|
**{
|
||||||
|
"gripper_joint_{}_vel".format(i): i + 25
|
||||||
|
for i in range(5)
|
||||||
|
},
|
||||||
|
**{
|
||||||
|
"right_gripper_joint_{}_vel".format(i): i + 25
|
||||||
|
for i in range(5)
|
||||||
|
},
|
||||||
|
"gripper_open_vel": 25, # alias of right_gripper_joint_0_vel
|
||||||
|
"right_gripper_open_vel": 25,
|
||||||
|
# [30, 33): right end effector positions
|
||||||
|
"eef_pos_x": 30,
|
||||||
|
"right_eef_pos_x": 30,
|
||||||
|
"eef_pos_y": 31,
|
||||||
|
"right_eef_pos_y": 31,
|
||||||
|
"eef_pos_z": 32,
|
||||||
|
"right_eef_pos_z": 32,
|
||||||
|
# [33, 39): right end effector 6D pose
|
||||||
|
"eef_angle_0": 33,
|
||||||
|
"right_eef_angle_0": 33,
|
||||||
|
"eef_angle_1": 34,
|
||||||
|
"right_eef_angle_1": 34,
|
||||||
|
"eef_angle_2": 35,
|
||||||
|
"right_eef_angle_2": 35,
|
||||||
|
"eef_angle_3": 36,
|
||||||
|
"right_eef_angle_3": 36,
|
||||||
|
"eef_angle_4": 37,
|
||||||
|
"right_eef_angle_4": 37,
|
||||||
|
"eef_angle_5": 38,
|
||||||
|
"right_eef_angle_5": 38,
|
||||||
|
# [39, 42): right end effector velocities
|
||||||
|
"eef_vel_x": 39,
|
||||||
|
"right_eef_vel_x": 39,
|
||||||
|
"eef_vel_y": 40,
|
||||||
|
"right_eef_vel_y": 40,
|
||||||
|
"eef_vel_z": 41,
|
||||||
|
"right_eef_vel_z": 41,
|
||||||
|
# [42, 45): right end effector angular velocities
|
||||||
|
"eef_angular_vel_roll": 42,
|
||||||
|
"right_eef_angular_vel_roll": 42,
|
||||||
|
"eef_angular_vel_pitch": 43,
|
||||||
|
"right_eef_angular_vel_pitch": 43,
|
||||||
|
"eef_angular_vel_yaw": 44,
|
||||||
|
"right_eef_angular_vel_yaw": 44,
|
||||||
|
# [45, 50): reserved
|
||||||
|
# [50, 60): left arm joint positions
|
||||||
|
**{
|
||||||
|
"left_arm_joint_{}_pos".format(i): i + 50
|
||||||
|
for i in range(10)
|
||||||
|
},
|
||||||
|
# [60, 65): left gripper joint positions
|
||||||
|
**{
|
||||||
|
"left_gripper_joint_{}_pos".format(i): i + 60
|
||||||
|
for i in range(5)
|
||||||
|
},
|
||||||
|
"left_gripper_open": 60, # alias of left_gripper_joint_0_pos
|
||||||
|
# [65, 75): left arm joint velocities
|
||||||
|
**{
|
||||||
|
"left_arm_joint_{}_vel".format(i): i + 65
|
||||||
|
for i in range(10)
|
||||||
|
},
|
||||||
|
# [75, 80): left gripper joint velocities
|
||||||
|
**{
|
||||||
|
"left_gripper_joint_{}_vel".format(i): i + 75
|
||||||
|
for i in range(5)
|
||||||
|
},
|
||||||
|
"left_gripper_open_vel": 75, # alias of left_gripper_joint_0_vel
|
||||||
|
# [80, 83): left end effector positions
|
||||||
|
"left_eef_pos_x": 80,
|
||||||
|
"left_eef_pos_y": 81,
|
||||||
|
"left_eef_pos_z": 82,
|
||||||
|
# [83, 89): left end effector 6D pose
|
||||||
|
"left_eef_angle_0": 83,
|
||||||
|
"left_eef_angle_1": 84,
|
||||||
|
"left_eef_angle_2": 85,
|
||||||
|
"left_eef_angle_3": 86,
|
||||||
|
"left_eef_angle_4": 87,
|
||||||
|
"left_eef_angle_5": 88,
|
||||||
|
# [89, 92): left end effector velocities
|
||||||
|
"left_eef_vel_x": 89,
|
||||||
|
"left_eef_vel_y": 90,
|
||||||
|
"left_eef_vel_z": 91,
|
||||||
|
# [92, 95): left end effector angular velocities
|
||||||
|
"left_eef_angular_vel_roll": 92,
|
||||||
|
"left_eef_angular_vel_pitch": 93,
|
||||||
|
"left_eef_angular_vel_yaw": 94,
|
||||||
|
# [95, 100): reserved
|
||||||
|
# [100, 102): base linear velocities
|
||||||
|
"base_vel_x": 100,
|
||||||
|
"base_vel_y": 101,
|
||||||
|
# [102, 103): base angular velocities
|
||||||
|
"base_angular_vel": 102,
|
||||||
|
# [103, 128): reserved
|
||||||
|
}
|
||||||
|
STATE_VEC_LEN = 128
|
||||||
14
RDT-170M/configs/zero2.json
Normal file
14
RDT-170M/configs/zero2.json
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
{
|
||||||
|
"bf16": {
|
||||||
|
"enabled": "auto"
|
||||||
|
},
|
||||||
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
"train_batch_size": "auto",
|
||||||
|
"gradient_accumulation_steps": "auto",
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 2,
|
||||||
|
"overlap_comm": true,
|
||||||
|
"contiguous_gradients": true,
|
||||||
|
"sub_group_size": 1e9
|
||||||
|
}
|
||||||
|
}
|
||||||
2
RDT-170M/data/.gitignore
vendored
Normal file
2
RDT-170M/data/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
# Ignore data files
|
||||||
|
datasets
|
||||||
154
RDT-170M/data/agilex/hdf5totfrecords.py
Normal file
154
RDT-170M/data/agilex/hdf5totfrecords.py
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
import tensorflow as tf
|
||||||
|
import h5py
|
||||||
|
import os
|
||||||
|
import fnmatch
|
||||||
|
import shutil
|
||||||
|
from tqdm import tqdm
|
||||||
|
from multiprocessing import Pool
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def _bytes_feature(value):
|
||||||
|
"""Returns a bytes_list from a string / byte."""
|
||||||
|
if isinstance(value, type(tf.constant(0))):
|
||||||
|
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
|
||||||
|
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
|
||||||
|
|
||||||
|
|
||||||
|
def _bool_feature(value):
|
||||||
|
"""Returns a bool_list from a boolean."""
|
||||||
|
return tf.train.Feature(int64_list=tf.train.Int64List(value=[int(value)]))
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_example(
|
||||||
|
action,
|
||||||
|
base_action,
|
||||||
|
qpos,
|
||||||
|
qvel,
|
||||||
|
cam_high,
|
||||||
|
cam_left_wrist,
|
||||||
|
cam_right_wrist,
|
||||||
|
instruction,
|
||||||
|
terminate_episode,
|
||||||
|
):
|
||||||
|
feature = {
|
||||||
|
"action":
|
||||||
|
_bytes_feature(tf.io.serialize_tensor(action)),
|
||||||
|
"base_action":
|
||||||
|
_bytes_feature(tf.io.serialize_tensor(base_action)),
|
||||||
|
"qpos":
|
||||||
|
_bytes_feature(tf.io.serialize_tensor(qpos)),
|
||||||
|
"qvel":
|
||||||
|
_bytes_feature(tf.io.serialize_tensor(qvel)),
|
||||||
|
"cam_high":
|
||||||
|
_bytes_feature(tf.io.serialize_tensor(tf.convert_to_tensor(cam_high.tobytes(), dtype=tf.string))),
|
||||||
|
"cam_left_wrist":
|
||||||
|
_bytes_feature(tf.io.serialize_tensor(tf.convert_to_tensor(cam_left_wrist.tobytes(), dtype=tf.string))),
|
||||||
|
"cam_right_wrist":
|
||||||
|
_bytes_feature(tf.io.serialize_tensor(tf.convert_to_tensor(cam_right_wrist.tobytes(), dtype=tf.string))),
|
||||||
|
"instruction":
|
||||||
|
_bytes_feature(instruction),
|
||||||
|
"terminate_episode":
|
||||||
|
_bool_feature(terminate_episode),
|
||||||
|
}
|
||||||
|
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
|
||||||
|
return example_proto.SerializeToString()
|
||||||
|
|
||||||
|
|
||||||
|
def process_hdf5_file(args):
|
||||||
|
filepath, root_dir, out_dir = args
|
||||||
|
output_dir = os.path.join(out_dir, os.path.relpath(os.path.dirname(filepath), root_dir))
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
filename = os.path.basename(filepath)
|
||||||
|
tfrecord_path = os.path.join(output_dir, filename.replace(".hdf5", ".tfrecord"))
|
||||||
|
|
||||||
|
if os.path.exists(tfrecord_path) and os.path.getsize(tfrecord_path) > 0:
|
||||||
|
return f"TFRecords already exist at {tfrecord_path}"
|
||||||
|
try:
|
||||||
|
with h5py.File(filepath, "r") as f, tf.io.TFRecordWriter(tfrecord_path) as writer:
|
||||||
|
num_episodes = f["action"].shape[0]
|
||||||
|
# Remove the first few still steps
|
||||||
|
EPS = 1e-2
|
||||||
|
qpos = f["observations"]["qpos"][:]
|
||||||
|
# Get the idx of the first qpos whose delta exceeds the threshold
|
||||||
|
qpos_delta = np.abs(qpos - qpos[0:1])
|
||||||
|
indices = np.where(np.any(qpos_delta > EPS, axis=1))[0]
|
||||||
|
if len(indices) > 0:
|
||||||
|
first_idx = indices[0]
|
||||||
|
else:
|
||||||
|
raise ValueError("Found no qpos that exceeds the threshold.")
|
||||||
|
|
||||||
|
for i in range(first_idx - 1, num_episodes):
|
||||||
|
action = f["action"][i]
|
||||||
|
base_action = f["base_action"][i]
|
||||||
|
qpos = f["observations"]["qpos"][i]
|
||||||
|
qvel = f["observations"]["qvel"][i]
|
||||||
|
cam_high = f["observations"]["images"]["cam_high"][i]
|
||||||
|
cam_left_wrist = f["observations"]["images"]["cam_left_wrist"][i]
|
||||||
|
cam_right_wrist = f["observations"]["images"]["cam_right_wrist"][i]
|
||||||
|
instruction = f["instruction"][()]
|
||||||
|
terminate_episode = i == num_episodes - 1
|
||||||
|
serialized_example = serialize_example(
|
||||||
|
action,
|
||||||
|
base_action,
|
||||||
|
qpos,
|
||||||
|
qvel,
|
||||||
|
cam_high,
|
||||||
|
cam_left_wrist,
|
||||||
|
cam_right_wrist,
|
||||||
|
instruction,
|
||||||
|
terminate_episode,
|
||||||
|
)
|
||||||
|
writer.write(serialized_example)
|
||||||
|
except Exception as e:
|
||||||
|
with open("error_log.txt", "a") as f:
|
||||||
|
f.write(f"{filepath}\n")
|
||||||
|
print(f"error at {filepath}: {e}")
|
||||||
|
return f"TFRecords written to {tfrecord_path}"
|
||||||
|
|
||||||
|
|
||||||
|
def write_tfrecords(root_dir, out_dir):
|
||||||
|
if not os.path.exists(out_dir):
|
||||||
|
os.makedirs(out_dir)
|
||||||
|
|
||||||
|
hdf5_files = []
|
||||||
|
for root, dirs, files in os.walk(root_dir):
|
||||||
|
if os.path.exists(os.path.join(root, "expanded_instruction_gpt-4-turbo.json")):
|
||||||
|
# copy the instruction file
|
||||||
|
target_path = os.path.join(out_dir, os.path.relpath(root, root_dir))
|
||||||
|
os.makedirs(target_path, exist_ok=True)
|
||||||
|
shutil.copy(os.path.join(root, "expanded_instruction_gpt-4-turbo.json"), target_path)
|
||||||
|
elif os.path.exists(os.path.join(root, "expanded_instruction.json")):
|
||||||
|
print(root)
|
||||||
|
target_path = os.path.join(out_dir, os.path.relpath(root, root_dir))
|
||||||
|
os.makedirs(target_path, exist_ok=True)
|
||||||
|
shutil.copy(os.path.join(root, "expanded_instruction.json"), target_path)
|
||||||
|
# rename into expanded_instruction_gpt-4-turbo.json
|
||||||
|
os.rename(
|
||||||
|
os.path.join(
|
||||||
|
out_dir,
|
||||||
|
os.path.relpath(root, root_dir),
|
||||||
|
"expanded_instruction.json",
|
||||||
|
),
|
||||||
|
os.path.join(
|
||||||
|
out_dir,
|
||||||
|
os.path.relpath(root, root_dir),
|
||||||
|
"expanded_instruction_gpt-4-turbo.json",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for filename in fnmatch.filter(files, "*.hdf5"):
|
||||||
|
filepath = os.path.join(root, filename)
|
||||||
|
hdf5_files.append((filepath, root_dir, out_dir))
|
||||||
|
|
||||||
|
with Pool(16) as pool:
|
||||||
|
max_count = len(hdf5_files)
|
||||||
|
with tqdm(total=max_count) as pbar:
|
||||||
|
for _ in pool.imap_unordered(process_hdf5_file, hdf5_files):
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
print(f"TFRecords written to {out_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
root_dir = "../datasets/agilex/rdt_data/"
|
||||||
|
out_dir = "../datasets/agilex/tfrecords/"
|
||||||
|
write_tfrecords(root_dir, out_dir)
|
||||||
256
RDT-170M/data/compute_dataset_stat.py
Normal file
256
RDT-170M/data/compute_dataset_stat.py
Normal file
@ -0,0 +1,256 @@
|
|||||||
|
"""
|
||||||
|
This file will compute the min, max, mean, and standard deviation of each datasets
|
||||||
|
in `pretrain_datasets.json` or `pretrain_datasets.json`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
# from multiprocessing import Pool, Manager
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from data.vla_dataset import VLADataset
|
||||||
|
from data.hdf5_vla_dataset import HDF5VLADataset
|
||||||
|
from data.preprocess import generate_json_state
|
||||||
|
|
||||||
|
|
||||||
|
# Process each dataset to get the statistics
|
||||||
|
@tf.autograph.experimental.do_not_convert
|
||||||
|
def process_dataset(name_dataset_pair):
|
||||||
|
# print(f"PID {os.getpid()} processing {name_dataset_pair[0]}")
|
||||||
|
dataset_iter = name_dataset_pair[1]
|
||||||
|
|
||||||
|
MAX_EPISODES = 100000
|
||||||
|
EPS = 1e-8
|
||||||
|
# For debugging
|
||||||
|
# MAX_EPISODES = 10
|
||||||
|
episode_cnt = 0
|
||||||
|
state_sum = 0
|
||||||
|
state_sum_sq = 0
|
||||||
|
z_state_sum = 0
|
||||||
|
z_state_sum_sq = 0
|
||||||
|
state_cnt = 0
|
||||||
|
nz_state_cnt = None
|
||||||
|
state_max = None
|
||||||
|
state_min = None
|
||||||
|
for episode in dataset_iter:
|
||||||
|
episode_cnt += 1
|
||||||
|
if episode_cnt % 1000 == 0:
|
||||||
|
print(f"Processing episodes {episode_cnt}/{MAX_EPISODES}")
|
||||||
|
if episode_cnt > MAX_EPISODES:
|
||||||
|
break
|
||||||
|
episode_dict = episode["episode_dict"]
|
||||||
|
dataset_name = episode["dataset_name"]
|
||||||
|
|
||||||
|
res_tup = generate_json_state(episode_dict, dataset_name)
|
||||||
|
states = res_tup[1]
|
||||||
|
|
||||||
|
# Convert to numpy
|
||||||
|
states = states.numpy()
|
||||||
|
|
||||||
|
# Zero the values that are close to zero
|
||||||
|
z_states = states.copy()
|
||||||
|
z_states[np.abs(states) <= EPS] = 0
|
||||||
|
# Compute the non-zero count
|
||||||
|
if nz_state_cnt is None:
|
||||||
|
nz_state_cnt = np.zeros(states.shape[1])
|
||||||
|
nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0)
|
||||||
|
|
||||||
|
# Update statistics
|
||||||
|
state_sum += np.sum(states, axis=0)
|
||||||
|
state_sum_sq += np.sum(states**2, axis=0)
|
||||||
|
z_state_sum += np.sum(z_states, axis=0)
|
||||||
|
z_state_sum_sq += np.sum(z_states**2, axis=0)
|
||||||
|
state_cnt += states.shape[0]
|
||||||
|
if state_max is None:
|
||||||
|
state_max = np.max(states, axis=0)
|
||||||
|
state_min = np.min(states, axis=0)
|
||||||
|
else:
|
||||||
|
state_max = np.maximum(state_max, np.max(states, axis=0))
|
||||||
|
state_min = np.minimum(state_min, np.min(states, axis=0))
|
||||||
|
|
||||||
|
# Add one to avoid division by zero
|
||||||
|
nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt))
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"dataset_name":
|
||||||
|
name_dataset_pair[0],
|
||||||
|
"state_mean": (state_sum / state_cnt).tolist(),
|
||||||
|
"state_std":
|
||||||
|
np.sqrt(
|
||||||
|
np.maximum(
|
||||||
|
(z_state_sum_sq / nz_state_cnt) - (z_state_sum / state_cnt)**2 * (state_cnt / nz_state_cnt),
|
||||||
|
np.zeros_like(state_sum_sq),
|
||||||
|
)).tolist(),
|
||||||
|
"state_min":
|
||||||
|
state_min.tolist(),
|
||||||
|
"state_max":
|
||||||
|
state_max.tolist(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def process_hdf5_dataset(vla_dataset):
|
||||||
|
EPS = 1e-8
|
||||||
|
episode_cnt = 0
|
||||||
|
state_sum = 0
|
||||||
|
state_sum_sq = 0
|
||||||
|
z_state_sum = 0
|
||||||
|
z_state_sum_sq = 0
|
||||||
|
state_cnt = 0
|
||||||
|
nz_state_cnt = None
|
||||||
|
state_max = None
|
||||||
|
state_min = None
|
||||||
|
for i in tqdm(range(len(vla_dataset))):
|
||||||
|
episode = vla_dataset.get_item(i, state_only=True)
|
||||||
|
episode_cnt += 1
|
||||||
|
|
||||||
|
states = episode["state"]
|
||||||
|
|
||||||
|
# Zero the values that are close to zero
|
||||||
|
z_states = states.copy()
|
||||||
|
z_states[np.abs(states) <= EPS] = 0
|
||||||
|
# Compute the non-zero count
|
||||||
|
if nz_state_cnt is None:
|
||||||
|
nz_state_cnt = np.zeros(states.shape[1])
|
||||||
|
nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0)
|
||||||
|
|
||||||
|
# Update statistics
|
||||||
|
state_sum += np.sum(states, axis=0)
|
||||||
|
state_sum_sq += np.sum(states**2, axis=0)
|
||||||
|
z_state_sum += np.sum(z_states, axis=0)
|
||||||
|
z_state_sum_sq += np.sum(z_states**2, axis=0)
|
||||||
|
state_cnt += states.shape[0]
|
||||||
|
if state_max is None:
|
||||||
|
state_max = np.max(states, axis=0)
|
||||||
|
state_min = np.min(states, axis=0)
|
||||||
|
else:
|
||||||
|
state_max = np.maximum(state_max, np.max(states, axis=0))
|
||||||
|
state_min = np.minimum(state_min, np.min(states, axis=0))
|
||||||
|
|
||||||
|
# Add one to avoid division by zero
|
||||||
|
nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt))
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"dataset_name":
|
||||||
|
vla_dataset.get_dataset_name(),
|
||||||
|
"state_mean": (state_sum / state_cnt).tolist(),
|
||||||
|
"state_std":
|
||||||
|
np.sqrt(
|
||||||
|
np.maximum(
|
||||||
|
(z_state_sum_sq / nz_state_cnt) - (z_state_sum / state_cnt)**2 * (state_cnt / nz_state_cnt),
|
||||||
|
np.zeros_like(state_sum_sq),
|
||||||
|
)).tolist(),
|
||||||
|
"state_min":
|
||||||
|
state_min.tolist(),
|
||||||
|
"state_max":
|
||||||
|
state_max.tolist(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
# Multiprocessing currently with bugs
|
||||||
|
# parser.add_argument('--n_workers', type=int, default=1,
|
||||||
|
# help="Number of parallel workers.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset_type",
|
||||||
|
type=str,
|
||||||
|
default="pretrain",
|
||||||
|
help="Whether to load the pretrain dataset or finetune dataset.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_path",
|
||||||
|
type=str,
|
||||||
|
default="configs/dataset_stat.json",
|
||||||
|
help="JSON file path to save the dataset statistics.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip_exist",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to skip the existing dataset statistics.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hdf5_dataset",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to load the dataset from the HDF5 files.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.hdf5_dataset:
|
||||||
|
vla_dataset = HDF5VLADataset()
|
||||||
|
dataset_name = vla_dataset.get_dataset_name()
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(args.save_path, "r") as f:
|
||||||
|
results = json.load(f)
|
||||||
|
except FileNotFoundError:
|
||||||
|
results = {}
|
||||||
|
if args.skip_exist and dataset_name in results:
|
||||||
|
print(f"Skipping existed {dataset_name} dataset statistics")
|
||||||
|
else:
|
||||||
|
print(f"Processing {dataset_name} dataset")
|
||||||
|
result = process_hdf5_dataset(vla_dataset)
|
||||||
|
results[result["dataset_name"]] = result
|
||||||
|
with open(args.save_path, "w") as f:
|
||||||
|
json.dump(results, f, indent=4)
|
||||||
|
print("All datasets have been processed.")
|
||||||
|
os._exit(0)
|
||||||
|
|
||||||
|
vla_dataset = VLADataset(seed=0, dataset_type=args.dataset_type, repeat=False)
|
||||||
|
name_dataset_pairs = vla_dataset.name2dataset.items()
|
||||||
|
# num_workers = args.n_workers
|
||||||
|
|
||||||
|
for name_dataset_pair in tqdm(name_dataset_pairs):
|
||||||
|
try:
|
||||||
|
with open(args.save_path, "r") as f:
|
||||||
|
results = json.load(f)
|
||||||
|
except FileNotFoundError:
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
if args.skip_exist and name_dataset_pair[0] in results:
|
||||||
|
print(f"Skipping existed {name_dataset_pair[0]} dataset statistics")
|
||||||
|
continue
|
||||||
|
print(f"Processing {name_dataset_pair[0]} dataset")
|
||||||
|
|
||||||
|
result = process_dataset(name_dataset_pair)
|
||||||
|
|
||||||
|
results[result["dataset_name"]] = result
|
||||||
|
|
||||||
|
# Save the results in the json file after each dataset (for resume)
|
||||||
|
with open(args.save_path, "w") as f:
|
||||||
|
json.dump(results, f, indent=4)
|
||||||
|
|
||||||
|
print("All datasets have been processed.")
|
||||||
|
|
||||||
|
# with Manager() as manager:
|
||||||
|
# # Create shared dictionary and lock through the manager, accessible by all processes
|
||||||
|
# progress = manager.dict(processed=0, results={})
|
||||||
|
# progress_lock = manager.Lock()
|
||||||
|
|
||||||
|
# # Callback function to update progress
|
||||||
|
# def update_progress(result):
|
||||||
|
# with progress_lock:
|
||||||
|
# progress['processed'] += 1
|
||||||
|
# print(f"{result['dataset_name']} - {progress['processed']}/{len(name_dataset_pairs)} datasets have been processed")
|
||||||
|
# # Append the result to the shared dictionary
|
||||||
|
# progress['results'][result["dataset_name"]] = result
|
||||||
|
|
||||||
|
# with Pool(num_workers) as p:
|
||||||
|
# for name_dataset_pair in name_dataset_pairs:
|
||||||
|
# p.apply_async(process_dataset, args=(name_dataset_pair,), callback=update_progress)
|
||||||
|
|
||||||
|
# # Close the pool and wait for the work to finish
|
||||||
|
# p.close()
|
||||||
|
# p.join()
|
||||||
|
|
||||||
|
# # Save the results in the json file
|
||||||
|
# with open(args.save_path, 'w') as f:
|
||||||
|
# json.dump(progress['results'], f, indent=4)
|
||||||
112
RDT-170M/data/compute_dataset_stat_hdf5.py
Normal file
112
RDT-170M/data/compute_dataset_stat_hdf5.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
"""
|
||||||
|
This file will compute the min, max, mean, and standard deviation of each datasets
|
||||||
|
in `pretrain_datasets.json` or `pretrain_datasets.json`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from data.hdf5_vla_dataset import HDF5VLADataset
|
||||||
|
|
||||||
|
|
||||||
|
def process_hdf5_dataset(vla_dataset):
|
||||||
|
EPS = 1e-8
|
||||||
|
episode_cnt = 0
|
||||||
|
state_sum = 0
|
||||||
|
state_sum_sq = 0
|
||||||
|
z_state_sum = 0
|
||||||
|
z_state_sum_sq = 0
|
||||||
|
state_cnt = 0
|
||||||
|
nz_state_cnt = None
|
||||||
|
state_max = None
|
||||||
|
state_min = None
|
||||||
|
for i in tqdm(range(len(vla_dataset))):
|
||||||
|
episode = vla_dataset.get_item(i, state_only=True)
|
||||||
|
episode_cnt += 1
|
||||||
|
|
||||||
|
states = episode["state"]
|
||||||
|
|
||||||
|
# Zero the values that are close to zero
|
||||||
|
z_states = states.copy()
|
||||||
|
z_states[np.abs(states) <= EPS] = 0
|
||||||
|
# Compute the non-zero count
|
||||||
|
if nz_state_cnt is None:
|
||||||
|
nz_state_cnt = np.zeros(states.shape[1])
|
||||||
|
nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0)
|
||||||
|
|
||||||
|
# Update statistics
|
||||||
|
state_sum += np.sum(states, axis=0)
|
||||||
|
state_sum_sq += np.sum(states**2, axis=0)
|
||||||
|
z_state_sum += np.sum(z_states, axis=0)
|
||||||
|
z_state_sum_sq += np.sum(z_states**2, axis=0)
|
||||||
|
state_cnt += states.shape[0]
|
||||||
|
if state_max is None:
|
||||||
|
state_max = np.max(states, axis=0)
|
||||||
|
state_min = np.min(states, axis=0)
|
||||||
|
else:
|
||||||
|
state_max = np.maximum(state_max, np.max(states, axis=0))
|
||||||
|
state_min = np.minimum(state_min, np.min(states, axis=0))
|
||||||
|
|
||||||
|
# Add one to avoid division by zero
|
||||||
|
nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt))
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"dataset_name":
|
||||||
|
vla_dataset.get_dataset_name(),
|
||||||
|
"state_mean": (state_sum / state_cnt).tolist(),
|
||||||
|
"state_std":
|
||||||
|
np.sqrt(
|
||||||
|
np.maximum(
|
||||||
|
(z_state_sum_sq / nz_state_cnt) - (z_state_sum / state_cnt)**2 * (state_cnt / nz_state_cnt),
|
||||||
|
np.zeros_like(state_sum_sq),
|
||||||
|
)).tolist(),
|
||||||
|
"state_min":
|
||||||
|
state_min.tolist(),
|
||||||
|
"state_max":
|
||||||
|
state_max.tolist(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--task_name",
|
||||||
|
type=str,
|
||||||
|
default="configs/dataset_stat.json",
|
||||||
|
help="JSON file path to save the dataset statistics.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_path",
|
||||||
|
type=str,
|
||||||
|
default="configs/dataset_stat.json",
|
||||||
|
help="JSON file path to save the dataset statistics.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip_exist",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to skip the existing dataset statistics.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
vla_dataset = HDF5VLADataset(f"model_config/{args.task_name}.yml")
|
||||||
|
dataset_name = vla_dataset.get_dataset_name()
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(args.save_path, "r") as f:
|
||||||
|
results = json.load(f)
|
||||||
|
except FileNotFoundError:
|
||||||
|
results = {}
|
||||||
|
if args.skip_exist and dataset_name in results:
|
||||||
|
print(f"Skipping existed {dataset_name} dataset statistics")
|
||||||
|
else:
|
||||||
|
print(f"Processing {dataset_name} dataset")
|
||||||
|
result = process_hdf5_dataset(vla_dataset)
|
||||||
|
results[result["dataset_name"]] = result
|
||||||
|
with open(args.save_path, "w") as f:
|
||||||
|
json.dump(results, f, indent=4)
|
||||||
|
print("All datasets have been processed.")
|
||||||
BIN
RDT-170M/data/empty_lang_embed.pt
Normal file
BIN
RDT-170M/data/empty_lang_embed.pt
Normal file
Binary file not shown.
398
RDT-170M/data/episode_transform.py
Normal file
398
RDT-170M/data/episode_transform.py
Normal file
@ -0,0 +1,398 @@
|
|||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from data.preprocess import generate_json_state
|
||||||
|
from configs.state_vec import STATE_VEC_IDX_MAPPING
|
||||||
|
|
||||||
|
# Read the config
|
||||||
|
with open("configs/base.yaml", "r") as file:
|
||||||
|
config = yaml.safe_load(file)
|
||||||
|
# Load some constants from the config
|
||||||
|
IMG_HISTORY_SIZE = config["common"]["img_history_size"]
|
||||||
|
if IMG_HISTORY_SIZE < 1:
|
||||||
|
raise ValueError("Config `img_history_size` must be at least 1.")
|
||||||
|
ACTION_CHUNK_SIZE = config["common"]["action_chunk_size"]
|
||||||
|
if ACTION_CHUNK_SIZE < 1:
|
||||||
|
raise ValueError("Config `action_chunk_size` must be at least 1.")
|
||||||
|
|
||||||
|
|
||||||
|
@tf.function
|
||||||
|
def process_episode(epsd: dict, dataset_name: str, image_keys: list, image_mask: list) -> dict:
|
||||||
|
"""
|
||||||
|
Process an episode to extract the frames and the json content.
|
||||||
|
"""
|
||||||
|
# Frames of each camera
|
||||||
|
# Ugly code due to tf's poor compatibility
|
||||||
|
frames_0 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True)
|
||||||
|
frames_1 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True)
|
||||||
|
frames_2 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True)
|
||||||
|
frames_3 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True)
|
||||||
|
# Traverse the episode to collect...
|
||||||
|
for step in iter(epsd["steps"]):
|
||||||
|
# Parse the image
|
||||||
|
frames_0 = frames_0.write(
|
||||||
|
frames_0.size(),
|
||||||
|
tf.cond(
|
||||||
|
tf.equal(image_mask[0], 1),
|
||||||
|
lambda: step["observation"][image_keys[0]],
|
||||||
|
lambda: tf.zeros([0, 0, 0], dtype=tf.uint8),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# Very ugly code due to tf's poor compatibility
|
||||||
|
frames_1 = frames_1.write(
|
||||||
|
frames_1.size(),
|
||||||
|
tf.cond(
|
||||||
|
tf.equal(image_mask[1], 1),
|
||||||
|
lambda: step["observation"][image_keys[1]],
|
||||||
|
lambda: tf.zeros([0, 0, 0], dtype=tf.uint8),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
frames_2 = frames_2.write(
|
||||||
|
frames_2.size(),
|
||||||
|
tf.cond(
|
||||||
|
tf.equal(image_mask[2], 1),
|
||||||
|
lambda: step["observation"][image_keys[2]],
|
||||||
|
lambda: tf.zeros([0, 0, 0], dtype=tf.uint8),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
frames_3 = frames_3.write(
|
||||||
|
frames_3.size(),
|
||||||
|
tf.cond(
|
||||||
|
tf.equal(image_mask[3], 1),
|
||||||
|
lambda: step["observation"][image_keys[3]],
|
||||||
|
lambda: tf.zeros([0, 0, 0], dtype=tf.uint8),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate the past_frames_0 for each step
|
||||||
|
# Each step has a window of previous frames with size IMG_HISTORY_SIZE
|
||||||
|
# Use the first state to pad the frames
|
||||||
|
# past_frames_0 will have shape (num_steps, IMG_HISTORY_SIZE, height, width, channels)
|
||||||
|
frames_0 = frames_0.stack()
|
||||||
|
first_frame = tf.expand_dims(frames_0[0], axis=0)
|
||||||
|
first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE - 1, axis=0)
|
||||||
|
padded_frames_0 = tf.concat([first_frame, frames_0], axis=0)
|
||||||
|
indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_0)[0] + IMG_HISTORY_SIZE)
|
||||||
|
past_frames_0 = tf.map_fn(lambda i: padded_frames_0[i - IMG_HISTORY_SIZE:i], indices, dtype=tf.uint8)
|
||||||
|
frames_0_time_mask = tf.ones([tf.shape(frames_0)[0]], dtype=tf.bool)
|
||||||
|
padded_frames_0_time_mask = tf.pad(
|
||||||
|
frames_0_time_mask,
|
||||||
|
[[IMG_HISTORY_SIZE - 1, 0]],
|
||||||
|
"CONSTANT",
|
||||||
|
constant_values=False,
|
||||||
|
)
|
||||||
|
past_frames_0_time_mask = tf.map_fn(
|
||||||
|
lambda i: padded_frames_0_time_mask[i - IMG_HISTORY_SIZE:i],
|
||||||
|
indices,
|
||||||
|
dtype=tf.bool,
|
||||||
|
)
|
||||||
|
|
||||||
|
# For past_frames_1
|
||||||
|
frames_1 = frames_1.stack()
|
||||||
|
first_frame = tf.expand_dims(frames_1[0], axis=0)
|
||||||
|
first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE - 1, axis=0)
|
||||||
|
padded_frames_1 = tf.concat([first_frame, frames_1], axis=0)
|
||||||
|
indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_1)[0] + IMG_HISTORY_SIZE)
|
||||||
|
past_frames_1 = tf.map_fn(lambda i: padded_frames_1[i - IMG_HISTORY_SIZE:i], indices, dtype=tf.uint8)
|
||||||
|
frames_1_time_mask = tf.ones([tf.shape(frames_1)[0]], dtype=tf.bool)
|
||||||
|
padded_frames_1_time_mask = tf.pad(
|
||||||
|
frames_1_time_mask,
|
||||||
|
[[IMG_HISTORY_SIZE - 1, 0]],
|
||||||
|
"CONSTANT",
|
||||||
|
constant_values=False,
|
||||||
|
)
|
||||||
|
past_frames_1_time_mask = tf.map_fn(
|
||||||
|
lambda i: padded_frames_1_time_mask[i - IMG_HISTORY_SIZE:i],
|
||||||
|
indices,
|
||||||
|
dtype=tf.bool,
|
||||||
|
)
|
||||||
|
|
||||||
|
# For past_frames_2
|
||||||
|
frames_2 = frames_2.stack()
|
||||||
|
first_frame = tf.expand_dims(frames_2[0], axis=0)
|
||||||
|
first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE - 1, axis=0)
|
||||||
|
padded_frames_2 = tf.concat([first_frame, frames_2], axis=0)
|
||||||
|
indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_2)[0] + IMG_HISTORY_SIZE)
|
||||||
|
past_frames_2 = tf.map_fn(lambda i: padded_frames_2[i - IMG_HISTORY_SIZE:i], indices, dtype=tf.uint8)
|
||||||
|
frames_2_time_mask = tf.ones([tf.shape(frames_2)[0]], dtype=tf.bool)
|
||||||
|
padded_frames_2_time_mask = tf.pad(
|
||||||
|
frames_2_time_mask,
|
||||||
|
[[IMG_HISTORY_SIZE - 1, 0]],
|
||||||
|
"CONSTANT",
|
||||||
|
constant_values=False,
|
||||||
|
)
|
||||||
|
past_frames_2_time_mask = tf.map_fn(
|
||||||
|
lambda i: padded_frames_2_time_mask[i - IMG_HISTORY_SIZE:i],
|
||||||
|
indices,
|
||||||
|
dtype=tf.bool,
|
||||||
|
)
|
||||||
|
|
||||||
|
# For past_frames_3
|
||||||
|
frames_3 = frames_3.stack()
|
||||||
|
first_frame = tf.expand_dims(frames_3[0], axis=0)
|
||||||
|
first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE - 1, axis=0)
|
||||||
|
padded_frames_3 = tf.concat([first_frame, frames_3], axis=0)
|
||||||
|
indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_3)[0] + IMG_HISTORY_SIZE)
|
||||||
|
past_frames_3 = tf.map_fn(lambda i: padded_frames_3[i - IMG_HISTORY_SIZE:i], indices, dtype=tf.uint8)
|
||||||
|
frames_3_time_mask = tf.ones([tf.shape(frames_3)[0]], dtype=tf.bool)
|
||||||
|
padded_frames_3_time_mask = tf.pad(
|
||||||
|
frames_3_time_mask,
|
||||||
|
[[IMG_HISTORY_SIZE - 1, 0]],
|
||||||
|
"CONSTANT",
|
||||||
|
constant_values=False,
|
||||||
|
)
|
||||||
|
past_frames_3_time_mask = tf.map_fn(
|
||||||
|
lambda i: padded_frames_3_time_mask[i - IMG_HISTORY_SIZE:i],
|
||||||
|
indices,
|
||||||
|
dtype=tf.bool,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Creat the ids for each step
|
||||||
|
step_id = tf.range(0, tf.shape(frames_0)[0])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"dataset_name": dataset_name,
|
||||||
|
"episode_dict": epsd,
|
||||||
|
"step_id": step_id,
|
||||||
|
"past_frames_0": past_frames_0,
|
||||||
|
"past_frames_0_time_mask": past_frames_0_time_mask,
|
||||||
|
"past_frames_1": past_frames_1,
|
||||||
|
"past_frames_1_time_mask": past_frames_1_time_mask,
|
||||||
|
"past_frames_2": past_frames_2,
|
||||||
|
"past_frames_2_time_mask": past_frames_2_time_mask,
|
||||||
|
"past_frames_3": past_frames_3,
|
||||||
|
"past_frames_3_time_mask": past_frames_3_time_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@tf.function
|
||||||
|
def bgr_to_rgb(epsd: dict):
|
||||||
|
"""
|
||||||
|
Convert BGR images to RGB images.
|
||||||
|
"""
|
||||||
|
past_frames_0 = epsd["past_frames_0"]
|
||||||
|
past_frames_0 = tf.cond(
|
||||||
|
tf.equal(tf.shape(past_frames_0)[-1], 3),
|
||||||
|
lambda: tf.stack(
|
||||||
|
[past_frames_0[..., 2], past_frames_0[..., 1], past_frames_0[..., 0]],
|
||||||
|
axis=-1,
|
||||||
|
),
|
||||||
|
lambda: past_frames_0,
|
||||||
|
)
|
||||||
|
|
||||||
|
past_frames_1 = epsd["past_frames_1"]
|
||||||
|
past_frames_1 = tf.cond(
|
||||||
|
tf.equal(tf.shape(past_frames_1)[-1], 3),
|
||||||
|
lambda: tf.stack(
|
||||||
|
[past_frames_1[..., 2], past_frames_1[..., 1], past_frames_1[..., 0]],
|
||||||
|
axis=-1,
|
||||||
|
),
|
||||||
|
lambda: past_frames_1,
|
||||||
|
)
|
||||||
|
|
||||||
|
past_frames_2 = epsd["past_frames_2"]
|
||||||
|
past_frames_2 = tf.cond(
|
||||||
|
tf.equal(tf.shape(past_frames_2)[-1], 3),
|
||||||
|
lambda: tf.stack(
|
||||||
|
[past_frames_2[..., 2], past_frames_2[..., 1], past_frames_2[..., 0]],
|
||||||
|
axis=-1,
|
||||||
|
),
|
||||||
|
lambda: past_frames_2,
|
||||||
|
)
|
||||||
|
|
||||||
|
past_frames_3 = epsd["past_frames_3"]
|
||||||
|
past_frames_3 = tf.cond(
|
||||||
|
tf.equal(tf.shape(past_frames_3)[-1], 3),
|
||||||
|
lambda: tf.stack(
|
||||||
|
[past_frames_3[..., 2], past_frames_3[..., 1], past_frames_3[..., 0]],
|
||||||
|
axis=-1,
|
||||||
|
),
|
||||||
|
lambda: past_frames_3,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"dataset_name": epsd["dataset_name"],
|
||||||
|
"episode_dict": epsd["episode_dict"],
|
||||||
|
"step_id": epsd["step_id"],
|
||||||
|
"past_frames_0": past_frames_0,
|
||||||
|
"past_frames_0_time_mask": epsd["past_frames_0_time_mask"],
|
||||||
|
"past_frames_1": past_frames_1,
|
||||||
|
"past_frames_1_time_mask": epsd["past_frames_1_time_mask"],
|
||||||
|
"past_frames_2": past_frames_2,
|
||||||
|
"past_frames_2_time_mask": epsd["past_frames_2_time_mask"],
|
||||||
|
"past_frames_3": past_frames_3,
|
||||||
|
"past_frames_3_time_mask": epsd["past_frames_3_time_mask"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_episode(episode: dict) -> tf.data.Dataset:
|
||||||
|
"""
|
||||||
|
Flatten the episode to a list of steps.
|
||||||
|
"""
|
||||||
|
episode_dict = episode["episode_dict"]
|
||||||
|
dataset_name = episode["dataset_name"]
|
||||||
|
|
||||||
|
json_content, states, masks = generate_json_state(episode_dict, dataset_name)
|
||||||
|
|
||||||
|
# Calculate the past_states for each step
|
||||||
|
# Each step has a window of previous states with size ACTION_CHUNK_SIZE
|
||||||
|
# Use the first state to pad the states
|
||||||
|
# past_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim)
|
||||||
|
first_state = tf.expand_dims(states[0], axis=0)
|
||||||
|
first_state = tf.repeat(first_state, ACTION_CHUNK_SIZE - 1, axis=0)
|
||||||
|
padded_states = tf.concat([first_state, states], axis=0)
|
||||||
|
indices = tf.range(ACTION_CHUNK_SIZE, tf.shape(states)[0] + ACTION_CHUNK_SIZE)
|
||||||
|
past_states = tf.map_fn(lambda i: padded_states[i - ACTION_CHUNK_SIZE:i], indices, dtype=tf.float32)
|
||||||
|
states_time_mask = tf.ones([tf.shape(states)[0]], dtype=tf.bool)
|
||||||
|
padded_states_time_mask = tf.pad(
|
||||||
|
states_time_mask,
|
||||||
|
[[ACTION_CHUNK_SIZE - 1, 0]],
|
||||||
|
"CONSTANT",
|
||||||
|
constant_values=False,
|
||||||
|
)
|
||||||
|
past_states_time_mask = tf.map_fn(
|
||||||
|
lambda i: padded_states_time_mask[i - ACTION_CHUNK_SIZE:i],
|
||||||
|
indices,
|
||||||
|
dtype=tf.bool,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate the future_states for each step
|
||||||
|
# Each step has a window of future states with size ACTION_CHUNK_SIZE
|
||||||
|
# Use the last state to pad the states
|
||||||
|
# future_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim)
|
||||||
|
last_state = tf.expand_dims(states[-1], axis=0)
|
||||||
|
last_state = tf.repeat(last_state, ACTION_CHUNK_SIZE, axis=0)
|
||||||
|
padded_states = tf.concat([states, last_state], axis=0)
|
||||||
|
indices = tf.range(1, tf.shape(states)[0] + 1)
|
||||||
|
future_states = tf.map_fn(lambda i: padded_states[i:i + ACTION_CHUNK_SIZE], indices, dtype=tf.float32)
|
||||||
|
states_time_mask = tf.ones([tf.shape(states)[0]], dtype=tf.bool)
|
||||||
|
padded_states_time_mask = tf.pad(states_time_mask, [[0, ACTION_CHUNK_SIZE]], "CONSTANT", constant_values=False)
|
||||||
|
future_states_time_mask = tf.map_fn(
|
||||||
|
lambda i: padded_states_time_mask[i:i + ACTION_CHUNK_SIZE],
|
||||||
|
indices,
|
||||||
|
dtype=tf.bool,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate the mean and std for state
|
||||||
|
state_std = tf.math.reduce_std(states, axis=0, keepdims=True)
|
||||||
|
state_std = tf.repeat(state_std, tf.shape(states)[0], axis=0)
|
||||||
|
state_mean = tf.math.reduce_mean(states, axis=0, keepdims=True)
|
||||||
|
state_mean = tf.repeat(state_mean, tf.shape(states)[0], axis=0)
|
||||||
|
|
||||||
|
state_norm = tf.math.reduce_mean(tf.math.square(states), axis=0, keepdims=True)
|
||||||
|
state_norm = tf.math.sqrt(state_norm)
|
||||||
|
state_norm = tf.repeat(state_norm, tf.shape(states)[0], axis=0)
|
||||||
|
|
||||||
|
# Create a list of steps
|
||||||
|
step_data = []
|
||||||
|
for i in range(tf.shape(states)[0]):
|
||||||
|
step_data.append({
|
||||||
|
"step_id": episode["step_id"][i],
|
||||||
|
"json_content": json_content,
|
||||||
|
"state_chunk": past_states[i],
|
||||||
|
"state_chunk_time_mask": past_states_time_mask[i],
|
||||||
|
"action_chunk": future_states[i],
|
||||||
|
"action_chunk_time_mask": future_states_time_mask[i],
|
||||||
|
"state_vec_mask": masks[i],
|
||||||
|
"past_frames_0": episode["past_frames_0"][i],
|
||||||
|
"past_frames_0_time_mask": episode["past_frames_0_time_mask"][i],
|
||||||
|
"past_frames_1": episode["past_frames_1"][i],
|
||||||
|
"past_frames_1_time_mask": episode["past_frames_1_time_mask"][i],
|
||||||
|
"past_frames_2": episode["past_frames_2"][i],
|
||||||
|
"past_frames_2_time_mask": episode["past_frames_2_time_mask"][i],
|
||||||
|
"past_frames_3": episode["past_frames_3"][i],
|
||||||
|
"past_frames_3_time_mask": episode["past_frames_3_time_mask"][i],
|
||||||
|
"state_std": state_std[i],
|
||||||
|
"state_mean": state_mean[i],
|
||||||
|
"state_norm": state_norm[i],
|
||||||
|
})
|
||||||
|
|
||||||
|
return step_data
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_episode_agilex(episode: dict) -> tf.data.Dataset:
|
||||||
|
"""
|
||||||
|
Flatten the episode to a list of steps.
|
||||||
|
"""
|
||||||
|
episode_dict = episode["episode_dict"]
|
||||||
|
dataset_name = episode["dataset_name"]
|
||||||
|
|
||||||
|
json_content, states, masks, acts = generate_json_state(episode_dict, dataset_name)
|
||||||
|
|
||||||
|
# Calculate the past_states for each step
|
||||||
|
# Each step has a window of previous states with size ACTION_CHUNK_SIZE
|
||||||
|
# Use the first state to pad the states
|
||||||
|
# past_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim)
|
||||||
|
first_state = tf.expand_dims(states[0], axis=0)
|
||||||
|
first_state = tf.repeat(first_state, ACTION_CHUNK_SIZE - 1, axis=0)
|
||||||
|
padded_states = tf.concat([first_state, states], axis=0)
|
||||||
|
indices = tf.range(ACTION_CHUNK_SIZE, tf.shape(states)[0] + ACTION_CHUNK_SIZE)
|
||||||
|
past_states = tf.map_fn(lambda i: padded_states[i - ACTION_CHUNK_SIZE:i], indices, dtype=tf.float32)
|
||||||
|
states_time_mask = tf.ones([tf.shape(states)[0]], dtype=tf.bool)
|
||||||
|
padded_states_time_mask = tf.pad(
|
||||||
|
states_time_mask,
|
||||||
|
[[ACTION_CHUNK_SIZE - 1, 0]],
|
||||||
|
"CONSTANT",
|
||||||
|
constant_values=False,
|
||||||
|
)
|
||||||
|
past_states_time_mask = tf.map_fn(
|
||||||
|
lambda i: padded_states_time_mask[i - ACTION_CHUNK_SIZE:i],
|
||||||
|
indices,
|
||||||
|
dtype=tf.bool,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE bg the future states shall be actions
|
||||||
|
# Calculate the future_states for each step
|
||||||
|
# Each step has a window of future states with size ACTION_CHUNK_SIZE
|
||||||
|
# Use the last action to pad the states
|
||||||
|
# future_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim)
|
||||||
|
last_act = tf.expand_dims(acts[-1], axis=0)
|
||||||
|
last_act = tf.repeat(last_act, ACTION_CHUNK_SIZE, axis=0)
|
||||||
|
padded_states = tf.concat([acts, last_act], axis=0)
|
||||||
|
# indices = tf.range(1, tf.shape(states)[0] + 1)
|
||||||
|
indices = tf.range(0, tf.shape(acts)[0]) # NOTE time 0 action = time 1 state
|
||||||
|
future_states = tf.map_fn(lambda i: padded_states[i:i + ACTION_CHUNK_SIZE], indices, dtype=tf.float32)
|
||||||
|
states_time_mask = tf.ones([tf.shape(acts)[0]], dtype=tf.bool)
|
||||||
|
padded_states_time_mask = tf.pad(states_time_mask, [[0, ACTION_CHUNK_SIZE]], "CONSTANT", constant_values=False)
|
||||||
|
future_states_time_mask = tf.map_fn(
|
||||||
|
lambda i: padded_states_time_mask[i:i + ACTION_CHUNK_SIZE],
|
||||||
|
indices,
|
||||||
|
dtype=tf.bool,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate the std and mean for state
|
||||||
|
state_std = tf.math.reduce_std(states, axis=0, keepdims=True)
|
||||||
|
state_std = tf.repeat(state_std, tf.shape(states)[0], axis=0)
|
||||||
|
state_mean = tf.math.reduce_mean(states, axis=0, keepdims=True)
|
||||||
|
state_mean = tf.repeat(state_mean, tf.shape(states)[0], axis=0)
|
||||||
|
|
||||||
|
state_norm = tf.math.reduce_mean(tf.math.square(acts), axis=0, keepdims=True)
|
||||||
|
state_norm = tf.math.sqrt(state_norm)
|
||||||
|
state_norm = tf.repeat(state_norm, tf.shape(states)[0], axis=0)
|
||||||
|
|
||||||
|
# Create a list of steps
|
||||||
|
step_data = []
|
||||||
|
for i in range(tf.shape(states)[0]):
|
||||||
|
step_data.append({
|
||||||
|
"step_id": episode["step_id"][i],
|
||||||
|
"json_content": json_content,
|
||||||
|
"state_chunk": past_states[i],
|
||||||
|
"state_chunk_time_mask": past_states_time_mask[i],
|
||||||
|
"action_chunk": future_states[i],
|
||||||
|
"action_chunk_time_mask": future_states_time_mask[i],
|
||||||
|
"state_vec_mask": masks[i],
|
||||||
|
"past_frames_0": episode["past_frames_0"][i],
|
||||||
|
"past_frames_0_time_mask": episode["past_frames_0_time_mask"][i],
|
||||||
|
"past_frames_1": episode["past_frames_1"][i],
|
||||||
|
"past_frames_1_time_mask": episode["past_frames_1_time_mask"][i],
|
||||||
|
"past_frames_2": episode["past_frames_2"][i],
|
||||||
|
"past_frames_2_time_mask": episode["past_frames_2_time_mask"][i],
|
||||||
|
"past_frames_3": episode["past_frames_3"][i],
|
||||||
|
"past_frames_3_time_mask": episode["past_frames_3_time_mask"][i],
|
||||||
|
"state_std": state_std[i],
|
||||||
|
"state_mean": state_mean[i],
|
||||||
|
"state_norm": state_norm[i],
|
||||||
|
})
|
||||||
|
|
||||||
|
return step_data
|
||||||
25
RDT-170M/data/filelock.py
Normal file
25
RDT-170M/data/filelock.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
import fcntl
|
||||||
|
|
||||||
|
|
||||||
|
class FileLock:
|
||||||
|
"""
|
||||||
|
A file lock class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, filename):
|
||||||
|
self.filename = filename
|
||||||
|
self.handle = None
|
||||||
|
|
||||||
|
def acquire_read_lock(self):
|
||||||
|
self.handle = open(self.filename + ".lock", "r")
|
||||||
|
fcntl.flock(self.handle, fcntl.LOCK_SH | fcntl.LOCK_NB)
|
||||||
|
|
||||||
|
def acquire_write_lock(self):
|
||||||
|
self.handle = open(self.filename + ".lock", "w")
|
||||||
|
fcntl.flock(self.handle, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||||
|
|
||||||
|
def release_lock(self):
|
||||||
|
if self.handle is not None:
|
||||||
|
fcntl.flock(self.handle, fcntl.LOCK_UN)
|
||||||
|
self.handle.close()
|
||||||
|
self.handle = None
|
||||||
372
RDT-170M/data/hdf5_vla_dataset.py
Normal file
372
RDT-170M/data/hdf5_vla_dataset.py
Normal file
@ -0,0 +1,372 @@
|
|||||||
|
import os
|
||||||
|
import fnmatch
|
||||||
|
import json
|
||||||
|
|
||||||
|
import h5py
|
||||||
|
import yaml
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from configs.state_vec import STATE_VEC_IDX_MAPPING
|
||||||
|
|
||||||
|
class HDF5VLADataset:
|
||||||
|
"""
|
||||||
|
This class is used to sample episodes from the embododiment dataset
|
||||||
|
stored in HDF5.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_config_path) -> None:
|
||||||
|
# [Modify] The path to the HDF5 dataset directory
|
||||||
|
# Each HDF5 file contains one episode
|
||||||
|
with open(model_config_path, "r") as f:
|
||||||
|
model_config = yaml.safe_load(f)
|
||||||
|
HDF5_DIR = model_config["data_path"]
|
||||||
|
self.DATASET_NAME = "agilex"
|
||||||
|
|
||||||
|
self.file_paths = []
|
||||||
|
for root, _, files in os.walk(HDF5_DIR):
|
||||||
|
for filename in fnmatch.filter(files, "*.hdf5"):
|
||||||
|
file_path = os.path.join(root, filename)
|
||||||
|
self.file_paths.append(file_path)
|
||||||
|
|
||||||
|
# Load the config
|
||||||
|
with open("configs/base.yaml", "r") as file:
|
||||||
|
config = yaml.safe_load(file)
|
||||||
|
self.CHUNK_SIZE = config["common"]["action_chunk_size"]
|
||||||
|
self.IMG_HISORY_SIZE = config["common"]["img_history_size"]
|
||||||
|
self.STATE_DIM = config["common"]["state_dim"]
|
||||||
|
|
||||||
|
# Get each episode's len (use original length, not standardized length)
|
||||||
|
episode_lens = []
|
||||||
|
for file_path in self.file_paths:
|
||||||
|
try:
|
||||||
|
with h5py.File(file_path, "r") as f:
|
||||||
|
qpos = f["observations"]["qpos"][:]
|
||||||
|
num_steps = qpos.shape[0]
|
||||||
|
episode_lens.append(num_steps)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Could not read {file_path}: {e}")
|
||||||
|
episode_lens.append(0)
|
||||||
|
self.episode_sample_weights = np.array(episode_lens) / np.sum(episode_lens)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.file_paths)
|
||||||
|
|
||||||
|
def get_dataset_name(self):
|
||||||
|
return self.DATASET_NAME
|
||||||
|
|
||||||
|
def get_item(self, index: int = None, state_only=False):
|
||||||
|
"""Get a training sample at a random timestep.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index (int, optional): the index of the episode.
|
||||||
|
If not provided, a random episode will be selected.
|
||||||
|
state_only (bool, optional): Whether to return only the state.
|
||||||
|
In this way, the sample will contain a complete trajectory rather
|
||||||
|
than a single timestep. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
sample (dict): a dictionary containing the training sample.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
if index is None:
|
||||||
|
file_path = np.random.choice(self.file_paths, p=self.episode_sample_weights)
|
||||||
|
else:
|
||||||
|
file_path = self.file_paths[index]
|
||||||
|
valid, sample = (self.parse_hdf5_file(file_path)
|
||||||
|
if not state_only else self.parse_hdf5_file_state_only(file_path))
|
||||||
|
if valid:
|
||||||
|
return sample
|
||||||
|
else:
|
||||||
|
index = np.random.randint(0, len(self.file_paths))
|
||||||
|
|
||||||
|
def parse_hdf5_file(self, file_path):
|
||||||
|
"""[Modify] Parse a hdf5 file to generate a training sample at
|
||||||
|
a random timestep.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (str): the path to the hdf5 file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
valid (bool): whether the episode is valid, which is useful for filtering.
|
||||||
|
If False, this episode will be dropped.
|
||||||
|
dict: a dictionary containing the training sample,
|
||||||
|
{
|
||||||
|
"meta": {
|
||||||
|
"dataset_name": str, # the name of your dataset.
|
||||||
|
"#steps": int, # the number of steps in the episode,
|
||||||
|
# also the total timesteps.
|
||||||
|
"instruction": str # the language instruction for this episode.
|
||||||
|
},
|
||||||
|
"step_id": int, # the index of the sampled step,
|
||||||
|
# also the timestep t.
|
||||||
|
"state": ndarray, # state[t], (1, STATE_DIM).
|
||||||
|
"state_std": ndarray, # std(state[:]), (STATE_DIM,).
|
||||||
|
"state_mean": ndarray, # mean(state[:]), (STATE_DIM,).
|
||||||
|
"state_norm": ndarray, # norm(state[:]), (STATE_DIM,).
|
||||||
|
"actions": ndarray, # action[t:t+CHUNK_SIZE], (CHUNK_SIZE, STATE_DIM).
|
||||||
|
"state_indicator", ndarray, # indicates the validness of each dim, (STATE_DIM,).
|
||||||
|
"cam_high": ndarray, # external camera image, (IMG_HISORY_SIZE, H, W, 3)
|
||||||
|
# or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
|
||||||
|
"cam_high_mask": ndarray, # indicates the validness of each timestep, (IMG_HISORY_SIZE,) boolean array.
|
||||||
|
# For the first IMAGE_HISTORY_SIZE-1 timesteps, the mask should be False.
|
||||||
|
"cam_left_wrist": ndarray, # left wrist camera image, (IMG_HISORY_SIZE, H, W, 3).
|
||||||
|
# or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
|
||||||
|
"cam_left_wrist_mask": ndarray,
|
||||||
|
"cam_right_wrist": ndarray, # right wrist camera image, (IMG_HISORY_SIZE, H, W, 3).
|
||||||
|
# or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
|
||||||
|
# If only one wrist, make it right wrist, plz.
|
||||||
|
"cam_right_wrist_mask": ndarray
|
||||||
|
} or None if the episode is invalid.
|
||||||
|
"""
|
||||||
|
with h5py.File(file_path, "r") as f:
|
||||||
|
qpos = f["observations"]["qpos"][:]
|
||||||
|
left_arm_dim = f["observations"]["left_arm_dim"][:]
|
||||||
|
right_arm_dim = f["observations"]["right_arm_dim"][:]
|
||||||
|
num_steps = qpos.shape[0]
|
||||||
|
action_dim = qpos
|
||||||
|
# [Optional] We drop too-short episode
|
||||||
|
# if num_steps < 128:
|
||||||
|
# return False, None
|
||||||
|
|
||||||
|
# [Optional] We skip the first few still steps
|
||||||
|
EPS = 1e-2
|
||||||
|
# Get the idx of the first qpos whose delta exceeds the threshold
|
||||||
|
qpos_delta = np.abs(qpos - qpos[0:1])
|
||||||
|
indices = np.where(np.any(qpos_delta > EPS, axis=1))[0]
|
||||||
|
if len(indices) > 0:
|
||||||
|
first_idx = indices[0]
|
||||||
|
else:
|
||||||
|
raise ValueError("Found no qpos that exceeds the threshold.")
|
||||||
|
|
||||||
|
# We randomly sample a timestep
|
||||||
|
step_id = np.random.randint(first_idx - 1, num_steps)
|
||||||
|
|
||||||
|
# Load the instruction
|
||||||
|
dir_path = os.path.dirname(file_path)
|
||||||
|
|
||||||
|
# with open(os.path.join(dir_path, 'instruction.json'), 'r') as f_instr:
|
||||||
|
# instruction_dict = json.load(f_instr)
|
||||||
|
# # We have 1/3 prob to use original instruction,
|
||||||
|
# # 1/3 to use simplified instruction,
|
||||||
|
# # and 1/3 to use expanded instruction.
|
||||||
|
# instruction_type = np.random.choice([
|
||||||
|
# 'instruction', 'expanded_instruction'])
|
||||||
|
# instruction = instruction_dict[instruction_type]
|
||||||
|
# if isinstance(instruction, list):
|
||||||
|
# instruction = np.random.choice(instruction)
|
||||||
|
|
||||||
|
# You can also use precomputed language embeddings (recommended)
|
||||||
|
# instruction = "path/to/lang_embed.pt"
|
||||||
|
instructions_path = os.path.join(dir_path, "instructions")
|
||||||
|
instructions_names = []
|
||||||
|
|
||||||
|
for filename in os.listdir(instructions_path):
|
||||||
|
# 检查文件名是否以.pt结尾
|
||||||
|
if filename.endswith(".pt"):
|
||||||
|
instructions_names.append(os.path.join(instructions_path, filename))
|
||||||
|
instruction = np.random.choice(instructions_names)
|
||||||
|
# print(f"choose {instruction} file as instruction.")
|
||||||
|
# Assemble the meta
|
||||||
|
meta = {
|
||||||
|
"dataset_name": self.DATASET_NAME,
|
||||||
|
"#steps": num_steps,
|
||||||
|
"step_id": step_id,
|
||||||
|
"instruction": instruction,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Rescale gripper to [0, 1]
|
||||||
|
# qpos = qpos / np.array([[1 for i in range(left_arm_dim[0] + 1 + right_arm_dim[0] + 1)]])
|
||||||
|
# target_qpos = f["action"][step_id:step_id + self.CHUNK_SIZE] / np.array(
|
||||||
|
# [[1 for i in range(left_arm_dim[0] + 1 + right_arm_dim[0] + 1)]])
|
||||||
|
|
||||||
|
qpos = qpos / np.array(
|
||||||
|
# [[1, 1, 1, 1, 1, 1, 4.7908, 1, 1, 1, 1, 1, 1, 4.7888]]
|
||||||
|
[[180, 180, 180, 180, 180, 180]]
|
||||||
|
)
|
||||||
|
target_qpos = f['action'][step_id:step_id + self.CHUNK_SIZE] / np.array(
|
||||||
|
# [[1, 1, 1, 1, 1, 1, 11.8997, 1, 1, 1, 1, 1, 1, 13.9231]]
|
||||||
|
[[180, 180, 180, 180, 180, 180]]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse the state and action
|
||||||
|
state = qpos[step_id:step_id + 1]
|
||||||
|
state_std = np.std(qpos, axis=0)
|
||||||
|
state_mean = np.mean(qpos, axis=0)
|
||||||
|
state_norm = np.sqrt(np.mean(qpos**2, axis=0))
|
||||||
|
actions = target_qpos
|
||||||
|
if actions.shape[0] < self.CHUNK_SIZE:
|
||||||
|
# Pad the actions using the last action
|
||||||
|
actions = np.concatenate(
|
||||||
|
[
|
||||||
|
actions,
|
||||||
|
np.tile(actions[-1:], (self.CHUNK_SIZE - actions.shape[0], 1)),
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fill the state/action into the unified vector
|
||||||
|
|
||||||
|
def fill_in_state(values):
|
||||||
|
# Target indices corresponding to your state space
|
||||||
|
# In this example: 6 joints + 1 gripper for each arm
|
||||||
|
UNI_STATE_INDICES = [
|
||||||
|
STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(6)
|
||||||
|
# ] + [
|
||||||
|
# STATE_VEC_IDX_MAPPING["right_gripper_open"]
|
||||||
|
]
|
||||||
|
uni_vec = np.zeros(values.shape[:-1] + (self.STATE_DIM, ))
|
||||||
|
uni_vec[..., UNI_STATE_INDICES] = values
|
||||||
|
return uni_vec
|
||||||
|
|
||||||
|
state = fill_in_state(state)
|
||||||
|
state_indicator = fill_in_state(np.ones_like(state_std))
|
||||||
|
state_std = fill_in_state(state_std)
|
||||||
|
state_mean = fill_in_state(state_mean)
|
||||||
|
state_norm = fill_in_state(state_norm)
|
||||||
|
# If action's format is different from state's,
|
||||||
|
# you may implement fill_in_action()
|
||||||
|
actions = fill_in_state(actions)
|
||||||
|
|
||||||
|
# Parse the images
|
||||||
|
def parse_img(key):
|
||||||
|
imgs = []
|
||||||
|
for i in range(max(step_id - self.IMG_HISORY_SIZE + 1, 0), step_id + 1):
|
||||||
|
img_bits = f["observations"]["images"][key][i]
|
||||||
|
img = cv2.imdecode(np.frombuffer(img_bits, np.uint8), cv2.IMREAD_COLOR)
|
||||||
|
imgs.append(img)
|
||||||
|
imgs = np.stack(imgs)
|
||||||
|
if imgs.shape[0] < self.IMG_HISORY_SIZE:
|
||||||
|
# Pad the images using the first image
|
||||||
|
imgs = np.concatenate(
|
||||||
|
[
|
||||||
|
np.tile(
|
||||||
|
imgs[:1],
|
||||||
|
(self.IMG_HISORY_SIZE - imgs.shape[0], 1, 1, 1),
|
||||||
|
),
|
||||||
|
imgs,
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
return imgs
|
||||||
|
|
||||||
|
# `cam_high` is the external camera image
|
||||||
|
cam_high = parse_img("cam_high")
|
||||||
|
# For step_id = first_idx - 1, the valid_len should be one
|
||||||
|
valid_len = min(step_id - (first_idx - 1) + 1, self.IMG_HISORY_SIZE)
|
||||||
|
cam_high_mask = np.array([False] * (self.IMG_HISORY_SIZE - valid_len) + [True] * valid_len)
|
||||||
|
# cam_left_wrist = parse_img("cam_left_wrist")
|
||||||
|
# cam_left_wrist_mask = cam_high_mask.copy()
|
||||||
|
cam_left_wrist = np.zeros((self.IMG_HISORY_SIZE, 0, 0, 0))#parse_img('cam_right_wrist')
|
||||||
|
cam_left_wrist_mask = np.array([False] * self.IMG_HISORY_SIZE)#cam_high_mask.copy()
|
||||||
|
cam_right_wrist = parse_img("cam_right_wrist")
|
||||||
|
cam_right_wrist_mask = cam_high_mask.copy() # 使用相同的掩码逻辑
|
||||||
|
|
||||||
|
# Return the resulting sample
|
||||||
|
# For unavailable images, return zero-shape arrays, i.e., (IMG_HISORY_SIZE, 0, 0, 0)
|
||||||
|
# E.g., return np.zeros((self.IMG_HISORY_SIZE, 0, 0, 0)) for the key "cam_left_wrist",
|
||||||
|
# if the left-wrist camera is unavailable on your robot
|
||||||
|
return True, {
|
||||||
|
"meta": meta,
|
||||||
|
"state": state,
|
||||||
|
"state_std": state_std,
|
||||||
|
"state_mean": state_mean,
|
||||||
|
"state_norm": state_norm,
|
||||||
|
"actions": actions,
|
||||||
|
"state_indicator": state_indicator,
|
||||||
|
"cam_high": cam_high,
|
||||||
|
"cam_high_mask": cam_high_mask,
|
||||||
|
"cam_left_wrist": cam_left_wrist,
|
||||||
|
"cam_left_wrist_mask": cam_left_wrist_mask,
|
||||||
|
"cam_right_wrist": cam_right_wrist,
|
||||||
|
"cam_right_wrist_mask": cam_right_wrist_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
def parse_hdf5_file_state_only(self, file_path):
|
||||||
|
"""[Modify] Parse a hdf5 file to generate a state trajectory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (str): the path to the hdf5 file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
valid (bool): whether the episode is valid, which is useful for filtering.
|
||||||
|
If False, this episode will be dropped.
|
||||||
|
dict: a dictionary containing the training sample,
|
||||||
|
{
|
||||||
|
"state": ndarray, # state[:], (T, STATE_DIM).
|
||||||
|
"action": ndarray, # action[:], (T, STATE_DIM).
|
||||||
|
} or None if the episode is invalid.
|
||||||
|
"""
|
||||||
|
with h5py.File(file_path, "r") as f:
|
||||||
|
qpos = f["observations"]["qpos"][:]
|
||||||
|
left_arm_dim = f["observations"]["left_arm_dim"][:]
|
||||||
|
right_arm_dim = f["observations"]["right_arm_dim"][:]
|
||||||
|
|
||||||
|
num_steps = qpos.shape[0]
|
||||||
|
# [Optional] We drop too-short episode
|
||||||
|
# if num_steps < 128:
|
||||||
|
# return False, None
|
||||||
|
|
||||||
|
# [Optional] We skip the first few still steps
|
||||||
|
EPS = 1e-2
|
||||||
|
# Get the idx of the first qpos whose delta exceeds the threshold
|
||||||
|
qpos_delta = np.abs(qpos - qpos[0:1])
|
||||||
|
indices = np.where(np.any(qpos_delta > EPS, axis=1))[0]
|
||||||
|
if len(indices) > 0:
|
||||||
|
first_idx = indices[0]
|
||||||
|
else:
|
||||||
|
raise ValueError("Found no qpos that exceeds the threshold.")
|
||||||
|
|
||||||
|
# Rescale gripper to [0, 1]
|
||||||
|
# qpos = qpos / np.array([[1 for i in range(left_arm_dim[0] + right_arm_dim[0] + 2)]])
|
||||||
|
# target_qpos = f["action"][:] / np.array([[1 for i in range(left_arm_dim[0] + right_arm_dim[0] + 2)]])
|
||||||
|
|
||||||
|
qpos = qpos / np.array(
|
||||||
|
# [[1, 1, 1, 1, 1, 1, 4.7908, 1, 1, 1, 1, 1, 1, 4.7888]]
|
||||||
|
[[180, 180, 180, 180, 180, 180]]
|
||||||
|
)
|
||||||
|
target_qpos = f['action'][first_idx - 1:] / np.array(
|
||||||
|
# [[1, 1, 1, 1, 1, 1, 11.8997, 1, 1, 1, 1, 1, 1, 13.9231]]
|
||||||
|
[[180, 180, 180, 180, 180, 180]]
|
||||||
|
)
|
||||||
|
# Parse the state and action
|
||||||
|
state = qpos[first_idx - 1:]
|
||||||
|
action = target_qpos[first_idx - 1:]
|
||||||
|
|
||||||
|
# Standardize trajectory length to avoid batch size mismatch
|
||||||
|
# Use a fixed length (e.g., 128) or pad/truncate to match
|
||||||
|
target_length = 128 # You can adjust this value
|
||||||
|
if state.shape[0] > target_length:
|
||||||
|
# Truncate to target length
|
||||||
|
state = state[:target_length]
|
||||||
|
action = action[:target_length]
|
||||||
|
elif state.shape[0] < target_length:
|
||||||
|
# Pad with the last state/action
|
||||||
|
pad_length = target_length - state.shape[0]
|
||||||
|
state = np.concatenate([state, np.tile(state[-1:], (pad_length, 1))], axis=0)
|
||||||
|
action = np.concatenate([action, np.tile(action[-1:], (pad_length, 1))], axis=0)
|
||||||
|
|
||||||
|
# Fill the state/action into the unified vector
|
||||||
|
def fill_in_state(values):
|
||||||
|
# Target indices corresponding to your state space
|
||||||
|
# In this example: 6 joints + 1 gripper for each arm
|
||||||
|
UNI_STATE_INDICES = [
|
||||||
|
STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(6)
|
||||||
|
# ] + [
|
||||||
|
# STATE_VEC_IDX_MAPPING["right_gripper_open"]
|
||||||
|
]
|
||||||
|
uni_vec = np.zeros(values.shape[:-1] + (self.STATE_DIM, ))
|
||||||
|
uni_vec[..., UNI_STATE_INDICES] = values
|
||||||
|
return uni_vec
|
||||||
|
|
||||||
|
state = fill_in_state(state)
|
||||||
|
action = fill_in_state(action)
|
||||||
|
|
||||||
|
# Return the resulting sample
|
||||||
|
return True, {"state": state, "action": action}
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
ds = HDF5VLADataset()
|
||||||
|
for i in range(len(ds)):
|
||||||
|
print(f"Processing episode {i}/{len(ds)}...")
|
||||||
|
ds.get_item(i)
|
||||||
299
RDT-170M/data/preprocess.py
Normal file
299
RDT-170M/data/preprocess.py
Normal file
@ -0,0 +1,299 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from data.preprocess_scripts import *
|
||||||
|
from configs.state_vec import STATE_VEC_IDX_MAPPING, STATE_VEC_LEN
|
||||||
|
from data.utils import capitalize_and_period
|
||||||
|
|
||||||
|
# The dataset without state
|
||||||
|
DATASET_NAMES_NO_STATE = [
|
||||||
|
"nyu_door_opening_surprising_effectiveness",
|
||||||
|
"usc_cloth_sim_converted_externally_to_rlds",
|
||||||
|
"cmu_franka_exploration_dataset_converted_externally_to_rlds",
|
||||||
|
"imperialcollege_sawyer_wrist_cam",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Read the image keys of each dataset
|
||||||
|
with open("configs/dataset_img_keys.json", "r") as file:
|
||||||
|
IMAGE_KEYS = json.load(file)
|
||||||
|
# Read the config
|
||||||
|
with open("configs/base.yaml", "r") as file:
|
||||||
|
config = yaml.safe_load(file)
|
||||||
|
|
||||||
|
|
||||||
|
def assemble_state_vec(arm_concat: tf.Tensor, arm_format: str, base_concat=None, base_format=None) -> tf.Tensor:
|
||||||
|
"""
|
||||||
|
Assemble the state/action vector from the arm and base.
|
||||||
|
"""
|
||||||
|
state_vec = tf.zeros(STATE_VEC_LEN, dtype=tf.float32)
|
||||||
|
mask_vec = tf.zeros(STATE_VEC_LEN, dtype=tf.float32)
|
||||||
|
|
||||||
|
# Assemble the arm state
|
||||||
|
arm_concat = tf.cast(arm_concat, tf.float32)
|
||||||
|
arm_format = arm_format.split(",")
|
||||||
|
# Use the scatter_nd to avoid the duplicate indices
|
||||||
|
state_vec = tf.tensor_scatter_nd_update(state_vec, [[STATE_VEC_IDX_MAPPING[name]] for name in arm_format],
|
||||||
|
arm_concat)
|
||||||
|
mask_vec = tf.tensor_scatter_nd_update(
|
||||||
|
mask_vec,
|
||||||
|
[[STATE_VEC_IDX_MAPPING[name]] for name in arm_format],
|
||||||
|
tf.ones(len(arm_format), dtype=tf.float32),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assemble the base state if exists
|
||||||
|
if base_concat is not None:
|
||||||
|
base_concat = tf.cast(base_concat, tf.float32)
|
||||||
|
base_format = base_format.split(",")
|
||||||
|
state_vec = tf.tensor_scatter_nd_update(
|
||||||
|
state_vec,
|
||||||
|
[[STATE_VEC_IDX_MAPPING[name]] for name in base_format],
|
||||||
|
base_concat,
|
||||||
|
)
|
||||||
|
mask_vec = tf.tensor_scatter_nd_update(
|
||||||
|
mask_vec,
|
||||||
|
[[STATE_VEC_IDX_MAPPING[name]] for name in base_format],
|
||||||
|
tf.ones(len(base_format), dtype=tf.float32),
|
||||||
|
)
|
||||||
|
return state_vec, mask_vec
|
||||||
|
|
||||||
|
|
||||||
|
@tf.autograph.experimental.do_not_convert
|
||||||
|
def _generate_json_state_agilex(episode: dict, dataset_name: str):
|
||||||
|
"""
|
||||||
|
Generate the json dict and state for a given episode.
|
||||||
|
"""
|
||||||
|
# Load some constants from the config
|
||||||
|
IMG_HISTORY_SIZE = config["common"]["img_history_size"]
|
||||||
|
if IMG_HISTORY_SIZE < 1:
|
||||||
|
raise ValueError("Config `img_history_size` must be at least 1.")
|
||||||
|
ACTION_CHUNK_SIZE = config["common"]["action_chunk_size"]
|
||||||
|
if ACTION_CHUNK_SIZE < 1:
|
||||||
|
raise ValueError("Config `action_chunk_size` must be at least 1.")
|
||||||
|
|
||||||
|
# Initialize the episode_metadata
|
||||||
|
episode_metadata = {"dataset_name": dataset_name, "#steps": 0, "instruction": None}
|
||||||
|
|
||||||
|
# Check whether this episode has an 'END'
|
||||||
|
base_act = None
|
||||||
|
last_base_act = None
|
||||||
|
episode_states = []
|
||||||
|
episode_acts = []
|
||||||
|
episode_masks = []
|
||||||
|
has_base = None
|
||||||
|
for step_id, step in enumerate(iter(episode["steps"])):
|
||||||
|
# Parse the action
|
||||||
|
action = step["action"]
|
||||||
|
if has_base is None:
|
||||||
|
has_base = "base_concat" in action
|
||||||
|
if has_base:
|
||||||
|
base_act = action["base_concat"]
|
||||||
|
|
||||||
|
# Parse the state
|
||||||
|
state = step["observation"]
|
||||||
|
|
||||||
|
arm_format = state["format"].numpy().decode("utf-8")
|
||||||
|
base_format = None
|
||||||
|
if has_base:
|
||||||
|
act_format = action["format"].numpy().decode("utf-8")
|
||||||
|
base_formate_idx = act_format.find("base")
|
||||||
|
base_format = act_format[base_formate_idx:]
|
||||||
|
|
||||||
|
arm_state = state["arm_concat"]
|
||||||
|
base_state = None
|
||||||
|
if has_base:
|
||||||
|
if last_base_act is None:
|
||||||
|
base_state = base_act * 0
|
||||||
|
else:
|
||||||
|
base_state = last_base_act
|
||||||
|
last_base_act = base_act
|
||||||
|
|
||||||
|
# Assemble the state vector
|
||||||
|
state_vec, mask_vec = assemble_state_vec(arm_state, arm_format, base_state, base_format)
|
||||||
|
|
||||||
|
act_vec, mask_vec = assemble_state_vec(action["arm_concat"], arm_format, base_state, base_format)
|
||||||
|
|
||||||
|
episode_states.append(state_vec)
|
||||||
|
episode_masks.append(mask_vec)
|
||||||
|
episode_acts.append(act_vec)
|
||||||
|
|
||||||
|
# Parse the task instruction
|
||||||
|
instr = step["observation"]["natural_language_instruction"]
|
||||||
|
instr = instr.numpy().decode("utf-8")
|
||||||
|
instr = capitalize_and_period(instr)
|
||||||
|
|
||||||
|
# Write to the episode_metadata
|
||||||
|
if episode_metadata["instruction"] is None:
|
||||||
|
episode_metadata["instruction"] = instr
|
||||||
|
|
||||||
|
episode_metadata["#steps"] = step_id
|
||||||
|
|
||||||
|
episode_states = tf.stack(episode_states)
|
||||||
|
episode_masks = tf.stack(episode_masks)
|
||||||
|
episode_acts = tf.stack(episode_acts)
|
||||||
|
|
||||||
|
return episode_metadata, episode_states, episode_masks, episode_acts
|
||||||
|
|
||||||
|
|
||||||
|
@tf.autograph.experimental.do_not_convert
|
||||||
|
def _generate_json_state(episode: dict, dataset_name: str):
|
||||||
|
"""
|
||||||
|
Generate the json dict and state for a given episode.
|
||||||
|
"""
|
||||||
|
# Load some constants from the config
|
||||||
|
IMG_HISTORY_SIZE = config["common"]["img_history_size"]
|
||||||
|
if IMG_HISTORY_SIZE < 1:
|
||||||
|
raise ValueError("Config `img_history_size` must be at least 1.")
|
||||||
|
ACTION_CHUNK_SIZE = config["common"]["action_chunk_size"]
|
||||||
|
if ACTION_CHUNK_SIZE < 1:
|
||||||
|
raise ValueError("Config `action_chunk_size` must be at least 1.")
|
||||||
|
|
||||||
|
# Initialize the episode_metadata
|
||||||
|
episode_metadata = {"dataset_name": dataset_name, "#steps": 0, "instruction": None}
|
||||||
|
|
||||||
|
# Check whether this episode has an 'END'
|
||||||
|
base_act = None
|
||||||
|
last_base_act = None
|
||||||
|
episode_states = []
|
||||||
|
episode_masks = []
|
||||||
|
has_base = None
|
||||||
|
for step_id, step in enumerate(iter(episode["steps"])):
|
||||||
|
# Parse the action
|
||||||
|
action = step["action"]
|
||||||
|
if has_base is None:
|
||||||
|
has_base = "base_concat" in action
|
||||||
|
if has_base:
|
||||||
|
base_act = action["base_concat"]
|
||||||
|
|
||||||
|
# Parse the state
|
||||||
|
state = step["observation"]
|
||||||
|
|
||||||
|
arm_format = state["format"].numpy().decode("utf-8")
|
||||||
|
base_format = None
|
||||||
|
if has_base:
|
||||||
|
act_format = action["format"].numpy().decode("utf-8")
|
||||||
|
base_formate_idx = act_format.find("base")
|
||||||
|
base_format = act_format[base_formate_idx:]
|
||||||
|
|
||||||
|
arm_state = state["arm_concat"]
|
||||||
|
base_state = None
|
||||||
|
if has_base:
|
||||||
|
if last_base_act is None:
|
||||||
|
base_state = base_act * 0
|
||||||
|
else:
|
||||||
|
base_state = last_base_act
|
||||||
|
last_base_act = base_act
|
||||||
|
|
||||||
|
# Assemble the state vector
|
||||||
|
state_vec, mask_vec = assemble_state_vec(arm_state, arm_format, base_state, base_format)
|
||||||
|
|
||||||
|
episode_states.append(state_vec)
|
||||||
|
episode_masks.append(mask_vec)
|
||||||
|
|
||||||
|
# Parse the task instruction
|
||||||
|
instr = step["observation"]["natural_language_instruction"]
|
||||||
|
instr = instr.numpy().decode("utf-8")
|
||||||
|
instr = capitalize_and_period(instr)
|
||||||
|
|
||||||
|
# Write to the episode_metadata
|
||||||
|
if episode_metadata["instruction"] is None:
|
||||||
|
episode_metadata["instruction"] = instr
|
||||||
|
|
||||||
|
episode_metadata["#steps"] = step_id
|
||||||
|
episode_states = tf.stack(episode_states)
|
||||||
|
episode_masks = tf.stack(episode_masks)
|
||||||
|
|
||||||
|
return episode_metadata, episode_states, episode_masks
|
||||||
|
|
||||||
|
|
||||||
|
@tf.autograph.experimental.do_not_convert
|
||||||
|
def _generate_json_state_nostate_ds(episode: dict, dataset_name: str):
|
||||||
|
"""
|
||||||
|
Generate the json dict and state for an episode in the dataset without state.
|
||||||
|
If not state, we use the last action as current state.
|
||||||
|
"""
|
||||||
|
# Load some constants from the config
|
||||||
|
IMG_HISTORY_SIZE = config["common"]["img_history_size"]
|
||||||
|
if IMG_HISTORY_SIZE < 1:
|
||||||
|
raise ValueError("Config `img_history_size` must be at least 1.")
|
||||||
|
ACTION_CHUNK_SIZE = config["common"]["action_chunk_size"]
|
||||||
|
if ACTION_CHUNK_SIZE < 1:
|
||||||
|
raise ValueError("Config `action_chunk_size` must be at least 1.")
|
||||||
|
|
||||||
|
# Initialize the episode_metadata
|
||||||
|
episode_metadata = {"dataset_name": dataset_name, "#steps": 0, "instruction": None}
|
||||||
|
|
||||||
|
last_base_act = None
|
||||||
|
last_arm_act = None
|
||||||
|
episode_states = []
|
||||||
|
episode_masks = []
|
||||||
|
has_base = None
|
||||||
|
for step_id, step in enumerate(iter(episode["steps"])):
|
||||||
|
# Parse the action
|
||||||
|
action = step["action"]
|
||||||
|
if has_base is None:
|
||||||
|
has_base = "base_concat" in action
|
||||||
|
if has_base:
|
||||||
|
base_act = action["base_concat"]
|
||||||
|
if last_base_act is None:
|
||||||
|
last_base_act = base_act * 0 # Initialize
|
||||||
|
|
||||||
|
# Parse the arm action
|
||||||
|
arm_act = action["arm_concat"]
|
||||||
|
if last_arm_act is None:
|
||||||
|
last_arm_act = arm_act * 0 # Initialize
|
||||||
|
|
||||||
|
# Parse the act format
|
||||||
|
# Action format as the state format
|
||||||
|
act_format = action["format"].numpy().decode("utf-8")
|
||||||
|
|
||||||
|
# Assemble the state vector
|
||||||
|
if has_base:
|
||||||
|
last_act_concat = tf.concat([last_arm_act, last_base_act], axis=0)
|
||||||
|
else:
|
||||||
|
last_act_concat = last_arm_act
|
||||||
|
state_vec, mask_vec = assemble_state_vec(last_act_concat, act_format)
|
||||||
|
|
||||||
|
episode_states.append(state_vec)
|
||||||
|
episode_masks.append(mask_vec)
|
||||||
|
|
||||||
|
# Parse the task instruction
|
||||||
|
instr = step["observation"]["natural_language_instruction"]
|
||||||
|
instr = instr.numpy().decode("utf-8")
|
||||||
|
instr = capitalize_and_period(instr)
|
||||||
|
|
||||||
|
# Write to the episode_metadata
|
||||||
|
if episode_metadata["instruction"] is None:
|
||||||
|
episode_metadata["instruction"] = instr
|
||||||
|
|
||||||
|
# Update the last_arm_act and last_base_act
|
||||||
|
last_arm_act = arm_act
|
||||||
|
if has_base:
|
||||||
|
last_base_act = base_act
|
||||||
|
|
||||||
|
episode_metadata["#steps"] = step_id
|
||||||
|
episode_states = tf.stack(episode_states)
|
||||||
|
episode_masks = tf.stack(episode_masks)
|
||||||
|
|
||||||
|
return episode_metadata, episode_states, episode_masks
|
||||||
|
|
||||||
|
|
||||||
|
@tf.autograph.experimental.do_not_convert
|
||||||
|
def generate_json_state(episode: dict, dataset_name: str):
|
||||||
|
"""
|
||||||
|
Generate the json dict and state for an episode.
|
||||||
|
"""
|
||||||
|
if isinstance(dataset_name, tf.Tensor):
|
||||||
|
dataset_name = dataset_name.numpy().decode("utf-8")
|
||||||
|
|
||||||
|
# Process each step in the episode
|
||||||
|
episode["steps"] = episode["steps"].map(globals()[dataset_name].process_step, )
|
||||||
|
|
||||||
|
if dataset_name == "agilex":
|
||||||
|
return _generate_json_state_agilex(episode, dataset_name)
|
||||||
|
|
||||||
|
if dataset_name in DATASET_NAMES_NO_STATE:
|
||||||
|
return _generate_json_state_nostate_ds(episode, dataset_name)
|
||||||
|
|
||||||
|
return _generate_json_state(episode, dataset_name)
|
||||||
313
RDT-170M/data/producer.py
Normal file
313
RDT-170M/data/producer.py
Normal file
@ -0,0 +1,313 @@
|
|||||||
|
import time
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
import signal
|
||||||
|
import random
|
||||||
|
from multiprocessing import Process
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from data.vla_dataset import VLADataset
|
||||||
|
from data.filelock import FileLock
|
||||||
|
|
||||||
|
# Producer does not need GPU
|
||||||
|
tf.config.set_visible_devices([], "GPU")
|
||||||
|
|
||||||
|
# Read the config
|
||||||
|
with open("configs/base.yaml", "r") as file:
|
||||||
|
config = yaml.safe_load(file)
|
||||||
|
# Load some constants from the config
|
||||||
|
BUF_PATH = config["dataset"]["buf_path"]
|
||||||
|
BUF_NUM_CHUNKS = config["dataset"]["buf_num_chunks"]
|
||||||
|
if BUF_NUM_CHUNKS < 1:
|
||||||
|
raise ValueError("Config `buf_num_chunks` must be at least 1.")
|
||||||
|
BUF_CHUNK_SIZE = config["dataset"]["buf_chunk_size"]
|
||||||
|
if BUF_CHUNK_SIZE < 1:
|
||||||
|
raise ValueError("Config `buf_chunk_size` must be at least 1.")
|
||||||
|
|
||||||
|
|
||||||
|
def get_dirty_item(chunk_dir):
|
||||||
|
"""
|
||||||
|
Get indexes of dirty items in a chunk.
|
||||||
|
"""
|
||||||
|
dirty_bit = read_dirty_bit(chunk_dir)
|
||||||
|
return np.where(dirty_bit)[0].tolist()
|
||||||
|
|
||||||
|
|
||||||
|
def get_clean_item(chunk_dir):
|
||||||
|
"""
|
||||||
|
Get indexes of clean items in a chunk.
|
||||||
|
"""
|
||||||
|
dirty_bit = read_dirty_bit(chunk_dir)
|
||||||
|
return np.where(1 - dirty_bit)[0].tolist()
|
||||||
|
|
||||||
|
|
||||||
|
def save_dirty_bit(chunk_dir, dirty_bit):
|
||||||
|
"""
|
||||||
|
Save the dirty bit to the chunk directory.
|
||||||
|
"""
|
||||||
|
time_stmp = time.time()
|
||||||
|
while time.time() - time_stmp < 10.0:
|
||||||
|
try:
|
||||||
|
file_path = os.path.join(chunk_dir, "dirty_bit")
|
||||||
|
lock = FileLock(file_path)
|
||||||
|
lock.acquire_write_lock()
|
||||||
|
with open(file_path, "wb") as file:
|
||||||
|
file.write(dirty_bit.tobytes())
|
||||||
|
lock.release_lock()
|
||||||
|
return
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
lock.release_lock()
|
||||||
|
raise KeyboardInterrupt
|
||||||
|
except BaseException:
|
||||||
|
lock.release_lock()
|
||||||
|
continue
|
||||||
|
# raise RuntimeError("Failed to save dirty bit.")
|
||||||
|
print("Failed to save dirty bit.")
|
||||||
|
|
||||||
|
|
||||||
|
def read_dirty_bit(chunk_dir):
|
||||||
|
"""
|
||||||
|
Read the dirty bit from the chunk directory.
|
||||||
|
"""
|
||||||
|
# If error occurs, retry
|
||||||
|
time_stmp = time.time()
|
||||||
|
while time.time() - time_stmp < 10.0:
|
||||||
|
try:
|
||||||
|
file_path = os.path.join(chunk_dir, "dirty_bit")
|
||||||
|
lock = FileLock(file_path)
|
||||||
|
lock.acquire_read_lock()
|
||||||
|
with open(file_path, "rb") as file:
|
||||||
|
dirty_bit = np.frombuffer(file.read(), dtype=np.uint8).copy()
|
||||||
|
lock.release_lock()
|
||||||
|
assert len(dirty_bit) == BUF_CHUNK_SIZE
|
||||||
|
return dirty_bit
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
lock.release_lock()
|
||||||
|
raise KeyboardInterrupt
|
||||||
|
except BaseException:
|
||||||
|
lock.release_lock()
|
||||||
|
continue
|
||||||
|
# If failed to read the dirty bit, return all ones for robustness
|
||||||
|
return np.ones(BUF_CHUNK_SIZE, dtype=np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
def save_sample(step_dict, chunk_dir, chunk_item_idx):
|
||||||
|
"""
|
||||||
|
Save a sample to the chunk directory.
|
||||||
|
"""
|
||||||
|
# Save the json content
|
||||||
|
time_stmp = time.time()
|
||||||
|
while time.time() - time_stmp < 10.0:
|
||||||
|
try:
|
||||||
|
locks = []
|
||||||
|
json_content = step_dict["json_content"]
|
||||||
|
file_path = os.path.join(chunk_dir, f"json_content_{chunk_item_idx}.json")
|
||||||
|
lock = FileLock(file_path)
|
||||||
|
locks.append(lock)
|
||||||
|
lock.acquire_write_lock()
|
||||||
|
with open(file_path, "w") as file:
|
||||||
|
json.dump(json_content, file, indent=4)
|
||||||
|
lock.release_lock()
|
||||||
|
# Save all other tensors in a npz
|
||||||
|
file_path = os.path.join(chunk_dir, f"sample_{chunk_item_idx}.npz")
|
||||||
|
lock = FileLock(file_path)
|
||||||
|
locks.append(lock)
|
||||||
|
lock.acquire_write_lock()
|
||||||
|
with open(file_path, "wb") as file:
|
||||||
|
np.savez(
|
||||||
|
file,
|
||||||
|
step_id=step_dict["step_id"].numpy(),
|
||||||
|
state_chunk=step_dict["state_chunk"].numpy(),
|
||||||
|
state_chunk_time_mask=step_dict["state_chunk_time_mask"].numpy(),
|
||||||
|
action_chunk=step_dict["action_chunk"].numpy(),
|
||||||
|
action_chunk_time_mask=step_dict["action_chunk_time_mask"].numpy(),
|
||||||
|
state_vec_mask=step_dict["state_vec_mask"].numpy(),
|
||||||
|
past_frames_0=step_dict["past_frames_0"].numpy(),
|
||||||
|
past_frames_0_time_mask=step_dict["past_frames_0_time_mask"].numpy(),
|
||||||
|
past_frames_1=step_dict["past_frames_1"].numpy(),
|
||||||
|
past_frames_1_time_mask=step_dict["past_frames_1_time_mask"].numpy(),
|
||||||
|
past_frames_2=step_dict["past_frames_2"].numpy(),
|
||||||
|
past_frames_2_time_mask=step_dict["past_frames_2_time_mask"].numpy(),
|
||||||
|
past_frames_3=step_dict["past_frames_3"].numpy(),
|
||||||
|
past_frames_3_time_mask=step_dict["past_frames_3_time_mask"].numpy(),
|
||||||
|
state_std=step_dict["state_std"].numpy(),
|
||||||
|
state_mean=step_dict["state_mean"].numpy(),
|
||||||
|
state_norm=step_dict["state_norm"].numpy(),
|
||||||
|
)
|
||||||
|
lock.release_lock()
|
||||||
|
return
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
for lock in locks:
|
||||||
|
lock.release_lock()
|
||||||
|
raise KeyboardInterrupt
|
||||||
|
except BaseException:
|
||||||
|
for lock in locks:
|
||||||
|
lock.release_lock()
|
||||||
|
continue
|
||||||
|
# raise RuntimeError("Failed to save sample.")
|
||||||
|
print("Failed to save sample.")
|
||||||
|
|
||||||
|
|
||||||
|
def run_producer(seed, num_workers, worker_id, fill_up, clean_dirty, dataset_type):
|
||||||
|
"""
|
||||||
|
Run the producer.
|
||||||
|
The producer will first fill up the buffer with samples.
|
||||||
|
Then it will keep replacing dirty samples
|
||||||
|
(i.e., samples that have been read by the consumer)
|
||||||
|
with new samples.
|
||||||
|
"""
|
||||||
|
vla_dataset = VLADataset(seed=seed, dataset_type=dataset_type)
|
||||||
|
chunk_start_idx = worker_id * BUF_NUM_CHUNKS // num_workers
|
||||||
|
chunk_end_idx = (worker_id + 1) * BUF_NUM_CHUNKS // num_workers
|
||||||
|
if fill_up:
|
||||||
|
print(f"Worker {worker_id}: Start filling up the buffer...")
|
||||||
|
elif clean_dirty:
|
||||||
|
# Only refresh the dirty bits
|
||||||
|
print(f"Worker {worker_id}: Start refreshing the dirty bits...")
|
||||||
|
for chunk_idx in range(chunk_start_idx, chunk_end_idx):
|
||||||
|
chunk_dir = os.path.join(BUF_PATH, f"chunk_{chunk_idx}")
|
||||||
|
dirty_bit = np.zeros(BUF_CHUNK_SIZE, dtype=np.uint8)
|
||||||
|
save_dirty_bit(chunk_dir, dirty_bit)
|
||||||
|
print(f"Worker {worker_id}: Refreshed the dirty bits.")
|
||||||
|
|
||||||
|
fill_chunk_idx = chunk_start_idx
|
||||||
|
fill_chunk_item_idx = 0
|
||||||
|
dirty_chunk_idx = chunk_start_idx
|
||||||
|
dirty_chunk_item_idxs = []
|
||||||
|
time_stmp = time.time()
|
||||||
|
for episode_steps in vla_dataset:
|
||||||
|
for step in episode_steps:
|
||||||
|
if fill_up and fill_chunk_idx < chunk_end_idx:
|
||||||
|
# Fill up the buffer
|
||||||
|
chunk_dir = os.path.join(BUF_PATH, f"chunk_{fill_chunk_idx}")
|
||||||
|
if fill_chunk_item_idx == 0:
|
||||||
|
# Create a new chunk
|
||||||
|
os.makedirs(chunk_dir, exist_ok=True)
|
||||||
|
# Write the dirty bit of size BUF_CHUNK_SIZE
|
||||||
|
dirty_bit = np.zeros(BUF_CHUNK_SIZE, dtype=np.uint8)
|
||||||
|
save_dirty_bit(chunk_dir, dirty_bit)
|
||||||
|
|
||||||
|
# Save the sample
|
||||||
|
save_sample(step, chunk_dir, fill_chunk_item_idx)
|
||||||
|
|
||||||
|
# print(f"Filled up chunk {fill_chunk_item_idx+1}/{BUF_CHUNK_SIZE} {fill_chunk_idx+1}/{BUF_NUM_CHUNKS}")
|
||||||
|
local_fill_chunk_idx = fill_chunk_idx - chunk_start_idx
|
||||||
|
local_num_chunks = chunk_end_idx - chunk_start_idx
|
||||||
|
if (local_fill_chunk_idx % 10 == 0
|
||||||
|
or local_fill_chunk_idx == local_num_chunks - 1) and fill_chunk_item_idx == 0:
|
||||||
|
print(f"Worker {worker_id}: Filled up chunk {local_fill_chunk_idx+1}/{local_num_chunks}")
|
||||||
|
fill_chunk_item_idx += 1
|
||||||
|
if fill_chunk_item_idx == BUF_CHUNK_SIZE:
|
||||||
|
fill_chunk_idx += 1
|
||||||
|
fill_chunk_item_idx = 0
|
||||||
|
if fill_chunk_idx == BUF_NUM_CHUNKS:
|
||||||
|
print(f"Worker {worker_id}: Buffer filled up. Start replacing dirty samples...")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Search for the dirty chunk to replace
|
||||||
|
while len(dirty_chunk_item_idxs) == 0:
|
||||||
|
dirty_chunk_dir = os.path.join(BUF_PATH, f"chunk_{dirty_chunk_idx}")
|
||||||
|
dirty_chunk_item_idxs = get_dirty_item(dirty_chunk_dir)
|
||||||
|
# Print the dirty ratio
|
||||||
|
if time.time() - time_stmp > 2.0:
|
||||||
|
dirty_ratio = len(dirty_chunk_item_idxs) / BUF_CHUNK_SIZE
|
||||||
|
print(f"Worker {worker_id}: Dirty Ratio for Chunk {dirty_chunk_idx}: {dirty_ratio:.2f}")
|
||||||
|
time_stmp = time.time()
|
||||||
|
|
||||||
|
if len(dirty_chunk_item_idxs) > 0:
|
||||||
|
# Lock the chunk
|
||||||
|
dirty_bit = np.ones(BUF_CHUNK_SIZE, dtype=np.uint8)
|
||||||
|
save_dirty_bit(dirty_chunk_dir, dirty_bit)
|
||||||
|
|
||||||
|
# Iterate over the chunks
|
||||||
|
dirty_chunk_idx += 1
|
||||||
|
if dirty_chunk_idx == chunk_end_idx:
|
||||||
|
dirty_chunk_idx = chunk_start_idx
|
||||||
|
|
||||||
|
# Replace the dirty item
|
||||||
|
dirty_item_idx = dirty_chunk_item_idxs.pop()
|
||||||
|
chunk_dir = os.path.join(BUF_PATH, f"chunk_{dirty_chunk_idx}")
|
||||||
|
# Save the sample
|
||||||
|
save_sample(step, chunk_dir, dirty_item_idx)
|
||||||
|
|
||||||
|
# If we have replaced all dirty items in the chunk
|
||||||
|
if len(dirty_chunk_item_idxs) == 0:
|
||||||
|
# Unlock the chunk
|
||||||
|
dirty_bit = np.zeros(BUF_CHUNK_SIZE, dtype=np.uint8)
|
||||||
|
save_dirty_bit(dirty_chunk_dir, dirty_bit)
|
||||||
|
print(f"Worker {worker_id}: Replaced dirty chunk {dirty_chunk_idx}.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Args: n_workers, fill_up
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_workers",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Number of parallel workers. It should be less than or equal to the number of chunks.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fill_up",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to fill up the buffer before replacing dirty samples.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--clean_dirty",
|
||||||
|
action="store_true",
|
||||||
|
help=
|
||||||
|
"Whether to clean the dirty bits before replacing dirty samples. This option is ignored when `fill_up` is set.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Random seed. If not set, the seed will be randomly generated.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset_type",
|
||||||
|
type=str,
|
||||||
|
default="pretrain",
|
||||||
|
help="Whether to load the pretrain dataset or finetune dataset.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the producer
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.seed is not None:
|
||||||
|
print(f"Base seed: {args.seed}")
|
||||||
|
random.seed(args.seed)
|
||||||
|
|
||||||
|
processes = []
|
||||||
|
process_seeds = [random.randint(0, 2**32) for _ in range(args.n_workers)]
|
||||||
|
print(f"Process seeds: {process_seeds}")
|
||||||
|
|
||||||
|
def signal_handler(sig, frame):
|
||||||
|
print("Ctrl+C received. Terminating child processes...")
|
||||||
|
for p in processes:
|
||||||
|
p.terminate()
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
for worker_id in range(args.n_workers):
|
||||||
|
p = Process(
|
||||||
|
target=run_producer,
|
||||||
|
args=(
|
||||||
|
process_seeds[worker_id],
|
||||||
|
args.n_workers,
|
||||||
|
worker_id,
|
||||||
|
args.fill_up,
|
||||||
|
args.clean_dirty,
|
||||||
|
args.dataset_type,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
p.start()
|
||||||
|
processes.append(p)
|
||||||
|
|
||||||
|
for p in processes:
|
||||||
|
p.join()
|
||||||
242
RDT-170M/data/utils.py
Normal file
242
RDT-170M/data/utils.py
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow_graphics.geometry.transformation.euler as tf_euler
|
||||||
|
import tensorflow_graphics.geometry.transformation.quaternion as tf_quat
|
||||||
|
import tensorflow_graphics.geometry.transformation.rotation_matrix_3d as tf_rotmat
|
||||||
|
|
||||||
|
|
||||||
|
def dataset_to_path(dataset_name: str, dir_name: str) -> str:
|
||||||
|
"""
|
||||||
|
Return the path to the dataset.
|
||||||
|
"""
|
||||||
|
if (dataset_name == "robo_net" or dataset_name == "cmu_playing_with_food" or dataset_name == "droid"):
|
||||||
|
version = "1.0.0"
|
||||||
|
elif (dataset_name == "language_table" or dataset_name == "fmb" or dataset_name == "dobbe"):
|
||||||
|
version = "0.0.1"
|
||||||
|
elif dataset_name == "nyu_door_opening_surprising_effectiveness":
|
||||||
|
version = ""
|
||||||
|
elif dataset_name == "cmu_play_fusion":
|
||||||
|
version = ""
|
||||||
|
elif dataset_name == "berkeley_gnm_recon":
|
||||||
|
version = ""
|
||||||
|
else:
|
||||||
|
version = "0.1.0"
|
||||||
|
return f"{dir_name}/{dataset_name}/{version}"
|
||||||
|
|
||||||
|
|
||||||
|
def clean_task_instruction(task_instruction: tf.Tensor, replacements: dict) -> tf.Tensor:
|
||||||
|
"""
|
||||||
|
Clean up the natural language task instruction.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create a function that applies all replacements
|
||||||
|
def apply_replacements(tensor):
|
||||||
|
for old, new in replacements.items():
|
||||||
|
tensor = tf.strings.regex_replace(tensor, old, new)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
# Apply the replacements and strip leading and trailing spaces
|
||||||
|
cleaned_task_instruction = apply_replacements(task_instruction)
|
||||||
|
cleaned_task_instruction = tf.strings.strip(cleaned_task_instruction)
|
||||||
|
return cleaned_task_instruction
|
||||||
|
|
||||||
|
|
||||||
|
def quaternion_to_euler(quaternion: tf.Tensor) -> tf.Tensor:
|
||||||
|
"""
|
||||||
|
Convert a quaternion (x, y, z, w) to Euler angles (roll, pitch, yaw).
|
||||||
|
The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention.
|
||||||
|
"""
|
||||||
|
# Normalize the quaternion
|
||||||
|
quaternion = tf.nn.l2_normalize(quaternion, axis=-1)
|
||||||
|
return tf_euler.from_quaternion(quaternion)
|
||||||
|
|
||||||
|
|
||||||
|
def euler_to_quaternion(euler: tf.Tensor) -> tf.Tensor:
|
||||||
|
"""
|
||||||
|
Convert Euler angles (roll, pitch, yaw) to a quaternion (x, y, z, w).
|
||||||
|
The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention.
|
||||||
|
"""
|
||||||
|
quaternion = tf_quat.from_euler(euler)
|
||||||
|
return tf.nn.l2_normalize(quaternion, axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def rotation_matrix_to_euler(matrix: tf.Tensor) -> tf.Tensor:
|
||||||
|
"""
|
||||||
|
Convert a 3x3 rotation matrix to Euler angles (roll, pitch, yaw).
|
||||||
|
The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention.
|
||||||
|
"""
|
||||||
|
return tf_euler.from_rotation_matrix(matrix)
|
||||||
|
|
||||||
|
|
||||||
|
def rotation_matrix_to_quaternion(matrix: tf.Tensor) -> tf.Tensor:
|
||||||
|
"""
|
||||||
|
Convert a 3x3 rotation matrix to a quaternion (x, y, z, w).
|
||||||
|
"""
|
||||||
|
quaternion = tf_quat.from_rotation_matrix(matrix)
|
||||||
|
return tf.nn.l2_normalize(quaternion, axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def euler_to_rotation_matrix(euler: tf.Tensor) -> tf.Tensor:
|
||||||
|
"""
|
||||||
|
Convert Euler angles (roll, pitch, yaw) to a 3x3 rotation matrix.
|
||||||
|
The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention.
|
||||||
|
"""
|
||||||
|
return tf_rotmat.from_euler(euler)
|
||||||
|
|
||||||
|
|
||||||
|
def quaternion_to_rotation_matrix(quaternion: tf.Tensor) -> tf.Tensor:
|
||||||
|
"""
|
||||||
|
Convert a quaternion (x, y, z, w) to a 3x3 rotation matrix.
|
||||||
|
"""
|
||||||
|
# Normalize the quaternion
|
||||||
|
quaternion = tf.nn.l2_normalize(quaternion, axis=-1)
|
||||||
|
return tf_rotmat.from_quaternion(quaternion)
|
||||||
|
|
||||||
|
|
||||||
|
def quaternion_to_rotation_matrix_wo_static_check(quaternion: tf.Tensor) -> tf.Tensor:
|
||||||
|
"""
|
||||||
|
Convert a quaternion (x, y, z, w) to a 3x3 rotation matrix.
|
||||||
|
This function is used to make tensorflow happy.
|
||||||
|
"""
|
||||||
|
# Normalize the quaternion
|
||||||
|
quaternion = tf.nn.l2_normalize(quaternion, axis=-1)
|
||||||
|
|
||||||
|
x = quaternion[..., 0]
|
||||||
|
y = quaternion[..., 1]
|
||||||
|
z = quaternion[..., 2]
|
||||||
|
w = quaternion[..., 3]
|
||||||
|
|
||||||
|
tx = 2.0 * x
|
||||||
|
ty = 2.0 * y
|
||||||
|
tz = 2.0 * z
|
||||||
|
twx = tx * w
|
||||||
|
twy = ty * w
|
||||||
|
twz = tz * w
|
||||||
|
txx = tx * x
|
||||||
|
txy = ty * x
|
||||||
|
txz = tz * x
|
||||||
|
tyy = ty * y
|
||||||
|
tyz = tz * y
|
||||||
|
tzz = tz * z
|
||||||
|
matrix = tf.stack(
|
||||||
|
(
|
||||||
|
1.0 - (tyy + tzz),
|
||||||
|
txy - twz,
|
||||||
|
txz + twy,
|
||||||
|
txy + twz,
|
||||||
|
1.0 - (txx + tzz),
|
||||||
|
tyz - twx,
|
||||||
|
txz - twy,
|
||||||
|
tyz + twx,
|
||||||
|
1.0 - (txx + tyy),
|
||||||
|
),
|
||||||
|
axis=-1,
|
||||||
|
) # pyformat: disable
|
||||||
|
output_shape = tf.concat((tf.shape(input=quaternion)[:-1], (3, 3)), axis=-1)
|
||||||
|
return tf.reshape(matrix, shape=output_shape)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Below is a continuous 6D rotation representation adapted from
|
||||||
|
On the Continuity of Rotation Representations in Neural Networks
|
||||||
|
https://arxiv.org/pdf/1812.07035.pdf
|
||||||
|
https://github.com/papagina/RotationContinuity/blob/master/sanity_test/code/tools.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def rotation_matrix_to_ortho6d(matrix: tf.Tensor) -> tf.Tensor:
|
||||||
|
"""
|
||||||
|
The orhto6d represents the first two column vectors a1 and a2 of the
|
||||||
|
rotation matrix: [ | , |, | ]
|
||||||
|
[ a1, a2, a3]
|
||||||
|
[ | , |, | ]
|
||||||
|
Input: (A1, ..., An, 3, 3)
|
||||||
|
Output: (A1, ..., An, 6)
|
||||||
|
"""
|
||||||
|
ortho6d = matrix[..., :, :2]
|
||||||
|
# Transpose the last two dimension
|
||||||
|
perm = list(range(len(ortho6d.shape)))
|
||||||
|
perm[-2], perm[-1] = perm[-1], perm[-2]
|
||||||
|
ortho6d = tf.transpose(ortho6d, perm)
|
||||||
|
# Flatten the last two dimension
|
||||||
|
ortho6d = tf.reshape(ortho6d, ortho6d.shape[:-2] + [6])
|
||||||
|
return ortho6d
|
||||||
|
|
||||||
|
|
||||||
|
def rotation_matrix_to_ortho6d_1d(matrix: tf.Tensor) -> tf.Tensor:
|
||||||
|
"""
|
||||||
|
The orhto6d represents the first two column vectors a1 and a2 of the
|
||||||
|
rotation matrix: [ | , |, | ]
|
||||||
|
[ a1, a2, a3]
|
||||||
|
[ | , |, | ]
|
||||||
|
Input: (3, 3)
|
||||||
|
Output: (6,)
|
||||||
|
This function is used to make tensorflow happy.
|
||||||
|
"""
|
||||||
|
ortho6d = matrix[:, :2]
|
||||||
|
# Transpose the last two dimension
|
||||||
|
ortho6d = tf.transpose(ortho6d)
|
||||||
|
# Flatten the last two dimension
|
||||||
|
ortho6d = tf.reshape(ortho6d, [6])
|
||||||
|
return ortho6d
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_vector(v):
|
||||||
|
"""
|
||||||
|
v: (..., N)
|
||||||
|
"""
|
||||||
|
v_mag = tf.sqrt(tf.reduce_sum(tf.square(v), axis=-1, keepdims=True))
|
||||||
|
v_mag = tf.maximum(v_mag, 1e-8)
|
||||||
|
v_normalized = v / v_mag
|
||||||
|
|
||||||
|
return v_normalized
|
||||||
|
|
||||||
|
|
||||||
|
def cross_product(u, v):
|
||||||
|
"""
|
||||||
|
u: (..., 3)
|
||||||
|
v: (..., 3)
|
||||||
|
u x v: (..., 3)
|
||||||
|
"""
|
||||||
|
i = u[..., 1] * v[..., 2] - u[..., 2] * v[..., 1]
|
||||||
|
j = u[..., 2] * v[..., 0] - u[..., 0] * v[..., 2]
|
||||||
|
k = u[..., 0] * v[..., 1] - u[..., 1] * v[..., 0]
|
||||||
|
out = tf.stack([i, j, k], axis=-1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def ortho6d_to_rotation_matrix(ortho6d: tf.Tensor) -> tf.Tensor:
|
||||||
|
"""
|
||||||
|
The orhto6d represents the first two column vectors a1 and a2 of the
|
||||||
|
rotation matrix: [ | , |, | ]
|
||||||
|
[ a1, a2, a3]
|
||||||
|
[ | , |, | ]
|
||||||
|
Input: (A1, ..., An, 6)
|
||||||
|
Output: (A1, ..., An, 3, 3)
|
||||||
|
"""
|
||||||
|
x_raw = ortho6d[..., 0:3]
|
||||||
|
y_raw = ortho6d[..., 3:6]
|
||||||
|
|
||||||
|
x = normalize_vector(x_raw)
|
||||||
|
z = cross_product(x, y_raw)
|
||||||
|
z = normalize_vector(z)
|
||||||
|
y = cross_product(z, x)
|
||||||
|
|
||||||
|
# Stack x, y, z to form the matrix
|
||||||
|
matrix = tf.stack([x, y, z], axis=-1)
|
||||||
|
return matrix
|
||||||
|
|
||||||
|
|
||||||
|
def capitalize_and_period(instr: str) -> str:
|
||||||
|
"""
|
||||||
|
Capitalize the first letter of a string and add a period to the end if it's not there.
|
||||||
|
"""
|
||||||
|
if len(instr) > 0:
|
||||||
|
# if the first letter is not capital, make it so
|
||||||
|
if not instr[0].isupper():
|
||||||
|
# if the first letter is not capital, make it so
|
||||||
|
instr = instr[0].upper() + instr[1:]
|
||||||
|
# add period to the end if it's not there
|
||||||
|
if instr[-1] != ".":
|
||||||
|
# add period to the end if it's not there
|
||||||
|
instr = instr + "."
|
||||||
|
return instr
|
||||||
149
RDT-170M/data/vla_dataset.py
Normal file
149
RDT-170M/data/vla_dataset.py
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
import json
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow_datasets as tfds
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from data.episode_transform import (
|
||||||
|
process_episode,
|
||||||
|
flatten_episode,
|
||||||
|
flatten_episode_agilex,
|
||||||
|
bgr_to_rgb,
|
||||||
|
)
|
||||||
|
from data.utils import dataset_to_path
|
||||||
|
from data.preprocess_scripts import *
|
||||||
|
|
||||||
|
# Producer does not need GPU
|
||||||
|
tf.config.set_visible_devices([], "GPU")
|
||||||
|
|
||||||
|
OPENX_EMBOD_DIR = "data/datasets/openx_embod"
|
||||||
|
|
||||||
|
DATASET_NAMES_NOOPENX = [
|
||||||
|
"aloha_mobile",
|
||||||
|
"aloha_static",
|
||||||
|
"roboset",
|
||||||
|
"agilex",
|
||||||
|
"rh20t",
|
||||||
|
"calvin",
|
||||||
|
"bridgev2",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Read the config
|
||||||
|
with open("configs/base.yaml", "r") as file:
|
||||||
|
config = yaml.safe_load(file)
|
||||||
|
# Load some constants from the config
|
||||||
|
EPSD_LEN_THRESH_LOW = config["dataset"]["epsd_len_thresh_low"]
|
||||||
|
EPSD_LEN_THRESH_HIGH = config["dataset"]["epsd_len_thresh_high"]
|
||||||
|
# Read the image keys of each dataset
|
||||||
|
with open("configs/dataset_img_keys.json", "r") as file:
|
||||||
|
IMAGE_KEYS = json.load(file)
|
||||||
|
|
||||||
|
|
||||||
|
class VLADataset:
|
||||||
|
"""
|
||||||
|
This class is used to sample episodes from the embododiment dataset.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, seed, dataset_type, repeat=True):
|
||||||
|
"""
|
||||||
|
seed: the random seed
|
||||||
|
dataset_type: 'pretrain' or 'finetune', which dataset to load
|
||||||
|
repeat: whether to repeat to infinite length
|
||||||
|
"""
|
||||||
|
dataset_names_cfg = ("configs/pretrain_datasets.json"
|
||||||
|
if dataset_type == "pretrain" else "configs/finetune_datasets.json")
|
||||||
|
with open(dataset_names_cfg, "r") as file:
|
||||||
|
DATASET_NAMES = json.load(file)
|
||||||
|
self.dataset_names = DATASET_NAMES
|
||||||
|
sample_weights_cfg = ("configs/pretrain_sample_weights.json"
|
||||||
|
if dataset_type == "pretrain" else "configs/finetune_sample_weights.json")
|
||||||
|
# Load the sample weights
|
||||||
|
with open(sample_weights_cfg, "r") as file:
|
||||||
|
SAMPLE_WEIGHTS = json.load(file)
|
||||||
|
self.openx_dir = OPENX_EMBOD_DIR
|
||||||
|
self.epsd_len_thresh_low = EPSD_LEN_THRESH_LOW
|
||||||
|
self.epsd_len_thresh_high = EPSD_LEN_THRESH_HIGH
|
||||||
|
self.repeat = repeat
|
||||||
|
|
||||||
|
# Set the random seed
|
||||||
|
tf.random.set_seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
|
||||||
|
# Weights of the each dataset in the collection to sample from
|
||||||
|
sample_weights = []
|
||||||
|
|
||||||
|
self.name2dataset = {}
|
||||||
|
for dataset_name in self.dataset_names:
|
||||||
|
if dataset_name in DATASET_NAMES_NOOPENX:
|
||||||
|
dataset = globals()[dataset_name].load_dataset(seed)
|
||||||
|
else:
|
||||||
|
dataset_path = dataset_to_path(dataset_name, self.openx_dir)
|
||||||
|
dataset = tfds.builder_from_directory(builder_dir=dataset_path)
|
||||||
|
dataset = dataset.as_dataset(split="all", shuffle_files=True)
|
||||||
|
|
||||||
|
# You can add filter for other datasets
|
||||||
|
if dataset_name == "kuka":
|
||||||
|
dataset = dataset.filter(lambda x: x["success"])
|
||||||
|
elif dataset_name == "bc_z":
|
||||||
|
dataset = dataset.filter(lambda x: tf.math.greater(
|
||||||
|
next(iter(x["steps"]))["observation"]["episode_success"],
|
||||||
|
0.5,
|
||||||
|
))
|
||||||
|
elif (dataset_name == "ucsd_pick_and_place_dataset_converted_externally_to_rlds"):
|
||||||
|
dataset = dataset.filter(lambda x: x["episode_metadata"]["success"])
|
||||||
|
elif (dataset_name == "utokyo_xarm_bimanual_converted_externally_to_rlds"):
|
||||||
|
# Only preserve the meaningful episodes
|
||||||
|
dataset = dataset.filter(lambda x: tf.math.equal(
|
||||||
|
next(iter(x["steps"]))["language_instruction"],
|
||||||
|
tf.constant("Unfold a wrinkled towel."),
|
||||||
|
))
|
||||||
|
|
||||||
|
# Note: use cache() will cause the unexpected crash
|
||||||
|
# dataset = dataset.map().cache().shuffle().repeat()
|
||||||
|
dataset = dataset.map(lambda x: process_episode(
|
||||||
|
x,
|
||||||
|
dataset_name,
|
||||||
|
IMAGE_KEYS[dataset_name]["image_keys"],
|
||||||
|
IMAGE_KEYS[dataset_name]["image_mask"],
|
||||||
|
))
|
||||||
|
|
||||||
|
# Change BGR to RGB if needed
|
||||||
|
if dataset_name == "fmb":
|
||||||
|
dataset = dataset.map(bgr_to_rgb)
|
||||||
|
|
||||||
|
if self.repeat:
|
||||||
|
dataset = dataset.repeat()
|
||||||
|
self.name2dataset[dataset_name] = iter(dataset)
|
||||||
|
sample_weights.append(SAMPLE_WEIGHTS[dataset_name])
|
||||||
|
# Normalize the sample weights
|
||||||
|
sample_weights = np.array(sample_weights)
|
||||||
|
self.sample_weights = sample_weights / np.sum(sample_weights)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
"""
|
||||||
|
Sample batches of episodes for an epoch.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
dataset_name = np.random.choice(self.dataset_names, p=self.sample_weights)
|
||||||
|
episode = next(self.name2dataset[dataset_name])
|
||||||
|
if dataset_name == "agilex":
|
||||||
|
episode_steps = flatten_episode_agilex(episode)
|
||||||
|
else:
|
||||||
|
episode_steps = flatten_episode(episode)
|
||||||
|
# Filter too short
|
||||||
|
if len(episode_steps) < self.epsd_len_thresh_low:
|
||||||
|
continue
|
||||||
|
# Randomly sample too long
|
||||||
|
if len(episode_steps) > self.epsd_len_thresh_high:
|
||||||
|
episode_steps = random.sample(episode_steps, self.epsd_len_thresh_high)
|
||||||
|
|
||||||
|
yield episode_steps
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
dataset = VLADataset(0, "finetune")
|
||||||
|
for episode in dataset:
|
||||||
|
print(episode[0])
|
||||||
|
break
|
||||||
106
RDT-170M/finetune.sh
Normal file
106
RDT-170M/finetune.sh
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
BEGIN_TIME=$(date +%s)
|
||||||
|
|
||||||
|
CONFIG_NAME="Train_Config_Default"
|
||||||
|
CONFIG_FILE="model_config/$CONFIG_NAME.yml"
|
||||||
|
echo "CONFIG_FILE_PATH: $CONFIG_FILE"
|
||||||
|
|
||||||
|
### ============Read Input Config and ReLoad Config YAML===================
|
||||||
|
|
||||||
|
TRAIN_CONFIG_FILE="input/config.json"
|
||||||
|
echo "TRAIN_CONFIG_FILE_PATH: $TRAIN_CONFIG_FILE"
|
||||||
|
python3 scripts/read_config.py "$TRAIN_CONFIG_FILE" "$CONFIG_FILE"
|
||||||
|
|
||||||
|
### ============Read Input Config and ReLoad Config YAML===================
|
||||||
|
|
||||||
|
export NCCL_IB_HCA=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
|
||||||
|
export NCCL_DEBUG=INFO
|
||||||
|
export NCCL_NVLS_ENABLE=0
|
||||||
|
export NCCL_SOCKET_IFNAME=eth0
|
||||||
|
# export TEXT_ENCODER_NAME="google/t5-v1_1-xxl"
|
||||||
|
export VISION_ENCODER_NAME="../weights/siglip-so400m-patch14-384"
|
||||||
|
export CFLAGS="-I/usr/include"
|
||||||
|
export LDFLAGS="-L/usr/lib/x86_64-linux-gnu"
|
||||||
|
export WANDB_PROJECT="RDT-170M"
|
||||||
|
export WANDB_DEFAULT_RUN_NAME=$CONFIG_NAME
|
||||||
|
export NCCL_P2P_DISABLE=1
|
||||||
|
export NCCL_IB_DISABLE=1
|
||||||
|
|
||||||
|
# check if YAML exist
|
||||||
|
if [ ! -f "$CONFIG_FILE" ]; then
|
||||||
|
echo "Config file $CONFIG_FILE does not exist!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
PRETRAINED_MODEL_NAME=$(python3 scripts/read_yaml.py "$CONFIG_FILE" pretrained_model_name_or_path)
|
||||||
|
TRAIN_BATCH_SIZE=$(python3 scripts/read_yaml.py "$CONFIG_FILE" train_batch_size)
|
||||||
|
SAMPLE_BATCH_SIZE=$(python3 scripts/read_yaml.py "$CONFIG_FILE" sample_batch_size)
|
||||||
|
MAX_TRAIN_STEPS=$(python3 scripts/read_yaml.py "$CONFIG_FILE" max_train_steps)
|
||||||
|
CHECKPOINTING_PERIOD=$(python3 scripts/read_yaml.py "$CONFIG_FILE" checkpointing_period)
|
||||||
|
SAMPLE_PERIOD=$(python3 scripts/read_yaml.py "$CONFIG_FILE" sample_period)
|
||||||
|
CHECKPOINTS_TOTAL_LIMIT=$(python3 scripts/read_yaml.py "$CONFIG_FILE" checkpoints_total_limit)
|
||||||
|
LR_SCHEDULER=$(python3 scripts/read_yaml.py "$CONFIG_FILE" lr_scheduler)
|
||||||
|
LEARNING_RATE=$(python3 scripts/read_yaml.py "$CONFIG_FILE" learning_rate)
|
||||||
|
DATALOADER_NUM_WORKERS=$(python3 scripts/read_yaml.py "$CONFIG_FILE" dataloader_num_workers)
|
||||||
|
DATASET_TYPE=$(python3 scripts/read_yaml.py "$CONFIG_FILE" dataset_type)
|
||||||
|
STATE_NOISE_SNR=$(python3 scripts/read_yaml.py "$CONFIG_FILE" state_noise_snr)
|
||||||
|
GRAD_ACCUM_STEPS=$(python3 scripts/read_yaml.py "$CONFIG_FILE" gradient_accumulation_steps)
|
||||||
|
OUTPUT_DIR=$(python3 scripts/read_yaml.py "$CONFIG_FILE" checkpoint_path)
|
||||||
|
CUDA_USE=$(python3 scripts/read_yaml.py "$CONFIG_FILE" cuda_visible_device)
|
||||||
|
|
||||||
|
export WANDB_MODE=disabled
|
||||||
|
PRETRAINED_MODEL_NAME=$(echo "$PRETRAINED_MODEL_NAME" | tr -d '"')
|
||||||
|
CUDA_USE=$(echo "$CUDA_USE" | tr -d '"')
|
||||||
|
OUTPUT_DIR=$(echo "$OUTPUT_DIR" | tr -d '"')
|
||||||
|
|
||||||
|
# create output path
|
||||||
|
if [ ! -d "$OUTPUT_DIR" ]; then
|
||||||
|
mkdir -p "$OUTPUT_DIR"
|
||||||
|
echo "Created output directory: $OUTPUT_DIR"
|
||||||
|
else
|
||||||
|
echo "Output directory already exists: $OUTPUT_DIR"
|
||||||
|
fi
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES=$CUDA_USE
|
||||||
|
|
||||||
|
python3 -m data.compute_dataset_stat_hdf5 --task_name $CONFIG_NAME
|
||||||
|
|
||||||
|
accelerate launch --main_process_port=28499 main.py \
|
||||||
|
--deepspeed="./configs/zero2.json" \
|
||||||
|
--pretrained_model_name_or_path=$PRETRAINED_MODEL_NAME \
|
||||||
|
--pretrained_text_encoder_name_or_path=$TEXT_ENCODER_NAME \
|
||||||
|
--pretrained_vision_encoder_name_or_path=$VISION_ENCODER_NAME \
|
||||||
|
--output_dir=$OUTPUT_DIR \
|
||||||
|
--train_batch_size=$TRAIN_BATCH_SIZE \
|
||||||
|
--sample_batch_size=$SAMPLE_BATCH_SIZE \
|
||||||
|
--max_train_steps=$MAX_TRAIN_STEPS \
|
||||||
|
--checkpointing_period=$CHECKPOINTING_PERIOD \
|
||||||
|
--sample_period=$SAMPLE_PERIOD \
|
||||||
|
--checkpoints_total_limit=$CHECKPOINTS_TOTAL_LIMIT \
|
||||||
|
--lr_scheduler="constant" \
|
||||||
|
--learning_rate=$LEARNING_RATE \
|
||||||
|
--mixed_precision="bf16" \
|
||||||
|
--dataloader_num_workers=$DATALOADER_NUM_WORKERS \
|
||||||
|
--image_aug \
|
||||||
|
--dataset_type="finetune" \
|
||||||
|
--state_noise_snr=$STATE_NOISE_SNR \
|
||||||
|
--load_from_hdf5 \
|
||||||
|
--report_to=wandb \
|
||||||
|
--precomp_lang_embed \
|
||||||
|
--gradient_accumulation_steps=$GRAD_ACCUM_STEPS \
|
||||||
|
--model_config_path=$CONFIG_FILE \
|
||||||
|
--CONFIG_NAME=$CONFIG_NAME \
|
||||||
|
--output_log_path=$OUTPUT_DIR/output.log
|
||||||
|
|
||||||
|
END_TIME=$(date +%s)
|
||||||
|
RUNTIME=$((END_TIME - BEGIN_TIME))
|
||||||
|
echo "Total runtime: $RUNTIME seconds"
|
||||||
|
|
||||||
|
### ============Generate Output JSON===================
|
||||||
|
sleep 10
|
||||||
|
python3 scripts/generate_output_json.py "$TRAIN_CONFIG_FILE" "$OUTPUT_DIR" "$RUNTIME"
|
||||||
|
|
||||||
|
### ============Generate Output JSON===================
|
||||||
|
|
||||||
|
|
||||||
Binary file not shown.
5
RDT-170M/generate.sh
Normal file
5
RDT-170M/generate.sh
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
model_name=${1}
|
||||||
|
|
||||||
|
python ./model_config/_generate_model_config.py $model_name
|
||||||
351
RDT-170M/main.py
Normal file
351
RDT-170M/main.py
Normal file
@ -0,0 +1,351 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from train.train import train
|
||||||
|
|
||||||
|
from accelerate.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args(input_args=None):
|
||||||
|
parser = argparse.ArgumentParser(description="Main script for training RDT.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_config_path",
|
||||||
|
type=str,
|
||||||
|
default="model_config/sjoe_place_D435_100_finetune_config.yaml",
|
||||||
|
help=
|
||||||
|
"Path to the finetune data and model configuration file. Default is `model_config/sjoe_place_D435_100_finetune_config.yaml`.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config_path",
|
||||||
|
type=str,
|
||||||
|
default="configs/base.yaml",
|
||||||
|
help="Path to the configuration file. Default is `configs/base.yaml`.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--deepspeed",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=
|
||||||
|
"Enable DeepSpeed and pass the path to its config file or an already initialized DeepSpeed config dictionary",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained_text_encoder_name_or_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Pretrained text encoder name or path if not the same as model_name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained_vision_encoder_name_or_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Pretrained vision encoder name or path if not the same as model_name",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir",
|
||||||
|
type=str,
|
||||||
|
default="checkpoints",
|
||||||
|
help="The output directory where the model predictions and checkpoints will be written.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--load_from_hdf5",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help=("Whether to load the dataset directly from HDF5 files. "
|
||||||
|
"If False, the dataset will be loaded using producer-consumer pattern, "
|
||||||
|
"where the producer reads TFRecords and saves them to buffer, and the consumer reads from buffer."),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--train_batch_size",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Batch size (per device) for the training dataloader.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample_batch_size",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Batch size (per device) for the sampling dataloader.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_sample_batches",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Number of batches to sample from the dataset.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--num_train_epochs", type=int, default=1)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_train_steps",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpointing_period",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
help=
|
||||||
|
("Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
|
||||||
|
"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
|
||||||
|
"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
|
||||||
|
"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
|
||||||
|
"instructions."),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoints_total_limit",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help=
|
||||||
|
("Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
|
||||||
|
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
|
||||||
|
" for more details"),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--resume_from_checkpoint",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=("Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
||||||
|
' `--checkpointing_period`, or `"latest"` to automatically select the last available checkpoint.'),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained_model_name_or_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=(
|
||||||
|
"Path or name of a pretrained checkpoint to load the model from.\n",
|
||||||
|
" This can be either:\n"
|
||||||
|
" - a string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co, e.g., `robotics-diffusion-transformer/rdt-1b`,\n"
|
||||||
|
" - a path to a *directory* containing model weights saved using [`~RDTRunner.save_pretrained`] method, e.g., `./my_model_directory/`.\n"
|
||||||
|
" - a path to model checkpoint (*.pt), .e.g, `my_model_directory/checkpoint-10000/pytorch_model/mp_rank_00_model_states.pt`"
|
||||||
|
" - `None` if you are randomly initializing model using configuration at `config_path`.",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gradient_accumulation_steps",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gradient_checkpointing",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--learning_rate",
|
||||||
|
type=float,
|
||||||
|
default=5e-6,
|
||||||
|
help="Initial learning rate (after the potential warmup period) to use.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cond_mask_prob",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help=("The probability to randomly mask the conditions (except states) during training. "
|
||||||
|
"If set to 0, the conditions are not masked."),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cam_ext_mask_prob",
|
||||||
|
type=float,
|
||||||
|
default=-1.0,
|
||||||
|
help=("The probability to randomly mask the external camera image during training. "
|
||||||
|
"If set to < 0, the external camera image is masked with the probability of `cond_mask_prob`."),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--state_noise_snr",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help=("The signal-to-noise ratio (SNR, unit: dB) for adding noise to the states. "
|
||||||
|
"Default is None, which means no noise is added."),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--image_aug",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Whether or not to apply image augmentation (ColorJitter, blur, noise, etc) to the input images.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--precomp_lang_embed",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Whether or not to use precomputed language embeddings.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--scale_lr",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lr_scheduler",
|
||||||
|
type=str,
|
||||||
|
default="constant",
|
||||||
|
help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||||
|
' "constant", "constant_with_warmup"]'),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lr_warmup_steps",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
help="Number of steps for the warmup in the lr scheduler.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lr_num_cycles",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lr_power",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="Power factor of the polynomial scheduler.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_8bit_adam",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether or not to use 8-bit Adam from bitsandbytes.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataloader_num_workers",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help=(
|
||||||
|
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--alpha",
|
||||||
|
type=float,
|
||||||
|
default=0.9,
|
||||||
|
help="The moving average coefficient for each dataset's loss.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--adam_beta1",
|
||||||
|
type=float,
|
||||||
|
default=0.9,
|
||||||
|
help="The beta1 parameter for the Adam optimizer.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--adam_beta2",
|
||||||
|
type=float,
|
||||||
|
default=0.999,
|
||||||
|
help="The beta2 parameter for the Adam optimizer.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--adam_epsilon",
|
||||||
|
type=float,
|
||||||
|
default=1e-08,
|
||||||
|
help="Epsilon value for the Adam optimizer",
|
||||||
|
)
|
||||||
|
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--push_to_hub",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether or not to push the model to the Hub.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hub_token",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The token to use to push to the Model Hub.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hub_model_id",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--logging_dir",
|
||||||
|
type=str,
|
||||||
|
default="logs",
|
||||||
|
help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||||
|
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--allow_tf32",
|
||||||
|
action="store_true",
|
||||||
|
help=("Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
||||||
|
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--report_to",
|
||||||
|
type=str,
|
||||||
|
default="tensorboard",
|
||||||
|
help=('The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
||||||
|
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample_period",
|
||||||
|
type=int,
|
||||||
|
default=-1,
|
||||||
|
help=("Run sampling every X steps. During the sampling phase, the model will sample a trajectory"
|
||||||
|
" and report the error between the sampled trajectory and groud-truth trajectory"
|
||||||
|
" in the training batch."),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mixed_precision",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
choices=["no", "fp16", "bf16"],
|
||||||
|
help=(
|
||||||
|
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||||
|
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
||||||
|
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."),
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--local_rank",
|
||||||
|
type=int,
|
||||||
|
default=-1,
|
||||||
|
help="For distributed training: local_rank",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--set_grads_to_none",
|
||||||
|
action="store_true",
|
||||||
|
help=("Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
|
||||||
|
" behaviors, so disable this argument if it causes any problems. More info:"
|
||||||
|
" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"),
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset_type",
|
||||||
|
type=str,
|
||||||
|
default="pretrain",
|
||||||
|
required=False,
|
||||||
|
help="Whether to load the pretrain dataset or finetune dataset.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--CONFIG_NAME",
|
||||||
|
type=str,
|
||||||
|
default="Null",
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_log_path",
|
||||||
|
type=str,
|
||||||
|
default="output/output.log",
|
||||||
|
help="The path to the output log file.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if input_args is not None:
|
||||||
|
args = parser.parse_args(input_args)
|
||||||
|
else:
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||||
|
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||||
|
args.local_rank = env_local_rank
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
args = parse_args()
|
||||||
|
train(args, logger)
|
||||||
269
RDT-170M/model.py
Normal file
269
RDT-170M/model.py
Normal file
@ -0,0 +1,269 @@
|
|||||||
|
#!/home/lin/software/miniconda3/envs/aloha/bin/python
|
||||||
|
# -- coding: UTF-8
|
||||||
|
"""
|
||||||
|
#!/usr/bin/python3
|
||||||
|
"""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# get current workspace
|
||||||
|
current_file = Path(__file__)
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
|
||||||
|
parent_dir = current_file.parent
|
||||||
|
sys.path.append(str(parent_dir))
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import yaml
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image as PImage
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
import sys, os
|
||||||
|
|
||||||
|
# get current workspace
|
||||||
|
current_file = Path(__file__)
|
||||||
|
sys.path.append(os.path.join(current_file.parent, "models"))
|
||||||
|
|
||||||
|
from scripts.agilex_model import create_model
|
||||||
|
from multimodal_encoder.t5_encoder import T5Embedder
|
||||||
|
|
||||||
|
global_path = parent_dir.parent
|
||||||
|
|
||||||
|
|
||||||
|
class RDT:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
task_name,
|
||||||
|
left_arm_dim,
|
||||||
|
right_arm_dim,
|
||||||
|
rdt_step,
|
||||||
|
):
|
||||||
|
# set path
|
||||||
|
current_file = Path(__file__)
|
||||||
|
self.global_path = current_file.parent.parent
|
||||||
|
# load the config
|
||||||
|
self.config = {
|
||||||
|
"episode_len": 10000, # args.max_publish_step
|
||||||
|
"state_dim": left_arm_dim + 1 + right_arm_dim +
|
||||||
|
1, # 14 dims action:[left joint angles,left gripper,right joint angles,right gripper]
|
||||||
|
"chunk_size": 64, # args.chunk_size
|
||||||
|
"camera_names": ["cam_high", "cam_right_wrist", "cam_left_wrist"],
|
||||||
|
}
|
||||||
|
# setup config
|
||||||
|
self.args = {
|
||||||
|
"max_publish_step": 10000, # Maximum number of action publishing steps
|
||||||
|
"seed": None, # Random seed
|
||||||
|
"ctrl_freq": 25, # The control frequency of the robot
|
||||||
|
"chunk_size": 64, # Action chunk size
|
||||||
|
# 'disable_puppet_arm': False, # Whether to disable the puppet arm
|
||||||
|
"config_path": os.path.join(self.global_path, "RDT/configs/base.yaml"),
|
||||||
|
"pretrained_model_name_or_path": pretrained_model_name_or_path,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Load rdt model
|
||||||
|
self.left_arm_dim, self.right_arm_dim = left_arm_dim, right_arm_dim
|
||||||
|
self.policy = self.make_policy(self.args)
|
||||||
|
self.max_publish_step = self.config["episode_len"]
|
||||||
|
self.chunk_size = self.config["chunk_size"]
|
||||||
|
self.task_name = task_name
|
||||||
|
self.observation_window = None
|
||||||
|
self.img_size = (640, 480)
|
||||||
|
self.set_language_embed()
|
||||||
|
self.rdt_step = rdt_step
|
||||||
|
|
||||||
|
# set img_size
|
||||||
|
def set_img_size(self, img_size):
|
||||||
|
self.img_size = img_size
|
||||||
|
|
||||||
|
def set_language_embed(self):
|
||||||
|
GPU = 0
|
||||||
|
MODEL_PATH = os.path.join(self.global_path, "weights/RDT/t5-v1_1-xxl")
|
||||||
|
CONFIG_PATH = os.path.join(self.global_path, "RDT/configs/base.yaml")
|
||||||
|
with open(CONFIG_PATH, "r") as fp:
|
||||||
|
config = yaml.safe_load(fp)
|
||||||
|
device = torch.device(f"cuda:{GPU}")
|
||||||
|
text_embedder = T5Embedder(
|
||||||
|
from_pretrained=MODEL_PATH,
|
||||||
|
model_max_length=config["dataset"]["tokenizer_max_length"],
|
||||||
|
device=device,
|
||||||
|
use_offload_folder=None,
|
||||||
|
)
|
||||||
|
self.tokenizer, self.text_encoder = text_embedder.tokenizer, text_embedder.model
|
||||||
|
self.text_encoder.eval()
|
||||||
|
|
||||||
|
# set language randomly
|
||||||
|
def random_set_language(self, instruction=None):
|
||||||
|
assert instruction is not None, "Missing input instruction"
|
||||||
|
self.set_language_instruction(instruction)
|
||||||
|
|
||||||
|
# encoding language
|
||||||
|
def set_language_instruction(self, language_instruction, save_dir=None, task_name=None):
|
||||||
|
assert ((save_dir is None) ^ (task_name is None)) == False, "input error"
|
||||||
|
|
||||||
|
if os.path.isfile(language_instruction):
|
||||||
|
lang_dict = torch.load(language_instruction)
|
||||||
|
print(f"Running with instruction: \"{lang_dict['instruction']}\" from \"{lang_dict['name']}\"")
|
||||||
|
self.lang_embeddings = lang_dict["embeddings"]
|
||||||
|
print("loading instruction from pre-embed path")
|
||||||
|
else:
|
||||||
|
device = next(self.text_encoder.parameters()).device
|
||||||
|
with torch.no_grad():
|
||||||
|
tokens = self.tokenizer(
|
||||||
|
language_instruction,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding="longest",
|
||||||
|
truncation=True,
|
||||||
|
)["input_ids"].to(device)
|
||||||
|
tokens = tokens.view(1, -1)
|
||||||
|
output = self.text_encoder(tokens)
|
||||||
|
pred = output.last_hidden_state.detach().cpu()
|
||||||
|
|
||||||
|
if save_dir is not None:
|
||||||
|
save_path = os.path.join(save_dir, f"{task_name}.pt")
|
||||||
|
torch.save({
|
||||||
|
"name": task_name,
|
||||||
|
"instruction": language_instruction,
|
||||||
|
"embeddings": pred,
|
||||||
|
}, save_path)
|
||||||
|
|
||||||
|
del tokens, output
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
self.lang_embeddings = pred
|
||||||
|
|
||||||
|
print(f"successfully set instruction: {language_instruction}")
|
||||||
|
|
||||||
|
# Update the observation window buffer
|
||||||
|
def update_observation_window(self, img_arr, state):
|
||||||
|
# JPEG transformation
|
||||||
|
# Align with training
|
||||||
|
def jpeg_mapping(img):
|
||||||
|
if img is None:
|
||||||
|
return None
|
||||||
|
img = cv2.imencode(".jpg", img)[1].tobytes()
|
||||||
|
img = cv2.imdecode(np.frombuffer(img, np.uint8), cv2.IMREAD_COLOR)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def resize_img(img, size):
|
||||||
|
return cv2.resize(img, size)
|
||||||
|
|
||||||
|
if self.observation_window is None:
|
||||||
|
self.observation_window = deque(maxlen=2)
|
||||||
|
|
||||||
|
# Append the first dummy image
|
||||||
|
self.observation_window.append({
|
||||||
|
"qpos": None,
|
||||||
|
"images": {
|
||||||
|
self.config["camera_names"][0]: None,
|
||||||
|
self.config["camera_names"][1]: None,
|
||||||
|
self.config["camera_names"][2]: None,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
img_front, img_right, img_left, puppet_arm = (
|
||||||
|
img_arr[0],
|
||||||
|
img_arr[1],
|
||||||
|
img_arr[2],
|
||||||
|
state,
|
||||||
|
)
|
||||||
|
# img resize
|
||||||
|
img_front = resize_img(img_front, self.img_size)
|
||||||
|
img_left = resize_img(img_left, self.img_size)
|
||||||
|
img_right = resize_img(img_right, self.img_size)
|
||||||
|
# img jprg encoding
|
||||||
|
img_front = jpeg_mapping(img_front)
|
||||||
|
img_left = jpeg_mapping(img_left)
|
||||||
|
img_right = jpeg_mapping(img_right)
|
||||||
|
|
||||||
|
qpos = np.array(puppet_arm)
|
||||||
|
qpos = torch.from_numpy(qpos).float().cuda()
|
||||||
|
self.observation_window.append({
|
||||||
|
"qpos": qpos,
|
||||||
|
"images": {
|
||||||
|
self.config["camera_names"][0]: img_front,
|
||||||
|
self.config["camera_names"][1]: img_right,
|
||||||
|
self.config["camera_names"][2]: img_left,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
def get_action(self, img_arr=None, state=None):
|
||||||
|
assert (img_arr is None) ^ (state is None) == False, "input error"
|
||||||
|
if (img_arr is not None) and (state is not None):
|
||||||
|
self.update_observation_window(img_arr, state)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
action_buffer = inference_fn(self.config, self.policy, self.lang_embeddings, self.observation_window).copy()
|
||||||
|
|
||||||
|
return action_buffer
|
||||||
|
|
||||||
|
def reset_obsrvationwindows(self):
|
||||||
|
self.lang_embeddings = None
|
||||||
|
self.observation_window = None
|
||||||
|
print("successfully unset obs and language intruction")
|
||||||
|
|
||||||
|
# Initialize the model
|
||||||
|
def make_policy(self, args):
|
||||||
|
with open(args["config_path"], "r") as fp:
|
||||||
|
config_base_yaml = yaml.safe_load(fp)
|
||||||
|
args["config"] = config_base_yaml
|
||||||
|
args["config"]["arm_dim"] = {
|
||||||
|
"left_arm_dim": self.left_arm_dim,
|
||||||
|
"right_arm_dim": self.right_arm_dim,
|
||||||
|
}
|
||||||
|
# pretrained_text_encoder_name_or_path = "weights/RDT/t5-v1_1-xxl"
|
||||||
|
pretrained_vision_encoder_name_or_path = os.path.join(self.global_path, "weights/RDT/siglip-so400m-patch14-384")
|
||||||
|
model = create_model(
|
||||||
|
args=args["config"],
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
pretrained=args["pretrained_model_name_or_path"],
|
||||||
|
# pretrained_text_encoder_name_or_path=pretrained_text_encoder_name_or_path,
|
||||||
|
pretrained_vision_encoder_name_or_path=pretrained_vision_encoder_name_or_path,
|
||||||
|
control_frequency=args["ctrl_freq"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
# RDT inference
|
||||||
|
def inference_fn(config, policy, lang_embeddings, observation_window):
|
||||||
|
|
||||||
|
# print(f"Start inference_thread_fn: t={t}")
|
||||||
|
while True:
|
||||||
|
time1 = time.time()
|
||||||
|
|
||||||
|
# fetch images in sequence [front, right, left]
|
||||||
|
image_arrs = [
|
||||||
|
observation_window[-2]["images"][config["camera_names"][0]],
|
||||||
|
observation_window[-2]["images"][config["camera_names"][1]],
|
||||||
|
observation_window[-2]["images"][config["camera_names"][2]],
|
||||||
|
observation_window[-1]["images"][config["camera_names"][0]],
|
||||||
|
observation_window[-1]["images"][config["camera_names"][1]],
|
||||||
|
observation_window[-1]["images"][config["camera_names"][2]],
|
||||||
|
]
|
||||||
|
|
||||||
|
images = [PImage.fromarray(arr) if arr is not None else None for arr in image_arrs]
|
||||||
|
|
||||||
|
# get last qpos in shape [14, ]
|
||||||
|
proprio = observation_window[-1]["qpos"]
|
||||||
|
# unsqueeze to [1, 14]
|
||||||
|
proprio = proprio.unsqueeze(0)
|
||||||
|
|
||||||
|
# actions shaped as [1, 64, 14] in format [left, right]
|
||||||
|
actions = (policy.step(proprio=proprio, images=images, text_embeds=lang_embeddings).squeeze(0).cpu().numpy())
|
||||||
|
# print(f"inference_actions: {actions.squeeze()}")
|
||||||
|
|
||||||
|
# print(f"Model inference time: {time.time() - time1} s")
|
||||||
|
|
||||||
|
# print(f"Finish inference_thread_fn: t={t}")
|
||||||
|
return actions
|
||||||
40
RDT-170M/model_config/_generate_model_config.py
Normal file
40
RDT-170M/model_config/_generate_model_config.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
import argparse
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Generate finetune config.")
|
||||||
|
parser.add_argument("model_name", type=str, help="The name of the task (e.g., beat_block_hammer)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
model_name = args.model_name
|
||||||
|
fintune_data_path = os.path.join("training_data/", f"{model_name}")
|
||||||
|
checkpoint_path = os.path.join("checkpoints/", f"{model_name}")
|
||||||
|
data = {
|
||||||
|
"model": model_name,
|
||||||
|
"data_path": fintune_data_path,
|
||||||
|
"checkpoint_path": checkpoint_path,
|
||||||
|
"pretrained_model_name_or_path": "../weights/RDT/rdt-1b",
|
||||||
|
"cuda_visible_device": "...", # args.gpu_use,
|
||||||
|
"train_batch_size": 32,
|
||||||
|
"sample_batch_size": 64,
|
||||||
|
"max_train_steps": 20000,
|
||||||
|
"checkpointing_period": 2500,
|
||||||
|
"sample_period": 100,
|
||||||
|
"checkpoints_total_limit": 40,
|
||||||
|
"learning_rate": 1e-4,
|
||||||
|
"dataloader_num_workers": 8,
|
||||||
|
"state_noise_snr": 40,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
}
|
||||||
|
task_config_path = os.path.join("model_config/", f"{model_name}.yml")
|
||||||
|
|
||||||
|
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
time_comment = f"# Generated on {current_time}\n"
|
||||||
|
|
||||||
|
with open(task_config_path, "w") as f:
|
||||||
|
f.write(time_comment)
|
||||||
|
yaml.dump(data, f, default_flow_style=False, sort_keys=False)
|
||||||
|
|
||||||
|
if not os.path.exists(fintune_data_path):
|
||||||
|
os.makedirs(fintune_data_path)
|
||||||
0
RDT-170M/models/__init__.py
Normal file
0
RDT-170M/models/__init__.py
Normal file
82
RDT-170M/models/ema_model.py
Normal file
82
RDT-170M/models/ema_model.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
# Reference: DiffusionPolicy [https://github.com/real-stanford/diffusion_policy]
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
|
|
||||||
|
class EMAModel:
|
||||||
|
"""
|
||||||
|
Exponential Moving Average of models weights
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model, update_after_step=0, inv_gamma=1.0, power=2 / 3, min_value=0.0, max_value=0.9999):
|
||||||
|
"""
|
||||||
|
@crowsonkb's notes on EMA Warmup:
|
||||||
|
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
|
||||||
|
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
|
||||||
|
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
|
||||||
|
at 215.4k steps).
|
||||||
|
Args:
|
||||||
|
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
||||||
|
power (float): Exponential factor of EMA warmup. Default: 2/3.
|
||||||
|
min_value (float): The minimum EMA decay rate. Default: 0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.averaged_model = model
|
||||||
|
self.averaged_model.eval()
|
||||||
|
self.averaged_model.requires_grad_(False)
|
||||||
|
|
||||||
|
self.update_after_step = update_after_step
|
||||||
|
self.inv_gamma = inv_gamma
|
||||||
|
self.power = power
|
||||||
|
self.min_value = min_value
|
||||||
|
self.max_value = max_value
|
||||||
|
|
||||||
|
self.decay = 0.0
|
||||||
|
self.optimization_step = 0
|
||||||
|
|
||||||
|
def get_decay(self, optimization_step):
|
||||||
|
"""
|
||||||
|
Compute the decay factor for the exponential moving average.
|
||||||
|
"""
|
||||||
|
step = max(0, optimization_step - self.update_after_step - 1)
|
||||||
|
value = 1 - (1 + step / self.inv_gamma)**-self.power
|
||||||
|
|
||||||
|
if step <= 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
return max(self.min_value, min(value, self.max_value))
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(self, new_model):
|
||||||
|
self.decay = self.get_decay(self.optimization_step)
|
||||||
|
|
||||||
|
# old_all_dataptrs = set()
|
||||||
|
# for param in new_model.parameters():
|
||||||
|
# data_ptr = param.data_ptr()
|
||||||
|
# if data_ptr != 0:
|
||||||
|
# old_all_dataptrs.add(data_ptr)
|
||||||
|
|
||||||
|
all_dataptrs = set()
|
||||||
|
for module, ema_module in zip(new_model.modules(), self.averaged_model.modules()):
|
||||||
|
for param, ema_param in zip(module.parameters(recurse=False), ema_module.parameters(recurse=False)):
|
||||||
|
# iterative over immediate parameters only.
|
||||||
|
if isinstance(param, dict):
|
||||||
|
raise RuntimeError('Dict parameter not supported')
|
||||||
|
|
||||||
|
# data_ptr = param.data_ptr()
|
||||||
|
# if data_ptr != 0:
|
||||||
|
# all_dataptrs.add(data_ptr)
|
||||||
|
|
||||||
|
if isinstance(module, _BatchNorm):
|
||||||
|
# skip batchnorms
|
||||||
|
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
|
||||||
|
elif not param.requires_grad:
|
||||||
|
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
|
||||||
|
else:
|
||||||
|
ema_param.mul_(self.decay)
|
||||||
|
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)
|
||||||
|
|
||||||
|
# verify that iterating over module and then parameters is identical to parameters recursively.
|
||||||
|
# assert old_all_dataptrs == all_dataptrs
|
||||||
|
self.optimization_step += 1
|
||||||
75
RDT-170M/models/hub_mixin.py
Normal file
75
RDT-170M/models/hub_mixin.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
|
from huggingface_hub import PyTorchModelHubMixin
|
||||||
|
from huggingface_hub.constants import (PYTORCH_WEIGHTS_NAME, SAFETENSORS_SINGLE_FILE)
|
||||||
|
from huggingface_hub.file_download import hf_hub_download
|
||||||
|
from huggingface_hub.utils import EntryNotFoundError, is_torch_available
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class CompatiblePyTorchModelHubMixin(PyTorchModelHubMixin):
|
||||||
|
"""Mixin class to load Pytorch models from the Hub."""
|
||||||
|
|
||||||
|
def _save_pretrained(self, save_directory: Path) -> None:
|
||||||
|
"""Save weights from a Pytorch model to a local directory."""
|
||||||
|
# To bypass saving into safetensor by default
|
||||||
|
model_to_save = self.module if hasattr(self, "module") else self # type: ignore
|
||||||
|
torch.save(model_to_save.state_dict(), save_directory / PYTORCH_WEIGHTS_NAME)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _from_pretrained(
|
||||||
|
cls,
|
||||||
|
*,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str],
|
||||||
|
cache_dir: Optional[Union[str, Path]],
|
||||||
|
force_download: bool,
|
||||||
|
proxies: Optional[Dict],
|
||||||
|
resume_download: Optional[bool],
|
||||||
|
local_files_only: bool,
|
||||||
|
token: Union[str, bool, None],
|
||||||
|
map_location: str = "cpu",
|
||||||
|
strict: bool = False,
|
||||||
|
**model_kwargs,
|
||||||
|
):
|
||||||
|
"""Load Pytorch pretrained weights and return the loaded model."""
|
||||||
|
model = cls(**model_kwargs)
|
||||||
|
if os.path.isdir(model_id):
|
||||||
|
print("Loading weights from local directory")
|
||||||
|
try:
|
||||||
|
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
||||||
|
return cls._load_as_safetensor(model, model_file, map_location, strict)
|
||||||
|
except FileNotFoundError:
|
||||||
|
model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
|
||||||
|
return cls._load_as_pickle(model, model_file, map_location, strict)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
model_file = hf_hub_download(
|
||||||
|
repo_id=model_id,
|
||||||
|
filename=SAFETENSORS_SINGLE_FILE,
|
||||||
|
revision=revision,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
token=token,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
return cls._load_as_safetensor(model, model_file, map_location, strict)
|
||||||
|
except EntryNotFoundError:
|
||||||
|
model_file = hf_hub_download(
|
||||||
|
repo_id=model_id,
|
||||||
|
filename=PYTORCH_WEIGHTS_NAME,
|
||||||
|
revision=revision,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
token=token,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
return cls._load_as_pickle(model, model_file, map_location, strict)
|
||||||
0
RDT-170M/models/multimodal_encoder/__init__.py
Normal file
0
RDT-170M/models/multimodal_encoder/__init__.py
Normal file
159
RDT-170M/models/multimodal_encoder/clip_encoder.py
Normal file
159
RDT-170M/models/multimodal_encoder/clip_encoder.py
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPVisionTower(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, vision_tower, args, delay_load=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.is_loaded = False
|
||||||
|
|
||||||
|
self.vision_tower_name = vision_tower
|
||||||
|
self.select_layer = args.mm_vision_select_layer
|
||||||
|
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
||||||
|
|
||||||
|
if not delay_load:
|
||||||
|
self.load_model()
|
||||||
|
elif getattr(args, 'unfreeze_mm_vision_tower', False):
|
||||||
|
self.load_model()
|
||||||
|
else:
|
||||||
|
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
|
||||||
|
|
||||||
|
def load_model(self, device_map=None):
|
||||||
|
if self.is_loaded:
|
||||||
|
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
|
||||||
|
return
|
||||||
|
|
||||||
|
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
|
||||||
|
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
|
||||||
|
self.vision_tower.requires_grad_(False)
|
||||||
|
|
||||||
|
self.is_loaded = True
|
||||||
|
|
||||||
|
def feature_select(self, image_forward_outs):
|
||||||
|
image_features = image_forward_outs.hidden_states[self.select_layer]
|
||||||
|
if self.select_feature == 'patch':
|
||||||
|
image_features = image_features[:, 1:]
|
||||||
|
elif self.select_feature == 'cls_patch':
|
||||||
|
image_features = image_features
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, images):
|
||||||
|
if type(images) is list:
|
||||||
|
image_features = []
|
||||||
|
for image in images:
|
||||||
|
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
|
||||||
|
output_hidden_states=True)
|
||||||
|
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
||||||
|
image_features.append(image_feature)
|
||||||
|
else:
|
||||||
|
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
|
||||||
|
output_hidden_states=True)
|
||||||
|
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
||||||
|
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dummy_feature(self):
|
||||||
|
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return self.vision_tower.dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return self.vision_tower.device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def config(self):
|
||||||
|
if self.is_loaded:
|
||||||
|
return self.vision_tower.config
|
||||||
|
else:
|
||||||
|
return self.cfg_only
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hidden_size(self):
|
||||||
|
return self.config.hidden_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_patches_per_side(self):
|
||||||
|
return self.config.image_size // self.config.patch_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_patches(self):
|
||||||
|
return (self.config.image_size // self.config.patch_size)**2
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPVisionTowerS2(CLIPVisionTower):
|
||||||
|
|
||||||
|
def __init__(self, vision_tower, args, delay_load=False):
|
||||||
|
super().__init__(vision_tower, args, delay_load)
|
||||||
|
|
||||||
|
self.s2_scales = getattr(args, 's2_scales', '336,672,1008')
|
||||||
|
self.s2_scales = list(map(int, self.s2_scales.split(',')))
|
||||||
|
self.s2_scales.sort()
|
||||||
|
self.s2_split_size = self.s2_scales[0]
|
||||||
|
self.s2_image_size = self.s2_scales[-1]
|
||||||
|
|
||||||
|
try:
|
||||||
|
from s2wrapper import forward as multiscale_forward
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
'Package s2wrapper not found! Please install by running: \npip install git+https://github.com/bfshi/scaling_on_scales.git'
|
||||||
|
)
|
||||||
|
self.multiscale_forward = multiscale_forward
|
||||||
|
|
||||||
|
# change resize/crop size in preprocessing to the largest image size in s2_scale
|
||||||
|
if not delay_load or getattr(args, 'unfreeze_mm_vision_tower', False):
|
||||||
|
self.image_processor.size['shortest_edge'] = self.s2_image_size
|
||||||
|
self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
|
||||||
|
|
||||||
|
def load_model(self, device_map=None):
|
||||||
|
if self.is_loaded:
|
||||||
|
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
|
||||||
|
return
|
||||||
|
|
||||||
|
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
|
||||||
|
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
|
||||||
|
self.vision_tower.requires_grad_(False)
|
||||||
|
|
||||||
|
self.image_processor.size['shortest_edge'] = self.s2_image_size
|
||||||
|
self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
|
||||||
|
|
||||||
|
self.is_loaded = True
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward_feature(self, images):
|
||||||
|
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
|
||||||
|
output_hidden_states=True)
|
||||||
|
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, images):
|
||||||
|
if type(images) is list:
|
||||||
|
image_features = []
|
||||||
|
for image in images:
|
||||||
|
image_feature = self.multiscale_forward(self.forward_feature,
|
||||||
|
image.unsqueeze(0),
|
||||||
|
img_sizes=self.s2_scales,
|
||||||
|
max_split_size=self.s2_split_size)
|
||||||
|
image_features.append(image_feature)
|
||||||
|
else:
|
||||||
|
image_features = self.multiscale_forward(self.forward_feature,
|
||||||
|
images,
|
||||||
|
img_sizes=self.s2_scales,
|
||||||
|
max_split_size=self.s2_split_size)
|
||||||
|
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hidden_size(self):
|
||||||
|
return self.config.hidden_size * len(self.s2_scales)
|
||||||
87
RDT-170M/models/multimodal_encoder/dinov2_encoder.py
Normal file
87
RDT-170M/models/multimodal_encoder/dinov2_encoder.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers import AutoConfig, AutoImageProcessor, AutoModel, Dinov2Model
|
||||||
|
|
||||||
|
|
||||||
|
class DinoV2VisionTower(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, vision_tower, args, delay_load=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.is_loaded = False
|
||||||
|
|
||||||
|
self.vision_tower_name = vision_tower
|
||||||
|
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
||||||
|
|
||||||
|
if not delay_load:
|
||||||
|
self.load_model()
|
||||||
|
elif getattr(args, 'unfreeze_mm_vision_tower', False):
|
||||||
|
self.load_model()
|
||||||
|
else:
|
||||||
|
self.cfg_only = AutoConfig.from_pretrained(self.vision_tower_name)
|
||||||
|
|
||||||
|
def load_model(self, device_map=None):
|
||||||
|
if self.is_loaded:
|
||||||
|
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
|
||||||
|
return
|
||||||
|
|
||||||
|
self.image_processor = AutoImageProcessor.from_pretrained(self.vision_tower_name)
|
||||||
|
self.vision_tower = AutoModel.from_pretrained(self.vision_tower_name, device_map=device_map)
|
||||||
|
self.vision_tower.requires_grad_(False) # FIXME:
|
||||||
|
|
||||||
|
self.is_loaded = True
|
||||||
|
|
||||||
|
def feature_select(self, image_forward_outs):
|
||||||
|
image_features = image_forward_outs.last_hidden_state
|
||||||
|
if self.select_feature == 'patch':
|
||||||
|
image_features = image_features[:, 1:] # (B, 1369, 1536)
|
||||||
|
elif self.select_feature == 'cls_patch':
|
||||||
|
image_features = image_features # (B, 1, 1536)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, images):
|
||||||
|
if type(images) is list:
|
||||||
|
image_features = []
|
||||||
|
for image in images:
|
||||||
|
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
|
||||||
|
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
||||||
|
image_features.append(image_feature)
|
||||||
|
else:
|
||||||
|
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype))
|
||||||
|
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
||||||
|
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dummy_feature(self):
|
||||||
|
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return self.vision_tower.dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return self.vision_tower.device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def config(self):
|
||||||
|
if self.is_loaded:
|
||||||
|
return self.vision_tower.config
|
||||||
|
else:
|
||||||
|
return self.cfg_only
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hidden_size(self):
|
||||||
|
return self.config.hidden_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_patches_per_side(self):
|
||||||
|
return self.config.image_size // self.config.patch_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_patches(self):
|
||||||
|
return (self.config.image_size // self.config.patch_size)**2
|
||||||
86
RDT-170M/models/multimodal_encoder/siglip_encoder.py
Normal file
86
RDT-170M/models/multimodal_encoder/siglip_encoder.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers import AutoConfig, SiglipImageProcessor, SiglipVisionModel
|
||||||
|
|
||||||
|
|
||||||
|
class SiglipVisionTower(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, vision_tower, args, delay_load=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.is_loaded = False
|
||||||
|
|
||||||
|
self.vision_tower_name = vision_tower
|
||||||
|
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
||||||
|
|
||||||
|
if not delay_load:
|
||||||
|
self.load_model()
|
||||||
|
elif getattr(args, 'unfreeze_mm_vision_tower', False):
|
||||||
|
self.load_model()
|
||||||
|
else:
|
||||||
|
self.cfg_only = AutoConfig.from_pretrained(self.vision_tower_name)
|
||||||
|
|
||||||
|
def load_model(self, device_map=None):
|
||||||
|
if self.is_loaded:
|
||||||
|
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
|
||||||
|
return
|
||||||
|
|
||||||
|
self.image_processor = SiglipImageProcessor.from_pretrained(self.vision_tower_name)
|
||||||
|
self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
|
||||||
|
self.vision_tower.eval()
|
||||||
|
|
||||||
|
self.is_loaded = True
|
||||||
|
|
||||||
|
def feature_select(self, image_forward_outs):
|
||||||
|
if self.select_feature == 'patch':
|
||||||
|
image_features = image_forward_outs.last_hidden_state # (B, 729, 1536)
|
||||||
|
elif self.select_feature == 'cls_patch':
|
||||||
|
image_features = image_forward_outs.pooler_output # (B, 1, 1536)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, images):
|
||||||
|
if type(images) is list:
|
||||||
|
image_features = []
|
||||||
|
for image in images:
|
||||||
|
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
|
||||||
|
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
||||||
|
image_features.append(image_feature)
|
||||||
|
else:
|
||||||
|
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype))
|
||||||
|
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
||||||
|
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dummy_feature(self):
|
||||||
|
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return self.vision_tower.dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return self.vision_tower.device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def config(self):
|
||||||
|
if self.is_loaded:
|
||||||
|
return self.vision_tower.config
|
||||||
|
else:
|
||||||
|
return self.cfg_only
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hidden_size(self):
|
||||||
|
return self.config.hidden_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_patches_per_side(self):
|
||||||
|
return self.config.image_size // self.config.patch_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_patches(self):
|
||||||
|
return (self.config.image_size // self.config.patch_size)**2
|
||||||
111
RDT-170M/models/multimodal_encoder/t5_encoder.py
Normal file
111
RDT-170M/models/multimodal_encoder/t5_encoder.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer, T5EncoderModel
|
||||||
|
|
||||||
|
|
||||||
|
class T5Embedder:
|
||||||
|
# available_models = ["google/t5-v1_1-xxl"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device,
|
||||||
|
from_pretrained=None,
|
||||||
|
*,
|
||||||
|
cache_dir=None,
|
||||||
|
hf_token=None,
|
||||||
|
use_text_preprocessing=True,
|
||||||
|
t5_model_kwargs=None,
|
||||||
|
torch_dtype=None,
|
||||||
|
use_offload_folder=None,
|
||||||
|
model_max_length=120,
|
||||||
|
local_files_only=False,
|
||||||
|
):
|
||||||
|
# from_pretrained="google/t5-v1_1-xxl" # zijian
|
||||||
|
self.device = torch.device(device)
|
||||||
|
self.torch_dtype = torch_dtype or torch.bfloat16
|
||||||
|
self.cache_dir = cache_dir
|
||||||
|
|
||||||
|
if t5_model_kwargs is None:
|
||||||
|
t5_model_kwargs = {
|
||||||
|
"low_cpu_mem_usage": True,
|
||||||
|
"torch_dtype": self.torch_dtype,
|
||||||
|
}
|
||||||
|
|
||||||
|
if use_offload_folder is not None:
|
||||||
|
t5_model_kwargs["offload_folder"] = use_offload_folder
|
||||||
|
t5_model_kwargs["device_map"] = {
|
||||||
|
"shared": self.device,
|
||||||
|
"encoder.embed_tokens": self.device,
|
||||||
|
"encoder.block.0": self.device,
|
||||||
|
"encoder.block.1": self.device,
|
||||||
|
"encoder.block.2": self.device,
|
||||||
|
"encoder.block.3": self.device,
|
||||||
|
"encoder.block.4": self.device,
|
||||||
|
"encoder.block.5": self.device,
|
||||||
|
"encoder.block.6": self.device,
|
||||||
|
"encoder.block.7": self.device,
|
||||||
|
"encoder.block.8": self.device,
|
||||||
|
"encoder.block.9": self.device,
|
||||||
|
"encoder.block.10": self.device,
|
||||||
|
"encoder.block.11": self.device,
|
||||||
|
"encoder.block.12": "disk",
|
||||||
|
"encoder.block.13": "disk",
|
||||||
|
"encoder.block.14": "disk",
|
||||||
|
"encoder.block.15": "disk",
|
||||||
|
"encoder.block.16": "disk",
|
||||||
|
"encoder.block.17": "disk",
|
||||||
|
"encoder.block.18": "disk",
|
||||||
|
"encoder.block.19": "disk",
|
||||||
|
"encoder.block.20": "disk",
|
||||||
|
"encoder.block.21": "disk",
|
||||||
|
"encoder.block.22": "disk",
|
||||||
|
"encoder.block.23": "disk",
|
||||||
|
"encoder.final_layer_norm": "disk",
|
||||||
|
"encoder.dropout": "disk",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
t5_model_kwargs["device_map"] = {
|
||||||
|
"shared": self.device,
|
||||||
|
"encoder": self.device,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.use_text_preprocessing = use_text_preprocessing
|
||||||
|
self.hf_token = hf_token
|
||||||
|
|
||||||
|
# assert from_pretrained in self.available_models
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
from_pretrained,
|
||||||
|
model_max_length=model_max_length,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
self.model = T5EncoderModel.from_pretrained(
|
||||||
|
from_pretrained,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
**t5_model_kwargs,
|
||||||
|
).eval()
|
||||||
|
self.model_max_length = model_max_length
|
||||||
|
|
||||||
|
def get_text_embeddings(self, texts):
|
||||||
|
text_tokens_and_mask = self.tokenizer(
|
||||||
|
texts,
|
||||||
|
max_length=self.model_max_length,
|
||||||
|
padding="longest",
|
||||||
|
truncation=True,
|
||||||
|
return_attention_mask=True,
|
||||||
|
add_special_tokens=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
input_ids = text_tokens_and_mask["input_ids"].to(self.device)
|
||||||
|
attention_mask = text_tokens_and_mask["attention_mask"].to(self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
text_encoder_embs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)["last_hidden_state"].detach()
|
||||||
|
return text_encoder_embs, attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
T5Embedder(from_pretrained="google/t5-v1_1-xxl", device='cuda:7')
|
||||||
304
RDT-170M/models/rdt/blocks.py
Normal file
304
RDT-170M/models/rdt/blocks.py
Normal file
@ -0,0 +1,304 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# References:
|
||||||
|
# DiT: https://github.com/facebookresearch/DiT
|
||||||
|
# GLIDE: https://github.com/openai/glide-text2im
|
||||||
|
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import math
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.jit import Final
|
||||||
|
from timm.models.vision_transformer import Attention, Mlp, RmsNorm, use_fused_attn
|
||||||
|
|
||||||
|
|
||||||
|
#################################################################################
|
||||||
|
# Embedding Layers for Timesteps and Condition Inptus #
|
||||||
|
#################################################################################
|
||||||
|
class TimestepEmbedder(nn.Module):
|
||||||
|
"""
|
||||||
|
Embeds scalar timesteps into vector representations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=torch.bfloat16):
|
||||||
|
super().__init__()
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||||
|
)
|
||||||
|
self.frequency_embedding_size = frequency_embedding_size
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
def timestep_embedding(self, t, dim, max_period=10000):
|
||||||
|
"""
|
||||||
|
Create sinusoidal timestep embeddings.
|
||||||
|
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||||
|
These may be fractional.
|
||||||
|
:param dim: the dimension of the output.
|
||||||
|
:param max_period: controls the minimum frequency of the embeddings.
|
||||||
|
:return: an (N, D) Tensor of positional embeddings.
|
||||||
|
"""
|
||||||
|
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
||||||
|
half = dim // 2
|
||||||
|
freqs = torch.exp(-math.log(max_period) *
|
||||||
|
torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
|
||||||
|
args = t[:, None].float() * freqs[None]
|
||||||
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
|
if dim % 2:
|
||||||
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||||
|
return embedding.to(self.dtype)
|
||||||
|
|
||||||
|
def forward(self, t):
|
||||||
|
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||||
|
t_emb = self.mlp(t_freq)
|
||||||
|
return t_emb
|
||||||
|
|
||||||
|
|
||||||
|
#################################################################################
|
||||||
|
# Cross Attention Layers #
|
||||||
|
#################################################################################
|
||||||
|
class CrossAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
A cross-attention layer with flash attention.
|
||||||
|
"""
|
||||||
|
fused_attn: Final[bool]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_heads: int = 8,
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
qk_norm: bool = False,
|
||||||
|
attn_drop: float = 0,
|
||||||
|
proj_drop: float = 0,
|
||||||
|
norm_layer: nn.Module = nn.LayerNorm,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
self.scale = self.head_dim**-0.5
|
||||||
|
self.fused_attn = use_fused_attn()
|
||||||
|
|
||||||
|
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
||||||
|
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
||||||
|
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||||
|
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, c: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
|
||||||
|
B, N, C = x.shape
|
||||||
|
_, L, _ = c.shape
|
||||||
|
q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
||||||
|
kv = self.kv(c).reshape(B, L, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||||
|
k, v = kv.unbind(0)
|
||||||
|
q, k = self.q_norm(q), self.k_norm(k)
|
||||||
|
|
||||||
|
# Prepare attn mask (B, L) to mask the conditioion
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask.reshape(B, 1, 1, L)
|
||||||
|
mask = mask.expand(-1, -1, N, -1)
|
||||||
|
|
||||||
|
if self.fused_attn:
|
||||||
|
x = F.scaled_dot_product_attention(query=q,
|
||||||
|
key=k,
|
||||||
|
value=v,
|
||||||
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
|
attn_mask=mask)
|
||||||
|
else:
|
||||||
|
q = q * self.scale
|
||||||
|
attn = q @ k.transpose(-2, -1)
|
||||||
|
if mask is not None:
|
||||||
|
attn = attn.masked_fill_(mask.logical_not(), float('-inf'))
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
if self.attn_drop.p > 0:
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
x = attn @ v
|
||||||
|
|
||||||
|
x = x.permute(0, 2, 1, 3).reshape(B, N, C)
|
||||||
|
x = self.proj(x)
|
||||||
|
if self.proj_drop.p > 0:
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
#################################################################################
|
||||||
|
# RDT Block #
|
||||||
|
#################################################################################
|
||||||
|
class RDTBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
A RDT block with cross-attention conditioning.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size, num_heads, **block_kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1 = RmsNorm(hidden_size, eps=1e-6)
|
||||||
|
self.attn = Attention(dim=hidden_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
qkv_bias=True,
|
||||||
|
qk_norm=True,
|
||||||
|
norm_layer=RmsNorm,
|
||||||
|
**block_kwargs)
|
||||||
|
self.cross_attn = CrossAttention(hidden_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
qkv_bias=True,
|
||||||
|
qk_norm=True,
|
||||||
|
norm_layer=RmsNorm,
|
||||||
|
**block_kwargs)
|
||||||
|
|
||||||
|
self.norm2 = RmsNorm(hidden_size, eps=1e-6)
|
||||||
|
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||||
|
self.ffn = Mlp(in_features=hidden_size, hidden_features=hidden_size, act_layer=approx_gelu, drop=0)
|
||||||
|
self.norm3 = RmsNorm(hidden_size, eps=1e-6)
|
||||||
|
|
||||||
|
def forward(self, x, c, mask=None):
|
||||||
|
origin_x = x
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.attn(x)
|
||||||
|
x = x + origin_x
|
||||||
|
|
||||||
|
origin_x = x
|
||||||
|
x = self.norm2(x)
|
||||||
|
x = self.cross_attn(x, c, mask)
|
||||||
|
x = x + origin_x
|
||||||
|
|
||||||
|
origin_x = x
|
||||||
|
x = self.norm3(x)
|
||||||
|
x = self.ffn(x)
|
||||||
|
x = x + origin_x
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FinalLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
The final layer of RDT.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size, out_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = RmsNorm(hidden_size, eps=1e-6)
|
||||||
|
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||||
|
self.ffn_final = Mlp(in_features=hidden_size,
|
||||||
|
hidden_features=hidden_size,
|
||||||
|
out_features=out_channels,
|
||||||
|
act_layer=approx_gelu,
|
||||||
|
drop=0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.norm_final(x)
|
||||||
|
x = self.ffn_final(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
#################################################################################
|
||||||
|
# Sine/Cosine Positional Embedding Functions #
|
||||||
|
#################################################################################
|
||||||
|
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
||||||
|
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||||
|
"""
|
||||||
|
embed_dim: output dimension for each position
|
||||||
|
pos: a list of positions to be encoded: size (M,)
|
||||||
|
out: (M, D)
|
||||||
|
"""
|
||||||
|
assert embed_dim % 2 == 0
|
||||||
|
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||||
|
omega /= embed_dim / 2.
|
||||||
|
omega = 1. / 10000**omega # (D/2,)
|
||||||
|
|
||||||
|
if not isinstance(pos, np.ndarray):
|
||||||
|
pos = np.array(pos, dtype=np.float64)
|
||||||
|
pos = pos.reshape(-1) # (M,)
|
||||||
|
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
||||||
|
|
||||||
|
emb_sin = np.sin(out) # (M, D/2)
|
||||||
|
emb_cos = np.cos(out) # (M, D/2)
|
||||||
|
|
||||||
|
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
def get_nd_sincos_pos_embed_from_grid(embed_dim, grid_sizes):
|
||||||
|
"""
|
||||||
|
embed_dim: output dimension for each position
|
||||||
|
grid_sizes: the grids sizes in each dimension (K,).
|
||||||
|
out: (grid_sizes[0], ..., grid_sizes[K-1], D)
|
||||||
|
"""
|
||||||
|
num_sizes = len(grid_sizes)
|
||||||
|
# For grid size of 1, we do not need to add any positional embedding
|
||||||
|
num_valid_sizes = len([x for x in grid_sizes if x > 1])
|
||||||
|
emb = np.zeros(grid_sizes + (embed_dim, ))
|
||||||
|
# Uniformly divide the embedding dimension for each grid size
|
||||||
|
dim_for_each_grid = embed_dim // num_valid_sizes
|
||||||
|
# To make it even
|
||||||
|
if dim_for_each_grid % 2 != 0:
|
||||||
|
dim_for_each_grid -= 1
|
||||||
|
valid_size_idx = 0
|
||||||
|
for size_idx in range(num_sizes):
|
||||||
|
grid_size = grid_sizes[size_idx]
|
||||||
|
if grid_size <= 1:
|
||||||
|
continue
|
||||||
|
pos = np.arange(grid_size)
|
||||||
|
posemb_shape = [1] * len(grid_sizes) + [dim_for_each_grid]
|
||||||
|
posemb_shape[size_idx] = -1
|
||||||
|
emb[..., valid_size_idx * dim_for_each_grid:(valid_size_idx + 1) * dim_for_each_grid] += \
|
||||||
|
get_1d_sincos_pos_embed_from_grid(dim_for_each_grid, pos).reshape(posemb_shape)
|
||||||
|
valid_size_idx += 1
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
def get_multimodal_cond_pos_embed(embed_dim, mm_cond_lens: OrderedDict, embed_modality=True):
|
||||||
|
"""
|
||||||
|
Generate position embeddings for multimodal conditions.
|
||||||
|
|
||||||
|
mm_cond_lens: an OrderedDict containing
|
||||||
|
(modality name, modality token length) pairs.
|
||||||
|
For `"image"` modality, the value can be a multi-dimensional tuple.
|
||||||
|
If the length < 0, it means there is no position embedding for the modality or grid.
|
||||||
|
embed_modality: whether to embed the modality information. Default is True.
|
||||||
|
"""
|
||||||
|
num_modalities = len(mm_cond_lens)
|
||||||
|
modality_pos_embed = np.zeros((num_modalities, embed_dim))
|
||||||
|
if embed_modality:
|
||||||
|
# Get embeddings for various modalites
|
||||||
|
# We put it in the first half
|
||||||
|
modality_sincos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, torch.arange(num_modalities))
|
||||||
|
modality_pos_embed[:, :embed_dim // 2] = modality_sincos_embed
|
||||||
|
# The second half is for position embeddings
|
||||||
|
pos_embed_dim = embed_dim // 2
|
||||||
|
else:
|
||||||
|
# The whole embedding is for position embeddings
|
||||||
|
pos_embed_dim = embed_dim
|
||||||
|
|
||||||
|
# Get embeddings for positions inside each modality
|
||||||
|
c_pos_emb = np.zeros((0, embed_dim))
|
||||||
|
for idx, (modality, cond_len) in enumerate(mm_cond_lens.items()):
|
||||||
|
if modality == "image" and \
|
||||||
|
(isinstance(cond_len, tuple) or isinstance(cond_len, list)):
|
||||||
|
all_grid_sizes = tuple([abs(x) for x in cond_len])
|
||||||
|
embed_grid_sizes = tuple([x if x > 0 else 1 for x in cond_len])
|
||||||
|
cond_sincos_embed = get_nd_sincos_pos_embed_from_grid(pos_embed_dim, embed_grid_sizes)
|
||||||
|
cond_pos_embed = np.zeros(all_grid_sizes + (embed_dim, ))
|
||||||
|
cond_pos_embed[..., -pos_embed_dim:] += cond_sincos_embed
|
||||||
|
cond_pos_embed = cond_pos_embed.reshape((-1, embed_dim))
|
||||||
|
else:
|
||||||
|
cond_sincos_embed = get_1d_sincos_pos_embed_from_grid(pos_embed_dim,
|
||||||
|
torch.arange(cond_len if cond_len > 0 else 1))
|
||||||
|
cond_pos_embed = np.zeros((abs(cond_len), embed_dim))
|
||||||
|
cond_pos_embed[:, -pos_embed_dim:] += cond_sincos_embed
|
||||||
|
cond_pos_embed += modality_pos_embed[idx]
|
||||||
|
c_pos_emb = np.concatenate([c_pos_emb, cond_pos_embed], axis=0)
|
||||||
|
|
||||||
|
return c_pos_emb
|
||||||
156
RDT-170M/models/rdt/model.py
Normal file
156
RDT-170M/models/rdt/model.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# References:
|
||||||
|
# DiT: https://github.com/facebookresearch/DiT
|
||||||
|
# GLIDE: https://github.com/openai/glide-text2im
|
||||||
|
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
||||||
|
# --------------------------------------------------------
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
import sys, os
|
||||||
|
# get current workspace
|
||||||
|
current_file = Path(__file__)
|
||||||
|
sys.path.append(str(current_file.parent.parent))
|
||||||
|
|
||||||
|
from rdt.blocks import (FinalLayer, RDTBlock, TimestepEmbedder, get_1d_sincos_pos_embed_from_grid,
|
||||||
|
get_multimodal_cond_pos_embed)
|
||||||
|
|
||||||
|
|
||||||
|
class RDT(nn.Module):
|
||||||
|
"""
|
||||||
|
Class for Robotics Diffusion Transformers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
output_dim=128,
|
||||||
|
horizon=32,
|
||||||
|
hidden_size=1152,
|
||||||
|
depth=28,
|
||||||
|
num_heads=16,
|
||||||
|
max_lang_cond_len=1024,
|
||||||
|
img_cond_len=4096,
|
||||||
|
lang_pos_embed_config=None,
|
||||||
|
img_pos_embed_config=None,
|
||||||
|
dtype=torch.bfloat16):
|
||||||
|
super().__init__()
|
||||||
|
self.horizon = horizon
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.max_lang_cond_len = max_lang_cond_len
|
||||||
|
self.img_cond_len = img_cond_len
|
||||||
|
self.dtype = dtype
|
||||||
|
self.lang_pos_embed_config = lang_pos_embed_config
|
||||||
|
self.img_pos_embed_config = img_pos_embed_config
|
||||||
|
|
||||||
|
self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype)
|
||||||
|
self.freq_embedder = TimestepEmbedder(hidden_size, dtype=dtype)
|
||||||
|
|
||||||
|
# We will use trainable sin-cos embeddings
|
||||||
|
# [timestep; state; action]
|
||||||
|
self.x_pos_embed = nn.Parameter(torch.zeros(1, horizon + 3, hidden_size))
|
||||||
|
# Language conditions
|
||||||
|
self.lang_cond_pos_embed = nn.Parameter(torch.zeros(1, max_lang_cond_len, hidden_size))
|
||||||
|
# Image conditions
|
||||||
|
self.img_cond_pos_embed = nn.Parameter(torch.zeros(1, img_cond_len, hidden_size))
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList([RDTBlock(hidden_size, num_heads) for _ in range(depth)])
|
||||||
|
self.final_layer = FinalLayer(hidden_size, output_dim)
|
||||||
|
self.initialize_weights()
|
||||||
|
|
||||||
|
def initialize_weights(self):
|
||||||
|
# Initialize transformer layers:
|
||||||
|
def _basic_init(module):
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
torch.nn.init.xavier_uniform_(module.weight)
|
||||||
|
if module.bias is not None:
|
||||||
|
nn.init.constant_(module.bias, 0)
|
||||||
|
|
||||||
|
self.apply(_basic_init)
|
||||||
|
|
||||||
|
# Initialize pos_embed by sin-cos embedding
|
||||||
|
x_pos_embed = get_multimodal_cond_pos_embed(embed_dim=self.hidden_size,
|
||||||
|
mm_cond_lens=OrderedDict([
|
||||||
|
('timestep', 1),
|
||||||
|
('ctrl_freq', 1),
|
||||||
|
('state', 1),
|
||||||
|
('action', self.horizon),
|
||||||
|
]))
|
||||||
|
self.x_pos_embed.data.copy_(torch.from_numpy(x_pos_embed).float().unsqueeze(0))
|
||||||
|
|
||||||
|
if self.lang_pos_embed_config is None:
|
||||||
|
lang_cond_pos_embed = get_1d_sincos_pos_embed_from_grid(self.hidden_size,
|
||||||
|
torch.arange(self.max_lang_cond_len))
|
||||||
|
else:
|
||||||
|
lang_cond_pos_embed = get_multimodal_cond_pos_embed(embed_dim=self.hidden_size,
|
||||||
|
mm_cond_lens=OrderedDict(self.lang_pos_embed_config),
|
||||||
|
embed_modality=False)
|
||||||
|
self.lang_cond_pos_embed.data.copy_(torch.from_numpy(lang_cond_pos_embed).float().unsqueeze(0))
|
||||||
|
|
||||||
|
if self.img_pos_embed_config is None:
|
||||||
|
img_cond_pos_embed = get_1d_sincos_pos_embed_from_grid(self.hidden_size, torch.arange(self.img_cond_len))
|
||||||
|
else:
|
||||||
|
img_cond_pos_embed = get_multimodal_cond_pos_embed(embed_dim=self.hidden_size,
|
||||||
|
mm_cond_lens=OrderedDict(self.img_pos_embed_config),
|
||||||
|
embed_modality=False)
|
||||||
|
self.img_cond_pos_embed.data.copy_(torch.from_numpy(img_cond_pos_embed).float().unsqueeze(0))
|
||||||
|
|
||||||
|
# Initialize timestep and control freq embedding MLP
|
||||||
|
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||||
|
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||||
|
nn.init.normal_(self.freq_embedder.mlp[0].weight, std=0.02)
|
||||||
|
nn.init.normal_(self.freq_embedder.mlp[2].weight, std=0.02)
|
||||||
|
|
||||||
|
# Initialize the final layer: zero-out the final linear layer
|
||||||
|
nn.init.constant_(self.final_layer.ffn_final.fc2.weight, 0)
|
||||||
|
nn.init.constant_(self.final_layer.ffn_final.fc2.bias, 0)
|
||||||
|
|
||||||
|
# Move all the params to given data type:
|
||||||
|
self.to(self.dtype)
|
||||||
|
|
||||||
|
def forward(self, x, freq, t, lang_c, img_c, lang_mask=None, img_mask=None):
|
||||||
|
"""
|
||||||
|
Forward pass of RDT.
|
||||||
|
|
||||||
|
x: (B, T, D), state + action token sequence, T = horizon + 1,
|
||||||
|
dimension D is assumed to be the same as the hidden size.
|
||||||
|
freq: (B,), a scalar indicating control frequency.
|
||||||
|
t: (B,) or (1,), diffusion timesteps.
|
||||||
|
lang_c: (B, L_lang, D) or None, language condition tokens (variable length),
|
||||||
|
dimension D is assumed to be the same as the hidden size.
|
||||||
|
img_c: (B, L_img, D) or None, image condition tokens (fixed length),
|
||||||
|
dimension D is assumed to be the same as the hidden size.
|
||||||
|
lang_mask: (B, L_lang) or None, language condition mask (True for valid).
|
||||||
|
img_mask: (B, L_img) or None, image condition mask (True for valid).
|
||||||
|
"""
|
||||||
|
t = self.t_embedder(t).unsqueeze(1) # (B, 1, D) or (1, 1, D)
|
||||||
|
freq = self.freq_embedder(freq).unsqueeze(1) # (B, 1, D)
|
||||||
|
# Append timestep to the input tokens
|
||||||
|
if t.shape[0] == 1:
|
||||||
|
t = t.expand(x.shape[0], -1, -1)
|
||||||
|
x = torch.cat([t, freq, x], dim=1) # (B, T+1, D)
|
||||||
|
|
||||||
|
# Add multimodal position embeddings
|
||||||
|
x = x + self.x_pos_embed
|
||||||
|
# Note the lang is of variable length
|
||||||
|
lang_c = lang_c + self.lang_cond_pos_embed[:, :lang_c.shape[1]]
|
||||||
|
img_c = img_c + self.img_cond_pos_embed
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
conds = [lang_c, img_c]
|
||||||
|
masks = [lang_mask, img_mask]
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
c, mask = conds[i % 2], masks[i % 2]
|
||||||
|
x = block(x, c, mask) # (B, T+1, D)
|
||||||
|
# Inject the language condition at the final layer
|
||||||
|
x = self.final_layer(x) # (B, T+1, out_channels)
|
||||||
|
|
||||||
|
# Only preserve the action tokens
|
||||||
|
x = x[:, -self.horizon:]
|
||||||
|
return x
|
||||||
246
RDT-170M/models/rdt_runner.py
Normal file
246
RDT-170M/models/rdt_runner.py
Normal file
@ -0,0 +1,246 @@
|
|||||||
|
import re, sys, os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||||
|
from diffusers.schedulers.scheduling_dpmsolver_multistep import \
|
||||||
|
DPMSolverMultistepScheduler
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
# get current workspace
|
||||||
|
current_file = Path(__file__)
|
||||||
|
sys.path.append(os.path.join(current_file.parent))
|
||||||
|
from hub_mixin import CompatiblePyTorchModelHubMixin
|
||||||
|
from rdt.model import RDT
|
||||||
|
|
||||||
|
|
||||||
|
class RDTRunner(nn.Module,
|
||||||
|
CompatiblePyTorchModelHubMixin,
|
||||||
|
repo_url="https://huggingface.co/robotics-diffusion-transformer/rdt-1b"):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
*,
|
||||||
|
action_dim,
|
||||||
|
pred_horizon,
|
||||||
|
config,
|
||||||
|
lang_token_dim,
|
||||||
|
img_token_dim,
|
||||||
|
state_token_dim,
|
||||||
|
max_lang_cond_len,
|
||||||
|
img_cond_len,
|
||||||
|
lang_pos_embed_config=None,
|
||||||
|
img_pos_embed_config=None,
|
||||||
|
dtype=torch.bfloat16):
|
||||||
|
super(RDTRunner, self).__init__()
|
||||||
|
# Create diffusion model
|
||||||
|
hidden_size = config['rdt']['hidden_size']
|
||||||
|
self.model = RDT(
|
||||||
|
output_dim=action_dim,
|
||||||
|
horizon=pred_horizon,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
depth=config['rdt']['depth'],
|
||||||
|
num_heads=config['rdt']['num_heads'],
|
||||||
|
max_lang_cond_len=max_lang_cond_len,
|
||||||
|
img_cond_len=img_cond_len,
|
||||||
|
lang_pos_embed_config=lang_pos_embed_config,
|
||||||
|
img_pos_embed_config=img_pos_embed_config,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create adpators for various conditional inputs
|
||||||
|
self.lang_adaptor = self.build_condition_adapter(config['lang_adaptor'],
|
||||||
|
in_features=lang_token_dim,
|
||||||
|
out_features=hidden_size)
|
||||||
|
self.img_adaptor = self.build_condition_adapter(config['img_adaptor'],
|
||||||
|
in_features=img_token_dim,
|
||||||
|
out_features=hidden_size)
|
||||||
|
# A `state` refers to an action or a proprioception vector
|
||||||
|
self.state_adaptor = self.build_condition_adapter(
|
||||||
|
config['state_adaptor'],
|
||||||
|
in_features=state_token_dim * 2, # state + state mask (indicator)
|
||||||
|
out_features=hidden_size)
|
||||||
|
|
||||||
|
# Create the noise scheduler
|
||||||
|
noise_scheduler_config = config['noise_scheduler']
|
||||||
|
self.noise_scheduler = DDPMScheduler(
|
||||||
|
num_train_timesteps=noise_scheduler_config['num_train_timesteps'],
|
||||||
|
beta_schedule=noise_scheduler_config['beta_schedule'],
|
||||||
|
prediction_type=noise_scheduler_config['prediction_type'],
|
||||||
|
clip_sample=noise_scheduler_config['clip_sample'],
|
||||||
|
)
|
||||||
|
self.noise_scheduler_sample = DPMSolverMultistepScheduler(
|
||||||
|
num_train_timesteps=noise_scheduler_config['num_train_timesteps'],
|
||||||
|
beta_schedule=noise_scheduler_config['beta_schedule'],
|
||||||
|
prediction_type=noise_scheduler_config['prediction_type'],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_train_timesteps = noise_scheduler_config['num_train_timesteps']
|
||||||
|
self.num_inference_timesteps = noise_scheduler_config['num_inference_timesteps']
|
||||||
|
self.prediction_type = noise_scheduler_config['prediction_type']
|
||||||
|
|
||||||
|
self.pred_horizon = pred_horizon
|
||||||
|
self.action_dim = action_dim
|
||||||
|
|
||||||
|
print("Diffusion params: %e" %
|
||||||
|
sum([p.numel() for p in self.model.parameters()] + [p.numel() for p in self.lang_adaptor.parameters()] +
|
||||||
|
[p.numel()
|
||||||
|
for p in self.img_adaptor.parameters()] + [p.numel() for p in self.state_adaptor.parameters()]))
|
||||||
|
|
||||||
|
def build_condition_adapter(self, projector_type, in_features, out_features):
|
||||||
|
projector = None
|
||||||
|
if projector_type == 'linear':
|
||||||
|
projector = nn.Linear(in_features, out_features)
|
||||||
|
else:
|
||||||
|
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
|
||||||
|
if mlp_gelu_match:
|
||||||
|
mlp_depth = int(mlp_gelu_match.group(1))
|
||||||
|
modules = [nn.Linear(in_features, out_features)]
|
||||||
|
for _ in range(1, mlp_depth):
|
||||||
|
modules.append(nn.GELU(approximate="tanh"))
|
||||||
|
modules.append(nn.Linear(out_features, out_features))
|
||||||
|
projector = nn.Sequential(*modules)
|
||||||
|
|
||||||
|
if projector is None:
|
||||||
|
raise ValueError(f'Unknown projector type: {projector_type}')
|
||||||
|
|
||||||
|
return projector
|
||||||
|
|
||||||
|
def adapt_conditions(self, lang_tokens, img_tokens, state_tokens):
|
||||||
|
'''
|
||||||
|
lang_tokens: (batch_size, lang_len, lang_token_dim)
|
||||||
|
img_tokens: (batch_size, img_len, img_token_dim)
|
||||||
|
state_tokens: (batch_size, state_len, state_token_dim)
|
||||||
|
|
||||||
|
return: adpated (..., hidden_size) for all input tokens
|
||||||
|
'''
|
||||||
|
adpated_lang = self.lang_adaptor(lang_tokens)
|
||||||
|
adpated_img = self.img_adaptor(img_tokens)
|
||||||
|
adpated_state = self.state_adaptor(state_tokens)
|
||||||
|
|
||||||
|
return adpated_lang, adpated_img, adpated_state
|
||||||
|
|
||||||
|
def conditional_sample(self, lang_cond, lang_attn_mask, img_cond, state_traj, action_mask, ctrl_freqs):
|
||||||
|
'''
|
||||||
|
lang_cond: language conditional data, (batch_size, lang_len, hidden_size).
|
||||||
|
lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens,
|
||||||
|
which should be True-False bool tensor.
|
||||||
|
img_cond: image conditional data, (batch_size, img_len, hidden_size).
|
||||||
|
state_traj: (batch_size, 1, hidden_size), state trajectory.
|
||||||
|
action_mask: (batch_size, 1, action_dim), a 0-1 **float** tensor
|
||||||
|
indicating the valid action dimensions.
|
||||||
|
ctrl_freqs: (batch_size,), control frequency for each sample.
|
||||||
|
|
||||||
|
return: (batch_size, horizon, action_dim)
|
||||||
|
'''
|
||||||
|
device = state_traj.device
|
||||||
|
dtype = state_traj.dtype
|
||||||
|
noisy_action = torch.randn(size=(state_traj.shape[0], self.pred_horizon, self.action_dim),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device)
|
||||||
|
action_mask = action_mask.expand(-1, self.pred_horizon, -1)
|
||||||
|
|
||||||
|
# Set step values
|
||||||
|
self.noise_scheduler_sample.set_timesteps(self.num_inference_timesteps)
|
||||||
|
|
||||||
|
for t in self.noise_scheduler_sample.timesteps:
|
||||||
|
# Prepare state-action trajectory
|
||||||
|
action_traj = torch.cat([noisy_action, action_mask], dim=2)
|
||||||
|
action_traj = self.state_adaptor(action_traj)
|
||||||
|
state_action_traj = torch.cat([state_traj, action_traj], dim=1)
|
||||||
|
|
||||||
|
# Predict the model output
|
||||||
|
model_output = self.model(state_action_traj,
|
||||||
|
ctrl_freqs,
|
||||||
|
t.unsqueeze(-1).to(device),
|
||||||
|
lang_cond,
|
||||||
|
img_cond,
|
||||||
|
lang_mask=lang_attn_mask)
|
||||||
|
|
||||||
|
# Compute previous actions: x_t -> x_t-1
|
||||||
|
noisy_action = self.noise_scheduler_sample.step(model_output, t, noisy_action).prev_sample
|
||||||
|
noisy_action = noisy_action.to(state_traj.dtype)
|
||||||
|
|
||||||
|
# Finally apply the action mask to mask invalid action dimensions
|
||||||
|
noisy_action = noisy_action * action_mask
|
||||||
|
|
||||||
|
return noisy_action
|
||||||
|
|
||||||
|
# ========= Train ============
|
||||||
|
def compute_loss(self, lang_tokens, lang_attn_mask, img_tokens, state_tokens, action_gt, action_mask,
|
||||||
|
ctrl_freqs) -> torch.Tensor:
|
||||||
|
'''
|
||||||
|
lang_tokens: (batch_size, lang_len, lang_token_dim)
|
||||||
|
lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens,
|
||||||
|
which should be True-False bool tensor.
|
||||||
|
img_tokens: (batch_size, img_len, img_token_dim)
|
||||||
|
state_tokens: (batch_size, 1, state_token_dim)
|
||||||
|
action_gt: (batch_size, horizon, state_token_dim), ground-truth actions for supervision
|
||||||
|
action_mask: (batch_size, 1, state_token_dim), a 0-1 **float** tensor.
|
||||||
|
ctrl_freqs: (batch_size,), control frequency for each sample.
|
||||||
|
|
||||||
|
return: loss_value, a scalar tensor
|
||||||
|
'''
|
||||||
|
batch_size = lang_tokens.shape[0]
|
||||||
|
device = lang_tokens.device
|
||||||
|
# Sample noise that we'll add to the actions
|
||||||
|
noise = torch.randn(action_gt.shape, dtype=action_gt.dtype, device=device)
|
||||||
|
# Sample random diffusion timesteps
|
||||||
|
timesteps = torch.randint(0, self.num_train_timesteps, (batch_size, ), device=device).long()
|
||||||
|
# Add noise to the clean actions according to the noise magnitude at each timestep
|
||||||
|
# (this is the forward diffusion process)
|
||||||
|
noisy_action = self.noise_scheduler.add_noise(action_gt, noise, timesteps)
|
||||||
|
|
||||||
|
# Concatenate the state and action tokens to form the input sequence
|
||||||
|
state_action_traj = torch.cat([state_tokens, noisy_action], dim=1)
|
||||||
|
# Append the action mask to the input sequence
|
||||||
|
action_mask = action_mask.expand(-1, state_action_traj.shape[1], -1)
|
||||||
|
state_action_traj = torch.cat([state_action_traj, action_mask], dim=2)
|
||||||
|
# Align the dimension with the hidden size
|
||||||
|
lang_cond, img_cond, state_action_traj = self.adapt_conditions(lang_tokens, img_tokens, state_action_traj)
|
||||||
|
# Predict the denoised result
|
||||||
|
pred = self.model(state_action_traj, ctrl_freqs, timesteps, lang_cond, img_cond, lang_mask=lang_attn_mask)
|
||||||
|
|
||||||
|
pred_type = self.prediction_type
|
||||||
|
if pred_type == 'epsilon':
|
||||||
|
target = noise
|
||||||
|
elif pred_type == 'sample':
|
||||||
|
target = action_gt
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported prediction type {pred_type}")
|
||||||
|
loss = F.mse_loss(pred, target)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
# ========= Inference ============
|
||||||
|
def predict_action(self, lang_tokens, lang_attn_mask, img_tokens, state_tokens, action_mask, ctrl_freqs):
|
||||||
|
'''
|
||||||
|
lang_tokens: (batch_size, lang_len, lang_token_dim)
|
||||||
|
lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens,
|
||||||
|
which should be True-False bool tensor.
|
||||||
|
img_tokens: (batch_size, img_len, img_token_dim)
|
||||||
|
state_tokens: (batch_size, 1, state_token_dim)
|
||||||
|
action_mask: (batch_size, 1, action_dim),
|
||||||
|
which should be a 0-1 **float** tensor.
|
||||||
|
ctrl_freqs: (batch_size,), control frequency for each sample.
|
||||||
|
|
||||||
|
return: (batch_size, horizon, action_dim), predicted action sequence
|
||||||
|
'''
|
||||||
|
# Prepare the state and conditions
|
||||||
|
state_tokens = torch.cat([state_tokens, action_mask], dim=2)
|
||||||
|
lang_cond, img_cond, state_traj = self.adapt_conditions(lang_tokens, img_tokens, state_tokens)
|
||||||
|
|
||||||
|
# Run sampling
|
||||||
|
action_pred = self.conditional_sample(
|
||||||
|
lang_cond,
|
||||||
|
lang_attn_mask,
|
||||||
|
img_cond,
|
||||||
|
state_traj,
|
||||||
|
action_mask,
|
||||||
|
ctrl_freqs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return action_pred
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs) -> torch.Tensor:
|
||||||
|
return self.compute_loss(*args, **kwargs)
|
||||||
49
RDT-170M/pretrain.sh
Normal file
49
RDT-170M/pretrain.sh
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
export NCCL_IB_HCA=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
|
||||||
|
export NCCL_IB_DISABLE=0
|
||||||
|
export NCCL_SOCKET_IFNAME=bond0
|
||||||
|
export NCCL_DEBUG=INFO
|
||||||
|
export NCCL_NVLS_ENABLE=0
|
||||||
|
|
||||||
|
export TEXT_ENCODER_NAME="google/t5-v1_1-xxl"
|
||||||
|
export VISION_ENCODER_NAME="google/siglip-so400m-patch14-384"
|
||||||
|
export OUTPUT_DIR="./checkpoints/rdt-pretrain-1b"
|
||||||
|
export CFLAGS="-I/usr/include"
|
||||||
|
export LDFLAGS="-L/usr/lib/x86_64-linux-gnu"
|
||||||
|
export CUTLASS_PATH="/path/to/cutlass"
|
||||||
|
|
||||||
|
export WANDB_PROJECT="robotics_diffusion_transformer"
|
||||||
|
|
||||||
|
if [ ! -d "$OUTPUT_DIR" ]; then
|
||||||
|
mkdir "$OUTPUT_DIR"
|
||||||
|
echo "Folder '$OUTPUT_DIR' created"
|
||||||
|
else
|
||||||
|
echo "Folder '$OUTPUT_DIR' already exists"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# For run in a single node/machine
|
||||||
|
# accelerate launch main.py \
|
||||||
|
# --deepspeed="./configs/zero2.json" \
|
||||||
|
# ...
|
||||||
|
|
||||||
|
deepspeed --hostfile=hostfile.txt main.py \
|
||||||
|
--deepspeed="./configs/zero2.json" \
|
||||||
|
--pretrained_text_encoder_name_or_path=$TEXT_ENCODER_NAME \
|
||||||
|
--pretrained_vision_encoder_name_or_path=$VISION_ENCODER_NAME \
|
||||||
|
--output_dir=$OUTPUT_DIR \
|
||||||
|
--train_batch_size=32 \
|
||||||
|
--sample_batch_size=64 \
|
||||||
|
--max_train_steps=1000000 \
|
||||||
|
--checkpointing_period=1000 \
|
||||||
|
--sample_period=500 \
|
||||||
|
--checkpoints_total_limit=40 \
|
||||||
|
--lr_scheduler="constant" \
|
||||||
|
--learning_rate=1e-4 \
|
||||||
|
--mixed_precision="bf16" \
|
||||||
|
--dataloader_num_workers=8 \
|
||||||
|
--dataset_type="pretrain" \
|
||||||
|
--report_to=wandb
|
||||||
|
|
||||||
|
# Use this to resume training from some previous checkpoint
|
||||||
|
# --resume_from_checkpoint="checkpoint-1000" \
|
||||||
9
RDT-170M/process_data_rdt.sh
Normal file
9
RDT-170M/process_data_rdt.sh
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
task_name=${1}
|
||||||
|
task_config=${2}
|
||||||
|
expert_data_num=${3}
|
||||||
|
gpu_id=${4}
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES=${gpu_id}
|
||||||
|
python scripts/process_data.py $task_name $task_config $expert_data_num
|
||||||
24
RDT-170M/requirements.txt
Normal file
24
RDT-170M/requirements.txt
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
numpy<2.0
|
||||||
|
packaging==24.0
|
||||||
|
wandb==0.17.0
|
||||||
|
deepspeed==0.14.2
|
||||||
|
accelerate==0.30.1
|
||||||
|
diffusers==0.27.2
|
||||||
|
timm==1.0.3
|
||||||
|
transformers==4.41.0
|
||||||
|
sentencepiece==0.2.0
|
||||||
|
h5py==3.11.0
|
||||||
|
opencv-python==4.9.0.80
|
||||||
|
imgaug==0.4.0
|
||||||
|
pytz==2022.1
|
||||||
|
huggingface_hub==0.23.0
|
||||||
|
|
||||||
|
# requirements_data.txt
|
||||||
|
# tfds-nightly==4.9.4.dev202402070044
|
||||||
|
gsutil==5.27
|
||||||
|
tensorflow==2.15.0.post1
|
||||||
|
pillow==10.2.0
|
||||||
|
pyyaml==6.0.1
|
||||||
|
tensorflow-graphics==2021.12.3
|
||||||
|
imageio==2.34.0
|
||||||
|
imageio-ffmpeg==0.4.9
|
||||||
941
RDT-170M/scripts/agilex_inference.py
Normal file
941
RDT-170M/scripts/agilex_inference.py
Normal file
@ -0,0 +1,941 @@
|
|||||||
|
#!/home/lin/software/miniconda3/envs/aloha/bin/python
|
||||||
|
# -- coding: UTF-8
|
||||||
|
"""
|
||||||
|
#!/usr/bin/python3
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import yaml
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import rospy
|
||||||
|
import torch
|
||||||
|
from cv_bridge import CvBridge
|
||||||
|
from geometry_msgs.msg import Twist
|
||||||
|
from nav_msgs.msg import Odometry
|
||||||
|
from PIL import Image as PImage
|
||||||
|
from sensor_msgs.msg import Image, JointState
|
||||||
|
from std_msgs.msg import Header
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
from scripts.agilex_model import create_model
|
||||||
|
|
||||||
|
# sys.path.append("./")
|
||||||
|
|
||||||
|
CAMERA_NAMES = ["cam_high", "cam_right_wrist", "cam_left_wrist"]
|
||||||
|
|
||||||
|
observation_window = None
|
||||||
|
|
||||||
|
lang_embeddings = None
|
||||||
|
|
||||||
|
# debug
|
||||||
|
preload_images = None
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize the model
|
||||||
|
def make_policy(args):
|
||||||
|
with open(args.config_path, "r") as fp:
|
||||||
|
config = yaml.safe_load(fp)
|
||||||
|
args.config = config
|
||||||
|
|
||||||
|
# pretrained_text_encoder_name_or_path = "google/t5-v1_1-xxl"
|
||||||
|
pretrained_vision_encoder_name_or_path = "google/siglip-so400m-patch14-384"
|
||||||
|
model = create_model(
|
||||||
|
args=args.config,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
pretrained=args.pretrained_model_name_or_path,
|
||||||
|
# pretrained_text_encoder_name_or_path=pretrained_text_encoder_name_or_path,
|
||||||
|
pretrained_vision_encoder_name_or_path=pretrained_vision_encoder_name_or_path,
|
||||||
|
control_frequency=args.ctrl_freq,
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed(seed):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
# Interpolate the actions to make the robot move smoothly
|
||||||
|
def interpolate_action(args, prev_action, cur_action):
|
||||||
|
steps = np.concatenate((np.array(args.arm_steps_length), np.array(args.arm_steps_length)), axis=0)
|
||||||
|
diff = np.abs(cur_action - prev_action)
|
||||||
|
step = np.ceil(diff / steps).astype(int)
|
||||||
|
step = np.max(step)
|
||||||
|
if step <= 1:
|
||||||
|
return cur_action[np.newaxis, :]
|
||||||
|
new_actions = np.linspace(prev_action, cur_action, step + 1)
|
||||||
|
return new_actions[1:]
|
||||||
|
|
||||||
|
|
||||||
|
def get_config(args):
|
||||||
|
config = {
|
||||||
|
"episode_len": args.max_publish_step,
|
||||||
|
"state_dim": 14,
|
||||||
|
"chunk_size": args.chunk_size,
|
||||||
|
"camera_names": CAMERA_NAMES,
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
# Get the observation from the ROS topic
|
||||||
|
def get_ros_observation(args, ros_operator):
|
||||||
|
rate = rospy.Rate(args.publish_rate)
|
||||||
|
print_flag = True
|
||||||
|
|
||||||
|
while True and not rospy.is_shutdown():
|
||||||
|
result = ros_operator.get_frame()
|
||||||
|
if not result:
|
||||||
|
if print_flag:
|
||||||
|
print("syn fail when get_ros_observation")
|
||||||
|
print_flag = False
|
||||||
|
rate.sleep()
|
||||||
|
continue
|
||||||
|
print_flag = True
|
||||||
|
(
|
||||||
|
img_front,
|
||||||
|
img_left,
|
||||||
|
img_right,
|
||||||
|
img_front_depth,
|
||||||
|
img_left_depth,
|
||||||
|
img_right_depth,
|
||||||
|
puppet_arm_left,
|
||||||
|
puppet_arm_right,
|
||||||
|
robot_base,
|
||||||
|
) = result
|
||||||
|
# print(f"sync success when get_ros_observation")
|
||||||
|
return (img_front, img_left, img_right, puppet_arm_left, puppet_arm_right)
|
||||||
|
|
||||||
|
|
||||||
|
# Update the observation window buffer
|
||||||
|
def update_observation_window(args, config, ros_operator):
|
||||||
|
# JPEG transformation
|
||||||
|
# Align with training
|
||||||
|
def jpeg_mapping(img):
|
||||||
|
img = cv2.imencode(".jpg", img)[1].tobytes()
|
||||||
|
img = cv2.imdecode(np.frombuffer(img, np.uint8), cv2.IMREAD_COLOR)
|
||||||
|
return img
|
||||||
|
|
||||||
|
global observation_window
|
||||||
|
if observation_window is None:
|
||||||
|
observation_window = deque(maxlen=2)
|
||||||
|
|
||||||
|
# Append the first dummy image
|
||||||
|
observation_window.append({
|
||||||
|
"qpos": None,
|
||||||
|
"images": {
|
||||||
|
config["camera_names"][0]: None,
|
||||||
|
config["camera_names"][1]: None,
|
||||||
|
config["camera_names"][2]: None,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
img_front, img_left, img_right, puppet_arm_left, puppet_arm_right = (get_ros_observation(args, ros_operator))
|
||||||
|
img_front = jpeg_mapping(img_front)
|
||||||
|
img_left = jpeg_mapping(img_left)
|
||||||
|
img_right = jpeg_mapping(img_right)
|
||||||
|
|
||||||
|
qpos = np.concatenate(
|
||||||
|
(np.array(puppet_arm_left.position), np.array(puppet_arm_right.position)),
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
qpos = torch.from_numpy(qpos).float().cuda()
|
||||||
|
observation_window.append({
|
||||||
|
"qpos": qpos,
|
||||||
|
"images": {
|
||||||
|
config["camera_names"][0]: img_front,
|
||||||
|
config["camera_names"][1]: img_right,
|
||||||
|
config["camera_names"][2]: img_left,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
# RDT inference
|
||||||
|
def inference_fn(args, config, policy, t):
|
||||||
|
global observation_window
|
||||||
|
global lang_embeddings
|
||||||
|
|
||||||
|
# print(f"Start inference_thread_fn: t={t}")
|
||||||
|
while True and not rospy.is_shutdown():
|
||||||
|
time1 = time.time()
|
||||||
|
|
||||||
|
# fetch images in sequence [front, right, left]
|
||||||
|
image_arrs = [
|
||||||
|
observation_window[-2]["images"][config["camera_names"][0]],
|
||||||
|
observation_window[-2]["images"][config["camera_names"][1]],
|
||||||
|
observation_window[-2]["images"][config["camera_names"][2]],
|
||||||
|
observation_window[-1]["images"][config["camera_names"][0]],
|
||||||
|
observation_window[-1]["images"][config["camera_names"][1]],
|
||||||
|
observation_window[-1]["images"][config["camera_names"][2]],
|
||||||
|
]
|
||||||
|
|
||||||
|
# fetch debug images in sequence [front, right, left]
|
||||||
|
# image_arrs = [
|
||||||
|
# preload_images[config['camera_names'][0]][max(t - 1, 0)],
|
||||||
|
# preload_images[config['camera_names'][2]][max(t - 1, 0)],
|
||||||
|
# preload_images[config['camera_names'][1]][max(t - 1, 0)],
|
||||||
|
# preload_images[config['camera_names'][0]][t],
|
||||||
|
# preload_images[config['camera_names'][2]][t],
|
||||||
|
# preload_images[config['camera_names'][1]][t]
|
||||||
|
# ]
|
||||||
|
# # encode the images
|
||||||
|
# for i in range(len(image_arrs)):
|
||||||
|
# image_arrs[i] = cv2.imdecode(np.frombuffer(image_arrs[i], np.uint8), cv2.IMREAD_COLOR)
|
||||||
|
# proprio = torch.from_numpy(preload_images['qpos'][t]).float().cuda()
|
||||||
|
|
||||||
|
images = [PImage.fromarray(arr) if arr is not None else None for arr in image_arrs]
|
||||||
|
|
||||||
|
# for i, pos in enumerate(['f', 'r', 'l'] * 2):
|
||||||
|
# images[i].save(f'{t}-{i}-{pos}.png')
|
||||||
|
|
||||||
|
# get last qpos in shape [14, ]
|
||||||
|
proprio = observation_window[-1]["qpos"]
|
||||||
|
# unsqueeze to [1, 14]
|
||||||
|
proprio = proprio.unsqueeze(0)
|
||||||
|
|
||||||
|
# actions shaped as [1, 64, 14] in format [left, right]
|
||||||
|
actions = (policy.step(proprio=proprio, images=images, text_embeds=lang_embeddings).squeeze(0).cpu().numpy())
|
||||||
|
# print(f"inference_actions: {actions.squeeze()}")
|
||||||
|
|
||||||
|
# print(f"Model inference time: {time.time() - time1} s")
|
||||||
|
|
||||||
|
# print(f"Finish inference_thread_fn: t={t}")
|
||||||
|
return actions
|
||||||
|
|
||||||
|
|
||||||
|
# Main loop for the manipulation task
|
||||||
|
def model_inference(args, config, ros_operator):
|
||||||
|
global lang_embeddings
|
||||||
|
|
||||||
|
# Load rdt model
|
||||||
|
policy = make_policy(args)
|
||||||
|
|
||||||
|
lang_dict = torch.load(args.lang_embeddings_path)
|
||||||
|
print(f"Running with instruction: \"{lang_dict['instruction']}\" from \"{lang_dict['name']}\"")
|
||||||
|
lang_embeddings = lang_dict["embeddings"]
|
||||||
|
|
||||||
|
max_publish_step = config["episode_len"]
|
||||||
|
chunk_size = config["chunk_size"]
|
||||||
|
|
||||||
|
# Initialize position of the puppet arm
|
||||||
|
left0 = [
|
||||||
|
-0.00133514404296875,
|
||||||
|
0.00209808349609375,
|
||||||
|
0.01583099365234375,
|
||||||
|
-0.032616615295410156,
|
||||||
|
-0.00286102294921875,
|
||||||
|
0.00095367431640625,
|
||||||
|
3.557830810546875,
|
||||||
|
]
|
||||||
|
right0 = [
|
||||||
|
-0.00133514404296875,
|
||||||
|
0.00438690185546875,
|
||||||
|
0.034523963928222656,
|
||||||
|
-0.053597450256347656,
|
||||||
|
-0.00476837158203125,
|
||||||
|
-0.00209808349609375,
|
||||||
|
3.557830810546875,
|
||||||
|
]
|
||||||
|
left1 = [
|
||||||
|
-0.00133514404296875,
|
||||||
|
0.00209808349609375,
|
||||||
|
0.01583099365234375,
|
||||||
|
-0.032616615295410156,
|
||||||
|
-0.00286102294921875,
|
||||||
|
0.00095367431640625,
|
||||||
|
-0.3393220901489258,
|
||||||
|
]
|
||||||
|
right1 = [
|
||||||
|
-0.00133514404296875,
|
||||||
|
0.00247955322265625,
|
||||||
|
0.01583099365234375,
|
||||||
|
-0.032616615295410156,
|
||||||
|
-0.00286102294921875,
|
||||||
|
0.00095367431640625,
|
||||||
|
-0.3397035598754883,
|
||||||
|
]
|
||||||
|
ros_operator.puppet_arm_publish_continuous(left0, right0)
|
||||||
|
input("Press enter to continue")
|
||||||
|
ros_operator.puppet_arm_publish_continuous(left1, right1)
|
||||||
|
# Initialize the previous action to be the initial robot state
|
||||||
|
pre_action = np.zeros(config["state_dim"])
|
||||||
|
pre_action[:14] = np.array([
|
||||||
|
-0.00133514404296875,
|
||||||
|
0.00209808349609375,
|
||||||
|
0.01583099365234375,
|
||||||
|
-0.032616615295410156,
|
||||||
|
-0.00286102294921875,
|
||||||
|
0.00095367431640625,
|
||||||
|
-0.3393220901489258,
|
||||||
|
] + [
|
||||||
|
-0.00133514404296875,
|
||||||
|
0.00247955322265625,
|
||||||
|
0.01583099365234375,
|
||||||
|
-0.032616615295410156,
|
||||||
|
-0.00286102294921875,
|
||||||
|
0.00095367431640625,
|
||||||
|
-0.3397035598754883,
|
||||||
|
])
|
||||||
|
action = None
|
||||||
|
# Inference loop
|
||||||
|
with torch.inference_mode():
|
||||||
|
while True and not rospy.is_shutdown():
|
||||||
|
# The current time step
|
||||||
|
t = 0
|
||||||
|
rate = rospy.Rate(args.publish_rate)
|
||||||
|
|
||||||
|
action_buffer = np.zeros([chunk_size, config["state_dim"]])
|
||||||
|
|
||||||
|
while t < max_publish_step and not rospy.is_shutdown():
|
||||||
|
# Update observation window
|
||||||
|
update_observation_window(args, config, ros_operator)
|
||||||
|
|
||||||
|
# When coming to the end of the action chunk
|
||||||
|
if t % chunk_size == 0:
|
||||||
|
# Start inference
|
||||||
|
action_buffer = inference_fn(args, config, policy, t).copy()
|
||||||
|
|
||||||
|
raw_action = action_buffer[t % chunk_size]
|
||||||
|
action = raw_action
|
||||||
|
# Interpolate the original action sequence
|
||||||
|
if args.use_actions_interpolation:
|
||||||
|
# print(f"Time {t}, pre {pre_action}, act {action}")
|
||||||
|
interp_actions = interpolate_action(args, pre_action, action)
|
||||||
|
else:
|
||||||
|
interp_actions = action[np.newaxis, :]
|
||||||
|
# Execute the interpolated actions one by one
|
||||||
|
for act in interp_actions:
|
||||||
|
left_action = act[:7]
|
||||||
|
right_action = act[7:14]
|
||||||
|
|
||||||
|
if not args.disable_puppet_arm:
|
||||||
|
ros_operator.puppet_arm_publish(left_action,
|
||||||
|
right_action) # puppet_arm_publish_continuous_thread
|
||||||
|
|
||||||
|
if args.use_robot_base:
|
||||||
|
vel_action = act[14:16]
|
||||||
|
ros_operator.robot_base_publish(vel_action)
|
||||||
|
rate.sleep()
|
||||||
|
# print(f"doing action: {act}")
|
||||||
|
t += 1
|
||||||
|
|
||||||
|
print("Published Step", t)
|
||||||
|
pre_action = action.copy()
|
||||||
|
|
||||||
|
|
||||||
|
# ROS operator class
|
||||||
|
class RosOperator:
|
||||||
|
|
||||||
|
def __init__(self, args):
|
||||||
|
self.robot_base_deque = None
|
||||||
|
self.puppet_arm_right_deque = None
|
||||||
|
self.puppet_arm_left_deque = None
|
||||||
|
self.img_front_deque = None
|
||||||
|
self.img_right_deque = None
|
||||||
|
self.img_left_deque = None
|
||||||
|
self.img_front_depth_deque = None
|
||||||
|
self.img_right_depth_deque = None
|
||||||
|
self.img_left_depth_deque = None
|
||||||
|
self.bridge = None
|
||||||
|
self.puppet_arm_left_publisher = None
|
||||||
|
self.puppet_arm_right_publisher = None
|
||||||
|
self.robot_base_publisher = None
|
||||||
|
self.puppet_arm_publish_thread = None
|
||||||
|
self.puppet_arm_publish_lock = None
|
||||||
|
self.args = args
|
||||||
|
self.init()
|
||||||
|
self.init_ros()
|
||||||
|
|
||||||
|
def init(self):
|
||||||
|
self.bridge = CvBridge()
|
||||||
|
self.img_left_deque = deque()
|
||||||
|
self.img_right_deque = deque()
|
||||||
|
self.img_front_deque = deque()
|
||||||
|
self.img_left_depth_deque = deque()
|
||||||
|
self.img_right_depth_deque = deque()
|
||||||
|
self.img_front_depth_deque = deque()
|
||||||
|
self.puppet_arm_left_deque = deque()
|
||||||
|
self.puppet_arm_right_deque = deque()
|
||||||
|
self.robot_base_deque = deque()
|
||||||
|
self.puppet_arm_publish_lock = threading.Lock()
|
||||||
|
self.puppet_arm_publish_lock.acquire()
|
||||||
|
|
||||||
|
def puppet_arm_publish(self, left, right):
|
||||||
|
joint_state_msg = JointState()
|
||||||
|
joint_state_msg.header = Header()
|
||||||
|
joint_state_msg.header.stamp = rospy.Time.now() # Set timestep
|
||||||
|
joint_state_msg.name = [
|
||||||
|
"joint0",
|
||||||
|
"joint1",
|
||||||
|
"joint2",
|
||||||
|
"joint3",
|
||||||
|
"joint4",
|
||||||
|
"joint5",
|
||||||
|
"joint6",
|
||||||
|
] # 设置关节名称
|
||||||
|
joint_state_msg.position = left
|
||||||
|
self.puppet_arm_left_publisher.publish(joint_state_msg)
|
||||||
|
joint_state_msg.position = right
|
||||||
|
self.puppet_arm_right_publisher.publish(joint_state_msg)
|
||||||
|
|
||||||
|
def robot_base_publish(self, vel):
|
||||||
|
vel_msg = Twist()
|
||||||
|
vel_msg.linear.x = vel[0]
|
||||||
|
vel_msg.linear.y = 0
|
||||||
|
vel_msg.linear.z = 0
|
||||||
|
vel_msg.angular.x = 0
|
||||||
|
vel_msg.angular.y = 0
|
||||||
|
vel_msg.angular.z = vel[1]
|
||||||
|
self.robot_base_publisher.publish(vel_msg)
|
||||||
|
|
||||||
|
def puppet_arm_publish_continuous(self, left, right):
|
||||||
|
rate = rospy.Rate(self.args.publish_rate)
|
||||||
|
left_arm = None
|
||||||
|
right_arm = None
|
||||||
|
while True and not rospy.is_shutdown():
|
||||||
|
if len(self.puppet_arm_left_deque) != 0:
|
||||||
|
left_arm = list(self.puppet_arm_left_deque[-1].position)
|
||||||
|
if len(self.puppet_arm_right_deque) != 0:
|
||||||
|
right_arm = list(self.puppet_arm_right_deque[-1].position)
|
||||||
|
if left_arm is None or right_arm is None:
|
||||||
|
rate.sleep()
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
left_symbol = [1 if left[i] - left_arm[i] > 0 else -1 for i in range(len(left))]
|
||||||
|
right_symbol = [1 if right[i] - right_arm[i] > 0 else -1 for i in range(len(right))]
|
||||||
|
flag = True
|
||||||
|
step = 0
|
||||||
|
while flag and not rospy.is_shutdown():
|
||||||
|
if self.puppet_arm_publish_lock.acquire(False):
|
||||||
|
return
|
||||||
|
left_diff = [abs(left[i] - left_arm[i]) for i in range(len(left))]
|
||||||
|
right_diff = [abs(right[i] - right_arm[i]) for i in range(len(right))]
|
||||||
|
flag = False
|
||||||
|
for i in range(len(left)):
|
||||||
|
if left_diff[i] < self.args.arm_steps_length[i]:
|
||||||
|
left_arm[i] = left[i]
|
||||||
|
else:
|
||||||
|
left_arm[i] += left_symbol[i] * self.args.arm_steps_length[i]
|
||||||
|
flag = True
|
||||||
|
for i in range(len(right)):
|
||||||
|
if right_diff[i] < self.args.arm_steps_length[i]:
|
||||||
|
right_arm[i] = right[i]
|
||||||
|
else:
|
||||||
|
right_arm[i] += right_symbol[i] * self.args.arm_steps_length[i]
|
||||||
|
flag = True
|
||||||
|
joint_state_msg = JointState()
|
||||||
|
joint_state_msg.header = Header()
|
||||||
|
joint_state_msg.header.stamp = rospy.Time.now() # Set the timestep
|
||||||
|
joint_state_msg.name = [
|
||||||
|
"joint0",
|
||||||
|
"joint1",
|
||||||
|
"joint2",
|
||||||
|
"joint3",
|
||||||
|
"joint4",
|
||||||
|
"joint5",
|
||||||
|
"joint6",
|
||||||
|
] # 设置关节名称
|
||||||
|
joint_state_msg.position = left_arm
|
||||||
|
self.puppet_arm_left_publisher.publish(joint_state_msg)
|
||||||
|
joint_state_msg.position = right_arm
|
||||||
|
self.puppet_arm_right_publisher.publish(joint_state_msg)
|
||||||
|
step += 1
|
||||||
|
print("puppet_arm_publish_continuous:", step)
|
||||||
|
rate.sleep()
|
||||||
|
|
||||||
|
def puppet_arm_publish_linear(self, left, right):
|
||||||
|
num_step = 100
|
||||||
|
rate = rospy.Rate(200)
|
||||||
|
|
||||||
|
left_arm = None
|
||||||
|
right_arm = None
|
||||||
|
|
||||||
|
while True and not rospy.is_shutdown():
|
||||||
|
if len(self.puppet_arm_left_deque) != 0:
|
||||||
|
left_arm = list(self.puppet_arm_left_deque[-1].position)
|
||||||
|
if len(self.puppet_arm_right_deque) != 0:
|
||||||
|
right_arm = list(self.puppet_arm_right_deque[-1].position)
|
||||||
|
if left_arm is None or right_arm is None:
|
||||||
|
rate.sleep()
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
traj_left_list = np.linspace(left_arm, left, num_step)
|
||||||
|
traj_right_list = np.linspace(right_arm, right, num_step)
|
||||||
|
|
||||||
|
for i in range(len(traj_left_list)):
|
||||||
|
traj_left = traj_left_list[i]
|
||||||
|
traj_right = traj_right_list[i]
|
||||||
|
traj_left[-1] = left[-1]
|
||||||
|
traj_right[-1] = right[-1]
|
||||||
|
joint_state_msg = JointState()
|
||||||
|
joint_state_msg.header = Header()
|
||||||
|
joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳
|
||||||
|
joint_state_msg.name = [
|
||||||
|
"joint0",
|
||||||
|
"joint1",
|
||||||
|
"joint2",
|
||||||
|
"joint3",
|
||||||
|
"joint4",
|
||||||
|
"joint5",
|
||||||
|
"joint6",
|
||||||
|
] # 设置关节名称
|
||||||
|
joint_state_msg.position = traj_left
|
||||||
|
self.puppet_arm_left_publisher.publish(joint_state_msg)
|
||||||
|
joint_state_msg.position = traj_right
|
||||||
|
self.puppet_arm_right_publisher.publish(joint_state_msg)
|
||||||
|
rate.sleep()
|
||||||
|
|
||||||
|
def puppet_arm_publish_continuous_thread(self, left, right):
|
||||||
|
if self.puppet_arm_publish_thread is not None:
|
||||||
|
self.puppet_arm_publish_lock.release()
|
||||||
|
self.puppet_arm_publish_thread.join()
|
||||||
|
self.puppet_arm_publish_lock.acquire(False)
|
||||||
|
self.puppet_arm_publish_thread = None
|
||||||
|
self.puppet_arm_publish_thread = threading.Thread(target=self.puppet_arm_publish_continuous, args=(left, right))
|
||||||
|
self.puppet_arm_publish_thread.start()
|
||||||
|
|
||||||
|
def get_frame(self):
|
||||||
|
if (len(self.img_left_deque) == 0 or len(self.img_right_deque) == 0 or len(self.img_front_deque) == 0 or
|
||||||
|
(self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or len(self.img_right_depth_deque) == 0
|
||||||
|
or len(self.img_front_depth_deque) == 0))):
|
||||||
|
return False
|
||||||
|
if self.args.use_depth_image:
|
||||||
|
frame_time = min([
|
||||||
|
self.img_left_deque[-1].header.stamp.to_sec(),
|
||||||
|
self.img_right_deque[-1].header.stamp.to_sec(),
|
||||||
|
self.img_front_deque[-1].header.stamp.to_sec(),
|
||||||
|
self.img_left_depth_deque[-1].header.stamp.to_sec(),
|
||||||
|
self.img_right_depth_deque[-1].header.stamp.to_sec(),
|
||||||
|
self.img_front_depth_deque[-1].header.stamp.to_sec(),
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
frame_time = min([
|
||||||
|
self.img_left_deque[-1].header.stamp.to_sec(),
|
||||||
|
self.img_right_deque[-1].header.stamp.to_sec(),
|
||||||
|
self.img_front_deque[-1].header.stamp.to_sec(),
|
||||||
|
])
|
||||||
|
|
||||||
|
if (len(self.img_left_deque) == 0 or self.img_left_deque[-1].header.stamp.to_sec() < frame_time):
|
||||||
|
return False
|
||||||
|
if (len(self.img_right_deque) == 0 or self.img_right_deque[-1].header.stamp.to_sec() < frame_time):
|
||||||
|
return False
|
||||||
|
if (len(self.img_front_deque) == 0 or self.img_front_deque[-1].header.stamp.to_sec() < frame_time):
|
||||||
|
return False
|
||||||
|
if (len(self.puppet_arm_left_deque) == 0 or self.puppet_arm_left_deque[-1].header.stamp.to_sec() < frame_time):
|
||||||
|
return False
|
||||||
|
if (len(self.puppet_arm_right_deque) == 0
|
||||||
|
or self.puppet_arm_right_deque[-1].header.stamp.to_sec() < frame_time):
|
||||||
|
return False
|
||||||
|
if self.args.use_depth_image and (len(self.img_left_depth_deque) == 0
|
||||||
|
or self.img_left_depth_deque[-1].header.stamp.to_sec() < frame_time):
|
||||||
|
return False
|
||||||
|
if self.args.use_depth_image and (len(self.img_right_depth_deque) == 0
|
||||||
|
or self.img_right_depth_deque[-1].header.stamp.to_sec() < frame_time):
|
||||||
|
return False
|
||||||
|
if self.args.use_depth_image and (len(self.img_front_depth_deque) == 0
|
||||||
|
or self.img_front_depth_deque[-1].header.stamp.to_sec() < frame_time):
|
||||||
|
return False
|
||||||
|
if self.args.use_robot_base and (len(self.robot_base_deque) == 0
|
||||||
|
or self.robot_base_deque[-1].header.stamp.to_sec() < frame_time):
|
||||||
|
return False
|
||||||
|
|
||||||
|
while self.img_left_deque[0].header.stamp.to_sec() < frame_time:
|
||||||
|
self.img_left_deque.popleft()
|
||||||
|
img_left = self.bridge.imgmsg_to_cv2(self.img_left_deque.popleft(), "passthrough")
|
||||||
|
|
||||||
|
while self.img_right_deque[0].header.stamp.to_sec() < frame_time:
|
||||||
|
self.img_right_deque.popleft()
|
||||||
|
img_right = self.bridge.imgmsg_to_cv2(self.img_right_deque.popleft(), "passthrough")
|
||||||
|
|
||||||
|
while self.img_front_deque[0].header.stamp.to_sec() < frame_time:
|
||||||
|
self.img_front_deque.popleft()
|
||||||
|
img_front = self.bridge.imgmsg_to_cv2(self.img_front_deque.popleft(), "passthrough")
|
||||||
|
|
||||||
|
while self.puppet_arm_left_deque[0].header.stamp.to_sec() < frame_time:
|
||||||
|
self.puppet_arm_left_deque.popleft()
|
||||||
|
puppet_arm_left = self.puppet_arm_left_deque.popleft()
|
||||||
|
|
||||||
|
while self.puppet_arm_right_deque[0].header.stamp.to_sec() < frame_time:
|
||||||
|
self.puppet_arm_right_deque.popleft()
|
||||||
|
puppet_arm_right = self.puppet_arm_right_deque.popleft()
|
||||||
|
|
||||||
|
img_left_depth = None
|
||||||
|
if self.args.use_depth_image:
|
||||||
|
while self.img_left_depth_deque[0].header.stamp.to_sec() < frame_time:
|
||||||
|
self.img_left_depth_deque.popleft()
|
||||||
|
img_left_depth = self.bridge.imgmsg_to_cv2(self.img_left_depth_deque.popleft(), "passthrough")
|
||||||
|
|
||||||
|
img_right_depth = None
|
||||||
|
if self.args.use_depth_image:
|
||||||
|
while self.img_right_depth_deque[0].header.stamp.to_sec() < frame_time:
|
||||||
|
self.img_right_depth_deque.popleft()
|
||||||
|
img_right_depth = self.bridge.imgmsg_to_cv2(self.img_right_depth_deque.popleft(), "passthrough")
|
||||||
|
|
||||||
|
img_front_depth = None
|
||||||
|
if self.args.use_depth_image:
|
||||||
|
while self.img_front_depth_deque[0].header.stamp.to_sec() < frame_time:
|
||||||
|
self.img_front_depth_deque.popleft()
|
||||||
|
img_front_depth = self.bridge.imgmsg_to_cv2(self.img_front_depth_deque.popleft(), "passthrough")
|
||||||
|
|
||||||
|
robot_base = None
|
||||||
|
if self.args.use_robot_base:
|
||||||
|
while self.robot_base_deque[0].header.stamp.to_sec() < frame_time:
|
||||||
|
self.robot_base_deque.popleft()
|
||||||
|
robot_base = self.robot_base_deque.popleft()
|
||||||
|
|
||||||
|
return (
|
||||||
|
img_front,
|
||||||
|
img_left,
|
||||||
|
img_right,
|
||||||
|
img_front_depth,
|
||||||
|
img_left_depth,
|
||||||
|
img_right_depth,
|
||||||
|
puppet_arm_left,
|
||||||
|
puppet_arm_right,
|
||||||
|
robot_base,
|
||||||
|
)
|
||||||
|
|
||||||
|
def img_left_callback(self, msg):
|
||||||
|
if len(self.img_left_deque) >= 2000:
|
||||||
|
self.img_left_deque.popleft()
|
||||||
|
self.img_left_deque.append(msg)
|
||||||
|
|
||||||
|
def img_right_callback(self, msg):
|
||||||
|
if len(self.img_right_deque) >= 2000:
|
||||||
|
self.img_right_deque.popleft()
|
||||||
|
self.img_right_deque.append(msg)
|
||||||
|
|
||||||
|
def img_front_callback(self, msg):
|
||||||
|
if len(self.img_front_deque) >= 2000:
|
||||||
|
self.img_front_deque.popleft()
|
||||||
|
self.img_front_deque.append(msg)
|
||||||
|
|
||||||
|
def img_left_depth_callback(self, msg):
|
||||||
|
if len(self.img_left_depth_deque) >= 2000:
|
||||||
|
self.img_left_depth_deque.popleft()
|
||||||
|
self.img_left_depth_deque.append(msg)
|
||||||
|
|
||||||
|
def img_right_depth_callback(self, msg):
|
||||||
|
if len(self.img_right_depth_deque) >= 2000:
|
||||||
|
self.img_right_depth_deque.popleft()
|
||||||
|
self.img_right_depth_deque.append(msg)
|
||||||
|
|
||||||
|
def img_front_depth_callback(self, msg):
|
||||||
|
if len(self.img_front_depth_deque) >= 2000:
|
||||||
|
self.img_front_depth_deque.popleft()
|
||||||
|
self.img_front_depth_deque.append(msg)
|
||||||
|
|
||||||
|
def puppet_arm_left_callback(self, msg):
|
||||||
|
if len(self.puppet_arm_left_deque) >= 2000:
|
||||||
|
self.puppet_arm_left_deque.popleft()
|
||||||
|
self.puppet_arm_left_deque.append(msg)
|
||||||
|
|
||||||
|
def puppet_arm_right_callback(self, msg):
|
||||||
|
if len(self.puppet_arm_right_deque) >= 2000:
|
||||||
|
self.puppet_arm_right_deque.popleft()
|
||||||
|
self.puppet_arm_right_deque.append(msg)
|
||||||
|
|
||||||
|
def robot_base_callback(self, msg):
|
||||||
|
if len(self.robot_base_deque) >= 2000:
|
||||||
|
self.robot_base_deque.popleft()
|
||||||
|
self.robot_base_deque.append(msg)
|
||||||
|
|
||||||
|
def init_ros(self):
|
||||||
|
rospy.init_node("joint_state_publisher", anonymous=True)
|
||||||
|
rospy.Subscriber(
|
||||||
|
self.args.img_left_topic,
|
||||||
|
Image,
|
||||||
|
self.img_left_callback,
|
||||||
|
queue_size=1000,
|
||||||
|
tcp_nodelay=True,
|
||||||
|
)
|
||||||
|
rospy.Subscriber(
|
||||||
|
self.args.img_right_topic,
|
||||||
|
Image,
|
||||||
|
self.img_right_callback,
|
||||||
|
queue_size=1000,
|
||||||
|
tcp_nodelay=True,
|
||||||
|
)
|
||||||
|
rospy.Subscriber(
|
||||||
|
self.args.img_front_topic,
|
||||||
|
Image,
|
||||||
|
self.img_front_callback,
|
||||||
|
queue_size=1000,
|
||||||
|
tcp_nodelay=True,
|
||||||
|
)
|
||||||
|
if self.args.use_depth_image:
|
||||||
|
rospy.Subscriber(
|
||||||
|
self.args.img_left_depth_topic,
|
||||||
|
Image,
|
||||||
|
self.img_left_depth_callback,
|
||||||
|
queue_size=1000,
|
||||||
|
tcp_nodelay=True,
|
||||||
|
)
|
||||||
|
rospy.Subscriber(
|
||||||
|
self.args.img_right_depth_topic,
|
||||||
|
Image,
|
||||||
|
self.img_right_depth_callback,
|
||||||
|
queue_size=1000,
|
||||||
|
tcp_nodelay=True,
|
||||||
|
)
|
||||||
|
rospy.Subscriber(
|
||||||
|
self.args.img_front_depth_topic,
|
||||||
|
Image,
|
||||||
|
self.img_front_depth_callback,
|
||||||
|
queue_size=1000,
|
||||||
|
tcp_nodelay=True,
|
||||||
|
)
|
||||||
|
rospy.Subscriber(
|
||||||
|
self.args.puppet_arm_left_topic,
|
||||||
|
JointState,
|
||||||
|
self.puppet_arm_left_callback,
|
||||||
|
queue_size=1000,
|
||||||
|
tcp_nodelay=True,
|
||||||
|
)
|
||||||
|
rospy.Subscriber(
|
||||||
|
self.args.puppet_arm_right_topic,
|
||||||
|
JointState,
|
||||||
|
self.puppet_arm_right_callback,
|
||||||
|
queue_size=1000,
|
||||||
|
tcp_nodelay=True,
|
||||||
|
)
|
||||||
|
rospy.Subscriber(
|
||||||
|
self.args.robot_base_topic,
|
||||||
|
Odometry,
|
||||||
|
self.robot_base_callback,
|
||||||
|
queue_size=1000,
|
||||||
|
tcp_nodelay=True,
|
||||||
|
)
|
||||||
|
self.puppet_arm_left_publisher = rospy.Publisher(self.args.puppet_arm_left_cmd_topic, JointState, queue_size=10)
|
||||||
|
self.puppet_arm_right_publisher = rospy.Publisher(self.args.puppet_arm_right_cmd_topic,
|
||||||
|
JointState,
|
||||||
|
queue_size=10)
|
||||||
|
self.robot_base_publisher = rospy.Publisher(self.args.robot_base_cmd_topic, Twist, queue_size=10)
|
||||||
|
|
||||||
|
|
||||||
|
def get_arguments():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_publish_step",
|
||||||
|
action="store",
|
||||||
|
type=int,
|
||||||
|
help="Maximum number of action publishing steps",
|
||||||
|
default=10000,
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
action="store",
|
||||||
|
type=int,
|
||||||
|
help="Random seed",
|
||||||
|
default=None,
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--img_front_topic",
|
||||||
|
action="store",
|
||||||
|
type=str,
|
||||||
|
help="img_front_topic",
|
||||||
|
default="/camera_f/color/image_raw",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--img_left_topic",
|
||||||
|
action="store",
|
||||||
|
type=str,
|
||||||
|
help="img_left_topic",
|
||||||
|
default="/camera_l/color/image_raw",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--img_right_topic",
|
||||||
|
action="store",
|
||||||
|
type=str,
|
||||||
|
help="img_right_topic",
|
||||||
|
default="/camera_r/color/image_raw",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--img_front_depth_topic",
|
||||||
|
action="store",
|
||||||
|
type=str,
|
||||||
|
help="img_front_depth_topic",
|
||||||
|
default="/camera_f/depth/image_raw",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--img_left_depth_topic",
|
||||||
|
action="store",
|
||||||
|
type=str,
|
||||||
|
help="img_left_depth_topic",
|
||||||
|
default="/camera_l/depth/image_raw",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--img_right_depth_topic",
|
||||||
|
action="store",
|
||||||
|
type=str,
|
||||||
|
help="img_right_depth_topic",
|
||||||
|
default="/camera_r/depth/image_raw",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--puppet_arm_left_cmd_topic",
|
||||||
|
action="store",
|
||||||
|
type=str,
|
||||||
|
help="puppet_arm_left_cmd_topic",
|
||||||
|
default="/master/joint_left",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--puppet_arm_right_cmd_topic",
|
||||||
|
action="store",
|
||||||
|
type=str,
|
||||||
|
help="puppet_arm_right_cmd_topic",
|
||||||
|
default="/master/joint_right",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--puppet_arm_left_topic",
|
||||||
|
action="store",
|
||||||
|
type=str,
|
||||||
|
help="puppet_arm_left_topic",
|
||||||
|
default="/puppet/joint_left",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--puppet_arm_right_topic",
|
||||||
|
action="store",
|
||||||
|
type=str,
|
||||||
|
help="puppet_arm_right_topic",
|
||||||
|
default="/puppet/joint_right",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--robot_base_topic",
|
||||||
|
action="store",
|
||||||
|
type=str,
|
||||||
|
help="robot_base_topic",
|
||||||
|
default="/odom_raw",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--robot_base_cmd_topic",
|
||||||
|
action="store",
|
||||||
|
type=str,
|
||||||
|
help="robot_base_topic",
|
||||||
|
default="/cmd_vel",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_robot_base",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to use the robot base to move around",
|
||||||
|
default=False,
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--publish_rate",
|
||||||
|
action="store",
|
||||||
|
type=int,
|
||||||
|
help="The rate at which to publish the actions",
|
||||||
|
default=30,
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ctrl_freq",
|
||||||
|
action="store",
|
||||||
|
type=int,
|
||||||
|
help="The control frequency of the robot",
|
||||||
|
default=25,
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--chunk_size",
|
||||||
|
action="store",
|
||||||
|
type=int,
|
||||||
|
help="Action chunk size",
|
||||||
|
default=64,
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--arm_steps_length",
|
||||||
|
action="store",
|
||||||
|
type=float,
|
||||||
|
help="The maximum change allowed for each joint per timestep",
|
||||||
|
default=[0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.2],
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_actions_interpolation",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to interpolate the actions if the difference is too large",
|
||||||
|
default=False,
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_depth_image",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to use depth images",
|
||||||
|
default=False,
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable_puppet_arm",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to disable the puppet arm. This is useful for safely debugging",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--config_path",
|
||||||
|
type=str,
|
||||||
|
default="configs/base.yaml",
|
||||||
|
help="Path to the config file",
|
||||||
|
)
|
||||||
|
# parser.add_argument('--cfg_scale', type=float, default=2.0,
|
||||||
|
# help='the scaling factor used to modify the magnitude of the control features during denoising')
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained_model_name_or_path",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Name or path to the pretrained model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang_embeddings_path",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the pre-encoded language instruction embeddings",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_arguments()
|
||||||
|
ros_operator = RosOperator(args)
|
||||||
|
if args.seed is not None:
|
||||||
|
set_seed(args.seed)
|
||||||
|
config = get_config(args)
|
||||||
|
model_inference(args, config, ros_operator)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
344
RDT-170M/scripts/agilex_model.py
Normal file
344
RDT-170M/scripts/agilex_model.py
Normal file
@ -0,0 +1,344 @@
|
|||||||
|
import os, sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
from configs.state_vec import STATE_VEC_IDX_MAPPING
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# get current workspace
|
||||||
|
current_file = Path(__file__)
|
||||||
|
sys.path.append(os.path.join(current_file.parent.parent, "models"))
|
||||||
|
sys.path.append(os.path.join(current_file.parent.parent, "models"))
|
||||||
|
|
||||||
|
from multimodal_encoder.siglip_encoder import SiglipVisionTower
|
||||||
|
from multimodal_encoder.t5_encoder import T5Embedder
|
||||||
|
from rdt_runner import RDTRunner
|
||||||
|
|
||||||
|
# The indices that the raw vector should be mapped to in the unified action vector
|
||||||
|
# AGILEX_STATE_INDICES = [
|
||||||
|
# STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"] for i in range(1)
|
||||||
|
# ] + [
|
||||||
|
# STATE_VEC_IDX_MAPPING["left_gripper_open"]
|
||||||
|
# ] + [
|
||||||
|
# STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(1)
|
||||||
|
# ] + [
|
||||||
|
# STATE_VEC_IDX_MAPPING[f"right_gripper_open"]
|
||||||
|
# ]
|
||||||
|
# AGILEX_STATE_INDICES = None
|
||||||
|
|
||||||
|
|
||||||
|
# Create the RDT model
|
||||||
|
def create_model(args, **kwargs):
|
||||||
|
left_arm_dim, right_arm_dim = (
|
||||||
|
args["arm_dim"]["left_arm_dim"],
|
||||||
|
args["arm_dim"]["right_arm_dim"],
|
||||||
|
)
|
||||||
|
AGILEX_STATE_INDICES = ([STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"]
|
||||||
|
for i in range(left_arm_dim)] + [STATE_VEC_IDX_MAPPING["left_gripper_open"]] +
|
||||||
|
[STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"]
|
||||||
|
for i in range(right_arm_dim)] + [STATE_VEC_IDX_MAPPING[f"right_gripper_open"]])
|
||||||
|
model = RoboticDiffusionTransformerModel(args, **kwargs)
|
||||||
|
pretrained = kwargs.get("pretrained", None)
|
||||||
|
if pretrained is not None and os.path.isfile(pretrained):
|
||||||
|
model.load_pretrained_weights(pretrained)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
class RoboticDiffusionTransformerModel(object):
|
||||||
|
"""A wrapper for the RDT model, which handles
|
||||||
|
1. Model initialization
|
||||||
|
2. Encodings of instructions
|
||||||
|
3. Model inference
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
args,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
image_size=None,
|
||||||
|
control_frequency=25,
|
||||||
|
pretrained=None,
|
||||||
|
pretrained_vision_encoder_name_or_path=None,
|
||||||
|
):
|
||||||
|
self.args = args
|
||||||
|
self.dtype = dtype
|
||||||
|
self.image_size = image_size
|
||||||
|
self.device = device
|
||||||
|
self.control_frequency = control_frequency
|
||||||
|
# We do not use the text encoder due to limited GPU memory
|
||||||
|
# self.text_tokenizer, self.text_model = self.get_text_encoder(pretrained_text_encoder_name_or_path)
|
||||||
|
self.image_processor, self.vision_model = self.get_vision_encoder(pretrained_vision_encoder_name_or_path)
|
||||||
|
self.policy = self.get_policy(pretrained)
|
||||||
|
self.left_arm_dim, self.right_arm_dim = (
|
||||||
|
args["arm_dim"]["left_arm_dim"],
|
||||||
|
args["arm_dim"]["right_arm_dim"],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def get_policy(self, pretrained):
|
||||||
|
"""Initialize the model."""
|
||||||
|
# Initialize model with arguments
|
||||||
|
if pretrained is None or os.path.isfile(pretrained):
|
||||||
|
img_cond_len = (self.args["common"]["img_history_size"] * self.args["common"]["num_cameras"] *
|
||||||
|
self.vision_model.num_patches)
|
||||||
|
|
||||||
|
_model = RDTRunner(
|
||||||
|
action_dim=self.args["common"]["state_dim"],
|
||||||
|
pred_horizon=self.args["common"]["action_chunk_size"],
|
||||||
|
config=self.args["model"],
|
||||||
|
lang_token_dim=self.args["model"]["lang_token_dim"],
|
||||||
|
img_token_dim=self.args["model"]["img_token_dim"],
|
||||||
|
state_token_dim=self.args["model"]["state_token_dim"],
|
||||||
|
max_lang_cond_len=self.args["dataset"]["tokenizer_max_length"],
|
||||||
|
img_cond_len=img_cond_len,
|
||||||
|
img_pos_embed_config=[
|
||||||
|
# No initial pos embed in the last grid size
|
||||||
|
# since we've already done in ViT
|
||||||
|
(
|
||||||
|
"image",
|
||||||
|
(
|
||||||
|
self.args["common"]["img_history_size"],
|
||||||
|
self.args["common"]["num_cameras"],
|
||||||
|
-self.vision_model.num_patches,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
lang_pos_embed_config=[
|
||||||
|
# Similarly, no initial pos embed for language
|
||||||
|
("lang", -self.args["dataset"]["tokenizer_max_length"]),
|
||||||
|
],
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_model = RDTRunner.from_pretrained(pretrained)
|
||||||
|
|
||||||
|
return _model
|
||||||
|
|
||||||
|
def get_text_encoder(self, pretrained_text_encoder_name_or_path):
|
||||||
|
text_embedder = T5Embedder(
|
||||||
|
from_pretrained=pretrained_text_encoder_name_or_path,
|
||||||
|
model_max_length=self.args["dataset"]["tokenizer_max_length"],
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
|
||||||
|
return tokenizer, text_encoder
|
||||||
|
|
||||||
|
def get_vision_encoder(self, pretrained_vision_encoder_name_or_path):
|
||||||
|
vision_encoder = SiglipVisionTower(vision_tower=pretrained_vision_encoder_name_or_path, args=None)
|
||||||
|
image_processor = vision_encoder.image_processor
|
||||||
|
return image_processor, vision_encoder
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Set model to evaluation mode."""
|
||||||
|
device = self.device
|
||||||
|
weight_dtype = self.dtype
|
||||||
|
self.policy.eval()
|
||||||
|
# self.text_model.eval()
|
||||||
|
self.vision_model.eval()
|
||||||
|
|
||||||
|
self.policy = self.policy.to(device, dtype=weight_dtype)
|
||||||
|
# self.text_model = self.text_model.to(device, dtype=weight_dtype)
|
||||||
|
self.vision_model = self.vision_model.to(device, dtype=weight_dtype)
|
||||||
|
|
||||||
|
def load_pretrained_weights(self, pretrained=None):
|
||||||
|
if pretrained is None:
|
||||||
|
return
|
||||||
|
print(f"Loading weights from {pretrained}")
|
||||||
|
filename = os.path.basename(pretrained)
|
||||||
|
if filename.endswith(".pt"):
|
||||||
|
checkpoint = torch.load(pretrained)
|
||||||
|
self.policy.load_state_dict(checkpoint["module"])
|
||||||
|
elif filename.endswith(".safetensors"):
|
||||||
|
from safetensors.torch import load_model
|
||||||
|
|
||||||
|
load_model(self.policy, pretrained)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unknown checkpoint format: {pretrained}")
|
||||||
|
|
||||||
|
def encode_instruction(self, instruction, device="cuda"):
|
||||||
|
"""Encode string instruction to latent embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instruction: a string of instruction
|
||||||
|
device: a string of device
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pred: a tensor of latent embeddings of shape (text_max_length, 512)
|
||||||
|
"""
|
||||||
|
tokens = self.text_tokenizer(instruction, return_tensors="pt", padding="longest",
|
||||||
|
truncation=True)["input_ids"].to(device)
|
||||||
|
|
||||||
|
tokens = tokens.view(1, -1)
|
||||||
|
with torch.no_grad():
|
||||||
|
pred = self.text_model(tokens).last_hidden_state.detach()
|
||||||
|
|
||||||
|
return pred
|
||||||
|
|
||||||
|
def _format_joint_to_state(self, joints):
|
||||||
|
"""
|
||||||
|
Format the joint proprioception into the unified action vector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
joints (torch.Tensor): The joint proprioception to be formatted.
|
||||||
|
qpos ([B, N, 14]).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
state (torch.Tensor): The formatted vector for RDT ([B, N, 128]).
|
||||||
|
"""
|
||||||
|
AGILEX_STATE_INDICES = ([STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"]
|
||||||
|
for i in range(self.left_arm_dim)] + [STATE_VEC_IDX_MAPPING["left_gripper_open"]] +
|
||||||
|
[STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"]
|
||||||
|
for i in range(self.right_arm_dim)] + [STATE_VEC_IDX_MAPPING[f"right_gripper_open"]])
|
||||||
|
# Rescale the gripper to the range of [0, 1]
|
||||||
|
joints = joints / torch.tensor(
|
||||||
|
[[[1 for i in range(self.left_arm_dim + 1 + self.right_arm_dim + 1)]]],
|
||||||
|
device=joints.device,
|
||||||
|
dtype=joints.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
B, N, _ = joints.shape
|
||||||
|
state = torch.zeros(
|
||||||
|
(B, N, self.args["model"]["state_token_dim"]),
|
||||||
|
device=joints.device,
|
||||||
|
dtype=joints.dtype,
|
||||||
|
)
|
||||||
|
# Fill into the unified state vector
|
||||||
|
state[:, :, AGILEX_STATE_INDICES] = joints
|
||||||
|
# Assemble the mask indicating each dimension's availability
|
||||||
|
state_elem_mask = torch.zeros(
|
||||||
|
(B, self.args["model"]["state_token_dim"]),
|
||||||
|
device=joints.device,
|
||||||
|
dtype=joints.dtype,
|
||||||
|
)
|
||||||
|
state_elem_mask[:, AGILEX_STATE_INDICES] = 1
|
||||||
|
return state, state_elem_mask
|
||||||
|
|
||||||
|
def _unformat_action_to_joint(self, action):
|
||||||
|
"""
|
||||||
|
Unformat the unified action vector into the joint action to be executed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action (torch.Tensor): The unified action vector to be unformatted.
|
||||||
|
([B, N, 128])
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
joints (torch.Tensor): The unformatted robot joint action.
|
||||||
|
qpos ([B, N, 14]).
|
||||||
|
"""
|
||||||
|
AGILEX_STATE_INDICES = ([STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"]
|
||||||
|
for i in range(self.left_arm_dim)] + [STATE_VEC_IDX_MAPPING["left_gripper_open"]] +
|
||||||
|
[STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"]
|
||||||
|
for i in range(self.right_arm_dim)] + [STATE_VEC_IDX_MAPPING[f"right_gripper_open"]])
|
||||||
|
action_indices = AGILEX_STATE_INDICES
|
||||||
|
joints = action[:, :, action_indices]
|
||||||
|
|
||||||
|
# Rescale the gripper back to the action range
|
||||||
|
# Note that the action range and proprioception range are different
|
||||||
|
# for Mobile ALOHA robot
|
||||||
|
joints = joints * torch.tensor(
|
||||||
|
[[[1 for i in range(self.left_arm_dim + 1 + self.right_arm_dim + 1)]]],
|
||||||
|
device=joints.device,
|
||||||
|
dtype=joints.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
return joints
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(self, proprio, images, text_embeds):
|
||||||
|
"""
|
||||||
|
Predict the next action chunk given the
|
||||||
|
proprioceptive states, images, and instruction embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
proprio: proprioceptive states
|
||||||
|
images: RGB images, the order should be
|
||||||
|
[ext_{t-1}, right_wrist_{t-1}, left_wrist_{t-1},
|
||||||
|
ext_{t}, right_wrist_{t}, left_wrist_{t}]
|
||||||
|
text_embeds: instruction embeddings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
action: predicted action
|
||||||
|
"""
|
||||||
|
device = self.device
|
||||||
|
dtype = self.dtype
|
||||||
|
|
||||||
|
# The background image used for padding
|
||||||
|
background_color = np.array([int(x * 255) for x in self.image_processor.image_mean],
|
||||||
|
dtype=np.uint8).reshape(1, 1, 3)
|
||||||
|
background_image = (np.ones(
|
||||||
|
(
|
||||||
|
self.image_processor.size["height"],
|
||||||
|
self.image_processor.size["width"],
|
||||||
|
3,
|
||||||
|
),
|
||||||
|
dtype=np.uint8,
|
||||||
|
) * background_color)
|
||||||
|
|
||||||
|
# Preprocess the images by order and encode them
|
||||||
|
image_tensor_list = []
|
||||||
|
for image in images:
|
||||||
|
if image is None:
|
||||||
|
# Replace it with the background image
|
||||||
|
image = Image.fromarray(background_image)
|
||||||
|
|
||||||
|
if self.image_size is not None:
|
||||||
|
image = transforms.Resize(self.data_args.image_size)(image)
|
||||||
|
|
||||||
|
if self.args["dataset"].get("auto_adjust_image_brightness", False):
|
||||||
|
pixel_values = list(image.getdata())
|
||||||
|
average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)
|
||||||
|
if average_brightness <= 0.15:
|
||||||
|
image = transforms.ColorJitter(brightness=(1.75, 1.75))(image)
|
||||||
|
|
||||||
|
if self.args["dataset"].get("image_aspect_ratio", "pad") == "pad":
|
||||||
|
|
||||||
|
def expand2square(pil_img, background_color):
|
||||||
|
width, height = pil_img.size
|
||||||
|
if width == height:
|
||||||
|
return pil_img
|
||||||
|
elif width > height:
|
||||||
|
result = Image.new(pil_img.mode, (width, width), background_color)
|
||||||
|
result.paste(pil_img, (0, (width - height) // 2))
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
result = Image.new(pil_img.mode, (height, height), background_color)
|
||||||
|
result.paste(pil_img, ((height - width) // 2, 0))
|
||||||
|
return result
|
||||||
|
|
||||||
|
image = expand2square(image, tuple(int(x * 255) for x in self.image_processor.image_mean))
|
||||||
|
image = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
|
||||||
|
image_tensor_list.append(image)
|
||||||
|
|
||||||
|
image_tensor = torch.stack(image_tensor_list, dim=0).to(device, dtype=dtype)
|
||||||
|
|
||||||
|
image_embeds = self.vision_model(image_tensor).detach()
|
||||||
|
image_embeds = image_embeds.reshape(-1, self.vision_model.hidden_size).unsqueeze(0)
|
||||||
|
|
||||||
|
# Prepare the proprioception states and the control frequency
|
||||||
|
joints = proprio.to(device).unsqueeze(0) # (1, 1, 14)
|
||||||
|
states, state_elem_mask = self._format_joint_to_state(joints) # (1, 1, 128), (1, 128)
|
||||||
|
states, state_elem_mask = states.to(device, dtype=dtype), state_elem_mask.to(device, dtype=dtype)
|
||||||
|
states = states[:, -1:, :] # (1, 1, 128)
|
||||||
|
ctrl_freqs = torch.tensor([self.control_frequency]).to(device)
|
||||||
|
|
||||||
|
text_embeds = text_embeds.to(device, dtype=dtype)
|
||||||
|
|
||||||
|
# Predict the next action chunk given the inputs
|
||||||
|
trajectory = self.policy.predict_action(
|
||||||
|
lang_tokens=text_embeds,
|
||||||
|
lang_attn_mask=torch.ones(text_embeds.shape[:2], dtype=torch.bool, device=text_embeds.device),
|
||||||
|
img_tokens=image_embeds,
|
||||||
|
state_tokens=states,
|
||||||
|
action_mask=state_elem_mask.unsqueeze(1),
|
||||||
|
ctrl_freqs=ctrl_freqs,
|
||||||
|
)
|
||||||
|
trajectory = self._unformat_action_to_joint(trajectory).to(torch.float32)
|
||||||
|
|
||||||
|
return trajectory
|
||||||
53
RDT-170M/scripts/encode_lang.py
Normal file
53
RDT-170M/scripts/encode_lang.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from models.multimodal_encoder.t5_encoder import T5Embedder
|
||||||
|
|
||||||
|
GPU = 0
|
||||||
|
MODEL_PATH = "google/t5-v1_1-xxl"
|
||||||
|
CONFIG_PATH = "configs/base.yaml"
|
||||||
|
SAVE_DIR = "outs/"
|
||||||
|
|
||||||
|
# Modify this to your task name and instruction
|
||||||
|
TASK_NAME = "handover_pan"
|
||||||
|
INSTRUCTION = "Pick up the black marker on the right and put it into the packaging box on the left."
|
||||||
|
|
||||||
|
# Note: if your GPU VRAM is less than 24GB,
|
||||||
|
# it is recommended to enable offloading by specifying an offload directory.
|
||||||
|
OFFLOAD_DIR = (
|
||||||
|
None # Specify your offload directory here, ensuring the directory exists.
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
with open(CONFIG_PATH, "r") as fp:
|
||||||
|
config = yaml.safe_load(fp)
|
||||||
|
|
||||||
|
device = torch.device(f"cuda:{GPU}")
|
||||||
|
text_embedder = T5Embedder(
|
||||||
|
from_pretrained=MODEL_PATH,
|
||||||
|
model_max_length=config["dataset"]["tokenizer_max_length"],
|
||||||
|
device=device,
|
||||||
|
use_offload_folder=OFFLOAD_DIR,
|
||||||
|
)
|
||||||
|
tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
|
||||||
|
|
||||||
|
tokens = tokenizer(INSTRUCTION, return_tensors="pt", padding="longest", truncation=True)["input_ids"].to(device)
|
||||||
|
|
||||||
|
tokens = tokens.view(1, -1)
|
||||||
|
with torch.no_grad():
|
||||||
|
pred = text_encoder(tokens).last_hidden_state.detach().cpu()
|
||||||
|
|
||||||
|
save_path = os.path.join(SAVE_DIR, f"{TASK_NAME}.pt")
|
||||||
|
# We save the embeddings in a dictionary format
|
||||||
|
torch.save({"name": TASK_NAME, "instruction": INSTRUCTION, "embeddings": pred}, save_path)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f'"{INSTRUCTION}" from "{TASK_NAME}" is encoded by "{MODEL_PATH}" into shape {pred.shape} and saved to "{save_path}"'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
57
RDT-170M/scripts/encode_lang_batch_once.py
Normal file
57
RDT-170M/scripts/encode_lang_batch_once.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
import yaml
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from models.multimodal_encoder.t5_encoder import T5Embedder
|
||||||
|
|
||||||
|
|
||||||
|
def encode_lang(
|
||||||
|
DATA_FILE_PATH,
|
||||||
|
TARGET_DIR,
|
||||||
|
GPU,
|
||||||
|
desc_type="seen",
|
||||||
|
tokenizer=None,
|
||||||
|
text_encoder=None,
|
||||||
|
):
|
||||||
|
current_dir = os.path.dirname(__file__)
|
||||||
|
|
||||||
|
with open(os.path.join(current_dir, "../configs/base.yaml"), "r") as fp:
|
||||||
|
config = yaml.safe_load(fp)
|
||||||
|
|
||||||
|
device = torch.device(f"cuda:{GPU}")
|
||||||
|
if tokenizer is None or text_encoder is None:
|
||||||
|
text_embedder = T5Embedder(
|
||||||
|
from_pretrained=os.path.join(current_dir, "../../weights/RDT/t5-v1_1-xxl"),
|
||||||
|
model_max_length=config["dataset"]["tokenizer_max_length"],
|
||||||
|
device=device,
|
||||||
|
use_offload_folder=None,
|
||||||
|
)
|
||||||
|
tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
|
||||||
|
|
||||||
|
with open(DATA_FILE_PATH, "r") as f_instr:
|
||||||
|
instruction_dict = json.load(f_instr)
|
||||||
|
|
||||||
|
instructions = instruction_dict[desc_type]
|
||||||
|
|
||||||
|
# Encode the instructions
|
||||||
|
tokenized_res = tokenizer(instructions, return_tensors="pt", padding="longest", truncation=True)
|
||||||
|
tokens = tokenized_res["input_ids"].to(device)
|
||||||
|
attn_mask = tokenized_res["attention_mask"].to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
text_embeds = (text_encoder(input_ids=tokens, attention_mask=attn_mask)["last_hidden_state"].detach().cpu())
|
||||||
|
|
||||||
|
attn_mask = attn_mask.cpu().bool()
|
||||||
|
if not os.path.exists(f"{TARGET_DIR}/instructions"):
|
||||||
|
os.makedirs(f"{TARGET_DIR}/instructions")
|
||||||
|
# Save the embeddings for training use
|
||||||
|
for i in range(len(instructions)):
|
||||||
|
text_embed = text_embeds[i][attn_mask[i]]
|
||||||
|
save_path = os.path.join(TARGET_DIR, f"instructions/lang_embed_{i}.pt")
|
||||||
|
# print("encoded instructions save_path:",save_path)
|
||||||
|
torch.save(text_embed, save_path)
|
||||||
|
|
||||||
|
return tokenizer, text_encoder
|
||||||
84
RDT-170M/scripts/generate_output_json.py
Normal file
84
RDT-170M/scripts/generate_output_json.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import re
|
||||||
|
|
||||||
|
def extract_metrics_from_log(log_file_path):
|
||||||
|
all_metrics = []
|
||||||
|
pattern = re.compile(
|
||||||
|
r"\{'agilex_sample_mse':\s*([0-9.eE+-]+),\s*'agilex_sample_l2err':\s*([0-9.eE+-]+),\s*'overall_avg_sample_mse':\s*([0-9.eE+-]+),\s*'overall_avg_sample_l2err':\s*([0-9.eE+-]+)\}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with open(log_file_path, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
m = pattern.search(line)
|
||||||
|
if m:
|
||||||
|
metrics = (
|
||||||
|
float(m.group(1)),
|
||||||
|
float(m.group(2)),
|
||||||
|
float(m.group(3)),
|
||||||
|
float(m.group(4))
|
||||||
|
)
|
||||||
|
all_metrics.append(metrics)
|
||||||
|
print(f"Find Metrics: agilex_sample_mse={metrics[0]}, agilex_sample_l2err={metrics[1]}, "
|
||||||
|
f"overall_avg_sample_mse={metrics[2]}, overall_avg_sample_l2err={metrics[3]}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to read log: {e}")
|
||||||
|
return (None, None, None, None)
|
||||||
|
|
||||||
|
if not all_metrics:
|
||||||
|
print("No metrics found in the log file")
|
||||||
|
return (None, None, None, None)
|
||||||
|
|
||||||
|
print(f"\nTotal {len(all_metrics)} metrics found in the log file")
|
||||||
|
|
||||||
|
best_agilex_mse = min(m[0] for m in all_metrics)
|
||||||
|
best_agilex_l2err = min(m[1] for m in all_metrics)
|
||||||
|
best_overall_mse = min(m[2] for m in all_metrics)
|
||||||
|
best_overall_l2err = min(m[3] for m in all_metrics)
|
||||||
|
|
||||||
|
print(f"\nBest metrics:")
|
||||||
|
print(f" agilex_sample_mse: {best_agilex_mse}")
|
||||||
|
print(f" agilex_sample_l2err: {best_agilex_l2err}")
|
||||||
|
print(f" overall_avg_sample_mse: {best_overall_mse}")
|
||||||
|
print(f" overall_avg_sample_l2err: {best_overall_l2err}")
|
||||||
|
|
||||||
|
return (best_agilex_mse, best_agilex_l2err, best_overall_mse, best_overall_l2err)
|
||||||
|
|
||||||
|
def generate_output_json(input_config_file, output_dir, runtime):
|
||||||
|
with open(input_config_file, 'r') as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
log_file = os.path.join(output_dir, 'output.log')
|
||||||
|
agilex_sample_mse, agilex_sample_l2err, overall_avg_sample_mse, overall_avg_sample_l2err = extract_metrics_from_log(log_file)
|
||||||
|
|
||||||
|
if None in [agilex_sample_mse, agilex_sample_l2err, overall_avg_sample_mse, overall_avg_sample_l2err]:
|
||||||
|
print("Warning: Some metrics are missing in the log file.")
|
||||||
|
|
||||||
|
output_json = {
|
||||||
|
"task_id": config.get("task_id"),
|
||||||
|
"model_type": "RDT-170M",
|
||||||
|
"model_name": config.get("model_name") if "model_name" in config else config.get("train", {}).get("model"),
|
||||||
|
"gpu_id": config.get("gpu_id"),
|
||||||
|
"runtime": runtime,
|
||||||
|
"log_path": log_file,
|
||||||
|
"output_dir": output_dir,
|
||||||
|
"model_path": os.path.join(output_dir, 'pytorch_model.bin'),
|
||||||
|
"metrics": {
|
||||||
|
"agilex_sample_mse": agilex_sample_mse,
|
||||||
|
"agilex_sample_l2err": agilex_sample_l2err,
|
||||||
|
"overall_avg_sample_mse": overall_avg_sample_mse,
|
||||||
|
"overall_avg_sample_l2err": overall_avg_sample_l2err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# 写入 output.json,格式化输出、确保null与规范json一致
|
||||||
|
output_json_path = os.path.join(output_dir, 'output.json')
|
||||||
|
with open(output_json_path, 'w') as f:
|
||||||
|
json.dump(output_json, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if len(sys.argv) != 4:
|
||||||
|
print("Usage: python generate_output_json.py <input_config_file> <output_dir> <runtime>")
|
||||||
|
sys.exit(1)
|
||||||
|
generate_output_json(sys.argv[1], sys.argv[2], sys.argv[3])
|
||||||
325
RDT-170M/scripts/maniskill_model.py
Normal file
325
RDT-170M/scripts/maniskill_model.py
Normal file
@ -0,0 +1,325 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
from configs.state_vec import STATE_VEC_IDX_MAPPING
|
||||||
|
from models.multimodal_encoder.siglip_encoder import SiglipVisionTower
|
||||||
|
from models.multimodal_encoder.t5_encoder import T5Embedder
|
||||||
|
from models.rdt_runner import RDTRunner
|
||||||
|
|
||||||
|
MANISKILL_INDICES = [STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"]
|
||||||
|
for i in range(7)] + [STATE_VEC_IDX_MAPPING[f"right_gripper_open"]]
|
||||||
|
|
||||||
|
|
||||||
|
def create_model(args, pretrained, **kwargs):
|
||||||
|
model = RoboticDiffusionTransformerModel(args, **kwargs)
|
||||||
|
if pretrained is not None:
|
||||||
|
model.load_pretrained_weights(pretrained)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
DATA_STAT = {
|
||||||
|
"state_min": [
|
||||||
|
-0.7463043928146362,
|
||||||
|
-0.0801204964518547,
|
||||||
|
-0.4976441562175751,
|
||||||
|
-2.657780647277832,
|
||||||
|
-0.5742632150650024,
|
||||||
|
1.8309762477874756,
|
||||||
|
-2.2423808574676514,
|
||||||
|
0.0,
|
||||||
|
],
|
||||||
|
"state_max": [
|
||||||
|
0.7645499110221863,
|
||||||
|
1.4967026710510254,
|
||||||
|
0.4650936424732208,
|
||||||
|
-0.3866899907588959,
|
||||||
|
0.5505855679512024,
|
||||||
|
3.2900545597076416,
|
||||||
|
2.5737812519073486,
|
||||||
|
0.03999999910593033,
|
||||||
|
],
|
||||||
|
"action_min": [
|
||||||
|
-0.7472005486488342,
|
||||||
|
-0.08631071448326111,
|
||||||
|
-0.4995281398296356,
|
||||||
|
-2.658363103866577,
|
||||||
|
-0.5751323103904724,
|
||||||
|
1.8290787935256958,
|
||||||
|
-2.245187997817993,
|
||||||
|
-1.0,
|
||||||
|
],
|
||||||
|
"action_max": [
|
||||||
|
0.7654682397842407,
|
||||||
|
1.4984270334243774,
|
||||||
|
0.46786263585090637,
|
||||||
|
-0.38181185722351074,
|
||||||
|
0.5517147779464722,
|
||||||
|
3.291581630706787,
|
||||||
|
2.575840711593628,
|
||||||
|
1.0,
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RoboticDiffusionTransformerModel(object):
|
||||||
|
"""A wrapper for the RDT model, which handles
|
||||||
|
1. Model initialization
|
||||||
|
2. Encodings of instructions
|
||||||
|
3. Model inference
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
args,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
image_size=None,
|
||||||
|
control_frequency=25,
|
||||||
|
pretrained_text_encoder_name_or_path=None,
|
||||||
|
pretrained_vision_encoder_name_or_path=None,
|
||||||
|
):
|
||||||
|
self.args = args
|
||||||
|
self.dtype = dtype
|
||||||
|
self.image_size = image_size
|
||||||
|
self.device = device
|
||||||
|
self.control_frequency = control_frequency
|
||||||
|
self.text_tokenizer, self.text_model = self.get_text_encoder(pretrained_text_encoder_name_or_path)
|
||||||
|
self.image_processor, self.vision_model = self.get_vision_encoder(pretrained_vision_encoder_name_or_path)
|
||||||
|
self.policy = self.get_policy()
|
||||||
|
|
||||||
|
self.state_min = torch.tensor(DATA_STAT["state_min"]).to(device)
|
||||||
|
self.state_max = torch.tensor(DATA_STAT["state_max"]).to(device)
|
||||||
|
self.action_min = torch.tensor(DATA_STAT["action_min"]).to(device)
|
||||||
|
self.action_max = torch.tensor(DATA_STAT["action_max"]).to(device)
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def get_policy(self):
|
||||||
|
"""Initialize the model."""
|
||||||
|
# Initialize model with arguments
|
||||||
|
img_cond_len = (self.args["common"]["img_history_size"] * self.args["common"]["num_cameras"] *
|
||||||
|
self.vision_model.num_patches)
|
||||||
|
|
||||||
|
_model = RDTRunner(
|
||||||
|
action_dim=self.args["common"]["state_dim"],
|
||||||
|
pred_horizon=self.args["common"]["action_chunk_size"],
|
||||||
|
config=self.args["model"],
|
||||||
|
lang_token_dim=self.args["model"]["lang_token_dim"],
|
||||||
|
img_token_dim=self.args["model"]["img_token_dim"],
|
||||||
|
state_token_dim=self.args["model"]["state_token_dim"],
|
||||||
|
max_lang_cond_len=self.args["dataset"]["tokenizer_max_length"],
|
||||||
|
img_cond_len=img_cond_len,
|
||||||
|
img_pos_embed_config=[
|
||||||
|
# No initial pos embed in the last grid size
|
||||||
|
# since we've already done in ViT
|
||||||
|
(
|
||||||
|
"image",
|
||||||
|
(
|
||||||
|
self.args["common"]["img_history_size"],
|
||||||
|
self.args["common"]["num_cameras"],
|
||||||
|
-self.vision_model.num_patches,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
lang_pos_embed_config=[
|
||||||
|
# Similarly, no initial pos embed for language
|
||||||
|
("lang", -self.args["dataset"]["tokenizer_max_length"]),
|
||||||
|
],
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _model
|
||||||
|
|
||||||
|
def get_text_encoder(self, pretrained_text_encoder_name_or_path):
|
||||||
|
text_embedder = T5Embedder(
|
||||||
|
from_pretrained=pretrained_text_encoder_name_or_path,
|
||||||
|
model_max_length=self.args["dataset"]["tokenizer_max_length"],
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
|
||||||
|
return tokenizer, text_encoder
|
||||||
|
|
||||||
|
def get_vision_encoder(self, pretrained_vision_encoder_name_or_path):
|
||||||
|
vision_encoder = SiglipVisionTower(vision_tower=pretrained_vision_encoder_name_or_path, args=None)
|
||||||
|
image_processor = vision_encoder.image_processor
|
||||||
|
return image_processor, vision_encoder
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Set model to evaluation mode."""
|
||||||
|
device = self.device
|
||||||
|
weight_dtype = self.dtype
|
||||||
|
self.policy.eval()
|
||||||
|
self.text_model.eval()
|
||||||
|
self.vision_model.eval()
|
||||||
|
|
||||||
|
self.policy = self.policy.to(device, dtype=weight_dtype)
|
||||||
|
self.text_model = self.text_model.to(device, dtype=weight_dtype)
|
||||||
|
self.vision_model = self.vision_model.to(device, dtype=weight_dtype)
|
||||||
|
|
||||||
|
def load_pretrained_weights(self, pretrained=None):
|
||||||
|
if pretrained is None:
|
||||||
|
return
|
||||||
|
print(f"Loading weights from {pretrained}")
|
||||||
|
filename = os.path.basename(pretrained)
|
||||||
|
if filename.endswith(".pt"):
|
||||||
|
checkpoint = torch.load(pretrained)
|
||||||
|
self.policy.load_state_dict(checkpoint["module"])
|
||||||
|
elif filename.endswith(".safetensors"):
|
||||||
|
from safetensors.torch import load_model
|
||||||
|
|
||||||
|
load_model(self.policy, pretrained)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unknown checkpoint format: {pretrained}")
|
||||||
|
|
||||||
|
def encode_instruction(self, instruction, device="cuda"):
|
||||||
|
"""Encode string instruction to latent embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instruction: a string of instruction
|
||||||
|
device: a string of device
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pred: a tensor of latent embeddings of shape (text_max_length, 512)
|
||||||
|
"""
|
||||||
|
tokens = self.text_tokenizer(instruction, return_tensors="pt", padding="longest",
|
||||||
|
truncation=True)["input_ids"].to(device)
|
||||||
|
|
||||||
|
tokens = tokens.view(1, -1)
|
||||||
|
with torch.no_grad():
|
||||||
|
pred = self.text_model(tokens).last_hidden_state.detach()
|
||||||
|
|
||||||
|
return pred
|
||||||
|
|
||||||
|
def _format_joint_to_state(self, joints):
|
||||||
|
"""
|
||||||
|
Format the robot joint state into the unified state vector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
joints (torch.Tensor): The joint state to be formatted.
|
||||||
|
qpos ([B, N, 14]).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
state (torch.Tensor): The formatted state for RDT ([B, N, 128]).
|
||||||
|
"""
|
||||||
|
# Rescale the gripper
|
||||||
|
# joints = joints / torch.tensor(
|
||||||
|
# [[[1, 1, 1, 1, 1, 1, 4.7908, 1, 1, 1, 1, 1, 1, 4.7888]]],
|
||||||
|
# device=joints.device, dtype=joints.dtype
|
||||||
|
# )
|
||||||
|
|
||||||
|
# normalize to -1,1
|
||||||
|
joints = (joints - self.state_min) / (self.state_max - self.state_min) * 2 - 1
|
||||||
|
B, N, _ = joints.shape
|
||||||
|
state = torch.zeros(
|
||||||
|
(B, N, self.args["model"]["state_token_dim"]),
|
||||||
|
device=joints.device,
|
||||||
|
dtype=joints.dtype,
|
||||||
|
)
|
||||||
|
# assemble the unifed state vector
|
||||||
|
state[:, :, MANISKILL_INDICES] = joints
|
||||||
|
state_elem_mask = torch.zeros(
|
||||||
|
(B, self.args["model"]["state_token_dim"]),
|
||||||
|
device=joints.device,
|
||||||
|
dtype=joints.dtype,
|
||||||
|
)
|
||||||
|
state_elem_mask[:, MANISKILL_INDICES] = 1
|
||||||
|
return state, state_elem_mask
|
||||||
|
|
||||||
|
def _unformat_action_to_joint(self, action):
|
||||||
|
action_indices = MANISKILL_INDICES
|
||||||
|
joints = action[:, :, action_indices]
|
||||||
|
|
||||||
|
# denormalize to action space
|
||||||
|
|
||||||
|
joints = (joints + 1) / 2 * (self.action_max - self.action_min) + self.action_min
|
||||||
|
|
||||||
|
return joints
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(self, proprio, images, text_embeds):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
proprio: proprioceptive states
|
||||||
|
images: RGB images
|
||||||
|
text_embeds: instruction embeddings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
action: predicted action
|
||||||
|
"""
|
||||||
|
device = self.device
|
||||||
|
dtype = self.dtype
|
||||||
|
|
||||||
|
background_color = np.array([int(x * 255) for x in self.image_processor.image_mean],
|
||||||
|
dtype=np.uint8).reshape(1, 1, 3)
|
||||||
|
background_image = (np.ones(
|
||||||
|
(
|
||||||
|
self.image_processor.size["height"],
|
||||||
|
self.image_processor.size["width"],
|
||||||
|
3,
|
||||||
|
),
|
||||||
|
dtype=np.uint8,
|
||||||
|
) * background_color)
|
||||||
|
|
||||||
|
image_tensor_list = []
|
||||||
|
for image in images:
|
||||||
|
if image is None:
|
||||||
|
# Replace it with the background image
|
||||||
|
image = Image.fromarray(background_image)
|
||||||
|
|
||||||
|
if self.image_size is not None:
|
||||||
|
image = transforms.Resize(self.data_args.image_size)(image)
|
||||||
|
|
||||||
|
if self.args["dataset"].get("auto_adjust_image_brightness", False):
|
||||||
|
pixel_values = list(image.getdata())
|
||||||
|
average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)
|
||||||
|
if average_brightness <= 0.15:
|
||||||
|
image = transforms.ColorJitter(brightness=(1.75, 1.75))(image)
|
||||||
|
|
||||||
|
if self.args["dataset"].get("image_aspect_ratio", "pad") == "pad":
|
||||||
|
|
||||||
|
def expand2square(pil_img, background_color):
|
||||||
|
width, height = pil_img.size
|
||||||
|
if width == height:
|
||||||
|
return pil_img
|
||||||
|
elif width > height:
|
||||||
|
result = Image.new(pil_img.mode, (width, width), background_color)
|
||||||
|
result.paste(pil_img, (0, (width - height) // 2))
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
result = Image.new(pil_img.mode, (height, height), background_color)
|
||||||
|
result.paste(pil_img, ((height - width) // 2, 0))
|
||||||
|
return result
|
||||||
|
|
||||||
|
image = expand2square(image, tuple(int(x * 255) for x in self.image_processor.image_mean))
|
||||||
|
image = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
|
||||||
|
image_tensor_list.append(image)
|
||||||
|
|
||||||
|
image_tensor = torch.stack(image_tensor_list, dim=0).to(device, dtype=dtype)
|
||||||
|
|
||||||
|
image_embeds = self.vision_model(image_tensor).detach()
|
||||||
|
image_embeds = image_embeds.reshape(-1, self.vision_model.hidden_size).unsqueeze(0)
|
||||||
|
|
||||||
|
# history of actions
|
||||||
|
joints = proprio.to(device).unsqueeze(0) # (1, 1, 14)
|
||||||
|
states, state_elem_mask = self._format_joint_to_state(joints) # (1, 1, 128), (1, 128)
|
||||||
|
states, state_elem_mask = states.to(device, dtype=dtype), state_elem_mask.to(device, dtype=dtype)
|
||||||
|
states = states[:, -1:, :] # (1, 1, 128)
|
||||||
|
ctrl_freqs = torch.tensor([self.control_frequency]).to(device)
|
||||||
|
|
||||||
|
text_embeds = text_embeds.to(device, dtype=dtype)
|
||||||
|
|
||||||
|
trajectory = self.policy.predict_action(
|
||||||
|
lang_tokens=text_embeds,
|
||||||
|
lang_attn_mask=torch.ones(text_embeds.shape[:2], dtype=torch.bool, device=text_embeds.device),
|
||||||
|
img_tokens=image_embeds,
|
||||||
|
state_tokens=states,
|
||||||
|
action_mask=state_elem_mask.unsqueeze(1),
|
||||||
|
ctrl_freqs=ctrl_freqs,
|
||||||
|
)
|
||||||
|
trajectory = self._unformat_action_to_joint(trajectory).to(torch.float32)
|
||||||
|
|
||||||
|
return trajectory
|
||||||
169
RDT-170M/scripts/process_data.py
Normal file
169
RDT-170M/scripts/process_data.py
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append("./")
|
||||||
|
|
||||||
|
import os
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
import pickle
|
||||||
|
import cv2
|
||||||
|
import argparse
|
||||||
|
import yaml
|
||||||
|
from scripts.encode_lang_batch_once import encode_lang
|
||||||
|
|
||||||
|
|
||||||
|
def load_hdf5(dataset_path):
|
||||||
|
if not os.path.isfile(dataset_path):
|
||||||
|
print(f"Dataset does not exist at \n{dataset_path}\n")
|
||||||
|
exit()
|
||||||
|
|
||||||
|
with h5py.File(dataset_path, "r") as root:
|
||||||
|
left_gripper, left_arm = (
|
||||||
|
root["/joint_action/left_gripper"][()],
|
||||||
|
root["/joint_action/left_arm"][()],
|
||||||
|
)
|
||||||
|
right_gripper, right_arm = (
|
||||||
|
root["/joint_action/right_gripper"][()],
|
||||||
|
root["/joint_action/right_arm"][()],
|
||||||
|
)
|
||||||
|
image_dict = dict()
|
||||||
|
for cam_name in root[f"/observation/"].keys():
|
||||||
|
image_dict[cam_name] = root[f"/observation/{cam_name}/rgb"][()]
|
||||||
|
|
||||||
|
return left_gripper, left_arm, right_gripper, right_arm, image_dict
|
||||||
|
|
||||||
|
|
||||||
|
def images_encoding(imgs):
|
||||||
|
encode_data = []
|
||||||
|
padded_data = []
|
||||||
|
max_len = 0
|
||||||
|
for i in range(len(imgs)):
|
||||||
|
success, encoded_image = cv2.imencode(".jpg", imgs[i])
|
||||||
|
jpeg_data = encoded_image.tobytes()
|
||||||
|
encode_data.append(jpeg_data)
|
||||||
|
max_len = max(max_len, len(jpeg_data))
|
||||||
|
# padding
|
||||||
|
for i in range(len(imgs)):
|
||||||
|
padded_data.append(encode_data[i].ljust(max_len, b"\0"))
|
||||||
|
return encode_data, max_len
|
||||||
|
|
||||||
|
|
||||||
|
def get_task_config(task_name):
|
||||||
|
with open(f"./task_config/{task_name}.yml", "r", encoding="utf-8") as f:
|
||||||
|
args = yaml.load(f.read(), Loader=yaml.FullLoader)
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def data_transform(path, episode_num, save_path):
|
||||||
|
begin = 0
|
||||||
|
floders = os.listdir(path)
|
||||||
|
assert episode_num <= len(floders), "data num not enough"
|
||||||
|
|
||||||
|
if not os.path.exists(save_path):
|
||||||
|
os.makedirs(save_path)
|
||||||
|
|
||||||
|
for i in range(episode_num):
|
||||||
|
left_gripper_all, left_arm_all, right_gripper_all, right_arm_all, image_dict = (load_hdf5(
|
||||||
|
os.path.join(path, f"episode{i}.hdf5")))
|
||||||
|
qpos = []
|
||||||
|
actions = []
|
||||||
|
cam_high = []
|
||||||
|
cam_right_wrist = []
|
||||||
|
cam_left_wrist = []
|
||||||
|
left_arm_dim = []
|
||||||
|
right_arm_dim = []
|
||||||
|
|
||||||
|
last_state = None
|
||||||
|
for j in range(0, left_gripper_all.shape[0]):
|
||||||
|
|
||||||
|
left_gripper, left_arm, right_gripper, right_arm = (
|
||||||
|
left_gripper_all[j],
|
||||||
|
left_arm_all[j],
|
||||||
|
right_gripper_all[j],
|
||||||
|
right_arm_all[j],
|
||||||
|
)
|
||||||
|
|
||||||
|
state = np.concatenate((left_arm, [left_gripper], right_arm, [right_gripper]), axis=0) # joint
|
||||||
|
state = state.astype(np.float32)
|
||||||
|
|
||||||
|
if j != left_gripper_all.shape[0] - 1:
|
||||||
|
|
||||||
|
qpos.append(state)
|
||||||
|
|
||||||
|
camera_high_bits = image_dict["head_camera"][j]
|
||||||
|
camera_high = cv2.imdecode(np.frombuffer(camera_high_bits, np.uint8), cv2.IMREAD_COLOR)
|
||||||
|
camera_high_resized = cv2.resize(camera_high, (640, 480))
|
||||||
|
cam_high.append(camera_high_resized)
|
||||||
|
|
||||||
|
camera_right_wrist_bits = image_dict["right_camera"][j]
|
||||||
|
camera_right_wrist = cv2.imdecode(np.frombuffer(camera_right_wrist_bits, np.uint8), cv2.IMREAD_COLOR)
|
||||||
|
camera_right_wrist_resized = cv2.resize(camera_right_wrist, (640, 480))
|
||||||
|
cam_right_wrist.append(camera_right_wrist_resized)
|
||||||
|
|
||||||
|
camera_left_wrist_bits = image_dict["left_camera"][j]
|
||||||
|
camera_left_wrist = cv2.imdecode(np.frombuffer(camera_left_wrist_bits, np.uint8), cv2.IMREAD_COLOR)
|
||||||
|
camera_left_wrist_resized = cv2.resize(camera_left_wrist, (640, 480))
|
||||||
|
cam_left_wrist.append(camera_left_wrist_resized)
|
||||||
|
|
||||||
|
if j != 0:
|
||||||
|
action = state
|
||||||
|
actions.append(action)
|
||||||
|
left_arm_dim.append(left_arm.shape[0])
|
||||||
|
right_arm_dim.append(right_arm.shape[0])
|
||||||
|
|
||||||
|
if not os.path.exists(os.path.join(save_path, f"episode_{i}")):
|
||||||
|
os.makedirs(os.path.join(save_path, f"episode_{i}"))
|
||||||
|
hdf5path = os.path.join(save_path, f"episode_{i}/episode_{i}.hdf5")
|
||||||
|
|
||||||
|
with h5py.File(hdf5path, "w") as f:
|
||||||
|
f.create_dataset("action", data=np.array(actions))
|
||||||
|
obs = f.create_group("observations")
|
||||||
|
obs.create_dataset("qpos", data=np.array(qpos))
|
||||||
|
obs.create_dataset("left_arm_dim", data=np.array(left_arm_dim))
|
||||||
|
obs.create_dataset("right_arm_dim", data=np.array(right_arm_dim))
|
||||||
|
image = obs.create_group("images")
|
||||||
|
cam_high_enc, len_high = images_encoding(cam_high)
|
||||||
|
cam_right_wrist_enc, len_right = images_encoding(cam_right_wrist)
|
||||||
|
cam_left_wrist_enc, len_left = images_encoding(cam_left_wrist)
|
||||||
|
image.create_dataset("cam_high", data=cam_high_enc, dtype=f"S{len_high}")
|
||||||
|
image.create_dataset("cam_right_wrist", data=cam_right_wrist_enc, dtype=f"S{len_right}")
|
||||||
|
image.create_dataset("cam_left_wrist", data=cam_left_wrist_enc, dtype=f"S{len_left}")
|
||||||
|
|
||||||
|
begin += 1
|
||||||
|
print(f"proccess {i} success!")
|
||||||
|
|
||||||
|
return begin
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Process some episodes.")
|
||||||
|
parser.add_argument("task_name", type=str)
|
||||||
|
parser.add_argument("task_config", type=str)
|
||||||
|
parser.add_argument("expert_data_num", type=int)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
task_name = args.task_name
|
||||||
|
task_config = args.task_config
|
||||||
|
expert_data_num = args.expert_data_num
|
||||||
|
|
||||||
|
load_dir = os.path.join("../../data", str(task_name), str(task_config), "data")
|
||||||
|
|
||||||
|
print(f"read data from path: {load_dir}")
|
||||||
|
begin = data_transform(
|
||||||
|
load_dir,
|
||||||
|
expert_data_num,
|
||||||
|
f"./processed_data/{task_name}-{task_config}-{expert_data_num}",
|
||||||
|
)
|
||||||
|
tokenizer, text_encoder = None, None
|
||||||
|
for idx in range(expert_data_num):
|
||||||
|
print(f"Processing Language: {idx}", end="\r")
|
||||||
|
data_file_path = (f"../../data/{task_name}/{task_config}/instructions/episode{idx}.json")
|
||||||
|
target_dir = (f"processed_data/{task_name}-{task_config}-{expert_data_num}/episode_{idx}")
|
||||||
|
tokenizer, text_encoder = encode_lang(
|
||||||
|
DATA_FILE_PATH=data_file_path,
|
||||||
|
TARGET_DIR=target_dir,
|
||||||
|
GPU=0,
|
||||||
|
desc_type="seen",
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
)
|
||||||
31
RDT-170M/scripts/read_config.py
Normal file
31
RDT-170M/scripts/read_config.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import json
|
||||||
|
import yaml
|
||||||
|
import sys
|
||||||
|
|
||||||
|
def read_config(config_file, yaml_file):
|
||||||
|
with open(config_file, 'r') as f:
|
||||||
|
json_config = json.load(f)
|
||||||
|
with open(yaml_file, 'r') as f:
|
||||||
|
yaml_config = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
|
yaml_config["model"] = json_config["train"]["model"] + json_config["task_id"]
|
||||||
|
yaml_config["data_path"] = json_config["train"]["input_data_path"] + "/data"
|
||||||
|
yaml_config["checkpoint_path"] = json_config["train"]["checkpoint_path"] + "/" + json_config["task_id"]
|
||||||
|
yaml_config["pretrained_model_name_or_path"] = "/weights/rdt-170m"
|
||||||
|
yaml_config["cuda_visible_device"] = str(json_config["gpu_id"])
|
||||||
|
print(f"cuda_visible_device: {yaml_config['cuda_visible_device']}")
|
||||||
|
yaml_config["train_batch_size"] = int(json_config["train"]["batch_size"])
|
||||||
|
yaml_config["sample_batch_size"] = int(json_config["train"]["batch_size"]) * 2
|
||||||
|
yaml_config["max_train_steps"] = int(json_config["train"]["epochs"])
|
||||||
|
yaml_config["checkpointing_period"] = int(int(json_config["train"]["epochs"]) / 10)
|
||||||
|
yaml_config["sample_period"] = 200
|
||||||
|
yaml_config["checkpoints_total_limit"] = 50
|
||||||
|
|
||||||
|
|
||||||
|
with open(yaml_file, 'w') as f:
|
||||||
|
yaml.dump(yaml_config, f, default_flow_style=False)
|
||||||
|
|
||||||
|
print("Config YAML file updated successfully")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
read_config(sys.argv[1], sys.argv[2])
|
||||||
22
RDT-170M/scripts/read_yaml.py
Normal file
22
RDT-170M/scripts/read_yaml.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import sys
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
def read_yaml_value(file_path, key):
|
||||||
|
with open(file_path, "r") as file:
|
||||||
|
data = yaml.safe_load(file)
|
||||||
|
value = data.get(key)
|
||||||
|
if value is not None:
|
||||||
|
print(value)
|
||||||
|
else:
|
||||||
|
print(f"Key '{key}' not found in {file_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if len(sys.argv) != 3:
|
||||||
|
print("Usage: python read_yaml.py <file_path> <key>")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
file_path = sys.argv[1]
|
||||||
|
key = sys.argv[2]
|
||||||
|
read_yaml_value(file_path, key)
|
||||||
0
RDT-170M/train/__init__.py
Normal file
0
RDT-170M/train/__init__.py
Normal file
479
RDT-170M/train/dataset.py
Normal file
479
RDT-170M/train/dataset.py
Normal file
@ -0,0 +1,479 @@
|
|||||||
|
import traceback
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
from typing import Dict, Sequence
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
from PIL import Image
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
from data.filelock import FileLock
|
||||||
|
from data.hdf5_vla_dataset import HDF5VLADataset
|
||||||
|
from train.image_corrupt import image_corrupt
|
||||||
|
|
||||||
|
|
||||||
|
def get_clean_item(chunk_dir):
|
||||||
|
"""
|
||||||
|
Get indexes of clean items in a chunk.
|
||||||
|
"""
|
||||||
|
dirty_bit = read_dirty_bit(chunk_dir)
|
||||||
|
return np.where(1 - dirty_bit)[0].tolist()
|
||||||
|
|
||||||
|
|
||||||
|
def save_dirty_bit(chunk_dir, dirty_bit):
|
||||||
|
"""
|
||||||
|
Save the dirty bit to the chunk directory.
|
||||||
|
"""
|
||||||
|
time_stmp = time.time()
|
||||||
|
while time.time() - time_stmp < 10.0:
|
||||||
|
try:
|
||||||
|
file_path = os.path.join(chunk_dir, "dirty_bit")
|
||||||
|
lock = FileLock(file_path)
|
||||||
|
lock.acquire_write_lock()
|
||||||
|
with open(file_path, "wb") as file:
|
||||||
|
file.write(dirty_bit.tobytes())
|
||||||
|
lock.release_lock()
|
||||||
|
return
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
lock.release_lock()
|
||||||
|
raise KeyboardInterrupt
|
||||||
|
except BaseException:
|
||||||
|
lock.release_lock()
|
||||||
|
continue
|
||||||
|
raise RuntimeError("Failed to save dirty bit.")
|
||||||
|
|
||||||
|
|
||||||
|
def read_dirty_bit(chunk_dir):
|
||||||
|
"""
|
||||||
|
Read the dirty bit from the chunk directory.
|
||||||
|
"""
|
||||||
|
# If error occurs, retry
|
||||||
|
time_stmp = time.time()
|
||||||
|
while time.time() - time_stmp < 10.0:
|
||||||
|
try:
|
||||||
|
file_path = os.path.join(chunk_dir, "dirty_bit")
|
||||||
|
lock = FileLock(file_path)
|
||||||
|
lock.acquire_read_lock()
|
||||||
|
with open(file_path, "rb") as file:
|
||||||
|
dirty_bit = np.frombuffer(file.read(), dtype=np.uint8).copy()
|
||||||
|
lock.release_lock()
|
||||||
|
assert len(dirty_bit) > 0
|
||||||
|
return dirty_bit
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
lock.release_lock()
|
||||||
|
raise KeyboardInterrupt
|
||||||
|
except BaseException:
|
||||||
|
lock.release_lock()
|
||||||
|
continue
|
||||||
|
raise RuntimeError("Failed to read dirty bit.")
|
||||||
|
|
||||||
|
|
||||||
|
class VLAConsumerDataset(Dataset):
|
||||||
|
"""A vision-languange-action Dataset for supervised training.
|
||||||
|
This dataset will load data from the buffer directory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config_path,
|
||||||
|
config,
|
||||||
|
tokenizer,
|
||||||
|
image_processor,
|
||||||
|
num_cameras,
|
||||||
|
img_history_size,
|
||||||
|
image_size=None,
|
||||||
|
auto_adjust_image_brightness=False,
|
||||||
|
image_aug=False,
|
||||||
|
dataset_type="pretrain",
|
||||||
|
cond_mask_prob=0.1,
|
||||||
|
cam_ext_mask_prob=-1.0,
|
||||||
|
state_noise_snr=None,
|
||||||
|
use_hdf5=False,
|
||||||
|
use_precomp_lang_embed=False,
|
||||||
|
):
|
||||||
|
super(VLAConsumerDataset, self).__init__()
|
||||||
|
|
||||||
|
# Load the control frequency for each dataset
|
||||||
|
with open("configs/dataset_control_freq.json", "r") as fp:
|
||||||
|
self.control_freq = json.load(fp)
|
||||||
|
# Load the dataset names
|
||||||
|
dataset_names_cfg = ("configs/pretrain_datasets.json"
|
||||||
|
if dataset_type == "pretrain" else "configs/finetune_datasets.json")
|
||||||
|
with open(dataset_names_cfg, "r") as file:
|
||||||
|
DATASET_NAMES = json.load(file)
|
||||||
|
# Create the mapping between dataset name and id
|
||||||
|
self.dataset_name2id = {name: i for i, name in enumerate(DATASET_NAMES)}
|
||||||
|
self.dataset_id2name = {i: name for i, name in enumerate(DATASET_NAMES)}
|
||||||
|
|
||||||
|
self.image_processor = image_processor
|
||||||
|
self.model_config_path = model_config_path
|
||||||
|
self.buffer_dir = config["buf_path"]
|
||||||
|
self.num_chunks = config["buf_num_chunks"]
|
||||||
|
self.chunk_size = config["buf_chunk_size"]
|
||||||
|
self.tokenizer_max_length = config["tokenizer_max_length"]
|
||||||
|
self.image_aspect_ratio = config["image_aspect_ratio"]
|
||||||
|
self.state_noise_snr = state_noise_snr
|
||||||
|
self.num_cameras = num_cameras
|
||||||
|
self.img_history_size = img_history_size
|
||||||
|
self.cond_mask_prob = cond_mask_prob
|
||||||
|
self.cam_ext_mask_prob = cam_ext_mask_prob
|
||||||
|
self.use_hdf5 = use_hdf5
|
||||||
|
self.hdf5_dataset = None
|
||||||
|
if use_hdf5:
|
||||||
|
self.hdf5_dataset = HDF5VLADataset(self.model_config_path)
|
||||||
|
self.use_precomp_lang_embed = use_precomp_lang_embed
|
||||||
|
if use_precomp_lang_embed:
|
||||||
|
self.empty_lang_embed = torch.load("data/empty_lang_embed.pt")
|
||||||
|
|
||||||
|
# Load dataset stat
|
||||||
|
with open("configs/dataset_stat.json", "r") as f:
|
||||||
|
dataset_stat = json.load(f)
|
||||||
|
self.dataset_stat = dataset_stat
|
||||||
|
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.image_size = image_size
|
||||||
|
self.auto_adjust_image_brightness = auto_adjust_image_brightness
|
||||||
|
self.image_aug = image_aug
|
||||||
|
|
||||||
|
self.last_content = None
|
||||||
|
self.last_meta = None
|
||||||
|
|
||||||
|
def get_dataset_name2id(self):
|
||||||
|
return self.dataset_name2id
|
||||||
|
|
||||||
|
def get_dataset_id2name(self):
|
||||||
|
return self.dataset_id2name
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def pairwise(iterable):
|
||||||
|
a = iter(iterable)
|
||||||
|
return zip(a, a)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_data_from_chunk(chunk_dir, chunk_item_idx):
|
||||||
|
# If error occurs, retry
|
||||||
|
time_stmp = time.time()
|
||||||
|
while time.time() - time_stmp < 10.0:
|
||||||
|
try:
|
||||||
|
locks = []
|
||||||
|
file_path = os.path.join(chunk_dir, f"json_content_{chunk_item_idx}.json")
|
||||||
|
lock = FileLock(file_path)
|
||||||
|
locks.append(lock)
|
||||||
|
lock.acquire_read_lock()
|
||||||
|
with open(file_path, "r") as file:
|
||||||
|
json_content = json.load(file)
|
||||||
|
lock.release_lock()
|
||||||
|
file_path = os.path.join(chunk_dir, f"sample_{chunk_item_idx}.npz")
|
||||||
|
lock = FileLock(file_path)
|
||||||
|
locks.append(lock)
|
||||||
|
lock.acquire_read_lock()
|
||||||
|
with open(file_path, "rb") as file:
|
||||||
|
sample_dict = np.load(file)
|
||||||
|
meta = tuple(sample_dict.values())
|
||||||
|
lock.release_lock()
|
||||||
|
return json_content, meta
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
for lock in locks:
|
||||||
|
lock.release_lock()
|
||||||
|
raise KeyboardInterrupt
|
||||||
|
except BaseException:
|
||||||
|
for lock in locks:
|
||||||
|
lock.release_lock()
|
||||||
|
continue
|
||||||
|
raise RuntimeError("Failed to load sample.")
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
if self.use_hdf5:
|
||||||
|
return len(self.hdf5_dataset)
|
||||||
|
else:
|
||||||
|
return self.num_chunks * self.chunk_size
|
||||||
|
|
||||||
|
def _safe_load(self, index):
|
||||||
|
read_chunk_item_indices = []
|
||||||
|
# Start searching from a random chunk
|
||||||
|
read_chunk_idx = index // self.chunk_size
|
||||||
|
while len(read_chunk_item_indices) == 0:
|
||||||
|
read_chunk_dir = os.path.join(self.buffer_dir, f"chunk_{read_chunk_idx}")
|
||||||
|
try:
|
||||||
|
read_chunk_item_indices = get_clean_item(read_chunk_dir)
|
||||||
|
except BaseException as e:
|
||||||
|
# Print the error info
|
||||||
|
print("Error catched when searching a clean chunk:", e)
|
||||||
|
traceback.print_exc()
|
||||||
|
read_chunk_item_indices = []
|
||||||
|
read_chunk_idx = (read_chunk_idx + 1) % self.num_chunks
|
||||||
|
|
||||||
|
# read_chunk_item_index = random.choice(read_chunk_item_indices)
|
||||||
|
# read_chunk_item_index = read_chunk_item_indices.pop()
|
||||||
|
random_item_index = index % len(read_chunk_item_indices)
|
||||||
|
read_chunk_item_index = read_chunk_item_indices[random_item_index]
|
||||||
|
|
||||||
|
# Modify the dirty bit
|
||||||
|
try:
|
||||||
|
dirty_bit = read_dirty_bit(read_chunk_dir)
|
||||||
|
dirty_bit[read_chunk_item_index] = 1
|
||||||
|
save_dirty_bit(read_chunk_dir, dirty_bit)
|
||||||
|
except BaseException as e:
|
||||||
|
# Print the error info
|
||||||
|
print("Error catched when modifying the dirty bit:", e)
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
# load the sample
|
||||||
|
try:
|
||||||
|
content, meta = self._load_data_from_chunk(read_chunk_dir, read_chunk_item_index)
|
||||||
|
self.last_content, self.last_meta = content, meta
|
||||||
|
except BaseException as e:
|
||||||
|
# Print the error info
|
||||||
|
print("Error catched when loading sample:", e)
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
# If failed to load the data, return the last loaded data for robustness
|
||||||
|
content, meta = self.last_content, self.last_meta
|
||||||
|
|
||||||
|
return (content, *meta)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
# For robustness, we will try to load the data until we succeed
|
||||||
|
while True:
|
||||||
|
data_dict = None
|
||||||
|
try:
|
||||||
|
if self.use_hdf5:
|
||||||
|
res = self.hdf5_dataset.get_item()
|
||||||
|
content = res["meta"]
|
||||||
|
states = res["state"]
|
||||||
|
actions = res["actions"]
|
||||||
|
state_elem_mask = res["state_indicator"]
|
||||||
|
image_metas = [
|
||||||
|
res["cam_high"],
|
||||||
|
res["cam_high_mask"],
|
||||||
|
res["cam_right_wrist"],
|
||||||
|
res["cam_right_wrist_mask"],
|
||||||
|
res["cam_left_wrist"],
|
||||||
|
res["cam_left_wrist_mask"],
|
||||||
|
]
|
||||||
|
state_std = res["state_std"]
|
||||||
|
state_mean = res["state_mean"]
|
||||||
|
state_norm = res["state_norm"]
|
||||||
|
else:
|
||||||
|
(
|
||||||
|
content,
|
||||||
|
_,
|
||||||
|
states,
|
||||||
|
_,
|
||||||
|
actions,
|
||||||
|
_,
|
||||||
|
state_elem_mask,
|
||||||
|
*image_metas,
|
||||||
|
state_std,
|
||||||
|
state_mean,
|
||||||
|
state_norm,
|
||||||
|
) = self._safe_load(index)
|
||||||
|
|
||||||
|
data_dict = {}
|
||||||
|
data_dict["dataset_name"] = content["dataset_name"]
|
||||||
|
data_dict["data_idx"] = self.dataset_name2id[data_dict["dataset_name"]]
|
||||||
|
data_dict["ctrl_freq"] = (self.control_freq[data_dict["dataset_name"]]
|
||||||
|
if random.random() > self.cond_mask_prob else 0)
|
||||||
|
|
||||||
|
if self.state_noise_snr is not None:
|
||||||
|
states += np.random.normal(
|
||||||
|
0.0,
|
||||||
|
state_std / np.sqrt(10**(self.state_noise_snr / 10)),
|
||||||
|
states.shape,
|
||||||
|
)
|
||||||
|
ds_state_mean = np.array(self.dataset_stat[data_dict["dataset_name"]]["state_mean"])
|
||||||
|
ds_state_mean = np.tile(ds_state_mean[None], (states.shape[0], 1))
|
||||||
|
# Randomly mask the states by the mean state
|
||||||
|
data_dict["states"] = (states if random.random() > self.cond_mask_prob else ds_state_mean)
|
||||||
|
data_dict["actions"] = actions
|
||||||
|
data_dict["state_elem_mask"] = (state_elem_mask if random.random() > self.cond_mask_prob else
|
||||||
|
np.zeros_like(state_elem_mask))
|
||||||
|
|
||||||
|
# Stat for the episode that the step belongs to
|
||||||
|
data_dict["state_norm"] = state_norm
|
||||||
|
|
||||||
|
# We replace the invalid images with the background image
|
||||||
|
# and also randomly mask images by the background image
|
||||||
|
background_color = np.array(
|
||||||
|
[int(x * 255) for x in self.image_processor.image_mean],
|
||||||
|
dtype=np.uint8,
|
||||||
|
).reshape(1, 1, 3)
|
||||||
|
background_image = (np.ones(
|
||||||
|
(
|
||||||
|
self.image_processor.size["height"],
|
||||||
|
self.image_processor.size["width"],
|
||||||
|
3,
|
||||||
|
),
|
||||||
|
dtype=np.uint8,
|
||||||
|
) * background_color)
|
||||||
|
|
||||||
|
image_metas = list(self.pairwise(image_metas))
|
||||||
|
mask_probs = [self.cond_mask_prob] * self.num_cameras
|
||||||
|
if self.cam_ext_mask_prob >= 0.0:
|
||||||
|
mask_probs[0] = self.cam_ext_mask_prob
|
||||||
|
rearranged_images = []
|
||||||
|
for i in range(self.img_history_size):
|
||||||
|
for j in range(self.num_cameras):
|
||||||
|
images, image_mask = image_metas[j]
|
||||||
|
image, valid = images[i], image_mask[i]
|
||||||
|
if (valid and (math.prod(image.shape) > 0) and (random.random() > mask_probs[j])):
|
||||||
|
rearranged_images.append((image, True))
|
||||||
|
else:
|
||||||
|
rearranged_images.append((background_image.copy(), False))
|
||||||
|
|
||||||
|
preprocessed_images = []
|
||||||
|
processor = self.image_processor
|
||||||
|
for image, valid in rearranged_images:
|
||||||
|
image = Image.fromarray(image)
|
||||||
|
if self.image_size is not None:
|
||||||
|
image = transforms.Resize(self.image_size)(image) # (1008, 336)
|
||||||
|
# assert image.height == 336, "We haven't prepare for training with images of different resolutions."
|
||||||
|
|
||||||
|
if valid and self.auto_adjust_image_brightness:
|
||||||
|
pixel_values = list(image.getdata())
|
||||||
|
average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)
|
||||||
|
if average_brightness <= 0.15:
|
||||||
|
image = transforms.ColorJitter(brightness=(1.75, 1.75))(image)
|
||||||
|
|
||||||
|
# Only apply image augmentation to 50% of the images
|
||||||
|
if valid and self.image_aug and (random.random() > 0.5):
|
||||||
|
aug_type = random.choice(["corrput_only", "color_only", "both"])
|
||||||
|
if aug_type != "corrput_only":
|
||||||
|
image = transforms.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.5,
|
||||||
|
hue=0.03)(image)
|
||||||
|
if aug_type != "color_only":
|
||||||
|
image = image_corrupt(image)
|
||||||
|
|
||||||
|
if self.image_aspect_ratio == "pad":
|
||||||
|
|
||||||
|
def expand2square(pil_img, background_color):
|
||||||
|
width, height = pil_img.size
|
||||||
|
if width == height:
|
||||||
|
return pil_img
|
||||||
|
elif width > height:
|
||||||
|
result = Image.new(pil_img.mode, (width, width), background_color)
|
||||||
|
result.paste(pil_img, (0, (width - height) // 2))
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
result = Image.new(pil_img.mode, (height, height), background_color)
|
||||||
|
result.paste(pil_img, ((height - width) // 2, 0))
|
||||||
|
return result
|
||||||
|
|
||||||
|
image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
|
||||||
|
image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
|
||||||
|
preprocessed_images.append(image)
|
||||||
|
data_dict["images"] = preprocessed_images
|
||||||
|
|
||||||
|
if self.use_precomp_lang_embed:
|
||||||
|
if content["instruction"][-1] == ".":
|
||||||
|
content["instruction"] = content["instruction"][:-1]
|
||||||
|
data_dict["lang_embed"] = (torch.load(content["instruction"])
|
||||||
|
if random.random() > self.cond_mask_prob else self.empty_lang_embed)
|
||||||
|
else:
|
||||||
|
instruction = (content["instruction"] if random.random() > self.cond_mask_prob else "")
|
||||||
|
data_dict["input_ids"] = self.tokenizer(
|
||||||
|
instruction,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding="longest",
|
||||||
|
truncation=False,
|
||||||
|
).input_ids[0]
|
||||||
|
|
||||||
|
assert (
|
||||||
|
len(data_dict["input_ids"]) <= self.tokenizer_max_length
|
||||||
|
), f"Instruction length {len(data_dict['input_ids'])} exceeds the maximum length {self.tokenizer_max_length}."
|
||||||
|
|
||||||
|
for k, v in data_dict.items():
|
||||||
|
if isinstance(v, np.ndarray):
|
||||||
|
data_dict[k] = torch.from_numpy(v)
|
||||||
|
|
||||||
|
for k, v in data_dict.items():
|
||||||
|
assert not isinstance(v, np.ndarray), f"key: {k}, value: {v}"
|
||||||
|
# data_dict[k] = torch.from_numpy(v)
|
||||||
|
|
||||||
|
return data_dict
|
||||||
|
except BaseException as e:
|
||||||
|
# Print the error info
|
||||||
|
if data_dict is not None:
|
||||||
|
print(
|
||||||
|
f"Error catched when processing sample from {data_dict.get('dataset_name')}:",
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(f"Error catched when processing sample:", e)
|
||||||
|
traceback.print_exc()
|
||||||
|
# Try incresing the index
|
||||||
|
index = (index + 1) % len(self)
|
||||||
|
|
||||||
|
|
||||||
|
class DataCollatorForVLAConsumerDataset(object):
|
||||||
|
"""Collate examples for supervised training."""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: transformers.PreTrainedTokenizer) -> None:
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
||||||
|
batch = {
|
||||||
|
"states": [],
|
||||||
|
"actions": [],
|
||||||
|
"state_elem_mask": [],
|
||||||
|
"state_norm": [],
|
||||||
|
"images": [],
|
||||||
|
"data_indices": [],
|
||||||
|
"ctrl_freqs": [],
|
||||||
|
}
|
||||||
|
input_ids = []
|
||||||
|
lang_embeds = []
|
||||||
|
lang_embed_lens = []
|
||||||
|
|
||||||
|
for instance in instances:
|
||||||
|
# Convert all the numpy arrays to tensor
|
||||||
|
keys_to_check = [
|
||||||
|
"states",
|
||||||
|
"actions",
|
||||||
|
"state_elem_mask",
|
||||||
|
"state_norm",
|
||||||
|
]
|
||||||
|
for key in keys_to_check:
|
||||||
|
if isinstance(instance[key], torch.Tensor):
|
||||||
|
item = instance[key]
|
||||||
|
else:
|
||||||
|
item = torch.from_numpy(instance[key])
|
||||||
|
batch[key].append(item)
|
||||||
|
|
||||||
|
if "input_ids" in instance:
|
||||||
|
input_ids.append(instance["input_ids"])
|
||||||
|
else:
|
||||||
|
lang_embeds.append(instance["lang_embed"])
|
||||||
|
lang_embed_lens.append(instance["lang_embed"].shape[0])
|
||||||
|
|
||||||
|
batch["images"].append(torch.stack(instance["images"], dim=0))
|
||||||
|
batch["data_indices"].append(instance["data_idx"])
|
||||||
|
batch["ctrl_freqs"].append(instance["ctrl_freq"])
|
||||||
|
|
||||||
|
keys_to_stack = ["states", "actions", "state_elem_mask", "state_norm", "images"]
|
||||||
|
for key in keys_to_stack:
|
||||||
|
batch[key] = torch.stack(batch[key], dim=0)
|
||||||
|
|
||||||
|
batch["ctrl_freqs"] = torch.tensor(batch["ctrl_freqs"])
|
||||||
|
|
||||||
|
if len(input_ids) > 0:
|
||||||
|
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids,
|
||||||
|
batch_first=True,
|
||||||
|
padding_value=self.tokenizer.pad_token_id)
|
||||||
|
batch["input_ids"] = input_ids
|
||||||
|
batch["lang_attn_mask"] = input_ids.ne(self.tokenizer.pad_token_id)
|
||||||
|
else:
|
||||||
|
lang_embeds = torch.nn.utils.rnn.pad_sequence(lang_embeds, batch_first=True, padding_value=0)
|
||||||
|
input_lang_attn_mask = torch.zeros(lang_embeds.shape[0], lang_embeds.shape[1], dtype=torch.bool)
|
||||||
|
for i, l in enumerate(lang_embed_lens):
|
||||||
|
input_lang_attn_mask[i, :l] = True
|
||||||
|
batch["lang_embeds"] = lang_embeds
|
||||||
|
batch["lang_attn_mask"] = input_lang_attn_mask
|
||||||
|
|
||||||
|
return batch
|
||||||
45
RDT-170M/train/image_corrupt.py
Normal file
45
RDT-170M/train/image_corrupt.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
import warnings
|
||||||
|
|
||||||
|
warnings.simplefilter(action="ignore", category=FutureWarning)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
np.bool = np.bool_
|
||||||
|
import imgaug.augmenters as iaa
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
# Define our sequence of augmentation steps that will be applied to every image.
|
||||||
|
seq = iaa.Sequential(
|
||||||
|
[
|
||||||
|
# Execute one of the following noise augmentations
|
||||||
|
iaa.OneOf([
|
||||||
|
iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05 * 255), per_channel=0.5),
|
||||||
|
iaa.AdditiveLaplaceNoise(scale=(0.0, 0.05 * 255), per_channel=0.5),
|
||||||
|
iaa.AdditivePoissonNoise(lam=(0.0, 0.05 * 255), per_channel=0.5),
|
||||||
|
]),
|
||||||
|
# Execute one or none of the following blur augmentations
|
||||||
|
iaa.SomeOf(
|
||||||
|
(0, 1),
|
||||||
|
[
|
||||||
|
iaa.OneOf([
|
||||||
|
iaa.GaussianBlur((0, 3.0)),
|
||||||
|
iaa.AverageBlur(k=(2, 7)),
|
||||||
|
iaa.MedianBlur(k=(3, 11)),
|
||||||
|
]),
|
||||||
|
iaa.MotionBlur(k=(3, 36)),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
# do all of the above augmentations in random order
|
||||||
|
random_order=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def image_corrupt(image: Image):
|
||||||
|
image_arr = np.array(image)
|
||||||
|
image_arr = image_arr[None, ...]
|
||||||
|
|
||||||
|
image_arr = seq(images=image_arr)
|
||||||
|
|
||||||
|
image = Image.fromarray(image_arr[0])
|
||||||
|
return image
|
||||||
101
RDT-170M/train/sample.py
Normal file
101
RDT-170M/train/sample.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def log_sample_res(
|
||||||
|
text_encoder,
|
||||||
|
vision_encoder,
|
||||||
|
rdt,
|
||||||
|
args,
|
||||||
|
accelerator,
|
||||||
|
weight_dtype,
|
||||||
|
dataset_id2name,
|
||||||
|
dataloader,
|
||||||
|
logger,
|
||||||
|
):
|
||||||
|
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
||||||
|
logger.info(f"Running sampling for {args.num_sample_batches} batches...")
|
||||||
|
|
||||||
|
rdt.eval()
|
||||||
|
|
||||||
|
loss_for_log = defaultdict(float)
|
||||||
|
loss_counter = defaultdict(int)
|
||||||
|
for step, batch in enumerate(dataloader):
|
||||||
|
if step >= args.num_sample_batches:
|
||||||
|
break
|
||||||
|
|
||||||
|
data_indices = batch["data_indices"]
|
||||||
|
ctrl_freqs = batch["ctrl_freqs"]
|
||||||
|
state_norm = batch["state_norm"].to(dtype=weight_dtype)
|
||||||
|
images = batch["images"].to(dtype=weight_dtype)
|
||||||
|
states = batch["states"].to(dtype=weight_dtype)
|
||||||
|
# We only use the last state as input
|
||||||
|
states = states[:, -1:, :]
|
||||||
|
actions = batch["actions"].to(dtype=weight_dtype)
|
||||||
|
state_elem_mask = batch["state_elem_mask"].to(dtype=weight_dtype)
|
||||||
|
|
||||||
|
batch_size, _, C, H, W = images.shape
|
||||||
|
image_embeds = vision_encoder(images.reshape(-1, C, H, W)).detach()
|
||||||
|
image_embeds = image_embeds.reshape((batch_size, -1, vision_encoder.hidden_size))
|
||||||
|
|
||||||
|
lang_attn_mask = batch["lang_attn_mask"]
|
||||||
|
text_embeds = (batch["lang_embeds"].to(dtype=weight_dtype) if args.precomp_lang_embed else text_encoder(
|
||||||
|
input_ids=batch["input_ids"], attention_mask=lang_attn_mask)["last_hidden_state"].detach())
|
||||||
|
|
||||||
|
pred_actions = rdt.predict_action(
|
||||||
|
lang_tokens=text_embeds,
|
||||||
|
lang_attn_mask=lang_attn_mask,
|
||||||
|
img_tokens=image_embeds,
|
||||||
|
state_tokens=states,
|
||||||
|
action_mask=state_elem_mask.unsqueeze(1),
|
||||||
|
ctrl_freqs=ctrl_freqs,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_steps = pred_actions.shape[1]
|
||||||
|
expanded_state_elem_mask = (state_elem_mask.unsqueeze(1).tile((1, num_steps, 1)).float())
|
||||||
|
expanded_state_norm = (state_norm.unsqueeze(1).tile((1, num_steps, 1)).float())
|
||||||
|
|
||||||
|
loss = F.mse_loss(pred_actions, actions, reduction="none").float()
|
||||||
|
|
||||||
|
mse_loss_per_entry = (loss * expanded_state_elem_mask).reshape(
|
||||||
|
(batch_size, -1)).sum(1) / expanded_state_elem_mask.reshape((batch_size, -1)).sum(1)
|
||||||
|
l2_loss_per_entry = loss.sqrt() / (expanded_state_norm + 1e-3)
|
||||||
|
l2_loss_per_entry = (l2_loss_per_entry * expanded_state_elem_mask).reshape(
|
||||||
|
(batch_size, -1)).sum(1) / expanded_state_elem_mask.reshape((batch_size, -1)).sum(1)
|
||||||
|
|
||||||
|
dataset_indices, mse_losses, l2_losses = accelerator.gather_for_metrics((
|
||||||
|
torch.LongTensor(data_indices).to(device=pred_actions.device),
|
||||||
|
mse_loss_per_entry,
|
||||||
|
l2_loss_per_entry,
|
||||||
|
), )
|
||||||
|
dataset_indices = dataset_indices.tolist()
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
for loss_suffix, losses in zip(["_sample_mse", "_sample_l2err"], [mse_losses, l2_losses]):
|
||||||
|
for dataset_idx, loss_tensor in zip(dataset_indices, losses):
|
||||||
|
loss_name = dataset_id2name[dataset_idx] + loss_suffix
|
||||||
|
loss_for_log[loss_name] += loss_tensor.item()
|
||||||
|
loss_counter[loss_name] += 1
|
||||||
|
|
||||||
|
mse_loss = (loss * expanded_state_elem_mask).sum() / expanded_state_elem_mask.sum()
|
||||||
|
mse_loss_scaler = accelerator.gather(mse_loss).mean().item()
|
||||||
|
loss_for_log["overall_avg_sample_mse"] += mse_loss_scaler
|
||||||
|
|
||||||
|
l2_loss = loss.sqrt() / (expanded_state_norm + 1e-3)
|
||||||
|
l2_loss = (l2_loss * expanded_state_elem_mask).sum() / expanded_state_elem_mask.sum()
|
||||||
|
l2_loss_scaler = accelerator.gather(l2_loss).mean().item()
|
||||||
|
loss_for_log["overall_avg_sample_l2err"] += l2_loss_scaler
|
||||||
|
|
||||||
|
for name in loss_for_log:
|
||||||
|
if name in ["overall_avg_sample_mse", "overall_avg_sample_l2err"]:
|
||||||
|
loss_scaler = loss_for_log[name]
|
||||||
|
loss_for_log[name] = round(loss_scaler / (args.num_sample_batches), 4)
|
||||||
|
else:
|
||||||
|
loss_for_log[name] = round(loss_for_log[name] / loss_counter[name], 4)
|
||||||
|
|
||||||
|
rdt.train()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return dict(loss_for_log)
|
||||||
521
RDT-170M/train/train.py
Normal file
521
RDT-170M/train/train.py
Normal file
@ -0,0 +1,521 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import diffusers
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
import transformers
|
||||||
|
import yaml
|
||||||
|
from accelerate import Accelerator
|
||||||
|
from accelerate.utils import DeepSpeedPlugin, ProjectConfiguration, set_seed
|
||||||
|
from diffusers.optimization import get_scheduler
|
||||||
|
from diffusers.utils import is_wandb_available
|
||||||
|
from huggingface_hub import create_repo, upload_folder
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from safetensors.torch import load_model
|
||||||
|
|
||||||
|
from models.ema_model import EMAModel
|
||||||
|
from models.multimodal_encoder.siglip_encoder import SiglipVisionTower
|
||||||
|
from models.multimodal_encoder.t5_encoder import T5Embedder
|
||||||
|
from models.rdt_runner import RDTRunner
|
||||||
|
from train.dataset import DataCollatorForVLAConsumerDataset, VLAConsumerDataset
|
||||||
|
from train.sample import log_sample_res
|
||||||
|
|
||||||
|
if is_wandb_available():
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
|
||||||
|
def save_model_card(repo_id: str, base_model=str, repo_folder=None):
|
||||||
|
yaml = f"""
|
||||||
|
---
|
||||||
|
license: mit
|
||||||
|
base_model: {base_model}
|
||||||
|
language:
|
||||||
|
- en
|
||||||
|
pipeline_tag: robotics
|
||||||
|
library_name: transformers
|
||||||
|
tags:
|
||||||
|
- robotics
|
||||||
|
- pytorch
|
||||||
|
- multimodal
|
||||||
|
- pretraining
|
||||||
|
- vla
|
||||||
|
- diffusion
|
||||||
|
- rdt
|
||||||
|
---
|
||||||
|
"""
|
||||||
|
model_card = f"""
|
||||||
|
# RDT - {repo_id}
|
||||||
|
|
||||||
|
This is a RDT model derived from {base_model}. The weights were trained using [RDT](https://rdt-robotics.github.io/rdt-robotics/).
|
||||||
|
"""
|
||||||
|
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||||
|
f.write(yaml + model_card)
|
||||||
|
|
||||||
|
|
||||||
|
def train(args, logger):
|
||||||
|
# Read the config
|
||||||
|
with open(args.config_path, "r") as fp:
|
||||||
|
config = yaml.safe_load(fp)
|
||||||
|
|
||||||
|
with open(args.model_config_path, "r") as f:
|
||||||
|
model_config = yaml.safe_load(f)
|
||||||
|
# print(model_config)
|
||||||
|
args.output_dir = model_config["checkpoint_path"]
|
||||||
|
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||||
|
|
||||||
|
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
|
||||||
|
accelerator = Accelerator(
|
||||||
|
deepspeed_plugin=(DeepSpeedPlugin(hf_ds_config=args.deepspeed) if args.deepspeed is not None else None),
|
||||||
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
|
mixed_precision=args.mixed_precision,
|
||||||
|
log_with=args.report_to,
|
||||||
|
project_dir=logging_dir,
|
||||||
|
project_config=accelerator_project_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.report_to == "wandb":
|
||||||
|
if not is_wandb_available():
|
||||||
|
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||||
|
|
||||||
|
# Make one log on every process with the configuration for debugging.
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
|
level=logging.INFO,
|
||||||
|
filename=args.output_log_path,
|
||||||
|
filemode='w',
|
||||||
|
)
|
||||||
|
logger.info(accelerator.state, main_process_only=False)
|
||||||
|
if accelerator.is_local_main_process:
|
||||||
|
transformers.utils.logging.set_verbosity_warning()
|
||||||
|
diffusers.utils.logging.set_verbosity_info()
|
||||||
|
else:
|
||||||
|
transformers.utils.logging.set_verbosity_error()
|
||||||
|
diffusers.utils.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
# If passed along, set the training seed now.
|
||||||
|
if args.seed is not None:
|
||||||
|
set_seed(args.seed)
|
||||||
|
|
||||||
|
# Handle the repository creation
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
if args.output_dir is not None:
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
if args.push_to_hub:
|
||||||
|
repo_id = create_repo(
|
||||||
|
repo_id=args.hub_model_id or Path(args.output_dir).name,
|
||||||
|
exist_ok=True,
|
||||||
|
token=args.hub_token,
|
||||||
|
).repo_id
|
||||||
|
|
||||||
|
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
||||||
|
# as these models are only used for inference, keeping weights in full precision is not required.
|
||||||
|
weight_dtype = torch.float32
|
||||||
|
if accelerator.mixed_precision == "fp16":
|
||||||
|
weight_dtype = torch.float16
|
||||||
|
elif accelerator.mixed_precision == "bf16":
|
||||||
|
weight_dtype = torch.bfloat16
|
||||||
|
|
||||||
|
if args.precomp_lang_embed:
|
||||||
|
tokenizer, text_encoder = None, None
|
||||||
|
else:
|
||||||
|
text_embedder = T5Embedder(
|
||||||
|
from_pretrained=args.pretrained_text_encoder_name_or_path,
|
||||||
|
model_max_length=config["dataset"]["tokenizer_max_length"],
|
||||||
|
device=accelerator.device,
|
||||||
|
)
|
||||||
|
tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
|
||||||
|
|
||||||
|
vision_encoder = SiglipVisionTower(vision_tower=args.pretrained_vision_encoder_name_or_path, args=None)
|
||||||
|
image_processor = vision_encoder.image_processor
|
||||||
|
|
||||||
|
# Load from a pretrained checkpoint
|
||||||
|
if args.pretrained_model_name_or_path is not None and not os.path.isfile(args.pretrained_model_name_or_path):
|
||||||
|
logger.info("Constructing model from pretrained checkpoint.")
|
||||||
|
rdt = RDTRunner.from_pretrained(args.pretrained_model_name_or_path)
|
||||||
|
else:
|
||||||
|
logger.info("Constructing model from provided config.")
|
||||||
|
# Calculate the image condition length
|
||||||
|
img_cond_len = (config["common"]["img_history_size"] * config["common"]["num_cameras"] *
|
||||||
|
vision_encoder.num_patches)
|
||||||
|
rdt = RDTRunner(
|
||||||
|
action_dim=config["common"]["state_dim"],
|
||||||
|
pred_horizon=config["common"]["action_chunk_size"],
|
||||||
|
config=config["model"],
|
||||||
|
lang_token_dim=config["model"]["lang_token_dim"],
|
||||||
|
img_token_dim=config["model"]["img_token_dim"],
|
||||||
|
state_token_dim=config["model"]["state_token_dim"],
|
||||||
|
max_lang_cond_len=config["dataset"]["tokenizer_max_length"],
|
||||||
|
img_cond_len=img_cond_len,
|
||||||
|
img_pos_embed_config=[
|
||||||
|
# No initial pos embed in the last grid size
|
||||||
|
# since we've already done in ViT
|
||||||
|
(
|
||||||
|
"image",
|
||||||
|
(
|
||||||
|
config["common"]["img_history_size"],
|
||||||
|
config["common"]["num_cameras"],
|
||||||
|
-vision_encoder.num_patches,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
lang_pos_embed_config=[
|
||||||
|
# Similarly, no initial pos embed for language
|
||||||
|
("lang", -config["dataset"]["tokenizer_max_length"]),
|
||||||
|
],
|
||||||
|
dtype=weight_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
ema_rdt = copy.deepcopy(rdt)
|
||||||
|
ema_model = EMAModel(
|
||||||
|
ema_rdt,
|
||||||
|
update_after_step=config["model"]["ema"]["update_after_step"],
|
||||||
|
inv_gamma=config["model"]["ema"]["inv_gamma"],
|
||||||
|
power=config["model"]["ema"]["power"],
|
||||||
|
min_value=config["model"]["ema"]["min_value"],
|
||||||
|
max_value=config["model"]["ema"]["max_value"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||||
|
# which ensure saving model in huggingface format (config.json + pytorch_model.bin)
|
||||||
|
def save_model_hook(models, weights, output_dir):
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
for model in models:
|
||||||
|
model_to_save = model.module if hasattr(model, "module") else model # type: ignore
|
||||||
|
if isinstance(model_to_save, type(accelerator.unwrap_model(rdt))):
|
||||||
|
model_to_save.save_pretrained(output_dir)
|
||||||
|
|
||||||
|
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||||
|
|
||||||
|
if args.gradient_checkpointing:
|
||||||
|
# TODO:
|
||||||
|
raise NotImplementedError("Gradient checkpointing is not yet implemented.")
|
||||||
|
|
||||||
|
# Enable TF32 for faster training on Ampere GPUs,
|
||||||
|
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||||
|
if args.allow_tf32:
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
|
if args.scale_lr:
|
||||||
|
args.learning_rate = (args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size *
|
||||||
|
accelerator.num_processes)
|
||||||
|
|
||||||
|
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
||||||
|
if args.use_8bit_adam:
|
||||||
|
try:
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
|
||||||
|
|
||||||
|
optimizer_class = bnb.optim.AdamW8bit
|
||||||
|
else:
|
||||||
|
optimizer_class = torch.optim.AdamW
|
||||||
|
|
||||||
|
# Optimizer creation
|
||||||
|
params_to_optimize = rdt.parameters()
|
||||||
|
optimizer = optimizer_class(
|
||||||
|
params_to_optimize,
|
||||||
|
lr=args.learning_rate,
|
||||||
|
betas=(args.adam_beta1, args.adam_beta2),
|
||||||
|
weight_decay=args.adam_weight_decay,
|
||||||
|
eps=args.adam_epsilon,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dataset and DataLoaders creation:
|
||||||
|
train_dataset = VLAConsumerDataset(
|
||||||
|
model_config_path=args.model_config_path, # TODO
|
||||||
|
config=config["dataset"],
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
image_processor=image_processor,
|
||||||
|
num_cameras=config["common"]["num_cameras"],
|
||||||
|
img_history_size=config["common"]["img_history_size"],
|
||||||
|
dataset_type=args.dataset_type,
|
||||||
|
image_aug=args.image_aug,
|
||||||
|
cond_mask_prob=args.cond_mask_prob,
|
||||||
|
cam_ext_mask_prob=args.cam_ext_mask_prob,
|
||||||
|
state_noise_snr=args.state_noise_snr,
|
||||||
|
use_hdf5=args.load_from_hdf5,
|
||||||
|
use_precomp_lang_embed=args.precomp_lang_embed,
|
||||||
|
)
|
||||||
|
sample_dataset = VLAConsumerDataset(
|
||||||
|
model_config_path=args.model_config_path, # TODO
|
||||||
|
config=config["dataset"],
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
image_processor=image_processor,
|
||||||
|
num_cameras=config["common"]["num_cameras"],
|
||||||
|
img_history_size=config["common"]["img_history_size"],
|
||||||
|
dataset_type=args.dataset_type,
|
||||||
|
image_aug=False,
|
||||||
|
cond_mask_prob=0,
|
||||||
|
cam_ext_mask_prob=-1,
|
||||||
|
state_noise_snr=None,
|
||||||
|
use_hdf5=args.load_from_hdf5,
|
||||||
|
use_precomp_lang_embed=args.precomp_lang_embed,
|
||||||
|
)
|
||||||
|
|
||||||
|
data_collator = DataCollatorForVLAConsumerDataset(tokenizer)
|
||||||
|
|
||||||
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=args.train_batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
collate_fn=data_collator,
|
||||||
|
num_workers=args.dataloader_num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
persistent_workers=True,
|
||||||
|
)
|
||||||
|
sample_dataloader = torch.utils.data.DataLoader(
|
||||||
|
sample_dataset,
|
||||||
|
batch_size=args.sample_batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
collate_fn=data_collator,
|
||||||
|
num_workers=args.dataloader_num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
persistent_workers=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scheduler and math around the number of training steps.
|
||||||
|
overrode_max_train_steps = False
|
||||||
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
if args.max_train_steps is None:
|
||||||
|
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||||
|
overrode_max_train_steps = True
|
||||||
|
|
||||||
|
lr_scheduler = get_scheduler(
|
||||||
|
args.lr_scheduler,
|
||||||
|
optimizer=optimizer,
|
||||||
|
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
||||||
|
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||||
|
num_cycles=args.lr_num_cycles,
|
||||||
|
power=args.lr_power,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare everything with our `accelerator`.
|
||||||
|
rdt, optimizer, train_dataloader, sample_dataloader, lr_scheduler = (accelerator.prepare(
|
||||||
|
rdt, optimizer, train_dataloader, sample_dataloader, lr_scheduler))
|
||||||
|
|
||||||
|
ema_rdt.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
|
if text_encoder is not None:
|
||||||
|
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
|
if vision_encoder is not None:
|
||||||
|
vision_encoder.vision_tower.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
|
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||||
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
if overrode_max_train_steps:
|
||||||
|
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||||
|
# Afterwards we recalculate our number of training epochs
|
||||||
|
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||||
|
|
||||||
|
# We need to initialize the trackers we use, and also store our configuration.
|
||||||
|
# The trackers initializes automatically on the main process.
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
accelerator.init_trackers(
|
||||||
|
"VLA",
|
||||||
|
config=vars(args),
|
||||||
|
init_kwargs={"wandb": {
|
||||||
|
"name": f"RoboTwin_RDT_{args.CONFIG_NAME}",
|
||||||
|
}},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Train!
|
||||||
|
total_batch_size = (args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps)
|
||||||
|
|
||||||
|
logger.info("***** Running training *****")
|
||||||
|
logger.info(f" Num examples = {len(train_dataset)}")
|
||||||
|
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
||||||
|
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||||
|
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||||
|
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||||
|
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||||
|
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||||
|
global_step = 0
|
||||||
|
first_epoch = 0
|
||||||
|
|
||||||
|
# Load from a pretrained checkpoint
|
||||||
|
if (args.resume_from_checkpoint is None and args.pretrained_model_name_or_path is not None
|
||||||
|
and os.path.isfile(args.pretrained_model_name_or_path)):
|
||||||
|
# Since EMA is deprecated, we do not load EMA from the pretrained checkpoint
|
||||||
|
logger.info("Loading from a pretrained checkpoint.")
|
||||||
|
checkpoint = torch.load(args.pretrained_model_name_or_path)
|
||||||
|
rdt.module.load_state_dict(checkpoint["module"])
|
||||||
|
|
||||||
|
# Potentially load in the weights and states from a previous save
|
||||||
|
if args.resume_from_checkpoint:
|
||||||
|
if args.resume_from_checkpoint != "latest":
|
||||||
|
path = os.path.basename(args.resume_from_checkpoint)
|
||||||
|
else:
|
||||||
|
# Get the mos recent checkpoint
|
||||||
|
dirs = os.listdir(args.output_dir)
|
||||||
|
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||||
|
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||||
|
path = dirs[-1] if len(dirs) > 0 else None
|
||||||
|
|
||||||
|
if path is None:
|
||||||
|
accelerator.print(
|
||||||
|
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.")
|
||||||
|
args.resume_from_checkpoint = None
|
||||||
|
else:
|
||||||
|
accelerator.print(f"Resuming from checkpoint {path}")
|
||||||
|
try:
|
||||||
|
accelerator.load_state(os.path.join(args.output_dir, path)) # load_module_strict=False
|
||||||
|
except:
|
||||||
|
# load deepspeed's state_dict
|
||||||
|
logger.info("Resuming training state failed. Attempting to only load from model checkpoint.")
|
||||||
|
checkpoint = torch.load(
|
||||||
|
os.path.join(
|
||||||
|
args.output_dir,
|
||||||
|
path,
|
||||||
|
"pytorch_model",
|
||||||
|
"mp_rank_00_model_states.pt",
|
||||||
|
))
|
||||||
|
rdt.module.load_state_dict(checkpoint["module"])
|
||||||
|
|
||||||
|
load_model(ema_rdt, os.path.join(args.output_dir, path, "ema", "model.safetensors"))
|
||||||
|
global_step = int(path.split("-")[1])
|
||||||
|
|
||||||
|
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||||
|
first_epoch = global_step // num_update_steps_per_epoch
|
||||||
|
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
||||||
|
|
||||||
|
# Only show the progress bar once on each machine.
|
||||||
|
progress_bar = tqdm(
|
||||||
|
range(global_step, args.max_train_steps),
|
||||||
|
disable=not accelerator.is_local_main_process,
|
||||||
|
)
|
||||||
|
progress_bar.set_description("Steps")
|
||||||
|
|
||||||
|
loss_for_log = {}
|
||||||
|
for epoch in range(first_epoch, args.num_train_epochs):
|
||||||
|
|
||||||
|
rdt.train()
|
||||||
|
|
||||||
|
# Set the progress_bar to correct position
|
||||||
|
if args.resume_from_checkpoint and epoch == first_epoch:
|
||||||
|
progress_bar.update(resume_step // args.gradient_accumulation_steps)
|
||||||
|
|
||||||
|
# Forward and backward...
|
||||||
|
for batch in train_dataloader:
|
||||||
|
with accelerator.accumulate(rdt):
|
||||||
|
images = batch["images"].to(dtype=weight_dtype)
|
||||||
|
states = batch["states"].to(dtype=weight_dtype) # (B, T, D_a)
|
||||||
|
# We only use the last state as input
|
||||||
|
states = states[:, -1:, :]
|
||||||
|
actions = batch["actions"].to(dtype=weight_dtype)
|
||||||
|
state_elem_mask = batch["state_elem_mask"].to(dtype=weight_dtype)
|
||||||
|
ctrl_freqs = batch["ctrl_freqs"]
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
batch_size, _, C, H, W = images.shape
|
||||||
|
image_embeds = vision_encoder(images.reshape(-1, C, H, W)).detach()
|
||||||
|
image_embeds = image_embeds.reshape((batch_size, -1, vision_encoder.hidden_size))
|
||||||
|
|
||||||
|
lang_attn_mask = batch["lang_attn_mask"]
|
||||||
|
text_embeds = (batch["lang_embeds"].to(
|
||||||
|
dtype=weight_dtype) if args.precomp_lang_embed else text_encoder(
|
||||||
|
input_ids=batch["input_ids"], attention_mask=lang_attn_mask)["last_hidden_state"].detach())
|
||||||
|
|
||||||
|
state_elem_mask = state_elem_mask.unsqueeze(1)
|
||||||
|
loss = rdt(
|
||||||
|
lang_tokens=text_embeds,
|
||||||
|
lang_attn_mask=lang_attn_mask,
|
||||||
|
img_tokens=image_embeds,
|
||||||
|
state_tokens=states,
|
||||||
|
action_gt=actions,
|
||||||
|
action_mask=state_elem_mask,
|
||||||
|
ctrl_freqs=ctrl_freqs,
|
||||||
|
)
|
||||||
|
|
||||||
|
accelerator.backward(loss)
|
||||||
|
if accelerator.sync_gradients:
|
||||||
|
params_to_clip = rdt.parameters()
|
||||||
|
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||||
|
optimizer.step()
|
||||||
|
lr_scheduler.step()
|
||||||
|
optimizer.zero_grad(set_to_none=args.set_grads_to_none)
|
||||||
|
|
||||||
|
ema_model.step(accelerator.unwrap_model(rdt))
|
||||||
|
|
||||||
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
|
if accelerator.sync_gradients:
|
||||||
|
progress_bar.update(1)
|
||||||
|
global_step += 1
|
||||||
|
|
||||||
|
if global_step % args.checkpointing_period == 0:
|
||||||
|
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||||
|
accelerator.save_state(save_path)
|
||||||
|
ema_save_path = os.path.join(save_path, f"ema")
|
||||||
|
accelerator.save_model(ema_rdt, ema_save_path)
|
||||||
|
logger.info(f"Saved state to {save_path}")
|
||||||
|
|
||||||
|
if args.sample_period > 0 and global_step % args.sample_period == 0:
|
||||||
|
sample_loss_for_log = log_sample_res(
|
||||||
|
text_encoder,
|
||||||
|
vision_encoder,
|
||||||
|
rdt, # We do not use EMA currently
|
||||||
|
args,
|
||||||
|
accelerator,
|
||||||
|
weight_dtype,
|
||||||
|
sample_dataset.get_dataset_id2name(),
|
||||||
|
sample_dataloader,
|
||||||
|
logger,
|
||||||
|
)
|
||||||
|
logger.info(sample_loss_for_log)
|
||||||
|
accelerator.log(sample_loss_for_log, step=global_step)
|
||||||
|
|
||||||
|
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
|
progress_bar.set_postfix(**logs)
|
||||||
|
logs.update(loss_for_log)
|
||||||
|
# logger.info(logs)
|
||||||
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
|
if global_step >= args.max_train_steps:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Create the pipeline using using the trained modules and save it.
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
accelerator.unwrap_model(rdt).save_pretrained(args.output_dir)
|
||||||
|
ema_save_path = os.path.join(args.output_dir, f"ema")
|
||||||
|
accelerator.save_model(ema_rdt, ema_save_path)
|
||||||
|
|
||||||
|
logger.info(f"Saved Model to {args.output_dir}")
|
||||||
|
|
||||||
|
if args.push_to_hub:
|
||||||
|
save_model_card(
|
||||||
|
repo_id,
|
||||||
|
base_model=args.pretrained_model_name_or_path,
|
||||||
|
repo_folder=args.output_dir,
|
||||||
|
)
|
||||||
|
upload_folder(
|
||||||
|
repo_id=repo_id,
|
||||||
|
folder_path=args.output_dir,
|
||||||
|
commit_message="End of training",
|
||||||
|
token=args.hub_token,
|
||||||
|
allow_patterns=["pytorch_model.bin", "*.json", "*.md"],
|
||||||
|
# ignore_patterns=["step_*", "epoch_*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
accelerator.end_training()
|
||||||
@ -10,13 +10,14 @@ ENV TZ=Asia/Shanghai
|
|||||||
RUN sed -i 's/archive.ubuntu.com/mirrors.tuna.tsinghua.edu.cn/g' /etc/apt/sources.list && \
|
RUN sed -i 's/archive.ubuntu.com/mirrors.tuna.tsinghua.edu.cn/g' /etc/apt/sources.list && \
|
||||||
sed -i 's/security.ubuntu.com/mirrors.tuna.tsinghua.edu.cn/g' /etc/apt/sources.list
|
sed -i 's/security.ubuntu.com/mirrors.tuna.tsinghua.edu.cn/g' /etc/apt/sources.list
|
||||||
|
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN apt-get update --allow-unauthenticated && apt-get install -y \
|
||||||
software-properties-common \
|
software-properties-common \
|
||||||
&& add-apt-repository ppa:deadsnakes/ppa \
|
&& add-apt-repository ppa:deadsnakes/ppa \
|
||||||
&& apt-get update \
|
&& apt-get update \
|
||||||
&& apt-get install -y \
|
&& apt-get install -y \
|
||||||
python3.10 \
|
python3.10 \
|
||||||
python3.10-dev \
|
python3.10-dev \
|
||||||
|
python3-pip \
|
||||||
python3.10-distutils \
|
python3.10-distutils \
|
||||||
libgl1-mesa-glx \
|
libgl1-mesa-glx \
|
||||||
libglib2.0-0 \
|
libglib2.0-0 \
|
||||||
@ -30,14 +31,16 @@ RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1
|
|||||||
|
|
||||||
COPY . /app/
|
COPY . /app/
|
||||||
|
|
||||||
RUN python3 -m pip install --upgrade pip
|
RUN python3 -m pip install --upgrade pip -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
|
|
||||||
RUN pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu121
|
# RUN pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu121
|
||||||
|
RUN pip install torch==2.1.0 torchvision==0.16.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
RUN pip3 install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
|
RUN pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
|
|
||||||
RUN pip install packaging==24.0
|
RUN pip install packaging==24.0
|
||||||
|
|
||||||
|
RUN pip install tfds-nightly==4.9.4.dev202402070044
|
||||||
|
|
||||||
RUN pip install flash_attn-2.7.2.post1+cu12torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
RUN pip install flash_attn-2.7.2.post1+cu12torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
||||||
|
|
||||||
RUN mkdir -p /app/dataset/input /app/dataset/output
|
RUN mkdir -p /app/dataset/input /app/dataset/output
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -7,11 +7,10 @@ CONFIG_FILE="model_config/$CONFIG_NAME.yml"
|
|||||||
echo "CONFIG_FILE_PATH: $CONFIG_FILE"
|
echo "CONFIG_FILE_PATH: $CONFIG_FILE"
|
||||||
|
|
||||||
### ============Read Input Config and ReLoad Config YAML===================
|
### ============Read Input Config and ReLoad Config YAML===================
|
||||||
ln -s /home/qi.xiong/Temp/RDT-1B/input/weights ../weights
|
|
||||||
|
|
||||||
TRAIN_CONFIG_FILE="input/config.json"
|
TRAIN_CONFIG_FILE="input/config.json"
|
||||||
echo "TRAIN_CONFIG_FILE_PATH: $TRAIN_CONFIG_FILE"
|
echo "TRAIN_CONFIG_FILE_PATH: $TRAIN_CONFIG_FILE"
|
||||||
python scripts/read_config.py "$TRAIN_CONFIG_FILE" "$CONFIG_FILE"
|
python3 scripts/read_config.py "$TRAIN_CONFIG_FILE" "$CONFIG_FILE"
|
||||||
|
|
||||||
### ============Read Input Config and ReLoad Config YAML===================
|
### ============Read Input Config and ReLoad Config YAML===================
|
||||||
|
|
||||||
@ -34,23 +33,23 @@ if [ ! -f "$CONFIG_FILE" ]; then
|
|||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
PRETRAINED_MODEL_NAME=$(python scripts/read_yaml.py "$CONFIG_FILE" pretrained_model_name_or_path)
|
PRETRAINED_MODEL_NAME=$(python3 scripts/read_yaml.py "$CONFIG_FILE" pretrained_model_name_or_path)
|
||||||
TRAIN_BATCH_SIZE=$(python scripts/read_yaml.py "$CONFIG_FILE" train_batch_size)
|
TRAIN_BATCH_SIZE=$(python3 scripts/read_yaml.py "$CONFIG_FILE" train_batch_size)
|
||||||
SAMPLE_BATCH_SIZE=$(python scripts/read_yaml.py "$CONFIG_FILE" sample_batch_size)
|
SAMPLE_BATCH_SIZE=$(python3 scripts/read_yaml.py "$CONFIG_FILE" sample_batch_size)
|
||||||
MAX_TRAIN_STEPS=$(python scripts/read_yaml.py "$CONFIG_FILE" max_train_steps)
|
MAX_TRAIN_STEPS=$(python3 scripts/read_yaml.py "$CONFIG_FILE" max_train_steps)
|
||||||
CHECKPOINTING_PERIOD=$(python scripts/read_yaml.py "$CONFIG_FILE" checkpointing_period)
|
CHECKPOINTING_PERIOD=$(python3 scripts/read_yaml.py "$CONFIG_FILE" checkpointing_period)
|
||||||
SAMPLE_PERIOD=$(python scripts/read_yaml.py "$CONFIG_FILE" sample_period)
|
SAMPLE_PERIOD=$(python3 scripts/read_yaml.py "$CONFIG_FILE" sample_period)
|
||||||
CHECKPOINTS_TOTAL_LIMIT=$(python scripts/read_yaml.py "$CONFIG_FILE" checkpoints_total_limit)
|
CHECKPOINTS_TOTAL_LIMIT=$(python3 scripts/read_yaml.py "$CONFIG_FILE" checkpoints_total_limit)
|
||||||
LR_SCHEDULER=$(python scripts/read_yaml.py "$CONFIG_FILE" lr_scheduler)
|
LR_SCHEDULER=$(python3 scripts/read_yaml.py "$CONFIG_FILE" lr_scheduler)
|
||||||
LEARNING_RATE=$(python scripts/read_yaml.py "$CONFIG_FILE" learning_rate)
|
LEARNING_RATE=$(python3 scripts/read_yaml.py "$CONFIG_FILE" learning_rate)
|
||||||
DATALOADER_NUM_WORKERS=$(python scripts/read_yaml.py "$CONFIG_FILE" dataloader_num_workers)
|
DATALOADER_NUM_WORKERS=$(python3 scripts/read_yaml.py "$CONFIG_FILE" dataloader_num_workers)
|
||||||
DATASET_TYPE=$(python scripts/read_yaml.py "$CONFIG_FILE" dataset_type)
|
DATASET_TYPE=$(python3 scripts/read_yaml.py "$CONFIG_FILE" dataset_type)
|
||||||
STATE_NOISE_SNR=$(python scripts/read_yaml.py "$CONFIG_FILE" state_noise_snr)
|
STATE_NOISE_SNR=$(python3 scripts/read_yaml.py "$CONFIG_FILE" state_noise_snr)
|
||||||
GRAD_ACCUM_STEPS=$(python scripts/read_yaml.py "$CONFIG_FILE" gradient_accumulation_steps)
|
GRAD_ACCUM_STEPS=$(python3 scripts/read_yaml.py "$CONFIG_FILE" gradient_accumulation_steps)
|
||||||
OUTPUT_DIR=$(python scripts/read_yaml.py "$CONFIG_FILE" checkpoint_path)
|
OUTPUT_DIR=$(python3 scripts/read_yaml.py "$CONFIG_FILE" checkpoint_path)
|
||||||
CUDA_USE=$(python scripts/read_yaml.py "$CONFIG_FILE" cuda_visible_device)
|
CUDA_USE=$(python3 scripts/read_yaml.py "$CONFIG_FILE" cuda_visible_device)
|
||||||
|
|
||||||
|
|
||||||
|
export WANDB_MODE=disabled
|
||||||
PRETRAINED_MODEL_NAME=$(echo "$PRETRAINED_MODEL_NAME" | tr -d '"')
|
PRETRAINED_MODEL_NAME=$(echo "$PRETRAINED_MODEL_NAME" | tr -d '"')
|
||||||
CUDA_USE=$(echo "$CUDA_USE" | tr -d '"')
|
CUDA_USE=$(echo "$CUDA_USE" | tr -d '"')
|
||||||
OUTPUT_DIR=$(echo "$OUTPUT_DIR" | tr -d '"')
|
OUTPUT_DIR=$(echo "$OUTPUT_DIR" | tr -d '"')
|
||||||
@ -65,7 +64,7 @@ fi
|
|||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES=$CUDA_USE
|
export CUDA_VISIBLE_DEVICES=$CUDA_USE
|
||||||
|
|
||||||
python -m data.compute_dataset_stat_hdf5 --task_name $CONFIG_NAME
|
python3 -m data.compute_dataset_stat_hdf5 --task_name $CONFIG_NAME
|
||||||
|
|
||||||
accelerate launch --main_process_port=28499 main.py \
|
accelerate launch --main_process_port=28499 main.py \
|
||||||
--deepspeed="./configs/zero2.json" \
|
--deepspeed="./configs/zero2.json" \
|
||||||
@ -99,8 +98,8 @@ RUNTIME=$((END_TIME - BEGIN_TIME))
|
|||||||
echo "Total runtime: $RUNTIME seconds"
|
echo "Total runtime: $RUNTIME seconds"
|
||||||
|
|
||||||
### ============Generate Output JSON===================
|
### ============Generate Output JSON===================
|
||||||
|
sleep 10
|
||||||
python scripts/generate_output_json.py "$TRAIN_CONFIG_FILE" "$OUTPUT_DIR" "$RUNTIME"
|
python3 scripts/generate_output_json.py "$TRAIN_CONFIG_FILE" "$OUTPUT_DIR" "$RUNTIME"
|
||||||
|
|
||||||
### ============Generate Output JSON===================
|
### ============Generate Output JSON===================
|
||||||
|
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -10,10 +10,11 @@ sentencepiece==0.2.0
|
|||||||
h5py==3.11.0
|
h5py==3.11.0
|
||||||
opencv-python==4.9.0.80
|
opencv-python==4.9.0.80
|
||||||
imgaug==0.4.0
|
imgaug==0.4.0
|
||||||
pytz>=2020.1
|
pytz==2022.1
|
||||||
|
huggingface_hub==0.23.0
|
||||||
|
|
||||||
# requirements_data.txt
|
# requirements_data.txt
|
||||||
tfds-nightly==4.9.4.dev202402070044
|
# tfds-nightly==4.9.4.dev202402070044
|
||||||
gsutil==5.27
|
gsutil==5.27
|
||||||
tensorflow==2.15.0.post1
|
tensorflow==2.15.0.post1
|
||||||
pillow==10.2.0
|
pillow==10.2.0
|
||||||
|
|||||||
@ -11,7 +11,7 @@ def read_config(config_file, yaml_file):
|
|||||||
yaml_config["model"] = json_config["train"]["model"] + json_config["task_id"]
|
yaml_config["model"] = json_config["train"]["model"] + json_config["task_id"]
|
||||||
yaml_config["data_path"] = json_config["train"]["input_data_path"] + "/data"
|
yaml_config["data_path"] = json_config["train"]["input_data_path"] + "/data"
|
||||||
yaml_config["checkpoint_path"] = json_config["train"]["checkpoint_path"] + "/" + json_config["task_id"]
|
yaml_config["checkpoint_path"] = json_config["train"]["checkpoint_path"] + "/" + json_config["task_id"]
|
||||||
yaml_config["pretrained_model_name_or_path"] = json_config["train"]["input_data_path"] + "/weights/rdt-1b"
|
yaml_config["pretrained_model_name_or_path"] = "/weights/rdt-1b"
|
||||||
yaml_config["cuda_visible_device"] = str(json_config["gpu_id"])
|
yaml_config["cuda_visible_device"] = str(json_config["gpu_id"])
|
||||||
print(f"cuda_visible_device: {yaml_config['cuda_visible_device']}")
|
print(f"cuda_visible_device: {yaml_config['cuda_visible_device']}")
|
||||||
yaml_config["train_batch_size"] = int(json_config["train"]["batch_size"])
|
yaml_config["train_batch_size"] = int(json_config["train"]["batch_size"])
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user