247 lines
11 KiB
Python
247 lines
11 KiB
Python
import re, sys, os
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
|
from diffusers.schedulers.scheduling_dpmsolver_multistep import \
|
|
DPMSolverMultistepScheduler
|
|
|
|
from pathlib import Path
|
|
# get current workspace
|
|
current_file = Path(__file__)
|
|
sys.path.append(os.path.join(current_file.parent))
|
|
from hub_mixin import CompatiblePyTorchModelHubMixin
|
|
from rdt.model import RDT
|
|
|
|
|
|
class RDTRunner(nn.Module,
|
|
CompatiblePyTorchModelHubMixin,
|
|
repo_url="https://huggingface.co/robotics-diffusion-transformer/rdt-1b"):
|
|
|
|
def __init__(self,
|
|
*,
|
|
action_dim,
|
|
pred_horizon,
|
|
config,
|
|
lang_token_dim,
|
|
img_token_dim,
|
|
state_token_dim,
|
|
max_lang_cond_len,
|
|
img_cond_len,
|
|
lang_pos_embed_config=None,
|
|
img_pos_embed_config=None,
|
|
dtype=torch.bfloat16):
|
|
super(RDTRunner, self).__init__()
|
|
# Create diffusion model
|
|
hidden_size = config['rdt']['hidden_size']
|
|
self.model = RDT(
|
|
output_dim=action_dim,
|
|
horizon=pred_horizon,
|
|
hidden_size=hidden_size,
|
|
depth=config['rdt']['depth'],
|
|
num_heads=config['rdt']['num_heads'],
|
|
max_lang_cond_len=max_lang_cond_len,
|
|
img_cond_len=img_cond_len,
|
|
lang_pos_embed_config=lang_pos_embed_config,
|
|
img_pos_embed_config=img_pos_embed_config,
|
|
dtype=dtype,
|
|
)
|
|
|
|
# Create adpators for various conditional inputs
|
|
self.lang_adaptor = self.build_condition_adapter(config['lang_adaptor'],
|
|
in_features=lang_token_dim,
|
|
out_features=hidden_size)
|
|
self.img_adaptor = self.build_condition_adapter(config['img_adaptor'],
|
|
in_features=img_token_dim,
|
|
out_features=hidden_size)
|
|
# A `state` refers to an action or a proprioception vector
|
|
self.state_adaptor = self.build_condition_adapter(
|
|
config['state_adaptor'],
|
|
in_features=state_token_dim * 2, # state + state mask (indicator)
|
|
out_features=hidden_size)
|
|
|
|
# Create the noise scheduler
|
|
noise_scheduler_config = config['noise_scheduler']
|
|
self.noise_scheduler = DDPMScheduler(
|
|
num_train_timesteps=noise_scheduler_config['num_train_timesteps'],
|
|
beta_schedule=noise_scheduler_config['beta_schedule'],
|
|
prediction_type=noise_scheduler_config['prediction_type'],
|
|
clip_sample=noise_scheduler_config['clip_sample'],
|
|
)
|
|
self.noise_scheduler_sample = DPMSolverMultistepScheduler(
|
|
num_train_timesteps=noise_scheduler_config['num_train_timesteps'],
|
|
beta_schedule=noise_scheduler_config['beta_schedule'],
|
|
prediction_type=noise_scheduler_config['prediction_type'],
|
|
)
|
|
|
|
self.num_train_timesteps = noise_scheduler_config['num_train_timesteps']
|
|
self.num_inference_timesteps = noise_scheduler_config['num_inference_timesteps']
|
|
self.prediction_type = noise_scheduler_config['prediction_type']
|
|
|
|
self.pred_horizon = pred_horizon
|
|
self.action_dim = action_dim
|
|
|
|
print("Diffusion params: %e" %
|
|
sum([p.numel() for p in self.model.parameters()] + [p.numel() for p in self.lang_adaptor.parameters()] +
|
|
[p.numel()
|
|
for p in self.img_adaptor.parameters()] + [p.numel() for p in self.state_adaptor.parameters()]))
|
|
|
|
def build_condition_adapter(self, projector_type, in_features, out_features):
|
|
projector = None
|
|
if projector_type == 'linear':
|
|
projector = nn.Linear(in_features, out_features)
|
|
else:
|
|
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
|
|
if mlp_gelu_match:
|
|
mlp_depth = int(mlp_gelu_match.group(1))
|
|
modules = [nn.Linear(in_features, out_features)]
|
|
for _ in range(1, mlp_depth):
|
|
modules.append(nn.GELU(approximate="tanh"))
|
|
modules.append(nn.Linear(out_features, out_features))
|
|
projector = nn.Sequential(*modules)
|
|
|
|
if projector is None:
|
|
raise ValueError(f'Unknown projector type: {projector_type}')
|
|
|
|
return projector
|
|
|
|
def adapt_conditions(self, lang_tokens, img_tokens, state_tokens):
|
|
'''
|
|
lang_tokens: (batch_size, lang_len, lang_token_dim)
|
|
img_tokens: (batch_size, img_len, img_token_dim)
|
|
state_tokens: (batch_size, state_len, state_token_dim)
|
|
|
|
return: adpated (..., hidden_size) for all input tokens
|
|
'''
|
|
adpated_lang = self.lang_adaptor(lang_tokens)
|
|
adpated_img = self.img_adaptor(img_tokens)
|
|
adpated_state = self.state_adaptor(state_tokens)
|
|
|
|
return adpated_lang, adpated_img, adpated_state
|
|
|
|
def conditional_sample(self, lang_cond, lang_attn_mask, img_cond, state_traj, action_mask, ctrl_freqs):
|
|
'''
|
|
lang_cond: language conditional data, (batch_size, lang_len, hidden_size).
|
|
lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens,
|
|
which should be True-False bool tensor.
|
|
img_cond: image conditional data, (batch_size, img_len, hidden_size).
|
|
state_traj: (batch_size, 1, hidden_size), state trajectory.
|
|
action_mask: (batch_size, 1, action_dim), a 0-1 **float** tensor
|
|
indicating the valid action dimensions.
|
|
ctrl_freqs: (batch_size,), control frequency for each sample.
|
|
|
|
return: (batch_size, horizon, action_dim)
|
|
'''
|
|
device = state_traj.device
|
|
dtype = state_traj.dtype
|
|
noisy_action = torch.randn(size=(state_traj.shape[0], self.pred_horizon, self.action_dim),
|
|
dtype=dtype,
|
|
device=device)
|
|
action_mask = action_mask.expand(-1, self.pred_horizon, -1)
|
|
|
|
# Set step values
|
|
self.noise_scheduler_sample.set_timesteps(self.num_inference_timesteps)
|
|
|
|
for t in self.noise_scheduler_sample.timesteps:
|
|
# Prepare state-action trajectory
|
|
action_traj = torch.cat([noisy_action, action_mask], dim=2)
|
|
action_traj = self.state_adaptor(action_traj)
|
|
state_action_traj = torch.cat([state_traj, action_traj], dim=1)
|
|
|
|
# Predict the model output
|
|
model_output = self.model(state_action_traj,
|
|
ctrl_freqs,
|
|
t.unsqueeze(-1).to(device),
|
|
lang_cond,
|
|
img_cond,
|
|
lang_mask=lang_attn_mask)
|
|
|
|
# Compute previous actions: x_t -> x_t-1
|
|
noisy_action = self.noise_scheduler_sample.step(model_output, t, noisy_action).prev_sample
|
|
noisy_action = noisy_action.to(state_traj.dtype)
|
|
|
|
# Finally apply the action mask to mask invalid action dimensions
|
|
noisy_action = noisy_action * action_mask
|
|
|
|
return noisy_action
|
|
|
|
# ========= Train ============
|
|
def compute_loss(self, lang_tokens, lang_attn_mask, img_tokens, state_tokens, action_gt, action_mask,
|
|
ctrl_freqs) -> torch.Tensor:
|
|
'''
|
|
lang_tokens: (batch_size, lang_len, lang_token_dim)
|
|
lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens,
|
|
which should be True-False bool tensor.
|
|
img_tokens: (batch_size, img_len, img_token_dim)
|
|
state_tokens: (batch_size, 1, state_token_dim)
|
|
action_gt: (batch_size, horizon, state_token_dim), ground-truth actions for supervision
|
|
action_mask: (batch_size, 1, state_token_dim), a 0-1 **float** tensor.
|
|
ctrl_freqs: (batch_size,), control frequency for each sample.
|
|
|
|
return: loss_value, a scalar tensor
|
|
'''
|
|
batch_size = lang_tokens.shape[0]
|
|
device = lang_tokens.device
|
|
# Sample noise that we'll add to the actions
|
|
noise = torch.randn(action_gt.shape, dtype=action_gt.dtype, device=device)
|
|
# Sample random diffusion timesteps
|
|
timesteps = torch.randint(0, self.num_train_timesteps, (batch_size, ), device=device).long()
|
|
# Add noise to the clean actions according to the noise magnitude at each timestep
|
|
# (this is the forward diffusion process)
|
|
noisy_action = self.noise_scheduler.add_noise(action_gt, noise, timesteps)
|
|
|
|
# Concatenate the state and action tokens to form the input sequence
|
|
state_action_traj = torch.cat([state_tokens, noisy_action], dim=1)
|
|
# Append the action mask to the input sequence
|
|
action_mask = action_mask.expand(-1, state_action_traj.shape[1], -1)
|
|
state_action_traj = torch.cat([state_action_traj, action_mask], dim=2)
|
|
# Align the dimension with the hidden size
|
|
lang_cond, img_cond, state_action_traj = self.adapt_conditions(lang_tokens, img_tokens, state_action_traj)
|
|
# Predict the denoised result
|
|
pred = self.model(state_action_traj, ctrl_freqs, timesteps, lang_cond, img_cond, lang_mask=lang_attn_mask)
|
|
|
|
pred_type = self.prediction_type
|
|
if pred_type == 'epsilon':
|
|
target = noise
|
|
elif pred_type == 'sample':
|
|
target = action_gt
|
|
else:
|
|
raise ValueError(f"Unsupported prediction type {pred_type}")
|
|
loss = F.mse_loss(pred, target)
|
|
return loss
|
|
|
|
# ========= Inference ============
|
|
def predict_action(self, lang_tokens, lang_attn_mask, img_tokens, state_tokens, action_mask, ctrl_freqs):
|
|
'''
|
|
lang_tokens: (batch_size, lang_len, lang_token_dim)
|
|
lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens,
|
|
which should be True-False bool tensor.
|
|
img_tokens: (batch_size, img_len, img_token_dim)
|
|
state_tokens: (batch_size, 1, state_token_dim)
|
|
action_mask: (batch_size, 1, action_dim),
|
|
which should be a 0-1 **float** tensor.
|
|
ctrl_freqs: (batch_size,), control frequency for each sample.
|
|
|
|
return: (batch_size, horizon, action_dim), predicted action sequence
|
|
'''
|
|
# Prepare the state and conditions
|
|
state_tokens = torch.cat([state_tokens, action_mask], dim=2)
|
|
lang_cond, img_cond, state_traj = self.adapt_conditions(lang_tokens, img_tokens, state_tokens)
|
|
|
|
# Run sampling
|
|
action_pred = self.conditional_sample(
|
|
lang_cond,
|
|
lang_attn_mask,
|
|
img_cond,
|
|
state_traj,
|
|
action_mask,
|
|
ctrl_freqs,
|
|
)
|
|
|
|
return action_pred
|
|
|
|
def forward(self, *args, **kwargs) -> torch.Tensor:
|
|
return self.compute_loss(*args, **kwargs)
|