import os import re import json import logging import argparse from time import time from collections import OrderedDict from dataclasses import dataclass import yaml import cv2 import numpy as np import torch import h5py from PIL import Image as PImage import onnx import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler from scripts.agilex_model import create_model from configs.state_vec import STATE_VEC_IDX_MAPPING from models.hub_mixin import CompatiblePyTorchModelHubMixin from models.rdt.blocks import (FinalLayer, RDTBlock, TimestepEmbedder, get_1d_sincos_pos_embed_from_grid, get_multimodal_cond_pos_embed) from models.multimodal_encoder.siglip_encoder import SiglipVisionTower from models.multimodal_encoder.t5_encoder import T5Embedder logging.basicConfig( level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%H:%M:%S') logger = logging.getLogger("RDT_EXPORT") os.environ["WANDB_MODE"] = "disabled" @dataclass class ExportConfig: task_id: str = None output_path: str = None model_path: str = None calibration_num: int = 100 lang_calibration_num: int = 1 dataset_path: str = None gpu_id: str = "0" march: str = None model_type: str = None pretrained_vision_encoder_name_or_path: str = None ctrl_freq: int = 25 cal_data_device: str = "cuda" AGILEX_STATE_INDICES = [ STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(6) ] def dump_img_adaptor(img_tokens): global img_adaptor_cal_ws global dump_cnt, dump_dataset_name np.save(os.path.join(img_adaptor_cal_ws, f"img_adaptor_{dump_dataset_name}_{dump_cnt}.npy"), img_tokens.float().contiguous().cpu().detach().numpy()) def dump_dit(state_action_traj, ctrl_freqs, t, lang_cond, img_cond, lang_attn_mask): t_str = str(t) x = state_action_traj.float().contiguous().cpu().detach().numpy() freq = ctrl_freqs.float().contiguous().cpu().detach().numpy().astype(np.int32).copy() t_ = t.float().contiguous().cpu().detach().numpy() t_ = np.expand_dims(t_.astype(np.int32), axis=0).copy() lang_c = lang_cond.float().contiguous().cpu().detach().numpy() img_c = img_cond.float().contiguous().cpu().detach().numpy() lang_mask = lang_attn_mask.float().contiguous().cpu().detach().numpy() pad_rows = 64 - lang_mask.shape[1] padded = np.pad(lang_mask, ((0,0), (0,pad_rows)), mode="constant") mask_float = np.where(padded, 0.0, -512.0).astype(np.float32) lang_cond_padded = np.pad(lang_c, pad_width=((0, 0), (0, pad_rows), (0,0)), mode="constant", constant_values=0) global dit_cal_path_x, dit_cal_path_freq, dit_cal_path_t, dit_cal_path_lang_c, dit_cal_path_img_c, dit_cal_path_lang_mask global dump_cnt, dump_dataset_name np.save(os.path.join(dit_cal_path_x, f"x_{t_str}_{dump_dataset_name}_{dump_cnt}.npy"), x) np.save(os.path.join(dit_cal_path_freq, f"freq_{t_str}_{dump_dataset_name}_{dump_cnt}.npy"), freq) np.save(os.path.join(dit_cal_path_t, f"t_{t_str}_{dump_dataset_name}_{dump_cnt}.npy"), t_) np.save(os.path.join(dit_cal_path_lang_c, f"lang_c_{t_str}_{dump_dataset_name}_{dump_cnt}.npy"), lang_cond_padded) np.save(os.path.join(dit_cal_path_img_c, f"img_{t_str}_{dump_dataset_name}_{dump_cnt}.npy"), img_c) np.save(os.path.join(dit_cal_path_lang_mask, f"lang_mask_{t_str}_{dump_dataset_name}_{dump_cnt}.npy"), mask_float) def create_dump_model(args, **kwargs): # left_arm_dim, right_arm_dim = (args["arm_dim"]["left_arm_dim"], args["arm_dim"]["right_arm_dim"],) # AGILEX_STATE_INDICES = ([STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"] # for i in range(left_arm_dim)] + [STATE_VEC_IDX_MAPPING["left_gripper_open"]] + # [STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] # for i in range(right_arm_dim)] + [STATE_VEC_IDX_MAPPING[f"right_gripper_open"]]) model = RoboticDiffusionTransformerModel_Dump(args, **kwargs) pretrained = kwargs.get("pretrained", None) if pretrained is not None and os.path.isfile(pretrained): model.load_pretrained_weights(pretrained) return model class RDT_Dump(nn.Module): def __init__(self, output_dim=128, horizon=32, hidden_size=1152, depth=28, num_heads=16, max_lang_cond_len=1024, img_cond_len=4096, lang_pos_embed_config=None, img_pos_embed_config=None, dtype=torch.bfloat16): super().__init__() self.horizon = horizon self.hidden_size = hidden_size self.max_lang_cond_len = max_lang_cond_len self.img_cond_len = img_cond_len self.dtype = dtype self.lang_pos_embed_config = lang_pos_embed_config self.img_pos_embed_config = img_pos_embed_config self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype) self.freq_embedder = TimestepEmbedder(hidden_size, dtype=dtype) # We will use trainable sin-cos embeddings # [timestep; state; action] self.x_pos_embed = nn.Parameter(torch.zeros(1, horizon + 3, hidden_size)) # Language conditions self.lang_cond_pos_embed = nn.Parameter(torch.zeros(1, max_lang_cond_len, hidden_size)) # Image conditions self.img_cond_pos_embed = nn.Parameter(torch.zeros(1, img_cond_len, hidden_size)) self.blocks = nn.ModuleList([RDTBlock(hidden_size, num_heads) for _ in range(depth)]) self.final_layer = FinalLayer(hidden_size, output_dim) self.initialize_weights() def initialize_weights(self): # Initialize transformer layers: def _basic_init(module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) # Initialize pos_embed by sin-cos embedding x_pos_embed = get_multimodal_cond_pos_embed(embed_dim=self.hidden_size, mm_cond_lens=OrderedDict([ ('timestep', 1), ('ctrl_freq', 1), ('state', 1), ('action', self.horizon), ])) self.x_pos_embed.data.copy_(torch.from_numpy(x_pos_embed).float().unsqueeze(0)) if self.lang_pos_embed_config is None: lang_cond_pos_embed = get_1d_sincos_pos_embed_from_grid(self.hidden_size, torch.arange(self.max_lang_cond_len)) else: lang_cond_pos_embed = get_multimodal_cond_pos_embed(embed_dim=self.hidden_size, mm_cond_lens=OrderedDict(self.lang_pos_embed_config), embed_modality=False) self.lang_cond_pos_embed.data.copy_(torch.from_numpy(lang_cond_pos_embed).float().unsqueeze(0)) if self.img_pos_embed_config is None: img_cond_pos_embed = get_1d_sincos_pos_embed_from_grid(self.hidden_size, torch.arange(self.img_cond_len)) else: img_cond_pos_embed = get_multimodal_cond_pos_embed(embed_dim=self.hidden_size, mm_cond_lens=OrderedDict(self.img_pos_embed_config), embed_modality=False) self.img_cond_pos_embed.data.copy_(torch.from_numpy(img_cond_pos_embed).float().unsqueeze(0)) # Initialize timestep and control freq embedding MLP nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) nn.init.normal_(self.freq_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.freq_embedder.mlp[2].weight, std=0.02) # Initialize the final layer: zero-out the final linear layer nn.init.constant_(self.final_layer.ffn_final.fc2.weight, 0) nn.init.constant_(self.final_layer.ffn_final.fc2.bias, 0) # Move all the params to given data type: self.to(self.dtype) def forward(self, x, freq, t, lang_c, img_c, lang_mask=None, img_mask=None): t = self.t_embedder(t).unsqueeze(1) # (B, 1, D) or (1, 1, D) freq = self.freq_embedder(freq).unsqueeze(1) # (B, 1, D) # Append timestep to the input tokens if t.shape[0] == 1: t = t.expand(x.shape[0], -1, -1) x = torch.cat([t, freq, x], dim=1) # (B, T+1, D) # Add multimodal position embeddings x = x + self.x_pos_embed # Note the lang is of variable length lang_c = lang_c + self.lang_cond_pos_embed[:, :lang_c.shape[1]] img_c = img_c + self.img_cond_pos_embed # Forward pass conds = [lang_c, img_c] masks = [lang_mask, img_mask] for i, block in enumerate(self.blocks): c, mask = conds[i % 2], masks[i % 2] x = block(x, c, mask) # (B, T+1, D) # Inject the language condition at the final layer x = self.final_layer(x) # (B, T+1, out_channels) # Only preserve the action tokens x = x[:, -self.horizon:] return x class RDTRunner_Dump(nn.Module, CompatiblePyTorchModelHubMixin, repo_url="https://huggingface.co/robotics-diffusion-transformer/rdt-1b"): def __init__(self, *, action_dim, pred_horizon, config, lang_token_dim, img_token_dim, state_token_dim, max_lang_cond_len, img_cond_len, lang_pos_embed_config=None, img_pos_embed_config=None, dtype=torch.bfloat16): super(RDTRunner_Dump, self).__init__() # Create diffusion model hidden_size = config['rdt']['hidden_size'] self.model = RDT_Dump( output_dim=action_dim, horizon=pred_horizon, hidden_size=hidden_size, depth=config['rdt']['depth'], num_heads=config['rdt']['num_heads'], max_lang_cond_len=max_lang_cond_len, img_cond_len=img_cond_len, lang_pos_embed_config=lang_pos_embed_config, img_pos_embed_config=img_pos_embed_config, dtype=dtype, ) # Create adpators for various conditional inputs self.lang_adaptor = self.build_condition_adapter(config['lang_adaptor'], in_features=lang_token_dim, out_features=hidden_size) self.img_adaptor = self.build_condition_adapter(config['img_adaptor'], in_features=img_token_dim, out_features=hidden_size) # A `state` refers to an action or a proprioception vector self.state_adaptor = self.build_condition_adapter( config['state_adaptor'], in_features=state_token_dim * 2, # state + state mask (indicator) out_features=hidden_size) # Create the noise scheduler noise_scheduler_config = config['noise_scheduler'] self.noise_scheduler = DDPMScheduler( num_train_timesteps=noise_scheduler_config['num_train_timesteps'], beta_schedule=noise_scheduler_config['beta_schedule'], prediction_type=noise_scheduler_config['prediction_type'], clip_sample=noise_scheduler_config['clip_sample'], ) self.noise_scheduler_sample = DPMSolverMultistepScheduler( num_train_timesteps=noise_scheduler_config['num_train_timesteps'], beta_schedule=noise_scheduler_config['beta_schedule'], prediction_type=noise_scheduler_config['prediction_type'], ) self.num_train_timesteps = noise_scheduler_config['num_train_timesteps'] self.num_inference_timesteps = noise_scheduler_config['num_inference_timesteps'] self.prediction_type = noise_scheduler_config['prediction_type'] self.pred_horizon = pred_horizon self.action_dim = action_dim print("Diffusion params: %e" % sum([p.numel() for p in self.model.parameters()] + [p.numel() for p in self.lang_adaptor.parameters()] + [p.numel() for p in self.img_adaptor.parameters()] + [p.numel() for p in self.state_adaptor.parameters()])) def build_condition_adapter(self, projector_type, in_features, out_features): projector = None if projector_type == 'linear': projector = nn.Linear(in_features, out_features) else: mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) if mlp_gelu_match: mlp_depth = int(mlp_gelu_match.group(1)) modules = [nn.Linear(in_features, out_features)] for _ in range(1, mlp_depth): modules.append(nn.GELU(approximate="tanh")) modules.append(nn.Linear(out_features, out_features)) projector = nn.Sequential(*modules) if projector is None: raise ValueError(f'Unknown projector type: {projector_type}') return projector def adapt_conditions(self, lang_tokens, img_tokens, state_tokens): adpated_lang = self.lang_adaptor(lang_tokens) dump_img_adaptor(img_tokens) adpated_img = self.img_adaptor(img_tokens) adpated_state = self.state_adaptor(state_tokens) return adpated_lang, adpated_img, adpated_state def conditional_sample(self, lang_cond, lang_attn_mask, img_cond, state_traj, action_mask, ctrl_freqs): device = state_traj.device dtype = state_traj.dtype noisy_action = torch.randn(size=(state_traj.shape[0], self.pred_horizon, self.action_dim), dtype=dtype, device=device) action_mask = action_mask.expand(-1, self.pred_horizon, -1) # Set step values self.noise_scheduler_sample.set_timesteps(self.num_inference_timesteps) for t in self.noise_scheduler_sample.timesteps: # Prepare state-action trajectory action_traj = torch.cat([noisy_action, action_mask], dim=2) action_traj = self.state_adaptor(action_traj) state_action_traj = torch.cat([state_traj, action_traj], dim=1) # dump dump_dit(state_action_traj, ctrl_freqs, t, lang_cond, img_cond, lang_attn_mask) # Predict the model output model_output = self.model(state_action_traj, ctrl_freqs, t.unsqueeze(-1).to(device), lang_cond, img_cond, lang_mask=lang_attn_mask) # Compute previous actions: x_t -> x_t-1 noisy_action = self.noise_scheduler_sample.step(model_output, t, noisy_action).prev_sample noisy_action = noisy_action.to(state_traj.dtype) # Finally apply the action mask to mask invalid action dimensions noisy_action = noisy_action * action_mask return noisy_action def compute_loss(self, lang_tokens, lang_attn_mask, img_tokens, state_tokens, action_gt, action_mask, ctrl_freqs) -> torch.Tensor: ''' lang_tokens: (batch_size, lang_len, lang_token_dim) lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens, which should be True-False bool tensor. img_tokens: (batch_size, img_len, img_token_dim) state_tokens: (batch_size, 1, state_token_dim) action_gt: (batch_size, horizon, state_token_dim), ground-truth actions for supervision action_mask: (batch_size, 1, state_token_dim), a 0-1 **float** tensor. ctrl_freqs: (batch_size,), control frequency for each sample. return: loss_value, a scalar tensor ''' batch_size = lang_tokens.shape[0] device = lang_tokens.device # Sample noise that we'll add to the actions noise = torch.randn(action_gt.shape, dtype=action_gt.dtype, device=device) # Sample random diffusion timesteps timesteps = torch.randint(0, self.num_train_timesteps, (batch_size, ), device=device).long() # Add noise to the clean actions according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_action = self.noise_scheduler.add_noise(action_gt, noise, timesteps) # Concatenate the state and action tokens to form the input sequence state_action_traj = torch.cat([state_tokens, noisy_action], dim=1) # Append the action mask to the input sequence action_mask = action_mask.expand(-1, state_action_traj.shape[1], -1) state_action_traj = torch.cat([state_action_traj, action_mask], dim=2) # Align the dimension with the hidden size lang_cond, img_cond, state_action_traj = self.adapt_conditions(lang_tokens, img_tokens, state_action_traj) # Predict the denoised result pred = self.model(state_action_traj, ctrl_freqs, timesteps, lang_cond, img_cond, lang_mask=lang_attn_mask) pred_type = self.prediction_type if pred_type == 'epsilon': target = noise elif pred_type == 'sample': target = action_gt else: raise ValueError(f"Unsupported prediction type {pred_type}") loss = F.mse_loss(pred, target) return loss def predict_action(self, lang_tokens, lang_attn_mask, img_tokens, state_tokens, action_mask, ctrl_freqs): ''' lang_tokens: (batch_size, lang_len, lang_token_dim) lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens, which should be True-False bool tensor. img_tokens: (batch_size, img_len, img_token_dim) state_tokens: (batch_size, 1, state_token_dim) action_mask: (batch_size, 1, action_dim), which should be a 0-1 **float** tensor. ctrl_freqs: (batch_size,), control frequency for each sample. return: (batch_size, horizon, action_dim), predicted action sequence ''' # Prepare the state and conditions state_tokens = torch.cat([state_tokens, action_mask], dim=2) lang_cond, img_cond, state_traj = self.adapt_conditions(lang_tokens, img_tokens, state_tokens) # Run sampling action_pred = self.conditional_sample( lang_cond, lang_attn_mask, img_cond, state_traj, action_mask, ctrl_freqs, ) return action_pred def forward(self, *args, **kwargs) -> torch.Tensor: return self.compute_loss(*args, **kwargs) class RoboticDiffusionTransformerModel_Dump(object): def __init__( self, args, device="cuda", dtype=torch.bfloat16, image_size=None, control_frequency=25, pretrained=None, pretrained_vision_encoder_name_or_path=None, ): self.args = args self.dtype = dtype self.image_size = image_size self.device = device self.control_frequency = control_frequency # We do not use the text encoder due to limited GPU memory # self.text_tokenizer, self.text_model = self.get_text_encoder(pretrained_text_encoder_name_or_path) self.image_processor, self.vision_model = self.get_vision_encoder(pretrained_vision_encoder_name_or_path) self.policy = self.get_policy(pretrained) self.reset() def get_policy(self, pretrained): # Initialize model with arguments if pretrained is None or os.path.isfile(pretrained): img_cond_len = (self.args["common"]["img_history_size"] * self.args["common"]["num_cameras"] * self.vision_model.num_patches) _model = RDTRunner_Dump( action_dim=self.args["common"]["state_dim"], pred_horizon=self.args["common"]["action_chunk_size"], config=self.args["model"], lang_token_dim=self.args["model"]["lang_token_dim"], img_token_dim=self.args["model"]["img_token_dim"], state_token_dim=self.args["model"]["state_token_dim"], max_lang_cond_len=self.args["dataset"]["tokenizer_max_length"], img_cond_len=img_cond_len, img_pos_embed_config=[ # No initial pos embed in the last grid size # since we've already done in ViT ( "image", ( self.args["common"]["img_history_size"], self.args["common"]["num_cameras"], -self.vision_model.num_patches, ), ), ], lang_pos_embed_config=[ # Similarly, no initial pos embed for language ("lang", -self.args["dataset"]["tokenizer_max_length"]), ], dtype=self.dtype, ) else: _model = RDTRunner_Dump.from_pretrained(pretrained) return _model def get_text_encoder(self, pretrained_text_encoder_name_or_path): text_embedder = T5Embedder( from_pretrained=pretrained_text_encoder_name_or_path, model_max_length=self.args["dataset"]["tokenizer_max_length"], device=self.device, ) tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model return tokenizer, text_encoder def get_vision_encoder(self, pretrained_vision_encoder_name_or_path): vision_encoder = SiglipVisionTower(vision_tower=pretrained_vision_encoder_name_or_path, args=None) image_processor = vision_encoder.image_processor return image_processor, vision_encoder def reset(self): device = self.device weight_dtype = self.dtype self.policy.eval() # self.text_model.eval() self.vision_model.eval() self.policy = self.policy.to(device, dtype=weight_dtype) # self.text_model = self.text_model.to(device, dtype=weight_dtype) self.vision_model = self.vision_model.to(device, dtype=weight_dtype) def load_pretrained_weights(self, pretrained=None): if pretrained is None: return print(f"Loading weights from {pretrained}") filename = os.path.basename(pretrained) if filename.endswith(".pt"): checkpoint = torch.load(pretrained) self.policy.load_state_dict(checkpoint["module"]) elif filename.endswith(".safetensors"): from safetensors.torch import load_model load_model(self.policy, pretrained) else: raise NotImplementedError(f"Unknown checkpoint format: {pretrained}") def encode_instruction(self, instruction, device="cuda"): tokens = self.text_tokenizer(instruction, return_tensors="pt", padding="longest", truncation=True)["input_ids"].to(device) tokens = tokens.view(1, -1) with torch.no_grad(): pred = self.text_model(tokens).last_hidden_state.detach() return pred def _format_joint_to_state(self, joints): # Rescale the gripper to the range of [0, 1] joints = joints / torch.tensor( [[[180, 180, 180, 180, 180, 180]]], device=joints.device, dtype=joints.dtype, ) B, N, _ = joints.shape state = torch.zeros( (B, N, self.args["model"]["state_token_dim"]), device=joints.device, dtype=joints.dtype, ) # Fill into the unified state vector state[:, :, AGILEX_STATE_INDICES] = joints # Assemble the mask indicating each dimension's availability state_elem_mask = torch.zeros( (B, self.args["model"]["state_token_dim"]), device=joints.device, dtype=joints.dtype, ) state_elem_mask[:, AGILEX_STATE_INDICES] = 1 return state, state_elem_mask def _unformat_action_to_joint(self, action): action_indices = AGILEX_STATE_INDICES joints = action[:, :, action_indices] # Rescale the gripper back to the action range # Note that the action range and proprioception range are different # for Mobile ALOHA robot joints = joints * torch.tensor( [[[180, 180, 180, 180, 180, 180]]], device=joints.device, dtype=joints.dtype, ) return joints @torch.no_grad() def step(self, proprio, images, text_embeds): device = self.device dtype = self.dtype # The background image used for padding background_color = np.array([int(x * 255) for x in self.image_processor.image_mean], dtype=np.uint8).reshape(1, 1, 3) background_image = (np.ones( ( self.image_processor.size["height"], self.image_processor.size["width"], 3, ), dtype=np.uint8, ) * background_color) # Preprocess the images by order and encode them image_tensor_list = [] for image in images: if image is None: # Replace it with the background image image = PImage.fromarray(background_image) else: # Convert numpy array to PIL Image if needed if isinstance(image, np.ndarray): image = PImage.fromarray(image) if self.image_size is not None: image = transforms.Resize(self.data_args.image_size)(image) if self.args["dataset"].get("auto_adjust_image_brightness", False): pixel_values = list(image.getdata()) average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3) if average_brightness <= 0.15: image = transforms.ColorJitter(brightness=(1.75, 1.75))(image) if self.args["dataset"].get("image_aspect_ratio", "pad") == "pad": def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = PImage.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = PImage.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image = expand2square(image, tuple(int(x * 255) for x in self.image_processor.image_mean)) image = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0] image_tensor_list.append(image) image_tensor = torch.stack(image_tensor_list, dim=0).to(device, dtype=dtype) image_embeds = self.vision_model(image_tensor).detach() image_embeds = image_embeds.reshape(-1, self.vision_model.hidden_size).unsqueeze(0) # Prepare the proprioception states and the control frequency joints = proprio.to(device).unsqueeze(0) # (1, 1, 14) states, state_elem_mask = self._format_joint_to_state(joints) # (1, 1, 128), (1, 128) states, state_elem_mask = states.to(device, dtype=dtype), state_elem_mask.to(device, dtype=dtype) states = states[:, -1:, :] # (1, 1, 128) ctrl_freqs = torch.tensor([self.control_frequency]).to(device) text_embeds = text_embeds.to(device, dtype=dtype) # Predict the next action chunk given the inputs trajectory = self.policy.predict_action( lang_tokens=text_embeds, lang_attn_mask=torch.ones(text_embeds.shape[:2], dtype=torch.bool, device=text_embeds.device), img_tokens=image_embeds, state_tokens=states, action_mask=state_elem_mask.unsqueeze(1), ctrl_freqs=ctrl_freqs, ) trajectory = self._unformat_action_to_joint(trajectory).to(torch.float32) return trajectory def get_training_samples(data_dirs, num_samples=5, instructions_per_episode=1): """ Get training samples from one or multiple data directories. Args: data_dirs: A single directory path (str) or a list of directory paths num_samples: Total number of samples to generate across all directories instructions_per_episode: Number of instructions per episode """ training_samples = [] # Handle both single directory and list of directories if isinstance(data_dirs, str): data_dirs = [data_dirs] logger.info(f"Get Training Data From: {len(data_dirs)} dataset(s).") # First, collect all available episode files from all directories episode_files = [] for data_dir in data_dirs: if not os.path.isdir(data_dir): logger.warning(f"Directory not found: {data_dir}, skipping") continue for root, dirs, files in os.walk(data_dir): for file in files: if file.endswith('.hdf5'): file_path = os.path.join(root, file) episode_files.append(file_path) if len(episode_files) == 0: logger.warning(f"No episode files found in the provided directories") return training_samples logger.info(f"Found {len(episode_files)} episode files across all datasets.") # Generate samples by randomly selecting from episodes while len(training_samples) < num_samples: # Randomly select an episode file file_path = np.random.choice(episode_files) try: with h5py.File(file_path, 'r') as f: observations = f['observations'] actions = f['action'][:] images = observations['images'] qpos = observations['qpos'][:] episode_dir = os.path.dirname(file_path) instructions_dir = os.path.join(episode_dir, 'instructions') num_steps = len(qpos) if num_steps > 1: # Image部分需要左中右三帧加上对饮历史帧组成4374维 lang_step_idx = int(np.random.randint(0, max(instructions_per_episode, 1))) instructions_dir = os.path.join(os.path.dirname(file_path), "instructions") lang_embed, lang_str = None, None # lang embed (optional) lang_embed_path = os.path.join(instructions_dir, f"lang_embed_{lang_step_idx}.pt") if os.path.exists(lang_embed_path): try: lang_embed = torch.load(lang_embed_path, map_location="cpu") except Exception as e: logger.error(f"Error reading {lang_embed_path}: {e}") # lang string (optional) lang_str_path = os.path.join(instructions_dir, f"txt_lang_embed_{lang_step_idx}.txt") if os.path.exists(lang_str_path): try: with open(lang_str_path, "r", encoding="utf-8") as tf: lang_str = tf.read().strip() except Exception as e: logger.error(f"Error reading {lang_str_path}: {e}") lang_str = lang_str or "" # 获取多摄像头多历史帧图像 step_idx = np.random.randint(0, num_steps) multi_cam_images = {} ref_frame = images['cam_high'][0] ref_img = cv2.imdecode(np.frombuffer(ref_frame, np.uint8), cv2.IMREAD_COLOR) IMG_HEIGHT, IMG_WIDTH = ref_img.shape[:2] # IMG_HEIGHT, IMG_WIDTH = images['cam_high'][0].shape[:2] ground_image = np.zeros((IMG_HEIGHT, IMG_WIDTH, 3), dtype=np.uint8) for cam_name in ['cam_high', 'cam_left_wrist', 'cam_right_wrist']: if cam_name in images: cam_images = [] # 获取2个历史帧的图像 for i in range(max(step_idx - 1, 0), step_idx + 1): # 2个历史帧 img_bits = images[cam_name][i] img = cv2.imdecode(np.frombuffer(img_bits, np.uint8), cv2.IMREAD_COLOR) # img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) cam_images.append(img) if len(cam_images) < 2: cam_images = [cam_images[0]] * 2 multi_cam_images[cam_name] = cam_images else: cam_images = [] for i in range(max(step_idx - 1, 0), step_idx + 1): # 2个历史帧 img_bits = ground_image # img = cv2.imdecode(np.frombuffer(img_bits, np.uint8), cv2.IMREAD_COLOR) cam_images.append(img_bits) if len(cam_images) < 2: cam_images = [cam_images[0]] * 2 multi_cam_images[cam_name] = cam_images training_samples.append({ 'multi_cam_images': multi_cam_images, 'joints': actions[step_idx], 'lang_embed': lang_embed, 'lang_str': lang_str, 'source': file_path, 'step': step_idx }) logger.debug(f"TimeStep: {step_idx}, Sample: {file_path}") except Exception as e: logger.error(f"Faild: {file_path} : {e}") continue logger.info(f"Total Num: {len(training_samples)}.") return training_samples def main(config_path): with open(config_path, "r") as f: cfg = json.load(f) export_info = cfg.get("export", {}) opt = ExportConfig( task_id=cfg.get("task_id"), output_path=os.path.join(export_info.get("output_path", "."), cfg.get("task_id", "")), model_path=export_info.get("model_path"), calibration_num=export_info.get("calibration_num", 100), dataset_path=export_info.get("dataset_path"), gpu_id=cfg.get("gpu_id", "0"), march=export_info.get("march"), model_type=export_info.get("model_type"), pretrained_vision_encoder_name_or_path="/home/qi.xiong/DualArm/Work_Docker/RDT/weights/siglip-so400m-patch14-384", ctrl_freq=export_info.get("ctrl_freq", 25), cal_data_device=cfg.get("cal_data_device", "cuda"), lang_calibration_num=export_info.get("lang_calibration_num", 1) ) if opt.model_type not in ["170M", "1B"]: raise ValueError(f"RDT ONLY SUPPORT 170M AND 1B, BUT GOT {opt.model_type}") logger.info(f"Export config loaded: {opt}") os.makedirs(opt.output_path, exist_ok=True) # PrePare Output Workspace ## BPU_RDT_Policy bpu_rdt_name = "BPU_RDT_Policy_170M" if opt.model_type == "170M" else "BPU_RDT_Policy_1B" bpu_rdt_path = os.path.join(opt.output_path, bpu_rdt_name) os.makedirs(bpu_rdt_path, exist_ok=True) os.system(f"cp configs/base_{opt.model_type}.yaml {bpu_rdt_path}/base.yaml") rdt_config_path = os.path.join(bpu_rdt_path, "base.yaml") ## Test_Datas test_data_name = "test_data" test_data_path = os.path.join(opt.output_path, test_data_name) os.makedirs(test_data_path, exist_ok=True) ## instruction instruction_ws_name = "instructions" instruction_ws_path = os.path.join(opt.output_path, instruction_ws_name) os.makedirs(instruction_ws_path, exist_ok=True) for name in os.listdir(opt.dataset_path): os.makedirs(os.path.join(instruction_ws_path, name), exist_ok=True) ## image adaptor global img_adaptor_cal_ws img_adaptor_ws_name = "img_adaptor_WorkSpace" img_adaptor_cal_name = "rdt_image_adaptor_calibration" img_adaptor_name = "rdt_image_adaptor.onnx" img_adaptor_config_name = "config.yaml" img_adaptor_ws = os.path.join(opt.output_path, img_adaptor_ws_name) img_adaptor_path = os.path.join(img_adaptor_ws, img_adaptor_name) img_adaptor_cal_ws = os.path.join(img_adaptor_ws, img_adaptor_cal_name) os.makedirs(img_adaptor_ws, exist_ok=True) os.makedirs(img_adaptor_cal_ws, exist_ok=True) ## action adaptor state_adaptor_name1 = "rdt_state_adaptor_1x1x256.onnx" state_adaptor_path1 = os.path.join(opt.output_path, bpu_rdt_name, state_adaptor_name1) state_adaptor_name2 = "rdt_state_adaptor_1x64x256.onnx" state_adaptor_path2 = os.path.join(opt.output_path, bpu_rdt_name, state_adaptor_name2) ## lang adaptor lang_adaptor_name = "rdt_lang_adaptor.onnx" lang_adaptor_path = os.path.join(opt.output_path, bpu_rdt_name, lang_adaptor_name) ## DiT Policy dit_ws_name = "DiT_WorkSpace" dit_cal_name = "rdt_dit_calibration" dit_name = "rdt_dit.onnx" dit_config_name = "config.yaml" dit_json_name = "quant_config.json" dit_ws = os.path.join(opt.output_path, dit_ws_name) dit_path = os.path.join(dit_ws, dit_name) dit_cal_ws = os.path.join(dit_ws, dit_cal_name) os.makedirs(dit_ws, exist_ok=True) os.makedirs(dit_cal_ws, exist_ok=True) global dit_cal_path_x, dit_cal_path_freq, dit_cal_path_t, dit_cal_path_lang_c, dit_cal_path_img_c, dit_cal_path_lang_mask dit_cal_path_x = os.path.join(dit_cal_ws, "x") os.makedirs(dit_cal_path_x, exist_ok=True) dit_cal_path_freq = os.path.join(dit_cal_ws, "freq") os.makedirs(dit_cal_path_freq, exist_ok=True) dit_cal_path_t = os.path.join(dit_cal_ws, "t") os.makedirs(dit_cal_path_t, exist_ok=True) dit_cal_path_lang_c = os.path.join(dit_cal_ws, "lang_c") os.makedirs(dit_cal_path_lang_c, exist_ok=True) dit_cal_path_img_c = os.path.join(dit_cal_ws, "img_c") os.makedirs(dit_cal_path_img_c, exist_ok=True) dit_cal_path_lang_mask = os.path.join(dit_cal_ws, "lang_mask") os.makedirs(dit_cal_path_lang_mask, exist_ok=True) # Prepare Calibrate Data with open(rdt_config_path, "r") as f: rdt_config = yaml.safe_load(f) dump_model = create_dump_model( args=rdt_config, dtype=torch.float32, pretrained=opt.model_path, pretrained_vision_encoder_name_or_path=opt.pretrained_vision_encoder_name_or_path, control_frequency=opt.ctrl_freq, device=opt.cal_data_device ) # Prepare Calbriation Data # load training data from all datasets global dump_cnt, dump_dataset_name test_data_cnt = 0 # Collect all dataset paths all_dataset_paths = [] for dump_dataset_name in os.listdir(opt.dataset_path): dump_dataset_path = os.path.join(opt.dataset_path, dump_dataset_name) if os.path.isdir(dump_dataset_path): all_dataset_paths.append(dump_dataset_path) # Get training samples from all datasets together training_samples = get_training_samples(all_dataset_paths, num_samples=opt.calibration_num, instructions_per_episode=opt.lang_calibration_num) if len(training_samples) == 0: logger.warning("No training samples found, skipping calibration data generation") else: # Only process up to the number of samples we actually have num_samples_to_process = min(len(training_samples), opt.calibration_num) for dump_cnt in range(num_samples_to_process): sample = training_samples[dump_cnt] # Extract dataset name from the sample's source path sample_source = sample['source'] dump_dataset_name = os.path.basename(os.path.dirname(os.path.dirname(sample_source))) instruction_emb = { "lang_cond": sample["lang_embed"].float().cpu(), "lang_str": sample["lang_str"] } ins_str_name = sample["lang_str"].replace(" ", "_") + "__" torch.save(instruction_emb, os.path.join(instruction_ws_path, dump_dataset_name, f"{ins_str_name}.pt")) image_arrs = [ sample['multi_cam_images']['cam_high'][0], sample['multi_cam_images']['cam_right_wrist'][0], sample['multi_cam_images']['cam_left_wrist'][0], sample['multi_cam_images']['cam_high'][1], sample['multi_cam_images']['cam_right_wrist'][1], sample['multi_cam_images']['cam_left_wrist'][1], ] test_data_cnt += 1 np.save(os.path.join(test_data_path, f"{test_data_cnt}_cam_high_0.npy"), sample['multi_cam_images']['cam_high'][0]) np.save(os.path.join(test_data_path, f"{test_data_cnt}_cam_right_wrist_0.npy"), sample['multi_cam_images']['cam_right_wrist'][0]) np.save(os.path.join(test_data_path, f"{test_data_cnt}_cam_left_wrist_0.npy"), sample['multi_cam_images']['cam_left_wrist'][0]) np.save(os.path.join(test_data_path, f"{test_data_cnt}_cam_high_1.npy"), sample['multi_cam_images']['cam_high'][1]) np.save(os.path.join(test_data_path, f"{test_data_cnt}_cam_right_wrist_1.npy"), sample['multi_cam_images']['cam_right_wrist'][1]) np.save(os.path.join(test_data_path, f"{test_data_cnt}_cam_left_wrist_1.npy"), sample['multi_cam_images']['cam_left_wrist'][1]) images = [PImage.fromarray(arr) if arr is not None else None for arr in image_arrs] proprio = torch.from_numpy(sample['joints']).float().unsqueeze(0).to(opt.cal_data_device) np.save(os.path.join(test_data_path, f"{test_data_cnt}_joints.npy"), sample['joints']) lang_embeddings = sample['lang_embed'].float().unsqueeze(0).to(opt.cal_data_device) torch.save(lang_embeddings, os.path.join(test_data_path, f"{test_data_cnt}_lang_embeddings.pt")) dump_model.reset() begin_time = time() actions = dump_model.step(proprio=proprio, images=images, text_embeds=lang_embeddings).squeeze(0).cpu().numpy() np.save(os.path.join(test_data_path, f"{test_data_cnt}_actions.npy"), actions) logger.debug(f"Dump: Cost {(1000*(time() - begin_time)):.1f} ms, cnt: {dump_cnt}, name: {dump_dataset_name}") logger.info("End Generate Calibration Data.") del dump_model # Load RDT Policy: CPU Model For ONNX Export with open(rdt_config_path, "r") as f: rdt_config = yaml.safe_load(f) model = create_model( args=rdt_config, dtype=torch.float32, pretrained=opt.model_path, pretrained_vision_encoder_name_or_path=opt.pretrained_vision_encoder_name_or_path, control_frequency=opt.ctrl_freq, device="cpu" ) # image adaptor: ONNX Model m = model.policy.img_adaptor m.eval() input_data = torch.randn(1, 4374, rdt_config['model']['img_token_dim']) # 假设批量大小为1 output = m(input_data) torch.onnx.export( m, input_data, img_adaptor_path, opset_version=14, do_constant_folding=True, input_names=["img_tokens"], output_names=["adapted_img"], dynamic_axes=None, verbose=False ) logger.info("Export RDT [img_adaptor] Model Success.") # DiT hidden_size = rdt_config['model']["rdt"]['hidden_size'] m = model.policy.model m = m.eval().cpu() x = torch.randn(1, 65, hidden_size) freq = torch.tensor([1], dtype=torch.int32) t = torch.tensor([10], dtype=torch.int32) lang_c = torch.randn(1, 64, hidden_size) img_c = torch.randn(1, 4374, hidden_size) lang_mask = torch.ones(1, 64, dtype=torch.float32) dummy_inputs = (x, freq, t, lang_c, img_c, lang_mask) # outputs = m(x, freq, t, lang_c, img_c, lang_mask) torch.onnx.export( m, dummy_inputs, dit_path, # export_params=True, opset_version=14, do_constant_folding=True, input_names=["x", "freq", "t", "lang_c", "img_c", "lang_mask"], output_names=["actions"], verbose=False ) logger.info("Export RDT [DiT] Model Success.") # state adaptor m = model.policy.state_adaptor m.eval() input_data = torch.randn(1, 1, 256) # 假设批量大小为1 output = m(input_data) torch.onnx.export( m, input_data, state_adaptor_path1, export_params=True, opset_version=14, do_constant_folding=True, input_names=["state_tokens"], output_names=["state_traj"], dynamic_axes=None, verbose=False ) logging.info("Export RDT [state 1x1x256] Model Success.") input_data = torch.randn(1, 64, 256) output = m(input_data) torch.onnx.export( m, input_data, state_adaptor_path2, export_params=True, opset_version=14, do_constant_folding=True, input_names=['state_tokens'], output_names=['state_traj'], dynamic_axes=None, verbose=False ) logging.info("Export RDT [state 1x64x256] Model Success.") # lang adaptor m = model.policy.lang_adaptor m.eval() input_data = torch.randn(1, 14, 4096) output = m(input_data) torch.onnx.export( m, input_data, lang_adaptor_path, export_params=True, opset_version=14, do_constant_folding=True, input_names=["text_embeds"], output_names=["lang_cond"], dynamic_axes={ "text_embeds": {1: "N"}, "lang_cond": {1: "N"} }, verbose=False ) logger.info("Export RDT [lang adaptor] Model Success.") ######## Prepare Calbibration Data if __name__ == "__main__": main("/home/qi.xiong/DualArm/Work_Docker/RDT/rdt-export/input/config.json") logger.info("All Models Have Been Exported Success.")