485 lines
19 KiB
Python
485 lines
19 KiB
Python
import torch
|
||
import os
|
||
import yaml
|
||
from cloud_helper import Server
|
||
import argparse
|
||
import numpy as np
|
||
from PIL import Image
|
||
from torchvision import transforms
|
||
|
||
from configs.state_vec import STATE_VEC_IDX_MAPPING
|
||
from models.multimodal_encoder.siglip_encoder import SiglipVisionTower
|
||
from models.multimodal_encoder.t5_encoder import T5Embedder
|
||
from models.rdt_runner import RDTRunner
|
||
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||
|
||
AGILEX_STATE_INDICES = [
|
||
STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(6)
|
||
]
|
||
|
||
def create_model(args, **kwargs):
|
||
model = RoboticDiffusionTransformerModel(args, **kwargs)
|
||
pretrained = kwargs.get("pretrained", None)
|
||
if pretrained is not None and os.path.isfile(pretrained):
|
||
model.load_pretrained_weights(pretrained)
|
||
|
||
return model
|
||
|
||
class RoboticDiffusionTransformerModel(object):
|
||
"""A wrapper for the RDT model, which handles
|
||
1. Model initialization
|
||
2. Encodings of instructions
|
||
3. Model inference
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
args,
|
||
device="cuda",
|
||
dtype=torch.bfloat16,
|
||
image_size=None,
|
||
control_frequency=25,
|
||
pretrained=None,
|
||
pretrained_text_encoder_name_or_path=None,
|
||
pretrained_vision_encoder_name_or_path=None,
|
||
):
|
||
self.args = args
|
||
self.dtype = dtype
|
||
self.image_size = image_size
|
||
self.device = device
|
||
self.control_frequency = control_frequency
|
||
# We do not use the text encoder due to limited GPU memory
|
||
self.text_tokenizer, self.text_model = self.get_text_encoder(pretrained_text_encoder_name_or_path)
|
||
self.image_processor, self.vision_model = self.get_vision_encoder(pretrained_vision_encoder_name_or_path)
|
||
self.policy = self.get_policy(pretrained)
|
||
|
||
self.reset()
|
||
|
||
def get_policy(self, pretrained):
|
||
"""Initialize the model."""
|
||
# Initialize model with arguments
|
||
if pretrained is None or os.path.isfile(pretrained):
|
||
img_cond_len = (self.args["common"]["img_history_size"] * self.args["common"]["num_cameras"] *
|
||
self.vision_model.num_patches)
|
||
|
||
_model = RDTRunner(
|
||
action_dim=self.args["common"]["state_dim"],
|
||
pred_horizon=self.args["common"]["action_chunk_size"],
|
||
config=self.args["model"],
|
||
lang_token_dim=self.args["model"]["lang_token_dim"],
|
||
img_token_dim=self.args["model"]["img_token_dim"],
|
||
state_token_dim=self.args["model"]["state_token_dim"],
|
||
max_lang_cond_len=self.args["dataset"]["tokenizer_max_length"],
|
||
img_cond_len=img_cond_len,
|
||
img_pos_embed_config=[
|
||
# No initial pos embed in the last grid size
|
||
# since we've already done in ViT
|
||
(
|
||
"image",
|
||
(
|
||
self.args["common"]["img_history_size"],
|
||
self.args["common"]["num_cameras"],
|
||
-self.vision_model.num_patches,
|
||
),
|
||
),
|
||
],
|
||
lang_pos_embed_config=[
|
||
# Similarly, no initial pos embed for language
|
||
("lang", -self.args["dataset"]["tokenizer_max_length"]),
|
||
],
|
||
dtype=self.dtype,
|
||
)
|
||
else:
|
||
_model = RDTRunner.from_pretrained(pretrained)
|
||
|
||
return _model
|
||
|
||
def get_text_encoder(self, pretrained_text_encoder_name_or_path):
|
||
text_embedder = T5Embedder(
|
||
from_pretrained=pretrained_text_encoder_name_or_path,
|
||
model_max_length=self.args["dataset"]["tokenizer_max_length"],
|
||
device=self.device,
|
||
)
|
||
tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
|
||
return tokenizer, text_encoder
|
||
|
||
def get_vision_encoder(self, pretrained_vision_encoder_name_or_path):
|
||
vision_encoder = SiglipVisionTower(vision_tower=pretrained_vision_encoder_name_or_path, args=None)
|
||
image_processor = vision_encoder.image_processor
|
||
return image_processor, vision_encoder
|
||
|
||
def reset(self):
|
||
"""Set model to evaluation mode."""
|
||
device = self.device
|
||
weight_dtype = self.dtype
|
||
self.policy.eval()
|
||
self.text_model.eval()
|
||
self.vision_model.eval()
|
||
|
||
self.policy = self.policy.to(device, dtype=weight_dtype)
|
||
self.text_model = self.text_model.to(device, dtype=weight_dtype)
|
||
self.vision_model = self.vision_model.to(device, dtype=weight_dtype)
|
||
|
||
def load_pretrained_weights(self, pretrained=None):
|
||
if pretrained is None:
|
||
return
|
||
print(f"Loading weights from {pretrained}")
|
||
filename = os.path.basename(pretrained)
|
||
if filename.endswith(".pt"):
|
||
checkpoint = torch.load(pretrained)
|
||
self.policy.load_state_dict(checkpoint["module"])
|
||
elif filename.endswith(".safetensors"):
|
||
from safetensors.torch import load_model
|
||
|
||
load_model(self.policy, pretrained)
|
||
else:
|
||
raise NotImplementedError(f"Unknown checkpoint format: {pretrained}")
|
||
|
||
def encode_instruction(self, instruction, device="cuda"):
|
||
"""Encode string instruction to latent embeddings.
|
||
|
||
Args:
|
||
instruction: a string of instruction
|
||
device: a string of device
|
||
|
||
Returns:
|
||
pred: a tensor of latent embeddings of shape (text_max_length, 512)
|
||
"""
|
||
tokens = self.text_tokenizer(instruction, return_tensors="pt", padding="longest",
|
||
truncation=True)["input_ids"].to(device)
|
||
|
||
tokens = tokens.view(1, -1)
|
||
with torch.no_grad():
|
||
pred = self.text_model(tokens).last_hidden_state.detach()
|
||
|
||
return pred
|
||
|
||
def _format_joint_to_state(self, joints):
|
||
"""
|
||
Format the joint proprioception into the unified action vector.
|
||
|
||
Args:
|
||
joints (torch.Tensor): The joint proprioception to be formatted.
|
||
qpos ([B, N, 14]).
|
||
|
||
Returns:
|
||
state (torch.Tensor): The formatted vector for RDT ([B, N, 128]).
|
||
"""
|
||
# Rescale the gripper to the range of [0, 1]
|
||
joints = joints / torch.tensor(
|
||
[[[180, 180, 180, 180, 180, 180]]],
|
||
device=joints.device,
|
||
dtype=joints.dtype,
|
||
)
|
||
|
||
B, N, _ = joints.shape
|
||
state = torch.zeros(
|
||
(B, N, self.args["model"]["state_token_dim"]),
|
||
device=joints.device,
|
||
dtype=joints.dtype,
|
||
)
|
||
# Fill into the unified state vector
|
||
state[:, :, AGILEX_STATE_INDICES] = joints
|
||
# Assemble the mask indicating each dimension's availability
|
||
state_elem_mask = torch.zeros(
|
||
(B, self.args["model"]["state_token_dim"]),
|
||
device=joints.device,
|
||
dtype=joints.dtype,
|
||
)
|
||
state_elem_mask[:, AGILEX_STATE_INDICES] = 1
|
||
return state, state_elem_mask
|
||
|
||
def _unformat_action_to_joint(self, action):
|
||
"""
|
||
Unformat the unified action vector into the joint action to be executed.
|
||
|
||
Args:
|
||
action (torch.Tensor): The unified action vector to be unformatted.
|
||
([B, N, 128])
|
||
|
||
Returns:
|
||
joints (torch.Tensor): The unformatted robot joint action.
|
||
qpos ([B, N, 14]).
|
||
"""
|
||
action_indices = AGILEX_STATE_INDICES
|
||
joints = action[:, :, action_indices]
|
||
|
||
# Rescale the gripper back to the action range
|
||
# Note that the action range and proprioception range are different
|
||
# for Mobile ALOHA robot
|
||
joints = joints * torch.tensor(
|
||
[[[180, 180, 180, 180, 180, 180]]],
|
||
device=joints.device,
|
||
dtype=joints.dtype,
|
||
)
|
||
|
||
return joints
|
||
|
||
@torch.no_grad()
|
||
def step(self, proprio, images, text_embeds):
|
||
"""
|
||
Predict the next action chunk given the
|
||
proprioceptive states, images, and instruction embeddings.
|
||
|
||
Args:
|
||
proprio: proprioceptive states
|
||
images: RGB images, the order should be
|
||
[ext_{t-1}, right_wrist_{t-1}, left_wrist_{t-1},
|
||
ext_{t}, right_wrist_{t}, left_wrist_{t}]
|
||
text_embeds: instruction embeddings
|
||
|
||
Returns:
|
||
action: predicted action
|
||
"""
|
||
device = self.device
|
||
dtype = self.dtype
|
||
|
||
# The background image used for padding
|
||
background_color = np.array([int(x * 255) for x in self.image_processor.image_mean],
|
||
dtype=np.uint8).reshape(1, 1, 3)
|
||
background_image = (np.ones(
|
||
(
|
||
self.image_processor.size["height"],
|
||
self.image_processor.size["width"],
|
||
3,
|
||
),
|
||
dtype=np.uint8,
|
||
) * background_color)
|
||
|
||
# Preprocess the images by order and encode them
|
||
image_tensor_list = []
|
||
for image in images:
|
||
if image is None:
|
||
# Replace it with the background image
|
||
image = Image.fromarray(background_image)
|
||
else:
|
||
# Convert numpy array to PIL Image if needed
|
||
if isinstance(image, np.ndarray):
|
||
image = Image.fromarray(image)
|
||
|
||
if self.image_size is not None:
|
||
image = transforms.Resize(self.image_size)(image)
|
||
|
||
if self.args["dataset"].get("auto_adjust_image_brightness", False):
|
||
pixel_values = list(image.getdata())
|
||
average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)
|
||
if average_brightness <= 0.15:
|
||
image = transforms.ColorJitter(brightness=(1.75, 1.75))(image)
|
||
|
||
if self.args["dataset"].get("image_aspect_ratio", "pad") == "pad":
|
||
|
||
def expand2square(pil_img, background_color):
|
||
width, height = pil_img.size
|
||
if width == height:
|
||
return pil_img
|
||
elif width > height:
|
||
result = Image.new(pil_img.mode, (width, width), background_color)
|
||
result.paste(pil_img, (0, (width - height) // 2))
|
||
return result
|
||
else:
|
||
result = Image.new(pil_img.mode, (height, height), background_color)
|
||
result.paste(pil_img, ((height - width) // 2, 0))
|
||
return result
|
||
|
||
image = expand2square(image, tuple(int(x * 255) for x in self.image_processor.image_mean))
|
||
image = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
|
||
image_tensor_list.append(image)
|
||
|
||
image_tensor = torch.stack(image_tensor_list, dim=0).to(device, dtype=dtype)
|
||
|
||
image_embeds = self.vision_model(image_tensor).detach()
|
||
image_embeds = image_embeds.reshape(-1, self.vision_model.hidden_size).unsqueeze(0)
|
||
|
||
# Prepare the proprioception states and the control frequency
|
||
# Convert numpy array to tensor if needed
|
||
if isinstance(proprio, np.ndarray):
|
||
# Copy the array to make it writable
|
||
proprio = torch.from_numpy(proprio.copy())
|
||
|
||
joints = proprio.to(device).unsqueeze(0) # (1, 1, 14)
|
||
states, state_elem_mask = self._format_joint_to_state(joints) # (1, 1, 128), (1, 128)
|
||
states, state_elem_mask = states.to(device, dtype=dtype), state_elem_mask.to(device, dtype=dtype)
|
||
states = states[:, -1:, :] # (1, 1, 128)
|
||
ctrl_freqs = torch.tensor([self.control_frequency]).to(device)
|
||
|
||
text_embeds = text_embeds.to(device, dtype=dtype)
|
||
|
||
# Predict the next action chunk given the inputs
|
||
trajectory = self.policy.predict_action(
|
||
lang_tokens=text_embeds,
|
||
lang_attn_mask=torch.ones(text_embeds.shape[:2], dtype=torch.bool, device=text_embeds.device),
|
||
img_tokens=image_embeds,
|
||
state_tokens=states,
|
||
action_mask=state_elem_mask.unsqueeze(1),
|
||
ctrl_freqs=ctrl_freqs,
|
||
)
|
||
trajectory = self._unformat_action_to_joint(trajectory).to(torch.float32)
|
||
|
||
return trajectory
|
||
|
||
class RDTInferenceServer:
|
||
def __init__(
|
||
self,
|
||
pretrained_text_encoder_name_or_path,
|
||
pretrained_vision_encoder_name_or_path,
|
||
pretrained_rdt_model_weights,
|
||
config,
|
||
args,
|
||
lang_model
|
||
):
|
||
|
||
self.device = args.device
|
||
self.policy_type = "rdt"
|
||
self.policy = create_model(
|
||
args=config,
|
||
dtype=torch.bfloat16,
|
||
pretrained=pretrained_rdt_model_weights,
|
||
pretrained_text_encoder_name_or_path=pretrained_text_encoder_name_or_path,
|
||
pretrained_vision_encoder_name_or_path=pretrained_vision_encoder_name_or_path,
|
||
control_frequency=args.control_frequency,
|
||
)
|
||
self.server = Server(args.host, args.port)
|
||
print(f"Loaded RDT policy from {pretrained_rdt_model_weights}")
|
||
|
||
if args.pre_lang is True and lang_model and os.path.exists(lang_model):
|
||
# 加载预计算语言嵌入
|
||
self.lang_embeddings = torch.load(lang_model, map_location=self.device)
|
||
print(f"Loaded language embeddings shape: {self.lang_embeddings.shape if self.lang_embeddings is not None else 'None'}")
|
||
print(f"Model expects tokenizer_max_length: {self.policy.args['dataset']['tokenizer_max_length']}")
|
||
print(f"Model lang_token_dim: {self.policy.args['model']['lang_token_dim']}")
|
||
else:
|
||
print("No language model provided, using runtime embeddings")
|
||
self.lang_embeddings = None
|
||
|
||
def get_actions(self, batch):
|
||
"""处理客户端请求并返回动作预测
|
||
|
||
Args:
|
||
batch: 包含观测和指令的字典
|
||
{
|
||
"observation": {
|
||
"state": (STATE_DIM,) np.ndarray, # 机器人状态
|
||
"images.cam_high": (IMG_HISTORY_SIZE, H, W, 3) np.uint8, # 主摄像头
|
||
"images.cam_right_wrist": (IMG_HISTORY_SIZE, H, W, 3) np.uint8, # 手腕摄像头
|
||
...
|
||
},
|
||
"instruction": str or int # 任务指令(字符串或预计算嵌入索引)
|
||
}
|
||
|
||
Returns:
|
||
action: (chunk_size, action_dim) np.ndarray
|
||
"""
|
||
observation = batch["observation"]
|
||
instruction = batch["instruction"]
|
||
|
||
# 1. 处理机器人状态(proprioception)
|
||
proprio = None
|
||
if "state" in observation:
|
||
state = observation["state"]
|
||
# 确保是numpy数组并转换为float32
|
||
if isinstance(state, np.ndarray):
|
||
proprio = state.astype(np.float32)
|
||
else:
|
||
proprio = np.array(state, dtype=np.float32)
|
||
|
||
# 2. 处理图像数据
|
||
# 收集所有图像键并排序,确保顺序一致
|
||
image_keys = sorted([k for k in observation.keys() if k.startswith("images.")])
|
||
|
||
# 按相机组织图像数据
|
||
camera_images = {}
|
||
for key in image_keys:
|
||
img_data = observation[key]
|
||
|
||
# 验证图像数据格式
|
||
if img_data is None:
|
||
camera_images[key] = None
|
||
elif isinstance(img_data, np.ndarray):
|
||
# 检查维度:应该是 (IMG_HISTORY_SIZE, H, W, 3)
|
||
if img_data.ndim == 4 and img_data.shape[-1] == 3:
|
||
# 检查是否为空图像
|
||
if img_data.shape[1] > 0 and img_data.shape[2] > 0:
|
||
camera_images[key] = img_data.astype(np.uint8)
|
||
else:
|
||
camera_images[key] = None
|
||
else:
|
||
print(f"警告: {key} 维度不正确,期望 (T, H, W, 3),实际 {img_data.shape}")
|
||
camera_images[key] = None
|
||
else:
|
||
print(f"警告: {key} 数据类型错误,期望 np.ndarray,实际 {type(img_data)}")
|
||
camera_images[key] = None
|
||
|
||
# 按时间步展开图像:[cam1[t0], cam2[t0], cam3[t0], cam1[t1], cam2[t1], cam3[t1], ...]
|
||
images = []
|
||
if camera_images:
|
||
img_history_size = self.policy.args["common"]["img_history_size"]
|
||
for t in range(img_history_size):
|
||
for key in image_keys:
|
||
img_array = camera_images.get(key)
|
||
if img_array is not None:
|
||
# 提取第 t 帧:(H, W, 3)
|
||
images.append(img_array[t])
|
||
else:
|
||
# 相机不可用时使用 None(模型会用背景色填充)
|
||
images.append(None)
|
||
|
||
# 3. 处理指令嵌入
|
||
if hasattr(self, "lang_embeddings") and self.lang_embeddings is not None:
|
||
# 使用预计算的语言嵌入
|
||
if isinstance(instruction, int):
|
||
text_embeds = self.lang_embeddings[instruction]
|
||
elif isinstance(instruction, str):
|
||
# 如果是字符串但有预计算嵌入,可能需要查找表
|
||
# 这里假设直接使用运行时编码
|
||
text_embeds = self.policy.encode_instruction(instruction, device=self.policy.device)
|
||
else:
|
||
raise ValueError(f"指令类型不支持: {type(instruction)}")
|
||
else:
|
||
# 运行时编码指令
|
||
text_embeds = self.policy.encode_instruction(instruction, device=self.policy.device)
|
||
|
||
# 4. 调用策略模型预测动作
|
||
action = self.policy.step(
|
||
proprio=proprio,
|
||
images=images,
|
||
text_embeds=text_embeds,
|
||
)
|
||
|
||
# 5. 返回动作(移除batch维度)
|
||
return action.squeeze(0).cpu().numpy() # (chunk_size, action_dim)
|
||
|
||
def run(self):
|
||
self.server.register_endpoint("get_actions", self.get_actions)
|
||
print(f"Lerobot {self.policy_type.upper()} Server is running...")
|
||
self.server.loop_forever()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("--config_file", type=str, default=None)
|
||
parser.add_argument("--path_to_vision_encoder_model", type=str, default=None)
|
||
parser.add_argument("--path_to_text_encoder_model", type=str, default=None)
|
||
parser.add_argument("--path_to_rdt_model_wights", type=str, default=None)
|
||
parser.add_argument("--device", type=str, default="cuda")
|
||
parser.add_argument("--control_frequency", type=int, default=25)
|
||
parser.add_argument("--pre_lang", type=bool, default=True)
|
||
parser.add_argument("--lang_model", type=str, default=None)
|
||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||
parser.add_argument("--port", type=int, default=8000)
|
||
args = parser.parse_args()
|
||
|
||
with open(args.config_file, "r") as fp:
|
||
config = yaml.safe_load(fp)
|
||
|
||
server = RDTInferenceServer(
|
||
args.path_to_text_encoder_model,
|
||
args.path_to_vision_encoder_model,
|
||
args.path_to_rdt_model_wights,
|
||
config,
|
||
args,
|
||
args.lang_model,
|
||
)
|
||
server.run()
|