157 lines
7.0 KiB
Python
157 lines
7.0 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
# --------------------------------------------------------
|
|
# References:
|
|
# DiT: https://github.com/facebookresearch/DiT
|
|
# GLIDE: https://github.com/openai/glide-text2im
|
|
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
|
# --------------------------------------------------------
|
|
from collections import OrderedDict
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from pathlib import Path
|
|
import sys, os
|
|
# get current workspace
|
|
current_file = Path(__file__)
|
|
sys.path.append(str(current_file.parent.parent))
|
|
|
|
from rdt.blocks import (FinalLayer, RDTBlock, TimestepEmbedder, get_1d_sincos_pos_embed_from_grid,
|
|
get_multimodal_cond_pos_embed)
|
|
|
|
|
|
class RDT(nn.Module):
|
|
"""
|
|
Class for Robotics Diffusion Transformers.
|
|
"""
|
|
|
|
def __init__(self,
|
|
output_dim=128,
|
|
horizon=32,
|
|
hidden_size=1152,
|
|
depth=28,
|
|
num_heads=16,
|
|
max_lang_cond_len=1024,
|
|
img_cond_len=4096,
|
|
lang_pos_embed_config=None,
|
|
img_pos_embed_config=None,
|
|
dtype=torch.bfloat16):
|
|
super().__init__()
|
|
self.horizon = horizon
|
|
self.hidden_size = hidden_size
|
|
self.max_lang_cond_len = max_lang_cond_len
|
|
self.img_cond_len = img_cond_len
|
|
self.dtype = dtype
|
|
self.lang_pos_embed_config = lang_pos_embed_config
|
|
self.img_pos_embed_config = img_pos_embed_config
|
|
|
|
self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype)
|
|
self.freq_embedder = TimestepEmbedder(hidden_size, dtype=dtype)
|
|
|
|
# We will use trainable sin-cos embeddings
|
|
# [timestep; state; action]
|
|
self.x_pos_embed = nn.Parameter(torch.zeros(1, horizon + 3, hidden_size))
|
|
# Language conditions
|
|
self.lang_cond_pos_embed = nn.Parameter(torch.zeros(1, max_lang_cond_len, hidden_size))
|
|
# Image conditions
|
|
self.img_cond_pos_embed = nn.Parameter(torch.zeros(1, img_cond_len, hidden_size))
|
|
|
|
self.blocks = nn.ModuleList([RDTBlock(hidden_size, num_heads) for _ in range(depth)])
|
|
self.final_layer = FinalLayer(hidden_size, output_dim)
|
|
self.initialize_weights()
|
|
|
|
def initialize_weights(self):
|
|
# Initialize transformer layers:
|
|
def _basic_init(module):
|
|
if isinstance(module, nn.Linear):
|
|
torch.nn.init.xavier_uniform_(module.weight)
|
|
if module.bias is not None:
|
|
nn.init.constant_(module.bias, 0)
|
|
|
|
self.apply(_basic_init)
|
|
|
|
# Initialize pos_embed by sin-cos embedding
|
|
x_pos_embed = get_multimodal_cond_pos_embed(embed_dim=self.hidden_size,
|
|
mm_cond_lens=OrderedDict([
|
|
('timestep', 1),
|
|
('ctrl_freq', 1),
|
|
('state', 1),
|
|
('action', self.horizon),
|
|
]))
|
|
self.x_pos_embed.data.copy_(torch.from_numpy(x_pos_embed).float().unsqueeze(0))
|
|
|
|
if self.lang_pos_embed_config is None:
|
|
lang_cond_pos_embed = get_1d_sincos_pos_embed_from_grid(self.hidden_size,
|
|
torch.arange(self.max_lang_cond_len))
|
|
else:
|
|
lang_cond_pos_embed = get_multimodal_cond_pos_embed(embed_dim=self.hidden_size,
|
|
mm_cond_lens=OrderedDict(self.lang_pos_embed_config),
|
|
embed_modality=False)
|
|
self.lang_cond_pos_embed.data.copy_(torch.from_numpy(lang_cond_pos_embed).float().unsqueeze(0))
|
|
|
|
if self.img_pos_embed_config is None:
|
|
img_cond_pos_embed = get_1d_sincos_pos_embed_from_grid(self.hidden_size, torch.arange(self.img_cond_len))
|
|
else:
|
|
img_cond_pos_embed = get_multimodal_cond_pos_embed(embed_dim=self.hidden_size,
|
|
mm_cond_lens=OrderedDict(self.img_pos_embed_config),
|
|
embed_modality=False)
|
|
self.img_cond_pos_embed.data.copy_(torch.from_numpy(img_cond_pos_embed).float().unsqueeze(0))
|
|
|
|
# Initialize timestep and control freq embedding MLP
|
|
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
|
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
|
nn.init.normal_(self.freq_embedder.mlp[0].weight, std=0.02)
|
|
nn.init.normal_(self.freq_embedder.mlp[2].weight, std=0.02)
|
|
|
|
# Initialize the final layer: zero-out the final linear layer
|
|
nn.init.constant_(self.final_layer.ffn_final.fc2.weight, 0)
|
|
nn.init.constant_(self.final_layer.ffn_final.fc2.bias, 0)
|
|
|
|
# Move all the params to given data type:
|
|
self.to(self.dtype)
|
|
|
|
def forward(self, x, freq, t, lang_c, img_c, lang_mask=None, img_mask=None):
|
|
"""
|
|
Forward pass of RDT.
|
|
|
|
x: (B, T, D), state + action token sequence, T = horizon + 1,
|
|
dimension D is assumed to be the same as the hidden size.
|
|
freq: (B,), a scalar indicating control frequency.
|
|
t: (B,) or (1,), diffusion timesteps.
|
|
lang_c: (B, L_lang, D) or None, language condition tokens (variable length),
|
|
dimension D is assumed to be the same as the hidden size.
|
|
img_c: (B, L_img, D) or None, image condition tokens (fixed length),
|
|
dimension D is assumed to be the same as the hidden size.
|
|
lang_mask: (B, L_lang) or None, language condition mask (True for valid).
|
|
img_mask: (B, L_img) or None, image condition mask (True for valid).
|
|
"""
|
|
t = self.t_embedder(t).unsqueeze(1) # (B, 1, D) or (1, 1, D)
|
|
freq = self.freq_embedder(freq).unsqueeze(1) # (B, 1, D)
|
|
# Append timestep to the input tokens
|
|
if t.shape[0] == 1:
|
|
t = t.expand(x.shape[0], -1, -1)
|
|
x = torch.cat([t, freq, x], dim=1) # (B, T+1, D)
|
|
|
|
# Add multimodal position embeddings
|
|
x = x + self.x_pos_embed
|
|
# Note the lang is of variable length
|
|
lang_c = lang_c + self.lang_cond_pos_embed[:, :lang_c.shape[1]]
|
|
img_c = img_c + self.img_cond_pos_embed
|
|
|
|
# Forward pass
|
|
conds = [lang_c, img_c]
|
|
masks = [lang_mask, img_mask]
|
|
for i, block in enumerate(self.blocks):
|
|
c, mask = conds[i % 2], masks[i % 2]
|
|
x = block(x, c, mask) # (B, T+1, D)
|
|
# Inject the language condition at the final layer
|
|
x = self.final_layer(x) # (B, T+1, out_channels)
|
|
|
|
# Only preserve the action tokens
|
|
x = x[:, -self.horizon:]
|
|
return x
|