import torch import os import yaml from cloud_helper import Server import argparse import numpy as np from PIL import Image from torchvision import transforms import time from configs.state_vec import STATE_VEC_IDX_MAPPING from models.multimodal_encoder.siglip_encoder import SiglipVisionTower from models.multimodal_encoder.t5_encoder import T5Embedder from models.rdt_runner import RDTRunner os.environ["CUDA_VISIBLE_DEVICES"] = "0" AGILEX_STATE_INDICES = [ STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(6) ] def create_model(args, **kwargs): model = RoboticDiffusionTransformerModel(args, **kwargs) pretrained = kwargs.get("pretrained", None) if pretrained is not None and os.path.isfile(pretrained): model.load_pretrained_weights(pretrained) return model class RoboticDiffusionTransformerModel(object): def __init__( self, args, device="cuda", dtype=torch.bfloat16, image_size=None, control_frequency=25, pretrained=None, pretrained_text_encoder_name_or_path=None, pretrained_vision_encoder_name_or_path=None, ): self.args = args self.dtype = dtype self.image_size = image_size self.device = device self.control_frequency = control_frequency # We do not use the text encoder due to limited GPU memory self.text_tokenizer, self.text_model = self.get_text_encoder(pretrained_text_encoder_name_or_path) self.image_processor, self.vision_model = self.get_vision_encoder(pretrained_vision_encoder_name_or_path) self.policy = self.get_policy(pretrained) self.reset() def get_policy(self, pretrained): """Initialize the model.""" # Initialize model with arguments if pretrained is None or os.path.isfile(pretrained): img_cond_len = (self.args["common"]["img_history_size"] * self.args["common"]["num_cameras"] * self.vision_model.num_patches) _model = RDTRunner( action_dim=self.args["common"]["state_dim"], pred_horizon=self.args["common"]["action_chunk_size"], config=self.args["model"], lang_token_dim=self.args["model"]["lang_token_dim"], img_token_dim=self.args["model"]["img_token_dim"], state_token_dim=self.args["model"]["state_token_dim"], max_lang_cond_len=self.args["dataset"]["tokenizer_max_length"], img_cond_len=img_cond_len, img_pos_embed_config=[ # No initial pos embed in the last grid size # since we've already done in ViT ( "image", ( self.args["common"]["img_history_size"], self.args["common"]["num_cameras"], -self.vision_model.num_patches, ), ), ], lang_pos_embed_config=[ # Similarly, no initial pos embed for language ("lang", -self.args["dataset"]["tokenizer_max_length"]), ], dtype=self.dtype, ) else: _model = RDTRunner.from_pretrained(pretrained) return _model def get_text_encoder(self, pretrained_text_encoder_name_or_path): text_embedder = T5Embedder( from_pretrained=pretrained_text_encoder_name_or_path, model_max_length=self.args["dataset"]["tokenizer_max_length"], device=self.device, ) tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model return tokenizer, text_encoder def get_vision_encoder(self, pretrained_vision_encoder_name_or_path): vision_encoder = SiglipVisionTower(vision_tower=pretrained_vision_encoder_name_or_path, args=None) image_processor = vision_encoder.image_processor return image_processor, vision_encoder def reset(self): """Set model to evaluation mode.""" device = self.device weight_dtype = self.dtype self.policy.eval() self.text_model.eval() self.vision_model.eval() self.policy = self.policy.to(device, dtype=weight_dtype) self.text_model = self.text_model.to(device, dtype=weight_dtype) self.vision_model = self.vision_model.to(device, dtype=weight_dtype) def load_pretrained_weights(self, pretrained=None): if pretrained is None: return print(f"Loading weights from {pretrained}") filename = os.path.basename(pretrained) if filename.endswith(".pt"): checkpoint = torch.load(pretrained) self.policy.load_state_dict(checkpoint["module"]) elif filename.endswith(".safetensors"): from safetensors.torch import load_model load_model(self.policy, pretrained) else: raise NotImplementedError(f"Unknown checkpoint format: {pretrained}") def encode_instruction(self, instruction, device="cuda"): """Encode string instruction to latent embeddings. Args: instruction: a string of instruction device: a string of device Returns: pred: a tensor of latent embeddings of shape (text_max_length, 512) """ tokens = self.text_tokenizer(instruction, return_tensors="pt", padding="longest", truncation=True)["input_ids"].to(device) tokens = tokens.view(1, -1) with torch.no_grad(): pred = self.text_model(tokens).last_hidden_state.detach() return pred def _format_joint_to_state(self, joints): """ Format the joint proprioception into the unified action vector. Args: joints (torch.Tensor): The joint proprioception to be formatted. qpos ([B, N, 14]). Returns: state (torch.Tensor): The formatted vector for RDT ([B, N, 128]). """ # Rescale the gripper to the range of [0, 1] joints = joints / torch.tensor( [[[180, 180, 180, 180, 180, 180]]], device=joints.device, dtype=joints.dtype, ) B, N, _ = joints.shape state = torch.zeros( (B, N, self.args["model"]["state_token_dim"]), device=joints.device, dtype=joints.dtype, ) # Fill into the unified state vector state[:, :, AGILEX_STATE_INDICES] = joints # Assemble the mask indicating each dimension's availability state_elem_mask = torch.zeros( (B, self.args["model"]["state_token_dim"]), device=joints.device, dtype=joints.dtype, ) state_elem_mask[:, AGILEX_STATE_INDICES] = 1 return state, state_elem_mask def _unformat_action_to_joint(self, action): """ Unformat the unified action vector into the joint action to be executed. Args: action (torch.Tensor): The unified action vector to be unformatted. ([B, N, 128]) Returns: joints (torch.Tensor): The unformatted robot joint action. qpos ([B, N, 14]). """ action_indices = AGILEX_STATE_INDICES joints = action[:, :, action_indices] # Rescale the gripper back to the action range # Note that the action range and proprioception range are different # for Mobile ALOHA robot joints = joints * torch.tensor( [[[180, 180, 180, 180, 180, 180]]], device=joints.device, dtype=joints.dtype, ) return joints @torch.no_grad() def step(self, proprio, images, text_embeds): """ Predict the next action chunk given the proprioceptive states, images, and instruction embeddings. Args: proprio: proprioceptive states images: RGB images, the order should be [ext_{t-1}, right_wrist_{t-1}, left_wrist_{t-1}, ext_{t}, right_wrist_{t}, left_wrist_{t}] text_embeds: instruction embeddings Returns: action: predicted action """ device = self.device dtype = self.dtype # The background image used for padding background_color = np.array([int(x * 255) for x in self.image_processor.image_mean], dtype=np.uint8).reshape(1, 1, 3) background_image = (np.ones( ( self.image_processor.size["height"], self.image_processor.size["width"], 3, ), dtype=np.uint8, ) * background_color) # Preprocess the images by order and encode them image_tensor_list = [] for image in images: if image is None: # Replace it with the background image image = Image.fromarray(background_image) else: # Convert numpy array to PIL Image if needed if isinstance(image, np.ndarray): image = Image.fromarray(image) if self.image_size is not None: image = transforms.Resize(self.image_size)(image) if self.args["dataset"].get("auto_adjust_image_brightness", False): pixel_values = list(image.getdata()) average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3) if average_brightness <= 0.15: image = transforms.ColorJitter(brightness=(1.75, 1.75))(image) if self.args["dataset"].get("image_aspect_ratio", "pad") == "pad": def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image = expand2square(image, tuple(int(x * 255) for x in self.image_processor.image_mean)) image = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0] image_tensor_list.append(image) image_tensor = torch.stack(image_tensor_list, dim=0).to(device, dtype=dtype) image_embeds = self.vision_model(image_tensor).detach() image_embeds = image_embeds.reshape(-1, self.vision_model.hidden_size).unsqueeze(0) # Prepare the proprioception states and the control frequency # Convert numpy array to tensor if needed if isinstance(proprio, np.ndarray): # Copy the array to make it writable proprio = torch.from_numpy(proprio.copy()) joints = proprio.to(device).unsqueeze(0) # (1, 1, 14) states, state_elem_mask = self._format_joint_to_state(joints) # (1, 1, 128), (1, 128) states, state_elem_mask = states.to(device, dtype=dtype), state_elem_mask.to(device, dtype=dtype) states = states[:, -1:, :] # (1, 1, 128) ctrl_freqs = torch.tensor([self.control_frequency]).to(device) text_embeds = text_embeds.to(device, dtype=dtype) # Predict the next action chunk given the inputs trajectory = self.policy.predict_action( lang_tokens=text_embeds, lang_attn_mask=torch.ones(text_embeds.shape[:2], dtype=torch.bool, device=text_embeds.device), img_tokens=image_embeds, state_tokens=states, action_mask=state_elem_mask.unsqueeze(1), ctrl_freqs=ctrl_freqs, ) trajectory = self._unformat_action_to_joint(trajectory).to(torch.float32) return trajectory class RDTInferenceServer: def __init__( self, pretrained_text_encoder_name_or_path, pretrained_vision_encoder_name_or_path, pretrained_rdt_model_weights, config, args, lang_model ): self.device = args.device self.policy_type = "rdt" self.policy = create_model( args=config, dtype=torch.bfloat16, pretrained=pretrained_rdt_model_weights, pretrained_text_encoder_name_or_path=pretrained_text_encoder_name_or_path, pretrained_vision_encoder_name_or_path=pretrained_vision_encoder_name_or_path, control_frequency=args.control_frequency, ) self.server = Server(args.host, args.port) print(f"Loaded RDT policy from {pretrained_rdt_model_weights}") if args.pre_lang is True and lang_model and os.path.exists(lang_model): self.lang_embeddings = torch.load(lang_model, map_location=self.device) print(f"Loaded language embeddings shape: {self.lang_embeddings.shape if self.lang_embeddings is not None else 'None'}") print(f"Model expects tokenizer_max_length: {self.policy.args['dataset']['tokenizer_max_length']}") print(f"Model lang_token_dim: {self.policy.args['model']['lang_token_dim']}") else: print("No language model provided, using runtime embeddings") self.lang_embeddings = None self.cached_instruction = None self.cached_text_embeds = None def get_actions(self, batch): """ Args: batch: { "observation": { "state": (STATE_DIM,) np.ndarray, "images.cam_high": (IMG_HISTORY_SIZE, H, W, 3) np.uint8, "images.cam_right_wrist": (IMG_HISTORY_SIZE, H, W, 3) np.uint8, ... }, "instruction": str or int } Returns: action: (chunk_size, action_dim) np.ndarray """ observation = batch["observation"] instruction = batch["instruction"] proprio = None if "state" in observation: state = observation["state"] if isinstance(state, np.ndarray): proprio = state.astype(np.float32) else: proprio = np.array(state, dtype=np.float32) image_keys = sorted([k for k in observation.keys() if k.startswith("images.")]) camera_images = {} for key in image_keys: img_data = observation[key] if img_data is None: camera_images[key] = None elif isinstance(img_data, np.ndarray): if img_data.ndim == 4 and img_data.shape[-1] == 3: if img_data.shape[1] > 0 and img_data.shape[2] > 0: camera_images[key] = img_data.astype(np.uint8) else: camera_images[key] = None else: print(f"Warning: {key} dimension is incorrect, expected (T, H, W, 3), actual {img_data.shape}") camera_images[key] = None else: print(f"Warning: {key} data type is incorrect, expected np.ndarray, actual {type(img_data)}") camera_images[key] = None images = [] if camera_images: img_history_size = self.policy.args["common"]["img_history_size"] for t in range(img_history_size): for key in image_keys: img_array = camera_images.get(key) if img_array is not None: images.append(img_array[t]) else: images.append(None) if hasattr(self, "lang_embeddings") and self.lang_embeddings is not None: if isinstance(instruction, int): text_embeds = self.lang_embeddings[instruction] else: raise ValueError(f"Instruction type not supported: {type(instruction)}") else: if instruction == self.cached_instruction and self.cached_text_embeds is not None: print("the instruction is not changed, use the cached instruction") text_embeds = self.cached_text_embeds else: print("the instruction is changed, re-encode the instruction") text_embeds = self.policy.encode_instruction(instruction, device=self.policy.device) self.cached_instruction = instruction self.cached_text_embeds = text_embeds begin_time = time.time() action = self.policy.step( proprio=proprio, images=images, text_embeds=text_embeds, ) end_time = time.time() print(f"The time cost of the policy.step is {end_time - begin_time} seconds") return action.squeeze(0).cpu().numpy() # (chunk_size, action_dim) def run(self): self.server.register_endpoint("get_actions", self.get_actions) print(f"Lerobot {self.policy_type.upper()} Server is running...") self.server.loop_forever() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config_file", type=str, default=None) parser.add_argument("--path_to_vision_encoder_model", type=str, default=None) parser.add_argument("--path_to_text_encoder_model", type=str, default=None) parser.add_argument("--path_to_rdt_model_wights", type=str, default=None) parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--control_frequency", type=int, default=25) parser.add_argument("--pre_lang", type=bool, default=True) parser.add_argument("--lang_model", type=str, default=None) parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=8000) args = parser.parse_args() with open(args.config_file, "r") as fp: config = yaml.safe_load(fp) server = RDTInferenceServer( args.path_to_text_encoder_model, args.path_to_vision_encoder_model, args.path_to_rdt_model_wights, config, args, args.lang_model, ) server.run()