994 lines
46 KiB
Python
994 lines
46 KiB
Python
import os
|
|
import re
|
|
import json
|
|
import logging
|
|
import argparse
|
|
from time import time
|
|
from collections import OrderedDict
|
|
from dataclasses import dataclass
|
|
import yaml
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
import h5py
|
|
from PIL import Image as PImage
|
|
|
|
import onnx
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torchvision import transforms
|
|
|
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
|
from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
|
|
|
|
from scripts.agilex_model import create_model
|
|
from configs.state_vec import STATE_VEC_IDX_MAPPING
|
|
from models.hub_mixin import CompatiblePyTorchModelHubMixin
|
|
from models.rdt.blocks import (FinalLayer, RDTBlock, TimestepEmbedder, get_1d_sincos_pos_embed_from_grid, get_multimodal_cond_pos_embed)
|
|
from models.multimodal_encoder.siglip_encoder import SiglipVisionTower
|
|
from models.multimodal_encoder.t5_encoder import T5Embedder
|
|
|
|
logging.basicConfig(
|
|
level=logging.DEBUG,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
datefmt='%H:%M:%S')
|
|
logger = logging.getLogger("RDT_EXPORT")
|
|
|
|
os.environ["WANDB_MODE"] = "disabled"
|
|
|
|
@dataclass
|
|
class ExportConfig:
|
|
task_id: str = None
|
|
output_path: str = None
|
|
model_path: str = None
|
|
calibration_num: int = 100
|
|
lang_calibration_num: int = 1
|
|
dataset_path: str = None
|
|
gpu_id: str = "0"
|
|
march: str = None
|
|
model_type: str = None
|
|
pretrained_vision_encoder_name_or_path: str = None
|
|
ctrl_freq: int = 25
|
|
cal_data_device: str = "cuda"
|
|
|
|
|
|
AGILEX_STATE_INDICES = [
|
|
STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(6)
|
|
]
|
|
|
|
def dump_img_adaptor(img_tokens):
|
|
global img_adaptor_cal_ws
|
|
global dump_cnt, dump_dataset_name
|
|
np.save(os.path.join(img_adaptor_cal_ws, f"img_adaptor_{dump_dataset_name}_{dump_cnt}.npy"), img_tokens.float().contiguous().cpu().detach().numpy())
|
|
|
|
def dump_dit(state_action_traj, ctrl_freqs, t, lang_cond, img_cond, lang_attn_mask):
|
|
t_str = str(t)
|
|
x = state_action_traj.float().contiguous().cpu().detach().numpy()
|
|
freq = ctrl_freqs.float().contiguous().cpu().detach().numpy().astype(np.int32).copy()
|
|
t_ = t.float().contiguous().cpu().detach().numpy()
|
|
t_ = np.expand_dims(t_.astype(np.int32), axis=0).copy()
|
|
lang_c = lang_cond.float().contiguous().cpu().detach().numpy()
|
|
img_c = img_cond.float().contiguous().cpu().detach().numpy()
|
|
lang_mask = lang_attn_mask.float().contiguous().cpu().detach().numpy()
|
|
pad_rows = 64 - lang_mask.shape[1]
|
|
padded = np.pad(lang_mask, ((0,0), (0,pad_rows)), mode="constant")
|
|
mask_float = np.where(padded, 0.0, -512.0).astype(np.float32)
|
|
lang_cond_padded = np.pad(lang_c, pad_width=((0, 0), (0, pad_rows), (0,0)), mode="constant", constant_values=0)
|
|
global dit_cal_path_x, dit_cal_path_freq, dit_cal_path_t, dit_cal_path_lang_c, dit_cal_path_img_c, dit_cal_path_lang_mask
|
|
global dump_cnt, dump_dataset_name
|
|
np.save(os.path.join(dit_cal_path_x, f"x_{t_str}_{dump_dataset_name}_{dump_cnt}.npy"), x)
|
|
np.save(os.path.join(dit_cal_path_freq, f"freq_{t_str}_{dump_dataset_name}_{dump_cnt}.npy"), freq)
|
|
np.save(os.path.join(dit_cal_path_t, f"t_{t_str}_{dump_dataset_name}_{dump_cnt}.npy"), t_)
|
|
np.save(os.path.join(dit_cal_path_lang_c, f"lang_c_{t_str}_{dump_dataset_name}_{dump_cnt}.npy"), lang_cond_padded)
|
|
np.save(os.path.join(dit_cal_path_img_c, f"img_{t_str}_{dump_dataset_name}_{dump_cnt}.npy"), img_c)
|
|
np.save(os.path.join(dit_cal_path_lang_mask, f"lang_mask_{t_str}_{dump_dataset_name}_{dump_cnt}.npy"), mask_float)
|
|
|
|
def create_dump_model(args, **kwargs):
|
|
# left_arm_dim, right_arm_dim = (args["arm_dim"]["left_arm_dim"], args["arm_dim"]["right_arm_dim"],)
|
|
# AGILEX_STATE_INDICES = ([STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"]
|
|
# for i in range(left_arm_dim)] + [STATE_VEC_IDX_MAPPING["left_gripper_open"]] +
|
|
# [STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"]
|
|
# for i in range(right_arm_dim)] + [STATE_VEC_IDX_MAPPING[f"right_gripper_open"]])
|
|
|
|
model = RoboticDiffusionTransformerModel_Dump(args, **kwargs)
|
|
pretrained = kwargs.get("pretrained", None)
|
|
if pretrained is not None and os.path.isfile(pretrained):
|
|
model.load_pretrained_weights(pretrained)
|
|
return model
|
|
|
|
class RDT_Dump(nn.Module):
|
|
def __init__(self,
|
|
output_dim=128,
|
|
horizon=32,
|
|
hidden_size=1152,
|
|
depth=28,
|
|
num_heads=16,
|
|
max_lang_cond_len=1024,
|
|
img_cond_len=4096,
|
|
lang_pos_embed_config=None,
|
|
img_pos_embed_config=None,
|
|
dtype=torch.bfloat16):
|
|
super().__init__()
|
|
self.horizon = horizon
|
|
self.hidden_size = hidden_size
|
|
self.max_lang_cond_len = max_lang_cond_len
|
|
self.img_cond_len = img_cond_len
|
|
self.dtype = dtype
|
|
self.lang_pos_embed_config = lang_pos_embed_config
|
|
self.img_pos_embed_config = img_pos_embed_config
|
|
self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype)
|
|
self.freq_embedder = TimestepEmbedder(hidden_size, dtype=dtype)
|
|
# We will use trainable sin-cos embeddings
|
|
# [timestep; state; action]
|
|
self.x_pos_embed = nn.Parameter(torch.zeros(1, horizon + 3, hidden_size))
|
|
# Language conditions
|
|
self.lang_cond_pos_embed = nn.Parameter(torch.zeros(1, max_lang_cond_len, hidden_size))
|
|
# Image conditions
|
|
self.img_cond_pos_embed = nn.Parameter(torch.zeros(1, img_cond_len, hidden_size))
|
|
self.blocks = nn.ModuleList([RDTBlock(hidden_size, num_heads) for _ in range(depth)])
|
|
self.final_layer = FinalLayer(hidden_size, output_dim)
|
|
self.initialize_weights()
|
|
def initialize_weights(self):
|
|
# Initialize transformer layers:
|
|
def _basic_init(module):
|
|
if isinstance(module, nn.Linear):
|
|
torch.nn.init.xavier_uniform_(module.weight)
|
|
if module.bias is not None:
|
|
nn.init.constant_(module.bias, 0)
|
|
self.apply(_basic_init)
|
|
# Initialize pos_embed by sin-cos embedding
|
|
x_pos_embed = get_multimodal_cond_pos_embed(embed_dim=self.hidden_size,
|
|
mm_cond_lens=OrderedDict([
|
|
('timestep', 1),
|
|
('ctrl_freq', 1),
|
|
('state', 1),
|
|
('action', self.horizon),
|
|
]))
|
|
self.x_pos_embed.data.copy_(torch.from_numpy(x_pos_embed).float().unsqueeze(0))
|
|
if self.lang_pos_embed_config is None:
|
|
lang_cond_pos_embed = get_1d_sincos_pos_embed_from_grid(self.hidden_size, torch.arange(self.max_lang_cond_len))
|
|
else:
|
|
lang_cond_pos_embed = get_multimodal_cond_pos_embed(embed_dim=self.hidden_size, mm_cond_lens=OrderedDict(self.lang_pos_embed_config), embed_modality=False)
|
|
self.lang_cond_pos_embed.data.copy_(torch.from_numpy(lang_cond_pos_embed).float().unsqueeze(0))
|
|
if self.img_pos_embed_config is None:
|
|
img_cond_pos_embed = get_1d_sincos_pos_embed_from_grid(self.hidden_size, torch.arange(self.img_cond_len))
|
|
else:
|
|
img_cond_pos_embed = get_multimodal_cond_pos_embed(embed_dim=self.hidden_size, mm_cond_lens=OrderedDict(self.img_pos_embed_config), embed_modality=False)
|
|
self.img_cond_pos_embed.data.copy_(torch.from_numpy(img_cond_pos_embed).float().unsqueeze(0))
|
|
# Initialize timestep and control freq embedding MLP
|
|
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
|
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
|
nn.init.normal_(self.freq_embedder.mlp[0].weight, std=0.02)
|
|
nn.init.normal_(self.freq_embedder.mlp[2].weight, std=0.02)
|
|
# Initialize the final layer: zero-out the final linear layer
|
|
nn.init.constant_(self.final_layer.ffn_final.fc2.weight, 0)
|
|
nn.init.constant_(self.final_layer.ffn_final.fc2.bias, 0)
|
|
# Move all the params to given data type:
|
|
self.to(self.dtype)
|
|
def forward(self, x, freq, t, lang_c, img_c, lang_mask=None, img_mask=None):
|
|
t = self.t_embedder(t).unsqueeze(1) # (B, 1, D) or (1, 1, D)
|
|
freq = self.freq_embedder(freq).unsqueeze(1) # (B, 1, D)
|
|
# Append timestep to the input tokens
|
|
if t.shape[0] == 1:
|
|
t = t.expand(x.shape[0], -1, -1)
|
|
x = torch.cat([t, freq, x], dim=1) # (B, T+1, D)
|
|
# Add multimodal position embeddings
|
|
x = x + self.x_pos_embed
|
|
# Note the lang is of variable length
|
|
lang_c = lang_c + self.lang_cond_pos_embed[:, :lang_c.shape[1]]
|
|
img_c = img_c + self.img_cond_pos_embed
|
|
# Forward pass
|
|
conds = [lang_c, img_c]
|
|
masks = [lang_mask, img_mask]
|
|
for i, block in enumerate(self.blocks):
|
|
c, mask = conds[i % 2], masks[i % 2]
|
|
x = block(x, c, mask) # (B, T+1, D)
|
|
# Inject the language condition at the final layer
|
|
x = self.final_layer(x) # (B, T+1, out_channels)
|
|
# Only preserve the action tokens
|
|
x = x[:, -self.horizon:]
|
|
return x
|
|
|
|
class RDTRunner_Dump(nn.Module,
|
|
CompatiblePyTorchModelHubMixin,
|
|
repo_url="https://huggingface.co/robotics-diffusion-transformer/rdt-1b"):
|
|
def __init__(self,
|
|
*,
|
|
action_dim,
|
|
pred_horizon,
|
|
config,
|
|
lang_token_dim,
|
|
img_token_dim,
|
|
state_token_dim,
|
|
max_lang_cond_len,
|
|
img_cond_len,
|
|
lang_pos_embed_config=None,
|
|
img_pos_embed_config=None,
|
|
dtype=torch.bfloat16):
|
|
super(RDTRunner_Dump, self).__init__()
|
|
# Create diffusion model
|
|
hidden_size = config['rdt']['hidden_size']
|
|
self.model = RDT_Dump(
|
|
output_dim=action_dim,
|
|
horizon=pred_horizon,
|
|
hidden_size=hidden_size,
|
|
depth=config['rdt']['depth'],
|
|
num_heads=config['rdt']['num_heads'],
|
|
max_lang_cond_len=max_lang_cond_len,
|
|
img_cond_len=img_cond_len,
|
|
lang_pos_embed_config=lang_pos_embed_config,
|
|
img_pos_embed_config=img_pos_embed_config,
|
|
dtype=dtype,
|
|
)
|
|
# Create adpators for various conditional inputs
|
|
self.lang_adaptor = self.build_condition_adapter(config['lang_adaptor'], in_features=lang_token_dim, out_features=hidden_size)
|
|
self.img_adaptor = self.build_condition_adapter(config['img_adaptor'], in_features=img_token_dim, out_features=hidden_size)
|
|
# A `state` refers to an action or a proprioception vector
|
|
self.state_adaptor = self.build_condition_adapter(
|
|
config['state_adaptor'],
|
|
in_features=state_token_dim * 2, # state + state mask (indicator)
|
|
out_features=hidden_size)
|
|
# Create the noise scheduler
|
|
noise_scheduler_config = config['noise_scheduler']
|
|
self.noise_scheduler = DDPMScheduler(
|
|
num_train_timesteps=noise_scheduler_config['num_train_timesteps'],
|
|
beta_schedule=noise_scheduler_config['beta_schedule'],
|
|
prediction_type=noise_scheduler_config['prediction_type'],
|
|
clip_sample=noise_scheduler_config['clip_sample'],
|
|
)
|
|
self.noise_scheduler_sample = DPMSolverMultistepScheduler(
|
|
num_train_timesteps=noise_scheduler_config['num_train_timesteps'],
|
|
beta_schedule=noise_scheduler_config['beta_schedule'],
|
|
prediction_type=noise_scheduler_config['prediction_type'],
|
|
)
|
|
self.num_train_timesteps = noise_scheduler_config['num_train_timesteps']
|
|
self.num_inference_timesteps = noise_scheduler_config['num_inference_timesteps']
|
|
self.prediction_type = noise_scheduler_config['prediction_type']
|
|
self.pred_horizon = pred_horizon
|
|
self.action_dim = action_dim
|
|
print("Diffusion params: %e" %
|
|
sum([p.numel() for p in self.model.parameters()] + [p.numel() for p in self.lang_adaptor.parameters()] +
|
|
[p.numel()
|
|
for p in self.img_adaptor.parameters()] + [p.numel() for p in self.state_adaptor.parameters()]))
|
|
def build_condition_adapter(self, projector_type, in_features, out_features):
|
|
projector = None
|
|
if projector_type == 'linear':
|
|
projector = nn.Linear(in_features, out_features)
|
|
else:
|
|
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
|
|
if mlp_gelu_match:
|
|
mlp_depth = int(mlp_gelu_match.group(1))
|
|
modules = [nn.Linear(in_features, out_features)]
|
|
for _ in range(1, mlp_depth):
|
|
modules.append(nn.GELU(approximate="tanh"))
|
|
modules.append(nn.Linear(out_features, out_features))
|
|
projector = nn.Sequential(*modules)
|
|
if projector is None:
|
|
raise ValueError(f'Unknown projector type: {projector_type}')
|
|
return projector
|
|
def adapt_conditions(self, lang_tokens, img_tokens, state_tokens):
|
|
adpated_lang = self.lang_adaptor(lang_tokens)
|
|
dump_img_adaptor(img_tokens)
|
|
adpated_img = self.img_adaptor(img_tokens)
|
|
adpated_state = self.state_adaptor(state_tokens)
|
|
return adpated_lang, adpated_img, adpated_state
|
|
def conditional_sample(self, lang_cond, lang_attn_mask, img_cond, state_traj, action_mask, ctrl_freqs):
|
|
device = state_traj.device
|
|
dtype = state_traj.dtype
|
|
noisy_action = torch.randn(size=(state_traj.shape[0], self.pred_horizon, self.action_dim), dtype=dtype, device=device)
|
|
action_mask = action_mask.expand(-1, self.pred_horizon, -1)
|
|
# Set step values
|
|
self.noise_scheduler_sample.set_timesteps(self.num_inference_timesteps)
|
|
for t in self.noise_scheduler_sample.timesteps:
|
|
# Prepare state-action trajectory
|
|
action_traj = torch.cat([noisy_action, action_mask], dim=2)
|
|
action_traj = self.state_adaptor(action_traj)
|
|
state_action_traj = torch.cat([state_traj, action_traj], dim=1)
|
|
# dump
|
|
dump_dit(state_action_traj, ctrl_freqs, t, lang_cond, img_cond, lang_attn_mask)
|
|
# Predict the model output
|
|
model_output = self.model(state_action_traj,
|
|
ctrl_freqs,
|
|
t.unsqueeze(-1).to(device),
|
|
lang_cond,
|
|
img_cond,
|
|
lang_mask=lang_attn_mask)
|
|
# Compute previous actions: x_t -> x_t-1
|
|
noisy_action = self.noise_scheduler_sample.step(model_output, t, noisy_action).prev_sample
|
|
noisy_action = noisy_action.to(state_traj.dtype)
|
|
# Finally apply the action mask to mask invalid action dimensions
|
|
noisy_action = noisy_action * action_mask
|
|
return noisy_action
|
|
def compute_loss(self, lang_tokens, lang_attn_mask, img_tokens, state_tokens, action_gt, action_mask,
|
|
ctrl_freqs) -> torch.Tensor:
|
|
'''
|
|
lang_tokens: (batch_size, lang_len, lang_token_dim)
|
|
lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens,
|
|
which should be True-False bool tensor.
|
|
img_tokens: (batch_size, img_len, img_token_dim)
|
|
state_tokens: (batch_size, 1, state_token_dim)
|
|
action_gt: (batch_size, horizon, state_token_dim), ground-truth actions for supervision
|
|
action_mask: (batch_size, 1, state_token_dim), a 0-1 **float** tensor.
|
|
ctrl_freqs: (batch_size,), control frequency for each sample.
|
|
|
|
return: loss_value, a scalar tensor
|
|
'''
|
|
batch_size = lang_tokens.shape[0]
|
|
device = lang_tokens.device
|
|
# Sample noise that we'll add to the actions
|
|
noise = torch.randn(action_gt.shape, dtype=action_gt.dtype, device=device)
|
|
# Sample random diffusion timesteps
|
|
timesteps = torch.randint(0, self.num_train_timesteps, (batch_size, ), device=device).long()
|
|
# Add noise to the clean actions according to the noise magnitude at each timestep
|
|
# (this is the forward diffusion process)
|
|
noisy_action = self.noise_scheduler.add_noise(action_gt, noise, timesteps)
|
|
# Concatenate the state and action tokens to form the input sequence
|
|
state_action_traj = torch.cat([state_tokens, noisy_action], dim=1)
|
|
# Append the action mask to the input sequence
|
|
action_mask = action_mask.expand(-1, state_action_traj.shape[1], -1)
|
|
state_action_traj = torch.cat([state_action_traj, action_mask], dim=2)
|
|
# Align the dimension with the hidden size
|
|
lang_cond, img_cond, state_action_traj = self.adapt_conditions(lang_tokens, img_tokens, state_action_traj)
|
|
# Predict the denoised result
|
|
pred = self.model(state_action_traj, ctrl_freqs, timesteps, lang_cond, img_cond, lang_mask=lang_attn_mask)
|
|
pred_type = self.prediction_type
|
|
if pred_type == 'epsilon':
|
|
target = noise
|
|
elif pred_type == 'sample':
|
|
target = action_gt
|
|
else:
|
|
raise ValueError(f"Unsupported prediction type {pred_type}")
|
|
loss = F.mse_loss(pred, target)
|
|
return loss
|
|
def predict_action(self, lang_tokens, lang_attn_mask, img_tokens, state_tokens, action_mask, ctrl_freqs):
|
|
'''
|
|
lang_tokens: (batch_size, lang_len, lang_token_dim)
|
|
lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens,
|
|
which should be True-False bool tensor.
|
|
img_tokens: (batch_size, img_len, img_token_dim)
|
|
state_tokens: (batch_size, 1, state_token_dim)
|
|
action_mask: (batch_size, 1, action_dim),
|
|
which should be a 0-1 **float** tensor.
|
|
ctrl_freqs: (batch_size,), control frequency for each sample.
|
|
|
|
return: (batch_size, horizon, action_dim), predicted action sequence
|
|
'''
|
|
# Prepare the state and conditions
|
|
state_tokens = torch.cat([state_tokens, action_mask], dim=2)
|
|
lang_cond, img_cond, state_traj = self.adapt_conditions(lang_tokens, img_tokens, state_tokens)
|
|
# Run sampling
|
|
action_pred = self.conditional_sample(
|
|
lang_cond,
|
|
lang_attn_mask,
|
|
img_cond,
|
|
state_traj,
|
|
action_mask,
|
|
ctrl_freqs,
|
|
)
|
|
return action_pred
|
|
def forward(self, *args, **kwargs) -> torch.Tensor:
|
|
return self.compute_loss(*args, **kwargs)
|
|
|
|
class RoboticDiffusionTransformerModel_Dump(object):
|
|
def __init__(
|
|
self,
|
|
args,
|
|
device="cuda",
|
|
dtype=torch.bfloat16,
|
|
image_size=None,
|
|
control_frequency=25,
|
|
pretrained=None,
|
|
pretrained_vision_encoder_name_or_path=None,
|
|
):
|
|
self.args = args
|
|
self.dtype = dtype
|
|
self.image_size = image_size
|
|
self.device = device
|
|
self.control_frequency = control_frequency
|
|
# We do not use the text encoder due to limited GPU memory
|
|
# self.text_tokenizer, self.text_model = self.get_text_encoder(pretrained_text_encoder_name_or_path)
|
|
self.image_processor, self.vision_model = self.get_vision_encoder(pretrained_vision_encoder_name_or_path)
|
|
self.policy = self.get_policy(pretrained)
|
|
|
|
self.reset()
|
|
|
|
def get_policy(self, pretrained):
|
|
# Initialize model with arguments
|
|
if pretrained is None or os.path.isfile(pretrained):
|
|
img_cond_len = (self.args["common"]["img_history_size"] * self.args["common"]["num_cameras"] *
|
|
self.vision_model.num_patches)
|
|
|
|
_model = RDTRunner_Dump(
|
|
action_dim=self.args["common"]["state_dim"],
|
|
pred_horizon=self.args["common"]["action_chunk_size"],
|
|
config=self.args["model"],
|
|
lang_token_dim=self.args["model"]["lang_token_dim"],
|
|
img_token_dim=self.args["model"]["img_token_dim"],
|
|
state_token_dim=self.args["model"]["state_token_dim"],
|
|
max_lang_cond_len=self.args["dataset"]["tokenizer_max_length"],
|
|
img_cond_len=img_cond_len,
|
|
img_pos_embed_config=[
|
|
# No initial pos embed in the last grid size
|
|
# since we've already done in ViT
|
|
(
|
|
"image",
|
|
(
|
|
self.args["common"]["img_history_size"],
|
|
self.args["common"]["num_cameras"],
|
|
-self.vision_model.num_patches,
|
|
),
|
|
),
|
|
],
|
|
lang_pos_embed_config=[
|
|
# Similarly, no initial pos embed for language
|
|
("lang", -self.args["dataset"]["tokenizer_max_length"]),
|
|
],
|
|
dtype=self.dtype,
|
|
)
|
|
else:
|
|
_model = RDTRunner_Dump.from_pretrained(pretrained)
|
|
|
|
return _model
|
|
|
|
def get_text_encoder(self, pretrained_text_encoder_name_or_path):
|
|
text_embedder = T5Embedder(
|
|
from_pretrained=pretrained_text_encoder_name_or_path,
|
|
model_max_length=self.args["dataset"]["tokenizer_max_length"],
|
|
device=self.device,
|
|
)
|
|
tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
|
|
return tokenizer, text_encoder
|
|
def get_vision_encoder(self, pretrained_vision_encoder_name_or_path):
|
|
vision_encoder = SiglipVisionTower(vision_tower=pretrained_vision_encoder_name_or_path, args=None)
|
|
image_processor = vision_encoder.image_processor
|
|
return image_processor, vision_encoder
|
|
def reset(self):
|
|
device = self.device
|
|
weight_dtype = self.dtype
|
|
self.policy.eval()
|
|
# self.text_model.eval()
|
|
self.vision_model.eval()
|
|
self.policy = self.policy.to(device, dtype=weight_dtype)
|
|
# self.text_model = self.text_model.to(device, dtype=weight_dtype)
|
|
self.vision_model = self.vision_model.to(device, dtype=weight_dtype)
|
|
def load_pretrained_weights(self, pretrained=None):
|
|
if pretrained is None:
|
|
return
|
|
print(f"Loading weights from {pretrained}")
|
|
filename = os.path.basename(pretrained)
|
|
if filename.endswith(".pt"):
|
|
checkpoint = torch.load(pretrained)
|
|
self.policy.load_state_dict(checkpoint["module"])
|
|
elif filename.endswith(".safetensors"):
|
|
from safetensors.torch import load_model
|
|
load_model(self.policy, pretrained)
|
|
else:
|
|
raise NotImplementedError(f"Unknown checkpoint format: {pretrained}")
|
|
def encode_instruction(self, instruction, device="cuda"):
|
|
tokens = self.text_tokenizer(instruction, return_tensors="pt", padding="longest",
|
|
truncation=True)["input_ids"].to(device)
|
|
tokens = tokens.view(1, -1)
|
|
with torch.no_grad():
|
|
pred = self.text_model(tokens).last_hidden_state.detach()
|
|
return pred
|
|
def _format_joint_to_state(self, joints):
|
|
# Rescale the gripper to the range of [0, 1]
|
|
joints = joints / torch.tensor(
|
|
[[[180, 180, 180, 180, 180, 180]]],
|
|
device=joints.device,
|
|
dtype=joints.dtype,
|
|
)
|
|
B, N, _ = joints.shape
|
|
state = torch.zeros(
|
|
(B, N, self.args["model"]["state_token_dim"]),
|
|
device=joints.device,
|
|
dtype=joints.dtype,
|
|
)
|
|
# Fill into the unified state vector
|
|
state[:, :, AGILEX_STATE_INDICES] = joints
|
|
# Assemble the mask indicating each dimension's availability
|
|
state_elem_mask = torch.zeros(
|
|
(B, self.args["model"]["state_token_dim"]),
|
|
device=joints.device,
|
|
dtype=joints.dtype,
|
|
)
|
|
state_elem_mask[:, AGILEX_STATE_INDICES] = 1
|
|
return state, state_elem_mask
|
|
def _unformat_action_to_joint(self, action):
|
|
action_indices = AGILEX_STATE_INDICES
|
|
joints = action[:, :, action_indices]
|
|
# Rescale the gripper back to the action range
|
|
# Note that the action range and proprioception range are different
|
|
# for Mobile ALOHA robot
|
|
joints = joints * torch.tensor(
|
|
[[[180, 180, 180, 180, 180, 180]]],
|
|
device=joints.device,
|
|
dtype=joints.dtype,
|
|
)
|
|
return joints
|
|
@torch.no_grad()
|
|
def step(self, proprio, images, text_embeds):
|
|
device = self.device
|
|
dtype = self.dtype
|
|
# The background image used for padding
|
|
background_color = np.array([int(x * 255) for x in self.image_processor.image_mean], dtype=np.uint8).reshape(1, 1, 3)
|
|
background_image = (np.ones(
|
|
(
|
|
self.image_processor.size["height"],
|
|
self.image_processor.size["width"],
|
|
3,
|
|
),
|
|
dtype=np.uint8,
|
|
) * background_color)
|
|
# Preprocess the images by order and encode them
|
|
image_tensor_list = []
|
|
for image in images:
|
|
if image is None:
|
|
# Replace it with the background image
|
|
image = PImage.fromarray(background_image)
|
|
else:
|
|
# Convert numpy array to PIL Image if needed
|
|
if isinstance(image, np.ndarray):
|
|
image = PImage.fromarray(image)
|
|
if self.image_size is not None:
|
|
image = transforms.Resize(self.data_args.image_size)(image)
|
|
if self.args["dataset"].get("auto_adjust_image_brightness", False):
|
|
pixel_values = list(image.getdata())
|
|
average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)
|
|
if average_brightness <= 0.15:
|
|
image = transforms.ColorJitter(brightness=(1.75, 1.75))(image)
|
|
if self.args["dataset"].get("image_aspect_ratio", "pad") == "pad":
|
|
def expand2square(pil_img, background_color):
|
|
width, height = pil_img.size
|
|
if width == height:
|
|
return pil_img
|
|
elif width > height:
|
|
result = PImage.new(pil_img.mode, (width, width), background_color)
|
|
result.paste(pil_img, (0, (width - height) // 2))
|
|
return result
|
|
else:
|
|
result = PImage.new(pil_img.mode, (height, height), background_color)
|
|
result.paste(pil_img, ((height - width) // 2, 0))
|
|
return result
|
|
image = expand2square(image, tuple(int(x * 255) for x in self.image_processor.image_mean))
|
|
image = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
|
|
image_tensor_list.append(image)
|
|
image_tensor = torch.stack(image_tensor_list, dim=0).to(device, dtype=dtype)
|
|
image_embeds = self.vision_model(image_tensor).detach()
|
|
image_embeds = image_embeds.reshape(-1, self.vision_model.hidden_size).unsqueeze(0)
|
|
# Prepare the proprioception states and the control frequency
|
|
joints = proprio.to(device).unsqueeze(0) # (1, 1, 14)
|
|
states, state_elem_mask = self._format_joint_to_state(joints) # (1, 1, 128), (1, 128)
|
|
states, state_elem_mask = states.to(device, dtype=dtype), state_elem_mask.to(device, dtype=dtype)
|
|
states = states[:, -1:, :] # (1, 1, 128)
|
|
ctrl_freqs = torch.tensor([self.control_frequency]).to(device)
|
|
text_embeds = text_embeds.to(device, dtype=dtype)
|
|
# Predict the next action chunk given the inputs
|
|
trajectory = self.policy.predict_action(
|
|
lang_tokens=text_embeds,
|
|
lang_attn_mask=torch.ones(text_embeds.shape[:2], dtype=torch.bool, device=text_embeds.device),
|
|
img_tokens=image_embeds,
|
|
state_tokens=states,
|
|
action_mask=state_elem_mask.unsqueeze(1),
|
|
ctrl_freqs=ctrl_freqs,
|
|
)
|
|
trajectory = self._unformat_action_to_joint(trajectory).to(torch.float32)
|
|
return trajectory
|
|
|
|
def get_training_samples(data_dirs, num_samples=5, instructions_per_episode=1):
|
|
"""
|
|
Get training samples from one or multiple data directories.
|
|
|
|
Args:
|
|
data_dirs: A single directory path (str) or a list of directory paths
|
|
num_samples: Total number of samples to generate across all directories
|
|
instructions_per_episode: Number of instructions per episode
|
|
"""
|
|
training_samples = []
|
|
|
|
# Handle both single directory and list of directories
|
|
if isinstance(data_dirs, str):
|
|
data_dirs = [data_dirs]
|
|
|
|
logger.info(f"Get Training Data From: {len(data_dirs)} dataset(s).")
|
|
|
|
# First, collect all available episode files from all directories
|
|
episode_files = []
|
|
for data_dir in data_dirs:
|
|
if not os.path.isdir(data_dir):
|
|
logger.warning(f"Directory not found: {data_dir}, skipping")
|
|
continue
|
|
for root, dirs, files in os.walk(data_dir):
|
|
for file in files:
|
|
if file.endswith('.hdf5'):
|
|
file_path = os.path.join(root, file)
|
|
episode_files.append(file_path)
|
|
|
|
if len(episode_files) == 0:
|
|
logger.warning(f"No episode files found in the provided directories")
|
|
return training_samples
|
|
|
|
logger.info(f"Found {len(episode_files)} episode files across all datasets.")
|
|
|
|
# Generate samples by randomly selecting from episodes
|
|
while len(training_samples) < num_samples:
|
|
# Randomly select an episode file
|
|
file_path = np.random.choice(episode_files)
|
|
try:
|
|
with h5py.File(file_path, 'r') as f:
|
|
observations = f['observations']
|
|
actions = f['action'][:]
|
|
images = observations['images']
|
|
qpos = observations['qpos'][:]
|
|
episode_dir = os.path.dirname(file_path)
|
|
instructions_dir = os.path.join(episode_dir, 'instructions')
|
|
num_steps = len(qpos)
|
|
if num_steps > 1: # Image部分需要左中右三帧加上对饮历史帧组成4374维
|
|
lang_step_idx = int(np.random.randint(0, max(instructions_per_episode, 1)))
|
|
instructions_dir = os.path.join(os.path.dirname(file_path), "instructions")
|
|
|
|
lang_embed, lang_str = None, None
|
|
# lang embed (optional)
|
|
lang_embed_path = os.path.join(instructions_dir, f"lang_embed_{lang_step_idx}.pt")
|
|
if os.path.exists(lang_embed_path):
|
|
try:
|
|
lang_embed = torch.load(lang_embed_path, map_location="cpu")
|
|
except Exception as e:
|
|
logger.error(f"Error reading {lang_embed_path}: {e}")
|
|
|
|
# lang string (optional)
|
|
lang_str_path = os.path.join(instructions_dir, f"txt_lang_embed_{lang_step_idx}.txt")
|
|
if os.path.exists(lang_str_path):
|
|
try:
|
|
with open(lang_str_path, "r", encoding="utf-8") as tf:
|
|
lang_str = tf.read().strip()
|
|
except Exception as e:
|
|
logger.error(f"Error reading {lang_str_path}: {e}")
|
|
lang_str = lang_str or ""
|
|
|
|
# 获取多摄像头多历史帧图像
|
|
step_idx = np.random.randint(0, num_steps)
|
|
multi_cam_images = {}
|
|
ref_frame = images['cam_high'][0]
|
|
ref_img = cv2.imdecode(np.frombuffer(ref_frame, np.uint8), cv2.IMREAD_COLOR)
|
|
IMG_HEIGHT, IMG_WIDTH = ref_img.shape[:2]
|
|
# IMG_HEIGHT, IMG_WIDTH = images['cam_high'][0].shape[:2]
|
|
ground_image = np.zeros((IMG_HEIGHT, IMG_WIDTH, 3), dtype=np.uint8)
|
|
for cam_name in ['cam_high', 'cam_left_wrist', 'cam_right_wrist']:
|
|
|
|
if cam_name in images:
|
|
cam_images = []
|
|
# 获取2个历史帧的图像
|
|
for i in range(max(step_idx - 1, 0), step_idx + 1): # 2个历史帧
|
|
img_bits = images[cam_name][i]
|
|
img = cv2.imdecode(np.frombuffer(img_bits, np.uint8), cv2.IMREAD_COLOR)
|
|
# img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
cam_images.append(img)
|
|
if len(cam_images) < 2:
|
|
cam_images = [cam_images[0]] * 2
|
|
multi_cam_images[cam_name] = cam_images
|
|
else:
|
|
cam_images = []
|
|
for i in range(max(step_idx - 1, 0), step_idx + 1): # 2个历史帧
|
|
img_bits = ground_image
|
|
# img = cv2.imdecode(np.frombuffer(img_bits, np.uint8), cv2.IMREAD_COLOR)
|
|
cam_images.append(img_bits)
|
|
if len(cam_images) < 2:
|
|
cam_images = [cam_images[0]] * 2
|
|
multi_cam_images[cam_name] = cam_images
|
|
training_samples.append({
|
|
'multi_cam_images': multi_cam_images,
|
|
'joints': actions[step_idx],
|
|
'lang_embed': lang_embed,
|
|
'lang_str': lang_str,
|
|
'source': file_path,
|
|
'step': step_idx
|
|
})
|
|
logger.debug(f"TimeStep: {step_idx}, Sample: {file_path}")
|
|
except Exception as e:
|
|
logger.error(f"Faild: {file_path} : {e}")
|
|
continue
|
|
logger.info(f"Total Num: {len(training_samples)}.")
|
|
return training_samples
|
|
|
|
|
|
def main(config_path):
|
|
with open(config_path, "r") as f:
|
|
cfg = json.load(f)
|
|
|
|
export_info = cfg.get("export", {})
|
|
|
|
opt = ExportConfig(
|
|
task_id=cfg.get("task_id"),
|
|
output_path=os.path.join(export_info.get("output_path", "."), cfg.get("task_id", "")),
|
|
model_path=export_info.get("model_path"),
|
|
calibration_num=export_info.get("calibration_num", 100),
|
|
dataset_path=export_info.get("dataset_path"),
|
|
gpu_id=cfg.get("gpu_id", "0"),
|
|
march=export_info.get("march"),
|
|
model_type=export_info.get("model_type"),
|
|
pretrained_vision_encoder_name_or_path="/home/qi.xiong/DualArm/Work_Docker/RDT/weights/siglip-so400m-patch14-384",
|
|
ctrl_freq=export_info.get("ctrl_freq", 25),
|
|
cal_data_device=cfg.get("cal_data_device", "cuda"),
|
|
lang_calibration_num=export_info.get("lang_calibration_num", 1)
|
|
)
|
|
|
|
if opt.model_type not in ["170M", "1B"]:
|
|
raise ValueError(f"RDT ONLY SUPPORT 170M AND 1B, BUT GOT {opt.model_type}")
|
|
|
|
logger.info(f"Export config loaded: {opt}")
|
|
os.makedirs(opt.output_path, exist_ok=True)
|
|
|
|
# PrePare Output Workspace
|
|
## BPU_RDT_Policy
|
|
bpu_rdt_name = "BPU_RDT_Policy_170M" if opt.model_type == "170M" else "BPU_RDT_Policy_1B"
|
|
bpu_rdt_path = os.path.join(opt.output_path, bpu_rdt_name)
|
|
os.makedirs(bpu_rdt_path, exist_ok=True)
|
|
os.system(f"cp configs/base_{opt.model_type}.yaml {bpu_rdt_path}/base.yaml")
|
|
rdt_config_path = os.path.join(bpu_rdt_path, "base.yaml")
|
|
## Test_Datas
|
|
test_data_name = "test_data"
|
|
test_data_path = os.path.join(opt.output_path, test_data_name)
|
|
os.makedirs(test_data_path, exist_ok=True)
|
|
## instruction
|
|
instruction_ws_name = "instructions"
|
|
instruction_ws_path = os.path.join(opt.output_path, instruction_ws_name)
|
|
os.makedirs(instruction_ws_path, exist_ok=True)
|
|
for name in os.listdir(opt.dataset_path):
|
|
os.makedirs(os.path.join(instruction_ws_path, name), exist_ok=True)
|
|
|
|
## image adaptor
|
|
global img_adaptor_cal_ws
|
|
img_adaptor_ws_name = "img_adaptor_WorkSpace"
|
|
img_adaptor_cal_name = "rdt_image_adaptor_calibration"
|
|
img_adaptor_name = "rdt_image_adaptor.onnx"
|
|
img_adaptor_config_name = "config.yaml"
|
|
img_adaptor_ws = os.path.join(opt.output_path, img_adaptor_ws_name)
|
|
img_adaptor_path = os.path.join(img_adaptor_ws, img_adaptor_name)
|
|
img_adaptor_cal_ws = os.path.join(img_adaptor_ws, img_adaptor_cal_name)
|
|
os.makedirs(img_adaptor_ws, exist_ok=True)
|
|
os.makedirs(img_adaptor_cal_ws, exist_ok=True)
|
|
|
|
## action adaptor
|
|
state_adaptor_name1 = "rdt_state_adaptor_1x1x256.onnx"
|
|
state_adaptor_path1 = os.path.join(opt.output_path, bpu_rdt_name, state_adaptor_name1)
|
|
state_adaptor_name2 = "rdt_state_adaptor_1x64x256.onnx"
|
|
state_adaptor_path2 = os.path.join(opt.output_path, bpu_rdt_name, state_adaptor_name2)
|
|
|
|
## lang adaptor
|
|
lang_adaptor_name = "rdt_lang_adaptor.onnx"
|
|
lang_adaptor_path = os.path.join(opt.output_path, bpu_rdt_name, lang_adaptor_name)
|
|
|
|
## DiT Policy
|
|
dit_ws_name = "DiT_WorkSpace"
|
|
dit_cal_name = "rdt_dit_calibration"
|
|
dit_name = "rdt_dit.onnx"
|
|
dit_config_name = "config.yaml"
|
|
dit_json_name = "quant_config.json"
|
|
dit_ws = os.path.join(opt.output_path, dit_ws_name)
|
|
dit_path = os.path.join(dit_ws, dit_name)
|
|
dit_cal_ws = os.path.join(dit_ws, dit_cal_name)
|
|
os.makedirs(dit_ws, exist_ok=True)
|
|
os.makedirs(dit_cal_ws, exist_ok=True)
|
|
|
|
global dit_cal_path_x, dit_cal_path_freq, dit_cal_path_t, dit_cal_path_lang_c, dit_cal_path_img_c, dit_cal_path_lang_mask
|
|
dit_cal_path_x = os.path.join(dit_cal_ws, "x")
|
|
os.makedirs(dit_cal_path_x, exist_ok=True)
|
|
dit_cal_path_freq = os.path.join(dit_cal_ws, "freq")
|
|
os.makedirs(dit_cal_path_freq, exist_ok=True)
|
|
dit_cal_path_t = os.path.join(dit_cal_ws, "t")
|
|
os.makedirs(dit_cal_path_t, exist_ok=True)
|
|
dit_cal_path_lang_c = os.path.join(dit_cal_ws, "lang_c")
|
|
os.makedirs(dit_cal_path_lang_c, exist_ok=True)
|
|
dit_cal_path_img_c = os.path.join(dit_cal_ws, "img_c")
|
|
os.makedirs(dit_cal_path_img_c, exist_ok=True)
|
|
dit_cal_path_lang_mask = os.path.join(dit_cal_ws, "lang_mask")
|
|
os.makedirs(dit_cal_path_lang_mask, exist_ok=True)
|
|
|
|
# Prepare Calibrate Data
|
|
with open(rdt_config_path, "r") as f:
|
|
rdt_config = yaml.safe_load(f)
|
|
|
|
dump_model = create_dump_model(
|
|
args=rdt_config,
|
|
dtype=torch.float32,
|
|
pretrained=opt.model_path,
|
|
pretrained_vision_encoder_name_or_path=opt.pretrained_vision_encoder_name_or_path,
|
|
control_frequency=opt.ctrl_freq,
|
|
device=opt.cal_data_device
|
|
)
|
|
|
|
# Prepare Calbriation Data
|
|
# load training data from all datasets
|
|
global dump_cnt, dump_dataset_name
|
|
test_data_cnt = 0
|
|
|
|
# Collect all dataset paths
|
|
all_dataset_paths = []
|
|
for dump_dataset_name in os.listdir(opt.dataset_path):
|
|
dump_dataset_path = os.path.join(opt.dataset_path, dump_dataset_name)
|
|
if os.path.isdir(dump_dataset_path):
|
|
all_dataset_paths.append(dump_dataset_path)
|
|
|
|
# Get training samples from all datasets together
|
|
training_samples = get_training_samples(all_dataset_paths, num_samples=opt.calibration_num, instructions_per_episode=opt.lang_calibration_num)
|
|
|
|
if len(training_samples) == 0:
|
|
logger.warning("No training samples found, skipping calibration data generation")
|
|
else:
|
|
# Only process up to the number of samples we actually have
|
|
num_samples_to_process = min(len(training_samples), opt.calibration_num)
|
|
for dump_cnt in range(num_samples_to_process):
|
|
sample = training_samples[dump_cnt]
|
|
# Extract dataset name from the sample's source path
|
|
sample_source = sample['source']
|
|
dump_dataset_name = os.path.basename(os.path.dirname(os.path.dirname(sample_source)))
|
|
instruction_emb = {
|
|
"lang_cond": sample["lang_embed"].float().cpu(),
|
|
"lang_str": sample["lang_str"]
|
|
}
|
|
ins_str_name = sample["lang_str"].replace(" ", "_") + "__"
|
|
torch.save(instruction_emb, os.path.join(instruction_ws_path, dump_dataset_name, f"{ins_str_name}.pt"))
|
|
image_arrs = [
|
|
sample['multi_cam_images']['cam_high'][0],
|
|
sample['multi_cam_images']['cam_right_wrist'][0],
|
|
sample['multi_cam_images']['cam_left_wrist'][0],
|
|
sample['multi_cam_images']['cam_high'][1],
|
|
sample['multi_cam_images']['cam_right_wrist'][1],
|
|
sample['multi_cam_images']['cam_left_wrist'][1],
|
|
]
|
|
test_data_cnt += 1
|
|
np.save(os.path.join(test_data_path, f"{test_data_cnt}_cam_high_0.npy"), sample['multi_cam_images']['cam_high'][0])
|
|
np.save(os.path.join(test_data_path, f"{test_data_cnt}_cam_right_wrist_0.npy"), sample['multi_cam_images']['cam_right_wrist'][0])
|
|
np.save(os.path.join(test_data_path, f"{test_data_cnt}_cam_left_wrist_0.npy"), sample['multi_cam_images']['cam_left_wrist'][0])
|
|
np.save(os.path.join(test_data_path, f"{test_data_cnt}_cam_high_1.npy"), sample['multi_cam_images']['cam_high'][1])
|
|
np.save(os.path.join(test_data_path, f"{test_data_cnt}_cam_right_wrist_1.npy"), sample['multi_cam_images']['cam_right_wrist'][1])
|
|
np.save(os.path.join(test_data_path, f"{test_data_cnt}_cam_left_wrist_1.npy"), sample['multi_cam_images']['cam_left_wrist'][1])
|
|
images = [PImage.fromarray(arr) if arr is not None else None for arr in image_arrs]
|
|
proprio = torch.from_numpy(sample['joints']).float().unsqueeze(0).to(opt.cal_data_device)
|
|
np.save(os.path.join(test_data_path, f"{test_data_cnt}_joints.npy"), sample['joints'])
|
|
lang_embeddings = sample['lang_embed'].float().unsqueeze(0).to(opt.cal_data_device)
|
|
torch.save(lang_embeddings, os.path.join(test_data_path, f"{test_data_cnt}_lang_embeddings.pt"))
|
|
dump_model.reset()
|
|
begin_time = time()
|
|
actions = dump_model.step(proprio=proprio, images=images, text_embeds=lang_embeddings).squeeze(0).cpu().numpy()
|
|
np.save(os.path.join(test_data_path, f"{test_data_cnt}_actions.npy"), actions)
|
|
logger.debug(f"Dump: Cost {(1000*(time() - begin_time)):.1f} ms, cnt: {dump_cnt}, name: {dump_dataset_name}")
|
|
logger.info("End Generate Calibration Data.")
|
|
del dump_model
|
|
|
|
# Load RDT Policy: CPU Model For ONNX Export
|
|
with open(rdt_config_path, "r") as f:
|
|
rdt_config = yaml.safe_load(f)
|
|
|
|
model = create_model(
|
|
args=rdt_config,
|
|
dtype=torch.float32,
|
|
pretrained=opt.model_path,
|
|
pretrained_vision_encoder_name_or_path=opt.pretrained_vision_encoder_name_or_path,
|
|
control_frequency=opt.ctrl_freq,
|
|
device="cpu"
|
|
)
|
|
|
|
# image adaptor: ONNX Model
|
|
m = model.policy.img_adaptor
|
|
m.eval()
|
|
|
|
input_data = torch.randn(1, 4374, rdt_config['model']['img_token_dim']) # 假设批量大小为1
|
|
output = m(input_data)
|
|
|
|
torch.onnx.export(
|
|
m,
|
|
input_data,
|
|
img_adaptor_path,
|
|
opset_version=14,
|
|
do_constant_folding=True,
|
|
input_names=["img_tokens"],
|
|
output_names=["adapted_img"],
|
|
dynamic_axes=None,
|
|
verbose=False
|
|
)
|
|
logger.info("Export RDT [img_adaptor] Model Success.")
|
|
|
|
# DiT
|
|
hidden_size = rdt_config['model']["rdt"]['hidden_size']
|
|
|
|
m = model.policy.model
|
|
m = m.eval().cpu()
|
|
x = torch.randn(1, 65, hidden_size)
|
|
freq = torch.tensor([1], dtype=torch.int32)
|
|
t = torch.tensor([10], dtype=torch.int32)
|
|
lang_c = torch.randn(1, 64, hidden_size)
|
|
img_c = torch.randn(1, 4374, hidden_size)
|
|
lang_mask = torch.ones(1, 64, dtype=torch.float32)
|
|
dummy_inputs = (x, freq, t, lang_c, img_c, lang_mask)
|
|
# outputs = m(x, freq, t, lang_c, img_c, lang_mask)
|
|
torch.onnx.export(
|
|
m,
|
|
dummy_inputs,
|
|
dit_path,
|
|
# export_params=True,
|
|
opset_version=14,
|
|
do_constant_folding=True,
|
|
input_names=["x", "freq", "t", "lang_c", "img_c", "lang_mask"],
|
|
output_names=["actions"],
|
|
verbose=False
|
|
)
|
|
|
|
logger.info("Export RDT [DiT] Model Success.")
|
|
|
|
# state adaptor
|
|
m = model.policy.state_adaptor
|
|
m.eval()
|
|
|
|
input_data = torch.randn(1, 1, 256) # 假设批量大小为1
|
|
output = m(input_data)
|
|
|
|
torch.onnx.export(
|
|
m,
|
|
input_data,
|
|
state_adaptor_path1,
|
|
export_params=True,
|
|
opset_version=14,
|
|
do_constant_folding=True,
|
|
input_names=["state_tokens"],
|
|
output_names=["state_traj"],
|
|
dynamic_axes=None,
|
|
verbose=False
|
|
)
|
|
|
|
logging.info("Export RDT [state 1x1x256] Model Success.")
|
|
|
|
input_data = torch.randn(1, 64, 256)
|
|
output = m(input_data)
|
|
|
|
torch.onnx.export(
|
|
m,
|
|
input_data,
|
|
state_adaptor_path2,
|
|
export_params=True,
|
|
opset_version=14,
|
|
do_constant_folding=True,
|
|
input_names=['state_tokens'],
|
|
output_names=['state_traj'],
|
|
dynamic_axes=None,
|
|
verbose=False
|
|
)
|
|
|
|
logging.info("Export RDT [state 1x64x256] Model Success.")
|
|
|
|
# lang adaptor
|
|
|
|
m = model.policy.lang_adaptor
|
|
m.eval()
|
|
|
|
input_data = torch.randn(1, 14, 4096)
|
|
output = m(input_data)
|
|
|
|
torch.onnx.export(
|
|
m,
|
|
input_data,
|
|
lang_adaptor_path,
|
|
export_params=True,
|
|
opset_version=14,
|
|
do_constant_folding=True,
|
|
input_names=["text_embeds"],
|
|
output_names=["lang_cond"],
|
|
dynamic_axes={
|
|
"text_embeds": {1: "N"},
|
|
"lang_cond": {1: "N"}
|
|
},
|
|
verbose=False
|
|
)
|
|
|
|
logger.info("Export RDT [lang adaptor] Model Success.")
|
|
|
|
######## Prepare Calbibration Data
|
|
|
|
if __name__ == "__main__":
|
|
main("/home/qi.xiong/DualArm/Work_Docker/RDT/rdt-export/input/config.json")
|
|
logger.info("All Models Have Been Exported Success.")
|
|
|