Compare commits
No commits in common. "3168acad93403a3535dbfaca3e93f1d65807dafb" and "e37bf13e1bc7484419d980e28c97c500332db13c" have entirely different histories.
3168acad93
...
e37bf13e1b
7
.gitignore
vendored
7
.gitignore
vendored
@ -1,9 +1,3 @@
|
|||||||
<<<<<<< HEAD
|
|
||||||
input/
|
|
||||||
output/
|
|
||||||
Temp/
|
|
||||||
weights/
|
|
||||||
=======
|
|
||||||
# ---> Python
|
# ---> Python
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
@ -174,4 +168,3 @@ cython_debug/
|
|||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
#.idea/
|
#.idea/
|
||||||
|
|
||||||
>>>>>>> e37bf13e1bc7484419d980e28c97c500332db13c
|
|
||||||
|
|||||||
@ -1,2 +0,0 @@
|
|||||||
input/*
|
|
||||||
output/*
|
|
||||||
7
RDT-1B/.gitignore
vendored
7
RDT-1B/.gitignore
vendored
@ -1,7 +0,0 @@
|
|||||||
processed_data/
|
|
||||||
training_data/
|
|
||||||
checkpoints/
|
|
||||||
model_config/*.yml
|
|
||||||
wandb/*
|
|
||||||
!models/
|
|
||||||
!data/
|
|
||||||
@ -1,45 +0,0 @@
|
|||||||
|
|
||||||
FROM registry.d-robotics.cc/public/cuda:11.8.0-cudnn8-devel-ubuntu22.04
|
|
||||||
# ccr-29eug8s3-pub.cnc.bj.baidubce.com/public/cuda:11.8.0-cudnn8-devel-ubuntu22.04
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
ENV DEBIAN_FRONTEND=noninteractive
|
|
||||||
ENV PYTHONUNBUFFERED=1
|
|
||||||
ENV TZ=Asia/Shanghai
|
|
||||||
|
|
||||||
RUN sed -i 's/archive.ubuntu.com/mirrors.tuna.tsinghua.edu.cn/g' /etc/apt/sources.list && \
|
|
||||||
sed -i 's/security.ubuntu.com/mirrors.tuna.tsinghua.edu.cn/g' /etc/apt/sources.list
|
|
||||||
|
|
||||||
RUN apt-get update && apt-get install -y \
|
|
||||||
software-properties-common \
|
|
||||||
&& add-apt-repository ppa:deadsnakes/ppa \
|
|
||||||
&& apt-get update \
|
|
||||||
&& apt-get install -y \
|
|
||||||
python3.10 \
|
|
||||||
python3.10-dev \
|
|
||||||
python3.10-distutils \
|
|
||||||
libgl1-mesa-glx \
|
|
||||||
libglib2.0-0 \
|
|
||||||
wget \
|
|
||||||
ffmpeg \
|
|
||||||
libsm6 \
|
|
||||||
libxext6 \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1
|
|
||||||
|
|
||||||
COPY . /app/
|
|
||||||
|
|
||||||
RUN python3 -m pip install --upgrade pip
|
|
||||||
|
|
||||||
RUN pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu121
|
|
||||||
|
|
||||||
RUN pip3 install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
|
|
||||||
|
|
||||||
RUN pip install packaging==24.0
|
|
||||||
|
|
||||||
RUN pip install flash_attn-2.7.2.post1+cu12torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
|
||||||
|
|
||||||
RUN mkdir -p /app/dataset/input /app/dataset/output
|
|
||||||
|
|
||||||
ENTRYPOINT ["bash", "finetune.sh"]
|
|
||||||
@ -1 +0,0 @@
|
|||||||
from .deploy_policy import *
|
|
||||||
Binary file not shown.
|
Before Width: | Height: | Size: 726 KiB |
Binary file not shown.
@ -1,71 +0,0 @@
|
|||||||
common:
|
|
||||||
# The number of historical images
|
|
||||||
img_history_size: 2
|
|
||||||
# The number of future actions to predict
|
|
||||||
action_chunk_size: 64
|
|
||||||
# The number of cameras to be used in the model
|
|
||||||
num_cameras: 3
|
|
||||||
# Dimension for state/action, we use the same space for both state and action
|
|
||||||
# This MUST be equal to configs/state_vec.py
|
|
||||||
state_dim: 128
|
|
||||||
|
|
||||||
|
|
||||||
dataset:
|
|
||||||
# We will extract the data from raw dataset
|
|
||||||
# and store them in the disk buffer by producer
|
|
||||||
# When training, we will read the data
|
|
||||||
# randomly from the buffer by consumer
|
|
||||||
# The producer will replace the data which has been
|
|
||||||
# read by the consumer with new data
|
|
||||||
|
|
||||||
# The path to the buffer (at least 400GB)
|
|
||||||
buf_path: /path/to/buffer
|
|
||||||
# The number of chunks in the buffer
|
|
||||||
buf_num_chunks: 512
|
|
||||||
# The number of samples (step rather than episode) in each chunk
|
|
||||||
buf_chunk_size: 512
|
|
||||||
|
|
||||||
# We will filter the episodes with length less than `epsd_len_thresh_low`
|
|
||||||
epsd_len_thresh_low: 32
|
|
||||||
# For those more than `epsd_len_thresh_high`,
|
|
||||||
# we will randomly sample `epsd_len_thresh_high` steps each time we load the episode
|
|
||||||
# to better balance the training datasets
|
|
||||||
epsd_len_thresh_high: 2048
|
|
||||||
# How to fit the image size
|
|
||||||
image_aspect_ratio: pad
|
|
||||||
# Maximum number of language tokens
|
|
||||||
tokenizer_max_length: 1024
|
|
||||||
|
|
||||||
model:
|
|
||||||
# Config for condition adpators
|
|
||||||
lang_adaptor: mlp2x_gelu
|
|
||||||
img_adaptor: mlp2x_gelu
|
|
||||||
state_adaptor: mlp3x_gelu
|
|
||||||
lang_token_dim: 4096
|
|
||||||
img_token_dim: 1152
|
|
||||||
# Dim of action or proprioception vector
|
|
||||||
# A `state` refers to an action or a proprioception vector
|
|
||||||
state_token_dim: 128
|
|
||||||
# Config for RDT structure
|
|
||||||
rdt:
|
|
||||||
# 1B: num_head 32 hidden_size 2048
|
|
||||||
hidden_size: 2048
|
|
||||||
depth: 28
|
|
||||||
num_heads: 32
|
|
||||||
cond_pos_embed_type: multimodal
|
|
||||||
# For noise scheduler
|
|
||||||
noise_scheduler:
|
|
||||||
type: ddpm
|
|
||||||
num_train_timesteps: 1000
|
|
||||||
num_inference_timesteps: 5
|
|
||||||
beta_schedule: squaredcos_cap_v2 # Critical choice
|
|
||||||
prediction_type: sample
|
|
||||||
clip_sample: False
|
|
||||||
# For EMA (params averaging)
|
|
||||||
# We do not use EMA currently
|
|
||||||
ema:
|
|
||||||
update_after_step: 0
|
|
||||||
inv_gamma: 1.0
|
|
||||||
power: 0.75
|
|
||||||
min_value: 0.0
|
|
||||||
max_value: 0.9999
|
|
||||||
@ -1,50 +0,0 @@
|
|||||||
{
|
|
||||||
"A": [
|
|
||||||
[
|
|
||||||
-0.2691913843154907,
|
|
||||||
-0.21995729207992554,
|
|
||||||
-0.182277649641037
|
|
||||||
],
|
|
||||||
[
|
|
||||||
0.35127854347229004,
|
|
||||||
0.2769763469696045,
|
|
||||||
0.17159393429756165
|
|
||||||
]
|
|
||||||
],
|
|
||||||
"B": [
|
|
||||||
[
|
|
||||||
-0.2576896846294403,
|
|
||||||
-0.22244493663311005,
|
|
||||||
-0.20557966828346252
|
|
||||||
],
|
|
||||||
[
|
|
||||||
0.32854634523391724,
|
|
||||||
0.2922680974006653,
|
|
||||||
0.17373555898666382
|
|
||||||
]
|
|
||||||
],
|
|
||||||
"C": [
|
|
||||||
[
|
|
||||||
-0.29205888509750366,
|
|
||||||
-0.24688798189163208,
|
|
||||||
-0.17577645182609558
|
|
||||||
],
|
|
||||||
[
|
|
||||||
0.25053921341896057,
|
|
||||||
0.3277084231376648,
|
|
||||||
0.16431939601898193
|
|
||||||
]
|
|
||||||
],
|
|
||||||
"D": [
|
|
||||||
[
|
|
||||||
-0.25131964683532715,
|
|
||||||
-0.15233077108860016,
|
|
||||||
-0.13294968008995056
|
|
||||||
],
|
|
||||||
[
|
|
||||||
0.19209328293800354,
|
|
||||||
0.19344553351402283,
|
|
||||||
0.1370421051979065
|
|
||||||
]
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@ -1,65 +0,0 @@
|
|||||||
{
|
|
||||||
"fractal20220817_data": 3,
|
|
||||||
"taco_play": 15,
|
|
||||||
"jaco_play": 10,
|
|
||||||
"berkeley_cable_routing": 10,
|
|
||||||
"nyu_door_opening_surprising_effectiveness": 3,
|
|
||||||
"viola": 20,
|
|
||||||
"berkeley_autolab_ur5": 5,
|
|
||||||
"toto": 30,
|
|
||||||
"kuka": 10,
|
|
||||||
"language_table": 10,
|
|
||||||
"columbia_cairlab_pusht_real": 10,
|
|
||||||
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": 20,
|
|
||||||
"nyu_rot_dataset_converted_externally_to_rlds":3,
|
|
||||||
"stanford_hydra_dataset_converted_externally_to_rlds": 10,
|
|
||||||
"austin_buds_dataset_converted_externally_to_rlds": 20,
|
|
||||||
"nyu_franka_play_dataset_converted_externally_to_rlds": 3,
|
|
||||||
"maniskill_dataset_converted_externally_to_rlds": 20,
|
|
||||||
"furniture_bench_dataset_converted_externally_to_rlds": 10,
|
|
||||||
"ucsd_kitchen_dataset_converted_externally_to_rlds": 2,
|
|
||||||
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": 3,
|
|
||||||
"austin_sailor_dataset_converted_externally_to_rlds": 20,
|
|
||||||
"austin_sirius_dataset_converted_externally_to_rlds": 20,
|
|
||||||
"bc_z": 10,
|
|
||||||
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": 10,
|
|
||||||
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": 10,
|
|
||||||
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": 10,
|
|
||||||
"utokyo_xarm_bimanual_converted_externally_to_rlds": 10,
|
|
||||||
"berkeley_mvp_converted_externally_to_rlds": 5,
|
|
||||||
"berkeley_rpt_converted_externally_to_rlds": 30,
|
|
||||||
"kaist_nonprehensile_converted_externally_to_rlds": 10,
|
|
||||||
"stanford_mask_vit_converted_externally_to_rlds": 0,
|
|
||||||
"tokyo_u_lsmo_converted_externally_to_rlds": 10,
|
|
||||||
"dlr_sara_pour_converted_externally_to_rlds": 10,
|
|
||||||
"dlr_sara_grid_clamp_converted_externally_to_rlds": 10,
|
|
||||||
"dlr_edan_shared_control_converted_externally_to_rlds": 5,
|
|
||||||
"asu_table_top_converted_externally_to_rlds": 12.5,
|
|
||||||
"stanford_robocook_converted_externally_to_rlds": 5,
|
|
||||||
"eth_agent_affordances": 66.6,
|
|
||||||
"imperialcollege_sawyer_wrist_cam": 10,
|
|
||||||
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": 20,
|
|
||||||
"uiuc_d3field": 1,
|
|
||||||
"utaustin_mutex": 20,
|
|
||||||
"berkeley_fanuc_manipulation": 10,
|
|
||||||
"cmu_play_fusion": 5,
|
|
||||||
"cmu_stretch": 10,
|
|
||||||
"berkeley_gnm_recon": 3,
|
|
||||||
"berkeley_gnm_cory_hall": 5,
|
|
||||||
"berkeley_gnm_sac_son": 10,
|
|
||||||
"robo_net": 1,
|
|
||||||
"roboturk_real_towercreation": 10,
|
|
||||||
"roboturk_real_laundrylayout": 10,
|
|
||||||
"roboturk_real_objectsearch": 10,
|
|
||||||
"aloha_mobile": 50,
|
|
||||||
"aloha_static": 50,
|
|
||||||
"roboset": 5,
|
|
||||||
"droid": 15,
|
|
||||||
"fmb": 10,
|
|
||||||
"dobbe": 30,
|
|
||||||
"qut_dexterous_manpulation": 30,
|
|
||||||
"agilex": 25,
|
|
||||||
"rh20t": 10,
|
|
||||||
"calvin": 30,
|
|
||||||
"bridgev2": 5
|
|
||||||
}
|
|
||||||
@ -1,575 +0,0 @@
|
|||||||
{
|
|
||||||
"fractal20220817_data": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image"
|
|
||||||
],
|
|
||||||
"image_mask":[
|
|
||||||
1,0,0,0
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"taco_play": {
|
|
||||||
"image_keys": [
|
|
||||||
"rgb_static",
|
|
||||||
"rgb_gripper",
|
|
||||||
"rgb_static",
|
|
||||||
"rgb_static"
|
|
||||||
],
|
|
||||||
"image_mask":[
|
|
||||||
1,1,0,0
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"jaco_play": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"image_wrist",
|
|
||||||
"image_wrist",
|
|
||||||
"image_wrist"
|
|
||||||
],
|
|
||||||
"image_mask":[
|
|
||||||
1,1,0,0
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"berkeley_cable_routing": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"wrist45_image",
|
|
||||||
"wrist225_image",
|
|
||||||
"top_image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,1,0,1]
|
|
||||||
},
|
|
||||||
"nyu_door_opening_surprising_effectiveness": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,0]
|
|
||||||
},
|
|
||||||
"viola": {
|
|
||||||
"image_keys": [
|
|
||||||
"agentview_rgb",
|
|
||||||
"eye_in_hand_rgb",
|
|
||||||
"eye_in_hand_rgb",
|
|
||||||
"eye_in_hand_rgb"
|
|
||||||
],
|
|
||||||
"image_mask":[1,1,0,0]
|
|
||||||
},
|
|
||||||
"berkeley_autolab_ur5": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"hand_image",
|
|
||||||
"hand_image",
|
|
||||||
"hand_image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,1,0,0]
|
|
||||||
},
|
|
||||||
"toto": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,0]
|
|
||||||
},
|
|
||||||
"kuka": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,0]
|
|
||||||
},
|
|
||||||
"language_table": {
|
|
||||||
"image_keys": [
|
|
||||||
"rgb",
|
|
||||||
"rgb",
|
|
||||||
"rgb",
|
|
||||||
"rgb"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,0]
|
|
||||||
},
|
|
||||||
"columbia_cairlab_pusht_real": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"wrist_image",
|
|
||||||
"wrist_image",
|
|
||||||
"wrist_image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,1,0,0]
|
|
||||||
},
|
|
||||||
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,0]
|
|
||||||
},
|
|
||||||
"nyu_rot_dataset_converted_externally_to_rlds": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,0]
|
|
||||||
},
|
|
||||||
"stanford_hydra_dataset_converted_externally_to_rlds": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"wrist_image",
|
|
||||||
"wrist_image",
|
|
||||||
"wrist_image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,1,0,0]
|
|
||||||
},
|
|
||||||
"austin_buds_dataset_converted_externally_to_rlds": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"wrist_image",
|
|
||||||
"wrist_image",
|
|
||||||
"wrist_image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,1,0,0]
|
|
||||||
},
|
|
||||||
"nyu_franka_play_dataset_converted_externally_to_rlds": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"image_additional_view",
|
|
||||||
"image_additional_view",
|
|
||||||
"image_additional_view"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,1]
|
|
||||||
},
|
|
||||||
"maniskill_dataset_converted_externally_to_rlds": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"wrist_image",
|
|
||||||
"wrist_image",
|
|
||||||
"wrist_image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,1,0,0]
|
|
||||||
},
|
|
||||||
"furniture_bench_dataset_converted_externally_to_rlds": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"wrist_image",
|
|
||||||
"wrist_image",
|
|
||||||
"wrist_image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,1,0,0]
|
|
||||||
},
|
|
||||||
"ucsd_kitchen_dataset_converted_externally_to_rlds": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,0]
|
|
||||||
},
|
|
||||||
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,0]
|
|
||||||
},
|
|
||||||
"austin_sailor_dataset_converted_externally_to_rlds": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"wrist_image",
|
|
||||||
"wrist_image",
|
|
||||||
"wrist_image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,1,0,0]
|
|
||||||
},
|
|
||||||
"austin_sirius_dataset_converted_externally_to_rlds": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"wrist_image",
|
|
||||||
"wrist_image",
|
|
||||||
"wrist_image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,1,0,0]
|
|
||||||
},
|
|
||||||
"bc_z": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,0]
|
|
||||||
},
|
|
||||||
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,0]
|
|
||||||
},
|
|
||||||
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,0]
|
|
||||||
},
|
|
||||||
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"hand_image",
|
|
||||||
"hand_image",
|
|
||||||
"image2"
|
|
||||||
],
|
|
||||||
"image_mask":[1,1,0,1]
|
|
||||||
},
|
|
||||||
"utokyo_xarm_bimanual_converted_externally_to_rlds": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,0]
|
|
||||||
},
|
|
||||||
"berkeley_mvp_converted_externally_to_rlds": {
|
|
||||||
"image_keys": [
|
|
||||||
"hand_image",
|
|
||||||
"hand_image",
|
|
||||||
"hand_image",
|
|
||||||
"hand_image"
|
|
||||||
],
|
|
||||||
"image_mask":[0,1,0,0]
|
|
||||||
},
|
|
||||||
"berkeley_rpt_converted_externally_to_rlds": {
|
|
||||||
"image_keys": [
|
|
||||||
"hand_image",
|
|
||||||
"hand_image",
|
|
||||||
"hand_image",
|
|
||||||
"hand_image"
|
|
||||||
],
|
|
||||||
"image_mask":[0,1,0,0]
|
|
||||||
},
|
|
||||||
"kaist_nonprehensile_converted_externally_to_rlds": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,0]
|
|
||||||
},
|
|
||||||
"stanford_mask_vit_converted_externally_to_rlds": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,0]
|
|
||||||
},
|
|
||||||
"tokyo_u_lsmo_converted_externally_to_rlds": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,0]
|
|
||||||
},
|
|
||||||
"dlr_sara_pour_converted_externally_to_rlds": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,0]
|
|
||||||
},
|
|
||||||
"dlr_sara_grid_clamp_converted_externally_to_rlds": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,0]
|
|
||||||
},
|
|
||||||
"dlr_edan_shared_control_converted_externally_to_rlds": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,0]
|
|
||||||
},
|
|
||||||
"asu_table_top_converted_externally_to_rlds": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,0]
|
|
||||||
},
|
|
||||||
"stanford_robocook_converted_externally_to_rlds": {
|
|
||||||
"image_keys": [
|
|
||||||
"image_2",
|
|
||||||
"image_1",
|
|
||||||
"image_3",
|
|
||||||
"image_4"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,1]
|
|
||||||
},
|
|
||||||
"eth_agent_affordances": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,0]
|
|
||||||
},
|
|
||||||
"imperialcollege_sawyer_wrist_cam": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"wrist_image",
|
|
||||||
"wrist_image",
|
|
||||||
"wrist_image"
|
|
||||||
],
|
|
||||||
"image_mask":[0,1,0,0]
|
|
||||||
},
|
|
||||||
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"wrist_image",
|
|
||||||
"wrist_image",
|
|
||||||
"wrist_image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,1,0,0]
|
|
||||||
},
|
|
||||||
"uiuc_d3field": {
|
|
||||||
"image_keys": [
|
|
||||||
"image_1",
|
|
||||||
"image_2",
|
|
||||||
"image_3",
|
|
||||||
"image_4"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,1]
|
|
||||||
},
|
|
||||||
"utaustin_mutex": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"wrist_image",
|
|
||||||
"wrist_image",
|
|
||||||
"wrist_image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,1,0,0]
|
|
||||||
},
|
|
||||||
"berkeley_fanuc_manipulation": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"wrist_image",
|
|
||||||
"wrist_image",
|
|
||||||
"wrist_image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,1,0,0]
|
|
||||||
},
|
|
||||||
"cmu_play_fusion": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,0]
|
|
||||||
},
|
|
||||||
"cmu_stretch": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,0]
|
|
||||||
},
|
|
||||||
"berkeley_gnm_recon": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,0]
|
|
||||||
},
|
|
||||||
"berkeley_gnm_cory_hall": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,0]
|
|
||||||
},
|
|
||||||
"berkeley_gnm_sac_son": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,0]
|
|
||||||
},
|
|
||||||
"robo_net": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"image1",
|
|
||||||
"image2",
|
|
||||||
"image2"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,1]
|
|
||||||
},
|
|
||||||
"roboturk_real_towercreation": {
|
|
||||||
"image_keys": [
|
|
||||||
"top_rgb_frame",
|
|
||||||
"front_rgb_frame",
|
|
||||||
"front_rgb_frame",
|
|
||||||
"front_rgb_frame"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,1]
|
|
||||||
},
|
|
||||||
"roboturk_real_laundrylayout": {
|
|
||||||
"image_keys": [
|
|
||||||
"top_rgb_frame",
|
|
||||||
"front_rgb_frame",
|
|
||||||
"front_rgb_frame",
|
|
||||||
"front_rgb_frame"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,1]
|
|
||||||
},
|
|
||||||
"roboturk_real_objectsearch": {
|
|
||||||
"image_keys": [
|
|
||||||
"top_rgb_frame",
|
|
||||||
"front_rgb_frame",
|
|
||||||
"front_rgb_frame",
|
|
||||||
"front_rgb_frame"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,1]
|
|
||||||
},
|
|
||||||
"aloha_mobile": {
|
|
||||||
"image_keys": [
|
|
||||||
"cam_high",
|
|
||||||
"cam_right_wrist",
|
|
||||||
"cam_left_wrist",
|
|
||||||
"cam_right_wrist"
|
|
||||||
],
|
|
||||||
"image_mask":[1,1,1,0]
|
|
||||||
},
|
|
||||||
"aloha_static": {
|
|
||||||
"image_keys": [
|
|
||||||
"cam_high",
|
|
||||||
"cam_right_wrist",
|
|
||||||
"cam_left_wrist",
|
|
||||||
"cam_low"
|
|
||||||
],
|
|
||||||
"image_mask":[1,1,1,1]
|
|
||||||
},
|
|
||||||
"roboset": {
|
|
||||||
"image_keys": [
|
|
||||||
"rgb_top",
|
|
||||||
"rgb_right",
|
|
||||||
"rgb_left",
|
|
||||||
"rgb_right"
|
|
||||||
],
|
|
||||||
"image_mask":[1,1,1,0]
|
|
||||||
},
|
|
||||||
"droid": {
|
|
||||||
"image_keys": [
|
|
||||||
"exterior_image_1_left",
|
|
||||||
"wrist_image_left",
|
|
||||||
"wrist_image_left",
|
|
||||||
"exterior_image_2_left"
|
|
||||||
],
|
|
||||||
"image_mask":[1,1,0,1]
|
|
||||||
},
|
|
||||||
"fmb": {
|
|
||||||
"image_keys": [
|
|
||||||
"image_side_1",
|
|
||||||
"image_wrist_1",
|
|
||||||
"image_wrist_1",
|
|
||||||
"image_side_2"
|
|
||||||
],
|
|
||||||
"image_mask":[1,1,0,1]
|
|
||||||
},
|
|
||||||
"dobbe": {
|
|
||||||
"image_keys": [
|
|
||||||
"wrist_image",
|
|
||||||
"wrist_image",
|
|
||||||
"wrist_image",
|
|
||||||
"wrist_image"
|
|
||||||
],
|
|
||||||
"image_mask":[0,1,0,0]
|
|
||||||
},
|
|
||||||
"qut_dexterous_manpulation": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"wrist_image",
|
|
||||||
"wrist_image",
|
|
||||||
"wrist_image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,1,0,0]
|
|
||||||
},
|
|
||||||
"agilex": {
|
|
||||||
"image_keys": [
|
|
||||||
"cam_high",
|
|
||||||
"cam_right_wrist",
|
|
||||||
"cam_left_wrist",
|
|
||||||
"cam_right_wrist"
|
|
||||||
],
|
|
||||||
"image_mask":[1,1,1,0]
|
|
||||||
},
|
|
||||||
"rh20t": {
|
|
||||||
"image_keys": [
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image",
|
|
||||||
"image"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,0]
|
|
||||||
},
|
|
||||||
"calvin": {
|
|
||||||
"image_keys": [
|
|
||||||
"rgb_static",
|
|
||||||
"rgb_gripper",
|
|
||||||
"rgb_gripper",
|
|
||||||
"rgb_gripper"
|
|
||||||
],
|
|
||||||
"image_mask":[1,1,0,0]
|
|
||||||
},
|
|
||||||
"bridgev2": {
|
|
||||||
"image_keys": [
|
|
||||||
"images0",
|
|
||||||
"images0",
|
|
||||||
"images0",
|
|
||||||
"images0"
|
|
||||||
],
|
|
||||||
"image_mask":[1,0,0,0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,525 +0,0 @@
|
|||||||
{
|
|
||||||
"agilex": {
|
|
||||||
"dataset_name": "agilex",
|
|
||||||
"state_mean": [
|
|
||||||
-0.0036545392947090432,
|
|
||||||
-0.2773659935760079,
|
|
||||||
0.3147616748061523,
|
|
||||||
0.3813313179910183,
|
|
||||||
0.04028575944090457,
|
|
||||||
0.034888520819083294,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0
|
|
||||||
],
|
|
||||||
"state_std": [
|
|
||||||
0.05763674563578847,
|
|
||||||
0.2580181064167735,
|
|
||||||
0.19785840483767897,
|
|
||||||
0.05020347749331385,
|
|
||||||
0.054529239104671424,
|
|
||||||
0.05020521339363586,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0
|
|
||||||
],
|
|
||||||
"state_min": [
|
|
||||||
-0.17447535196940103,
|
|
||||||
-0.5522612677680121,
|
|
||||||
-0.3340397516886393,
|
|
||||||
0.21861712137858072,
|
|
||||||
-0.09725829230414497,
|
|
||||||
0.003396739231215583,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0
|
|
||||||
],
|
|
||||||
"state_max": [
|
|
||||||
0.21961932712131077,
|
|
||||||
0.30613206227620443,
|
|
||||||
0.5444545321994357,
|
|
||||||
0.4866888682047526,
|
|
||||||
0.31486290825737845,
|
|
||||||
0.3355223337809245,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
0.0
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
[
|
|
||||||
"agilex"
|
|
||||||
]
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
{
|
|
||||||
"agilex": 100
|
|
||||||
}
|
|
||||||
@ -1,48 +0,0 @@
|
|||||||
[
|
|
||||||
"fractal20220817_data",
|
|
||||||
"jaco_play",
|
|
||||||
"taco_play",
|
|
||||||
"berkeley_cable_routing",
|
|
||||||
"viola",
|
|
||||||
"berkeley_autolab_ur5",
|
|
||||||
"toto",
|
|
||||||
"nyu_door_opening_surprising_effectiveness",
|
|
||||||
"columbia_cairlab_pusht_real",
|
|
||||||
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds",
|
|
||||||
"austin_buds_dataset_converted_externally_to_rlds",
|
|
||||||
"kuka",
|
|
||||||
"utokyo_xarm_bimanual_converted_externally_to_rlds",
|
|
||||||
"stanford_hydra_dataset_converted_externally_to_rlds",
|
|
||||||
"maniskill_dataset_converted_externally_to_rlds",
|
|
||||||
"ucsd_kitchen_dataset_converted_externally_to_rlds",
|
|
||||||
"ucsd_pick_and_place_dataset_converted_externally_to_rlds",
|
|
||||||
"austin_sailor_dataset_converted_externally_to_rlds",
|
|
||||||
"austin_sirius_dataset_converted_externally_to_rlds",
|
|
||||||
"bc_z",
|
|
||||||
"utokyo_pr2_opening_fridge_converted_externally_to_rlds",
|
|
||||||
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds",
|
|
||||||
"utokyo_xarm_pick_and_place_converted_externally_to_rlds",
|
|
||||||
"berkeley_mvp_converted_externally_to_rlds",
|
|
||||||
"berkeley_rpt_converted_externally_to_rlds",
|
|
||||||
"kaist_nonprehensile_converted_externally_to_rlds",
|
|
||||||
"tokyo_u_lsmo_converted_externally_to_rlds",
|
|
||||||
"dlr_sara_grid_clamp_converted_externally_to_rlds",
|
|
||||||
"stanford_robocook_converted_externally_to_rlds",
|
|
||||||
"imperialcollege_sawyer_wrist_cam",
|
|
||||||
"iamlab_cmu_pickup_insert_converted_externally_to_rlds",
|
|
||||||
"utaustin_mutex",
|
|
||||||
"berkeley_fanuc_manipulation",
|
|
||||||
"cmu_play_fusion",
|
|
||||||
"language_table",
|
|
||||||
"furniture_bench_dataset_converted_externally_to_rlds",
|
|
||||||
"droid",
|
|
||||||
"fmb",
|
|
||||||
"dobbe",
|
|
||||||
"qut_dexterous_manpulation",
|
|
||||||
"aloha_mobile",
|
|
||||||
"aloha_static",
|
|
||||||
"roboset",
|
|
||||||
"rh20t",
|
|
||||||
"calvin",
|
|
||||||
"bridgev2"
|
|
||||||
]
|
|
||||||
@ -1,48 +0,0 @@
|
|||||||
{
|
|
||||||
"fractal20220817_data": 271,
|
|
||||||
"taco_play": 60,
|
|
||||||
"jaco_play": 33,
|
|
||||||
"berkeley_cable_routing": 8,
|
|
||||||
"nyu_door_opening_surprising_effectiveness": 10,
|
|
||||||
"viola": 12,
|
|
||||||
"berkeley_autolab_ur5": 32,
|
|
||||||
"toto": 32,
|
|
||||||
"kuka": 50,
|
|
||||||
"language_table": 100,
|
|
||||||
"columbia_cairlab_pusht_real": 12,
|
|
||||||
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": 55,
|
|
||||||
"stanford_hydra_dataset_converted_externally_to_rlds": 24,
|
|
||||||
"austin_buds_dataset_converted_externally_to_rlds": 7,
|
|
||||||
"maniskill_dataset_converted_externally_to_rlds": 174,
|
|
||||||
"furniture_bench_dataset_converted_externally_to_rlds": 71,
|
|
||||||
"ucsd_kitchen_dataset_converted_externally_to_rlds": 12,
|
|
||||||
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": 37,
|
|
||||||
"austin_sailor_dataset_converted_externally_to_rlds": 15,
|
|
||||||
"austin_sirius_dataset_converted_externally_to_rlds": 24,
|
|
||||||
"bc_z": 208,
|
|
||||||
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": 9,
|
|
||||||
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": 15,
|
|
||||||
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": 10,
|
|
||||||
"utokyo_xarm_bimanual_converted_externally_to_rlds": 1,
|
|
||||||
"berkeley_mvp_converted_externally_to_rlds": 22,
|
|
||||||
"berkeley_rpt_converted_externally_to_rlds": 30,
|
|
||||||
"kaist_nonprehensile_converted_externally_to_rlds": 14,
|
|
||||||
"tokyo_u_lsmo_converted_externally_to_rlds": 7,
|
|
||||||
"dlr_sara_grid_clamp_converted_externally_to_rlds": 1,
|
|
||||||
"stanford_robocook_converted_externally_to_rlds": 50,
|
|
||||||
"imperialcollege_sawyer_wrist_cam": 13,
|
|
||||||
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": 25,
|
|
||||||
"utaustin_mutex": 39,
|
|
||||||
"berkeley_fanuc_manipulation": 20,
|
|
||||||
"cmu_play_fusion": 24,
|
|
||||||
"droid": 303,
|
|
||||||
"fmb": 42,
|
|
||||||
"dobbe": 36,
|
|
||||||
"qut_dexterous_manpulation": 14,
|
|
||||||
"aloha_mobile": 150,
|
|
||||||
"aloha_static": 150,
|
|
||||||
"roboset": 135,
|
|
||||||
"rh20t": 331,
|
|
||||||
"calvin": 100,
|
|
||||||
"bridgev2": 224
|
|
||||||
}
|
|
||||||
@ -1,126 +0,0 @@
|
|||||||
STATE_VEC_IDX_MAPPING = {
|
|
||||||
# [0, 10): right arm joint positions
|
|
||||||
**{
|
|
||||||
"arm_joint_{}_pos".format(i): i
|
|
||||||
for i in range(10)
|
|
||||||
},
|
|
||||||
**{
|
|
||||||
"right_arm_joint_{}_pos".format(i): i
|
|
||||||
for i in range(10)
|
|
||||||
},
|
|
||||||
# [10, 15): right gripper joint positions
|
|
||||||
**{
|
|
||||||
"gripper_joint_{}_pos".format(i): i + 10
|
|
||||||
for i in range(5)
|
|
||||||
},
|
|
||||||
**{
|
|
||||||
"right_gripper_joint_{}_pos".format(i): i + 10
|
|
||||||
for i in range(5)
|
|
||||||
},
|
|
||||||
"gripper_open": 10, # alias of right_gripper_joint_0_pos
|
|
||||||
"right_gripper_open": 10,
|
|
||||||
# [15, 25): right arm joint velocities
|
|
||||||
**{
|
|
||||||
"arm_joint_{}_vel".format(i): i + 15
|
|
||||||
for i in range(10)
|
|
||||||
},
|
|
||||||
**{
|
|
||||||
"right_arm_joint_{}_vel".format(i): i + 15
|
|
||||||
for i in range(10)
|
|
||||||
},
|
|
||||||
# [25, 30): right gripper joint velocities
|
|
||||||
**{
|
|
||||||
"gripper_joint_{}_vel".format(i): i + 25
|
|
||||||
for i in range(5)
|
|
||||||
},
|
|
||||||
**{
|
|
||||||
"right_gripper_joint_{}_vel".format(i): i + 25
|
|
||||||
for i in range(5)
|
|
||||||
},
|
|
||||||
"gripper_open_vel": 25, # alias of right_gripper_joint_0_vel
|
|
||||||
"right_gripper_open_vel": 25,
|
|
||||||
# [30, 33): right end effector positions
|
|
||||||
"eef_pos_x": 30,
|
|
||||||
"right_eef_pos_x": 30,
|
|
||||||
"eef_pos_y": 31,
|
|
||||||
"right_eef_pos_y": 31,
|
|
||||||
"eef_pos_z": 32,
|
|
||||||
"right_eef_pos_z": 32,
|
|
||||||
# [33, 39): right end effector 6D pose
|
|
||||||
"eef_angle_0": 33,
|
|
||||||
"right_eef_angle_0": 33,
|
|
||||||
"eef_angle_1": 34,
|
|
||||||
"right_eef_angle_1": 34,
|
|
||||||
"eef_angle_2": 35,
|
|
||||||
"right_eef_angle_2": 35,
|
|
||||||
"eef_angle_3": 36,
|
|
||||||
"right_eef_angle_3": 36,
|
|
||||||
"eef_angle_4": 37,
|
|
||||||
"right_eef_angle_4": 37,
|
|
||||||
"eef_angle_5": 38,
|
|
||||||
"right_eef_angle_5": 38,
|
|
||||||
# [39, 42): right end effector velocities
|
|
||||||
"eef_vel_x": 39,
|
|
||||||
"right_eef_vel_x": 39,
|
|
||||||
"eef_vel_y": 40,
|
|
||||||
"right_eef_vel_y": 40,
|
|
||||||
"eef_vel_z": 41,
|
|
||||||
"right_eef_vel_z": 41,
|
|
||||||
# [42, 45): right end effector angular velocities
|
|
||||||
"eef_angular_vel_roll": 42,
|
|
||||||
"right_eef_angular_vel_roll": 42,
|
|
||||||
"eef_angular_vel_pitch": 43,
|
|
||||||
"right_eef_angular_vel_pitch": 43,
|
|
||||||
"eef_angular_vel_yaw": 44,
|
|
||||||
"right_eef_angular_vel_yaw": 44,
|
|
||||||
# [45, 50): reserved
|
|
||||||
# [50, 60): left arm joint positions
|
|
||||||
**{
|
|
||||||
"left_arm_joint_{}_pos".format(i): i + 50
|
|
||||||
for i in range(10)
|
|
||||||
},
|
|
||||||
# [60, 65): left gripper joint positions
|
|
||||||
**{
|
|
||||||
"left_gripper_joint_{}_pos".format(i): i + 60
|
|
||||||
for i in range(5)
|
|
||||||
},
|
|
||||||
"left_gripper_open": 60, # alias of left_gripper_joint_0_pos
|
|
||||||
# [65, 75): left arm joint velocities
|
|
||||||
**{
|
|
||||||
"left_arm_joint_{}_vel".format(i): i + 65
|
|
||||||
for i in range(10)
|
|
||||||
},
|
|
||||||
# [75, 80): left gripper joint velocities
|
|
||||||
**{
|
|
||||||
"left_gripper_joint_{}_vel".format(i): i + 75
|
|
||||||
for i in range(5)
|
|
||||||
},
|
|
||||||
"left_gripper_open_vel": 75, # alias of left_gripper_joint_0_vel
|
|
||||||
# [80, 83): left end effector positions
|
|
||||||
"left_eef_pos_x": 80,
|
|
||||||
"left_eef_pos_y": 81,
|
|
||||||
"left_eef_pos_z": 82,
|
|
||||||
# [83, 89): left end effector 6D pose
|
|
||||||
"left_eef_angle_0": 83,
|
|
||||||
"left_eef_angle_1": 84,
|
|
||||||
"left_eef_angle_2": 85,
|
|
||||||
"left_eef_angle_3": 86,
|
|
||||||
"left_eef_angle_4": 87,
|
|
||||||
"left_eef_angle_5": 88,
|
|
||||||
# [89, 92): left end effector velocities
|
|
||||||
"left_eef_vel_x": 89,
|
|
||||||
"left_eef_vel_y": 90,
|
|
||||||
"left_eef_vel_z": 91,
|
|
||||||
# [92, 95): left end effector angular velocities
|
|
||||||
"left_eef_angular_vel_roll": 92,
|
|
||||||
"left_eef_angular_vel_pitch": 93,
|
|
||||||
"left_eef_angular_vel_yaw": 94,
|
|
||||||
# [95, 100): reserved
|
|
||||||
# [100, 102): base linear velocities
|
|
||||||
"base_vel_x": 100,
|
|
||||||
"base_vel_y": 101,
|
|
||||||
# [102, 103): base angular velocities
|
|
||||||
"base_angular_vel": 102,
|
|
||||||
# [103, 128): reserved
|
|
||||||
}
|
|
||||||
STATE_VEC_LEN = 128
|
|
||||||
@ -1,14 +0,0 @@
|
|||||||
{
|
|
||||||
"bf16": {
|
|
||||||
"enabled": "auto"
|
|
||||||
},
|
|
||||||
"train_micro_batch_size_per_gpu": "auto",
|
|
||||||
"train_batch_size": "auto",
|
|
||||||
"gradient_accumulation_steps": "auto",
|
|
||||||
"zero_optimization": {
|
|
||||||
"stage": 2,
|
|
||||||
"overlap_comm": true,
|
|
||||||
"contiguous_gradients": true,
|
|
||||||
"sub_group_size": 1e9
|
|
||||||
}
|
|
||||||
}
|
|
||||||
2
RDT-1B/data/.gitignore
vendored
2
RDT-1B/data/.gitignore
vendored
@ -1,2 +0,0 @@
|
|||||||
# Ignore data files
|
|
||||||
datasets
|
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1,154 +0,0 @@
|
|||||||
import tensorflow as tf
|
|
||||||
import h5py
|
|
||||||
import os
|
|
||||||
import fnmatch
|
|
||||||
import shutil
|
|
||||||
from tqdm import tqdm
|
|
||||||
from multiprocessing import Pool
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def _bytes_feature(value):
|
|
||||||
"""Returns a bytes_list from a string / byte."""
|
|
||||||
if isinstance(value, type(tf.constant(0))):
|
|
||||||
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
|
|
||||||
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
|
|
||||||
|
|
||||||
|
|
||||||
def _bool_feature(value):
|
|
||||||
"""Returns a bool_list from a boolean."""
|
|
||||||
return tf.train.Feature(int64_list=tf.train.Int64List(value=[int(value)]))
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_example(
|
|
||||||
action,
|
|
||||||
base_action,
|
|
||||||
qpos,
|
|
||||||
qvel,
|
|
||||||
cam_high,
|
|
||||||
cam_left_wrist,
|
|
||||||
cam_right_wrist,
|
|
||||||
instruction,
|
|
||||||
terminate_episode,
|
|
||||||
):
|
|
||||||
feature = {
|
|
||||||
"action":
|
|
||||||
_bytes_feature(tf.io.serialize_tensor(action)),
|
|
||||||
"base_action":
|
|
||||||
_bytes_feature(tf.io.serialize_tensor(base_action)),
|
|
||||||
"qpos":
|
|
||||||
_bytes_feature(tf.io.serialize_tensor(qpos)),
|
|
||||||
"qvel":
|
|
||||||
_bytes_feature(tf.io.serialize_tensor(qvel)),
|
|
||||||
"cam_high":
|
|
||||||
_bytes_feature(tf.io.serialize_tensor(tf.convert_to_tensor(cam_high.tobytes(), dtype=tf.string))),
|
|
||||||
"cam_left_wrist":
|
|
||||||
_bytes_feature(tf.io.serialize_tensor(tf.convert_to_tensor(cam_left_wrist.tobytes(), dtype=tf.string))),
|
|
||||||
"cam_right_wrist":
|
|
||||||
_bytes_feature(tf.io.serialize_tensor(tf.convert_to_tensor(cam_right_wrist.tobytes(), dtype=tf.string))),
|
|
||||||
"instruction":
|
|
||||||
_bytes_feature(instruction),
|
|
||||||
"terminate_episode":
|
|
||||||
_bool_feature(terminate_episode),
|
|
||||||
}
|
|
||||||
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
|
|
||||||
return example_proto.SerializeToString()
|
|
||||||
|
|
||||||
|
|
||||||
def process_hdf5_file(args):
|
|
||||||
filepath, root_dir, out_dir = args
|
|
||||||
output_dir = os.path.join(out_dir, os.path.relpath(os.path.dirname(filepath), root_dir))
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
filename = os.path.basename(filepath)
|
|
||||||
tfrecord_path = os.path.join(output_dir, filename.replace(".hdf5", ".tfrecord"))
|
|
||||||
|
|
||||||
if os.path.exists(tfrecord_path) and os.path.getsize(tfrecord_path) > 0:
|
|
||||||
return f"TFRecords already exist at {tfrecord_path}"
|
|
||||||
try:
|
|
||||||
with h5py.File(filepath, "r") as f, tf.io.TFRecordWriter(tfrecord_path) as writer:
|
|
||||||
num_episodes = f["action"].shape[0]
|
|
||||||
# Remove the first few still steps
|
|
||||||
EPS = 1e-2
|
|
||||||
qpos = f["observations"]["qpos"][:]
|
|
||||||
# Get the idx of the first qpos whose delta exceeds the threshold
|
|
||||||
qpos_delta = np.abs(qpos - qpos[0:1])
|
|
||||||
indices = np.where(np.any(qpos_delta > EPS, axis=1))[0]
|
|
||||||
if len(indices) > 0:
|
|
||||||
first_idx = indices[0]
|
|
||||||
else:
|
|
||||||
raise ValueError("Found no qpos that exceeds the threshold.")
|
|
||||||
|
|
||||||
for i in range(first_idx - 1, num_episodes):
|
|
||||||
action = f["action"][i]
|
|
||||||
base_action = f["base_action"][i]
|
|
||||||
qpos = f["observations"]["qpos"][i]
|
|
||||||
qvel = f["observations"]["qvel"][i]
|
|
||||||
cam_high = f["observations"]["images"]["cam_high"][i]
|
|
||||||
cam_left_wrist = f["observations"]["images"]["cam_left_wrist"][i]
|
|
||||||
cam_right_wrist = f["observations"]["images"]["cam_right_wrist"][i]
|
|
||||||
instruction = f["instruction"][()]
|
|
||||||
terminate_episode = i == num_episodes - 1
|
|
||||||
serialized_example = serialize_example(
|
|
||||||
action,
|
|
||||||
base_action,
|
|
||||||
qpos,
|
|
||||||
qvel,
|
|
||||||
cam_high,
|
|
||||||
cam_left_wrist,
|
|
||||||
cam_right_wrist,
|
|
||||||
instruction,
|
|
||||||
terminate_episode,
|
|
||||||
)
|
|
||||||
writer.write(serialized_example)
|
|
||||||
except Exception as e:
|
|
||||||
with open("error_log.txt", "a") as f:
|
|
||||||
f.write(f"{filepath}\n")
|
|
||||||
print(f"error at {filepath}: {e}")
|
|
||||||
return f"TFRecords written to {tfrecord_path}"
|
|
||||||
|
|
||||||
|
|
||||||
def write_tfrecords(root_dir, out_dir):
|
|
||||||
if not os.path.exists(out_dir):
|
|
||||||
os.makedirs(out_dir)
|
|
||||||
|
|
||||||
hdf5_files = []
|
|
||||||
for root, dirs, files in os.walk(root_dir):
|
|
||||||
if os.path.exists(os.path.join(root, "expanded_instruction_gpt-4-turbo.json")):
|
|
||||||
# copy the instruction file
|
|
||||||
target_path = os.path.join(out_dir, os.path.relpath(root, root_dir))
|
|
||||||
os.makedirs(target_path, exist_ok=True)
|
|
||||||
shutil.copy(os.path.join(root, "expanded_instruction_gpt-4-turbo.json"), target_path)
|
|
||||||
elif os.path.exists(os.path.join(root, "expanded_instruction.json")):
|
|
||||||
print(root)
|
|
||||||
target_path = os.path.join(out_dir, os.path.relpath(root, root_dir))
|
|
||||||
os.makedirs(target_path, exist_ok=True)
|
|
||||||
shutil.copy(os.path.join(root, "expanded_instruction.json"), target_path)
|
|
||||||
# rename into expanded_instruction_gpt-4-turbo.json
|
|
||||||
os.rename(
|
|
||||||
os.path.join(
|
|
||||||
out_dir,
|
|
||||||
os.path.relpath(root, root_dir),
|
|
||||||
"expanded_instruction.json",
|
|
||||||
),
|
|
||||||
os.path.join(
|
|
||||||
out_dir,
|
|
||||||
os.path.relpath(root, root_dir),
|
|
||||||
"expanded_instruction_gpt-4-turbo.json",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for filename in fnmatch.filter(files, "*.hdf5"):
|
|
||||||
filepath = os.path.join(root, filename)
|
|
||||||
hdf5_files.append((filepath, root_dir, out_dir))
|
|
||||||
|
|
||||||
with Pool(16) as pool:
|
|
||||||
max_count = len(hdf5_files)
|
|
||||||
with tqdm(total=max_count) as pbar:
|
|
||||||
for _ in pool.imap_unordered(process_hdf5_file, hdf5_files):
|
|
||||||
pbar.update(1)
|
|
||||||
|
|
||||||
print(f"TFRecords written to {out_dir}")
|
|
||||||
|
|
||||||
|
|
||||||
root_dir = "../datasets/agilex/rdt_data/"
|
|
||||||
out_dir = "../datasets/agilex/tfrecords/"
|
|
||||||
write_tfrecords(root_dir, out_dir)
|
|
||||||
@ -1,256 +0,0 @@
|
|||||||
"""
|
|
||||||
This file will compute the min, max, mean, and standard deviation of each datasets
|
|
||||||
in `pretrain_datasets.json` or `pretrain_datasets.json`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
|
|
||||||
# from multiprocessing import Pool, Manager
|
|
||||||
|
|
||||||
import tensorflow as tf
|
|
||||||
import numpy as np
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from data.vla_dataset import VLADataset
|
|
||||||
from data.hdf5_vla_dataset import HDF5VLADataset
|
|
||||||
from data.preprocess import generate_json_state
|
|
||||||
|
|
||||||
|
|
||||||
# Process each dataset to get the statistics
|
|
||||||
@tf.autograph.experimental.do_not_convert
|
|
||||||
def process_dataset(name_dataset_pair):
|
|
||||||
# print(f"PID {os.getpid()} processing {name_dataset_pair[0]}")
|
|
||||||
dataset_iter = name_dataset_pair[1]
|
|
||||||
|
|
||||||
MAX_EPISODES = 100000
|
|
||||||
EPS = 1e-8
|
|
||||||
# For debugging
|
|
||||||
# MAX_EPISODES = 10
|
|
||||||
episode_cnt = 0
|
|
||||||
state_sum = 0
|
|
||||||
state_sum_sq = 0
|
|
||||||
z_state_sum = 0
|
|
||||||
z_state_sum_sq = 0
|
|
||||||
state_cnt = 0
|
|
||||||
nz_state_cnt = None
|
|
||||||
state_max = None
|
|
||||||
state_min = None
|
|
||||||
for episode in dataset_iter:
|
|
||||||
episode_cnt += 1
|
|
||||||
if episode_cnt % 1000 == 0:
|
|
||||||
print(f"Processing episodes {episode_cnt}/{MAX_EPISODES}")
|
|
||||||
if episode_cnt > MAX_EPISODES:
|
|
||||||
break
|
|
||||||
episode_dict = episode["episode_dict"]
|
|
||||||
dataset_name = episode["dataset_name"]
|
|
||||||
|
|
||||||
res_tup = generate_json_state(episode_dict, dataset_name)
|
|
||||||
states = res_tup[1]
|
|
||||||
|
|
||||||
# Convert to numpy
|
|
||||||
states = states.numpy()
|
|
||||||
|
|
||||||
# Zero the values that are close to zero
|
|
||||||
z_states = states.copy()
|
|
||||||
z_states[np.abs(states) <= EPS] = 0
|
|
||||||
# Compute the non-zero count
|
|
||||||
if nz_state_cnt is None:
|
|
||||||
nz_state_cnt = np.zeros(states.shape[1])
|
|
||||||
nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0)
|
|
||||||
|
|
||||||
# Update statistics
|
|
||||||
state_sum += np.sum(states, axis=0)
|
|
||||||
state_sum_sq += np.sum(states**2, axis=0)
|
|
||||||
z_state_sum += np.sum(z_states, axis=0)
|
|
||||||
z_state_sum_sq += np.sum(z_states**2, axis=0)
|
|
||||||
state_cnt += states.shape[0]
|
|
||||||
if state_max is None:
|
|
||||||
state_max = np.max(states, axis=0)
|
|
||||||
state_min = np.min(states, axis=0)
|
|
||||||
else:
|
|
||||||
state_max = np.maximum(state_max, np.max(states, axis=0))
|
|
||||||
state_min = np.minimum(state_min, np.min(states, axis=0))
|
|
||||||
|
|
||||||
# Add one to avoid division by zero
|
|
||||||
nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt))
|
|
||||||
|
|
||||||
result = {
|
|
||||||
"dataset_name":
|
|
||||||
name_dataset_pair[0],
|
|
||||||
"state_mean": (state_sum / state_cnt).tolist(),
|
|
||||||
"state_std":
|
|
||||||
np.sqrt(
|
|
||||||
np.maximum(
|
|
||||||
(z_state_sum_sq / nz_state_cnt) - (z_state_sum / state_cnt)**2 * (state_cnt / nz_state_cnt),
|
|
||||||
np.zeros_like(state_sum_sq),
|
|
||||||
)).tolist(),
|
|
||||||
"state_min":
|
|
||||||
state_min.tolist(),
|
|
||||||
"state_max":
|
|
||||||
state_max.tolist(),
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def process_hdf5_dataset(vla_dataset):
|
|
||||||
EPS = 1e-8
|
|
||||||
episode_cnt = 0
|
|
||||||
state_sum = 0
|
|
||||||
state_sum_sq = 0
|
|
||||||
z_state_sum = 0
|
|
||||||
z_state_sum_sq = 0
|
|
||||||
state_cnt = 0
|
|
||||||
nz_state_cnt = None
|
|
||||||
state_max = None
|
|
||||||
state_min = None
|
|
||||||
for i in tqdm(range(len(vla_dataset))):
|
|
||||||
episode = vla_dataset.get_item(i, state_only=True)
|
|
||||||
episode_cnt += 1
|
|
||||||
|
|
||||||
states = episode["state"]
|
|
||||||
|
|
||||||
# Zero the values that are close to zero
|
|
||||||
z_states = states.copy()
|
|
||||||
z_states[np.abs(states) <= EPS] = 0
|
|
||||||
# Compute the non-zero count
|
|
||||||
if nz_state_cnt is None:
|
|
||||||
nz_state_cnt = np.zeros(states.shape[1])
|
|
||||||
nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0)
|
|
||||||
|
|
||||||
# Update statistics
|
|
||||||
state_sum += np.sum(states, axis=0)
|
|
||||||
state_sum_sq += np.sum(states**2, axis=0)
|
|
||||||
z_state_sum += np.sum(z_states, axis=0)
|
|
||||||
z_state_sum_sq += np.sum(z_states**2, axis=0)
|
|
||||||
state_cnt += states.shape[0]
|
|
||||||
if state_max is None:
|
|
||||||
state_max = np.max(states, axis=0)
|
|
||||||
state_min = np.min(states, axis=0)
|
|
||||||
else:
|
|
||||||
state_max = np.maximum(state_max, np.max(states, axis=0))
|
|
||||||
state_min = np.minimum(state_min, np.min(states, axis=0))
|
|
||||||
|
|
||||||
# Add one to avoid division by zero
|
|
||||||
nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt))
|
|
||||||
|
|
||||||
result = {
|
|
||||||
"dataset_name":
|
|
||||||
vla_dataset.get_dataset_name(),
|
|
||||||
"state_mean": (state_sum / state_cnt).tolist(),
|
|
||||||
"state_std":
|
|
||||||
np.sqrt(
|
|
||||||
np.maximum(
|
|
||||||
(z_state_sum_sq / nz_state_cnt) - (z_state_sum / state_cnt)**2 * (state_cnt / nz_state_cnt),
|
|
||||||
np.zeros_like(state_sum_sq),
|
|
||||||
)).tolist(),
|
|
||||||
"state_min":
|
|
||||||
state_min.tolist(),
|
|
||||||
"state_max":
|
|
||||||
state_max.tolist(),
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
# Multiprocessing currently with bugs
|
|
||||||
# parser.add_argument('--n_workers', type=int, default=1,
|
|
||||||
# help="Number of parallel workers.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--dataset_type",
|
|
||||||
type=str,
|
|
||||||
default="pretrain",
|
|
||||||
help="Whether to load the pretrain dataset or finetune dataset.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--save_path",
|
|
||||||
type=str,
|
|
||||||
default="configs/dataset_stat.json",
|
|
||||||
help="JSON file path to save the dataset statistics.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--skip_exist",
|
|
||||||
action="store_true",
|
|
||||||
help="Whether to skip the existing dataset statistics.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--hdf5_dataset",
|
|
||||||
action="store_true",
|
|
||||||
help="Whether to load the dataset from the HDF5 files.",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.hdf5_dataset:
|
|
||||||
vla_dataset = HDF5VLADataset()
|
|
||||||
dataset_name = vla_dataset.get_dataset_name()
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(args.save_path, "r") as f:
|
|
||||||
results = json.load(f)
|
|
||||||
except FileNotFoundError:
|
|
||||||
results = {}
|
|
||||||
if args.skip_exist and dataset_name in results:
|
|
||||||
print(f"Skipping existed {dataset_name} dataset statistics")
|
|
||||||
else:
|
|
||||||
print(f"Processing {dataset_name} dataset")
|
|
||||||
result = process_hdf5_dataset(vla_dataset)
|
|
||||||
results[result["dataset_name"]] = result
|
|
||||||
with open(args.save_path, "w") as f:
|
|
||||||
json.dump(results, f, indent=4)
|
|
||||||
print("All datasets have been processed.")
|
|
||||||
os._exit(0)
|
|
||||||
|
|
||||||
vla_dataset = VLADataset(seed=0, dataset_type=args.dataset_type, repeat=False)
|
|
||||||
name_dataset_pairs = vla_dataset.name2dataset.items()
|
|
||||||
# num_workers = args.n_workers
|
|
||||||
|
|
||||||
for name_dataset_pair in tqdm(name_dataset_pairs):
|
|
||||||
try:
|
|
||||||
with open(args.save_path, "r") as f:
|
|
||||||
results = json.load(f)
|
|
||||||
except FileNotFoundError:
|
|
||||||
results = {}
|
|
||||||
|
|
||||||
if args.skip_exist and name_dataset_pair[0] in results:
|
|
||||||
print(f"Skipping existed {name_dataset_pair[0]} dataset statistics")
|
|
||||||
continue
|
|
||||||
print(f"Processing {name_dataset_pair[0]} dataset")
|
|
||||||
|
|
||||||
result = process_dataset(name_dataset_pair)
|
|
||||||
|
|
||||||
results[result["dataset_name"]] = result
|
|
||||||
|
|
||||||
# Save the results in the json file after each dataset (for resume)
|
|
||||||
with open(args.save_path, "w") as f:
|
|
||||||
json.dump(results, f, indent=4)
|
|
||||||
|
|
||||||
print("All datasets have been processed.")
|
|
||||||
|
|
||||||
# with Manager() as manager:
|
|
||||||
# # Create shared dictionary and lock through the manager, accessible by all processes
|
|
||||||
# progress = manager.dict(processed=0, results={})
|
|
||||||
# progress_lock = manager.Lock()
|
|
||||||
|
|
||||||
# # Callback function to update progress
|
|
||||||
# def update_progress(result):
|
|
||||||
# with progress_lock:
|
|
||||||
# progress['processed'] += 1
|
|
||||||
# print(f"{result['dataset_name']} - {progress['processed']}/{len(name_dataset_pairs)} datasets have been processed")
|
|
||||||
# # Append the result to the shared dictionary
|
|
||||||
# progress['results'][result["dataset_name"]] = result
|
|
||||||
|
|
||||||
# with Pool(num_workers) as p:
|
|
||||||
# for name_dataset_pair in name_dataset_pairs:
|
|
||||||
# p.apply_async(process_dataset, args=(name_dataset_pair,), callback=update_progress)
|
|
||||||
|
|
||||||
# # Close the pool and wait for the work to finish
|
|
||||||
# p.close()
|
|
||||||
# p.join()
|
|
||||||
|
|
||||||
# # Save the results in the json file
|
|
||||||
# with open(args.save_path, 'w') as f:
|
|
||||||
# json.dump(progress['results'], f, indent=4)
|
|
||||||
@ -1,112 +0,0 @@
|
|||||||
"""
|
|
||||||
This file will compute the min, max, mean, and standard deviation of each datasets
|
|
||||||
in `pretrain_datasets.json` or `pretrain_datasets.json`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from data.hdf5_vla_dataset import HDF5VLADataset
|
|
||||||
|
|
||||||
|
|
||||||
def process_hdf5_dataset(vla_dataset):
|
|
||||||
EPS = 1e-8
|
|
||||||
episode_cnt = 0
|
|
||||||
state_sum = 0
|
|
||||||
state_sum_sq = 0
|
|
||||||
z_state_sum = 0
|
|
||||||
z_state_sum_sq = 0
|
|
||||||
state_cnt = 0
|
|
||||||
nz_state_cnt = None
|
|
||||||
state_max = None
|
|
||||||
state_min = None
|
|
||||||
for i in tqdm(range(len(vla_dataset))):
|
|
||||||
episode = vla_dataset.get_item(i, state_only=True)
|
|
||||||
episode_cnt += 1
|
|
||||||
|
|
||||||
states = episode["state"]
|
|
||||||
|
|
||||||
# Zero the values that are close to zero
|
|
||||||
z_states = states.copy()
|
|
||||||
z_states[np.abs(states) <= EPS] = 0
|
|
||||||
# Compute the non-zero count
|
|
||||||
if nz_state_cnt is None:
|
|
||||||
nz_state_cnt = np.zeros(states.shape[1])
|
|
||||||
nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0)
|
|
||||||
|
|
||||||
# Update statistics
|
|
||||||
state_sum += np.sum(states, axis=0)
|
|
||||||
state_sum_sq += np.sum(states**2, axis=0)
|
|
||||||
z_state_sum += np.sum(z_states, axis=0)
|
|
||||||
z_state_sum_sq += np.sum(z_states**2, axis=0)
|
|
||||||
state_cnt += states.shape[0]
|
|
||||||
if state_max is None:
|
|
||||||
state_max = np.max(states, axis=0)
|
|
||||||
state_min = np.min(states, axis=0)
|
|
||||||
else:
|
|
||||||
state_max = np.maximum(state_max, np.max(states, axis=0))
|
|
||||||
state_min = np.minimum(state_min, np.min(states, axis=0))
|
|
||||||
|
|
||||||
# Add one to avoid division by zero
|
|
||||||
nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt))
|
|
||||||
|
|
||||||
result = {
|
|
||||||
"dataset_name":
|
|
||||||
vla_dataset.get_dataset_name(),
|
|
||||||
"state_mean": (state_sum / state_cnt).tolist(),
|
|
||||||
"state_std":
|
|
||||||
np.sqrt(
|
|
||||||
np.maximum(
|
|
||||||
(z_state_sum_sq / nz_state_cnt) - (z_state_sum / state_cnt)**2 * (state_cnt / nz_state_cnt),
|
|
||||||
np.zeros_like(state_sum_sq),
|
|
||||||
)).tolist(),
|
|
||||||
"state_min":
|
|
||||||
state_min.tolist(),
|
|
||||||
"state_max":
|
|
||||||
state_max.tolist(),
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--task_name",
|
|
||||||
type=str,
|
|
||||||
default="configs/dataset_stat.json",
|
|
||||||
help="JSON file path to save the dataset statistics.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--save_path",
|
|
||||||
type=str,
|
|
||||||
default="configs/dataset_stat.json",
|
|
||||||
help="JSON file path to save the dataset statistics.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--skip_exist",
|
|
||||||
action="store_true",
|
|
||||||
help="Whether to skip the existing dataset statistics.",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
vla_dataset = HDF5VLADataset(f"model_config/{args.task_name}.yml")
|
|
||||||
dataset_name = vla_dataset.get_dataset_name()
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(args.save_path, "r") as f:
|
|
||||||
results = json.load(f)
|
|
||||||
except FileNotFoundError:
|
|
||||||
results = {}
|
|
||||||
if args.skip_exist and dataset_name in results:
|
|
||||||
print(f"Skipping existed {dataset_name} dataset statistics")
|
|
||||||
else:
|
|
||||||
print(f"Processing {dataset_name} dataset")
|
|
||||||
result = process_hdf5_dataset(vla_dataset)
|
|
||||||
results[result["dataset_name"]] = result
|
|
||||||
with open(args.save_path, "w") as f:
|
|
||||||
json.dump(results, f, indent=4)
|
|
||||||
print("All datasets have been processed.")
|
|
||||||
Binary file not shown.
@ -1,398 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import tensorflow as tf
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from data.preprocess import generate_json_state
|
|
||||||
from configs.state_vec import STATE_VEC_IDX_MAPPING
|
|
||||||
|
|
||||||
# Read the config
|
|
||||||
with open("configs/base.yaml", "r") as file:
|
|
||||||
config = yaml.safe_load(file)
|
|
||||||
# Load some constants from the config
|
|
||||||
IMG_HISTORY_SIZE = config["common"]["img_history_size"]
|
|
||||||
if IMG_HISTORY_SIZE < 1:
|
|
||||||
raise ValueError("Config `img_history_size` must be at least 1.")
|
|
||||||
ACTION_CHUNK_SIZE = config["common"]["action_chunk_size"]
|
|
||||||
if ACTION_CHUNK_SIZE < 1:
|
|
||||||
raise ValueError("Config `action_chunk_size` must be at least 1.")
|
|
||||||
|
|
||||||
|
|
||||||
@tf.function
|
|
||||||
def process_episode(epsd: dict, dataset_name: str, image_keys: list, image_mask: list) -> dict:
|
|
||||||
"""
|
|
||||||
Process an episode to extract the frames and the json content.
|
|
||||||
"""
|
|
||||||
# Frames of each camera
|
|
||||||
# Ugly code due to tf's poor compatibility
|
|
||||||
frames_0 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True)
|
|
||||||
frames_1 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True)
|
|
||||||
frames_2 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True)
|
|
||||||
frames_3 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True)
|
|
||||||
# Traverse the episode to collect...
|
|
||||||
for step in iter(epsd["steps"]):
|
|
||||||
# Parse the image
|
|
||||||
frames_0 = frames_0.write(
|
|
||||||
frames_0.size(),
|
|
||||||
tf.cond(
|
|
||||||
tf.equal(image_mask[0], 1),
|
|
||||||
lambda: step["observation"][image_keys[0]],
|
|
||||||
lambda: tf.zeros([0, 0, 0], dtype=tf.uint8),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
# Very ugly code due to tf's poor compatibility
|
|
||||||
frames_1 = frames_1.write(
|
|
||||||
frames_1.size(),
|
|
||||||
tf.cond(
|
|
||||||
tf.equal(image_mask[1], 1),
|
|
||||||
lambda: step["observation"][image_keys[1]],
|
|
||||||
lambda: tf.zeros([0, 0, 0], dtype=tf.uint8),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
frames_2 = frames_2.write(
|
|
||||||
frames_2.size(),
|
|
||||||
tf.cond(
|
|
||||||
tf.equal(image_mask[2], 1),
|
|
||||||
lambda: step["observation"][image_keys[2]],
|
|
||||||
lambda: tf.zeros([0, 0, 0], dtype=tf.uint8),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
frames_3 = frames_3.write(
|
|
||||||
frames_3.size(),
|
|
||||||
tf.cond(
|
|
||||||
tf.equal(image_mask[3], 1),
|
|
||||||
lambda: step["observation"][image_keys[3]],
|
|
||||||
lambda: tf.zeros([0, 0, 0], dtype=tf.uint8),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate the past_frames_0 for each step
|
|
||||||
# Each step has a window of previous frames with size IMG_HISTORY_SIZE
|
|
||||||
# Use the first state to pad the frames
|
|
||||||
# past_frames_0 will have shape (num_steps, IMG_HISTORY_SIZE, height, width, channels)
|
|
||||||
frames_0 = frames_0.stack()
|
|
||||||
first_frame = tf.expand_dims(frames_0[0], axis=0)
|
|
||||||
first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE - 1, axis=0)
|
|
||||||
padded_frames_0 = tf.concat([first_frame, frames_0], axis=0)
|
|
||||||
indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_0)[0] + IMG_HISTORY_SIZE)
|
|
||||||
past_frames_0 = tf.map_fn(lambda i: padded_frames_0[i - IMG_HISTORY_SIZE:i], indices, dtype=tf.uint8)
|
|
||||||
frames_0_time_mask = tf.ones([tf.shape(frames_0)[0]], dtype=tf.bool)
|
|
||||||
padded_frames_0_time_mask = tf.pad(
|
|
||||||
frames_0_time_mask,
|
|
||||||
[[IMG_HISTORY_SIZE - 1, 0]],
|
|
||||||
"CONSTANT",
|
|
||||||
constant_values=False,
|
|
||||||
)
|
|
||||||
past_frames_0_time_mask = tf.map_fn(
|
|
||||||
lambda i: padded_frames_0_time_mask[i - IMG_HISTORY_SIZE:i],
|
|
||||||
indices,
|
|
||||||
dtype=tf.bool,
|
|
||||||
)
|
|
||||||
|
|
||||||
# For past_frames_1
|
|
||||||
frames_1 = frames_1.stack()
|
|
||||||
first_frame = tf.expand_dims(frames_1[0], axis=0)
|
|
||||||
first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE - 1, axis=0)
|
|
||||||
padded_frames_1 = tf.concat([first_frame, frames_1], axis=0)
|
|
||||||
indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_1)[0] + IMG_HISTORY_SIZE)
|
|
||||||
past_frames_1 = tf.map_fn(lambda i: padded_frames_1[i - IMG_HISTORY_SIZE:i], indices, dtype=tf.uint8)
|
|
||||||
frames_1_time_mask = tf.ones([tf.shape(frames_1)[0]], dtype=tf.bool)
|
|
||||||
padded_frames_1_time_mask = tf.pad(
|
|
||||||
frames_1_time_mask,
|
|
||||||
[[IMG_HISTORY_SIZE - 1, 0]],
|
|
||||||
"CONSTANT",
|
|
||||||
constant_values=False,
|
|
||||||
)
|
|
||||||
past_frames_1_time_mask = tf.map_fn(
|
|
||||||
lambda i: padded_frames_1_time_mask[i - IMG_HISTORY_SIZE:i],
|
|
||||||
indices,
|
|
||||||
dtype=tf.bool,
|
|
||||||
)
|
|
||||||
|
|
||||||
# For past_frames_2
|
|
||||||
frames_2 = frames_2.stack()
|
|
||||||
first_frame = tf.expand_dims(frames_2[0], axis=0)
|
|
||||||
first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE - 1, axis=0)
|
|
||||||
padded_frames_2 = tf.concat([first_frame, frames_2], axis=0)
|
|
||||||
indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_2)[0] + IMG_HISTORY_SIZE)
|
|
||||||
past_frames_2 = tf.map_fn(lambda i: padded_frames_2[i - IMG_HISTORY_SIZE:i], indices, dtype=tf.uint8)
|
|
||||||
frames_2_time_mask = tf.ones([tf.shape(frames_2)[0]], dtype=tf.bool)
|
|
||||||
padded_frames_2_time_mask = tf.pad(
|
|
||||||
frames_2_time_mask,
|
|
||||||
[[IMG_HISTORY_SIZE - 1, 0]],
|
|
||||||
"CONSTANT",
|
|
||||||
constant_values=False,
|
|
||||||
)
|
|
||||||
past_frames_2_time_mask = tf.map_fn(
|
|
||||||
lambda i: padded_frames_2_time_mask[i - IMG_HISTORY_SIZE:i],
|
|
||||||
indices,
|
|
||||||
dtype=tf.bool,
|
|
||||||
)
|
|
||||||
|
|
||||||
# For past_frames_3
|
|
||||||
frames_3 = frames_3.stack()
|
|
||||||
first_frame = tf.expand_dims(frames_3[0], axis=0)
|
|
||||||
first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE - 1, axis=0)
|
|
||||||
padded_frames_3 = tf.concat([first_frame, frames_3], axis=0)
|
|
||||||
indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_3)[0] + IMG_HISTORY_SIZE)
|
|
||||||
past_frames_3 = tf.map_fn(lambda i: padded_frames_3[i - IMG_HISTORY_SIZE:i], indices, dtype=tf.uint8)
|
|
||||||
frames_3_time_mask = tf.ones([tf.shape(frames_3)[0]], dtype=tf.bool)
|
|
||||||
padded_frames_3_time_mask = tf.pad(
|
|
||||||
frames_3_time_mask,
|
|
||||||
[[IMG_HISTORY_SIZE - 1, 0]],
|
|
||||||
"CONSTANT",
|
|
||||||
constant_values=False,
|
|
||||||
)
|
|
||||||
past_frames_3_time_mask = tf.map_fn(
|
|
||||||
lambda i: padded_frames_3_time_mask[i - IMG_HISTORY_SIZE:i],
|
|
||||||
indices,
|
|
||||||
dtype=tf.bool,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Creat the ids for each step
|
|
||||||
step_id = tf.range(0, tf.shape(frames_0)[0])
|
|
||||||
|
|
||||||
return {
|
|
||||||
"dataset_name": dataset_name,
|
|
||||||
"episode_dict": epsd,
|
|
||||||
"step_id": step_id,
|
|
||||||
"past_frames_0": past_frames_0,
|
|
||||||
"past_frames_0_time_mask": past_frames_0_time_mask,
|
|
||||||
"past_frames_1": past_frames_1,
|
|
||||||
"past_frames_1_time_mask": past_frames_1_time_mask,
|
|
||||||
"past_frames_2": past_frames_2,
|
|
||||||
"past_frames_2_time_mask": past_frames_2_time_mask,
|
|
||||||
"past_frames_3": past_frames_3,
|
|
||||||
"past_frames_3_time_mask": past_frames_3_time_mask,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@tf.function
|
|
||||||
def bgr_to_rgb(epsd: dict):
|
|
||||||
"""
|
|
||||||
Convert BGR images to RGB images.
|
|
||||||
"""
|
|
||||||
past_frames_0 = epsd["past_frames_0"]
|
|
||||||
past_frames_0 = tf.cond(
|
|
||||||
tf.equal(tf.shape(past_frames_0)[-1], 3),
|
|
||||||
lambda: tf.stack(
|
|
||||||
[past_frames_0[..., 2], past_frames_0[..., 1], past_frames_0[..., 0]],
|
|
||||||
axis=-1,
|
|
||||||
),
|
|
||||||
lambda: past_frames_0,
|
|
||||||
)
|
|
||||||
|
|
||||||
past_frames_1 = epsd["past_frames_1"]
|
|
||||||
past_frames_1 = tf.cond(
|
|
||||||
tf.equal(tf.shape(past_frames_1)[-1], 3),
|
|
||||||
lambda: tf.stack(
|
|
||||||
[past_frames_1[..., 2], past_frames_1[..., 1], past_frames_1[..., 0]],
|
|
||||||
axis=-1,
|
|
||||||
),
|
|
||||||
lambda: past_frames_1,
|
|
||||||
)
|
|
||||||
|
|
||||||
past_frames_2 = epsd["past_frames_2"]
|
|
||||||
past_frames_2 = tf.cond(
|
|
||||||
tf.equal(tf.shape(past_frames_2)[-1], 3),
|
|
||||||
lambda: tf.stack(
|
|
||||||
[past_frames_2[..., 2], past_frames_2[..., 1], past_frames_2[..., 0]],
|
|
||||||
axis=-1,
|
|
||||||
),
|
|
||||||
lambda: past_frames_2,
|
|
||||||
)
|
|
||||||
|
|
||||||
past_frames_3 = epsd["past_frames_3"]
|
|
||||||
past_frames_3 = tf.cond(
|
|
||||||
tf.equal(tf.shape(past_frames_3)[-1], 3),
|
|
||||||
lambda: tf.stack(
|
|
||||||
[past_frames_3[..., 2], past_frames_3[..., 1], past_frames_3[..., 0]],
|
|
||||||
axis=-1,
|
|
||||||
),
|
|
||||||
lambda: past_frames_3,
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"dataset_name": epsd["dataset_name"],
|
|
||||||
"episode_dict": epsd["episode_dict"],
|
|
||||||
"step_id": epsd["step_id"],
|
|
||||||
"past_frames_0": past_frames_0,
|
|
||||||
"past_frames_0_time_mask": epsd["past_frames_0_time_mask"],
|
|
||||||
"past_frames_1": past_frames_1,
|
|
||||||
"past_frames_1_time_mask": epsd["past_frames_1_time_mask"],
|
|
||||||
"past_frames_2": past_frames_2,
|
|
||||||
"past_frames_2_time_mask": epsd["past_frames_2_time_mask"],
|
|
||||||
"past_frames_3": past_frames_3,
|
|
||||||
"past_frames_3_time_mask": epsd["past_frames_3_time_mask"],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def flatten_episode(episode: dict) -> tf.data.Dataset:
|
|
||||||
"""
|
|
||||||
Flatten the episode to a list of steps.
|
|
||||||
"""
|
|
||||||
episode_dict = episode["episode_dict"]
|
|
||||||
dataset_name = episode["dataset_name"]
|
|
||||||
|
|
||||||
json_content, states, masks = generate_json_state(episode_dict, dataset_name)
|
|
||||||
|
|
||||||
# Calculate the past_states for each step
|
|
||||||
# Each step has a window of previous states with size ACTION_CHUNK_SIZE
|
|
||||||
# Use the first state to pad the states
|
|
||||||
# past_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim)
|
|
||||||
first_state = tf.expand_dims(states[0], axis=0)
|
|
||||||
first_state = tf.repeat(first_state, ACTION_CHUNK_SIZE - 1, axis=0)
|
|
||||||
padded_states = tf.concat([first_state, states], axis=0)
|
|
||||||
indices = tf.range(ACTION_CHUNK_SIZE, tf.shape(states)[0] + ACTION_CHUNK_SIZE)
|
|
||||||
past_states = tf.map_fn(lambda i: padded_states[i - ACTION_CHUNK_SIZE:i], indices, dtype=tf.float32)
|
|
||||||
states_time_mask = tf.ones([tf.shape(states)[0]], dtype=tf.bool)
|
|
||||||
padded_states_time_mask = tf.pad(
|
|
||||||
states_time_mask,
|
|
||||||
[[ACTION_CHUNK_SIZE - 1, 0]],
|
|
||||||
"CONSTANT",
|
|
||||||
constant_values=False,
|
|
||||||
)
|
|
||||||
past_states_time_mask = tf.map_fn(
|
|
||||||
lambda i: padded_states_time_mask[i - ACTION_CHUNK_SIZE:i],
|
|
||||||
indices,
|
|
||||||
dtype=tf.bool,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate the future_states for each step
|
|
||||||
# Each step has a window of future states with size ACTION_CHUNK_SIZE
|
|
||||||
# Use the last state to pad the states
|
|
||||||
# future_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim)
|
|
||||||
last_state = tf.expand_dims(states[-1], axis=0)
|
|
||||||
last_state = tf.repeat(last_state, ACTION_CHUNK_SIZE, axis=0)
|
|
||||||
padded_states = tf.concat([states, last_state], axis=0)
|
|
||||||
indices = tf.range(1, tf.shape(states)[0] + 1)
|
|
||||||
future_states = tf.map_fn(lambda i: padded_states[i:i + ACTION_CHUNK_SIZE], indices, dtype=tf.float32)
|
|
||||||
states_time_mask = tf.ones([tf.shape(states)[0]], dtype=tf.bool)
|
|
||||||
padded_states_time_mask = tf.pad(states_time_mask, [[0, ACTION_CHUNK_SIZE]], "CONSTANT", constant_values=False)
|
|
||||||
future_states_time_mask = tf.map_fn(
|
|
||||||
lambda i: padded_states_time_mask[i:i + ACTION_CHUNK_SIZE],
|
|
||||||
indices,
|
|
||||||
dtype=tf.bool,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate the mean and std for state
|
|
||||||
state_std = tf.math.reduce_std(states, axis=0, keepdims=True)
|
|
||||||
state_std = tf.repeat(state_std, tf.shape(states)[0], axis=0)
|
|
||||||
state_mean = tf.math.reduce_mean(states, axis=0, keepdims=True)
|
|
||||||
state_mean = tf.repeat(state_mean, tf.shape(states)[0], axis=0)
|
|
||||||
|
|
||||||
state_norm = tf.math.reduce_mean(tf.math.square(states), axis=0, keepdims=True)
|
|
||||||
state_norm = tf.math.sqrt(state_norm)
|
|
||||||
state_norm = tf.repeat(state_norm, tf.shape(states)[0], axis=0)
|
|
||||||
|
|
||||||
# Create a list of steps
|
|
||||||
step_data = []
|
|
||||||
for i in range(tf.shape(states)[0]):
|
|
||||||
step_data.append({
|
|
||||||
"step_id": episode["step_id"][i],
|
|
||||||
"json_content": json_content,
|
|
||||||
"state_chunk": past_states[i],
|
|
||||||
"state_chunk_time_mask": past_states_time_mask[i],
|
|
||||||
"action_chunk": future_states[i],
|
|
||||||
"action_chunk_time_mask": future_states_time_mask[i],
|
|
||||||
"state_vec_mask": masks[i],
|
|
||||||
"past_frames_0": episode["past_frames_0"][i],
|
|
||||||
"past_frames_0_time_mask": episode["past_frames_0_time_mask"][i],
|
|
||||||
"past_frames_1": episode["past_frames_1"][i],
|
|
||||||
"past_frames_1_time_mask": episode["past_frames_1_time_mask"][i],
|
|
||||||
"past_frames_2": episode["past_frames_2"][i],
|
|
||||||
"past_frames_2_time_mask": episode["past_frames_2_time_mask"][i],
|
|
||||||
"past_frames_3": episode["past_frames_3"][i],
|
|
||||||
"past_frames_3_time_mask": episode["past_frames_3_time_mask"][i],
|
|
||||||
"state_std": state_std[i],
|
|
||||||
"state_mean": state_mean[i],
|
|
||||||
"state_norm": state_norm[i],
|
|
||||||
})
|
|
||||||
|
|
||||||
return step_data
|
|
||||||
|
|
||||||
|
|
||||||
def flatten_episode_agilex(episode: dict) -> tf.data.Dataset:
|
|
||||||
"""
|
|
||||||
Flatten the episode to a list of steps.
|
|
||||||
"""
|
|
||||||
episode_dict = episode["episode_dict"]
|
|
||||||
dataset_name = episode["dataset_name"]
|
|
||||||
|
|
||||||
json_content, states, masks, acts = generate_json_state(episode_dict, dataset_name)
|
|
||||||
|
|
||||||
# Calculate the past_states for each step
|
|
||||||
# Each step has a window of previous states with size ACTION_CHUNK_SIZE
|
|
||||||
# Use the first state to pad the states
|
|
||||||
# past_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim)
|
|
||||||
first_state = tf.expand_dims(states[0], axis=0)
|
|
||||||
first_state = tf.repeat(first_state, ACTION_CHUNK_SIZE - 1, axis=0)
|
|
||||||
padded_states = tf.concat([first_state, states], axis=0)
|
|
||||||
indices = tf.range(ACTION_CHUNK_SIZE, tf.shape(states)[0] + ACTION_CHUNK_SIZE)
|
|
||||||
past_states = tf.map_fn(lambda i: padded_states[i - ACTION_CHUNK_SIZE:i], indices, dtype=tf.float32)
|
|
||||||
states_time_mask = tf.ones([tf.shape(states)[0]], dtype=tf.bool)
|
|
||||||
padded_states_time_mask = tf.pad(
|
|
||||||
states_time_mask,
|
|
||||||
[[ACTION_CHUNK_SIZE - 1, 0]],
|
|
||||||
"CONSTANT",
|
|
||||||
constant_values=False,
|
|
||||||
)
|
|
||||||
past_states_time_mask = tf.map_fn(
|
|
||||||
lambda i: padded_states_time_mask[i - ACTION_CHUNK_SIZE:i],
|
|
||||||
indices,
|
|
||||||
dtype=tf.bool,
|
|
||||||
)
|
|
||||||
|
|
||||||
# NOTE bg the future states shall be actions
|
|
||||||
# Calculate the future_states for each step
|
|
||||||
# Each step has a window of future states with size ACTION_CHUNK_SIZE
|
|
||||||
# Use the last action to pad the states
|
|
||||||
# future_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim)
|
|
||||||
last_act = tf.expand_dims(acts[-1], axis=0)
|
|
||||||
last_act = tf.repeat(last_act, ACTION_CHUNK_SIZE, axis=0)
|
|
||||||
padded_states = tf.concat([acts, last_act], axis=0)
|
|
||||||
# indices = tf.range(1, tf.shape(states)[0] + 1)
|
|
||||||
indices = tf.range(0, tf.shape(acts)[0]) # NOTE time 0 action = time 1 state
|
|
||||||
future_states = tf.map_fn(lambda i: padded_states[i:i + ACTION_CHUNK_SIZE], indices, dtype=tf.float32)
|
|
||||||
states_time_mask = tf.ones([tf.shape(acts)[0]], dtype=tf.bool)
|
|
||||||
padded_states_time_mask = tf.pad(states_time_mask, [[0, ACTION_CHUNK_SIZE]], "CONSTANT", constant_values=False)
|
|
||||||
future_states_time_mask = tf.map_fn(
|
|
||||||
lambda i: padded_states_time_mask[i:i + ACTION_CHUNK_SIZE],
|
|
||||||
indices,
|
|
||||||
dtype=tf.bool,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate the std and mean for state
|
|
||||||
state_std = tf.math.reduce_std(states, axis=0, keepdims=True)
|
|
||||||
state_std = tf.repeat(state_std, tf.shape(states)[0], axis=0)
|
|
||||||
state_mean = tf.math.reduce_mean(states, axis=0, keepdims=True)
|
|
||||||
state_mean = tf.repeat(state_mean, tf.shape(states)[0], axis=0)
|
|
||||||
|
|
||||||
state_norm = tf.math.reduce_mean(tf.math.square(acts), axis=0, keepdims=True)
|
|
||||||
state_norm = tf.math.sqrt(state_norm)
|
|
||||||
state_norm = tf.repeat(state_norm, tf.shape(states)[0], axis=0)
|
|
||||||
|
|
||||||
# Create a list of steps
|
|
||||||
step_data = []
|
|
||||||
for i in range(tf.shape(states)[0]):
|
|
||||||
step_data.append({
|
|
||||||
"step_id": episode["step_id"][i],
|
|
||||||
"json_content": json_content,
|
|
||||||
"state_chunk": past_states[i],
|
|
||||||
"state_chunk_time_mask": past_states_time_mask[i],
|
|
||||||
"action_chunk": future_states[i],
|
|
||||||
"action_chunk_time_mask": future_states_time_mask[i],
|
|
||||||
"state_vec_mask": masks[i],
|
|
||||||
"past_frames_0": episode["past_frames_0"][i],
|
|
||||||
"past_frames_0_time_mask": episode["past_frames_0_time_mask"][i],
|
|
||||||
"past_frames_1": episode["past_frames_1"][i],
|
|
||||||
"past_frames_1_time_mask": episode["past_frames_1_time_mask"][i],
|
|
||||||
"past_frames_2": episode["past_frames_2"][i],
|
|
||||||
"past_frames_2_time_mask": episode["past_frames_2_time_mask"][i],
|
|
||||||
"past_frames_3": episode["past_frames_3"][i],
|
|
||||||
"past_frames_3_time_mask": episode["past_frames_3_time_mask"][i],
|
|
||||||
"state_std": state_std[i],
|
|
||||||
"state_mean": state_mean[i],
|
|
||||||
"state_norm": state_norm[i],
|
|
||||||
})
|
|
||||||
|
|
||||||
return step_data
|
|
||||||
@ -1,25 +0,0 @@
|
|||||||
import fcntl
|
|
||||||
|
|
||||||
|
|
||||||
class FileLock:
|
|
||||||
"""
|
|
||||||
A file lock class.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, filename):
|
|
||||||
self.filename = filename
|
|
||||||
self.handle = None
|
|
||||||
|
|
||||||
def acquire_read_lock(self):
|
|
||||||
self.handle = open(self.filename + ".lock", "r")
|
|
||||||
fcntl.flock(self.handle, fcntl.LOCK_SH | fcntl.LOCK_NB)
|
|
||||||
|
|
||||||
def acquire_write_lock(self):
|
|
||||||
self.handle = open(self.filename + ".lock", "w")
|
|
||||||
fcntl.flock(self.handle, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
|
||||||
|
|
||||||
def release_lock(self):
|
|
||||||
if self.handle is not None:
|
|
||||||
fcntl.flock(self.handle, fcntl.LOCK_UN)
|
|
||||||
self.handle.close()
|
|
||||||
self.handle = None
|
|
||||||
@ -1,372 +0,0 @@
|
|||||||
import os
|
|
||||||
import fnmatch
|
|
||||||
import json
|
|
||||||
|
|
||||||
import h5py
|
|
||||||
import yaml
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from configs.state_vec import STATE_VEC_IDX_MAPPING
|
|
||||||
|
|
||||||
class HDF5VLADataset:
|
|
||||||
"""
|
|
||||||
This class is used to sample episodes from the embododiment dataset
|
|
||||||
stored in HDF5.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model_config_path) -> None:
|
|
||||||
# [Modify] The path to the HDF5 dataset directory
|
|
||||||
# Each HDF5 file contains one episode
|
|
||||||
with open(model_config_path, "r") as f:
|
|
||||||
model_config = yaml.safe_load(f)
|
|
||||||
HDF5_DIR = model_config["data_path"]
|
|
||||||
self.DATASET_NAME = "agilex"
|
|
||||||
|
|
||||||
self.file_paths = []
|
|
||||||
for root, _, files in os.walk(HDF5_DIR):
|
|
||||||
for filename in fnmatch.filter(files, "*.hdf5"):
|
|
||||||
file_path = os.path.join(root, filename)
|
|
||||||
self.file_paths.append(file_path)
|
|
||||||
|
|
||||||
# Load the config
|
|
||||||
with open("configs/base.yaml", "r") as file:
|
|
||||||
config = yaml.safe_load(file)
|
|
||||||
self.CHUNK_SIZE = config["common"]["action_chunk_size"]
|
|
||||||
self.IMG_HISORY_SIZE = config["common"]["img_history_size"]
|
|
||||||
self.STATE_DIM = config["common"]["state_dim"]
|
|
||||||
|
|
||||||
# Get each episode's len (use original length, not standardized length)
|
|
||||||
episode_lens = []
|
|
||||||
for file_path in self.file_paths:
|
|
||||||
try:
|
|
||||||
with h5py.File(file_path, "r") as f:
|
|
||||||
qpos = f["observations"]["qpos"][:]
|
|
||||||
num_steps = qpos.shape[0]
|
|
||||||
episode_lens.append(num_steps)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Warning: Could not read {file_path}: {e}")
|
|
||||||
episode_lens.append(0)
|
|
||||||
self.episode_sample_weights = np.array(episode_lens) / np.sum(episode_lens)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.file_paths)
|
|
||||||
|
|
||||||
def get_dataset_name(self):
|
|
||||||
return self.DATASET_NAME
|
|
||||||
|
|
||||||
def get_item(self, index: int = None, state_only=False):
|
|
||||||
"""Get a training sample at a random timestep.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index (int, optional): the index of the episode.
|
|
||||||
If not provided, a random episode will be selected.
|
|
||||||
state_only (bool, optional): Whether to return only the state.
|
|
||||||
In this way, the sample will contain a complete trajectory rather
|
|
||||||
than a single timestep. Defaults to False.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
sample (dict): a dictionary containing the training sample.
|
|
||||||
"""
|
|
||||||
while True:
|
|
||||||
if index is None:
|
|
||||||
file_path = np.random.choice(self.file_paths, p=self.episode_sample_weights)
|
|
||||||
else:
|
|
||||||
file_path = self.file_paths[index]
|
|
||||||
valid, sample = (self.parse_hdf5_file(file_path)
|
|
||||||
if not state_only else self.parse_hdf5_file_state_only(file_path))
|
|
||||||
if valid:
|
|
||||||
return sample
|
|
||||||
else:
|
|
||||||
index = np.random.randint(0, len(self.file_paths))
|
|
||||||
|
|
||||||
def parse_hdf5_file(self, file_path):
|
|
||||||
"""[Modify] Parse a hdf5 file to generate a training sample at
|
|
||||||
a random timestep.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path (str): the path to the hdf5 file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
valid (bool): whether the episode is valid, which is useful for filtering.
|
|
||||||
If False, this episode will be dropped.
|
|
||||||
dict: a dictionary containing the training sample,
|
|
||||||
{
|
|
||||||
"meta": {
|
|
||||||
"dataset_name": str, # the name of your dataset.
|
|
||||||
"#steps": int, # the number of steps in the episode,
|
|
||||||
# also the total timesteps.
|
|
||||||
"instruction": str # the language instruction for this episode.
|
|
||||||
},
|
|
||||||
"step_id": int, # the index of the sampled step,
|
|
||||||
# also the timestep t.
|
|
||||||
"state": ndarray, # state[t], (1, STATE_DIM).
|
|
||||||
"state_std": ndarray, # std(state[:]), (STATE_DIM,).
|
|
||||||
"state_mean": ndarray, # mean(state[:]), (STATE_DIM,).
|
|
||||||
"state_norm": ndarray, # norm(state[:]), (STATE_DIM,).
|
|
||||||
"actions": ndarray, # action[t:t+CHUNK_SIZE], (CHUNK_SIZE, STATE_DIM).
|
|
||||||
"state_indicator", ndarray, # indicates the validness of each dim, (STATE_DIM,).
|
|
||||||
"cam_high": ndarray, # external camera image, (IMG_HISORY_SIZE, H, W, 3)
|
|
||||||
# or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
|
|
||||||
"cam_high_mask": ndarray, # indicates the validness of each timestep, (IMG_HISORY_SIZE,) boolean array.
|
|
||||||
# For the first IMAGE_HISTORY_SIZE-1 timesteps, the mask should be False.
|
|
||||||
"cam_left_wrist": ndarray, # left wrist camera image, (IMG_HISORY_SIZE, H, W, 3).
|
|
||||||
# or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
|
|
||||||
"cam_left_wrist_mask": ndarray,
|
|
||||||
"cam_right_wrist": ndarray, # right wrist camera image, (IMG_HISORY_SIZE, H, W, 3).
|
|
||||||
# or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
|
|
||||||
# If only one wrist, make it right wrist, plz.
|
|
||||||
"cam_right_wrist_mask": ndarray
|
|
||||||
} or None if the episode is invalid.
|
|
||||||
"""
|
|
||||||
with h5py.File(file_path, "r") as f:
|
|
||||||
qpos = f["observations"]["qpos"][:]
|
|
||||||
left_arm_dim = f["observations"]["left_arm_dim"][:]
|
|
||||||
right_arm_dim = f["observations"]["right_arm_dim"][:]
|
|
||||||
num_steps = qpos.shape[0]
|
|
||||||
action_dim = qpos
|
|
||||||
# [Optional] We drop too-short episode
|
|
||||||
# if num_steps < 128:
|
|
||||||
# return False, None
|
|
||||||
|
|
||||||
# [Optional] We skip the first few still steps
|
|
||||||
EPS = 1e-2
|
|
||||||
# Get the idx of the first qpos whose delta exceeds the threshold
|
|
||||||
qpos_delta = np.abs(qpos - qpos[0:1])
|
|
||||||
indices = np.where(np.any(qpos_delta > EPS, axis=1))[0]
|
|
||||||
if len(indices) > 0:
|
|
||||||
first_idx = indices[0]
|
|
||||||
else:
|
|
||||||
raise ValueError("Found no qpos that exceeds the threshold.")
|
|
||||||
|
|
||||||
# We randomly sample a timestep
|
|
||||||
step_id = np.random.randint(first_idx - 1, num_steps)
|
|
||||||
|
|
||||||
# Load the instruction
|
|
||||||
dir_path = os.path.dirname(file_path)
|
|
||||||
|
|
||||||
# with open(os.path.join(dir_path, 'instruction.json'), 'r') as f_instr:
|
|
||||||
# instruction_dict = json.load(f_instr)
|
|
||||||
# # We have 1/3 prob to use original instruction,
|
|
||||||
# # 1/3 to use simplified instruction,
|
|
||||||
# # and 1/3 to use expanded instruction.
|
|
||||||
# instruction_type = np.random.choice([
|
|
||||||
# 'instruction', 'expanded_instruction'])
|
|
||||||
# instruction = instruction_dict[instruction_type]
|
|
||||||
# if isinstance(instruction, list):
|
|
||||||
# instruction = np.random.choice(instruction)
|
|
||||||
|
|
||||||
# You can also use precomputed language embeddings (recommended)
|
|
||||||
# instruction = "path/to/lang_embed.pt"
|
|
||||||
instructions_path = os.path.join(dir_path, "instructions")
|
|
||||||
instructions_names = []
|
|
||||||
|
|
||||||
for filename in os.listdir(instructions_path):
|
|
||||||
# 检查文件名是否以.pt结尾
|
|
||||||
if filename.endswith(".pt"):
|
|
||||||
instructions_names.append(os.path.join(instructions_path, filename))
|
|
||||||
instruction = np.random.choice(instructions_names)
|
|
||||||
# print(f"choose {instruction} file as instruction.")
|
|
||||||
# Assemble the meta
|
|
||||||
meta = {
|
|
||||||
"dataset_name": self.DATASET_NAME,
|
|
||||||
"#steps": num_steps,
|
|
||||||
"step_id": step_id,
|
|
||||||
"instruction": instruction,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Rescale gripper to [0, 1]
|
|
||||||
# qpos = qpos / np.array([[1 for i in range(left_arm_dim[0] + 1 + right_arm_dim[0] + 1)]])
|
|
||||||
# target_qpos = f["action"][step_id:step_id + self.CHUNK_SIZE] / np.array(
|
|
||||||
# [[1 for i in range(left_arm_dim[0] + 1 + right_arm_dim[0] + 1)]])
|
|
||||||
|
|
||||||
qpos = qpos / np.array(
|
|
||||||
# [[1, 1, 1, 1, 1, 1, 4.7908, 1, 1, 1, 1, 1, 1, 4.7888]]
|
|
||||||
[[180, 180, 180, 180, 180, 180]]
|
|
||||||
)
|
|
||||||
target_qpos = f['action'][step_id:step_id + self.CHUNK_SIZE] / np.array(
|
|
||||||
# [[1, 1, 1, 1, 1, 1, 11.8997, 1, 1, 1, 1, 1, 1, 13.9231]]
|
|
||||||
[[180, 180, 180, 180, 180, 180]]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Parse the state and action
|
|
||||||
state = qpos[step_id:step_id + 1]
|
|
||||||
state_std = np.std(qpos, axis=0)
|
|
||||||
state_mean = np.mean(qpos, axis=0)
|
|
||||||
state_norm = np.sqrt(np.mean(qpos**2, axis=0))
|
|
||||||
actions = target_qpos
|
|
||||||
if actions.shape[0] < self.CHUNK_SIZE:
|
|
||||||
# Pad the actions using the last action
|
|
||||||
actions = np.concatenate(
|
|
||||||
[
|
|
||||||
actions,
|
|
||||||
np.tile(actions[-1:], (self.CHUNK_SIZE - actions.shape[0], 1)),
|
|
||||||
],
|
|
||||||
axis=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fill the state/action into the unified vector
|
|
||||||
|
|
||||||
def fill_in_state(values):
|
|
||||||
# Target indices corresponding to your state space
|
|
||||||
# In this example: 6 joints + 1 gripper for each arm
|
|
||||||
UNI_STATE_INDICES = [
|
|
||||||
STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(6)
|
|
||||||
# ] + [
|
|
||||||
# STATE_VEC_IDX_MAPPING["right_gripper_open"]
|
|
||||||
]
|
|
||||||
uni_vec = np.zeros(values.shape[:-1] + (self.STATE_DIM, ))
|
|
||||||
uni_vec[..., UNI_STATE_INDICES] = values
|
|
||||||
return uni_vec
|
|
||||||
|
|
||||||
state = fill_in_state(state)
|
|
||||||
state_indicator = fill_in_state(np.ones_like(state_std))
|
|
||||||
state_std = fill_in_state(state_std)
|
|
||||||
state_mean = fill_in_state(state_mean)
|
|
||||||
state_norm = fill_in_state(state_norm)
|
|
||||||
# If action's format is different from state's,
|
|
||||||
# you may implement fill_in_action()
|
|
||||||
actions = fill_in_state(actions)
|
|
||||||
|
|
||||||
# Parse the images
|
|
||||||
def parse_img(key):
|
|
||||||
imgs = []
|
|
||||||
for i in range(max(step_id - self.IMG_HISORY_SIZE + 1, 0), step_id + 1):
|
|
||||||
img_bits = f["observations"]["images"][key][i]
|
|
||||||
img = cv2.imdecode(np.frombuffer(img_bits, np.uint8), cv2.IMREAD_COLOR)
|
|
||||||
imgs.append(img)
|
|
||||||
imgs = np.stack(imgs)
|
|
||||||
if imgs.shape[0] < self.IMG_HISORY_SIZE:
|
|
||||||
# Pad the images using the first image
|
|
||||||
imgs = np.concatenate(
|
|
||||||
[
|
|
||||||
np.tile(
|
|
||||||
imgs[:1],
|
|
||||||
(self.IMG_HISORY_SIZE - imgs.shape[0], 1, 1, 1),
|
|
||||||
),
|
|
||||||
imgs,
|
|
||||||
],
|
|
||||||
axis=0,
|
|
||||||
)
|
|
||||||
return imgs
|
|
||||||
|
|
||||||
# `cam_high` is the external camera image
|
|
||||||
cam_high = parse_img("cam_high")
|
|
||||||
# For step_id = first_idx - 1, the valid_len should be one
|
|
||||||
valid_len = min(step_id - (first_idx - 1) + 1, self.IMG_HISORY_SIZE)
|
|
||||||
cam_high_mask = np.array([False] * (self.IMG_HISORY_SIZE - valid_len) + [True] * valid_len)
|
|
||||||
# cam_left_wrist = parse_img("cam_left_wrist")
|
|
||||||
# cam_left_wrist_mask = cam_high_mask.copy()
|
|
||||||
cam_left_wrist = np.zeros((self.IMG_HISORY_SIZE, 0, 0, 0))#parse_img('cam_right_wrist')
|
|
||||||
cam_left_wrist_mask = np.array([False] * self.IMG_HISORY_SIZE)#cam_high_mask.copy()
|
|
||||||
cam_right_wrist = parse_img("cam_right_wrist")
|
|
||||||
cam_right_wrist_mask = cam_high_mask.copy() # 使用相同的掩码逻辑
|
|
||||||
|
|
||||||
# Return the resulting sample
|
|
||||||
# For unavailable images, return zero-shape arrays, i.e., (IMG_HISORY_SIZE, 0, 0, 0)
|
|
||||||
# E.g., return np.zeros((self.IMG_HISORY_SIZE, 0, 0, 0)) for the key "cam_left_wrist",
|
|
||||||
# if the left-wrist camera is unavailable on your robot
|
|
||||||
return True, {
|
|
||||||
"meta": meta,
|
|
||||||
"state": state,
|
|
||||||
"state_std": state_std,
|
|
||||||
"state_mean": state_mean,
|
|
||||||
"state_norm": state_norm,
|
|
||||||
"actions": actions,
|
|
||||||
"state_indicator": state_indicator,
|
|
||||||
"cam_high": cam_high,
|
|
||||||
"cam_high_mask": cam_high_mask,
|
|
||||||
"cam_left_wrist": cam_left_wrist,
|
|
||||||
"cam_left_wrist_mask": cam_left_wrist_mask,
|
|
||||||
"cam_right_wrist": cam_right_wrist,
|
|
||||||
"cam_right_wrist_mask": cam_right_wrist_mask,
|
|
||||||
}
|
|
||||||
|
|
||||||
def parse_hdf5_file_state_only(self, file_path):
|
|
||||||
"""[Modify] Parse a hdf5 file to generate a state trajectory.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path (str): the path to the hdf5 file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
valid (bool): whether the episode is valid, which is useful for filtering.
|
|
||||||
If False, this episode will be dropped.
|
|
||||||
dict: a dictionary containing the training sample,
|
|
||||||
{
|
|
||||||
"state": ndarray, # state[:], (T, STATE_DIM).
|
|
||||||
"action": ndarray, # action[:], (T, STATE_DIM).
|
|
||||||
} or None if the episode is invalid.
|
|
||||||
"""
|
|
||||||
with h5py.File(file_path, "r") as f:
|
|
||||||
qpos = f["observations"]["qpos"][:]
|
|
||||||
left_arm_dim = f["observations"]["left_arm_dim"][:]
|
|
||||||
right_arm_dim = f["observations"]["right_arm_dim"][:]
|
|
||||||
|
|
||||||
num_steps = qpos.shape[0]
|
|
||||||
# [Optional] We drop too-short episode
|
|
||||||
# if num_steps < 128:
|
|
||||||
# return False, None
|
|
||||||
|
|
||||||
# [Optional] We skip the first few still steps
|
|
||||||
EPS = 1e-2
|
|
||||||
# Get the idx of the first qpos whose delta exceeds the threshold
|
|
||||||
qpos_delta = np.abs(qpos - qpos[0:1])
|
|
||||||
indices = np.where(np.any(qpos_delta > EPS, axis=1))[0]
|
|
||||||
if len(indices) > 0:
|
|
||||||
first_idx = indices[0]
|
|
||||||
else:
|
|
||||||
raise ValueError("Found no qpos that exceeds the threshold.")
|
|
||||||
|
|
||||||
# Rescale gripper to [0, 1]
|
|
||||||
# qpos = qpos / np.array([[1 for i in range(left_arm_dim[0] + right_arm_dim[0] + 2)]])
|
|
||||||
# target_qpos = f["action"][:] / np.array([[1 for i in range(left_arm_dim[0] + right_arm_dim[0] + 2)]])
|
|
||||||
|
|
||||||
qpos = qpos / np.array(
|
|
||||||
# [[1, 1, 1, 1, 1, 1, 4.7908, 1, 1, 1, 1, 1, 1, 4.7888]]
|
|
||||||
[[180, 180, 180, 180, 180, 180]]
|
|
||||||
)
|
|
||||||
target_qpos = f['action'][first_idx - 1:] / np.array(
|
|
||||||
# [[1, 1, 1, 1, 1, 1, 11.8997, 1, 1, 1, 1, 1, 1, 13.9231]]
|
|
||||||
[[180, 180, 180, 180, 180, 180]]
|
|
||||||
)
|
|
||||||
# Parse the state and action
|
|
||||||
state = qpos[first_idx - 1:]
|
|
||||||
action = target_qpos[first_idx - 1:]
|
|
||||||
|
|
||||||
# Standardize trajectory length to avoid batch size mismatch
|
|
||||||
# Use a fixed length (e.g., 128) or pad/truncate to match
|
|
||||||
target_length = 128 # You can adjust this value
|
|
||||||
if state.shape[0] > target_length:
|
|
||||||
# Truncate to target length
|
|
||||||
state = state[:target_length]
|
|
||||||
action = action[:target_length]
|
|
||||||
elif state.shape[0] < target_length:
|
|
||||||
# Pad with the last state/action
|
|
||||||
pad_length = target_length - state.shape[0]
|
|
||||||
state = np.concatenate([state, np.tile(state[-1:], (pad_length, 1))], axis=0)
|
|
||||||
action = np.concatenate([action, np.tile(action[-1:], (pad_length, 1))], axis=0)
|
|
||||||
|
|
||||||
# Fill the state/action into the unified vector
|
|
||||||
def fill_in_state(values):
|
|
||||||
# Target indices corresponding to your state space
|
|
||||||
# In this example: 6 joints + 1 gripper for each arm
|
|
||||||
UNI_STATE_INDICES = [
|
|
||||||
STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(6)
|
|
||||||
# ] + [
|
|
||||||
# STATE_VEC_IDX_MAPPING["right_gripper_open"]
|
|
||||||
]
|
|
||||||
uni_vec = np.zeros(values.shape[:-1] + (self.STATE_DIM, ))
|
|
||||||
uni_vec[..., UNI_STATE_INDICES] = values
|
|
||||||
return uni_vec
|
|
||||||
|
|
||||||
state = fill_in_state(state)
|
|
||||||
action = fill_in_state(action)
|
|
||||||
|
|
||||||
# Return the resulting sample
|
|
||||||
return True, {"state": state, "action": action}
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
ds = HDF5VLADataset()
|
|
||||||
for i in range(len(ds)):
|
|
||||||
print(f"Processing episode {i}/{len(ds)}...")
|
|
||||||
ds.get_item(i)
|
|
||||||
@ -1,299 +0,0 @@
|
|||||||
import json
|
|
||||||
|
|
||||||
import tensorflow as tf
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from data.preprocess_scripts import *
|
|
||||||
from configs.state_vec import STATE_VEC_IDX_MAPPING, STATE_VEC_LEN
|
|
||||||
from data.utils import capitalize_and_period
|
|
||||||
|
|
||||||
# The dataset without state
|
|
||||||
DATASET_NAMES_NO_STATE = [
|
|
||||||
"nyu_door_opening_surprising_effectiveness",
|
|
||||||
"usc_cloth_sim_converted_externally_to_rlds",
|
|
||||||
"cmu_franka_exploration_dataset_converted_externally_to_rlds",
|
|
||||||
"imperialcollege_sawyer_wrist_cam",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Read the image keys of each dataset
|
|
||||||
with open("configs/dataset_img_keys.json", "r") as file:
|
|
||||||
IMAGE_KEYS = json.load(file)
|
|
||||||
# Read the config
|
|
||||||
with open("configs/base.yaml", "r") as file:
|
|
||||||
config = yaml.safe_load(file)
|
|
||||||
|
|
||||||
|
|
||||||
def assemble_state_vec(arm_concat: tf.Tensor, arm_format: str, base_concat=None, base_format=None) -> tf.Tensor:
|
|
||||||
"""
|
|
||||||
Assemble the state/action vector from the arm and base.
|
|
||||||
"""
|
|
||||||
state_vec = tf.zeros(STATE_VEC_LEN, dtype=tf.float32)
|
|
||||||
mask_vec = tf.zeros(STATE_VEC_LEN, dtype=tf.float32)
|
|
||||||
|
|
||||||
# Assemble the arm state
|
|
||||||
arm_concat = tf.cast(arm_concat, tf.float32)
|
|
||||||
arm_format = arm_format.split(",")
|
|
||||||
# Use the scatter_nd to avoid the duplicate indices
|
|
||||||
state_vec = tf.tensor_scatter_nd_update(state_vec, [[STATE_VEC_IDX_MAPPING[name]] for name in arm_format],
|
|
||||||
arm_concat)
|
|
||||||
mask_vec = tf.tensor_scatter_nd_update(
|
|
||||||
mask_vec,
|
|
||||||
[[STATE_VEC_IDX_MAPPING[name]] for name in arm_format],
|
|
||||||
tf.ones(len(arm_format), dtype=tf.float32),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assemble the base state if exists
|
|
||||||
if base_concat is not None:
|
|
||||||
base_concat = tf.cast(base_concat, tf.float32)
|
|
||||||
base_format = base_format.split(",")
|
|
||||||
state_vec = tf.tensor_scatter_nd_update(
|
|
||||||
state_vec,
|
|
||||||
[[STATE_VEC_IDX_MAPPING[name]] for name in base_format],
|
|
||||||
base_concat,
|
|
||||||
)
|
|
||||||
mask_vec = tf.tensor_scatter_nd_update(
|
|
||||||
mask_vec,
|
|
||||||
[[STATE_VEC_IDX_MAPPING[name]] for name in base_format],
|
|
||||||
tf.ones(len(base_format), dtype=tf.float32),
|
|
||||||
)
|
|
||||||
return state_vec, mask_vec
|
|
||||||
|
|
||||||
|
|
||||||
@tf.autograph.experimental.do_not_convert
|
|
||||||
def _generate_json_state_agilex(episode: dict, dataset_name: str):
|
|
||||||
"""
|
|
||||||
Generate the json dict and state for a given episode.
|
|
||||||
"""
|
|
||||||
# Load some constants from the config
|
|
||||||
IMG_HISTORY_SIZE = config["common"]["img_history_size"]
|
|
||||||
if IMG_HISTORY_SIZE < 1:
|
|
||||||
raise ValueError("Config `img_history_size` must be at least 1.")
|
|
||||||
ACTION_CHUNK_SIZE = config["common"]["action_chunk_size"]
|
|
||||||
if ACTION_CHUNK_SIZE < 1:
|
|
||||||
raise ValueError("Config `action_chunk_size` must be at least 1.")
|
|
||||||
|
|
||||||
# Initialize the episode_metadata
|
|
||||||
episode_metadata = {"dataset_name": dataset_name, "#steps": 0, "instruction": None}
|
|
||||||
|
|
||||||
# Check whether this episode has an 'END'
|
|
||||||
base_act = None
|
|
||||||
last_base_act = None
|
|
||||||
episode_states = []
|
|
||||||
episode_acts = []
|
|
||||||
episode_masks = []
|
|
||||||
has_base = None
|
|
||||||
for step_id, step in enumerate(iter(episode["steps"])):
|
|
||||||
# Parse the action
|
|
||||||
action = step["action"]
|
|
||||||
if has_base is None:
|
|
||||||
has_base = "base_concat" in action
|
|
||||||
if has_base:
|
|
||||||
base_act = action["base_concat"]
|
|
||||||
|
|
||||||
# Parse the state
|
|
||||||
state = step["observation"]
|
|
||||||
|
|
||||||
arm_format = state["format"].numpy().decode("utf-8")
|
|
||||||
base_format = None
|
|
||||||
if has_base:
|
|
||||||
act_format = action["format"].numpy().decode("utf-8")
|
|
||||||
base_formate_idx = act_format.find("base")
|
|
||||||
base_format = act_format[base_formate_idx:]
|
|
||||||
|
|
||||||
arm_state = state["arm_concat"]
|
|
||||||
base_state = None
|
|
||||||
if has_base:
|
|
||||||
if last_base_act is None:
|
|
||||||
base_state = base_act * 0
|
|
||||||
else:
|
|
||||||
base_state = last_base_act
|
|
||||||
last_base_act = base_act
|
|
||||||
|
|
||||||
# Assemble the state vector
|
|
||||||
state_vec, mask_vec = assemble_state_vec(arm_state, arm_format, base_state, base_format)
|
|
||||||
|
|
||||||
act_vec, mask_vec = assemble_state_vec(action["arm_concat"], arm_format, base_state, base_format)
|
|
||||||
|
|
||||||
episode_states.append(state_vec)
|
|
||||||
episode_masks.append(mask_vec)
|
|
||||||
episode_acts.append(act_vec)
|
|
||||||
|
|
||||||
# Parse the task instruction
|
|
||||||
instr = step["observation"]["natural_language_instruction"]
|
|
||||||
instr = instr.numpy().decode("utf-8")
|
|
||||||
instr = capitalize_and_period(instr)
|
|
||||||
|
|
||||||
# Write to the episode_metadata
|
|
||||||
if episode_metadata["instruction"] is None:
|
|
||||||
episode_metadata["instruction"] = instr
|
|
||||||
|
|
||||||
episode_metadata["#steps"] = step_id
|
|
||||||
|
|
||||||
episode_states = tf.stack(episode_states)
|
|
||||||
episode_masks = tf.stack(episode_masks)
|
|
||||||
episode_acts = tf.stack(episode_acts)
|
|
||||||
|
|
||||||
return episode_metadata, episode_states, episode_masks, episode_acts
|
|
||||||
|
|
||||||
|
|
||||||
@tf.autograph.experimental.do_not_convert
|
|
||||||
def _generate_json_state(episode: dict, dataset_name: str):
|
|
||||||
"""
|
|
||||||
Generate the json dict and state for a given episode.
|
|
||||||
"""
|
|
||||||
# Load some constants from the config
|
|
||||||
IMG_HISTORY_SIZE = config["common"]["img_history_size"]
|
|
||||||
if IMG_HISTORY_SIZE < 1:
|
|
||||||
raise ValueError("Config `img_history_size` must be at least 1.")
|
|
||||||
ACTION_CHUNK_SIZE = config["common"]["action_chunk_size"]
|
|
||||||
if ACTION_CHUNK_SIZE < 1:
|
|
||||||
raise ValueError("Config `action_chunk_size` must be at least 1.")
|
|
||||||
|
|
||||||
# Initialize the episode_metadata
|
|
||||||
episode_metadata = {"dataset_name": dataset_name, "#steps": 0, "instruction": None}
|
|
||||||
|
|
||||||
# Check whether this episode has an 'END'
|
|
||||||
base_act = None
|
|
||||||
last_base_act = None
|
|
||||||
episode_states = []
|
|
||||||
episode_masks = []
|
|
||||||
has_base = None
|
|
||||||
for step_id, step in enumerate(iter(episode["steps"])):
|
|
||||||
# Parse the action
|
|
||||||
action = step["action"]
|
|
||||||
if has_base is None:
|
|
||||||
has_base = "base_concat" in action
|
|
||||||
if has_base:
|
|
||||||
base_act = action["base_concat"]
|
|
||||||
|
|
||||||
# Parse the state
|
|
||||||
state = step["observation"]
|
|
||||||
|
|
||||||
arm_format = state["format"].numpy().decode("utf-8")
|
|
||||||
base_format = None
|
|
||||||
if has_base:
|
|
||||||
act_format = action["format"].numpy().decode("utf-8")
|
|
||||||
base_formate_idx = act_format.find("base")
|
|
||||||
base_format = act_format[base_formate_idx:]
|
|
||||||
|
|
||||||
arm_state = state["arm_concat"]
|
|
||||||
base_state = None
|
|
||||||
if has_base:
|
|
||||||
if last_base_act is None:
|
|
||||||
base_state = base_act * 0
|
|
||||||
else:
|
|
||||||
base_state = last_base_act
|
|
||||||
last_base_act = base_act
|
|
||||||
|
|
||||||
# Assemble the state vector
|
|
||||||
state_vec, mask_vec = assemble_state_vec(arm_state, arm_format, base_state, base_format)
|
|
||||||
|
|
||||||
episode_states.append(state_vec)
|
|
||||||
episode_masks.append(mask_vec)
|
|
||||||
|
|
||||||
# Parse the task instruction
|
|
||||||
instr = step["observation"]["natural_language_instruction"]
|
|
||||||
instr = instr.numpy().decode("utf-8")
|
|
||||||
instr = capitalize_and_period(instr)
|
|
||||||
|
|
||||||
# Write to the episode_metadata
|
|
||||||
if episode_metadata["instruction"] is None:
|
|
||||||
episode_metadata["instruction"] = instr
|
|
||||||
|
|
||||||
episode_metadata["#steps"] = step_id
|
|
||||||
episode_states = tf.stack(episode_states)
|
|
||||||
episode_masks = tf.stack(episode_masks)
|
|
||||||
|
|
||||||
return episode_metadata, episode_states, episode_masks
|
|
||||||
|
|
||||||
|
|
||||||
@tf.autograph.experimental.do_not_convert
|
|
||||||
def _generate_json_state_nostate_ds(episode: dict, dataset_name: str):
|
|
||||||
"""
|
|
||||||
Generate the json dict and state for an episode in the dataset without state.
|
|
||||||
If not state, we use the last action as current state.
|
|
||||||
"""
|
|
||||||
# Load some constants from the config
|
|
||||||
IMG_HISTORY_SIZE = config["common"]["img_history_size"]
|
|
||||||
if IMG_HISTORY_SIZE < 1:
|
|
||||||
raise ValueError("Config `img_history_size` must be at least 1.")
|
|
||||||
ACTION_CHUNK_SIZE = config["common"]["action_chunk_size"]
|
|
||||||
if ACTION_CHUNK_SIZE < 1:
|
|
||||||
raise ValueError("Config `action_chunk_size` must be at least 1.")
|
|
||||||
|
|
||||||
# Initialize the episode_metadata
|
|
||||||
episode_metadata = {"dataset_name": dataset_name, "#steps": 0, "instruction": None}
|
|
||||||
|
|
||||||
last_base_act = None
|
|
||||||
last_arm_act = None
|
|
||||||
episode_states = []
|
|
||||||
episode_masks = []
|
|
||||||
has_base = None
|
|
||||||
for step_id, step in enumerate(iter(episode["steps"])):
|
|
||||||
# Parse the action
|
|
||||||
action = step["action"]
|
|
||||||
if has_base is None:
|
|
||||||
has_base = "base_concat" in action
|
|
||||||
if has_base:
|
|
||||||
base_act = action["base_concat"]
|
|
||||||
if last_base_act is None:
|
|
||||||
last_base_act = base_act * 0 # Initialize
|
|
||||||
|
|
||||||
# Parse the arm action
|
|
||||||
arm_act = action["arm_concat"]
|
|
||||||
if last_arm_act is None:
|
|
||||||
last_arm_act = arm_act * 0 # Initialize
|
|
||||||
|
|
||||||
# Parse the act format
|
|
||||||
# Action format as the state format
|
|
||||||
act_format = action["format"].numpy().decode("utf-8")
|
|
||||||
|
|
||||||
# Assemble the state vector
|
|
||||||
if has_base:
|
|
||||||
last_act_concat = tf.concat([last_arm_act, last_base_act], axis=0)
|
|
||||||
else:
|
|
||||||
last_act_concat = last_arm_act
|
|
||||||
state_vec, mask_vec = assemble_state_vec(last_act_concat, act_format)
|
|
||||||
|
|
||||||
episode_states.append(state_vec)
|
|
||||||
episode_masks.append(mask_vec)
|
|
||||||
|
|
||||||
# Parse the task instruction
|
|
||||||
instr = step["observation"]["natural_language_instruction"]
|
|
||||||
instr = instr.numpy().decode("utf-8")
|
|
||||||
instr = capitalize_and_period(instr)
|
|
||||||
|
|
||||||
# Write to the episode_metadata
|
|
||||||
if episode_metadata["instruction"] is None:
|
|
||||||
episode_metadata["instruction"] = instr
|
|
||||||
|
|
||||||
# Update the last_arm_act and last_base_act
|
|
||||||
last_arm_act = arm_act
|
|
||||||
if has_base:
|
|
||||||
last_base_act = base_act
|
|
||||||
|
|
||||||
episode_metadata["#steps"] = step_id
|
|
||||||
episode_states = tf.stack(episode_states)
|
|
||||||
episode_masks = tf.stack(episode_masks)
|
|
||||||
|
|
||||||
return episode_metadata, episode_states, episode_masks
|
|
||||||
|
|
||||||
|
|
||||||
@tf.autograph.experimental.do_not_convert
|
|
||||||
def generate_json_state(episode: dict, dataset_name: str):
|
|
||||||
"""
|
|
||||||
Generate the json dict and state for an episode.
|
|
||||||
"""
|
|
||||||
if isinstance(dataset_name, tf.Tensor):
|
|
||||||
dataset_name = dataset_name.numpy().decode("utf-8")
|
|
||||||
|
|
||||||
# Process each step in the episode
|
|
||||||
episode["steps"] = episode["steps"].map(globals()[dataset_name].process_step, )
|
|
||||||
|
|
||||||
if dataset_name == "agilex":
|
|
||||||
return _generate_json_state_agilex(episode, dataset_name)
|
|
||||||
|
|
||||||
if dataset_name in DATASET_NAMES_NO_STATE:
|
|
||||||
return _generate_json_state_nostate_ds(episode, dataset_name)
|
|
||||||
|
|
||||||
return _generate_json_state(episode, dataset_name)
|
|
||||||
@ -1,313 +0,0 @@
|
|||||||
import time
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import argparse
|
|
||||||
import sys
|
|
||||||
import signal
|
|
||||||
import random
|
|
||||||
from multiprocessing import Process
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow as tf
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from data.vla_dataset import VLADataset
|
|
||||||
from data.filelock import FileLock
|
|
||||||
|
|
||||||
# Producer does not need GPU
|
|
||||||
tf.config.set_visible_devices([], "GPU")
|
|
||||||
|
|
||||||
# Read the config
|
|
||||||
with open("configs/base.yaml", "r") as file:
|
|
||||||
config = yaml.safe_load(file)
|
|
||||||
# Load some constants from the config
|
|
||||||
BUF_PATH = config["dataset"]["buf_path"]
|
|
||||||
BUF_NUM_CHUNKS = config["dataset"]["buf_num_chunks"]
|
|
||||||
if BUF_NUM_CHUNKS < 1:
|
|
||||||
raise ValueError("Config `buf_num_chunks` must be at least 1.")
|
|
||||||
BUF_CHUNK_SIZE = config["dataset"]["buf_chunk_size"]
|
|
||||||
if BUF_CHUNK_SIZE < 1:
|
|
||||||
raise ValueError("Config `buf_chunk_size` must be at least 1.")
|
|
||||||
|
|
||||||
|
|
||||||
def get_dirty_item(chunk_dir):
|
|
||||||
"""
|
|
||||||
Get indexes of dirty items in a chunk.
|
|
||||||
"""
|
|
||||||
dirty_bit = read_dirty_bit(chunk_dir)
|
|
||||||
return np.where(dirty_bit)[0].tolist()
|
|
||||||
|
|
||||||
|
|
||||||
def get_clean_item(chunk_dir):
|
|
||||||
"""
|
|
||||||
Get indexes of clean items in a chunk.
|
|
||||||
"""
|
|
||||||
dirty_bit = read_dirty_bit(chunk_dir)
|
|
||||||
return np.where(1 - dirty_bit)[0].tolist()
|
|
||||||
|
|
||||||
|
|
||||||
def save_dirty_bit(chunk_dir, dirty_bit):
|
|
||||||
"""
|
|
||||||
Save the dirty bit to the chunk directory.
|
|
||||||
"""
|
|
||||||
time_stmp = time.time()
|
|
||||||
while time.time() - time_stmp < 10.0:
|
|
||||||
try:
|
|
||||||
file_path = os.path.join(chunk_dir, "dirty_bit")
|
|
||||||
lock = FileLock(file_path)
|
|
||||||
lock.acquire_write_lock()
|
|
||||||
with open(file_path, "wb") as file:
|
|
||||||
file.write(dirty_bit.tobytes())
|
|
||||||
lock.release_lock()
|
|
||||||
return
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
lock.release_lock()
|
|
||||||
raise KeyboardInterrupt
|
|
||||||
except BaseException:
|
|
||||||
lock.release_lock()
|
|
||||||
continue
|
|
||||||
# raise RuntimeError("Failed to save dirty bit.")
|
|
||||||
print("Failed to save dirty bit.")
|
|
||||||
|
|
||||||
|
|
||||||
def read_dirty_bit(chunk_dir):
|
|
||||||
"""
|
|
||||||
Read the dirty bit from the chunk directory.
|
|
||||||
"""
|
|
||||||
# If error occurs, retry
|
|
||||||
time_stmp = time.time()
|
|
||||||
while time.time() - time_stmp < 10.0:
|
|
||||||
try:
|
|
||||||
file_path = os.path.join(chunk_dir, "dirty_bit")
|
|
||||||
lock = FileLock(file_path)
|
|
||||||
lock.acquire_read_lock()
|
|
||||||
with open(file_path, "rb") as file:
|
|
||||||
dirty_bit = np.frombuffer(file.read(), dtype=np.uint8).copy()
|
|
||||||
lock.release_lock()
|
|
||||||
assert len(dirty_bit) == BUF_CHUNK_SIZE
|
|
||||||
return dirty_bit
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
lock.release_lock()
|
|
||||||
raise KeyboardInterrupt
|
|
||||||
except BaseException:
|
|
||||||
lock.release_lock()
|
|
||||||
continue
|
|
||||||
# If failed to read the dirty bit, return all ones for robustness
|
|
||||||
return np.ones(BUF_CHUNK_SIZE, dtype=np.uint8)
|
|
||||||
|
|
||||||
|
|
||||||
def save_sample(step_dict, chunk_dir, chunk_item_idx):
|
|
||||||
"""
|
|
||||||
Save a sample to the chunk directory.
|
|
||||||
"""
|
|
||||||
# Save the json content
|
|
||||||
time_stmp = time.time()
|
|
||||||
while time.time() - time_stmp < 10.0:
|
|
||||||
try:
|
|
||||||
locks = []
|
|
||||||
json_content = step_dict["json_content"]
|
|
||||||
file_path = os.path.join(chunk_dir, f"json_content_{chunk_item_idx}.json")
|
|
||||||
lock = FileLock(file_path)
|
|
||||||
locks.append(lock)
|
|
||||||
lock.acquire_write_lock()
|
|
||||||
with open(file_path, "w") as file:
|
|
||||||
json.dump(json_content, file, indent=4)
|
|
||||||
lock.release_lock()
|
|
||||||
# Save all other tensors in a npz
|
|
||||||
file_path = os.path.join(chunk_dir, f"sample_{chunk_item_idx}.npz")
|
|
||||||
lock = FileLock(file_path)
|
|
||||||
locks.append(lock)
|
|
||||||
lock.acquire_write_lock()
|
|
||||||
with open(file_path, "wb") as file:
|
|
||||||
np.savez(
|
|
||||||
file,
|
|
||||||
step_id=step_dict["step_id"].numpy(),
|
|
||||||
state_chunk=step_dict["state_chunk"].numpy(),
|
|
||||||
state_chunk_time_mask=step_dict["state_chunk_time_mask"].numpy(),
|
|
||||||
action_chunk=step_dict["action_chunk"].numpy(),
|
|
||||||
action_chunk_time_mask=step_dict["action_chunk_time_mask"].numpy(),
|
|
||||||
state_vec_mask=step_dict["state_vec_mask"].numpy(),
|
|
||||||
past_frames_0=step_dict["past_frames_0"].numpy(),
|
|
||||||
past_frames_0_time_mask=step_dict["past_frames_0_time_mask"].numpy(),
|
|
||||||
past_frames_1=step_dict["past_frames_1"].numpy(),
|
|
||||||
past_frames_1_time_mask=step_dict["past_frames_1_time_mask"].numpy(),
|
|
||||||
past_frames_2=step_dict["past_frames_2"].numpy(),
|
|
||||||
past_frames_2_time_mask=step_dict["past_frames_2_time_mask"].numpy(),
|
|
||||||
past_frames_3=step_dict["past_frames_3"].numpy(),
|
|
||||||
past_frames_3_time_mask=step_dict["past_frames_3_time_mask"].numpy(),
|
|
||||||
state_std=step_dict["state_std"].numpy(),
|
|
||||||
state_mean=step_dict["state_mean"].numpy(),
|
|
||||||
state_norm=step_dict["state_norm"].numpy(),
|
|
||||||
)
|
|
||||||
lock.release_lock()
|
|
||||||
return
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
for lock in locks:
|
|
||||||
lock.release_lock()
|
|
||||||
raise KeyboardInterrupt
|
|
||||||
except BaseException:
|
|
||||||
for lock in locks:
|
|
||||||
lock.release_lock()
|
|
||||||
continue
|
|
||||||
# raise RuntimeError("Failed to save sample.")
|
|
||||||
print("Failed to save sample.")
|
|
||||||
|
|
||||||
|
|
||||||
def run_producer(seed, num_workers, worker_id, fill_up, clean_dirty, dataset_type):
|
|
||||||
"""
|
|
||||||
Run the producer.
|
|
||||||
The producer will first fill up the buffer with samples.
|
|
||||||
Then it will keep replacing dirty samples
|
|
||||||
(i.e., samples that have been read by the consumer)
|
|
||||||
with new samples.
|
|
||||||
"""
|
|
||||||
vla_dataset = VLADataset(seed=seed, dataset_type=dataset_type)
|
|
||||||
chunk_start_idx = worker_id * BUF_NUM_CHUNKS // num_workers
|
|
||||||
chunk_end_idx = (worker_id + 1) * BUF_NUM_CHUNKS // num_workers
|
|
||||||
if fill_up:
|
|
||||||
print(f"Worker {worker_id}: Start filling up the buffer...")
|
|
||||||
elif clean_dirty:
|
|
||||||
# Only refresh the dirty bits
|
|
||||||
print(f"Worker {worker_id}: Start refreshing the dirty bits...")
|
|
||||||
for chunk_idx in range(chunk_start_idx, chunk_end_idx):
|
|
||||||
chunk_dir = os.path.join(BUF_PATH, f"chunk_{chunk_idx}")
|
|
||||||
dirty_bit = np.zeros(BUF_CHUNK_SIZE, dtype=np.uint8)
|
|
||||||
save_dirty_bit(chunk_dir, dirty_bit)
|
|
||||||
print(f"Worker {worker_id}: Refreshed the dirty bits.")
|
|
||||||
|
|
||||||
fill_chunk_idx = chunk_start_idx
|
|
||||||
fill_chunk_item_idx = 0
|
|
||||||
dirty_chunk_idx = chunk_start_idx
|
|
||||||
dirty_chunk_item_idxs = []
|
|
||||||
time_stmp = time.time()
|
|
||||||
for episode_steps in vla_dataset:
|
|
||||||
for step in episode_steps:
|
|
||||||
if fill_up and fill_chunk_idx < chunk_end_idx:
|
|
||||||
# Fill up the buffer
|
|
||||||
chunk_dir = os.path.join(BUF_PATH, f"chunk_{fill_chunk_idx}")
|
|
||||||
if fill_chunk_item_idx == 0:
|
|
||||||
# Create a new chunk
|
|
||||||
os.makedirs(chunk_dir, exist_ok=True)
|
|
||||||
# Write the dirty bit of size BUF_CHUNK_SIZE
|
|
||||||
dirty_bit = np.zeros(BUF_CHUNK_SIZE, dtype=np.uint8)
|
|
||||||
save_dirty_bit(chunk_dir, dirty_bit)
|
|
||||||
|
|
||||||
# Save the sample
|
|
||||||
save_sample(step, chunk_dir, fill_chunk_item_idx)
|
|
||||||
|
|
||||||
# print(f"Filled up chunk {fill_chunk_item_idx+1}/{BUF_CHUNK_SIZE} {fill_chunk_idx+1}/{BUF_NUM_CHUNKS}")
|
|
||||||
local_fill_chunk_idx = fill_chunk_idx - chunk_start_idx
|
|
||||||
local_num_chunks = chunk_end_idx - chunk_start_idx
|
|
||||||
if (local_fill_chunk_idx % 10 == 0
|
|
||||||
or local_fill_chunk_idx == local_num_chunks - 1) and fill_chunk_item_idx == 0:
|
|
||||||
print(f"Worker {worker_id}: Filled up chunk {local_fill_chunk_idx+1}/{local_num_chunks}")
|
|
||||||
fill_chunk_item_idx += 1
|
|
||||||
if fill_chunk_item_idx == BUF_CHUNK_SIZE:
|
|
||||||
fill_chunk_idx += 1
|
|
||||||
fill_chunk_item_idx = 0
|
|
||||||
if fill_chunk_idx == BUF_NUM_CHUNKS:
|
|
||||||
print(f"Worker {worker_id}: Buffer filled up. Start replacing dirty samples...")
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Search for the dirty chunk to replace
|
|
||||||
while len(dirty_chunk_item_idxs) == 0:
|
|
||||||
dirty_chunk_dir = os.path.join(BUF_PATH, f"chunk_{dirty_chunk_idx}")
|
|
||||||
dirty_chunk_item_idxs = get_dirty_item(dirty_chunk_dir)
|
|
||||||
# Print the dirty ratio
|
|
||||||
if time.time() - time_stmp > 2.0:
|
|
||||||
dirty_ratio = len(dirty_chunk_item_idxs) / BUF_CHUNK_SIZE
|
|
||||||
print(f"Worker {worker_id}: Dirty Ratio for Chunk {dirty_chunk_idx}: {dirty_ratio:.2f}")
|
|
||||||
time_stmp = time.time()
|
|
||||||
|
|
||||||
if len(dirty_chunk_item_idxs) > 0:
|
|
||||||
# Lock the chunk
|
|
||||||
dirty_bit = np.ones(BUF_CHUNK_SIZE, dtype=np.uint8)
|
|
||||||
save_dirty_bit(dirty_chunk_dir, dirty_bit)
|
|
||||||
|
|
||||||
# Iterate over the chunks
|
|
||||||
dirty_chunk_idx += 1
|
|
||||||
if dirty_chunk_idx == chunk_end_idx:
|
|
||||||
dirty_chunk_idx = chunk_start_idx
|
|
||||||
|
|
||||||
# Replace the dirty item
|
|
||||||
dirty_item_idx = dirty_chunk_item_idxs.pop()
|
|
||||||
chunk_dir = os.path.join(BUF_PATH, f"chunk_{dirty_chunk_idx}")
|
|
||||||
# Save the sample
|
|
||||||
save_sample(step, chunk_dir, dirty_item_idx)
|
|
||||||
|
|
||||||
# If we have replaced all dirty items in the chunk
|
|
||||||
if len(dirty_chunk_item_idxs) == 0:
|
|
||||||
# Unlock the chunk
|
|
||||||
dirty_bit = np.zeros(BUF_CHUNK_SIZE, dtype=np.uint8)
|
|
||||||
save_dirty_bit(dirty_chunk_dir, dirty_bit)
|
|
||||||
print(f"Worker {worker_id}: Replaced dirty chunk {dirty_chunk_idx}.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Args: n_workers, fill_up
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--n_workers",
|
|
||||||
type=int,
|
|
||||||
default=2,
|
|
||||||
help="Number of parallel workers. It should be less than or equal to the number of chunks.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--fill_up",
|
|
||||||
action="store_true",
|
|
||||||
help="Whether to fill up the buffer before replacing dirty samples.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--clean_dirty",
|
|
||||||
action="store_true",
|
|
||||||
help=
|
|
||||||
"Whether to clean the dirty bits before replacing dirty samples. This option is ignored when `fill_up` is set.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--seed",
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help="Random seed. If not set, the seed will be randomly generated.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--dataset_type",
|
|
||||||
type=str,
|
|
||||||
default="pretrain",
|
|
||||||
help="Whether to load the pretrain dataset or finetune dataset.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run the producer
|
|
||||||
args = parser.parse_args()
|
|
||||||
if args.seed is not None:
|
|
||||||
print(f"Base seed: {args.seed}")
|
|
||||||
random.seed(args.seed)
|
|
||||||
|
|
||||||
processes = []
|
|
||||||
process_seeds = [random.randint(0, 2**32) for _ in range(args.n_workers)]
|
|
||||||
print(f"Process seeds: {process_seeds}")
|
|
||||||
|
|
||||||
def signal_handler(sig, frame):
|
|
||||||
print("Ctrl+C received. Terminating child processes...")
|
|
||||||
for p in processes:
|
|
||||||
p.terminate()
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
for worker_id in range(args.n_workers):
|
|
||||||
p = Process(
|
|
||||||
target=run_producer,
|
|
||||||
args=(
|
|
||||||
process_seeds[worker_id],
|
|
||||||
args.n_workers,
|
|
||||||
worker_id,
|
|
||||||
args.fill_up,
|
|
||||||
args.clean_dirty,
|
|
||||||
args.dataset_type,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
p.start()
|
|
||||||
processes.append(p)
|
|
||||||
|
|
||||||
for p in processes:
|
|
||||||
p.join()
|
|
||||||
@ -1,242 +0,0 @@
|
|||||||
import tensorflow as tf
|
|
||||||
import tensorflow_graphics.geometry.transformation.euler as tf_euler
|
|
||||||
import tensorflow_graphics.geometry.transformation.quaternion as tf_quat
|
|
||||||
import tensorflow_graphics.geometry.transformation.rotation_matrix_3d as tf_rotmat
|
|
||||||
|
|
||||||
|
|
||||||
def dataset_to_path(dataset_name: str, dir_name: str) -> str:
|
|
||||||
"""
|
|
||||||
Return the path to the dataset.
|
|
||||||
"""
|
|
||||||
if (dataset_name == "robo_net" or dataset_name == "cmu_playing_with_food" or dataset_name == "droid"):
|
|
||||||
version = "1.0.0"
|
|
||||||
elif (dataset_name == "language_table" or dataset_name == "fmb" or dataset_name == "dobbe"):
|
|
||||||
version = "0.0.1"
|
|
||||||
elif dataset_name == "nyu_door_opening_surprising_effectiveness":
|
|
||||||
version = ""
|
|
||||||
elif dataset_name == "cmu_play_fusion":
|
|
||||||
version = ""
|
|
||||||
elif dataset_name == "berkeley_gnm_recon":
|
|
||||||
version = ""
|
|
||||||
else:
|
|
||||||
version = "0.1.0"
|
|
||||||
return f"{dir_name}/{dataset_name}/{version}"
|
|
||||||
|
|
||||||
|
|
||||||
def clean_task_instruction(task_instruction: tf.Tensor, replacements: dict) -> tf.Tensor:
|
|
||||||
"""
|
|
||||||
Clean up the natural language task instruction.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Create a function that applies all replacements
|
|
||||||
def apply_replacements(tensor):
|
|
||||||
for old, new in replacements.items():
|
|
||||||
tensor = tf.strings.regex_replace(tensor, old, new)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
# Apply the replacements and strip leading and trailing spaces
|
|
||||||
cleaned_task_instruction = apply_replacements(task_instruction)
|
|
||||||
cleaned_task_instruction = tf.strings.strip(cleaned_task_instruction)
|
|
||||||
return cleaned_task_instruction
|
|
||||||
|
|
||||||
|
|
||||||
def quaternion_to_euler(quaternion: tf.Tensor) -> tf.Tensor:
|
|
||||||
"""
|
|
||||||
Convert a quaternion (x, y, z, w) to Euler angles (roll, pitch, yaw).
|
|
||||||
The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention.
|
|
||||||
"""
|
|
||||||
# Normalize the quaternion
|
|
||||||
quaternion = tf.nn.l2_normalize(quaternion, axis=-1)
|
|
||||||
return tf_euler.from_quaternion(quaternion)
|
|
||||||
|
|
||||||
|
|
||||||
def euler_to_quaternion(euler: tf.Tensor) -> tf.Tensor:
|
|
||||||
"""
|
|
||||||
Convert Euler angles (roll, pitch, yaw) to a quaternion (x, y, z, w).
|
|
||||||
The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention.
|
|
||||||
"""
|
|
||||||
quaternion = tf_quat.from_euler(euler)
|
|
||||||
return tf.nn.l2_normalize(quaternion, axis=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def rotation_matrix_to_euler(matrix: tf.Tensor) -> tf.Tensor:
|
|
||||||
"""
|
|
||||||
Convert a 3x3 rotation matrix to Euler angles (roll, pitch, yaw).
|
|
||||||
The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention.
|
|
||||||
"""
|
|
||||||
return tf_euler.from_rotation_matrix(matrix)
|
|
||||||
|
|
||||||
|
|
||||||
def rotation_matrix_to_quaternion(matrix: tf.Tensor) -> tf.Tensor:
|
|
||||||
"""
|
|
||||||
Convert a 3x3 rotation matrix to a quaternion (x, y, z, w).
|
|
||||||
"""
|
|
||||||
quaternion = tf_quat.from_rotation_matrix(matrix)
|
|
||||||
return tf.nn.l2_normalize(quaternion, axis=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def euler_to_rotation_matrix(euler: tf.Tensor) -> tf.Tensor:
|
|
||||||
"""
|
|
||||||
Convert Euler angles (roll, pitch, yaw) to a 3x3 rotation matrix.
|
|
||||||
The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention.
|
|
||||||
"""
|
|
||||||
return tf_rotmat.from_euler(euler)
|
|
||||||
|
|
||||||
|
|
||||||
def quaternion_to_rotation_matrix(quaternion: tf.Tensor) -> tf.Tensor:
|
|
||||||
"""
|
|
||||||
Convert a quaternion (x, y, z, w) to a 3x3 rotation matrix.
|
|
||||||
"""
|
|
||||||
# Normalize the quaternion
|
|
||||||
quaternion = tf.nn.l2_normalize(quaternion, axis=-1)
|
|
||||||
return tf_rotmat.from_quaternion(quaternion)
|
|
||||||
|
|
||||||
|
|
||||||
def quaternion_to_rotation_matrix_wo_static_check(quaternion: tf.Tensor) -> tf.Tensor:
|
|
||||||
"""
|
|
||||||
Convert a quaternion (x, y, z, w) to a 3x3 rotation matrix.
|
|
||||||
This function is used to make tensorflow happy.
|
|
||||||
"""
|
|
||||||
# Normalize the quaternion
|
|
||||||
quaternion = tf.nn.l2_normalize(quaternion, axis=-1)
|
|
||||||
|
|
||||||
x = quaternion[..., 0]
|
|
||||||
y = quaternion[..., 1]
|
|
||||||
z = quaternion[..., 2]
|
|
||||||
w = quaternion[..., 3]
|
|
||||||
|
|
||||||
tx = 2.0 * x
|
|
||||||
ty = 2.0 * y
|
|
||||||
tz = 2.0 * z
|
|
||||||
twx = tx * w
|
|
||||||
twy = ty * w
|
|
||||||
twz = tz * w
|
|
||||||
txx = tx * x
|
|
||||||
txy = ty * x
|
|
||||||
txz = tz * x
|
|
||||||
tyy = ty * y
|
|
||||||
tyz = tz * y
|
|
||||||
tzz = tz * z
|
|
||||||
matrix = tf.stack(
|
|
||||||
(
|
|
||||||
1.0 - (tyy + tzz),
|
|
||||||
txy - twz,
|
|
||||||
txz + twy,
|
|
||||||
txy + twz,
|
|
||||||
1.0 - (txx + tzz),
|
|
||||||
tyz - twx,
|
|
||||||
txz - twy,
|
|
||||||
tyz + twx,
|
|
||||||
1.0 - (txx + tyy),
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
) # pyformat: disable
|
|
||||||
output_shape = tf.concat((tf.shape(input=quaternion)[:-1], (3, 3)), axis=-1)
|
|
||||||
return tf.reshape(matrix, shape=output_shape)
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
Below is a continuous 6D rotation representation adapted from
|
|
||||||
On the Continuity of Rotation Representations in Neural Networks
|
|
||||||
https://arxiv.org/pdf/1812.07035.pdf
|
|
||||||
https://github.com/papagina/RotationContinuity/blob/master/sanity_test/code/tools.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def rotation_matrix_to_ortho6d(matrix: tf.Tensor) -> tf.Tensor:
|
|
||||||
"""
|
|
||||||
The orhto6d represents the first two column vectors a1 and a2 of the
|
|
||||||
rotation matrix: [ | , |, | ]
|
|
||||||
[ a1, a2, a3]
|
|
||||||
[ | , |, | ]
|
|
||||||
Input: (A1, ..., An, 3, 3)
|
|
||||||
Output: (A1, ..., An, 6)
|
|
||||||
"""
|
|
||||||
ortho6d = matrix[..., :, :2]
|
|
||||||
# Transpose the last two dimension
|
|
||||||
perm = list(range(len(ortho6d.shape)))
|
|
||||||
perm[-2], perm[-1] = perm[-1], perm[-2]
|
|
||||||
ortho6d = tf.transpose(ortho6d, perm)
|
|
||||||
# Flatten the last two dimension
|
|
||||||
ortho6d = tf.reshape(ortho6d, ortho6d.shape[:-2] + [6])
|
|
||||||
return ortho6d
|
|
||||||
|
|
||||||
|
|
||||||
def rotation_matrix_to_ortho6d_1d(matrix: tf.Tensor) -> tf.Tensor:
|
|
||||||
"""
|
|
||||||
The orhto6d represents the first two column vectors a1 and a2 of the
|
|
||||||
rotation matrix: [ | , |, | ]
|
|
||||||
[ a1, a2, a3]
|
|
||||||
[ | , |, | ]
|
|
||||||
Input: (3, 3)
|
|
||||||
Output: (6,)
|
|
||||||
This function is used to make tensorflow happy.
|
|
||||||
"""
|
|
||||||
ortho6d = matrix[:, :2]
|
|
||||||
# Transpose the last two dimension
|
|
||||||
ortho6d = tf.transpose(ortho6d)
|
|
||||||
# Flatten the last two dimension
|
|
||||||
ortho6d = tf.reshape(ortho6d, [6])
|
|
||||||
return ortho6d
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_vector(v):
|
|
||||||
"""
|
|
||||||
v: (..., N)
|
|
||||||
"""
|
|
||||||
v_mag = tf.sqrt(tf.reduce_sum(tf.square(v), axis=-1, keepdims=True))
|
|
||||||
v_mag = tf.maximum(v_mag, 1e-8)
|
|
||||||
v_normalized = v / v_mag
|
|
||||||
|
|
||||||
return v_normalized
|
|
||||||
|
|
||||||
|
|
||||||
def cross_product(u, v):
|
|
||||||
"""
|
|
||||||
u: (..., 3)
|
|
||||||
v: (..., 3)
|
|
||||||
u x v: (..., 3)
|
|
||||||
"""
|
|
||||||
i = u[..., 1] * v[..., 2] - u[..., 2] * v[..., 1]
|
|
||||||
j = u[..., 2] * v[..., 0] - u[..., 0] * v[..., 2]
|
|
||||||
k = u[..., 0] * v[..., 1] - u[..., 1] * v[..., 0]
|
|
||||||
out = tf.stack([i, j, k], axis=-1)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def ortho6d_to_rotation_matrix(ortho6d: tf.Tensor) -> tf.Tensor:
|
|
||||||
"""
|
|
||||||
The orhto6d represents the first two column vectors a1 and a2 of the
|
|
||||||
rotation matrix: [ | , |, | ]
|
|
||||||
[ a1, a2, a3]
|
|
||||||
[ | , |, | ]
|
|
||||||
Input: (A1, ..., An, 6)
|
|
||||||
Output: (A1, ..., An, 3, 3)
|
|
||||||
"""
|
|
||||||
x_raw = ortho6d[..., 0:3]
|
|
||||||
y_raw = ortho6d[..., 3:6]
|
|
||||||
|
|
||||||
x = normalize_vector(x_raw)
|
|
||||||
z = cross_product(x, y_raw)
|
|
||||||
z = normalize_vector(z)
|
|
||||||
y = cross_product(z, x)
|
|
||||||
|
|
||||||
# Stack x, y, z to form the matrix
|
|
||||||
matrix = tf.stack([x, y, z], axis=-1)
|
|
||||||
return matrix
|
|
||||||
|
|
||||||
|
|
||||||
def capitalize_and_period(instr: str) -> str:
|
|
||||||
"""
|
|
||||||
Capitalize the first letter of a string and add a period to the end if it's not there.
|
|
||||||
"""
|
|
||||||
if len(instr) > 0:
|
|
||||||
# if the first letter is not capital, make it so
|
|
||||||
if not instr[0].isupper():
|
|
||||||
# if the first letter is not capital, make it so
|
|
||||||
instr = instr[0].upper() + instr[1:]
|
|
||||||
# add period to the end if it's not there
|
|
||||||
if instr[-1] != ".":
|
|
||||||
# add period to the end if it's not there
|
|
||||||
instr = instr + "."
|
|
||||||
return instr
|
|
||||||
@ -1,149 +0,0 @@
|
|||||||
import json
|
|
||||||
import random
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow as tf
|
|
||||||
import tensorflow_datasets as tfds
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from data.episode_transform import (
|
|
||||||
process_episode,
|
|
||||||
flatten_episode,
|
|
||||||
flatten_episode_agilex,
|
|
||||||
bgr_to_rgb,
|
|
||||||
)
|
|
||||||
from data.utils import dataset_to_path
|
|
||||||
from data.preprocess_scripts import *
|
|
||||||
|
|
||||||
# Producer does not need GPU
|
|
||||||
tf.config.set_visible_devices([], "GPU")
|
|
||||||
|
|
||||||
OPENX_EMBOD_DIR = "data/datasets/openx_embod"
|
|
||||||
|
|
||||||
DATASET_NAMES_NOOPENX = [
|
|
||||||
"aloha_mobile",
|
|
||||||
"aloha_static",
|
|
||||||
"roboset",
|
|
||||||
"agilex",
|
|
||||||
"rh20t",
|
|
||||||
"calvin",
|
|
||||||
"bridgev2",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Read the config
|
|
||||||
with open("configs/base.yaml", "r") as file:
|
|
||||||
config = yaml.safe_load(file)
|
|
||||||
# Load some constants from the config
|
|
||||||
EPSD_LEN_THRESH_LOW = config["dataset"]["epsd_len_thresh_low"]
|
|
||||||
EPSD_LEN_THRESH_HIGH = config["dataset"]["epsd_len_thresh_high"]
|
|
||||||
# Read the image keys of each dataset
|
|
||||||
with open("configs/dataset_img_keys.json", "r") as file:
|
|
||||||
IMAGE_KEYS = json.load(file)
|
|
||||||
|
|
||||||
|
|
||||||
class VLADataset:
|
|
||||||
"""
|
|
||||||
This class is used to sample episodes from the embododiment dataset.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, seed, dataset_type, repeat=True):
|
|
||||||
"""
|
|
||||||
seed: the random seed
|
|
||||||
dataset_type: 'pretrain' or 'finetune', which dataset to load
|
|
||||||
repeat: whether to repeat to infinite length
|
|
||||||
"""
|
|
||||||
dataset_names_cfg = ("configs/pretrain_datasets.json"
|
|
||||||
if dataset_type == "pretrain" else "configs/finetune_datasets.json")
|
|
||||||
with open(dataset_names_cfg, "r") as file:
|
|
||||||
DATASET_NAMES = json.load(file)
|
|
||||||
self.dataset_names = DATASET_NAMES
|
|
||||||
sample_weights_cfg = ("configs/pretrain_sample_weights.json"
|
|
||||||
if dataset_type == "pretrain" else "configs/finetune_sample_weights.json")
|
|
||||||
# Load the sample weights
|
|
||||||
with open(sample_weights_cfg, "r") as file:
|
|
||||||
SAMPLE_WEIGHTS = json.load(file)
|
|
||||||
self.openx_dir = OPENX_EMBOD_DIR
|
|
||||||
self.epsd_len_thresh_low = EPSD_LEN_THRESH_LOW
|
|
||||||
self.epsd_len_thresh_high = EPSD_LEN_THRESH_HIGH
|
|
||||||
self.repeat = repeat
|
|
||||||
|
|
||||||
# Set the random seed
|
|
||||||
tf.random.set_seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
|
|
||||||
# Weights of the each dataset in the collection to sample from
|
|
||||||
sample_weights = []
|
|
||||||
|
|
||||||
self.name2dataset = {}
|
|
||||||
for dataset_name in self.dataset_names:
|
|
||||||
if dataset_name in DATASET_NAMES_NOOPENX:
|
|
||||||
dataset = globals()[dataset_name].load_dataset(seed)
|
|
||||||
else:
|
|
||||||
dataset_path = dataset_to_path(dataset_name, self.openx_dir)
|
|
||||||
dataset = tfds.builder_from_directory(builder_dir=dataset_path)
|
|
||||||
dataset = dataset.as_dataset(split="all", shuffle_files=True)
|
|
||||||
|
|
||||||
# You can add filter for other datasets
|
|
||||||
if dataset_name == "kuka":
|
|
||||||
dataset = dataset.filter(lambda x: x["success"])
|
|
||||||
elif dataset_name == "bc_z":
|
|
||||||
dataset = dataset.filter(lambda x: tf.math.greater(
|
|
||||||
next(iter(x["steps"]))["observation"]["episode_success"],
|
|
||||||
0.5,
|
|
||||||
))
|
|
||||||
elif (dataset_name == "ucsd_pick_and_place_dataset_converted_externally_to_rlds"):
|
|
||||||
dataset = dataset.filter(lambda x: x["episode_metadata"]["success"])
|
|
||||||
elif (dataset_name == "utokyo_xarm_bimanual_converted_externally_to_rlds"):
|
|
||||||
# Only preserve the meaningful episodes
|
|
||||||
dataset = dataset.filter(lambda x: tf.math.equal(
|
|
||||||
next(iter(x["steps"]))["language_instruction"],
|
|
||||||
tf.constant("Unfold a wrinkled towel."),
|
|
||||||
))
|
|
||||||
|
|
||||||
# Note: use cache() will cause the unexpected crash
|
|
||||||
# dataset = dataset.map().cache().shuffle().repeat()
|
|
||||||
dataset = dataset.map(lambda x: process_episode(
|
|
||||||
x,
|
|
||||||
dataset_name,
|
|
||||||
IMAGE_KEYS[dataset_name]["image_keys"],
|
|
||||||
IMAGE_KEYS[dataset_name]["image_mask"],
|
|
||||||
))
|
|
||||||
|
|
||||||
# Change BGR to RGB if needed
|
|
||||||
if dataset_name == "fmb":
|
|
||||||
dataset = dataset.map(bgr_to_rgb)
|
|
||||||
|
|
||||||
if self.repeat:
|
|
||||||
dataset = dataset.repeat()
|
|
||||||
self.name2dataset[dataset_name] = iter(dataset)
|
|
||||||
sample_weights.append(SAMPLE_WEIGHTS[dataset_name])
|
|
||||||
# Normalize the sample weights
|
|
||||||
sample_weights = np.array(sample_weights)
|
|
||||||
self.sample_weights = sample_weights / np.sum(sample_weights)
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
"""
|
|
||||||
Sample batches of episodes for an epoch.
|
|
||||||
"""
|
|
||||||
while True:
|
|
||||||
dataset_name = np.random.choice(self.dataset_names, p=self.sample_weights)
|
|
||||||
episode = next(self.name2dataset[dataset_name])
|
|
||||||
if dataset_name == "agilex":
|
|
||||||
episode_steps = flatten_episode_agilex(episode)
|
|
||||||
else:
|
|
||||||
episode_steps = flatten_episode(episode)
|
|
||||||
# Filter too short
|
|
||||||
if len(episode_steps) < self.epsd_len_thresh_low:
|
|
||||||
continue
|
|
||||||
# Randomly sample too long
|
|
||||||
if len(episode_steps) > self.epsd_len_thresh_high:
|
|
||||||
episode_steps = random.sample(episode_steps, self.epsd_len_thresh_high)
|
|
||||||
|
|
||||||
yield episode_steps
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
dataset = VLADataset(0, "finetune")
|
|
||||||
for episode in dataset:
|
|
||||||
print(episode[0])
|
|
||||||
break
|
|
||||||
@ -1,107 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
BEGIN_TIME=$(date +%s)
|
|
||||||
|
|
||||||
CONFIG_NAME="Train_Config_Default"
|
|
||||||
CONFIG_FILE="model_config/$CONFIG_NAME.yml"
|
|
||||||
echo "CONFIG_FILE_PATH: $CONFIG_FILE"
|
|
||||||
|
|
||||||
### ============Read Input Config and ReLoad Config YAML===================
|
|
||||||
ln -s /home/qi.xiong/Temp/RDT-1B/input/weights ../weights
|
|
||||||
|
|
||||||
TRAIN_CONFIG_FILE="input/config.json"
|
|
||||||
echo "TRAIN_CONFIG_FILE_PATH: $TRAIN_CONFIG_FILE"
|
|
||||||
python scripts/read_config.py "$TRAIN_CONFIG_FILE" "$CONFIG_FILE"
|
|
||||||
|
|
||||||
### ============Read Input Config and ReLoad Config YAML===================
|
|
||||||
|
|
||||||
export NCCL_IB_HCA=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
|
|
||||||
export NCCL_DEBUG=INFO
|
|
||||||
export NCCL_NVLS_ENABLE=0
|
|
||||||
export NCCL_SOCKET_IFNAME=eth0
|
|
||||||
# export TEXT_ENCODER_NAME="google/t5-v1_1-xxl"
|
|
||||||
export VISION_ENCODER_NAME="../weights/siglip-so400m-patch14-384"
|
|
||||||
export CFLAGS="-I/usr/include"
|
|
||||||
export LDFLAGS="-L/usr/lib/x86_64-linux-gnu"
|
|
||||||
export WANDB_PROJECT="RDT-1B"
|
|
||||||
export WANDB_DEFAULT_RUN_NAME=$CONFIG_NAME
|
|
||||||
export NCCL_P2P_DISABLE=1
|
|
||||||
export NCCL_IB_DISABLE=1
|
|
||||||
|
|
||||||
# check if YAML exist
|
|
||||||
if [ ! -f "$CONFIG_FILE" ]; then
|
|
||||||
echo "Config file $CONFIG_FILE does not exist!"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
PRETRAINED_MODEL_NAME=$(python scripts/read_yaml.py "$CONFIG_FILE" pretrained_model_name_or_path)
|
|
||||||
TRAIN_BATCH_SIZE=$(python scripts/read_yaml.py "$CONFIG_FILE" train_batch_size)
|
|
||||||
SAMPLE_BATCH_SIZE=$(python scripts/read_yaml.py "$CONFIG_FILE" sample_batch_size)
|
|
||||||
MAX_TRAIN_STEPS=$(python scripts/read_yaml.py "$CONFIG_FILE" max_train_steps)
|
|
||||||
CHECKPOINTING_PERIOD=$(python scripts/read_yaml.py "$CONFIG_FILE" checkpointing_period)
|
|
||||||
SAMPLE_PERIOD=$(python scripts/read_yaml.py "$CONFIG_FILE" sample_period)
|
|
||||||
CHECKPOINTS_TOTAL_LIMIT=$(python scripts/read_yaml.py "$CONFIG_FILE" checkpoints_total_limit)
|
|
||||||
LR_SCHEDULER=$(python scripts/read_yaml.py "$CONFIG_FILE" lr_scheduler)
|
|
||||||
LEARNING_RATE=$(python scripts/read_yaml.py "$CONFIG_FILE" learning_rate)
|
|
||||||
DATALOADER_NUM_WORKERS=$(python scripts/read_yaml.py "$CONFIG_FILE" dataloader_num_workers)
|
|
||||||
DATASET_TYPE=$(python scripts/read_yaml.py "$CONFIG_FILE" dataset_type)
|
|
||||||
STATE_NOISE_SNR=$(python scripts/read_yaml.py "$CONFIG_FILE" state_noise_snr)
|
|
||||||
GRAD_ACCUM_STEPS=$(python scripts/read_yaml.py "$CONFIG_FILE" gradient_accumulation_steps)
|
|
||||||
OUTPUT_DIR=$(python scripts/read_yaml.py "$CONFIG_FILE" checkpoint_path)
|
|
||||||
CUDA_USE=$(python scripts/read_yaml.py "$CONFIG_FILE" cuda_visible_device)
|
|
||||||
|
|
||||||
|
|
||||||
PRETRAINED_MODEL_NAME=$(echo "$PRETRAINED_MODEL_NAME" | tr -d '"')
|
|
||||||
CUDA_USE=$(echo "$CUDA_USE" | tr -d '"')
|
|
||||||
OUTPUT_DIR=$(echo "$OUTPUT_DIR" | tr -d '"')
|
|
||||||
|
|
||||||
# create output path
|
|
||||||
if [ ! -d "$OUTPUT_DIR" ]; then
|
|
||||||
mkdir -p "$OUTPUT_DIR"
|
|
||||||
echo "Created output directory: $OUTPUT_DIR"
|
|
||||||
else
|
|
||||||
echo "Output directory already exists: $OUTPUT_DIR"
|
|
||||||
fi
|
|
||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES=$CUDA_USE
|
|
||||||
|
|
||||||
python -m data.compute_dataset_stat_hdf5 --task_name $CONFIG_NAME
|
|
||||||
|
|
||||||
accelerate launch --main_process_port=28499 main.py \
|
|
||||||
--deepspeed="./configs/zero2.json" \
|
|
||||||
--pretrained_model_name_or_path=$PRETRAINED_MODEL_NAME \
|
|
||||||
--pretrained_text_encoder_name_or_path=$TEXT_ENCODER_NAME \
|
|
||||||
--pretrained_vision_encoder_name_or_path=$VISION_ENCODER_NAME \
|
|
||||||
--output_dir=$OUTPUT_DIR \
|
|
||||||
--train_batch_size=$TRAIN_BATCH_SIZE \
|
|
||||||
--sample_batch_size=$SAMPLE_BATCH_SIZE \
|
|
||||||
--max_train_steps=$MAX_TRAIN_STEPS \
|
|
||||||
--checkpointing_period=$CHECKPOINTING_PERIOD \
|
|
||||||
--sample_period=$SAMPLE_PERIOD \
|
|
||||||
--checkpoints_total_limit=$CHECKPOINTS_TOTAL_LIMIT \
|
|
||||||
--lr_scheduler="constant" \
|
|
||||||
--learning_rate=$LEARNING_RATE \
|
|
||||||
--mixed_precision="bf16" \
|
|
||||||
--dataloader_num_workers=$DATALOADER_NUM_WORKERS \
|
|
||||||
--image_aug \
|
|
||||||
--dataset_type="finetune" \
|
|
||||||
--state_noise_snr=$STATE_NOISE_SNR \
|
|
||||||
--load_from_hdf5 \
|
|
||||||
--report_to=wandb \
|
|
||||||
--precomp_lang_embed \
|
|
||||||
--gradient_accumulation_steps=$GRAD_ACCUM_STEPS \
|
|
||||||
--model_config_path=$CONFIG_FILE \
|
|
||||||
--CONFIG_NAME=$CONFIG_NAME \
|
|
||||||
--output_log_path=$OUTPUT_DIR/output.log
|
|
||||||
|
|
||||||
END_TIME=$(date +%s)
|
|
||||||
RUNTIME=$((END_TIME - BEGIN_TIME))
|
|
||||||
echo "Total runtime: $RUNTIME seconds"
|
|
||||||
|
|
||||||
### ============Generate Output JSON===================
|
|
||||||
|
|
||||||
python scripts/generate_output_json.py "$TRAIN_CONFIG_FILE" "$OUTPUT_DIR" "$RUNTIME"
|
|
||||||
|
|
||||||
### ============Generate Output JSON===================
|
|
||||||
|
|
||||||
|
|
||||||
Binary file not shown.
@ -1,5 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
model_name=${1}
|
|
||||||
|
|
||||||
python ./model_config/_generate_model_config.py $model_name
|
|
||||||
351
RDT-1B/main.py
351
RDT-1B/main.py
@ -1,351 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import os
|
|
||||||
from train.train import train
|
|
||||||
|
|
||||||
from accelerate.logging import get_logger
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args(input_args=None):
|
|
||||||
parser = argparse.ArgumentParser(description="Main script for training RDT.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--model_config_path",
|
|
||||||
type=str,
|
|
||||||
default="model_config/sjoe_place_D435_100_finetune_config.yaml",
|
|
||||||
help=
|
|
||||||
"Path to the finetune data and model configuration file. Default is `model_config/sjoe_place_D435_100_finetune_config.yaml`.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--config_path",
|
|
||||||
type=str,
|
|
||||||
default="configs/base.yaml",
|
|
||||||
help="Path to the configuration file. Default is `configs/base.yaml`.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--deepspeed",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help=
|
|
||||||
"Enable DeepSpeed and pass the path to its config file or an already initialized DeepSpeed config dictionary",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--pretrained_text_encoder_name_or_path",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Pretrained text encoder name or path if not the same as model_name",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--pretrained_vision_encoder_name_or_path",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Pretrained vision encoder name or path if not the same as model_name",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--output_dir",
|
|
||||||
type=str,
|
|
||||||
default="checkpoints",
|
|
||||||
help="The output directory where the model predictions and checkpoints will be written.",
|
|
||||||
)
|
|
||||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--load_from_hdf5",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help=("Whether to load the dataset directly from HDF5 files. "
|
|
||||||
"If False, the dataset will be loaded using producer-consumer pattern, "
|
|
||||||
"where the producer reads TFRecords and saves them to buffer, and the consumer reads from buffer."),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--train_batch_size",
|
|
||||||
type=int,
|
|
||||||
default=4,
|
|
||||||
help="Batch size (per device) for the training dataloader.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--sample_batch_size",
|
|
||||||
type=int,
|
|
||||||
default=8,
|
|
||||||
help="Batch size (per device) for the sampling dataloader.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_sample_batches",
|
|
||||||
type=int,
|
|
||||||
default=2,
|
|
||||||
help="Number of batches to sample from the dataset.",
|
|
||||||
)
|
|
||||||
parser.add_argument("--num_train_epochs", type=int, default=1)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max_train_steps",
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--checkpointing_period",
|
|
||||||
type=int,
|
|
||||||
default=500,
|
|
||||||
help=
|
|
||||||
("Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
|
|
||||||
"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
|
|
||||||
"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
|
|
||||||
"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
|
|
||||||
"instructions."),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--checkpoints_total_limit",
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help=
|
|
||||||
("Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
|
|
||||||
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
|
|
||||||
" for more details"),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--resume_from_checkpoint",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help=("Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
|
||||||
' `--checkpointing_period`, or `"latest"` to automatically select the last available checkpoint.'),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--pretrained_model_name_or_path",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help=(
|
|
||||||
"Path or name of a pretrained checkpoint to load the model from.\n",
|
|
||||||
" This can be either:\n"
|
|
||||||
" - a string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co, e.g., `robotics-diffusion-transformer/rdt-1b`,\n"
|
|
||||||
" - a path to a *directory* containing model weights saved using [`~RDTRunner.save_pretrained`] method, e.g., `./my_model_directory/`.\n"
|
|
||||||
" - a path to model checkpoint (*.pt), .e.g, `my_model_directory/checkpoint-10000/pytorch_model/mp_rank_00_model_states.pt`"
|
|
||||||
" - `None` if you are randomly initializing model using configuration at `config_path`.",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--gradient_accumulation_steps",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--gradient_checkpointing",
|
|
||||||
action="store_true",
|
|
||||||
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--learning_rate",
|
|
||||||
type=float,
|
|
||||||
default=5e-6,
|
|
||||||
help="Initial learning rate (after the potential warmup period) to use.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--cond_mask_prob",
|
|
||||||
type=float,
|
|
||||||
default=0.1,
|
|
||||||
help=("The probability to randomly mask the conditions (except states) during training. "
|
|
||||||
"If set to 0, the conditions are not masked."),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--cam_ext_mask_prob",
|
|
||||||
type=float,
|
|
||||||
default=-1.0,
|
|
||||||
help=("The probability to randomly mask the external camera image during training. "
|
|
||||||
"If set to < 0, the external camera image is masked with the probability of `cond_mask_prob`."),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--state_noise_snr",
|
|
||||||
type=float,
|
|
||||||
default=None,
|
|
||||||
help=("The signal-to-noise ratio (SNR, unit: dB) for adding noise to the states. "
|
|
||||||
"Default is None, which means no noise is added."),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--image_aug",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Whether or not to apply image augmentation (ColorJitter, blur, noise, etc) to the input images.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--precomp_lang_embed",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Whether or not to use precomputed language embeddings.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--scale_lr",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--lr_scheduler",
|
|
||||||
type=str,
|
|
||||||
default="constant",
|
|
||||||
help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
|
||||||
' "constant", "constant_with_warmup"]'),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--lr_warmup_steps",
|
|
||||||
type=int,
|
|
||||||
default=500,
|
|
||||||
help="Number of steps for the warmup in the lr scheduler.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--lr_num_cycles",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--lr_power",
|
|
||||||
type=float,
|
|
||||||
default=1.0,
|
|
||||||
help="Power factor of the polynomial scheduler.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_8bit_adam",
|
|
||||||
action="store_true",
|
|
||||||
help="Whether or not to use 8-bit Adam from bitsandbytes.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--dataloader_num_workers",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help=(
|
|
||||||
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--alpha",
|
|
||||||
type=float,
|
|
||||||
default=0.9,
|
|
||||||
help="The moving average coefficient for each dataset's loss.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--adam_beta1",
|
|
||||||
type=float,
|
|
||||||
default=0.9,
|
|
||||||
help="The beta1 parameter for the Adam optimizer.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--adam_beta2",
|
|
||||||
type=float,
|
|
||||||
default=0.999,
|
|
||||||
help="The beta2 parameter for the Adam optimizer.",
|
|
||||||
)
|
|
||||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--adam_epsilon",
|
|
||||||
type=float,
|
|
||||||
default=1e-08,
|
|
||||||
help="Epsilon value for the Adam optimizer",
|
|
||||||
)
|
|
||||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--push_to_hub",
|
|
||||||
action="store_true",
|
|
||||||
help="Whether or not to push the model to the Hub.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--hub_token",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="The token to use to push to the Model Hub.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--hub_model_id",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--logging_dir",
|
|
||||||
type=str,
|
|
||||||
default="logs",
|
|
||||||
help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
|
||||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--allow_tf32",
|
|
||||||
action="store_true",
|
|
||||||
help=("Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
|
||||||
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--report_to",
|
|
||||||
type=str,
|
|
||||||
default="tensorboard",
|
|
||||||
help=('The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
|
||||||
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--sample_period",
|
|
||||||
type=int,
|
|
||||||
default=-1,
|
|
||||||
help=("Run sampling every X steps. During the sampling phase, the model will sample a trajectory"
|
|
||||||
" and report the error between the sampled trajectory and groud-truth trajectory"
|
|
||||||
" in the training batch."),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--mixed_precision",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
choices=["no", "fp16", "bf16"],
|
|
||||||
help=(
|
|
||||||
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
|
||||||
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
|
||||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."),
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--local_rank",
|
|
||||||
type=int,
|
|
||||||
default=-1,
|
|
||||||
help="For distributed training: local_rank",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--set_grads_to_none",
|
|
||||||
action="store_true",
|
|
||||||
help=("Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
|
|
||||||
" behaviors, so disable this argument if it causes any problems. More info:"
|
|
||||||
" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"),
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--dataset_type",
|
|
||||||
type=str,
|
|
||||||
default="pretrain",
|
|
||||||
required=False,
|
|
||||||
help="Whether to load the pretrain dataset or finetune dataset.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--CONFIG_NAME",
|
|
||||||
type=str,
|
|
||||||
default="Null",
|
|
||||||
required=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--output_log_path",
|
|
||||||
type=str,
|
|
||||||
default="output/output.log",
|
|
||||||
help="The path to the output log file.",
|
|
||||||
)
|
|
||||||
|
|
||||||
if input_args is not None:
|
|
||||||
args = parser.parse_args(input_args)
|
|
||||||
else:
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
|
||||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
|
||||||
args.local_rank = env_local_rank
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
args = parse_args()
|
|
||||||
train(args, logger)
|
|
||||||
269
RDT-1B/model.py
269
RDT-1B/model.py
@ -1,269 +0,0 @@
|
|||||||
#!/home/lin/software/miniconda3/envs/aloha/bin/python
|
|
||||||
# -- coding: UTF-8
|
|
||||||
"""
|
|
||||||
#!/usr/bin/python3
|
|
||||||
"""
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# get current workspace
|
|
||||||
current_file = Path(__file__)
|
|
||||||
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
|
|
||||||
parent_dir = current_file.parent
|
|
||||||
sys.path.append(str(parent_dir))
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import yaml
|
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from PIL import Image as PImage
|
|
||||||
import cv2
|
|
||||||
|
|
||||||
import sys, os
|
|
||||||
|
|
||||||
# get current workspace
|
|
||||||
current_file = Path(__file__)
|
|
||||||
sys.path.append(os.path.join(current_file.parent, "models"))
|
|
||||||
|
|
||||||
from scripts.agilex_model import create_model
|
|
||||||
from multimodal_encoder.t5_encoder import T5Embedder
|
|
||||||
|
|
||||||
global_path = parent_dir.parent
|
|
||||||
|
|
||||||
|
|
||||||
class RDT:
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
task_name,
|
|
||||||
left_arm_dim,
|
|
||||||
right_arm_dim,
|
|
||||||
rdt_step,
|
|
||||||
):
|
|
||||||
# set path
|
|
||||||
current_file = Path(__file__)
|
|
||||||
self.global_path = current_file.parent.parent
|
|
||||||
# load the config
|
|
||||||
self.config = {
|
|
||||||
"episode_len": 10000, # args.max_publish_step
|
|
||||||
"state_dim": left_arm_dim + 1 + right_arm_dim +
|
|
||||||
1, # 14 dims action:[left joint angles,left gripper,right joint angles,right gripper]
|
|
||||||
"chunk_size": 64, # args.chunk_size
|
|
||||||
"camera_names": ["cam_high", "cam_right_wrist", "cam_left_wrist"],
|
|
||||||
}
|
|
||||||
# setup config
|
|
||||||
self.args = {
|
|
||||||
"max_publish_step": 10000, # Maximum number of action publishing steps
|
|
||||||
"seed": None, # Random seed
|
|
||||||
"ctrl_freq": 25, # The control frequency of the robot
|
|
||||||
"chunk_size": 64, # Action chunk size
|
|
||||||
# 'disable_puppet_arm': False, # Whether to disable the puppet arm
|
|
||||||
"config_path": os.path.join(self.global_path, "RDT/configs/base.yaml"),
|
|
||||||
"pretrained_model_name_or_path": pretrained_model_name_or_path,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Load rdt model
|
|
||||||
self.left_arm_dim, self.right_arm_dim = left_arm_dim, right_arm_dim
|
|
||||||
self.policy = self.make_policy(self.args)
|
|
||||||
self.max_publish_step = self.config["episode_len"]
|
|
||||||
self.chunk_size = self.config["chunk_size"]
|
|
||||||
self.task_name = task_name
|
|
||||||
self.observation_window = None
|
|
||||||
self.img_size = (640, 480)
|
|
||||||
self.set_language_embed()
|
|
||||||
self.rdt_step = rdt_step
|
|
||||||
|
|
||||||
# set img_size
|
|
||||||
def set_img_size(self, img_size):
|
|
||||||
self.img_size = img_size
|
|
||||||
|
|
||||||
def set_language_embed(self):
|
|
||||||
GPU = 0
|
|
||||||
MODEL_PATH = os.path.join(self.global_path, "weights/RDT/t5-v1_1-xxl")
|
|
||||||
CONFIG_PATH = os.path.join(self.global_path, "RDT/configs/base.yaml")
|
|
||||||
with open(CONFIG_PATH, "r") as fp:
|
|
||||||
config = yaml.safe_load(fp)
|
|
||||||
device = torch.device(f"cuda:{GPU}")
|
|
||||||
text_embedder = T5Embedder(
|
|
||||||
from_pretrained=MODEL_PATH,
|
|
||||||
model_max_length=config["dataset"]["tokenizer_max_length"],
|
|
||||||
device=device,
|
|
||||||
use_offload_folder=None,
|
|
||||||
)
|
|
||||||
self.tokenizer, self.text_encoder = text_embedder.tokenizer, text_embedder.model
|
|
||||||
self.text_encoder.eval()
|
|
||||||
|
|
||||||
# set language randomly
|
|
||||||
def random_set_language(self, instruction=None):
|
|
||||||
assert instruction is not None, "Missing input instruction"
|
|
||||||
self.set_language_instruction(instruction)
|
|
||||||
|
|
||||||
# encoding language
|
|
||||||
def set_language_instruction(self, language_instruction, save_dir=None, task_name=None):
|
|
||||||
assert ((save_dir is None) ^ (task_name is None)) == False, "input error"
|
|
||||||
|
|
||||||
if os.path.isfile(language_instruction):
|
|
||||||
lang_dict = torch.load(language_instruction)
|
|
||||||
print(f"Running with instruction: \"{lang_dict['instruction']}\" from \"{lang_dict['name']}\"")
|
|
||||||
self.lang_embeddings = lang_dict["embeddings"]
|
|
||||||
print("loading instruction from pre-embed path")
|
|
||||||
else:
|
|
||||||
device = next(self.text_encoder.parameters()).device
|
|
||||||
with torch.no_grad():
|
|
||||||
tokens = self.tokenizer(
|
|
||||||
language_instruction,
|
|
||||||
return_tensors="pt",
|
|
||||||
padding="longest",
|
|
||||||
truncation=True,
|
|
||||||
)["input_ids"].to(device)
|
|
||||||
tokens = tokens.view(1, -1)
|
|
||||||
output = self.text_encoder(tokens)
|
|
||||||
pred = output.last_hidden_state.detach().cpu()
|
|
||||||
|
|
||||||
if save_dir is not None:
|
|
||||||
save_path = os.path.join(save_dir, f"{task_name}.pt")
|
|
||||||
torch.save({
|
|
||||||
"name": task_name,
|
|
||||||
"instruction": language_instruction,
|
|
||||||
"embeddings": pred,
|
|
||||||
}, save_path)
|
|
||||||
|
|
||||||
del tokens, output
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
self.lang_embeddings = pred
|
|
||||||
|
|
||||||
print(f"successfully set instruction: {language_instruction}")
|
|
||||||
|
|
||||||
# Update the observation window buffer
|
|
||||||
def update_observation_window(self, img_arr, state):
|
|
||||||
# JPEG transformation
|
|
||||||
# Align with training
|
|
||||||
def jpeg_mapping(img):
|
|
||||||
if img is None:
|
|
||||||
return None
|
|
||||||
img = cv2.imencode(".jpg", img)[1].tobytes()
|
|
||||||
img = cv2.imdecode(np.frombuffer(img, np.uint8), cv2.IMREAD_COLOR)
|
|
||||||
return img
|
|
||||||
|
|
||||||
def resize_img(img, size):
|
|
||||||
return cv2.resize(img, size)
|
|
||||||
|
|
||||||
if self.observation_window is None:
|
|
||||||
self.observation_window = deque(maxlen=2)
|
|
||||||
|
|
||||||
# Append the first dummy image
|
|
||||||
self.observation_window.append({
|
|
||||||
"qpos": None,
|
|
||||||
"images": {
|
|
||||||
self.config["camera_names"][0]: None,
|
|
||||||
self.config["camera_names"][1]: None,
|
|
||||||
self.config["camera_names"][2]: None,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
img_front, img_right, img_left, puppet_arm = (
|
|
||||||
img_arr[0],
|
|
||||||
img_arr[1],
|
|
||||||
img_arr[2],
|
|
||||||
state,
|
|
||||||
)
|
|
||||||
# img resize
|
|
||||||
img_front = resize_img(img_front, self.img_size)
|
|
||||||
img_left = resize_img(img_left, self.img_size)
|
|
||||||
img_right = resize_img(img_right, self.img_size)
|
|
||||||
# img jprg encoding
|
|
||||||
img_front = jpeg_mapping(img_front)
|
|
||||||
img_left = jpeg_mapping(img_left)
|
|
||||||
img_right = jpeg_mapping(img_right)
|
|
||||||
|
|
||||||
qpos = np.array(puppet_arm)
|
|
||||||
qpos = torch.from_numpy(qpos).float().cuda()
|
|
||||||
self.observation_window.append({
|
|
||||||
"qpos": qpos,
|
|
||||||
"images": {
|
|
||||||
self.config["camera_names"][0]: img_front,
|
|
||||||
self.config["camera_names"][1]: img_right,
|
|
||||||
self.config["camera_names"][2]: img_left,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
def get_action(self, img_arr=None, state=None):
|
|
||||||
assert (img_arr is None) ^ (state is None) == False, "input error"
|
|
||||||
if (img_arr is not None) and (state is not None):
|
|
||||||
self.update_observation_window(img_arr, state)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
action_buffer = inference_fn(self.config, self.policy, self.lang_embeddings, self.observation_window).copy()
|
|
||||||
|
|
||||||
return action_buffer
|
|
||||||
|
|
||||||
def reset_obsrvationwindows(self):
|
|
||||||
self.lang_embeddings = None
|
|
||||||
self.observation_window = None
|
|
||||||
print("successfully unset obs and language intruction")
|
|
||||||
|
|
||||||
# Initialize the model
|
|
||||||
def make_policy(self, args):
|
|
||||||
with open(args["config_path"], "r") as fp:
|
|
||||||
config_base_yaml = yaml.safe_load(fp)
|
|
||||||
args["config"] = config_base_yaml
|
|
||||||
args["config"]["arm_dim"] = {
|
|
||||||
"left_arm_dim": self.left_arm_dim,
|
|
||||||
"right_arm_dim": self.right_arm_dim,
|
|
||||||
}
|
|
||||||
# pretrained_text_encoder_name_or_path = "weights/RDT/t5-v1_1-xxl"
|
|
||||||
pretrained_vision_encoder_name_or_path = os.path.join(self.global_path, "weights/RDT/siglip-so400m-patch14-384")
|
|
||||||
model = create_model(
|
|
||||||
args=args["config"],
|
|
||||||
dtype=torch.bfloat16,
|
|
||||||
pretrained=args["pretrained_model_name_or_path"],
|
|
||||||
# pretrained_text_encoder_name_or_path=pretrained_text_encoder_name_or_path,
|
|
||||||
pretrained_vision_encoder_name_or_path=pretrained_vision_encoder_name_or_path,
|
|
||||||
control_frequency=args["ctrl_freq"],
|
|
||||||
)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
# RDT inference
|
|
||||||
def inference_fn(config, policy, lang_embeddings, observation_window):
|
|
||||||
|
|
||||||
# print(f"Start inference_thread_fn: t={t}")
|
|
||||||
while True:
|
|
||||||
time1 = time.time()
|
|
||||||
|
|
||||||
# fetch images in sequence [front, right, left]
|
|
||||||
image_arrs = [
|
|
||||||
observation_window[-2]["images"][config["camera_names"][0]],
|
|
||||||
observation_window[-2]["images"][config["camera_names"][1]],
|
|
||||||
observation_window[-2]["images"][config["camera_names"][2]],
|
|
||||||
observation_window[-1]["images"][config["camera_names"][0]],
|
|
||||||
observation_window[-1]["images"][config["camera_names"][1]],
|
|
||||||
observation_window[-1]["images"][config["camera_names"][2]],
|
|
||||||
]
|
|
||||||
|
|
||||||
images = [PImage.fromarray(arr) if arr is not None else None for arr in image_arrs]
|
|
||||||
|
|
||||||
# get last qpos in shape [14, ]
|
|
||||||
proprio = observation_window[-1]["qpos"]
|
|
||||||
# unsqueeze to [1, 14]
|
|
||||||
proprio = proprio.unsqueeze(0)
|
|
||||||
|
|
||||||
# actions shaped as [1, 64, 14] in format [left, right]
|
|
||||||
actions = (policy.step(proprio=proprio, images=images, text_embeds=lang_embeddings).squeeze(0).cpu().numpy())
|
|
||||||
# print(f"inference_actions: {actions.squeeze()}")
|
|
||||||
|
|
||||||
# print(f"Model inference time: {time.time() - time1} s")
|
|
||||||
|
|
||||||
# print(f"Finish inference_thread_fn: t={t}")
|
|
||||||
return actions
|
|
||||||
@ -1,40 +0,0 @@
|
|||||||
import os
|
|
||||||
import yaml
|
|
||||||
import argparse
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Generate finetune config.")
|
|
||||||
parser.add_argument("model_name", type=str, help="The name of the task (e.g., beat_block_hammer)")
|
|
||||||
args = parser.parse_args()
|
|
||||||
model_name = args.model_name
|
|
||||||
fintune_data_path = os.path.join("training_data/", f"{model_name}")
|
|
||||||
checkpoint_path = os.path.join("checkpoints/", f"{model_name}")
|
|
||||||
data = {
|
|
||||||
"model": model_name,
|
|
||||||
"data_path": fintune_data_path,
|
|
||||||
"checkpoint_path": checkpoint_path,
|
|
||||||
"pretrained_model_name_or_path": "../weights/RDT/rdt-1b",
|
|
||||||
"cuda_visible_device": "...", # args.gpu_use,
|
|
||||||
"train_batch_size": 32,
|
|
||||||
"sample_batch_size": 64,
|
|
||||||
"max_train_steps": 20000,
|
|
||||||
"checkpointing_period": 2500,
|
|
||||||
"sample_period": 100,
|
|
||||||
"checkpoints_total_limit": 40,
|
|
||||||
"learning_rate": 1e-4,
|
|
||||||
"dataloader_num_workers": 8,
|
|
||||||
"state_noise_snr": 40,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
}
|
|
||||||
task_config_path = os.path.join("model_config/", f"{model_name}.yml")
|
|
||||||
|
|
||||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
time_comment = f"# Generated on {current_time}\n"
|
|
||||||
|
|
||||||
with open(task_config_path, "w") as f:
|
|
||||||
f.write(time_comment)
|
|
||||||
yaml.dump(data, f, default_flow_style=False, sort_keys=False)
|
|
||||||
|
|
||||||
if not os.path.exists(fintune_data_path):
|
|
||||||
os.makedirs(fintune_data_path)
|
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1,82 +0,0 @@
|
|||||||
# Reference: DiffusionPolicy [https://github.com/real-stanford/diffusion_policy]
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.nn.modules.batchnorm import _BatchNorm
|
|
||||||
|
|
||||||
|
|
||||||
class EMAModel:
|
|
||||||
"""
|
|
||||||
Exponential Moving Average of models weights
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model, update_after_step=0, inv_gamma=1.0, power=2 / 3, min_value=0.0, max_value=0.9999):
|
|
||||||
"""
|
|
||||||
@crowsonkb's notes on EMA Warmup:
|
|
||||||
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
|
|
||||||
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
|
|
||||||
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
|
|
||||||
at 215.4k steps).
|
|
||||||
Args:
|
|
||||||
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
|
||||||
power (float): Exponential factor of EMA warmup. Default: 2/3.
|
|
||||||
min_value (float): The minimum EMA decay rate. Default: 0.
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.averaged_model = model
|
|
||||||
self.averaged_model.eval()
|
|
||||||
self.averaged_model.requires_grad_(False)
|
|
||||||
|
|
||||||
self.update_after_step = update_after_step
|
|
||||||
self.inv_gamma = inv_gamma
|
|
||||||
self.power = power
|
|
||||||
self.min_value = min_value
|
|
||||||
self.max_value = max_value
|
|
||||||
|
|
||||||
self.decay = 0.0
|
|
||||||
self.optimization_step = 0
|
|
||||||
|
|
||||||
def get_decay(self, optimization_step):
|
|
||||||
"""
|
|
||||||
Compute the decay factor for the exponential moving average.
|
|
||||||
"""
|
|
||||||
step = max(0, optimization_step - self.update_after_step - 1)
|
|
||||||
value = 1 - (1 + step / self.inv_gamma)**-self.power
|
|
||||||
|
|
||||||
if step <= 0:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
return max(self.min_value, min(value, self.max_value))
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def step(self, new_model):
|
|
||||||
self.decay = self.get_decay(self.optimization_step)
|
|
||||||
|
|
||||||
# old_all_dataptrs = set()
|
|
||||||
# for param in new_model.parameters():
|
|
||||||
# data_ptr = param.data_ptr()
|
|
||||||
# if data_ptr != 0:
|
|
||||||
# old_all_dataptrs.add(data_ptr)
|
|
||||||
|
|
||||||
all_dataptrs = set()
|
|
||||||
for module, ema_module in zip(new_model.modules(), self.averaged_model.modules()):
|
|
||||||
for param, ema_param in zip(module.parameters(recurse=False), ema_module.parameters(recurse=False)):
|
|
||||||
# iterative over immediate parameters only.
|
|
||||||
if isinstance(param, dict):
|
|
||||||
raise RuntimeError('Dict parameter not supported')
|
|
||||||
|
|
||||||
# data_ptr = param.data_ptr()
|
|
||||||
# if data_ptr != 0:
|
|
||||||
# all_dataptrs.add(data_ptr)
|
|
||||||
|
|
||||||
if isinstance(module, _BatchNorm):
|
|
||||||
# skip batchnorms
|
|
||||||
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
|
|
||||||
elif not param.requires_grad:
|
|
||||||
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
|
|
||||||
else:
|
|
||||||
ema_param.mul_(self.decay)
|
|
||||||
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)
|
|
||||||
|
|
||||||
# verify that iterating over module and then parameters is identical to parameters recursively.
|
|
||||||
# assert old_all_dataptrs == all_dataptrs
|
|
||||||
self.optimization_step += 1
|
|
||||||
@ -1,75 +0,0 @@
|
|||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, Optional, Union
|
|
||||||
|
|
||||||
from huggingface_hub import PyTorchModelHubMixin
|
|
||||||
from huggingface_hub.constants import (PYTORCH_WEIGHTS_NAME, SAFETENSORS_SINGLE_FILE)
|
|
||||||
from huggingface_hub.file_download import hf_hub_download
|
|
||||||
from huggingface_hub.utils import EntryNotFoundError, is_torch_available
|
|
||||||
|
|
||||||
if is_torch_available():
|
|
||||||
import torch # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
class CompatiblePyTorchModelHubMixin(PyTorchModelHubMixin):
|
|
||||||
"""Mixin class to load Pytorch models from the Hub."""
|
|
||||||
|
|
||||||
def _save_pretrained(self, save_directory: Path) -> None:
|
|
||||||
"""Save weights from a Pytorch model to a local directory."""
|
|
||||||
# To bypass saving into safetensor by default
|
|
||||||
model_to_save = self.module if hasattr(self, "module") else self # type: ignore
|
|
||||||
torch.save(model_to_save.state_dict(), save_directory / PYTORCH_WEIGHTS_NAME)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _from_pretrained(
|
|
||||||
cls,
|
|
||||||
*,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str],
|
|
||||||
cache_dir: Optional[Union[str, Path]],
|
|
||||||
force_download: bool,
|
|
||||||
proxies: Optional[Dict],
|
|
||||||
resume_download: Optional[bool],
|
|
||||||
local_files_only: bool,
|
|
||||||
token: Union[str, bool, None],
|
|
||||||
map_location: str = "cpu",
|
|
||||||
strict: bool = False,
|
|
||||||
**model_kwargs,
|
|
||||||
):
|
|
||||||
"""Load Pytorch pretrained weights and return the loaded model."""
|
|
||||||
model = cls(**model_kwargs)
|
|
||||||
if os.path.isdir(model_id):
|
|
||||||
print("Loading weights from local directory")
|
|
||||||
try:
|
|
||||||
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
|
||||||
return cls._load_as_safetensor(model, model_file, map_location, strict)
|
|
||||||
except FileNotFoundError:
|
|
||||||
model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
|
|
||||||
return cls._load_as_pickle(model, model_file, map_location, strict)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
model_file = hf_hub_download(
|
|
||||||
repo_id=model_id,
|
|
||||||
filename=SAFETENSORS_SINGLE_FILE,
|
|
||||||
revision=revision,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
force_download=force_download,
|
|
||||||
proxies=proxies,
|
|
||||||
resume_download=resume_download,
|
|
||||||
token=token,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
)
|
|
||||||
return cls._load_as_safetensor(model, model_file, map_location, strict)
|
|
||||||
except EntryNotFoundError:
|
|
||||||
model_file = hf_hub_download(
|
|
||||||
repo_id=model_id,
|
|
||||||
filename=PYTORCH_WEIGHTS_NAME,
|
|
||||||
revision=revision,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
force_download=force_download,
|
|
||||||
proxies=proxies,
|
|
||||||
resume_download=resume_download,
|
|
||||||
token=token,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
)
|
|
||||||
return cls._load_as_pickle(model, model_file, map_location, strict)
|
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1,159 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionTower(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, vision_tower, args, delay_load=False):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.is_loaded = False
|
|
||||||
|
|
||||||
self.vision_tower_name = vision_tower
|
|
||||||
self.select_layer = args.mm_vision_select_layer
|
|
||||||
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
|
||||||
|
|
||||||
if not delay_load:
|
|
||||||
self.load_model()
|
|
||||||
elif getattr(args, 'unfreeze_mm_vision_tower', False):
|
|
||||||
self.load_model()
|
|
||||||
else:
|
|
||||||
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
|
|
||||||
|
|
||||||
def load_model(self, device_map=None):
|
|
||||||
if self.is_loaded:
|
|
||||||
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
|
|
||||||
return
|
|
||||||
|
|
||||||
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
|
|
||||||
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
|
|
||||||
self.vision_tower.requires_grad_(False)
|
|
||||||
|
|
||||||
self.is_loaded = True
|
|
||||||
|
|
||||||
def feature_select(self, image_forward_outs):
|
|
||||||
image_features = image_forward_outs.hidden_states[self.select_layer]
|
|
||||||
if self.select_feature == 'patch':
|
|
||||||
image_features = image_features[:, 1:]
|
|
||||||
elif self.select_feature == 'cls_patch':
|
|
||||||
image_features = image_features
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
|
||||||
return image_features
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward(self, images):
|
|
||||||
if type(images) is list:
|
|
||||||
image_features = []
|
|
||||||
for image in images:
|
|
||||||
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
|
|
||||||
output_hidden_states=True)
|
|
||||||
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
|
||||||
image_features.append(image_feature)
|
|
||||||
else:
|
|
||||||
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
|
|
||||||
output_hidden_states=True)
|
|
||||||
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
|
||||||
|
|
||||||
return image_features
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dummy_feature(self):
|
|
||||||
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self):
|
|
||||||
return self.vision_tower.dtype
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self):
|
|
||||||
return self.vision_tower.device
|
|
||||||
|
|
||||||
@property
|
|
||||||
def config(self):
|
|
||||||
if self.is_loaded:
|
|
||||||
return self.vision_tower.config
|
|
||||||
else:
|
|
||||||
return self.cfg_only
|
|
||||||
|
|
||||||
@property
|
|
||||||
def hidden_size(self):
|
|
||||||
return self.config.hidden_size
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_patches_per_side(self):
|
|
||||||
return self.config.image_size // self.config.patch_size
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_patches(self):
|
|
||||||
return (self.config.image_size // self.config.patch_size)**2
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionTowerS2(CLIPVisionTower):
|
|
||||||
|
|
||||||
def __init__(self, vision_tower, args, delay_load=False):
|
|
||||||
super().__init__(vision_tower, args, delay_load)
|
|
||||||
|
|
||||||
self.s2_scales = getattr(args, 's2_scales', '336,672,1008')
|
|
||||||
self.s2_scales = list(map(int, self.s2_scales.split(',')))
|
|
||||||
self.s2_scales.sort()
|
|
||||||
self.s2_split_size = self.s2_scales[0]
|
|
||||||
self.s2_image_size = self.s2_scales[-1]
|
|
||||||
|
|
||||||
try:
|
|
||||||
from s2wrapper import forward as multiscale_forward
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
'Package s2wrapper not found! Please install by running: \npip install git+https://github.com/bfshi/scaling_on_scales.git'
|
|
||||||
)
|
|
||||||
self.multiscale_forward = multiscale_forward
|
|
||||||
|
|
||||||
# change resize/crop size in preprocessing to the largest image size in s2_scale
|
|
||||||
if not delay_load or getattr(args, 'unfreeze_mm_vision_tower', False):
|
|
||||||
self.image_processor.size['shortest_edge'] = self.s2_image_size
|
|
||||||
self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
|
|
||||||
|
|
||||||
def load_model(self, device_map=None):
|
|
||||||
if self.is_loaded:
|
|
||||||
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
|
|
||||||
return
|
|
||||||
|
|
||||||
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
|
|
||||||
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
|
|
||||||
self.vision_tower.requires_grad_(False)
|
|
||||||
|
|
||||||
self.image_processor.size['shortest_edge'] = self.s2_image_size
|
|
||||||
self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
|
|
||||||
|
|
||||||
self.is_loaded = True
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward_feature(self, images):
|
|
||||||
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
|
|
||||||
output_hidden_states=True)
|
|
||||||
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
|
||||||
return image_features
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward(self, images):
|
|
||||||
if type(images) is list:
|
|
||||||
image_features = []
|
|
||||||
for image in images:
|
|
||||||
image_feature = self.multiscale_forward(self.forward_feature,
|
|
||||||
image.unsqueeze(0),
|
|
||||||
img_sizes=self.s2_scales,
|
|
||||||
max_split_size=self.s2_split_size)
|
|
||||||
image_features.append(image_feature)
|
|
||||||
else:
|
|
||||||
image_features = self.multiscale_forward(self.forward_feature,
|
|
||||||
images,
|
|
||||||
img_sizes=self.s2_scales,
|
|
||||||
max_split_size=self.s2_split_size)
|
|
||||||
|
|
||||||
return image_features
|
|
||||||
|
|
||||||
@property
|
|
||||||
def hidden_size(self):
|
|
||||||
return self.config.hidden_size * len(self.s2_scales)
|
|
||||||
@ -1,87 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from transformers import AutoConfig, AutoImageProcessor, AutoModel, Dinov2Model
|
|
||||||
|
|
||||||
|
|
||||||
class DinoV2VisionTower(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, vision_tower, args, delay_load=False):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.is_loaded = False
|
|
||||||
|
|
||||||
self.vision_tower_name = vision_tower
|
|
||||||
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
|
||||||
|
|
||||||
if not delay_load:
|
|
||||||
self.load_model()
|
|
||||||
elif getattr(args, 'unfreeze_mm_vision_tower', False):
|
|
||||||
self.load_model()
|
|
||||||
else:
|
|
||||||
self.cfg_only = AutoConfig.from_pretrained(self.vision_tower_name)
|
|
||||||
|
|
||||||
def load_model(self, device_map=None):
|
|
||||||
if self.is_loaded:
|
|
||||||
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
|
|
||||||
return
|
|
||||||
|
|
||||||
self.image_processor = AutoImageProcessor.from_pretrained(self.vision_tower_name)
|
|
||||||
self.vision_tower = AutoModel.from_pretrained(self.vision_tower_name, device_map=device_map)
|
|
||||||
self.vision_tower.requires_grad_(False) # FIXME:
|
|
||||||
|
|
||||||
self.is_loaded = True
|
|
||||||
|
|
||||||
def feature_select(self, image_forward_outs):
|
|
||||||
image_features = image_forward_outs.last_hidden_state
|
|
||||||
if self.select_feature == 'patch':
|
|
||||||
image_features = image_features[:, 1:] # (B, 1369, 1536)
|
|
||||||
elif self.select_feature == 'cls_patch':
|
|
||||||
image_features = image_features # (B, 1, 1536)
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
|
||||||
return image_features
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward(self, images):
|
|
||||||
if type(images) is list:
|
|
||||||
image_features = []
|
|
||||||
for image in images:
|
|
||||||
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
|
|
||||||
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
|
||||||
image_features.append(image_feature)
|
|
||||||
else:
|
|
||||||
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype))
|
|
||||||
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
|
||||||
|
|
||||||
return image_features
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dummy_feature(self):
|
|
||||||
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self):
|
|
||||||
return self.vision_tower.dtype
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self):
|
|
||||||
return self.vision_tower.device
|
|
||||||
|
|
||||||
@property
|
|
||||||
def config(self):
|
|
||||||
if self.is_loaded:
|
|
||||||
return self.vision_tower.config
|
|
||||||
else:
|
|
||||||
return self.cfg_only
|
|
||||||
|
|
||||||
@property
|
|
||||||
def hidden_size(self):
|
|
||||||
return self.config.hidden_size
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_patches_per_side(self):
|
|
||||||
return self.config.image_size // self.config.patch_size
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_patches(self):
|
|
||||||
return (self.config.image_size // self.config.patch_size)**2
|
|
||||||
@ -1,86 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from transformers import AutoConfig, SiglipImageProcessor, SiglipVisionModel
|
|
||||||
|
|
||||||
|
|
||||||
class SiglipVisionTower(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, vision_tower, args, delay_load=False):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.is_loaded = False
|
|
||||||
|
|
||||||
self.vision_tower_name = vision_tower
|
|
||||||
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
|
||||||
|
|
||||||
if not delay_load:
|
|
||||||
self.load_model()
|
|
||||||
elif getattr(args, 'unfreeze_mm_vision_tower', False):
|
|
||||||
self.load_model()
|
|
||||||
else:
|
|
||||||
self.cfg_only = AutoConfig.from_pretrained(self.vision_tower_name)
|
|
||||||
|
|
||||||
def load_model(self, device_map=None):
|
|
||||||
if self.is_loaded:
|
|
||||||
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
|
|
||||||
return
|
|
||||||
|
|
||||||
self.image_processor = SiglipImageProcessor.from_pretrained(self.vision_tower_name)
|
|
||||||
self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
|
|
||||||
self.vision_tower.eval()
|
|
||||||
|
|
||||||
self.is_loaded = True
|
|
||||||
|
|
||||||
def feature_select(self, image_forward_outs):
|
|
||||||
if self.select_feature == 'patch':
|
|
||||||
image_features = image_forward_outs.last_hidden_state # (B, 729, 1536)
|
|
||||||
elif self.select_feature == 'cls_patch':
|
|
||||||
image_features = image_forward_outs.pooler_output # (B, 1, 1536)
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
|
||||||
return image_features
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward(self, images):
|
|
||||||
if type(images) is list:
|
|
||||||
image_features = []
|
|
||||||
for image in images:
|
|
||||||
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
|
|
||||||
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
|
||||||
image_features.append(image_feature)
|
|
||||||
else:
|
|
||||||
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype))
|
|
||||||
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
|
||||||
|
|
||||||
return image_features
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dummy_feature(self):
|
|
||||||
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self):
|
|
||||||
return self.vision_tower.dtype
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self):
|
|
||||||
return self.vision_tower.device
|
|
||||||
|
|
||||||
@property
|
|
||||||
def config(self):
|
|
||||||
if self.is_loaded:
|
|
||||||
return self.vision_tower.config
|
|
||||||
else:
|
|
||||||
return self.cfg_only
|
|
||||||
|
|
||||||
@property
|
|
||||||
def hidden_size(self):
|
|
||||||
return self.config.hidden_size
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_patches_per_side(self):
|
|
||||||
return self.config.image_size // self.config.patch_size
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_patches(self):
|
|
||||||
return (self.config.image_size // self.config.patch_size)**2
|
|
||||||
@ -1,111 +0,0 @@
|
|||||||
import torch
|
|
||||||
from transformers import AutoTokenizer, T5EncoderModel
|
|
||||||
|
|
||||||
|
|
||||||
class T5Embedder:
|
|
||||||
# available_models = ["google/t5-v1_1-xxl"]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
device,
|
|
||||||
from_pretrained=None,
|
|
||||||
*,
|
|
||||||
cache_dir=None,
|
|
||||||
hf_token=None,
|
|
||||||
use_text_preprocessing=True,
|
|
||||||
t5_model_kwargs=None,
|
|
||||||
torch_dtype=None,
|
|
||||||
use_offload_folder=None,
|
|
||||||
model_max_length=120,
|
|
||||||
local_files_only=False,
|
|
||||||
):
|
|
||||||
# from_pretrained="google/t5-v1_1-xxl" # zijian
|
|
||||||
self.device = torch.device(device)
|
|
||||||
self.torch_dtype = torch_dtype or torch.bfloat16
|
|
||||||
self.cache_dir = cache_dir
|
|
||||||
|
|
||||||
if t5_model_kwargs is None:
|
|
||||||
t5_model_kwargs = {
|
|
||||||
"low_cpu_mem_usage": True,
|
|
||||||
"torch_dtype": self.torch_dtype,
|
|
||||||
}
|
|
||||||
|
|
||||||
if use_offload_folder is not None:
|
|
||||||
t5_model_kwargs["offload_folder"] = use_offload_folder
|
|
||||||
t5_model_kwargs["device_map"] = {
|
|
||||||
"shared": self.device,
|
|
||||||
"encoder.embed_tokens": self.device,
|
|
||||||
"encoder.block.0": self.device,
|
|
||||||
"encoder.block.1": self.device,
|
|
||||||
"encoder.block.2": self.device,
|
|
||||||
"encoder.block.3": self.device,
|
|
||||||
"encoder.block.4": self.device,
|
|
||||||
"encoder.block.5": self.device,
|
|
||||||
"encoder.block.6": self.device,
|
|
||||||
"encoder.block.7": self.device,
|
|
||||||
"encoder.block.8": self.device,
|
|
||||||
"encoder.block.9": self.device,
|
|
||||||
"encoder.block.10": self.device,
|
|
||||||
"encoder.block.11": self.device,
|
|
||||||
"encoder.block.12": "disk",
|
|
||||||
"encoder.block.13": "disk",
|
|
||||||
"encoder.block.14": "disk",
|
|
||||||
"encoder.block.15": "disk",
|
|
||||||
"encoder.block.16": "disk",
|
|
||||||
"encoder.block.17": "disk",
|
|
||||||
"encoder.block.18": "disk",
|
|
||||||
"encoder.block.19": "disk",
|
|
||||||
"encoder.block.20": "disk",
|
|
||||||
"encoder.block.21": "disk",
|
|
||||||
"encoder.block.22": "disk",
|
|
||||||
"encoder.block.23": "disk",
|
|
||||||
"encoder.final_layer_norm": "disk",
|
|
||||||
"encoder.dropout": "disk",
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
t5_model_kwargs["device_map"] = {
|
|
||||||
"shared": self.device,
|
|
||||||
"encoder": self.device,
|
|
||||||
}
|
|
||||||
|
|
||||||
self.use_text_preprocessing = use_text_preprocessing
|
|
||||||
self.hf_token = hf_token
|
|
||||||
|
|
||||||
# assert from_pretrained in self.available_models
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
from_pretrained,
|
|
||||||
model_max_length=model_max_length,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
)
|
|
||||||
self.model = T5EncoderModel.from_pretrained(
|
|
||||||
from_pretrained,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
**t5_model_kwargs,
|
|
||||||
).eval()
|
|
||||||
self.model_max_length = model_max_length
|
|
||||||
|
|
||||||
def get_text_embeddings(self, texts):
|
|
||||||
text_tokens_and_mask = self.tokenizer(
|
|
||||||
texts,
|
|
||||||
max_length=self.model_max_length,
|
|
||||||
padding="longest",
|
|
||||||
truncation=True,
|
|
||||||
return_attention_mask=True,
|
|
||||||
add_special_tokens=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
|
|
||||||
input_ids = text_tokens_and_mask["input_ids"].to(self.device)
|
|
||||||
attention_mask = text_tokens_and_mask["attention_mask"].to(self.device)
|
|
||||||
with torch.no_grad():
|
|
||||||
text_encoder_embs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
)["last_hidden_state"].detach()
|
|
||||||
return text_encoder_embs, attention_mask
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
T5Embedder(from_pretrained="google/t5-v1_1-xxl", device='cuda:7')
|
|
||||||
Binary file not shown.
Binary file not shown.
@ -1,304 +0,0 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
# --------------------------------------------------------
|
|
||||||
# References:
|
|
||||||
# DiT: https://github.com/facebookresearch/DiT
|
|
||||||
# GLIDE: https://github.com/openai/glide-text2im
|
|
||||||
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
|
||||||
# --------------------------------------------------------
|
|
||||||
|
|
||||||
import math
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.jit import Final
|
|
||||||
from timm.models.vision_transformer import Attention, Mlp, RmsNorm, use_fused_attn
|
|
||||||
|
|
||||||
|
|
||||||
#################################################################################
|
|
||||||
# Embedding Layers for Timesteps and Condition Inptus #
|
|
||||||
#################################################################################
|
|
||||||
class TimestepEmbedder(nn.Module):
|
|
||||||
"""
|
|
||||||
Embeds scalar timesteps into vector representations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=torch.bfloat16):
|
|
||||||
super().__init__()
|
|
||||||
self.mlp = nn.Sequential(
|
|
||||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
|
||||||
)
|
|
||||||
self.frequency_embedding_size = frequency_embedding_size
|
|
||||||
self.dtype = dtype
|
|
||||||
|
|
||||||
def timestep_embedding(self, t, dim, max_period=10000):
|
|
||||||
"""
|
|
||||||
Create sinusoidal timestep embeddings.
|
|
||||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
|
||||||
These may be fractional.
|
|
||||||
:param dim: the dimension of the output.
|
|
||||||
:param max_period: controls the minimum frequency of the embeddings.
|
|
||||||
:return: an (N, D) Tensor of positional embeddings.
|
|
||||||
"""
|
|
||||||
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
|
||||||
half = dim // 2
|
|
||||||
freqs = torch.exp(-math.log(max_period) *
|
|
||||||
torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
|
|
||||||
args = t[:, None].float() * freqs[None]
|
|
||||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
||||||
if dim % 2:
|
|
||||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
|
||||||
return embedding.to(self.dtype)
|
|
||||||
|
|
||||||
def forward(self, t):
|
|
||||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
|
||||||
t_emb = self.mlp(t_freq)
|
|
||||||
return t_emb
|
|
||||||
|
|
||||||
|
|
||||||
#################################################################################
|
|
||||||
# Cross Attention Layers #
|
|
||||||
#################################################################################
|
|
||||||
class CrossAttention(nn.Module):
|
|
||||||
"""
|
|
||||||
A cross-attention layer with flash attention.
|
|
||||||
"""
|
|
||||||
fused_attn: Final[bool]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
num_heads: int = 8,
|
|
||||||
qkv_bias: bool = False,
|
|
||||||
qk_norm: bool = False,
|
|
||||||
attn_drop: float = 0,
|
|
||||||
proj_drop: float = 0,
|
|
||||||
norm_layer: nn.Module = nn.LayerNorm,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_dim = dim // num_heads
|
|
||||||
self.scale = self.head_dim**-0.5
|
|
||||||
self.fused_attn = use_fused_attn()
|
|
||||||
|
|
||||||
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
|
||||||
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
|
||||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
|
||||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
|
||||||
self.attn_drop = nn.Dropout(attn_drop)
|
|
||||||
self.proj = nn.Linear(dim, dim)
|
|
||||||
self.proj_drop = nn.Dropout(proj_drop)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, c: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
|
|
||||||
B, N, C = x.shape
|
|
||||||
_, L, _ = c.shape
|
|
||||||
q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
|
||||||
kv = self.kv(c).reshape(B, L, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
|
||||||
k, v = kv.unbind(0)
|
|
||||||
q, k = self.q_norm(q), self.k_norm(k)
|
|
||||||
|
|
||||||
# Prepare attn mask (B, L) to mask the conditioion
|
|
||||||
if mask is not None:
|
|
||||||
mask = mask.reshape(B, 1, 1, L)
|
|
||||||
mask = mask.expand(-1, -1, N, -1)
|
|
||||||
|
|
||||||
if self.fused_attn:
|
|
||||||
x = F.scaled_dot_product_attention(query=q,
|
|
||||||
key=k,
|
|
||||||
value=v,
|
|
||||||
dropout_p=self.attn_drop.p if self.training else 0.,
|
|
||||||
attn_mask=mask)
|
|
||||||
else:
|
|
||||||
q = q * self.scale
|
|
||||||
attn = q @ k.transpose(-2, -1)
|
|
||||||
if mask is not None:
|
|
||||||
attn = attn.masked_fill_(mask.logical_not(), float('-inf'))
|
|
||||||
attn = attn.softmax(dim=-1)
|
|
||||||
if self.attn_drop.p > 0:
|
|
||||||
attn = self.attn_drop(attn)
|
|
||||||
x = attn @ v
|
|
||||||
|
|
||||||
x = x.permute(0, 2, 1, 3).reshape(B, N, C)
|
|
||||||
x = self.proj(x)
|
|
||||||
if self.proj_drop.p > 0:
|
|
||||||
x = self.proj_drop(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
#################################################################################
|
|
||||||
# RDT Block #
|
|
||||||
#################################################################################
|
|
||||||
class RDTBlock(nn.Module):
|
|
||||||
"""
|
|
||||||
A RDT block with cross-attention conditioning.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, hidden_size, num_heads, **block_kwargs):
|
|
||||||
super().__init__()
|
|
||||||
self.norm1 = RmsNorm(hidden_size, eps=1e-6)
|
|
||||||
self.attn = Attention(dim=hidden_size,
|
|
||||||
num_heads=num_heads,
|
|
||||||
qkv_bias=True,
|
|
||||||
qk_norm=True,
|
|
||||||
norm_layer=RmsNorm,
|
|
||||||
**block_kwargs)
|
|
||||||
self.cross_attn = CrossAttention(hidden_size,
|
|
||||||
num_heads=num_heads,
|
|
||||||
qkv_bias=True,
|
|
||||||
qk_norm=True,
|
|
||||||
norm_layer=RmsNorm,
|
|
||||||
**block_kwargs)
|
|
||||||
|
|
||||||
self.norm2 = RmsNorm(hidden_size, eps=1e-6)
|
|
||||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
|
||||||
self.ffn = Mlp(in_features=hidden_size, hidden_features=hidden_size, act_layer=approx_gelu, drop=0)
|
|
||||||
self.norm3 = RmsNorm(hidden_size, eps=1e-6)
|
|
||||||
|
|
||||||
def forward(self, x, c, mask=None):
|
|
||||||
origin_x = x
|
|
||||||
x = self.norm1(x)
|
|
||||||
x = self.attn(x)
|
|
||||||
x = x + origin_x
|
|
||||||
|
|
||||||
origin_x = x
|
|
||||||
x = self.norm2(x)
|
|
||||||
x = self.cross_attn(x, c, mask)
|
|
||||||
x = x + origin_x
|
|
||||||
|
|
||||||
origin_x = x
|
|
||||||
x = self.norm3(x)
|
|
||||||
x = self.ffn(x)
|
|
||||||
x = x + origin_x
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class FinalLayer(nn.Module):
|
|
||||||
"""
|
|
||||||
The final layer of RDT.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, hidden_size, out_channels):
|
|
||||||
super().__init__()
|
|
||||||
self.norm_final = RmsNorm(hidden_size, eps=1e-6)
|
|
||||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
|
||||||
self.ffn_final = Mlp(in_features=hidden_size,
|
|
||||||
hidden_features=hidden_size,
|
|
||||||
out_features=out_channels,
|
|
||||||
act_layer=approx_gelu,
|
|
||||||
drop=0)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.norm_final(x)
|
|
||||||
x = self.ffn_final(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
#################################################################################
|
|
||||||
# Sine/Cosine Positional Embedding Functions #
|
|
||||||
#################################################################################
|
|
||||||
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
|
||||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
|
||||||
"""
|
|
||||||
embed_dim: output dimension for each position
|
|
||||||
pos: a list of positions to be encoded: size (M,)
|
|
||||||
out: (M, D)
|
|
||||||
"""
|
|
||||||
assert embed_dim % 2 == 0
|
|
||||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
|
||||||
omega /= embed_dim / 2.
|
|
||||||
omega = 1. / 10000**omega # (D/2,)
|
|
||||||
|
|
||||||
if not isinstance(pos, np.ndarray):
|
|
||||||
pos = np.array(pos, dtype=np.float64)
|
|
||||||
pos = pos.reshape(-1) # (M,)
|
|
||||||
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
|
||||||
|
|
||||||
emb_sin = np.sin(out) # (M, D/2)
|
|
||||||
emb_cos = np.cos(out) # (M, D/2)
|
|
||||||
|
|
||||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
|
||||||
return emb
|
|
||||||
|
|
||||||
|
|
||||||
def get_nd_sincos_pos_embed_from_grid(embed_dim, grid_sizes):
|
|
||||||
"""
|
|
||||||
embed_dim: output dimension for each position
|
|
||||||
grid_sizes: the grids sizes in each dimension (K,).
|
|
||||||
out: (grid_sizes[0], ..., grid_sizes[K-1], D)
|
|
||||||
"""
|
|
||||||
num_sizes = len(grid_sizes)
|
|
||||||
# For grid size of 1, we do not need to add any positional embedding
|
|
||||||
num_valid_sizes = len([x for x in grid_sizes if x > 1])
|
|
||||||
emb = np.zeros(grid_sizes + (embed_dim, ))
|
|
||||||
# Uniformly divide the embedding dimension for each grid size
|
|
||||||
dim_for_each_grid = embed_dim // num_valid_sizes
|
|
||||||
# To make it even
|
|
||||||
if dim_for_each_grid % 2 != 0:
|
|
||||||
dim_for_each_grid -= 1
|
|
||||||
valid_size_idx = 0
|
|
||||||
for size_idx in range(num_sizes):
|
|
||||||
grid_size = grid_sizes[size_idx]
|
|
||||||
if grid_size <= 1:
|
|
||||||
continue
|
|
||||||
pos = np.arange(grid_size)
|
|
||||||
posemb_shape = [1] * len(grid_sizes) + [dim_for_each_grid]
|
|
||||||
posemb_shape[size_idx] = -1
|
|
||||||
emb[..., valid_size_idx * dim_for_each_grid:(valid_size_idx + 1) * dim_for_each_grid] += \
|
|
||||||
get_1d_sincos_pos_embed_from_grid(dim_for_each_grid, pos).reshape(posemb_shape)
|
|
||||||
valid_size_idx += 1
|
|
||||||
return emb
|
|
||||||
|
|
||||||
|
|
||||||
def get_multimodal_cond_pos_embed(embed_dim, mm_cond_lens: OrderedDict, embed_modality=True):
|
|
||||||
"""
|
|
||||||
Generate position embeddings for multimodal conditions.
|
|
||||||
|
|
||||||
mm_cond_lens: an OrderedDict containing
|
|
||||||
(modality name, modality token length) pairs.
|
|
||||||
For `"image"` modality, the value can be a multi-dimensional tuple.
|
|
||||||
If the length < 0, it means there is no position embedding for the modality or grid.
|
|
||||||
embed_modality: whether to embed the modality information. Default is True.
|
|
||||||
"""
|
|
||||||
num_modalities = len(mm_cond_lens)
|
|
||||||
modality_pos_embed = np.zeros((num_modalities, embed_dim))
|
|
||||||
if embed_modality:
|
|
||||||
# Get embeddings for various modalites
|
|
||||||
# We put it in the first half
|
|
||||||
modality_sincos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, torch.arange(num_modalities))
|
|
||||||
modality_pos_embed[:, :embed_dim // 2] = modality_sincos_embed
|
|
||||||
# The second half is for position embeddings
|
|
||||||
pos_embed_dim = embed_dim // 2
|
|
||||||
else:
|
|
||||||
# The whole embedding is for position embeddings
|
|
||||||
pos_embed_dim = embed_dim
|
|
||||||
|
|
||||||
# Get embeddings for positions inside each modality
|
|
||||||
c_pos_emb = np.zeros((0, embed_dim))
|
|
||||||
for idx, (modality, cond_len) in enumerate(mm_cond_lens.items()):
|
|
||||||
if modality == "image" and \
|
|
||||||
(isinstance(cond_len, tuple) or isinstance(cond_len, list)):
|
|
||||||
all_grid_sizes = tuple([abs(x) for x in cond_len])
|
|
||||||
embed_grid_sizes = tuple([x if x > 0 else 1 for x in cond_len])
|
|
||||||
cond_sincos_embed = get_nd_sincos_pos_embed_from_grid(pos_embed_dim, embed_grid_sizes)
|
|
||||||
cond_pos_embed = np.zeros(all_grid_sizes + (embed_dim, ))
|
|
||||||
cond_pos_embed[..., -pos_embed_dim:] += cond_sincos_embed
|
|
||||||
cond_pos_embed = cond_pos_embed.reshape((-1, embed_dim))
|
|
||||||
else:
|
|
||||||
cond_sincos_embed = get_1d_sincos_pos_embed_from_grid(pos_embed_dim,
|
|
||||||
torch.arange(cond_len if cond_len > 0 else 1))
|
|
||||||
cond_pos_embed = np.zeros((abs(cond_len), embed_dim))
|
|
||||||
cond_pos_embed[:, -pos_embed_dim:] += cond_sincos_embed
|
|
||||||
cond_pos_embed += modality_pos_embed[idx]
|
|
||||||
c_pos_emb = np.concatenate([c_pos_emb, cond_pos_embed], axis=0)
|
|
||||||
|
|
||||||
return c_pos_emb
|
|
||||||
@ -1,156 +0,0 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
# --------------------------------------------------------
|
|
||||||
# References:
|
|
||||||
# DiT: https://github.com/facebookresearch/DiT
|
|
||||||
# GLIDE: https://github.com/openai/glide-text2im
|
|
||||||
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
|
||||||
# --------------------------------------------------------
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
import sys, os
|
|
||||||
# get current workspace
|
|
||||||
current_file = Path(__file__)
|
|
||||||
sys.path.append(str(current_file.parent.parent))
|
|
||||||
|
|
||||||
from rdt.blocks import (FinalLayer, RDTBlock, TimestepEmbedder, get_1d_sincos_pos_embed_from_grid,
|
|
||||||
get_multimodal_cond_pos_embed)
|
|
||||||
|
|
||||||
|
|
||||||
class RDT(nn.Module):
|
|
||||||
"""
|
|
||||||
Class for Robotics Diffusion Transformers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
output_dim=128,
|
|
||||||
horizon=32,
|
|
||||||
hidden_size=1152,
|
|
||||||
depth=28,
|
|
||||||
num_heads=16,
|
|
||||||
max_lang_cond_len=1024,
|
|
||||||
img_cond_len=4096,
|
|
||||||
lang_pos_embed_config=None,
|
|
||||||
img_pos_embed_config=None,
|
|
||||||
dtype=torch.bfloat16):
|
|
||||||
super().__init__()
|
|
||||||
self.horizon = horizon
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.max_lang_cond_len = max_lang_cond_len
|
|
||||||
self.img_cond_len = img_cond_len
|
|
||||||
self.dtype = dtype
|
|
||||||
self.lang_pos_embed_config = lang_pos_embed_config
|
|
||||||
self.img_pos_embed_config = img_pos_embed_config
|
|
||||||
|
|
||||||
self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype)
|
|
||||||
self.freq_embedder = TimestepEmbedder(hidden_size, dtype=dtype)
|
|
||||||
|
|
||||||
# We will use trainable sin-cos embeddings
|
|
||||||
# [timestep; state; action]
|
|
||||||
self.x_pos_embed = nn.Parameter(torch.zeros(1, horizon + 3, hidden_size))
|
|
||||||
# Language conditions
|
|
||||||
self.lang_cond_pos_embed = nn.Parameter(torch.zeros(1, max_lang_cond_len, hidden_size))
|
|
||||||
# Image conditions
|
|
||||||
self.img_cond_pos_embed = nn.Parameter(torch.zeros(1, img_cond_len, hidden_size))
|
|
||||||
|
|
||||||
self.blocks = nn.ModuleList([RDTBlock(hidden_size, num_heads) for _ in range(depth)])
|
|
||||||
self.final_layer = FinalLayer(hidden_size, output_dim)
|
|
||||||
self.initialize_weights()
|
|
||||||
|
|
||||||
def initialize_weights(self):
|
|
||||||
# Initialize transformer layers:
|
|
||||||
def _basic_init(module):
|
|
||||||
if isinstance(module, nn.Linear):
|
|
||||||
torch.nn.init.xavier_uniform_(module.weight)
|
|
||||||
if module.bias is not None:
|
|
||||||
nn.init.constant_(module.bias, 0)
|
|
||||||
|
|
||||||
self.apply(_basic_init)
|
|
||||||
|
|
||||||
# Initialize pos_embed by sin-cos embedding
|
|
||||||
x_pos_embed = get_multimodal_cond_pos_embed(embed_dim=self.hidden_size,
|
|
||||||
mm_cond_lens=OrderedDict([
|
|
||||||
('timestep', 1),
|
|
||||||
('ctrl_freq', 1),
|
|
||||||
('state', 1),
|
|
||||||
('action', self.horizon),
|
|
||||||
]))
|
|
||||||
self.x_pos_embed.data.copy_(torch.from_numpy(x_pos_embed).float().unsqueeze(0))
|
|
||||||
|
|
||||||
if self.lang_pos_embed_config is None:
|
|
||||||
lang_cond_pos_embed = get_1d_sincos_pos_embed_from_grid(self.hidden_size,
|
|
||||||
torch.arange(self.max_lang_cond_len))
|
|
||||||
else:
|
|
||||||
lang_cond_pos_embed = get_multimodal_cond_pos_embed(embed_dim=self.hidden_size,
|
|
||||||
mm_cond_lens=OrderedDict(self.lang_pos_embed_config),
|
|
||||||
embed_modality=False)
|
|
||||||
self.lang_cond_pos_embed.data.copy_(torch.from_numpy(lang_cond_pos_embed).float().unsqueeze(0))
|
|
||||||
|
|
||||||
if self.img_pos_embed_config is None:
|
|
||||||
img_cond_pos_embed = get_1d_sincos_pos_embed_from_grid(self.hidden_size, torch.arange(self.img_cond_len))
|
|
||||||
else:
|
|
||||||
img_cond_pos_embed = get_multimodal_cond_pos_embed(embed_dim=self.hidden_size,
|
|
||||||
mm_cond_lens=OrderedDict(self.img_pos_embed_config),
|
|
||||||
embed_modality=False)
|
|
||||||
self.img_cond_pos_embed.data.copy_(torch.from_numpy(img_cond_pos_embed).float().unsqueeze(0))
|
|
||||||
|
|
||||||
# Initialize timestep and control freq embedding MLP
|
|
||||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
|
||||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
|
||||||
nn.init.normal_(self.freq_embedder.mlp[0].weight, std=0.02)
|
|
||||||
nn.init.normal_(self.freq_embedder.mlp[2].weight, std=0.02)
|
|
||||||
|
|
||||||
# Initialize the final layer: zero-out the final linear layer
|
|
||||||
nn.init.constant_(self.final_layer.ffn_final.fc2.weight, 0)
|
|
||||||
nn.init.constant_(self.final_layer.ffn_final.fc2.bias, 0)
|
|
||||||
|
|
||||||
# Move all the params to given data type:
|
|
||||||
self.to(self.dtype)
|
|
||||||
|
|
||||||
def forward(self, x, freq, t, lang_c, img_c, lang_mask=None, img_mask=None):
|
|
||||||
"""
|
|
||||||
Forward pass of RDT.
|
|
||||||
|
|
||||||
x: (B, T, D), state + action token sequence, T = horizon + 1,
|
|
||||||
dimension D is assumed to be the same as the hidden size.
|
|
||||||
freq: (B,), a scalar indicating control frequency.
|
|
||||||
t: (B,) or (1,), diffusion timesteps.
|
|
||||||
lang_c: (B, L_lang, D) or None, language condition tokens (variable length),
|
|
||||||
dimension D is assumed to be the same as the hidden size.
|
|
||||||
img_c: (B, L_img, D) or None, image condition tokens (fixed length),
|
|
||||||
dimension D is assumed to be the same as the hidden size.
|
|
||||||
lang_mask: (B, L_lang) or None, language condition mask (True for valid).
|
|
||||||
img_mask: (B, L_img) or None, image condition mask (True for valid).
|
|
||||||
"""
|
|
||||||
t = self.t_embedder(t).unsqueeze(1) # (B, 1, D) or (1, 1, D)
|
|
||||||
freq = self.freq_embedder(freq).unsqueeze(1) # (B, 1, D)
|
|
||||||
# Append timestep to the input tokens
|
|
||||||
if t.shape[0] == 1:
|
|
||||||
t = t.expand(x.shape[0], -1, -1)
|
|
||||||
x = torch.cat([t, freq, x], dim=1) # (B, T+1, D)
|
|
||||||
|
|
||||||
# Add multimodal position embeddings
|
|
||||||
x = x + self.x_pos_embed
|
|
||||||
# Note the lang is of variable length
|
|
||||||
lang_c = lang_c + self.lang_cond_pos_embed[:, :lang_c.shape[1]]
|
|
||||||
img_c = img_c + self.img_cond_pos_embed
|
|
||||||
|
|
||||||
# Forward pass
|
|
||||||
conds = [lang_c, img_c]
|
|
||||||
masks = [lang_mask, img_mask]
|
|
||||||
for i, block in enumerate(self.blocks):
|
|
||||||
c, mask = conds[i % 2], masks[i % 2]
|
|
||||||
x = block(x, c, mask) # (B, T+1, D)
|
|
||||||
# Inject the language condition at the final layer
|
|
||||||
x = self.final_layer(x) # (B, T+1, out_channels)
|
|
||||||
|
|
||||||
# Only preserve the action tokens
|
|
||||||
x = x[:, -self.horizon:]
|
|
||||||
return x
|
|
||||||
@ -1,246 +0,0 @@
|
|||||||
import re, sys, os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
|
||||||
from diffusers.schedulers.scheduling_dpmsolver_multistep import \
|
|
||||||
DPMSolverMultistepScheduler
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
# get current workspace
|
|
||||||
current_file = Path(__file__)
|
|
||||||
sys.path.append(os.path.join(current_file.parent))
|
|
||||||
from hub_mixin import CompatiblePyTorchModelHubMixin
|
|
||||||
from rdt.model import RDT
|
|
||||||
|
|
||||||
|
|
||||||
class RDTRunner(nn.Module,
|
|
||||||
CompatiblePyTorchModelHubMixin,
|
|
||||||
repo_url="https://huggingface.co/robotics-diffusion-transformer/rdt-1b"):
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
*,
|
|
||||||
action_dim,
|
|
||||||
pred_horizon,
|
|
||||||
config,
|
|
||||||
lang_token_dim,
|
|
||||||
img_token_dim,
|
|
||||||
state_token_dim,
|
|
||||||
max_lang_cond_len,
|
|
||||||
img_cond_len,
|
|
||||||
lang_pos_embed_config=None,
|
|
||||||
img_pos_embed_config=None,
|
|
||||||
dtype=torch.bfloat16):
|
|
||||||
super(RDTRunner, self).__init__()
|
|
||||||
# Create diffusion model
|
|
||||||
hidden_size = config['rdt']['hidden_size']
|
|
||||||
self.model = RDT(
|
|
||||||
output_dim=action_dim,
|
|
||||||
horizon=pred_horizon,
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
depth=config['rdt']['depth'],
|
|
||||||
num_heads=config['rdt']['num_heads'],
|
|
||||||
max_lang_cond_len=max_lang_cond_len,
|
|
||||||
img_cond_len=img_cond_len,
|
|
||||||
lang_pos_embed_config=lang_pos_embed_config,
|
|
||||||
img_pos_embed_config=img_pos_embed_config,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create adpators for various conditional inputs
|
|
||||||
self.lang_adaptor = self.build_condition_adapter(config['lang_adaptor'],
|
|
||||||
in_features=lang_token_dim,
|
|
||||||
out_features=hidden_size)
|
|
||||||
self.img_adaptor = self.build_condition_adapter(config['img_adaptor'],
|
|
||||||
in_features=img_token_dim,
|
|
||||||
out_features=hidden_size)
|
|
||||||
# A `state` refers to an action or a proprioception vector
|
|
||||||
self.state_adaptor = self.build_condition_adapter(
|
|
||||||
config['state_adaptor'],
|
|
||||||
in_features=state_token_dim * 2, # state + state mask (indicator)
|
|
||||||
out_features=hidden_size)
|
|
||||||
|
|
||||||
# Create the noise scheduler
|
|
||||||
noise_scheduler_config = config['noise_scheduler']
|
|
||||||
self.noise_scheduler = DDPMScheduler(
|
|
||||||
num_train_timesteps=noise_scheduler_config['num_train_timesteps'],
|
|
||||||
beta_schedule=noise_scheduler_config['beta_schedule'],
|
|
||||||
prediction_type=noise_scheduler_config['prediction_type'],
|
|
||||||
clip_sample=noise_scheduler_config['clip_sample'],
|
|
||||||
)
|
|
||||||
self.noise_scheduler_sample = DPMSolverMultistepScheduler(
|
|
||||||
num_train_timesteps=noise_scheduler_config['num_train_timesteps'],
|
|
||||||
beta_schedule=noise_scheduler_config['beta_schedule'],
|
|
||||||
prediction_type=noise_scheduler_config['prediction_type'],
|
|
||||||
)
|
|
||||||
|
|
||||||
self.num_train_timesteps = noise_scheduler_config['num_train_timesteps']
|
|
||||||
self.num_inference_timesteps = noise_scheduler_config['num_inference_timesteps']
|
|
||||||
self.prediction_type = noise_scheduler_config['prediction_type']
|
|
||||||
|
|
||||||
self.pred_horizon = pred_horizon
|
|
||||||
self.action_dim = action_dim
|
|
||||||
|
|
||||||
print("Diffusion params: %e" %
|
|
||||||
sum([p.numel() for p in self.model.parameters()] + [p.numel() for p in self.lang_adaptor.parameters()] +
|
|
||||||
[p.numel()
|
|
||||||
for p in self.img_adaptor.parameters()] + [p.numel() for p in self.state_adaptor.parameters()]))
|
|
||||||
|
|
||||||
def build_condition_adapter(self, projector_type, in_features, out_features):
|
|
||||||
projector = None
|
|
||||||
if projector_type == 'linear':
|
|
||||||
projector = nn.Linear(in_features, out_features)
|
|
||||||
else:
|
|
||||||
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
|
|
||||||
if mlp_gelu_match:
|
|
||||||
mlp_depth = int(mlp_gelu_match.group(1))
|
|
||||||
modules = [nn.Linear(in_features, out_features)]
|
|
||||||
for _ in range(1, mlp_depth):
|
|
||||||
modules.append(nn.GELU(approximate="tanh"))
|
|
||||||
modules.append(nn.Linear(out_features, out_features))
|
|
||||||
projector = nn.Sequential(*modules)
|
|
||||||
|
|
||||||
if projector is None:
|
|
||||||
raise ValueError(f'Unknown projector type: {projector_type}')
|
|
||||||
|
|
||||||
return projector
|
|
||||||
|
|
||||||
def adapt_conditions(self, lang_tokens, img_tokens, state_tokens):
|
|
||||||
'''
|
|
||||||
lang_tokens: (batch_size, lang_len, lang_token_dim)
|
|
||||||
img_tokens: (batch_size, img_len, img_token_dim)
|
|
||||||
state_tokens: (batch_size, state_len, state_token_dim)
|
|
||||||
|
|
||||||
return: adpated (..., hidden_size) for all input tokens
|
|
||||||
'''
|
|
||||||
adpated_lang = self.lang_adaptor(lang_tokens)
|
|
||||||
adpated_img = self.img_adaptor(img_tokens)
|
|
||||||
adpated_state = self.state_adaptor(state_tokens)
|
|
||||||
|
|
||||||
return adpated_lang, adpated_img, adpated_state
|
|
||||||
|
|
||||||
def conditional_sample(self, lang_cond, lang_attn_mask, img_cond, state_traj, action_mask, ctrl_freqs):
|
|
||||||
'''
|
|
||||||
lang_cond: language conditional data, (batch_size, lang_len, hidden_size).
|
|
||||||
lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens,
|
|
||||||
which should be True-False bool tensor.
|
|
||||||
img_cond: image conditional data, (batch_size, img_len, hidden_size).
|
|
||||||
state_traj: (batch_size, 1, hidden_size), state trajectory.
|
|
||||||
action_mask: (batch_size, 1, action_dim), a 0-1 **float** tensor
|
|
||||||
indicating the valid action dimensions.
|
|
||||||
ctrl_freqs: (batch_size,), control frequency for each sample.
|
|
||||||
|
|
||||||
return: (batch_size, horizon, action_dim)
|
|
||||||
'''
|
|
||||||
device = state_traj.device
|
|
||||||
dtype = state_traj.dtype
|
|
||||||
noisy_action = torch.randn(size=(state_traj.shape[0], self.pred_horizon, self.action_dim),
|
|
||||||
dtype=dtype,
|
|
||||||
device=device)
|
|
||||||
action_mask = action_mask.expand(-1, self.pred_horizon, -1)
|
|
||||||
|
|
||||||
# Set step values
|
|
||||||
self.noise_scheduler_sample.set_timesteps(self.num_inference_timesteps)
|
|
||||||
|
|
||||||
for t in self.noise_scheduler_sample.timesteps:
|
|
||||||
# Prepare state-action trajectory
|
|
||||||
action_traj = torch.cat([noisy_action, action_mask], dim=2)
|
|
||||||
action_traj = self.state_adaptor(action_traj)
|
|
||||||
state_action_traj = torch.cat([state_traj, action_traj], dim=1)
|
|
||||||
|
|
||||||
# Predict the model output
|
|
||||||
model_output = self.model(state_action_traj,
|
|
||||||
ctrl_freqs,
|
|
||||||
t.unsqueeze(-1).to(device),
|
|
||||||
lang_cond,
|
|
||||||
img_cond,
|
|
||||||
lang_mask=lang_attn_mask)
|
|
||||||
|
|
||||||
# Compute previous actions: x_t -> x_t-1
|
|
||||||
noisy_action = self.noise_scheduler_sample.step(model_output, t, noisy_action).prev_sample
|
|
||||||
noisy_action = noisy_action.to(state_traj.dtype)
|
|
||||||
|
|
||||||
# Finally apply the action mask to mask invalid action dimensions
|
|
||||||
noisy_action = noisy_action * action_mask
|
|
||||||
|
|
||||||
return noisy_action
|
|
||||||
|
|
||||||
# ========= Train ============
|
|
||||||
def compute_loss(self, lang_tokens, lang_attn_mask, img_tokens, state_tokens, action_gt, action_mask,
|
|
||||||
ctrl_freqs) -> torch.Tensor:
|
|
||||||
'''
|
|
||||||
lang_tokens: (batch_size, lang_len, lang_token_dim)
|
|
||||||
lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens,
|
|
||||||
which should be True-False bool tensor.
|
|
||||||
img_tokens: (batch_size, img_len, img_token_dim)
|
|
||||||
state_tokens: (batch_size, 1, state_token_dim)
|
|
||||||
action_gt: (batch_size, horizon, state_token_dim), ground-truth actions for supervision
|
|
||||||
action_mask: (batch_size, 1, state_token_dim), a 0-1 **float** tensor.
|
|
||||||
ctrl_freqs: (batch_size,), control frequency for each sample.
|
|
||||||
|
|
||||||
return: loss_value, a scalar tensor
|
|
||||||
'''
|
|
||||||
batch_size = lang_tokens.shape[0]
|
|
||||||
device = lang_tokens.device
|
|
||||||
# Sample noise that we'll add to the actions
|
|
||||||
noise = torch.randn(action_gt.shape, dtype=action_gt.dtype, device=device)
|
|
||||||
# Sample random diffusion timesteps
|
|
||||||
timesteps = torch.randint(0, self.num_train_timesteps, (batch_size, ), device=device).long()
|
|
||||||
# Add noise to the clean actions according to the noise magnitude at each timestep
|
|
||||||
# (this is the forward diffusion process)
|
|
||||||
noisy_action = self.noise_scheduler.add_noise(action_gt, noise, timesteps)
|
|
||||||
|
|
||||||
# Concatenate the state and action tokens to form the input sequence
|
|
||||||
state_action_traj = torch.cat([state_tokens, noisy_action], dim=1)
|
|
||||||
# Append the action mask to the input sequence
|
|
||||||
action_mask = action_mask.expand(-1, state_action_traj.shape[1], -1)
|
|
||||||
state_action_traj = torch.cat([state_action_traj, action_mask], dim=2)
|
|
||||||
# Align the dimension with the hidden size
|
|
||||||
lang_cond, img_cond, state_action_traj = self.adapt_conditions(lang_tokens, img_tokens, state_action_traj)
|
|
||||||
# Predict the denoised result
|
|
||||||
pred = self.model(state_action_traj, ctrl_freqs, timesteps, lang_cond, img_cond, lang_mask=lang_attn_mask)
|
|
||||||
|
|
||||||
pred_type = self.prediction_type
|
|
||||||
if pred_type == 'epsilon':
|
|
||||||
target = noise
|
|
||||||
elif pred_type == 'sample':
|
|
||||||
target = action_gt
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported prediction type {pred_type}")
|
|
||||||
loss = F.mse_loss(pred, target)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
# ========= Inference ============
|
|
||||||
def predict_action(self, lang_tokens, lang_attn_mask, img_tokens, state_tokens, action_mask, ctrl_freqs):
|
|
||||||
'''
|
|
||||||
lang_tokens: (batch_size, lang_len, lang_token_dim)
|
|
||||||
lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens,
|
|
||||||
which should be True-False bool tensor.
|
|
||||||
img_tokens: (batch_size, img_len, img_token_dim)
|
|
||||||
state_tokens: (batch_size, 1, state_token_dim)
|
|
||||||
action_mask: (batch_size, 1, action_dim),
|
|
||||||
which should be a 0-1 **float** tensor.
|
|
||||||
ctrl_freqs: (batch_size,), control frequency for each sample.
|
|
||||||
|
|
||||||
return: (batch_size, horizon, action_dim), predicted action sequence
|
|
||||||
'''
|
|
||||||
# Prepare the state and conditions
|
|
||||||
state_tokens = torch.cat([state_tokens, action_mask], dim=2)
|
|
||||||
lang_cond, img_cond, state_traj = self.adapt_conditions(lang_tokens, img_tokens, state_tokens)
|
|
||||||
|
|
||||||
# Run sampling
|
|
||||||
action_pred = self.conditional_sample(
|
|
||||||
lang_cond,
|
|
||||||
lang_attn_mask,
|
|
||||||
img_cond,
|
|
||||||
state_traj,
|
|
||||||
action_mask,
|
|
||||||
ctrl_freqs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return action_pred
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs) -> torch.Tensor:
|
|
||||||
return self.compute_loss(*args, **kwargs)
|
|
||||||
@ -1,49 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
export NCCL_IB_HCA=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
|
|
||||||
export NCCL_IB_DISABLE=0
|
|
||||||
export NCCL_SOCKET_IFNAME=bond0
|
|
||||||
export NCCL_DEBUG=INFO
|
|
||||||
export NCCL_NVLS_ENABLE=0
|
|
||||||
|
|
||||||
export TEXT_ENCODER_NAME="google/t5-v1_1-xxl"
|
|
||||||
export VISION_ENCODER_NAME="google/siglip-so400m-patch14-384"
|
|
||||||
export OUTPUT_DIR="./checkpoints/rdt-pretrain-1b"
|
|
||||||
export CFLAGS="-I/usr/include"
|
|
||||||
export LDFLAGS="-L/usr/lib/x86_64-linux-gnu"
|
|
||||||
export CUTLASS_PATH="/path/to/cutlass"
|
|
||||||
|
|
||||||
export WANDB_PROJECT="robotics_diffusion_transformer"
|
|
||||||
|
|
||||||
if [ ! -d "$OUTPUT_DIR" ]; then
|
|
||||||
mkdir "$OUTPUT_DIR"
|
|
||||||
echo "Folder '$OUTPUT_DIR' created"
|
|
||||||
else
|
|
||||||
echo "Folder '$OUTPUT_DIR' already exists"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# For run in a single node/machine
|
|
||||||
# accelerate launch main.py \
|
|
||||||
# --deepspeed="./configs/zero2.json" \
|
|
||||||
# ...
|
|
||||||
|
|
||||||
deepspeed --hostfile=hostfile.txt main.py \
|
|
||||||
--deepspeed="./configs/zero2.json" \
|
|
||||||
--pretrained_text_encoder_name_or_path=$TEXT_ENCODER_NAME \
|
|
||||||
--pretrained_vision_encoder_name_or_path=$VISION_ENCODER_NAME \
|
|
||||||
--output_dir=$OUTPUT_DIR \
|
|
||||||
--train_batch_size=32 \
|
|
||||||
--sample_batch_size=64 \
|
|
||||||
--max_train_steps=1000000 \
|
|
||||||
--checkpointing_period=1000 \
|
|
||||||
--sample_period=500 \
|
|
||||||
--checkpoints_total_limit=40 \
|
|
||||||
--lr_scheduler="constant" \
|
|
||||||
--learning_rate=1e-4 \
|
|
||||||
--mixed_precision="bf16" \
|
|
||||||
--dataloader_num_workers=8 \
|
|
||||||
--dataset_type="pretrain" \
|
|
||||||
--report_to=wandb
|
|
||||||
|
|
||||||
# Use this to resume training from some previous checkpoint
|
|
||||||
# --resume_from_checkpoint="checkpoint-1000" \
|
|
||||||
@ -1,9 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
task_name=${1}
|
|
||||||
task_config=${2}
|
|
||||||
expert_data_num=${3}
|
|
||||||
gpu_id=${4}
|
|
||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES=${gpu_id}
|
|
||||||
python scripts/process_data.py $task_name $task_config $expert_data_num
|
|
||||||
@ -1,23 +0,0 @@
|
|||||||
numpy<2.0
|
|
||||||
packaging==24.0
|
|
||||||
wandb==0.17.0
|
|
||||||
deepspeed==0.14.2
|
|
||||||
accelerate==0.30.1
|
|
||||||
diffusers==0.27.2
|
|
||||||
timm==1.0.3
|
|
||||||
transformers==4.41.0
|
|
||||||
sentencepiece==0.2.0
|
|
||||||
h5py==3.11.0
|
|
||||||
opencv-python==4.9.0.80
|
|
||||||
imgaug==0.4.0
|
|
||||||
pytz>=2020.1
|
|
||||||
|
|
||||||
# requirements_data.txt
|
|
||||||
tfds-nightly==4.9.4.dev202402070044
|
|
||||||
gsutil==5.27
|
|
||||||
tensorflow==2.15.0.post1
|
|
||||||
pillow==10.2.0
|
|
||||||
pyyaml==6.0.1
|
|
||||||
tensorflow-graphics==2021.12.3
|
|
||||||
imageio==2.34.0
|
|
||||||
imageio-ffmpeg==0.4.9
|
|
||||||
@ -1,941 +0,0 @@
|
|||||||
#!/home/lin/software/miniconda3/envs/aloha/bin/python
|
|
||||||
# -- coding: UTF-8
|
|
||||||
"""
|
|
||||||
#!/usr/bin/python3
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import sys
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import yaml
|
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import rospy
|
|
||||||
import torch
|
|
||||||
from cv_bridge import CvBridge
|
|
||||||
from geometry_msgs.msg import Twist
|
|
||||||
from nav_msgs.msg import Odometry
|
|
||||||
from PIL import Image as PImage
|
|
||||||
from sensor_msgs.msg import Image, JointState
|
|
||||||
from std_msgs.msg import Header
|
|
||||||
import cv2
|
|
||||||
|
|
||||||
from scripts.agilex_model import create_model
|
|
||||||
|
|
||||||
# sys.path.append("./")
|
|
||||||
|
|
||||||
CAMERA_NAMES = ["cam_high", "cam_right_wrist", "cam_left_wrist"]
|
|
||||||
|
|
||||||
observation_window = None
|
|
||||||
|
|
||||||
lang_embeddings = None
|
|
||||||
|
|
||||||
# debug
|
|
||||||
preload_images = None
|
|
||||||
|
|
||||||
|
|
||||||
# Initialize the model
|
|
||||||
def make_policy(args):
|
|
||||||
with open(args.config_path, "r") as fp:
|
|
||||||
config = yaml.safe_load(fp)
|
|
||||||
args.config = config
|
|
||||||
|
|
||||||
# pretrained_text_encoder_name_or_path = "google/t5-v1_1-xxl"
|
|
||||||
pretrained_vision_encoder_name_or_path = "google/siglip-so400m-patch14-384"
|
|
||||||
model = create_model(
|
|
||||||
args=args.config,
|
|
||||||
dtype=torch.bfloat16,
|
|
||||||
pretrained=args.pretrained_model_name_or_path,
|
|
||||||
# pretrained_text_encoder_name_or_path=pretrained_text_encoder_name_or_path,
|
|
||||||
pretrained_vision_encoder_name_or_path=pretrained_vision_encoder_name_or_path,
|
|
||||||
control_frequency=args.ctrl_freq,
|
|
||||||
)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed):
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
|
|
||||||
|
|
||||||
# Interpolate the actions to make the robot move smoothly
|
|
||||||
def interpolate_action(args, prev_action, cur_action):
|
|
||||||
steps = np.concatenate((np.array(args.arm_steps_length), np.array(args.arm_steps_length)), axis=0)
|
|
||||||
diff = np.abs(cur_action - prev_action)
|
|
||||||
step = np.ceil(diff / steps).astype(int)
|
|
||||||
step = np.max(step)
|
|
||||||
if step <= 1:
|
|
||||||
return cur_action[np.newaxis, :]
|
|
||||||
new_actions = np.linspace(prev_action, cur_action, step + 1)
|
|
||||||
return new_actions[1:]
|
|
||||||
|
|
||||||
|
|
||||||
def get_config(args):
|
|
||||||
config = {
|
|
||||||
"episode_len": args.max_publish_step,
|
|
||||||
"state_dim": 14,
|
|
||||||
"chunk_size": args.chunk_size,
|
|
||||||
"camera_names": CAMERA_NAMES,
|
|
||||||
}
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
# Get the observation from the ROS topic
|
|
||||||
def get_ros_observation(args, ros_operator):
|
|
||||||
rate = rospy.Rate(args.publish_rate)
|
|
||||||
print_flag = True
|
|
||||||
|
|
||||||
while True and not rospy.is_shutdown():
|
|
||||||
result = ros_operator.get_frame()
|
|
||||||
if not result:
|
|
||||||
if print_flag:
|
|
||||||
print("syn fail when get_ros_observation")
|
|
||||||
print_flag = False
|
|
||||||
rate.sleep()
|
|
||||||
continue
|
|
||||||
print_flag = True
|
|
||||||
(
|
|
||||||
img_front,
|
|
||||||
img_left,
|
|
||||||
img_right,
|
|
||||||
img_front_depth,
|
|
||||||
img_left_depth,
|
|
||||||
img_right_depth,
|
|
||||||
puppet_arm_left,
|
|
||||||
puppet_arm_right,
|
|
||||||
robot_base,
|
|
||||||
) = result
|
|
||||||
# print(f"sync success when get_ros_observation")
|
|
||||||
return (img_front, img_left, img_right, puppet_arm_left, puppet_arm_right)
|
|
||||||
|
|
||||||
|
|
||||||
# Update the observation window buffer
|
|
||||||
def update_observation_window(args, config, ros_operator):
|
|
||||||
# JPEG transformation
|
|
||||||
# Align with training
|
|
||||||
def jpeg_mapping(img):
|
|
||||||
img = cv2.imencode(".jpg", img)[1].tobytes()
|
|
||||||
img = cv2.imdecode(np.frombuffer(img, np.uint8), cv2.IMREAD_COLOR)
|
|
||||||
return img
|
|
||||||
|
|
||||||
global observation_window
|
|
||||||
if observation_window is None:
|
|
||||||
observation_window = deque(maxlen=2)
|
|
||||||
|
|
||||||
# Append the first dummy image
|
|
||||||
observation_window.append({
|
|
||||||
"qpos": None,
|
|
||||||
"images": {
|
|
||||||
config["camera_names"][0]: None,
|
|
||||||
config["camera_names"][1]: None,
|
|
||||||
config["camera_names"][2]: None,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
img_front, img_left, img_right, puppet_arm_left, puppet_arm_right = (get_ros_observation(args, ros_operator))
|
|
||||||
img_front = jpeg_mapping(img_front)
|
|
||||||
img_left = jpeg_mapping(img_left)
|
|
||||||
img_right = jpeg_mapping(img_right)
|
|
||||||
|
|
||||||
qpos = np.concatenate(
|
|
||||||
(np.array(puppet_arm_left.position), np.array(puppet_arm_right.position)),
|
|
||||||
axis=0,
|
|
||||||
)
|
|
||||||
qpos = torch.from_numpy(qpos).float().cuda()
|
|
||||||
observation_window.append({
|
|
||||||
"qpos": qpos,
|
|
||||||
"images": {
|
|
||||||
config["camera_names"][0]: img_front,
|
|
||||||
config["camera_names"][1]: img_right,
|
|
||||||
config["camera_names"][2]: img_left,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
# RDT inference
|
|
||||||
def inference_fn(args, config, policy, t):
|
|
||||||
global observation_window
|
|
||||||
global lang_embeddings
|
|
||||||
|
|
||||||
# print(f"Start inference_thread_fn: t={t}")
|
|
||||||
while True and not rospy.is_shutdown():
|
|
||||||
time1 = time.time()
|
|
||||||
|
|
||||||
# fetch images in sequence [front, right, left]
|
|
||||||
image_arrs = [
|
|
||||||
observation_window[-2]["images"][config["camera_names"][0]],
|
|
||||||
observation_window[-2]["images"][config["camera_names"][1]],
|
|
||||||
observation_window[-2]["images"][config["camera_names"][2]],
|
|
||||||
observation_window[-1]["images"][config["camera_names"][0]],
|
|
||||||
observation_window[-1]["images"][config["camera_names"][1]],
|
|
||||||
observation_window[-1]["images"][config["camera_names"][2]],
|
|
||||||
]
|
|
||||||
|
|
||||||
# fetch debug images in sequence [front, right, left]
|
|
||||||
# image_arrs = [
|
|
||||||
# preload_images[config['camera_names'][0]][max(t - 1, 0)],
|
|
||||||
# preload_images[config['camera_names'][2]][max(t - 1, 0)],
|
|
||||||
# preload_images[config['camera_names'][1]][max(t - 1, 0)],
|
|
||||||
# preload_images[config['camera_names'][0]][t],
|
|
||||||
# preload_images[config['camera_names'][2]][t],
|
|
||||||
# preload_images[config['camera_names'][1]][t]
|
|
||||||
# ]
|
|
||||||
# # encode the images
|
|
||||||
# for i in range(len(image_arrs)):
|
|
||||||
# image_arrs[i] = cv2.imdecode(np.frombuffer(image_arrs[i], np.uint8), cv2.IMREAD_COLOR)
|
|
||||||
# proprio = torch.from_numpy(preload_images['qpos'][t]).float().cuda()
|
|
||||||
|
|
||||||
images = [PImage.fromarray(arr) if arr is not None else None for arr in image_arrs]
|
|
||||||
|
|
||||||
# for i, pos in enumerate(['f', 'r', 'l'] * 2):
|
|
||||||
# images[i].save(f'{t}-{i}-{pos}.png')
|
|
||||||
|
|
||||||
# get last qpos in shape [14, ]
|
|
||||||
proprio = observation_window[-1]["qpos"]
|
|
||||||
# unsqueeze to [1, 14]
|
|
||||||
proprio = proprio.unsqueeze(0)
|
|
||||||
|
|
||||||
# actions shaped as [1, 64, 14] in format [left, right]
|
|
||||||
actions = (policy.step(proprio=proprio, images=images, text_embeds=lang_embeddings).squeeze(0).cpu().numpy())
|
|
||||||
# print(f"inference_actions: {actions.squeeze()}")
|
|
||||||
|
|
||||||
# print(f"Model inference time: {time.time() - time1} s")
|
|
||||||
|
|
||||||
# print(f"Finish inference_thread_fn: t={t}")
|
|
||||||
return actions
|
|
||||||
|
|
||||||
|
|
||||||
# Main loop for the manipulation task
|
|
||||||
def model_inference(args, config, ros_operator):
|
|
||||||
global lang_embeddings
|
|
||||||
|
|
||||||
# Load rdt model
|
|
||||||
policy = make_policy(args)
|
|
||||||
|
|
||||||
lang_dict = torch.load(args.lang_embeddings_path)
|
|
||||||
print(f"Running with instruction: \"{lang_dict['instruction']}\" from \"{lang_dict['name']}\"")
|
|
||||||
lang_embeddings = lang_dict["embeddings"]
|
|
||||||
|
|
||||||
max_publish_step = config["episode_len"]
|
|
||||||
chunk_size = config["chunk_size"]
|
|
||||||
|
|
||||||
# Initialize position of the puppet arm
|
|
||||||
left0 = [
|
|
||||||
-0.00133514404296875,
|
|
||||||
0.00209808349609375,
|
|
||||||
0.01583099365234375,
|
|
||||||
-0.032616615295410156,
|
|
||||||
-0.00286102294921875,
|
|
||||||
0.00095367431640625,
|
|
||||||
3.557830810546875,
|
|
||||||
]
|
|
||||||
right0 = [
|
|
||||||
-0.00133514404296875,
|
|
||||||
0.00438690185546875,
|
|
||||||
0.034523963928222656,
|
|
||||||
-0.053597450256347656,
|
|
||||||
-0.00476837158203125,
|
|
||||||
-0.00209808349609375,
|
|
||||||
3.557830810546875,
|
|
||||||
]
|
|
||||||
left1 = [
|
|
||||||
-0.00133514404296875,
|
|
||||||
0.00209808349609375,
|
|
||||||
0.01583099365234375,
|
|
||||||
-0.032616615295410156,
|
|
||||||
-0.00286102294921875,
|
|
||||||
0.00095367431640625,
|
|
||||||
-0.3393220901489258,
|
|
||||||
]
|
|
||||||
right1 = [
|
|
||||||
-0.00133514404296875,
|
|
||||||
0.00247955322265625,
|
|
||||||
0.01583099365234375,
|
|
||||||
-0.032616615295410156,
|
|
||||||
-0.00286102294921875,
|
|
||||||
0.00095367431640625,
|
|
||||||
-0.3397035598754883,
|
|
||||||
]
|
|
||||||
ros_operator.puppet_arm_publish_continuous(left0, right0)
|
|
||||||
input("Press enter to continue")
|
|
||||||
ros_operator.puppet_arm_publish_continuous(left1, right1)
|
|
||||||
# Initialize the previous action to be the initial robot state
|
|
||||||
pre_action = np.zeros(config["state_dim"])
|
|
||||||
pre_action[:14] = np.array([
|
|
||||||
-0.00133514404296875,
|
|
||||||
0.00209808349609375,
|
|
||||||
0.01583099365234375,
|
|
||||||
-0.032616615295410156,
|
|
||||||
-0.00286102294921875,
|
|
||||||
0.00095367431640625,
|
|
||||||
-0.3393220901489258,
|
|
||||||
] + [
|
|
||||||
-0.00133514404296875,
|
|
||||||
0.00247955322265625,
|
|
||||||
0.01583099365234375,
|
|
||||||
-0.032616615295410156,
|
|
||||||
-0.00286102294921875,
|
|
||||||
0.00095367431640625,
|
|
||||||
-0.3397035598754883,
|
|
||||||
])
|
|
||||||
action = None
|
|
||||||
# Inference loop
|
|
||||||
with torch.inference_mode():
|
|
||||||
while True and not rospy.is_shutdown():
|
|
||||||
# The current time step
|
|
||||||
t = 0
|
|
||||||
rate = rospy.Rate(args.publish_rate)
|
|
||||||
|
|
||||||
action_buffer = np.zeros([chunk_size, config["state_dim"]])
|
|
||||||
|
|
||||||
while t < max_publish_step and not rospy.is_shutdown():
|
|
||||||
# Update observation window
|
|
||||||
update_observation_window(args, config, ros_operator)
|
|
||||||
|
|
||||||
# When coming to the end of the action chunk
|
|
||||||
if t % chunk_size == 0:
|
|
||||||
# Start inference
|
|
||||||
action_buffer = inference_fn(args, config, policy, t).copy()
|
|
||||||
|
|
||||||
raw_action = action_buffer[t % chunk_size]
|
|
||||||
action = raw_action
|
|
||||||
# Interpolate the original action sequence
|
|
||||||
if args.use_actions_interpolation:
|
|
||||||
# print(f"Time {t}, pre {pre_action}, act {action}")
|
|
||||||
interp_actions = interpolate_action(args, pre_action, action)
|
|
||||||
else:
|
|
||||||
interp_actions = action[np.newaxis, :]
|
|
||||||
# Execute the interpolated actions one by one
|
|
||||||
for act in interp_actions:
|
|
||||||
left_action = act[:7]
|
|
||||||
right_action = act[7:14]
|
|
||||||
|
|
||||||
if not args.disable_puppet_arm:
|
|
||||||
ros_operator.puppet_arm_publish(left_action,
|
|
||||||
right_action) # puppet_arm_publish_continuous_thread
|
|
||||||
|
|
||||||
if args.use_robot_base:
|
|
||||||
vel_action = act[14:16]
|
|
||||||
ros_operator.robot_base_publish(vel_action)
|
|
||||||
rate.sleep()
|
|
||||||
# print(f"doing action: {act}")
|
|
||||||
t += 1
|
|
||||||
|
|
||||||
print("Published Step", t)
|
|
||||||
pre_action = action.copy()
|
|
||||||
|
|
||||||
|
|
||||||
# ROS operator class
|
|
||||||
class RosOperator:
|
|
||||||
|
|
||||||
def __init__(self, args):
|
|
||||||
self.robot_base_deque = None
|
|
||||||
self.puppet_arm_right_deque = None
|
|
||||||
self.puppet_arm_left_deque = None
|
|
||||||
self.img_front_deque = None
|
|
||||||
self.img_right_deque = None
|
|
||||||
self.img_left_deque = None
|
|
||||||
self.img_front_depth_deque = None
|
|
||||||
self.img_right_depth_deque = None
|
|
||||||
self.img_left_depth_deque = None
|
|
||||||
self.bridge = None
|
|
||||||
self.puppet_arm_left_publisher = None
|
|
||||||
self.puppet_arm_right_publisher = None
|
|
||||||
self.robot_base_publisher = None
|
|
||||||
self.puppet_arm_publish_thread = None
|
|
||||||
self.puppet_arm_publish_lock = None
|
|
||||||
self.args = args
|
|
||||||
self.init()
|
|
||||||
self.init_ros()
|
|
||||||
|
|
||||||
def init(self):
|
|
||||||
self.bridge = CvBridge()
|
|
||||||
self.img_left_deque = deque()
|
|
||||||
self.img_right_deque = deque()
|
|
||||||
self.img_front_deque = deque()
|
|
||||||
self.img_left_depth_deque = deque()
|
|
||||||
self.img_right_depth_deque = deque()
|
|
||||||
self.img_front_depth_deque = deque()
|
|
||||||
self.puppet_arm_left_deque = deque()
|
|
||||||
self.puppet_arm_right_deque = deque()
|
|
||||||
self.robot_base_deque = deque()
|
|
||||||
self.puppet_arm_publish_lock = threading.Lock()
|
|
||||||
self.puppet_arm_publish_lock.acquire()
|
|
||||||
|
|
||||||
def puppet_arm_publish(self, left, right):
|
|
||||||
joint_state_msg = JointState()
|
|
||||||
joint_state_msg.header = Header()
|
|
||||||
joint_state_msg.header.stamp = rospy.Time.now() # Set timestep
|
|
||||||
joint_state_msg.name = [
|
|
||||||
"joint0",
|
|
||||||
"joint1",
|
|
||||||
"joint2",
|
|
||||||
"joint3",
|
|
||||||
"joint4",
|
|
||||||
"joint5",
|
|
||||||
"joint6",
|
|
||||||
] # 设置关节名称
|
|
||||||
joint_state_msg.position = left
|
|
||||||
self.puppet_arm_left_publisher.publish(joint_state_msg)
|
|
||||||
joint_state_msg.position = right
|
|
||||||
self.puppet_arm_right_publisher.publish(joint_state_msg)
|
|
||||||
|
|
||||||
def robot_base_publish(self, vel):
|
|
||||||
vel_msg = Twist()
|
|
||||||
vel_msg.linear.x = vel[0]
|
|
||||||
vel_msg.linear.y = 0
|
|
||||||
vel_msg.linear.z = 0
|
|
||||||
vel_msg.angular.x = 0
|
|
||||||
vel_msg.angular.y = 0
|
|
||||||
vel_msg.angular.z = vel[1]
|
|
||||||
self.robot_base_publisher.publish(vel_msg)
|
|
||||||
|
|
||||||
def puppet_arm_publish_continuous(self, left, right):
|
|
||||||
rate = rospy.Rate(self.args.publish_rate)
|
|
||||||
left_arm = None
|
|
||||||
right_arm = None
|
|
||||||
while True and not rospy.is_shutdown():
|
|
||||||
if len(self.puppet_arm_left_deque) != 0:
|
|
||||||
left_arm = list(self.puppet_arm_left_deque[-1].position)
|
|
||||||
if len(self.puppet_arm_right_deque) != 0:
|
|
||||||
right_arm = list(self.puppet_arm_right_deque[-1].position)
|
|
||||||
if left_arm is None or right_arm is None:
|
|
||||||
rate.sleep()
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
left_symbol = [1 if left[i] - left_arm[i] > 0 else -1 for i in range(len(left))]
|
|
||||||
right_symbol = [1 if right[i] - right_arm[i] > 0 else -1 for i in range(len(right))]
|
|
||||||
flag = True
|
|
||||||
step = 0
|
|
||||||
while flag and not rospy.is_shutdown():
|
|
||||||
if self.puppet_arm_publish_lock.acquire(False):
|
|
||||||
return
|
|
||||||
left_diff = [abs(left[i] - left_arm[i]) for i in range(len(left))]
|
|
||||||
right_diff = [abs(right[i] - right_arm[i]) for i in range(len(right))]
|
|
||||||
flag = False
|
|
||||||
for i in range(len(left)):
|
|
||||||
if left_diff[i] < self.args.arm_steps_length[i]:
|
|
||||||
left_arm[i] = left[i]
|
|
||||||
else:
|
|
||||||
left_arm[i] += left_symbol[i] * self.args.arm_steps_length[i]
|
|
||||||
flag = True
|
|
||||||
for i in range(len(right)):
|
|
||||||
if right_diff[i] < self.args.arm_steps_length[i]:
|
|
||||||
right_arm[i] = right[i]
|
|
||||||
else:
|
|
||||||
right_arm[i] += right_symbol[i] * self.args.arm_steps_length[i]
|
|
||||||
flag = True
|
|
||||||
joint_state_msg = JointState()
|
|
||||||
joint_state_msg.header = Header()
|
|
||||||
joint_state_msg.header.stamp = rospy.Time.now() # Set the timestep
|
|
||||||
joint_state_msg.name = [
|
|
||||||
"joint0",
|
|
||||||
"joint1",
|
|
||||||
"joint2",
|
|
||||||
"joint3",
|
|
||||||
"joint4",
|
|
||||||
"joint5",
|
|
||||||
"joint6",
|
|
||||||
] # 设置关节名称
|
|
||||||
joint_state_msg.position = left_arm
|
|
||||||
self.puppet_arm_left_publisher.publish(joint_state_msg)
|
|
||||||
joint_state_msg.position = right_arm
|
|
||||||
self.puppet_arm_right_publisher.publish(joint_state_msg)
|
|
||||||
step += 1
|
|
||||||
print("puppet_arm_publish_continuous:", step)
|
|
||||||
rate.sleep()
|
|
||||||
|
|
||||||
def puppet_arm_publish_linear(self, left, right):
|
|
||||||
num_step = 100
|
|
||||||
rate = rospy.Rate(200)
|
|
||||||
|
|
||||||
left_arm = None
|
|
||||||
right_arm = None
|
|
||||||
|
|
||||||
while True and not rospy.is_shutdown():
|
|
||||||
if len(self.puppet_arm_left_deque) != 0:
|
|
||||||
left_arm = list(self.puppet_arm_left_deque[-1].position)
|
|
||||||
if len(self.puppet_arm_right_deque) != 0:
|
|
||||||
right_arm = list(self.puppet_arm_right_deque[-1].position)
|
|
||||||
if left_arm is None or right_arm is None:
|
|
||||||
rate.sleep()
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
traj_left_list = np.linspace(left_arm, left, num_step)
|
|
||||||
traj_right_list = np.linspace(right_arm, right, num_step)
|
|
||||||
|
|
||||||
for i in range(len(traj_left_list)):
|
|
||||||
traj_left = traj_left_list[i]
|
|
||||||
traj_right = traj_right_list[i]
|
|
||||||
traj_left[-1] = left[-1]
|
|
||||||
traj_right[-1] = right[-1]
|
|
||||||
joint_state_msg = JointState()
|
|
||||||
joint_state_msg.header = Header()
|
|
||||||
joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳
|
|
||||||
joint_state_msg.name = [
|
|
||||||
"joint0",
|
|
||||||
"joint1",
|
|
||||||
"joint2",
|
|
||||||
"joint3",
|
|
||||||
"joint4",
|
|
||||||
"joint5",
|
|
||||||
"joint6",
|
|
||||||
] # 设置关节名称
|
|
||||||
joint_state_msg.position = traj_left
|
|
||||||
self.puppet_arm_left_publisher.publish(joint_state_msg)
|
|
||||||
joint_state_msg.position = traj_right
|
|
||||||
self.puppet_arm_right_publisher.publish(joint_state_msg)
|
|
||||||
rate.sleep()
|
|
||||||
|
|
||||||
def puppet_arm_publish_continuous_thread(self, left, right):
|
|
||||||
if self.puppet_arm_publish_thread is not None:
|
|
||||||
self.puppet_arm_publish_lock.release()
|
|
||||||
self.puppet_arm_publish_thread.join()
|
|
||||||
self.puppet_arm_publish_lock.acquire(False)
|
|
||||||
self.puppet_arm_publish_thread = None
|
|
||||||
self.puppet_arm_publish_thread = threading.Thread(target=self.puppet_arm_publish_continuous, args=(left, right))
|
|
||||||
self.puppet_arm_publish_thread.start()
|
|
||||||
|
|
||||||
def get_frame(self):
|
|
||||||
if (len(self.img_left_deque) == 0 or len(self.img_right_deque) == 0 or len(self.img_front_deque) == 0 or
|
|
||||||
(self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or len(self.img_right_depth_deque) == 0
|
|
||||||
or len(self.img_front_depth_deque) == 0))):
|
|
||||||
return False
|
|
||||||
if self.args.use_depth_image:
|
|
||||||
frame_time = min([
|
|
||||||
self.img_left_deque[-1].header.stamp.to_sec(),
|
|
||||||
self.img_right_deque[-1].header.stamp.to_sec(),
|
|
||||||
self.img_front_deque[-1].header.stamp.to_sec(),
|
|
||||||
self.img_left_depth_deque[-1].header.stamp.to_sec(),
|
|
||||||
self.img_right_depth_deque[-1].header.stamp.to_sec(),
|
|
||||||
self.img_front_depth_deque[-1].header.stamp.to_sec(),
|
|
||||||
])
|
|
||||||
else:
|
|
||||||
frame_time = min([
|
|
||||||
self.img_left_deque[-1].header.stamp.to_sec(),
|
|
||||||
self.img_right_deque[-1].header.stamp.to_sec(),
|
|
||||||
self.img_front_deque[-1].header.stamp.to_sec(),
|
|
||||||
])
|
|
||||||
|
|
||||||
if (len(self.img_left_deque) == 0 or self.img_left_deque[-1].header.stamp.to_sec() < frame_time):
|
|
||||||
return False
|
|
||||||
if (len(self.img_right_deque) == 0 or self.img_right_deque[-1].header.stamp.to_sec() < frame_time):
|
|
||||||
return False
|
|
||||||
if (len(self.img_front_deque) == 0 or self.img_front_deque[-1].header.stamp.to_sec() < frame_time):
|
|
||||||
return False
|
|
||||||
if (len(self.puppet_arm_left_deque) == 0 or self.puppet_arm_left_deque[-1].header.stamp.to_sec() < frame_time):
|
|
||||||
return False
|
|
||||||
if (len(self.puppet_arm_right_deque) == 0
|
|
||||||
or self.puppet_arm_right_deque[-1].header.stamp.to_sec() < frame_time):
|
|
||||||
return False
|
|
||||||
if self.args.use_depth_image and (len(self.img_left_depth_deque) == 0
|
|
||||||
or self.img_left_depth_deque[-1].header.stamp.to_sec() < frame_time):
|
|
||||||
return False
|
|
||||||
if self.args.use_depth_image and (len(self.img_right_depth_deque) == 0
|
|
||||||
or self.img_right_depth_deque[-1].header.stamp.to_sec() < frame_time):
|
|
||||||
return False
|
|
||||||
if self.args.use_depth_image and (len(self.img_front_depth_deque) == 0
|
|
||||||
or self.img_front_depth_deque[-1].header.stamp.to_sec() < frame_time):
|
|
||||||
return False
|
|
||||||
if self.args.use_robot_base and (len(self.robot_base_deque) == 0
|
|
||||||
or self.robot_base_deque[-1].header.stamp.to_sec() < frame_time):
|
|
||||||
return False
|
|
||||||
|
|
||||||
while self.img_left_deque[0].header.stamp.to_sec() < frame_time:
|
|
||||||
self.img_left_deque.popleft()
|
|
||||||
img_left = self.bridge.imgmsg_to_cv2(self.img_left_deque.popleft(), "passthrough")
|
|
||||||
|
|
||||||
while self.img_right_deque[0].header.stamp.to_sec() < frame_time:
|
|
||||||
self.img_right_deque.popleft()
|
|
||||||
img_right = self.bridge.imgmsg_to_cv2(self.img_right_deque.popleft(), "passthrough")
|
|
||||||
|
|
||||||
while self.img_front_deque[0].header.stamp.to_sec() < frame_time:
|
|
||||||
self.img_front_deque.popleft()
|
|
||||||
img_front = self.bridge.imgmsg_to_cv2(self.img_front_deque.popleft(), "passthrough")
|
|
||||||
|
|
||||||
while self.puppet_arm_left_deque[0].header.stamp.to_sec() < frame_time:
|
|
||||||
self.puppet_arm_left_deque.popleft()
|
|
||||||
puppet_arm_left = self.puppet_arm_left_deque.popleft()
|
|
||||||
|
|
||||||
while self.puppet_arm_right_deque[0].header.stamp.to_sec() < frame_time:
|
|
||||||
self.puppet_arm_right_deque.popleft()
|
|
||||||
puppet_arm_right = self.puppet_arm_right_deque.popleft()
|
|
||||||
|
|
||||||
img_left_depth = None
|
|
||||||
if self.args.use_depth_image:
|
|
||||||
while self.img_left_depth_deque[0].header.stamp.to_sec() < frame_time:
|
|
||||||
self.img_left_depth_deque.popleft()
|
|
||||||
img_left_depth = self.bridge.imgmsg_to_cv2(self.img_left_depth_deque.popleft(), "passthrough")
|
|
||||||
|
|
||||||
img_right_depth = None
|
|
||||||
if self.args.use_depth_image:
|
|
||||||
while self.img_right_depth_deque[0].header.stamp.to_sec() < frame_time:
|
|
||||||
self.img_right_depth_deque.popleft()
|
|
||||||
img_right_depth = self.bridge.imgmsg_to_cv2(self.img_right_depth_deque.popleft(), "passthrough")
|
|
||||||
|
|
||||||
img_front_depth = None
|
|
||||||
if self.args.use_depth_image:
|
|
||||||
while self.img_front_depth_deque[0].header.stamp.to_sec() < frame_time:
|
|
||||||
self.img_front_depth_deque.popleft()
|
|
||||||
img_front_depth = self.bridge.imgmsg_to_cv2(self.img_front_depth_deque.popleft(), "passthrough")
|
|
||||||
|
|
||||||
robot_base = None
|
|
||||||
if self.args.use_robot_base:
|
|
||||||
while self.robot_base_deque[0].header.stamp.to_sec() < frame_time:
|
|
||||||
self.robot_base_deque.popleft()
|
|
||||||
robot_base = self.robot_base_deque.popleft()
|
|
||||||
|
|
||||||
return (
|
|
||||||
img_front,
|
|
||||||
img_left,
|
|
||||||
img_right,
|
|
||||||
img_front_depth,
|
|
||||||
img_left_depth,
|
|
||||||
img_right_depth,
|
|
||||||
puppet_arm_left,
|
|
||||||
puppet_arm_right,
|
|
||||||
robot_base,
|
|
||||||
)
|
|
||||||
|
|
||||||
def img_left_callback(self, msg):
|
|
||||||
if len(self.img_left_deque) >= 2000:
|
|
||||||
self.img_left_deque.popleft()
|
|
||||||
self.img_left_deque.append(msg)
|
|
||||||
|
|
||||||
def img_right_callback(self, msg):
|
|
||||||
if len(self.img_right_deque) >= 2000:
|
|
||||||
self.img_right_deque.popleft()
|
|
||||||
self.img_right_deque.append(msg)
|
|
||||||
|
|
||||||
def img_front_callback(self, msg):
|
|
||||||
if len(self.img_front_deque) >= 2000:
|
|
||||||
self.img_front_deque.popleft()
|
|
||||||
self.img_front_deque.append(msg)
|
|
||||||
|
|
||||||
def img_left_depth_callback(self, msg):
|
|
||||||
if len(self.img_left_depth_deque) >= 2000:
|
|
||||||
self.img_left_depth_deque.popleft()
|
|
||||||
self.img_left_depth_deque.append(msg)
|
|
||||||
|
|
||||||
def img_right_depth_callback(self, msg):
|
|
||||||
if len(self.img_right_depth_deque) >= 2000:
|
|
||||||
self.img_right_depth_deque.popleft()
|
|
||||||
self.img_right_depth_deque.append(msg)
|
|
||||||
|
|
||||||
def img_front_depth_callback(self, msg):
|
|
||||||
if len(self.img_front_depth_deque) >= 2000:
|
|
||||||
self.img_front_depth_deque.popleft()
|
|
||||||
self.img_front_depth_deque.append(msg)
|
|
||||||
|
|
||||||
def puppet_arm_left_callback(self, msg):
|
|
||||||
if len(self.puppet_arm_left_deque) >= 2000:
|
|
||||||
self.puppet_arm_left_deque.popleft()
|
|
||||||
self.puppet_arm_left_deque.append(msg)
|
|
||||||
|
|
||||||
def puppet_arm_right_callback(self, msg):
|
|
||||||
if len(self.puppet_arm_right_deque) >= 2000:
|
|
||||||
self.puppet_arm_right_deque.popleft()
|
|
||||||
self.puppet_arm_right_deque.append(msg)
|
|
||||||
|
|
||||||
def robot_base_callback(self, msg):
|
|
||||||
if len(self.robot_base_deque) >= 2000:
|
|
||||||
self.robot_base_deque.popleft()
|
|
||||||
self.robot_base_deque.append(msg)
|
|
||||||
|
|
||||||
def init_ros(self):
|
|
||||||
rospy.init_node("joint_state_publisher", anonymous=True)
|
|
||||||
rospy.Subscriber(
|
|
||||||
self.args.img_left_topic,
|
|
||||||
Image,
|
|
||||||
self.img_left_callback,
|
|
||||||
queue_size=1000,
|
|
||||||
tcp_nodelay=True,
|
|
||||||
)
|
|
||||||
rospy.Subscriber(
|
|
||||||
self.args.img_right_topic,
|
|
||||||
Image,
|
|
||||||
self.img_right_callback,
|
|
||||||
queue_size=1000,
|
|
||||||
tcp_nodelay=True,
|
|
||||||
)
|
|
||||||
rospy.Subscriber(
|
|
||||||
self.args.img_front_topic,
|
|
||||||
Image,
|
|
||||||
self.img_front_callback,
|
|
||||||
queue_size=1000,
|
|
||||||
tcp_nodelay=True,
|
|
||||||
)
|
|
||||||
if self.args.use_depth_image:
|
|
||||||
rospy.Subscriber(
|
|
||||||
self.args.img_left_depth_topic,
|
|
||||||
Image,
|
|
||||||
self.img_left_depth_callback,
|
|
||||||
queue_size=1000,
|
|
||||||
tcp_nodelay=True,
|
|
||||||
)
|
|
||||||
rospy.Subscriber(
|
|
||||||
self.args.img_right_depth_topic,
|
|
||||||
Image,
|
|
||||||
self.img_right_depth_callback,
|
|
||||||
queue_size=1000,
|
|
||||||
tcp_nodelay=True,
|
|
||||||
)
|
|
||||||
rospy.Subscriber(
|
|
||||||
self.args.img_front_depth_topic,
|
|
||||||
Image,
|
|
||||||
self.img_front_depth_callback,
|
|
||||||
queue_size=1000,
|
|
||||||
tcp_nodelay=True,
|
|
||||||
)
|
|
||||||
rospy.Subscriber(
|
|
||||||
self.args.puppet_arm_left_topic,
|
|
||||||
JointState,
|
|
||||||
self.puppet_arm_left_callback,
|
|
||||||
queue_size=1000,
|
|
||||||
tcp_nodelay=True,
|
|
||||||
)
|
|
||||||
rospy.Subscriber(
|
|
||||||
self.args.puppet_arm_right_topic,
|
|
||||||
JointState,
|
|
||||||
self.puppet_arm_right_callback,
|
|
||||||
queue_size=1000,
|
|
||||||
tcp_nodelay=True,
|
|
||||||
)
|
|
||||||
rospy.Subscriber(
|
|
||||||
self.args.robot_base_topic,
|
|
||||||
Odometry,
|
|
||||||
self.robot_base_callback,
|
|
||||||
queue_size=1000,
|
|
||||||
tcp_nodelay=True,
|
|
||||||
)
|
|
||||||
self.puppet_arm_left_publisher = rospy.Publisher(self.args.puppet_arm_left_cmd_topic, JointState, queue_size=10)
|
|
||||||
self.puppet_arm_right_publisher = rospy.Publisher(self.args.puppet_arm_right_cmd_topic,
|
|
||||||
JointState,
|
|
||||||
queue_size=10)
|
|
||||||
self.robot_base_publisher = rospy.Publisher(self.args.robot_base_cmd_topic, Twist, queue_size=10)
|
|
||||||
|
|
||||||
|
|
||||||
def get_arguments():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--max_publish_step",
|
|
||||||
action="store",
|
|
||||||
type=int,
|
|
||||||
help="Maximum number of action publishing steps",
|
|
||||||
default=10000,
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--seed",
|
|
||||||
action="store",
|
|
||||||
type=int,
|
|
||||||
help="Random seed",
|
|
||||||
default=None,
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--img_front_topic",
|
|
||||||
action="store",
|
|
||||||
type=str,
|
|
||||||
help="img_front_topic",
|
|
||||||
default="/camera_f/color/image_raw",
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--img_left_topic",
|
|
||||||
action="store",
|
|
||||||
type=str,
|
|
||||||
help="img_left_topic",
|
|
||||||
default="/camera_l/color/image_raw",
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--img_right_topic",
|
|
||||||
action="store",
|
|
||||||
type=str,
|
|
||||||
help="img_right_topic",
|
|
||||||
default="/camera_r/color/image_raw",
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--img_front_depth_topic",
|
|
||||||
action="store",
|
|
||||||
type=str,
|
|
||||||
help="img_front_depth_topic",
|
|
||||||
default="/camera_f/depth/image_raw",
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--img_left_depth_topic",
|
|
||||||
action="store",
|
|
||||||
type=str,
|
|
||||||
help="img_left_depth_topic",
|
|
||||||
default="/camera_l/depth/image_raw",
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--img_right_depth_topic",
|
|
||||||
action="store",
|
|
||||||
type=str,
|
|
||||||
help="img_right_depth_topic",
|
|
||||||
default="/camera_r/depth/image_raw",
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--puppet_arm_left_cmd_topic",
|
|
||||||
action="store",
|
|
||||||
type=str,
|
|
||||||
help="puppet_arm_left_cmd_topic",
|
|
||||||
default="/master/joint_left",
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--puppet_arm_right_cmd_topic",
|
|
||||||
action="store",
|
|
||||||
type=str,
|
|
||||||
help="puppet_arm_right_cmd_topic",
|
|
||||||
default="/master/joint_right",
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--puppet_arm_left_topic",
|
|
||||||
action="store",
|
|
||||||
type=str,
|
|
||||||
help="puppet_arm_left_topic",
|
|
||||||
default="/puppet/joint_left",
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--puppet_arm_right_topic",
|
|
||||||
action="store",
|
|
||||||
type=str,
|
|
||||||
help="puppet_arm_right_topic",
|
|
||||||
default="/puppet/joint_right",
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--robot_base_topic",
|
|
||||||
action="store",
|
|
||||||
type=str,
|
|
||||||
help="robot_base_topic",
|
|
||||||
default="/odom_raw",
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--robot_base_cmd_topic",
|
|
||||||
action="store",
|
|
||||||
type=str,
|
|
||||||
help="robot_base_topic",
|
|
||||||
default="/cmd_vel",
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_robot_base",
|
|
||||||
action="store_true",
|
|
||||||
help="Whether to use the robot base to move around",
|
|
||||||
default=False,
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--publish_rate",
|
|
||||||
action="store",
|
|
||||||
type=int,
|
|
||||||
help="The rate at which to publish the actions",
|
|
||||||
default=30,
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--ctrl_freq",
|
|
||||||
action="store",
|
|
||||||
type=int,
|
|
||||||
help="The control frequency of the robot",
|
|
||||||
default=25,
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--chunk_size",
|
|
||||||
action="store",
|
|
||||||
type=int,
|
|
||||||
help="Action chunk size",
|
|
||||||
default=64,
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--arm_steps_length",
|
|
||||||
action="store",
|
|
||||||
type=float,
|
|
||||||
help="The maximum change allowed for each joint per timestep",
|
|
||||||
default=[0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.2],
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_actions_interpolation",
|
|
||||||
action="store_true",
|
|
||||||
help="Whether to interpolate the actions if the difference is too large",
|
|
||||||
default=False,
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_depth_image",
|
|
||||||
action="store_true",
|
|
||||||
help="Whether to use depth images",
|
|
||||||
default=False,
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--disable_puppet_arm",
|
|
||||||
action="store_true",
|
|
||||||
help="Whether to disable the puppet arm. This is useful for safely debugging",
|
|
||||||
default=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--config_path",
|
|
||||||
type=str,
|
|
||||||
default="configs/base.yaml",
|
|
||||||
help="Path to the config file",
|
|
||||||
)
|
|
||||||
# parser.add_argument('--cfg_scale', type=float, default=2.0,
|
|
||||||
# help='the scaling factor used to modify the magnitude of the control features during denoising')
|
|
||||||
parser.add_argument(
|
|
||||||
"--pretrained_model_name_or_path",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Name or path to the pretrained model",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--lang_embeddings_path",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path to the pre-encoded language instruction embeddings",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
args = get_arguments()
|
|
||||||
ros_operator = RosOperator(args)
|
|
||||||
if args.seed is not None:
|
|
||||||
set_seed(args.seed)
|
|
||||||
config = get_config(args)
|
|
||||||
model_inference(args, config, ros_operator)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@ -1,344 +0,0 @@
|
|||||||
import os, sys
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from torchvision import transforms
|
|
||||||
|
|
||||||
from configs.state_vec import STATE_VEC_IDX_MAPPING
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# get current workspace
|
|
||||||
current_file = Path(__file__)
|
|
||||||
sys.path.append(os.path.join(current_file.parent.parent, "models"))
|
|
||||||
sys.path.append(os.path.join(current_file.parent.parent, "models"))
|
|
||||||
|
|
||||||
from multimodal_encoder.siglip_encoder import SiglipVisionTower
|
|
||||||
from multimodal_encoder.t5_encoder import T5Embedder
|
|
||||||
from rdt_runner import RDTRunner
|
|
||||||
|
|
||||||
# The indices that the raw vector should be mapped to in the unified action vector
|
|
||||||
# AGILEX_STATE_INDICES = [
|
|
||||||
# STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"] for i in range(1)
|
|
||||||
# ] + [
|
|
||||||
# STATE_VEC_IDX_MAPPING["left_gripper_open"]
|
|
||||||
# ] + [
|
|
||||||
# STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(1)
|
|
||||||
# ] + [
|
|
||||||
# STATE_VEC_IDX_MAPPING[f"right_gripper_open"]
|
|
||||||
# ]
|
|
||||||
# AGILEX_STATE_INDICES = None
|
|
||||||
|
|
||||||
|
|
||||||
# Create the RDT model
|
|
||||||
def create_model(args, **kwargs):
|
|
||||||
left_arm_dim, right_arm_dim = (
|
|
||||||
args["arm_dim"]["left_arm_dim"],
|
|
||||||
args["arm_dim"]["right_arm_dim"],
|
|
||||||
)
|
|
||||||
AGILEX_STATE_INDICES = ([STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"]
|
|
||||||
for i in range(left_arm_dim)] + [STATE_VEC_IDX_MAPPING["left_gripper_open"]] +
|
|
||||||
[STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"]
|
|
||||||
for i in range(right_arm_dim)] + [STATE_VEC_IDX_MAPPING[f"right_gripper_open"]])
|
|
||||||
model = RoboticDiffusionTransformerModel(args, **kwargs)
|
|
||||||
pretrained = kwargs.get("pretrained", None)
|
|
||||||
if pretrained is not None and os.path.isfile(pretrained):
|
|
||||||
model.load_pretrained_weights(pretrained)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
class RoboticDiffusionTransformerModel(object):
|
|
||||||
"""A wrapper for the RDT model, which handles
|
|
||||||
1. Model initialization
|
|
||||||
2. Encodings of instructions
|
|
||||||
3. Model inference
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
args,
|
|
||||||
device="cuda",
|
|
||||||
dtype=torch.bfloat16,
|
|
||||||
image_size=None,
|
|
||||||
control_frequency=25,
|
|
||||||
pretrained=None,
|
|
||||||
pretrained_vision_encoder_name_or_path=None,
|
|
||||||
):
|
|
||||||
self.args = args
|
|
||||||
self.dtype = dtype
|
|
||||||
self.image_size = image_size
|
|
||||||
self.device = device
|
|
||||||
self.control_frequency = control_frequency
|
|
||||||
# We do not use the text encoder due to limited GPU memory
|
|
||||||
# self.text_tokenizer, self.text_model = self.get_text_encoder(pretrained_text_encoder_name_or_path)
|
|
||||||
self.image_processor, self.vision_model = self.get_vision_encoder(pretrained_vision_encoder_name_or_path)
|
|
||||||
self.policy = self.get_policy(pretrained)
|
|
||||||
self.left_arm_dim, self.right_arm_dim = (
|
|
||||||
args["arm_dim"]["left_arm_dim"],
|
|
||||||
args["arm_dim"]["right_arm_dim"],
|
|
||||||
)
|
|
||||||
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
def get_policy(self, pretrained):
|
|
||||||
"""Initialize the model."""
|
|
||||||
# Initialize model with arguments
|
|
||||||
if pretrained is None or os.path.isfile(pretrained):
|
|
||||||
img_cond_len = (self.args["common"]["img_history_size"] * self.args["common"]["num_cameras"] *
|
|
||||||
self.vision_model.num_patches)
|
|
||||||
|
|
||||||
_model = RDTRunner(
|
|
||||||
action_dim=self.args["common"]["state_dim"],
|
|
||||||
pred_horizon=self.args["common"]["action_chunk_size"],
|
|
||||||
config=self.args["model"],
|
|
||||||
lang_token_dim=self.args["model"]["lang_token_dim"],
|
|
||||||
img_token_dim=self.args["model"]["img_token_dim"],
|
|
||||||
state_token_dim=self.args["model"]["state_token_dim"],
|
|
||||||
max_lang_cond_len=self.args["dataset"]["tokenizer_max_length"],
|
|
||||||
img_cond_len=img_cond_len,
|
|
||||||
img_pos_embed_config=[
|
|
||||||
# No initial pos embed in the last grid size
|
|
||||||
# since we've already done in ViT
|
|
||||||
(
|
|
||||||
"image",
|
|
||||||
(
|
|
||||||
self.args["common"]["img_history_size"],
|
|
||||||
self.args["common"]["num_cameras"],
|
|
||||||
-self.vision_model.num_patches,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
lang_pos_embed_config=[
|
|
||||||
# Similarly, no initial pos embed for language
|
|
||||||
("lang", -self.args["dataset"]["tokenizer_max_length"]),
|
|
||||||
],
|
|
||||||
dtype=self.dtype,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
_model = RDTRunner.from_pretrained(pretrained)
|
|
||||||
|
|
||||||
return _model
|
|
||||||
|
|
||||||
def get_text_encoder(self, pretrained_text_encoder_name_or_path):
|
|
||||||
text_embedder = T5Embedder(
|
|
||||||
from_pretrained=pretrained_text_encoder_name_or_path,
|
|
||||||
model_max_length=self.args["dataset"]["tokenizer_max_length"],
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
|
|
||||||
return tokenizer, text_encoder
|
|
||||||
|
|
||||||
def get_vision_encoder(self, pretrained_vision_encoder_name_or_path):
|
|
||||||
vision_encoder = SiglipVisionTower(vision_tower=pretrained_vision_encoder_name_or_path, args=None)
|
|
||||||
image_processor = vision_encoder.image_processor
|
|
||||||
return image_processor, vision_encoder
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
"""Set model to evaluation mode."""
|
|
||||||
device = self.device
|
|
||||||
weight_dtype = self.dtype
|
|
||||||
self.policy.eval()
|
|
||||||
# self.text_model.eval()
|
|
||||||
self.vision_model.eval()
|
|
||||||
|
|
||||||
self.policy = self.policy.to(device, dtype=weight_dtype)
|
|
||||||
# self.text_model = self.text_model.to(device, dtype=weight_dtype)
|
|
||||||
self.vision_model = self.vision_model.to(device, dtype=weight_dtype)
|
|
||||||
|
|
||||||
def load_pretrained_weights(self, pretrained=None):
|
|
||||||
if pretrained is None:
|
|
||||||
return
|
|
||||||
print(f"Loading weights from {pretrained}")
|
|
||||||
filename = os.path.basename(pretrained)
|
|
||||||
if filename.endswith(".pt"):
|
|
||||||
checkpoint = torch.load(pretrained)
|
|
||||||
self.policy.load_state_dict(checkpoint["module"])
|
|
||||||
elif filename.endswith(".safetensors"):
|
|
||||||
from safetensors.torch import load_model
|
|
||||||
|
|
||||||
load_model(self.policy, pretrained)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unknown checkpoint format: {pretrained}")
|
|
||||||
|
|
||||||
def encode_instruction(self, instruction, device="cuda"):
|
|
||||||
"""Encode string instruction to latent embeddings.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
instruction: a string of instruction
|
|
||||||
device: a string of device
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
pred: a tensor of latent embeddings of shape (text_max_length, 512)
|
|
||||||
"""
|
|
||||||
tokens = self.text_tokenizer(instruction, return_tensors="pt", padding="longest",
|
|
||||||
truncation=True)["input_ids"].to(device)
|
|
||||||
|
|
||||||
tokens = tokens.view(1, -1)
|
|
||||||
with torch.no_grad():
|
|
||||||
pred = self.text_model(tokens).last_hidden_state.detach()
|
|
||||||
|
|
||||||
return pred
|
|
||||||
|
|
||||||
def _format_joint_to_state(self, joints):
|
|
||||||
"""
|
|
||||||
Format the joint proprioception into the unified action vector.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
joints (torch.Tensor): The joint proprioception to be formatted.
|
|
||||||
qpos ([B, N, 14]).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
state (torch.Tensor): The formatted vector for RDT ([B, N, 128]).
|
|
||||||
"""
|
|
||||||
AGILEX_STATE_INDICES = ([STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"]
|
|
||||||
for i in range(self.left_arm_dim)] + [STATE_VEC_IDX_MAPPING["left_gripper_open"]] +
|
|
||||||
[STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"]
|
|
||||||
for i in range(self.right_arm_dim)] + [STATE_VEC_IDX_MAPPING[f"right_gripper_open"]])
|
|
||||||
# Rescale the gripper to the range of [0, 1]
|
|
||||||
joints = joints / torch.tensor(
|
|
||||||
[[[1 for i in range(self.left_arm_dim + 1 + self.right_arm_dim + 1)]]],
|
|
||||||
device=joints.device,
|
|
||||||
dtype=joints.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
B, N, _ = joints.shape
|
|
||||||
state = torch.zeros(
|
|
||||||
(B, N, self.args["model"]["state_token_dim"]),
|
|
||||||
device=joints.device,
|
|
||||||
dtype=joints.dtype,
|
|
||||||
)
|
|
||||||
# Fill into the unified state vector
|
|
||||||
state[:, :, AGILEX_STATE_INDICES] = joints
|
|
||||||
# Assemble the mask indicating each dimension's availability
|
|
||||||
state_elem_mask = torch.zeros(
|
|
||||||
(B, self.args["model"]["state_token_dim"]),
|
|
||||||
device=joints.device,
|
|
||||||
dtype=joints.dtype,
|
|
||||||
)
|
|
||||||
state_elem_mask[:, AGILEX_STATE_INDICES] = 1
|
|
||||||
return state, state_elem_mask
|
|
||||||
|
|
||||||
def _unformat_action_to_joint(self, action):
|
|
||||||
"""
|
|
||||||
Unformat the unified action vector into the joint action to be executed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action (torch.Tensor): The unified action vector to be unformatted.
|
|
||||||
([B, N, 128])
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
joints (torch.Tensor): The unformatted robot joint action.
|
|
||||||
qpos ([B, N, 14]).
|
|
||||||
"""
|
|
||||||
AGILEX_STATE_INDICES = ([STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"]
|
|
||||||
for i in range(self.left_arm_dim)] + [STATE_VEC_IDX_MAPPING["left_gripper_open"]] +
|
|
||||||
[STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"]
|
|
||||||
for i in range(self.right_arm_dim)] + [STATE_VEC_IDX_MAPPING[f"right_gripper_open"]])
|
|
||||||
action_indices = AGILEX_STATE_INDICES
|
|
||||||
joints = action[:, :, action_indices]
|
|
||||||
|
|
||||||
# Rescale the gripper back to the action range
|
|
||||||
# Note that the action range and proprioception range are different
|
|
||||||
# for Mobile ALOHA robot
|
|
||||||
joints = joints * torch.tensor(
|
|
||||||
[[[1 for i in range(self.left_arm_dim + 1 + self.right_arm_dim + 1)]]],
|
|
||||||
device=joints.device,
|
|
||||||
dtype=joints.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
return joints
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def step(self, proprio, images, text_embeds):
|
|
||||||
"""
|
|
||||||
Predict the next action chunk given the
|
|
||||||
proprioceptive states, images, and instruction embeddings.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
proprio: proprioceptive states
|
|
||||||
images: RGB images, the order should be
|
|
||||||
[ext_{t-1}, right_wrist_{t-1}, left_wrist_{t-1},
|
|
||||||
ext_{t}, right_wrist_{t}, left_wrist_{t}]
|
|
||||||
text_embeds: instruction embeddings
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
action: predicted action
|
|
||||||
"""
|
|
||||||
device = self.device
|
|
||||||
dtype = self.dtype
|
|
||||||
|
|
||||||
# The background image used for padding
|
|
||||||
background_color = np.array([int(x * 255) for x in self.image_processor.image_mean],
|
|
||||||
dtype=np.uint8).reshape(1, 1, 3)
|
|
||||||
background_image = (np.ones(
|
|
||||||
(
|
|
||||||
self.image_processor.size["height"],
|
|
||||||
self.image_processor.size["width"],
|
|
||||||
3,
|
|
||||||
),
|
|
||||||
dtype=np.uint8,
|
|
||||||
) * background_color)
|
|
||||||
|
|
||||||
# Preprocess the images by order and encode them
|
|
||||||
image_tensor_list = []
|
|
||||||
for image in images:
|
|
||||||
if image is None:
|
|
||||||
# Replace it with the background image
|
|
||||||
image = Image.fromarray(background_image)
|
|
||||||
|
|
||||||
if self.image_size is not None:
|
|
||||||
image = transforms.Resize(self.data_args.image_size)(image)
|
|
||||||
|
|
||||||
if self.args["dataset"].get("auto_adjust_image_brightness", False):
|
|
||||||
pixel_values = list(image.getdata())
|
|
||||||
average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)
|
|
||||||
if average_brightness <= 0.15:
|
|
||||||
image = transforms.ColorJitter(brightness=(1.75, 1.75))(image)
|
|
||||||
|
|
||||||
if self.args["dataset"].get("image_aspect_ratio", "pad") == "pad":
|
|
||||||
|
|
||||||
def expand2square(pil_img, background_color):
|
|
||||||
width, height = pil_img.size
|
|
||||||
if width == height:
|
|
||||||
return pil_img
|
|
||||||
elif width > height:
|
|
||||||
result = Image.new(pil_img.mode, (width, width), background_color)
|
|
||||||
result.paste(pil_img, (0, (width - height) // 2))
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
result = Image.new(pil_img.mode, (height, height), background_color)
|
|
||||||
result.paste(pil_img, ((height - width) // 2, 0))
|
|
||||||
return result
|
|
||||||
|
|
||||||
image = expand2square(image, tuple(int(x * 255) for x in self.image_processor.image_mean))
|
|
||||||
image = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
|
|
||||||
image_tensor_list.append(image)
|
|
||||||
|
|
||||||
image_tensor = torch.stack(image_tensor_list, dim=0).to(device, dtype=dtype)
|
|
||||||
|
|
||||||
image_embeds = self.vision_model(image_tensor).detach()
|
|
||||||
image_embeds = image_embeds.reshape(-1, self.vision_model.hidden_size).unsqueeze(0)
|
|
||||||
|
|
||||||
# Prepare the proprioception states and the control frequency
|
|
||||||
joints = proprio.to(device).unsqueeze(0) # (1, 1, 14)
|
|
||||||
states, state_elem_mask = self._format_joint_to_state(joints) # (1, 1, 128), (1, 128)
|
|
||||||
states, state_elem_mask = states.to(device, dtype=dtype), state_elem_mask.to(device, dtype=dtype)
|
|
||||||
states = states[:, -1:, :] # (1, 1, 128)
|
|
||||||
ctrl_freqs = torch.tensor([self.control_frequency]).to(device)
|
|
||||||
|
|
||||||
text_embeds = text_embeds.to(device, dtype=dtype)
|
|
||||||
|
|
||||||
# Predict the next action chunk given the inputs
|
|
||||||
trajectory = self.policy.predict_action(
|
|
||||||
lang_tokens=text_embeds,
|
|
||||||
lang_attn_mask=torch.ones(text_embeds.shape[:2], dtype=torch.bool, device=text_embeds.device),
|
|
||||||
img_tokens=image_embeds,
|
|
||||||
state_tokens=states,
|
|
||||||
action_mask=state_elem_mask.unsqueeze(1),
|
|
||||||
ctrl_freqs=ctrl_freqs,
|
|
||||||
)
|
|
||||||
trajectory = self._unformat_action_to_joint(trajectory).to(torch.float32)
|
|
||||||
|
|
||||||
return trajectory
|
|
||||||
@ -1,53 +0,0 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from models.multimodal_encoder.t5_encoder import T5Embedder
|
|
||||||
|
|
||||||
GPU = 0
|
|
||||||
MODEL_PATH = "google/t5-v1_1-xxl"
|
|
||||||
CONFIG_PATH = "configs/base.yaml"
|
|
||||||
SAVE_DIR = "outs/"
|
|
||||||
|
|
||||||
# Modify this to your task name and instruction
|
|
||||||
TASK_NAME = "handover_pan"
|
|
||||||
INSTRUCTION = "Pick up the black marker on the right and put it into the packaging box on the left."
|
|
||||||
|
|
||||||
# Note: if your GPU VRAM is less than 24GB,
|
|
||||||
# it is recommended to enable offloading by specifying an offload directory.
|
|
||||||
OFFLOAD_DIR = (
|
|
||||||
None # Specify your offload directory here, ensuring the directory exists.
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
with open(CONFIG_PATH, "r") as fp:
|
|
||||||
config = yaml.safe_load(fp)
|
|
||||||
|
|
||||||
device = torch.device(f"cuda:{GPU}")
|
|
||||||
text_embedder = T5Embedder(
|
|
||||||
from_pretrained=MODEL_PATH,
|
|
||||||
model_max_length=config["dataset"]["tokenizer_max_length"],
|
|
||||||
device=device,
|
|
||||||
use_offload_folder=OFFLOAD_DIR,
|
|
||||||
)
|
|
||||||
tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
|
|
||||||
|
|
||||||
tokens = tokenizer(INSTRUCTION, return_tensors="pt", padding="longest", truncation=True)["input_ids"].to(device)
|
|
||||||
|
|
||||||
tokens = tokens.view(1, -1)
|
|
||||||
with torch.no_grad():
|
|
||||||
pred = text_encoder(tokens).last_hidden_state.detach().cpu()
|
|
||||||
|
|
||||||
save_path = os.path.join(SAVE_DIR, f"{TASK_NAME}.pt")
|
|
||||||
# We save the embeddings in a dictionary format
|
|
||||||
torch.save({"name": TASK_NAME, "instruction": INSTRUCTION, "embeddings": pred}, save_path)
|
|
||||||
|
|
||||||
print(
|
|
||||||
f'"{INSTRUCTION}" from "{TASK_NAME}" is encoded by "{MODEL_PATH}" into shape {pred.shape} and saved to "{save_path}"'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@ -1,57 +0,0 @@
|
|||||||
import os
|
|
||||||
import json
|
|
||||||
import argparse
|
|
||||||
import torch
|
|
||||||
import yaml
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from models.multimodal_encoder.t5_encoder import T5Embedder
|
|
||||||
|
|
||||||
|
|
||||||
def encode_lang(
|
|
||||||
DATA_FILE_PATH,
|
|
||||||
TARGET_DIR,
|
|
||||||
GPU,
|
|
||||||
desc_type="seen",
|
|
||||||
tokenizer=None,
|
|
||||||
text_encoder=None,
|
|
||||||
):
|
|
||||||
current_dir = os.path.dirname(__file__)
|
|
||||||
|
|
||||||
with open(os.path.join(current_dir, "../configs/base.yaml"), "r") as fp:
|
|
||||||
config = yaml.safe_load(fp)
|
|
||||||
|
|
||||||
device = torch.device(f"cuda:{GPU}")
|
|
||||||
if tokenizer is None or text_encoder is None:
|
|
||||||
text_embedder = T5Embedder(
|
|
||||||
from_pretrained=os.path.join(current_dir, "../../weights/RDT/t5-v1_1-xxl"),
|
|
||||||
model_max_length=config["dataset"]["tokenizer_max_length"],
|
|
||||||
device=device,
|
|
||||||
use_offload_folder=None,
|
|
||||||
)
|
|
||||||
tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
|
|
||||||
|
|
||||||
with open(DATA_FILE_PATH, "r") as f_instr:
|
|
||||||
instruction_dict = json.load(f_instr)
|
|
||||||
|
|
||||||
instructions = instruction_dict[desc_type]
|
|
||||||
|
|
||||||
# Encode the instructions
|
|
||||||
tokenized_res = tokenizer(instructions, return_tensors="pt", padding="longest", truncation=True)
|
|
||||||
tokens = tokenized_res["input_ids"].to(device)
|
|
||||||
attn_mask = tokenized_res["attention_mask"].to(device)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
text_embeds = (text_encoder(input_ids=tokens, attention_mask=attn_mask)["last_hidden_state"].detach().cpu())
|
|
||||||
|
|
||||||
attn_mask = attn_mask.cpu().bool()
|
|
||||||
if not os.path.exists(f"{TARGET_DIR}/instructions"):
|
|
||||||
os.makedirs(f"{TARGET_DIR}/instructions")
|
|
||||||
# Save the embeddings for training use
|
|
||||||
for i in range(len(instructions)):
|
|
||||||
text_embed = text_embeds[i][attn_mask[i]]
|
|
||||||
save_path = os.path.join(TARGET_DIR, f"instructions/lang_embed_{i}.pt")
|
|
||||||
# print("encoded instructions save_path:",save_path)
|
|
||||||
torch.save(text_embed, save_path)
|
|
||||||
|
|
||||||
return tokenizer, text_encoder
|
|
||||||
@ -1,84 +0,0 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import re
|
|
||||||
|
|
||||||
def extract_metrics_from_log(log_file_path):
|
|
||||||
all_metrics = []
|
|
||||||
pattern = re.compile(
|
|
||||||
r"\{'agilex_sample_mse':\s*([0-9.eE+-]+),\s*'agilex_sample_l2err':\s*([0-9.eE+-]+),\s*'overall_avg_sample_mse':\s*([0-9.eE+-]+),\s*'overall_avg_sample_l2err':\s*([0-9.eE+-]+)\}"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
with open(log_file_path, 'r', encoding='utf-8') as f:
|
|
||||||
for line in f:
|
|
||||||
m = pattern.search(line)
|
|
||||||
if m:
|
|
||||||
metrics = (
|
|
||||||
float(m.group(1)),
|
|
||||||
float(m.group(2)),
|
|
||||||
float(m.group(3)),
|
|
||||||
float(m.group(4))
|
|
||||||
)
|
|
||||||
all_metrics.append(metrics)
|
|
||||||
print(f"Find Metrics: agilex_sample_mse={metrics[0]}, agilex_sample_l2err={metrics[1]}, "
|
|
||||||
f"overall_avg_sample_mse={metrics[2]}, overall_avg_sample_l2err={metrics[3]}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to read log: {e}")
|
|
||||||
return (None, None, None, None)
|
|
||||||
|
|
||||||
if not all_metrics:
|
|
||||||
print("No metrics found in the log file")
|
|
||||||
return (None, None, None, None)
|
|
||||||
|
|
||||||
print(f"\nTotal {len(all_metrics)} metrics found in the log file")
|
|
||||||
|
|
||||||
best_agilex_mse = min(m[0] for m in all_metrics)
|
|
||||||
best_agilex_l2err = min(m[1] for m in all_metrics)
|
|
||||||
best_overall_mse = min(m[2] for m in all_metrics)
|
|
||||||
best_overall_l2err = min(m[3] for m in all_metrics)
|
|
||||||
|
|
||||||
print(f"\nBest metrics:")
|
|
||||||
print(f" agilex_sample_mse: {best_agilex_mse}")
|
|
||||||
print(f" agilex_sample_l2err: {best_agilex_l2err}")
|
|
||||||
print(f" overall_avg_sample_mse: {best_overall_mse}")
|
|
||||||
print(f" overall_avg_sample_l2err: {best_overall_l2err}")
|
|
||||||
|
|
||||||
return (best_agilex_mse, best_agilex_l2err, best_overall_mse, best_overall_l2err)
|
|
||||||
|
|
||||||
def generate_output_json(input_config_file, output_dir, runtime):
|
|
||||||
with open(input_config_file, 'r') as f:
|
|
||||||
config = json.load(f)
|
|
||||||
|
|
||||||
log_file = os.path.join(output_dir, 'output.log')
|
|
||||||
agilex_sample_mse, agilex_sample_l2err, overall_avg_sample_mse, overall_avg_sample_l2err = extract_metrics_from_log(log_file)
|
|
||||||
|
|
||||||
if None in [agilex_sample_mse, agilex_sample_l2err, overall_avg_sample_mse, overall_avg_sample_l2err]:
|
|
||||||
print("Warning: Some metrics are missing in the log file.")
|
|
||||||
|
|
||||||
output_json = {
|
|
||||||
"task_id": config.get("task_id"),
|
|
||||||
"model_type": "RDT-1B",
|
|
||||||
"model_name": config.get("model_name") if "model_name" in config else config.get("train", {}).get("model"),
|
|
||||||
"gpu_id": config.get("gpu_id"),
|
|
||||||
"runtime": runtime,
|
|
||||||
"log_path": log_file,
|
|
||||||
"output_dir": output_dir,
|
|
||||||
"model_path": os.path.join(output_dir, 'pytorch_model.bin'),
|
|
||||||
"metrics": {
|
|
||||||
"agilex_sample_mse": agilex_sample_mse,
|
|
||||||
"agilex_sample_l2err": agilex_sample_l2err,
|
|
||||||
"overall_avg_sample_mse": overall_avg_sample_mse,
|
|
||||||
"overall_avg_sample_l2err": overall_avg_sample_l2err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# 写入 output.json,格式化输出、确保null与规范json一致
|
|
||||||
output_json_path = os.path.join(output_dir, 'output.json')
|
|
||||||
with open(output_json_path, 'w') as f:
|
|
||||||
json.dump(output_json, f, indent=4, ensure_ascii=False)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
if len(sys.argv) != 4:
|
|
||||||
print("Usage: python generate_output_json.py <input_config_file> <output_dir> <runtime>")
|
|
||||||
sys.exit(1)
|
|
||||||
generate_output_json(sys.argv[1], sys.argv[2], sys.argv[3])
|
|
||||||
@ -1,325 +0,0 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from torchvision import transforms
|
|
||||||
|
|
||||||
from configs.state_vec import STATE_VEC_IDX_MAPPING
|
|
||||||
from models.multimodal_encoder.siglip_encoder import SiglipVisionTower
|
|
||||||
from models.multimodal_encoder.t5_encoder import T5Embedder
|
|
||||||
from models.rdt_runner import RDTRunner
|
|
||||||
|
|
||||||
MANISKILL_INDICES = [STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"]
|
|
||||||
for i in range(7)] + [STATE_VEC_IDX_MAPPING[f"right_gripper_open"]]
|
|
||||||
|
|
||||||
|
|
||||||
def create_model(args, pretrained, **kwargs):
|
|
||||||
model = RoboticDiffusionTransformerModel(args, **kwargs)
|
|
||||||
if pretrained is not None:
|
|
||||||
model.load_pretrained_weights(pretrained)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
DATA_STAT = {
|
|
||||||
"state_min": [
|
|
||||||
-0.7463043928146362,
|
|
||||||
-0.0801204964518547,
|
|
||||||
-0.4976441562175751,
|
|
||||||
-2.657780647277832,
|
|
||||||
-0.5742632150650024,
|
|
||||||
1.8309762477874756,
|
|
||||||
-2.2423808574676514,
|
|
||||||
0.0,
|
|
||||||
],
|
|
||||||
"state_max": [
|
|
||||||
0.7645499110221863,
|
|
||||||
1.4967026710510254,
|
|
||||||
0.4650936424732208,
|
|
||||||
-0.3866899907588959,
|
|
||||||
0.5505855679512024,
|
|
||||||
3.2900545597076416,
|
|
||||||
2.5737812519073486,
|
|
||||||
0.03999999910593033,
|
|
||||||
],
|
|
||||||
"action_min": [
|
|
||||||
-0.7472005486488342,
|
|
||||||
-0.08631071448326111,
|
|
||||||
-0.4995281398296356,
|
|
||||||
-2.658363103866577,
|
|
||||||
-0.5751323103904724,
|
|
||||||
1.8290787935256958,
|
|
||||||
-2.245187997817993,
|
|
||||||
-1.0,
|
|
||||||
],
|
|
||||||
"action_max": [
|
|
||||||
0.7654682397842407,
|
|
||||||
1.4984270334243774,
|
|
||||||
0.46786263585090637,
|
|
||||||
-0.38181185722351074,
|
|
||||||
0.5517147779464722,
|
|
||||||
3.291581630706787,
|
|
||||||
2.575840711593628,
|
|
||||||
1.0,
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class RoboticDiffusionTransformerModel(object):
|
|
||||||
"""A wrapper for the RDT model, which handles
|
|
||||||
1. Model initialization
|
|
||||||
2. Encodings of instructions
|
|
||||||
3. Model inference
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
args,
|
|
||||||
device="cuda",
|
|
||||||
dtype=torch.bfloat16,
|
|
||||||
image_size=None,
|
|
||||||
control_frequency=25,
|
|
||||||
pretrained_text_encoder_name_or_path=None,
|
|
||||||
pretrained_vision_encoder_name_or_path=None,
|
|
||||||
):
|
|
||||||
self.args = args
|
|
||||||
self.dtype = dtype
|
|
||||||
self.image_size = image_size
|
|
||||||
self.device = device
|
|
||||||
self.control_frequency = control_frequency
|
|
||||||
self.text_tokenizer, self.text_model = self.get_text_encoder(pretrained_text_encoder_name_or_path)
|
|
||||||
self.image_processor, self.vision_model = self.get_vision_encoder(pretrained_vision_encoder_name_or_path)
|
|
||||||
self.policy = self.get_policy()
|
|
||||||
|
|
||||||
self.state_min = torch.tensor(DATA_STAT["state_min"]).to(device)
|
|
||||||
self.state_max = torch.tensor(DATA_STAT["state_max"]).to(device)
|
|
||||||
self.action_min = torch.tensor(DATA_STAT["action_min"]).to(device)
|
|
||||||
self.action_max = torch.tensor(DATA_STAT["action_max"]).to(device)
|
|
||||||
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
def get_policy(self):
|
|
||||||
"""Initialize the model."""
|
|
||||||
# Initialize model with arguments
|
|
||||||
img_cond_len = (self.args["common"]["img_history_size"] * self.args["common"]["num_cameras"] *
|
|
||||||
self.vision_model.num_patches)
|
|
||||||
|
|
||||||
_model = RDTRunner(
|
|
||||||
action_dim=self.args["common"]["state_dim"],
|
|
||||||
pred_horizon=self.args["common"]["action_chunk_size"],
|
|
||||||
config=self.args["model"],
|
|
||||||
lang_token_dim=self.args["model"]["lang_token_dim"],
|
|
||||||
img_token_dim=self.args["model"]["img_token_dim"],
|
|
||||||
state_token_dim=self.args["model"]["state_token_dim"],
|
|
||||||
max_lang_cond_len=self.args["dataset"]["tokenizer_max_length"],
|
|
||||||
img_cond_len=img_cond_len,
|
|
||||||
img_pos_embed_config=[
|
|
||||||
# No initial pos embed in the last grid size
|
|
||||||
# since we've already done in ViT
|
|
||||||
(
|
|
||||||
"image",
|
|
||||||
(
|
|
||||||
self.args["common"]["img_history_size"],
|
|
||||||
self.args["common"]["num_cameras"],
|
|
||||||
-self.vision_model.num_patches,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
lang_pos_embed_config=[
|
|
||||||
# Similarly, no initial pos embed for language
|
|
||||||
("lang", -self.args["dataset"]["tokenizer_max_length"]),
|
|
||||||
],
|
|
||||||
dtype=self.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
return _model
|
|
||||||
|
|
||||||
def get_text_encoder(self, pretrained_text_encoder_name_or_path):
|
|
||||||
text_embedder = T5Embedder(
|
|
||||||
from_pretrained=pretrained_text_encoder_name_or_path,
|
|
||||||
model_max_length=self.args["dataset"]["tokenizer_max_length"],
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
|
|
||||||
return tokenizer, text_encoder
|
|
||||||
|
|
||||||
def get_vision_encoder(self, pretrained_vision_encoder_name_or_path):
|
|
||||||
vision_encoder = SiglipVisionTower(vision_tower=pretrained_vision_encoder_name_or_path, args=None)
|
|
||||||
image_processor = vision_encoder.image_processor
|
|
||||||
return image_processor, vision_encoder
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
"""Set model to evaluation mode."""
|
|
||||||
device = self.device
|
|
||||||
weight_dtype = self.dtype
|
|
||||||
self.policy.eval()
|
|
||||||
self.text_model.eval()
|
|
||||||
self.vision_model.eval()
|
|
||||||
|
|
||||||
self.policy = self.policy.to(device, dtype=weight_dtype)
|
|
||||||
self.text_model = self.text_model.to(device, dtype=weight_dtype)
|
|
||||||
self.vision_model = self.vision_model.to(device, dtype=weight_dtype)
|
|
||||||
|
|
||||||
def load_pretrained_weights(self, pretrained=None):
|
|
||||||
if pretrained is None:
|
|
||||||
return
|
|
||||||
print(f"Loading weights from {pretrained}")
|
|
||||||
filename = os.path.basename(pretrained)
|
|
||||||
if filename.endswith(".pt"):
|
|
||||||
checkpoint = torch.load(pretrained)
|
|
||||||
self.policy.load_state_dict(checkpoint["module"])
|
|
||||||
elif filename.endswith(".safetensors"):
|
|
||||||
from safetensors.torch import load_model
|
|
||||||
|
|
||||||
load_model(self.policy, pretrained)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unknown checkpoint format: {pretrained}")
|
|
||||||
|
|
||||||
def encode_instruction(self, instruction, device="cuda"):
|
|
||||||
"""Encode string instruction to latent embeddings.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
instruction: a string of instruction
|
|
||||||
device: a string of device
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
pred: a tensor of latent embeddings of shape (text_max_length, 512)
|
|
||||||
"""
|
|
||||||
tokens = self.text_tokenizer(instruction, return_tensors="pt", padding="longest",
|
|
||||||
truncation=True)["input_ids"].to(device)
|
|
||||||
|
|
||||||
tokens = tokens.view(1, -1)
|
|
||||||
with torch.no_grad():
|
|
||||||
pred = self.text_model(tokens).last_hidden_state.detach()
|
|
||||||
|
|
||||||
return pred
|
|
||||||
|
|
||||||
def _format_joint_to_state(self, joints):
|
|
||||||
"""
|
|
||||||
Format the robot joint state into the unified state vector.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
joints (torch.Tensor): The joint state to be formatted.
|
|
||||||
qpos ([B, N, 14]).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
state (torch.Tensor): The formatted state for RDT ([B, N, 128]).
|
|
||||||
"""
|
|
||||||
# Rescale the gripper
|
|
||||||
# joints = joints / torch.tensor(
|
|
||||||
# [[[1, 1, 1, 1, 1, 1, 4.7908, 1, 1, 1, 1, 1, 1, 4.7888]]],
|
|
||||||
# device=joints.device, dtype=joints.dtype
|
|
||||||
# )
|
|
||||||
|
|
||||||
# normalize to -1,1
|
|
||||||
joints = (joints - self.state_min) / (self.state_max - self.state_min) * 2 - 1
|
|
||||||
B, N, _ = joints.shape
|
|
||||||
state = torch.zeros(
|
|
||||||
(B, N, self.args["model"]["state_token_dim"]),
|
|
||||||
device=joints.device,
|
|
||||||
dtype=joints.dtype,
|
|
||||||
)
|
|
||||||
# assemble the unifed state vector
|
|
||||||
state[:, :, MANISKILL_INDICES] = joints
|
|
||||||
state_elem_mask = torch.zeros(
|
|
||||||
(B, self.args["model"]["state_token_dim"]),
|
|
||||||
device=joints.device,
|
|
||||||
dtype=joints.dtype,
|
|
||||||
)
|
|
||||||
state_elem_mask[:, MANISKILL_INDICES] = 1
|
|
||||||
return state, state_elem_mask
|
|
||||||
|
|
||||||
def _unformat_action_to_joint(self, action):
|
|
||||||
action_indices = MANISKILL_INDICES
|
|
||||||
joints = action[:, :, action_indices]
|
|
||||||
|
|
||||||
# denormalize to action space
|
|
||||||
|
|
||||||
joints = (joints + 1) / 2 * (self.action_max - self.action_min) + self.action_min
|
|
||||||
|
|
||||||
return joints
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def step(self, proprio, images, text_embeds):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
proprio: proprioceptive states
|
|
||||||
images: RGB images
|
|
||||||
text_embeds: instruction embeddings
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
action: predicted action
|
|
||||||
"""
|
|
||||||
device = self.device
|
|
||||||
dtype = self.dtype
|
|
||||||
|
|
||||||
background_color = np.array([int(x * 255) for x in self.image_processor.image_mean],
|
|
||||||
dtype=np.uint8).reshape(1, 1, 3)
|
|
||||||
background_image = (np.ones(
|
|
||||||
(
|
|
||||||
self.image_processor.size["height"],
|
|
||||||
self.image_processor.size["width"],
|
|
||||||
3,
|
|
||||||
),
|
|
||||||
dtype=np.uint8,
|
|
||||||
) * background_color)
|
|
||||||
|
|
||||||
image_tensor_list = []
|
|
||||||
for image in images:
|
|
||||||
if image is None:
|
|
||||||
# Replace it with the background image
|
|
||||||
image = Image.fromarray(background_image)
|
|
||||||
|
|
||||||
if self.image_size is not None:
|
|
||||||
image = transforms.Resize(self.data_args.image_size)(image)
|
|
||||||
|
|
||||||
if self.args["dataset"].get("auto_adjust_image_brightness", False):
|
|
||||||
pixel_values = list(image.getdata())
|
|
||||||
average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)
|
|
||||||
if average_brightness <= 0.15:
|
|
||||||
image = transforms.ColorJitter(brightness=(1.75, 1.75))(image)
|
|
||||||
|
|
||||||
if self.args["dataset"].get("image_aspect_ratio", "pad") == "pad":
|
|
||||||
|
|
||||||
def expand2square(pil_img, background_color):
|
|
||||||
width, height = pil_img.size
|
|
||||||
if width == height:
|
|
||||||
return pil_img
|
|
||||||
elif width > height:
|
|
||||||
result = Image.new(pil_img.mode, (width, width), background_color)
|
|
||||||
result.paste(pil_img, (0, (width - height) // 2))
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
result = Image.new(pil_img.mode, (height, height), background_color)
|
|
||||||
result.paste(pil_img, ((height - width) // 2, 0))
|
|
||||||
return result
|
|
||||||
|
|
||||||
image = expand2square(image, tuple(int(x * 255) for x in self.image_processor.image_mean))
|
|
||||||
image = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
|
|
||||||
image_tensor_list.append(image)
|
|
||||||
|
|
||||||
image_tensor = torch.stack(image_tensor_list, dim=0).to(device, dtype=dtype)
|
|
||||||
|
|
||||||
image_embeds = self.vision_model(image_tensor).detach()
|
|
||||||
image_embeds = image_embeds.reshape(-1, self.vision_model.hidden_size).unsqueeze(0)
|
|
||||||
|
|
||||||
# history of actions
|
|
||||||
joints = proprio.to(device).unsqueeze(0) # (1, 1, 14)
|
|
||||||
states, state_elem_mask = self._format_joint_to_state(joints) # (1, 1, 128), (1, 128)
|
|
||||||
states, state_elem_mask = states.to(device, dtype=dtype), state_elem_mask.to(device, dtype=dtype)
|
|
||||||
states = states[:, -1:, :] # (1, 1, 128)
|
|
||||||
ctrl_freqs = torch.tensor([self.control_frequency]).to(device)
|
|
||||||
|
|
||||||
text_embeds = text_embeds.to(device, dtype=dtype)
|
|
||||||
|
|
||||||
trajectory = self.policy.predict_action(
|
|
||||||
lang_tokens=text_embeds,
|
|
||||||
lang_attn_mask=torch.ones(text_embeds.shape[:2], dtype=torch.bool, device=text_embeds.device),
|
|
||||||
img_tokens=image_embeds,
|
|
||||||
state_tokens=states,
|
|
||||||
action_mask=state_elem_mask.unsqueeze(1),
|
|
||||||
ctrl_freqs=ctrl_freqs,
|
|
||||||
)
|
|
||||||
trajectory = self._unformat_action_to_joint(trajectory).to(torch.float32)
|
|
||||||
|
|
||||||
return trajectory
|
|
||||||
@ -1,169 +0,0 @@
|
|||||||
import sys
|
|
||||||
|
|
||||||
sys.path.append("./")
|
|
||||||
|
|
||||||
import os
|
|
||||||
import h5py
|
|
||||||
import numpy as np
|
|
||||||
import pickle
|
|
||||||
import cv2
|
|
||||||
import argparse
|
|
||||||
import yaml
|
|
||||||
from scripts.encode_lang_batch_once import encode_lang
|
|
||||||
|
|
||||||
|
|
||||||
def load_hdf5(dataset_path):
|
|
||||||
if not os.path.isfile(dataset_path):
|
|
||||||
print(f"Dataset does not exist at \n{dataset_path}\n")
|
|
||||||
exit()
|
|
||||||
|
|
||||||
with h5py.File(dataset_path, "r") as root:
|
|
||||||
left_gripper, left_arm = (
|
|
||||||
root["/joint_action/left_gripper"][()],
|
|
||||||
root["/joint_action/left_arm"][()],
|
|
||||||
)
|
|
||||||
right_gripper, right_arm = (
|
|
||||||
root["/joint_action/right_gripper"][()],
|
|
||||||
root["/joint_action/right_arm"][()],
|
|
||||||
)
|
|
||||||
image_dict = dict()
|
|
||||||
for cam_name in root[f"/observation/"].keys():
|
|
||||||
image_dict[cam_name] = root[f"/observation/{cam_name}/rgb"][()]
|
|
||||||
|
|
||||||
return left_gripper, left_arm, right_gripper, right_arm, image_dict
|
|
||||||
|
|
||||||
|
|
||||||
def images_encoding(imgs):
|
|
||||||
encode_data = []
|
|
||||||
padded_data = []
|
|
||||||
max_len = 0
|
|
||||||
for i in range(len(imgs)):
|
|
||||||
success, encoded_image = cv2.imencode(".jpg", imgs[i])
|
|
||||||
jpeg_data = encoded_image.tobytes()
|
|
||||||
encode_data.append(jpeg_data)
|
|
||||||
max_len = max(max_len, len(jpeg_data))
|
|
||||||
# padding
|
|
||||||
for i in range(len(imgs)):
|
|
||||||
padded_data.append(encode_data[i].ljust(max_len, b"\0"))
|
|
||||||
return encode_data, max_len
|
|
||||||
|
|
||||||
|
|
||||||
def get_task_config(task_name):
|
|
||||||
with open(f"./task_config/{task_name}.yml", "r", encoding="utf-8") as f:
|
|
||||||
args = yaml.load(f.read(), Loader=yaml.FullLoader)
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def data_transform(path, episode_num, save_path):
|
|
||||||
begin = 0
|
|
||||||
floders = os.listdir(path)
|
|
||||||
assert episode_num <= len(floders), "data num not enough"
|
|
||||||
|
|
||||||
if not os.path.exists(save_path):
|
|
||||||
os.makedirs(save_path)
|
|
||||||
|
|
||||||
for i in range(episode_num):
|
|
||||||
left_gripper_all, left_arm_all, right_gripper_all, right_arm_all, image_dict = (load_hdf5(
|
|
||||||
os.path.join(path, f"episode{i}.hdf5")))
|
|
||||||
qpos = []
|
|
||||||
actions = []
|
|
||||||
cam_high = []
|
|
||||||
cam_right_wrist = []
|
|
||||||
cam_left_wrist = []
|
|
||||||
left_arm_dim = []
|
|
||||||
right_arm_dim = []
|
|
||||||
|
|
||||||
last_state = None
|
|
||||||
for j in range(0, left_gripper_all.shape[0]):
|
|
||||||
|
|
||||||
left_gripper, left_arm, right_gripper, right_arm = (
|
|
||||||
left_gripper_all[j],
|
|
||||||
left_arm_all[j],
|
|
||||||
right_gripper_all[j],
|
|
||||||
right_arm_all[j],
|
|
||||||
)
|
|
||||||
|
|
||||||
state = np.concatenate((left_arm, [left_gripper], right_arm, [right_gripper]), axis=0) # joint
|
|
||||||
state = state.astype(np.float32)
|
|
||||||
|
|
||||||
if j != left_gripper_all.shape[0] - 1:
|
|
||||||
|
|
||||||
qpos.append(state)
|
|
||||||
|
|
||||||
camera_high_bits = image_dict["head_camera"][j]
|
|
||||||
camera_high = cv2.imdecode(np.frombuffer(camera_high_bits, np.uint8), cv2.IMREAD_COLOR)
|
|
||||||
camera_high_resized = cv2.resize(camera_high, (640, 480))
|
|
||||||
cam_high.append(camera_high_resized)
|
|
||||||
|
|
||||||
camera_right_wrist_bits = image_dict["right_camera"][j]
|
|
||||||
camera_right_wrist = cv2.imdecode(np.frombuffer(camera_right_wrist_bits, np.uint8), cv2.IMREAD_COLOR)
|
|
||||||
camera_right_wrist_resized = cv2.resize(camera_right_wrist, (640, 480))
|
|
||||||
cam_right_wrist.append(camera_right_wrist_resized)
|
|
||||||
|
|
||||||
camera_left_wrist_bits = image_dict["left_camera"][j]
|
|
||||||
camera_left_wrist = cv2.imdecode(np.frombuffer(camera_left_wrist_bits, np.uint8), cv2.IMREAD_COLOR)
|
|
||||||
camera_left_wrist_resized = cv2.resize(camera_left_wrist, (640, 480))
|
|
||||||
cam_left_wrist.append(camera_left_wrist_resized)
|
|
||||||
|
|
||||||
if j != 0:
|
|
||||||
action = state
|
|
||||||
actions.append(action)
|
|
||||||
left_arm_dim.append(left_arm.shape[0])
|
|
||||||
right_arm_dim.append(right_arm.shape[0])
|
|
||||||
|
|
||||||
if not os.path.exists(os.path.join(save_path, f"episode_{i}")):
|
|
||||||
os.makedirs(os.path.join(save_path, f"episode_{i}"))
|
|
||||||
hdf5path = os.path.join(save_path, f"episode_{i}/episode_{i}.hdf5")
|
|
||||||
|
|
||||||
with h5py.File(hdf5path, "w") as f:
|
|
||||||
f.create_dataset("action", data=np.array(actions))
|
|
||||||
obs = f.create_group("observations")
|
|
||||||
obs.create_dataset("qpos", data=np.array(qpos))
|
|
||||||
obs.create_dataset("left_arm_dim", data=np.array(left_arm_dim))
|
|
||||||
obs.create_dataset("right_arm_dim", data=np.array(right_arm_dim))
|
|
||||||
image = obs.create_group("images")
|
|
||||||
cam_high_enc, len_high = images_encoding(cam_high)
|
|
||||||
cam_right_wrist_enc, len_right = images_encoding(cam_right_wrist)
|
|
||||||
cam_left_wrist_enc, len_left = images_encoding(cam_left_wrist)
|
|
||||||
image.create_dataset("cam_high", data=cam_high_enc, dtype=f"S{len_high}")
|
|
||||||
image.create_dataset("cam_right_wrist", data=cam_right_wrist_enc, dtype=f"S{len_right}")
|
|
||||||
image.create_dataset("cam_left_wrist", data=cam_left_wrist_enc, dtype=f"S{len_left}")
|
|
||||||
|
|
||||||
begin += 1
|
|
||||||
print(f"proccess {i} success!")
|
|
||||||
|
|
||||||
return begin
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Process some episodes.")
|
|
||||||
parser.add_argument("task_name", type=str)
|
|
||||||
parser.add_argument("task_config", type=str)
|
|
||||||
parser.add_argument("expert_data_num", type=int)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
task_name = args.task_name
|
|
||||||
task_config = args.task_config
|
|
||||||
expert_data_num = args.expert_data_num
|
|
||||||
|
|
||||||
load_dir = os.path.join("../../data", str(task_name), str(task_config), "data")
|
|
||||||
|
|
||||||
print(f"read data from path: {load_dir}")
|
|
||||||
begin = data_transform(
|
|
||||||
load_dir,
|
|
||||||
expert_data_num,
|
|
||||||
f"./processed_data/{task_name}-{task_config}-{expert_data_num}",
|
|
||||||
)
|
|
||||||
tokenizer, text_encoder = None, None
|
|
||||||
for idx in range(expert_data_num):
|
|
||||||
print(f"Processing Language: {idx}", end="\r")
|
|
||||||
data_file_path = (f"../../data/{task_name}/{task_config}/instructions/episode{idx}.json")
|
|
||||||
target_dir = (f"processed_data/{task_name}-{task_config}-{expert_data_num}/episode_{idx}")
|
|
||||||
tokenizer, text_encoder = encode_lang(
|
|
||||||
DATA_FILE_PATH=data_file_path,
|
|
||||||
TARGET_DIR=target_dir,
|
|
||||||
GPU=0,
|
|
||||||
desc_type="seen",
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
text_encoder=text_encoder,
|
|
||||||
)
|
|
||||||
@ -1,31 +0,0 @@
|
|||||||
import json
|
|
||||||
import yaml
|
|
||||||
import sys
|
|
||||||
|
|
||||||
def read_config(config_file, yaml_file):
|
|
||||||
with open(config_file, 'r') as f:
|
|
||||||
json_config = json.load(f)
|
|
||||||
with open(yaml_file, 'r') as f:
|
|
||||||
yaml_config = yaml.load(f, Loader=yaml.FullLoader)
|
|
||||||
|
|
||||||
yaml_config["model"] = json_config["train"]["model"] + json_config["task_id"]
|
|
||||||
yaml_config["data_path"] = json_config["train"]["input_data_path"] + "/data"
|
|
||||||
yaml_config["checkpoint_path"] = json_config["train"]["checkpoint_path"] + "/" + json_config["task_id"]
|
|
||||||
yaml_config["pretrained_model_name_or_path"] = json_config["train"]["input_data_path"] + "/weights/rdt-1b"
|
|
||||||
yaml_config["cuda_visible_device"] = str(json_config["gpu_id"])
|
|
||||||
print(f"cuda_visible_device: {yaml_config['cuda_visible_device']}")
|
|
||||||
yaml_config["train_batch_size"] = int(json_config["train"]["batch_size"])
|
|
||||||
yaml_config["sample_batch_size"] = int(json_config["train"]["batch_size"]) * 2
|
|
||||||
yaml_config["max_train_steps"] = int(json_config["train"]["epochs"])
|
|
||||||
yaml_config["checkpointing_period"] = int(int(json_config["train"]["epochs"]) / 10)
|
|
||||||
yaml_config["sample_period"] = 200
|
|
||||||
yaml_config["checkpoints_total_limit"] = 50
|
|
||||||
|
|
||||||
|
|
||||||
with open(yaml_file, 'w') as f:
|
|
||||||
yaml.dump(yaml_config, f, default_flow_style=False)
|
|
||||||
|
|
||||||
print("Config YAML file updated successfully")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
read_config(sys.argv[1], sys.argv[2])
|
|
||||||
@ -1,22 +0,0 @@
|
|||||||
import sys
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
|
|
||||||
def read_yaml_value(file_path, key):
|
|
||||||
with open(file_path, "r") as file:
|
|
||||||
data = yaml.safe_load(file)
|
|
||||||
value = data.get(key)
|
|
||||||
if value is not None:
|
|
||||||
print(value)
|
|
||||||
else:
|
|
||||||
print(f"Key '{key}' not found in {file_path}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
if len(sys.argv) != 3:
|
|
||||||
print("Usage: python read_yaml.py <file_path> <key>")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
file_path = sys.argv[1]
|
|
||||||
key = sys.argv[2]
|
|
||||||
read_yaml_value(file_path, key)
|
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1,479 +0,0 @@
|
|||||||
import traceback
|
|
||||||
import time
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import math
|
|
||||||
import random
|
|
||||||
from typing import Dict, Sequence
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
from torchvision import transforms
|
|
||||||
from PIL import Image
|
|
||||||
import transformers
|
|
||||||
|
|
||||||
from data.filelock import FileLock
|
|
||||||
from data.hdf5_vla_dataset import HDF5VLADataset
|
|
||||||
from train.image_corrupt import image_corrupt
|
|
||||||
|
|
||||||
|
|
||||||
def get_clean_item(chunk_dir):
|
|
||||||
"""
|
|
||||||
Get indexes of clean items in a chunk.
|
|
||||||
"""
|
|
||||||
dirty_bit = read_dirty_bit(chunk_dir)
|
|
||||||
return np.where(1 - dirty_bit)[0].tolist()
|
|
||||||
|
|
||||||
|
|
||||||
def save_dirty_bit(chunk_dir, dirty_bit):
|
|
||||||
"""
|
|
||||||
Save the dirty bit to the chunk directory.
|
|
||||||
"""
|
|
||||||
time_stmp = time.time()
|
|
||||||
while time.time() - time_stmp < 10.0:
|
|
||||||
try:
|
|
||||||
file_path = os.path.join(chunk_dir, "dirty_bit")
|
|
||||||
lock = FileLock(file_path)
|
|
||||||
lock.acquire_write_lock()
|
|
||||||
with open(file_path, "wb") as file:
|
|
||||||
file.write(dirty_bit.tobytes())
|
|
||||||
lock.release_lock()
|
|
||||||
return
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
lock.release_lock()
|
|
||||||
raise KeyboardInterrupt
|
|
||||||
except BaseException:
|
|
||||||
lock.release_lock()
|
|
||||||
continue
|
|
||||||
raise RuntimeError("Failed to save dirty bit.")
|
|
||||||
|
|
||||||
|
|
||||||
def read_dirty_bit(chunk_dir):
|
|
||||||
"""
|
|
||||||
Read the dirty bit from the chunk directory.
|
|
||||||
"""
|
|
||||||
# If error occurs, retry
|
|
||||||
time_stmp = time.time()
|
|
||||||
while time.time() - time_stmp < 10.0:
|
|
||||||
try:
|
|
||||||
file_path = os.path.join(chunk_dir, "dirty_bit")
|
|
||||||
lock = FileLock(file_path)
|
|
||||||
lock.acquire_read_lock()
|
|
||||||
with open(file_path, "rb") as file:
|
|
||||||
dirty_bit = np.frombuffer(file.read(), dtype=np.uint8).copy()
|
|
||||||
lock.release_lock()
|
|
||||||
assert len(dirty_bit) > 0
|
|
||||||
return dirty_bit
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
lock.release_lock()
|
|
||||||
raise KeyboardInterrupt
|
|
||||||
except BaseException:
|
|
||||||
lock.release_lock()
|
|
||||||
continue
|
|
||||||
raise RuntimeError("Failed to read dirty bit.")
|
|
||||||
|
|
||||||
|
|
||||||
class VLAConsumerDataset(Dataset):
|
|
||||||
"""A vision-languange-action Dataset for supervised training.
|
|
||||||
This dataset will load data from the buffer directory.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_config_path,
|
|
||||||
config,
|
|
||||||
tokenizer,
|
|
||||||
image_processor,
|
|
||||||
num_cameras,
|
|
||||||
img_history_size,
|
|
||||||
image_size=None,
|
|
||||||
auto_adjust_image_brightness=False,
|
|
||||||
image_aug=False,
|
|
||||||
dataset_type="pretrain",
|
|
||||||
cond_mask_prob=0.1,
|
|
||||||
cam_ext_mask_prob=-1.0,
|
|
||||||
state_noise_snr=None,
|
|
||||||
use_hdf5=False,
|
|
||||||
use_precomp_lang_embed=False,
|
|
||||||
):
|
|
||||||
super(VLAConsumerDataset, self).__init__()
|
|
||||||
|
|
||||||
# Load the control frequency for each dataset
|
|
||||||
with open("configs/dataset_control_freq.json", "r") as fp:
|
|
||||||
self.control_freq = json.load(fp)
|
|
||||||
# Load the dataset names
|
|
||||||
dataset_names_cfg = ("configs/pretrain_datasets.json"
|
|
||||||
if dataset_type == "pretrain" else "configs/finetune_datasets.json")
|
|
||||||
with open(dataset_names_cfg, "r") as file:
|
|
||||||
DATASET_NAMES = json.load(file)
|
|
||||||
# Create the mapping between dataset name and id
|
|
||||||
self.dataset_name2id = {name: i for i, name in enumerate(DATASET_NAMES)}
|
|
||||||
self.dataset_id2name = {i: name for i, name in enumerate(DATASET_NAMES)}
|
|
||||||
|
|
||||||
self.image_processor = image_processor
|
|
||||||
self.model_config_path = model_config_path
|
|
||||||
self.buffer_dir = config["buf_path"]
|
|
||||||
self.num_chunks = config["buf_num_chunks"]
|
|
||||||
self.chunk_size = config["buf_chunk_size"]
|
|
||||||
self.tokenizer_max_length = config["tokenizer_max_length"]
|
|
||||||
self.image_aspect_ratio = config["image_aspect_ratio"]
|
|
||||||
self.state_noise_snr = state_noise_snr
|
|
||||||
self.num_cameras = num_cameras
|
|
||||||
self.img_history_size = img_history_size
|
|
||||||
self.cond_mask_prob = cond_mask_prob
|
|
||||||
self.cam_ext_mask_prob = cam_ext_mask_prob
|
|
||||||
self.use_hdf5 = use_hdf5
|
|
||||||
self.hdf5_dataset = None
|
|
||||||
if use_hdf5:
|
|
||||||
self.hdf5_dataset = HDF5VLADataset(self.model_config_path)
|
|
||||||
self.use_precomp_lang_embed = use_precomp_lang_embed
|
|
||||||
if use_precomp_lang_embed:
|
|
||||||
self.empty_lang_embed = torch.load("data/empty_lang_embed.pt")
|
|
||||||
|
|
||||||
# Load dataset stat
|
|
||||||
with open("configs/dataset_stat.json", "r") as f:
|
|
||||||
dataset_stat = json.load(f)
|
|
||||||
self.dataset_stat = dataset_stat
|
|
||||||
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.image_size = image_size
|
|
||||||
self.auto_adjust_image_brightness = auto_adjust_image_brightness
|
|
||||||
self.image_aug = image_aug
|
|
||||||
|
|
||||||
self.last_content = None
|
|
||||||
self.last_meta = None
|
|
||||||
|
|
||||||
def get_dataset_name2id(self):
|
|
||||||
return self.dataset_name2id
|
|
||||||
|
|
||||||
def get_dataset_id2name(self):
|
|
||||||
return self.dataset_id2name
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def pairwise(iterable):
|
|
||||||
a = iter(iterable)
|
|
||||||
return zip(a, a)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _load_data_from_chunk(chunk_dir, chunk_item_idx):
|
|
||||||
# If error occurs, retry
|
|
||||||
time_stmp = time.time()
|
|
||||||
while time.time() - time_stmp < 10.0:
|
|
||||||
try:
|
|
||||||
locks = []
|
|
||||||
file_path = os.path.join(chunk_dir, f"json_content_{chunk_item_idx}.json")
|
|
||||||
lock = FileLock(file_path)
|
|
||||||
locks.append(lock)
|
|
||||||
lock.acquire_read_lock()
|
|
||||||
with open(file_path, "r") as file:
|
|
||||||
json_content = json.load(file)
|
|
||||||
lock.release_lock()
|
|
||||||
file_path = os.path.join(chunk_dir, f"sample_{chunk_item_idx}.npz")
|
|
||||||
lock = FileLock(file_path)
|
|
||||||
locks.append(lock)
|
|
||||||
lock.acquire_read_lock()
|
|
||||||
with open(file_path, "rb") as file:
|
|
||||||
sample_dict = np.load(file)
|
|
||||||
meta = tuple(sample_dict.values())
|
|
||||||
lock.release_lock()
|
|
||||||
return json_content, meta
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
for lock in locks:
|
|
||||||
lock.release_lock()
|
|
||||||
raise KeyboardInterrupt
|
|
||||||
except BaseException:
|
|
||||||
for lock in locks:
|
|
||||||
lock.release_lock()
|
|
||||||
continue
|
|
||||||
raise RuntimeError("Failed to load sample.")
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
if self.use_hdf5:
|
|
||||||
return len(self.hdf5_dataset)
|
|
||||||
else:
|
|
||||||
return self.num_chunks * self.chunk_size
|
|
||||||
|
|
||||||
def _safe_load(self, index):
|
|
||||||
read_chunk_item_indices = []
|
|
||||||
# Start searching from a random chunk
|
|
||||||
read_chunk_idx = index // self.chunk_size
|
|
||||||
while len(read_chunk_item_indices) == 0:
|
|
||||||
read_chunk_dir = os.path.join(self.buffer_dir, f"chunk_{read_chunk_idx}")
|
|
||||||
try:
|
|
||||||
read_chunk_item_indices = get_clean_item(read_chunk_dir)
|
|
||||||
except BaseException as e:
|
|
||||||
# Print the error info
|
|
||||||
print("Error catched when searching a clean chunk:", e)
|
|
||||||
traceback.print_exc()
|
|
||||||
read_chunk_item_indices = []
|
|
||||||
read_chunk_idx = (read_chunk_idx + 1) % self.num_chunks
|
|
||||||
|
|
||||||
# read_chunk_item_index = random.choice(read_chunk_item_indices)
|
|
||||||
# read_chunk_item_index = read_chunk_item_indices.pop()
|
|
||||||
random_item_index = index % len(read_chunk_item_indices)
|
|
||||||
read_chunk_item_index = read_chunk_item_indices[random_item_index]
|
|
||||||
|
|
||||||
# Modify the dirty bit
|
|
||||||
try:
|
|
||||||
dirty_bit = read_dirty_bit(read_chunk_dir)
|
|
||||||
dirty_bit[read_chunk_item_index] = 1
|
|
||||||
save_dirty_bit(read_chunk_dir, dirty_bit)
|
|
||||||
except BaseException as e:
|
|
||||||
# Print the error info
|
|
||||||
print("Error catched when modifying the dirty bit:", e)
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
# load the sample
|
|
||||||
try:
|
|
||||||
content, meta = self._load_data_from_chunk(read_chunk_dir, read_chunk_item_index)
|
|
||||||
self.last_content, self.last_meta = content, meta
|
|
||||||
except BaseException as e:
|
|
||||||
# Print the error info
|
|
||||||
print("Error catched when loading sample:", e)
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
# If failed to load the data, return the last loaded data for robustness
|
|
||||||
content, meta = self.last_content, self.last_meta
|
|
||||||
|
|
||||||
return (content, *meta)
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
# For robustness, we will try to load the data until we succeed
|
|
||||||
while True:
|
|
||||||
data_dict = None
|
|
||||||
try:
|
|
||||||
if self.use_hdf5:
|
|
||||||
res = self.hdf5_dataset.get_item()
|
|
||||||
content = res["meta"]
|
|
||||||
states = res["state"]
|
|
||||||
actions = res["actions"]
|
|
||||||
state_elem_mask = res["state_indicator"]
|
|
||||||
image_metas = [
|
|
||||||
res["cam_high"],
|
|
||||||
res["cam_high_mask"],
|
|
||||||
res["cam_right_wrist"],
|
|
||||||
res["cam_right_wrist_mask"],
|
|
||||||
res["cam_left_wrist"],
|
|
||||||
res["cam_left_wrist_mask"],
|
|
||||||
]
|
|
||||||
state_std = res["state_std"]
|
|
||||||
state_mean = res["state_mean"]
|
|
||||||
state_norm = res["state_norm"]
|
|
||||||
else:
|
|
||||||
(
|
|
||||||
content,
|
|
||||||
_,
|
|
||||||
states,
|
|
||||||
_,
|
|
||||||
actions,
|
|
||||||
_,
|
|
||||||
state_elem_mask,
|
|
||||||
*image_metas,
|
|
||||||
state_std,
|
|
||||||
state_mean,
|
|
||||||
state_norm,
|
|
||||||
) = self._safe_load(index)
|
|
||||||
|
|
||||||
data_dict = {}
|
|
||||||
data_dict["dataset_name"] = content["dataset_name"]
|
|
||||||
data_dict["data_idx"] = self.dataset_name2id[data_dict["dataset_name"]]
|
|
||||||
data_dict["ctrl_freq"] = (self.control_freq[data_dict["dataset_name"]]
|
|
||||||
if random.random() > self.cond_mask_prob else 0)
|
|
||||||
|
|
||||||
if self.state_noise_snr is not None:
|
|
||||||
states += np.random.normal(
|
|
||||||
0.0,
|
|
||||||
state_std / np.sqrt(10**(self.state_noise_snr / 10)),
|
|
||||||
states.shape,
|
|
||||||
)
|
|
||||||
ds_state_mean = np.array(self.dataset_stat[data_dict["dataset_name"]]["state_mean"])
|
|
||||||
ds_state_mean = np.tile(ds_state_mean[None], (states.shape[0], 1))
|
|
||||||
# Randomly mask the states by the mean state
|
|
||||||
data_dict["states"] = (states if random.random() > self.cond_mask_prob else ds_state_mean)
|
|
||||||
data_dict["actions"] = actions
|
|
||||||
data_dict["state_elem_mask"] = (state_elem_mask if random.random() > self.cond_mask_prob else
|
|
||||||
np.zeros_like(state_elem_mask))
|
|
||||||
|
|
||||||
# Stat for the episode that the step belongs to
|
|
||||||
data_dict["state_norm"] = state_norm
|
|
||||||
|
|
||||||
# We replace the invalid images with the background image
|
|
||||||
# and also randomly mask images by the background image
|
|
||||||
background_color = np.array(
|
|
||||||
[int(x * 255) for x in self.image_processor.image_mean],
|
|
||||||
dtype=np.uint8,
|
|
||||||
).reshape(1, 1, 3)
|
|
||||||
background_image = (np.ones(
|
|
||||||
(
|
|
||||||
self.image_processor.size["height"],
|
|
||||||
self.image_processor.size["width"],
|
|
||||||
3,
|
|
||||||
),
|
|
||||||
dtype=np.uint8,
|
|
||||||
) * background_color)
|
|
||||||
|
|
||||||
image_metas = list(self.pairwise(image_metas))
|
|
||||||
mask_probs = [self.cond_mask_prob] * self.num_cameras
|
|
||||||
if self.cam_ext_mask_prob >= 0.0:
|
|
||||||
mask_probs[0] = self.cam_ext_mask_prob
|
|
||||||
rearranged_images = []
|
|
||||||
for i in range(self.img_history_size):
|
|
||||||
for j in range(self.num_cameras):
|
|
||||||
images, image_mask = image_metas[j]
|
|
||||||
image, valid = images[i], image_mask[i]
|
|
||||||
if (valid and (math.prod(image.shape) > 0) and (random.random() > mask_probs[j])):
|
|
||||||
rearranged_images.append((image, True))
|
|
||||||
else:
|
|
||||||
rearranged_images.append((background_image.copy(), False))
|
|
||||||
|
|
||||||
preprocessed_images = []
|
|
||||||
processor = self.image_processor
|
|
||||||
for image, valid in rearranged_images:
|
|
||||||
image = Image.fromarray(image)
|
|
||||||
if self.image_size is not None:
|
|
||||||
image = transforms.Resize(self.image_size)(image) # (1008, 336)
|
|
||||||
# assert image.height == 336, "We haven't prepare for training with images of different resolutions."
|
|
||||||
|
|
||||||
if valid and self.auto_adjust_image_brightness:
|
|
||||||
pixel_values = list(image.getdata())
|
|
||||||
average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)
|
|
||||||
if average_brightness <= 0.15:
|
|
||||||
image = transforms.ColorJitter(brightness=(1.75, 1.75))(image)
|
|
||||||
|
|
||||||
# Only apply image augmentation to 50% of the images
|
|
||||||
if valid and self.image_aug and (random.random() > 0.5):
|
|
||||||
aug_type = random.choice(["corrput_only", "color_only", "both"])
|
|
||||||
if aug_type != "corrput_only":
|
|
||||||
image = transforms.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.5,
|
|
||||||
hue=0.03)(image)
|
|
||||||
if aug_type != "color_only":
|
|
||||||
image = image_corrupt(image)
|
|
||||||
|
|
||||||
if self.image_aspect_ratio == "pad":
|
|
||||||
|
|
||||||
def expand2square(pil_img, background_color):
|
|
||||||
width, height = pil_img.size
|
|
||||||
if width == height:
|
|
||||||
return pil_img
|
|
||||||
elif width > height:
|
|
||||||
result = Image.new(pil_img.mode, (width, width), background_color)
|
|
||||||
result.paste(pil_img, (0, (width - height) // 2))
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
result = Image.new(pil_img.mode, (height, height), background_color)
|
|
||||||
result.paste(pil_img, ((height - width) // 2, 0))
|
|
||||||
return result
|
|
||||||
|
|
||||||
image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
|
|
||||||
image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
|
|
||||||
preprocessed_images.append(image)
|
|
||||||
data_dict["images"] = preprocessed_images
|
|
||||||
|
|
||||||
if self.use_precomp_lang_embed:
|
|
||||||
if content["instruction"][-1] == ".":
|
|
||||||
content["instruction"] = content["instruction"][:-1]
|
|
||||||
data_dict["lang_embed"] = (torch.load(content["instruction"])
|
|
||||||
if random.random() > self.cond_mask_prob else self.empty_lang_embed)
|
|
||||||
else:
|
|
||||||
instruction = (content["instruction"] if random.random() > self.cond_mask_prob else "")
|
|
||||||
data_dict["input_ids"] = self.tokenizer(
|
|
||||||
instruction,
|
|
||||||
return_tensors="pt",
|
|
||||||
padding="longest",
|
|
||||||
truncation=False,
|
|
||||||
).input_ids[0]
|
|
||||||
|
|
||||||
assert (
|
|
||||||
len(data_dict["input_ids"]) <= self.tokenizer_max_length
|
|
||||||
), f"Instruction length {len(data_dict['input_ids'])} exceeds the maximum length {self.tokenizer_max_length}."
|
|
||||||
|
|
||||||
for k, v in data_dict.items():
|
|
||||||
if isinstance(v, np.ndarray):
|
|
||||||
data_dict[k] = torch.from_numpy(v)
|
|
||||||
|
|
||||||
for k, v in data_dict.items():
|
|
||||||
assert not isinstance(v, np.ndarray), f"key: {k}, value: {v}"
|
|
||||||
# data_dict[k] = torch.from_numpy(v)
|
|
||||||
|
|
||||||
return data_dict
|
|
||||||
except BaseException as e:
|
|
||||||
# Print the error info
|
|
||||||
if data_dict is not None:
|
|
||||||
print(
|
|
||||||
f"Error catched when processing sample from {data_dict.get('dataset_name')}:",
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print(f"Error catched when processing sample:", e)
|
|
||||||
traceback.print_exc()
|
|
||||||
# Try incresing the index
|
|
||||||
index = (index + 1) % len(self)
|
|
||||||
|
|
||||||
|
|
||||||
class DataCollatorForVLAConsumerDataset(object):
|
|
||||||
"""Collate examples for supervised training."""
|
|
||||||
|
|
||||||
def __init__(self, tokenizer: transformers.PreTrainedTokenizer) -> None:
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
|
|
||||||
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
|
||||||
batch = {
|
|
||||||
"states": [],
|
|
||||||
"actions": [],
|
|
||||||
"state_elem_mask": [],
|
|
||||||
"state_norm": [],
|
|
||||||
"images": [],
|
|
||||||
"data_indices": [],
|
|
||||||
"ctrl_freqs": [],
|
|
||||||
}
|
|
||||||
input_ids = []
|
|
||||||
lang_embeds = []
|
|
||||||
lang_embed_lens = []
|
|
||||||
|
|
||||||
for instance in instances:
|
|
||||||
# Convert all the numpy arrays to tensor
|
|
||||||
keys_to_check = [
|
|
||||||
"states",
|
|
||||||
"actions",
|
|
||||||
"state_elem_mask",
|
|
||||||
"state_norm",
|
|
||||||
]
|
|
||||||
for key in keys_to_check:
|
|
||||||
if isinstance(instance[key], torch.Tensor):
|
|
||||||
item = instance[key]
|
|
||||||
else:
|
|
||||||
item = torch.from_numpy(instance[key])
|
|
||||||
batch[key].append(item)
|
|
||||||
|
|
||||||
if "input_ids" in instance:
|
|
||||||
input_ids.append(instance["input_ids"])
|
|
||||||
else:
|
|
||||||
lang_embeds.append(instance["lang_embed"])
|
|
||||||
lang_embed_lens.append(instance["lang_embed"].shape[0])
|
|
||||||
|
|
||||||
batch["images"].append(torch.stack(instance["images"], dim=0))
|
|
||||||
batch["data_indices"].append(instance["data_idx"])
|
|
||||||
batch["ctrl_freqs"].append(instance["ctrl_freq"])
|
|
||||||
|
|
||||||
keys_to_stack = ["states", "actions", "state_elem_mask", "state_norm", "images"]
|
|
||||||
for key in keys_to_stack:
|
|
||||||
batch[key] = torch.stack(batch[key], dim=0)
|
|
||||||
|
|
||||||
batch["ctrl_freqs"] = torch.tensor(batch["ctrl_freqs"])
|
|
||||||
|
|
||||||
if len(input_ids) > 0:
|
|
||||||
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids,
|
|
||||||
batch_first=True,
|
|
||||||
padding_value=self.tokenizer.pad_token_id)
|
|
||||||
batch["input_ids"] = input_ids
|
|
||||||
batch["lang_attn_mask"] = input_ids.ne(self.tokenizer.pad_token_id)
|
|
||||||
else:
|
|
||||||
lang_embeds = torch.nn.utils.rnn.pad_sequence(lang_embeds, batch_first=True, padding_value=0)
|
|
||||||
input_lang_attn_mask = torch.zeros(lang_embeds.shape[0], lang_embeds.shape[1], dtype=torch.bool)
|
|
||||||
for i, l in enumerate(lang_embed_lens):
|
|
||||||
input_lang_attn_mask[i, :l] = True
|
|
||||||
batch["lang_embeds"] = lang_embeds
|
|
||||||
batch["lang_attn_mask"] = input_lang_attn_mask
|
|
||||||
|
|
||||||
return batch
|
|
||||||
@ -1,45 +0,0 @@
|
|||||||
import warnings
|
|
||||||
|
|
||||||
warnings.simplefilter(action="ignore", category=FutureWarning)
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
np.bool = np.bool_
|
|
||||||
import imgaug.augmenters as iaa
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
# Define our sequence of augmentation steps that will be applied to every image.
|
|
||||||
seq = iaa.Sequential(
|
|
||||||
[
|
|
||||||
# Execute one of the following noise augmentations
|
|
||||||
iaa.OneOf([
|
|
||||||
iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05 * 255), per_channel=0.5),
|
|
||||||
iaa.AdditiveLaplaceNoise(scale=(0.0, 0.05 * 255), per_channel=0.5),
|
|
||||||
iaa.AdditivePoissonNoise(lam=(0.0, 0.05 * 255), per_channel=0.5),
|
|
||||||
]),
|
|
||||||
# Execute one or none of the following blur augmentations
|
|
||||||
iaa.SomeOf(
|
|
||||||
(0, 1),
|
|
||||||
[
|
|
||||||
iaa.OneOf([
|
|
||||||
iaa.GaussianBlur((0, 3.0)),
|
|
||||||
iaa.AverageBlur(k=(2, 7)),
|
|
||||||
iaa.MedianBlur(k=(3, 11)),
|
|
||||||
]),
|
|
||||||
iaa.MotionBlur(k=(3, 36)),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
# do all of the above augmentations in random order
|
|
||||||
random_order=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def image_corrupt(image: Image):
|
|
||||||
image_arr = np.array(image)
|
|
||||||
image_arr = image_arr[None, ...]
|
|
||||||
|
|
||||||
image_arr = seq(images=image_arr)
|
|
||||||
|
|
||||||
image = Image.fromarray(image_arr[0])
|
|
||||||
return image
|
|
||||||
@ -1,101 +0,0 @@
|
|||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def log_sample_res(
|
|
||||||
text_encoder,
|
|
||||||
vision_encoder,
|
|
||||||
rdt,
|
|
||||||
args,
|
|
||||||
accelerator,
|
|
||||||
weight_dtype,
|
|
||||||
dataset_id2name,
|
|
||||||
dataloader,
|
|
||||||
logger,
|
|
||||||
):
|
|
||||||
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
|
||||||
logger.info(f"Running sampling for {args.num_sample_batches} batches...")
|
|
||||||
|
|
||||||
rdt.eval()
|
|
||||||
|
|
||||||
loss_for_log = defaultdict(float)
|
|
||||||
loss_counter = defaultdict(int)
|
|
||||||
for step, batch in enumerate(dataloader):
|
|
||||||
if step >= args.num_sample_batches:
|
|
||||||
break
|
|
||||||
|
|
||||||
data_indices = batch["data_indices"]
|
|
||||||
ctrl_freqs = batch["ctrl_freqs"]
|
|
||||||
state_norm = batch["state_norm"].to(dtype=weight_dtype)
|
|
||||||
images = batch["images"].to(dtype=weight_dtype)
|
|
||||||
states = batch["states"].to(dtype=weight_dtype)
|
|
||||||
# We only use the last state as input
|
|
||||||
states = states[:, -1:, :]
|
|
||||||
actions = batch["actions"].to(dtype=weight_dtype)
|
|
||||||
state_elem_mask = batch["state_elem_mask"].to(dtype=weight_dtype)
|
|
||||||
|
|
||||||
batch_size, _, C, H, W = images.shape
|
|
||||||
image_embeds = vision_encoder(images.reshape(-1, C, H, W)).detach()
|
|
||||||
image_embeds = image_embeds.reshape((batch_size, -1, vision_encoder.hidden_size))
|
|
||||||
|
|
||||||
lang_attn_mask = batch["lang_attn_mask"]
|
|
||||||
text_embeds = (batch["lang_embeds"].to(dtype=weight_dtype) if args.precomp_lang_embed else text_encoder(
|
|
||||||
input_ids=batch["input_ids"], attention_mask=lang_attn_mask)["last_hidden_state"].detach())
|
|
||||||
|
|
||||||
pred_actions = rdt.predict_action(
|
|
||||||
lang_tokens=text_embeds,
|
|
||||||
lang_attn_mask=lang_attn_mask,
|
|
||||||
img_tokens=image_embeds,
|
|
||||||
state_tokens=states,
|
|
||||||
action_mask=state_elem_mask.unsqueeze(1),
|
|
||||||
ctrl_freqs=ctrl_freqs,
|
|
||||||
)
|
|
||||||
|
|
||||||
num_steps = pred_actions.shape[1]
|
|
||||||
expanded_state_elem_mask = (state_elem_mask.unsqueeze(1).tile((1, num_steps, 1)).float())
|
|
||||||
expanded_state_norm = (state_norm.unsqueeze(1).tile((1, num_steps, 1)).float())
|
|
||||||
|
|
||||||
loss = F.mse_loss(pred_actions, actions, reduction="none").float()
|
|
||||||
|
|
||||||
mse_loss_per_entry = (loss * expanded_state_elem_mask).reshape(
|
|
||||||
(batch_size, -1)).sum(1) / expanded_state_elem_mask.reshape((batch_size, -1)).sum(1)
|
|
||||||
l2_loss_per_entry = loss.sqrt() / (expanded_state_norm + 1e-3)
|
|
||||||
l2_loss_per_entry = (l2_loss_per_entry * expanded_state_elem_mask).reshape(
|
|
||||||
(batch_size, -1)).sum(1) / expanded_state_elem_mask.reshape((batch_size, -1)).sum(1)
|
|
||||||
|
|
||||||
dataset_indices, mse_losses, l2_losses = accelerator.gather_for_metrics((
|
|
||||||
torch.LongTensor(data_indices).to(device=pred_actions.device),
|
|
||||||
mse_loss_per_entry,
|
|
||||||
l2_loss_per_entry,
|
|
||||||
), )
|
|
||||||
dataset_indices = dataset_indices.tolist()
|
|
||||||
if accelerator.is_main_process:
|
|
||||||
for loss_suffix, losses in zip(["_sample_mse", "_sample_l2err"], [mse_losses, l2_losses]):
|
|
||||||
for dataset_idx, loss_tensor in zip(dataset_indices, losses):
|
|
||||||
loss_name = dataset_id2name[dataset_idx] + loss_suffix
|
|
||||||
loss_for_log[loss_name] += loss_tensor.item()
|
|
||||||
loss_counter[loss_name] += 1
|
|
||||||
|
|
||||||
mse_loss = (loss * expanded_state_elem_mask).sum() / expanded_state_elem_mask.sum()
|
|
||||||
mse_loss_scaler = accelerator.gather(mse_loss).mean().item()
|
|
||||||
loss_for_log["overall_avg_sample_mse"] += mse_loss_scaler
|
|
||||||
|
|
||||||
l2_loss = loss.sqrt() / (expanded_state_norm + 1e-3)
|
|
||||||
l2_loss = (l2_loss * expanded_state_elem_mask).sum() / expanded_state_elem_mask.sum()
|
|
||||||
l2_loss_scaler = accelerator.gather(l2_loss).mean().item()
|
|
||||||
loss_for_log["overall_avg_sample_l2err"] += l2_loss_scaler
|
|
||||||
|
|
||||||
for name in loss_for_log:
|
|
||||||
if name in ["overall_avg_sample_mse", "overall_avg_sample_l2err"]:
|
|
||||||
loss_scaler = loss_for_log[name]
|
|
||||||
loss_for_log[name] = round(loss_scaler / (args.num_sample_batches), 4)
|
|
||||||
else:
|
|
||||||
loss_for_log[name] = round(loss_for_log[name] / loss_counter[name], 4)
|
|
||||||
|
|
||||||
rdt.train()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
return dict(loss_for_log)
|
|
||||||
@ -1,521 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# coding=utf-8
|
|
||||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
|
|
||||||
import copy
|
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import diffusers
|
|
||||||
import torch
|
|
||||||
import torch.utils.checkpoint
|
|
||||||
import transformers
|
|
||||||
import yaml
|
|
||||||
from accelerate import Accelerator
|
|
||||||
from accelerate.utils import DeepSpeedPlugin, ProjectConfiguration, set_seed
|
|
||||||
from diffusers.optimization import get_scheduler
|
|
||||||
from diffusers.utils import is_wandb_available
|
|
||||||
from huggingface_hub import create_repo, upload_folder
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from safetensors.torch import load_model
|
|
||||||
|
|
||||||
from models.ema_model import EMAModel
|
|
||||||
from models.multimodal_encoder.siglip_encoder import SiglipVisionTower
|
|
||||||
from models.multimodal_encoder.t5_encoder import T5Embedder
|
|
||||||
from models.rdt_runner import RDTRunner
|
|
||||||
from train.dataset import DataCollatorForVLAConsumerDataset, VLAConsumerDataset
|
|
||||||
from train.sample import log_sample_res
|
|
||||||
|
|
||||||
if is_wandb_available():
|
|
||||||
import wandb
|
|
||||||
|
|
||||||
|
|
||||||
def save_model_card(repo_id: str, base_model=str, repo_folder=None):
|
|
||||||
yaml = f"""
|
|
||||||
---
|
|
||||||
license: mit
|
|
||||||
base_model: {base_model}
|
|
||||||
language:
|
|
||||||
- en
|
|
||||||
pipeline_tag: robotics
|
|
||||||
library_name: transformers
|
|
||||||
tags:
|
|
||||||
- robotics
|
|
||||||
- pytorch
|
|
||||||
- multimodal
|
|
||||||
- pretraining
|
|
||||||
- vla
|
|
||||||
- diffusion
|
|
||||||
- rdt
|
|
||||||
---
|
|
||||||
"""
|
|
||||||
model_card = f"""
|
|
||||||
# RDT - {repo_id}
|
|
||||||
|
|
||||||
This is a RDT model derived from {base_model}. The weights were trained using [RDT](https://rdt-robotics.github.io/rdt-robotics/).
|
|
||||||
"""
|
|
||||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
|
||||||
f.write(yaml + model_card)
|
|
||||||
|
|
||||||
|
|
||||||
def train(args, logger):
|
|
||||||
# Read the config
|
|
||||||
with open(args.config_path, "r") as fp:
|
|
||||||
config = yaml.safe_load(fp)
|
|
||||||
|
|
||||||
with open(args.model_config_path, "r") as f:
|
|
||||||
model_config = yaml.safe_load(f)
|
|
||||||
# print(model_config)
|
|
||||||
args.output_dir = model_config["checkpoint_path"]
|
|
||||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
|
||||||
|
|
||||||
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
|
|
||||||
accelerator = Accelerator(
|
|
||||||
deepspeed_plugin=(DeepSpeedPlugin(hf_ds_config=args.deepspeed) if args.deepspeed is not None else None),
|
|
||||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
|
||||||
mixed_precision=args.mixed_precision,
|
|
||||||
log_with=args.report_to,
|
|
||||||
project_dir=logging_dir,
|
|
||||||
project_config=accelerator_project_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.report_to == "wandb":
|
|
||||||
if not is_wandb_available():
|
|
||||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
|
||||||
|
|
||||||
# Make one log on every process with the configuration for debugging.
|
|
||||||
logging.basicConfig(
|
|
||||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
||||||
datefmt="%m/%d/%Y %H:%M:%S",
|
|
||||||
level=logging.INFO,
|
|
||||||
filename=args.output_log_path,
|
|
||||||
filemode='w',
|
|
||||||
)
|
|
||||||
logger.info(accelerator.state, main_process_only=False)
|
|
||||||
if accelerator.is_local_main_process:
|
|
||||||
transformers.utils.logging.set_verbosity_warning()
|
|
||||||
diffusers.utils.logging.set_verbosity_info()
|
|
||||||
else:
|
|
||||||
transformers.utils.logging.set_verbosity_error()
|
|
||||||
diffusers.utils.logging.set_verbosity_error()
|
|
||||||
|
|
||||||
# If passed along, set the training seed now.
|
|
||||||
if args.seed is not None:
|
|
||||||
set_seed(args.seed)
|
|
||||||
|
|
||||||
# Handle the repository creation
|
|
||||||
if accelerator.is_main_process:
|
|
||||||
if args.output_dir is not None:
|
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
|
||||||
|
|
||||||
if args.push_to_hub:
|
|
||||||
repo_id = create_repo(
|
|
||||||
repo_id=args.hub_model_id or Path(args.output_dir).name,
|
|
||||||
exist_ok=True,
|
|
||||||
token=args.hub_token,
|
|
||||||
).repo_id
|
|
||||||
|
|
||||||
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
|
||||||
# as these models are only used for inference, keeping weights in full precision is not required.
|
|
||||||
weight_dtype = torch.float32
|
|
||||||
if accelerator.mixed_precision == "fp16":
|
|
||||||
weight_dtype = torch.float16
|
|
||||||
elif accelerator.mixed_precision == "bf16":
|
|
||||||
weight_dtype = torch.bfloat16
|
|
||||||
|
|
||||||
if args.precomp_lang_embed:
|
|
||||||
tokenizer, text_encoder = None, None
|
|
||||||
else:
|
|
||||||
text_embedder = T5Embedder(
|
|
||||||
from_pretrained=args.pretrained_text_encoder_name_or_path,
|
|
||||||
model_max_length=config["dataset"]["tokenizer_max_length"],
|
|
||||||
device=accelerator.device,
|
|
||||||
)
|
|
||||||
tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
|
|
||||||
|
|
||||||
vision_encoder = SiglipVisionTower(vision_tower=args.pretrained_vision_encoder_name_or_path, args=None)
|
|
||||||
image_processor = vision_encoder.image_processor
|
|
||||||
|
|
||||||
# Load from a pretrained checkpoint
|
|
||||||
if args.pretrained_model_name_or_path is not None and not os.path.isfile(args.pretrained_model_name_or_path):
|
|
||||||
logger.info("Constructing model from pretrained checkpoint.")
|
|
||||||
rdt = RDTRunner.from_pretrained(args.pretrained_model_name_or_path)
|
|
||||||
else:
|
|
||||||
logger.info("Constructing model from provided config.")
|
|
||||||
# Calculate the image condition length
|
|
||||||
img_cond_len = (config["common"]["img_history_size"] * config["common"]["num_cameras"] *
|
|
||||||
vision_encoder.num_patches)
|
|
||||||
rdt = RDTRunner(
|
|
||||||
action_dim=config["common"]["state_dim"],
|
|
||||||
pred_horizon=config["common"]["action_chunk_size"],
|
|
||||||
config=config["model"],
|
|
||||||
lang_token_dim=config["model"]["lang_token_dim"],
|
|
||||||
img_token_dim=config["model"]["img_token_dim"],
|
|
||||||
state_token_dim=config["model"]["state_token_dim"],
|
|
||||||
max_lang_cond_len=config["dataset"]["tokenizer_max_length"],
|
|
||||||
img_cond_len=img_cond_len,
|
|
||||||
img_pos_embed_config=[
|
|
||||||
# No initial pos embed in the last grid size
|
|
||||||
# since we've already done in ViT
|
|
||||||
(
|
|
||||||
"image",
|
|
||||||
(
|
|
||||||
config["common"]["img_history_size"],
|
|
||||||
config["common"]["num_cameras"],
|
|
||||||
-vision_encoder.num_patches,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
lang_pos_embed_config=[
|
|
||||||
# Similarly, no initial pos embed for language
|
|
||||||
("lang", -config["dataset"]["tokenizer_max_length"]),
|
|
||||||
],
|
|
||||||
dtype=weight_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
ema_rdt = copy.deepcopy(rdt)
|
|
||||||
ema_model = EMAModel(
|
|
||||||
ema_rdt,
|
|
||||||
update_after_step=config["model"]["ema"]["update_after_step"],
|
|
||||||
inv_gamma=config["model"]["ema"]["inv_gamma"],
|
|
||||||
power=config["model"]["ema"]["power"],
|
|
||||||
min_value=config["model"]["ema"]["min_value"],
|
|
||||||
max_value=config["model"]["ema"]["max_value"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
|
||||||
# which ensure saving model in huggingface format (config.json + pytorch_model.bin)
|
|
||||||
def save_model_hook(models, weights, output_dir):
|
|
||||||
if accelerator.is_main_process:
|
|
||||||
for model in models:
|
|
||||||
model_to_save = model.module if hasattr(model, "module") else model # type: ignore
|
|
||||||
if isinstance(model_to_save, type(accelerator.unwrap_model(rdt))):
|
|
||||||
model_to_save.save_pretrained(output_dir)
|
|
||||||
|
|
||||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
|
||||||
# TODO:
|
|
||||||
raise NotImplementedError("Gradient checkpointing is not yet implemented.")
|
|
||||||
|
|
||||||
# Enable TF32 for faster training on Ampere GPUs,
|
|
||||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
|
||||||
if args.allow_tf32:
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
|
|
||||||
if args.scale_lr:
|
|
||||||
args.learning_rate = (args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size *
|
|
||||||
accelerator.num_processes)
|
|
||||||
|
|
||||||
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
|
||||||
if args.use_8bit_adam:
|
|
||||||
try:
|
|
||||||
import bitsandbytes as bnb
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
|
|
||||||
|
|
||||||
optimizer_class = bnb.optim.AdamW8bit
|
|
||||||
else:
|
|
||||||
optimizer_class = torch.optim.AdamW
|
|
||||||
|
|
||||||
# Optimizer creation
|
|
||||||
params_to_optimize = rdt.parameters()
|
|
||||||
optimizer = optimizer_class(
|
|
||||||
params_to_optimize,
|
|
||||||
lr=args.learning_rate,
|
|
||||||
betas=(args.adam_beta1, args.adam_beta2),
|
|
||||||
weight_decay=args.adam_weight_decay,
|
|
||||||
eps=args.adam_epsilon,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Dataset and DataLoaders creation:
|
|
||||||
train_dataset = VLAConsumerDataset(
|
|
||||||
model_config_path=args.model_config_path, # TODO
|
|
||||||
config=config["dataset"],
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
image_processor=image_processor,
|
|
||||||
num_cameras=config["common"]["num_cameras"],
|
|
||||||
img_history_size=config["common"]["img_history_size"],
|
|
||||||
dataset_type=args.dataset_type,
|
|
||||||
image_aug=args.image_aug,
|
|
||||||
cond_mask_prob=args.cond_mask_prob,
|
|
||||||
cam_ext_mask_prob=args.cam_ext_mask_prob,
|
|
||||||
state_noise_snr=args.state_noise_snr,
|
|
||||||
use_hdf5=args.load_from_hdf5,
|
|
||||||
use_precomp_lang_embed=args.precomp_lang_embed,
|
|
||||||
)
|
|
||||||
sample_dataset = VLAConsumerDataset(
|
|
||||||
model_config_path=args.model_config_path, # TODO
|
|
||||||
config=config["dataset"],
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
image_processor=image_processor,
|
|
||||||
num_cameras=config["common"]["num_cameras"],
|
|
||||||
img_history_size=config["common"]["img_history_size"],
|
|
||||||
dataset_type=args.dataset_type,
|
|
||||||
image_aug=False,
|
|
||||||
cond_mask_prob=0,
|
|
||||||
cam_ext_mask_prob=-1,
|
|
||||||
state_noise_snr=None,
|
|
||||||
use_hdf5=args.load_from_hdf5,
|
|
||||||
use_precomp_lang_embed=args.precomp_lang_embed,
|
|
||||||
)
|
|
||||||
|
|
||||||
data_collator = DataCollatorForVLAConsumerDataset(tokenizer)
|
|
||||||
|
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
|
||||||
train_dataset,
|
|
||||||
batch_size=args.train_batch_size,
|
|
||||||
shuffle=True,
|
|
||||||
collate_fn=data_collator,
|
|
||||||
num_workers=args.dataloader_num_workers,
|
|
||||||
pin_memory=True,
|
|
||||||
persistent_workers=True,
|
|
||||||
)
|
|
||||||
sample_dataloader = torch.utils.data.DataLoader(
|
|
||||||
sample_dataset,
|
|
||||||
batch_size=args.sample_batch_size,
|
|
||||||
shuffle=True,
|
|
||||||
collate_fn=data_collator,
|
|
||||||
num_workers=args.dataloader_num_workers,
|
|
||||||
pin_memory=True,
|
|
||||||
persistent_workers=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Scheduler and math around the number of training steps.
|
|
||||||
overrode_max_train_steps = False
|
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
|
||||||
if args.max_train_steps is None:
|
|
||||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
|
||||||
overrode_max_train_steps = True
|
|
||||||
|
|
||||||
lr_scheduler = get_scheduler(
|
|
||||||
args.lr_scheduler,
|
|
||||||
optimizer=optimizer,
|
|
||||||
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
|
||||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
|
||||||
num_cycles=args.lr_num_cycles,
|
|
||||||
power=args.lr_power,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare everything with our `accelerator`.
|
|
||||||
rdt, optimizer, train_dataloader, sample_dataloader, lr_scheduler = (accelerator.prepare(
|
|
||||||
rdt, optimizer, train_dataloader, sample_dataloader, lr_scheduler))
|
|
||||||
|
|
||||||
ema_rdt.to(accelerator.device, dtype=weight_dtype)
|
|
||||||
|
|
||||||
if text_encoder is not None:
|
|
||||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
|
||||||
|
|
||||||
if vision_encoder is not None:
|
|
||||||
vision_encoder.vision_tower.to(accelerator.device, dtype=weight_dtype)
|
|
||||||
|
|
||||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
|
||||||
if overrode_max_train_steps:
|
|
||||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
|
||||||
# Afterwards we recalculate our number of training epochs
|
|
||||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
|
||||||
|
|
||||||
# We need to initialize the trackers we use, and also store our configuration.
|
|
||||||
# The trackers initializes automatically on the main process.
|
|
||||||
if accelerator.is_main_process:
|
|
||||||
accelerator.init_trackers(
|
|
||||||
"VLA",
|
|
||||||
config=vars(args),
|
|
||||||
init_kwargs={"wandb": {
|
|
||||||
"name": f"RoboTwin_RDT_{args.CONFIG_NAME}",
|
|
||||||
}},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Train!
|
|
||||||
total_batch_size = (args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps)
|
|
||||||
|
|
||||||
logger.info("***** Running training *****")
|
|
||||||
logger.info(f" Num examples = {len(train_dataset)}")
|
|
||||||
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
|
||||||
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
|
||||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
|
||||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
|
||||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
|
||||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
|
||||||
global_step = 0
|
|
||||||
first_epoch = 0
|
|
||||||
|
|
||||||
# Load from a pretrained checkpoint
|
|
||||||
if (args.resume_from_checkpoint is None and args.pretrained_model_name_or_path is not None
|
|
||||||
and os.path.isfile(args.pretrained_model_name_or_path)):
|
|
||||||
# Since EMA is deprecated, we do not load EMA from the pretrained checkpoint
|
|
||||||
logger.info("Loading from a pretrained checkpoint.")
|
|
||||||
checkpoint = torch.load(args.pretrained_model_name_or_path)
|
|
||||||
rdt.module.load_state_dict(checkpoint["module"])
|
|
||||||
|
|
||||||
# Potentially load in the weights and states from a previous save
|
|
||||||
if args.resume_from_checkpoint:
|
|
||||||
if args.resume_from_checkpoint != "latest":
|
|
||||||
path = os.path.basename(args.resume_from_checkpoint)
|
|
||||||
else:
|
|
||||||
# Get the mos recent checkpoint
|
|
||||||
dirs = os.listdir(args.output_dir)
|
|
||||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
|
||||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
|
||||||
path = dirs[-1] if len(dirs) > 0 else None
|
|
||||||
|
|
||||||
if path is None:
|
|
||||||
accelerator.print(
|
|
||||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.")
|
|
||||||
args.resume_from_checkpoint = None
|
|
||||||
else:
|
|
||||||
accelerator.print(f"Resuming from checkpoint {path}")
|
|
||||||
try:
|
|
||||||
accelerator.load_state(os.path.join(args.output_dir, path)) # load_module_strict=False
|
|
||||||
except:
|
|
||||||
# load deepspeed's state_dict
|
|
||||||
logger.info("Resuming training state failed. Attempting to only load from model checkpoint.")
|
|
||||||
checkpoint = torch.load(
|
|
||||||
os.path.join(
|
|
||||||
args.output_dir,
|
|
||||||
path,
|
|
||||||
"pytorch_model",
|
|
||||||
"mp_rank_00_model_states.pt",
|
|
||||||
))
|
|
||||||
rdt.module.load_state_dict(checkpoint["module"])
|
|
||||||
|
|
||||||
load_model(ema_rdt, os.path.join(args.output_dir, path, "ema", "model.safetensors"))
|
|
||||||
global_step = int(path.split("-")[1])
|
|
||||||
|
|
||||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
|
||||||
first_epoch = global_step // num_update_steps_per_epoch
|
|
||||||
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
|
||||||
|
|
||||||
# Only show the progress bar once on each machine.
|
|
||||||
progress_bar = tqdm(
|
|
||||||
range(global_step, args.max_train_steps),
|
|
||||||
disable=not accelerator.is_local_main_process,
|
|
||||||
)
|
|
||||||
progress_bar.set_description("Steps")
|
|
||||||
|
|
||||||
loss_for_log = {}
|
|
||||||
for epoch in range(first_epoch, args.num_train_epochs):
|
|
||||||
|
|
||||||
rdt.train()
|
|
||||||
|
|
||||||
# Set the progress_bar to correct position
|
|
||||||
if args.resume_from_checkpoint and epoch == first_epoch:
|
|
||||||
progress_bar.update(resume_step // args.gradient_accumulation_steps)
|
|
||||||
|
|
||||||
# Forward and backward...
|
|
||||||
for batch in train_dataloader:
|
|
||||||
with accelerator.accumulate(rdt):
|
|
||||||
images = batch["images"].to(dtype=weight_dtype)
|
|
||||||
states = batch["states"].to(dtype=weight_dtype) # (B, T, D_a)
|
|
||||||
# We only use the last state as input
|
|
||||||
states = states[:, -1:, :]
|
|
||||||
actions = batch["actions"].to(dtype=weight_dtype)
|
|
||||||
state_elem_mask = batch["state_elem_mask"].to(dtype=weight_dtype)
|
|
||||||
ctrl_freqs = batch["ctrl_freqs"]
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
batch_size, _, C, H, W = images.shape
|
|
||||||
image_embeds = vision_encoder(images.reshape(-1, C, H, W)).detach()
|
|
||||||
image_embeds = image_embeds.reshape((batch_size, -1, vision_encoder.hidden_size))
|
|
||||||
|
|
||||||
lang_attn_mask = batch["lang_attn_mask"]
|
|
||||||
text_embeds = (batch["lang_embeds"].to(
|
|
||||||
dtype=weight_dtype) if args.precomp_lang_embed else text_encoder(
|
|
||||||
input_ids=batch["input_ids"], attention_mask=lang_attn_mask)["last_hidden_state"].detach())
|
|
||||||
|
|
||||||
state_elem_mask = state_elem_mask.unsqueeze(1)
|
|
||||||
loss = rdt(
|
|
||||||
lang_tokens=text_embeds,
|
|
||||||
lang_attn_mask=lang_attn_mask,
|
|
||||||
img_tokens=image_embeds,
|
|
||||||
state_tokens=states,
|
|
||||||
action_gt=actions,
|
|
||||||
action_mask=state_elem_mask,
|
|
||||||
ctrl_freqs=ctrl_freqs,
|
|
||||||
)
|
|
||||||
|
|
||||||
accelerator.backward(loss)
|
|
||||||
if accelerator.sync_gradients:
|
|
||||||
params_to_clip = rdt.parameters()
|
|
||||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
|
||||||
optimizer.step()
|
|
||||||
lr_scheduler.step()
|
|
||||||
optimizer.zero_grad(set_to_none=args.set_grads_to_none)
|
|
||||||
|
|
||||||
ema_model.step(accelerator.unwrap_model(rdt))
|
|
||||||
|
|
||||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
|
||||||
if accelerator.sync_gradients:
|
|
||||||
progress_bar.update(1)
|
|
||||||
global_step += 1
|
|
||||||
|
|
||||||
if global_step % args.checkpointing_period == 0:
|
|
||||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
|
||||||
accelerator.save_state(save_path)
|
|
||||||
ema_save_path = os.path.join(save_path, f"ema")
|
|
||||||
accelerator.save_model(ema_rdt, ema_save_path)
|
|
||||||
logger.info(f"Saved state to {save_path}")
|
|
||||||
|
|
||||||
if args.sample_period > 0 and global_step % args.sample_period == 0:
|
|
||||||
sample_loss_for_log = log_sample_res(
|
|
||||||
text_encoder,
|
|
||||||
vision_encoder,
|
|
||||||
rdt, # We do not use EMA currently
|
|
||||||
args,
|
|
||||||
accelerator,
|
|
||||||
weight_dtype,
|
|
||||||
sample_dataset.get_dataset_id2name(),
|
|
||||||
sample_dataloader,
|
|
||||||
logger,
|
|
||||||
)
|
|
||||||
logger.info(sample_loss_for_log)
|
|
||||||
accelerator.log(sample_loss_for_log, step=global_step)
|
|
||||||
|
|
||||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
|
||||||
progress_bar.set_postfix(**logs)
|
|
||||||
logs.update(loss_for_log)
|
|
||||||
# logger.info(logs)
|
|
||||||
accelerator.log(logs, step=global_step)
|
|
||||||
|
|
||||||
if global_step >= args.max_train_steps:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Create the pipeline using using the trained modules and save it.
|
|
||||||
accelerator.wait_for_everyone()
|
|
||||||
if accelerator.is_main_process:
|
|
||||||
accelerator.unwrap_model(rdt).save_pretrained(args.output_dir)
|
|
||||||
ema_save_path = os.path.join(args.output_dir, f"ema")
|
|
||||||
accelerator.save_model(ema_rdt, ema_save_path)
|
|
||||||
|
|
||||||
logger.info(f"Saved Model to {args.output_dir}")
|
|
||||||
|
|
||||||
if args.push_to_hub:
|
|
||||||
save_model_card(
|
|
||||||
repo_id,
|
|
||||||
base_model=args.pretrained_model_name_or_path,
|
|
||||||
repo_folder=args.output_dir,
|
|
||||||
)
|
|
||||||
upload_folder(
|
|
||||||
repo_id=repo_id,
|
|
||||||
folder_path=args.output_dir,
|
|
||||||
commit_message="End of training",
|
|
||||||
token=args.hub_token,
|
|
||||||
allow_patterns=["pytorch_model.bin", "*.json", "*.md"],
|
|
||||||
# ignore_patterns=["step_*", "epoch_*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
accelerator.end_training()
|
|
||||||
Loading…
x
Reference in New Issue
Block a user