Files
aloha/detr/models/detr_vae.py
Tony Zhao 0f54d10d7f Change network architecture to have joints as ACT encoder input
- does not affect performance for position control
- not backward compatible: policies trained before this commit will not load because of additional params.
2023-04-14 11:50:27 -07:00

279 lines
12 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DETR model and criterion classes.
"""
import torch
from torch import nn
from torch.autograd import Variable
from .backbone import build_backbone
from .transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer
import numpy as np
import IPython
e = IPython.embed
def reparametrize(mu, logvar):
std = logvar.div(2).exp()
eps = Variable(std.data.new(std.size()).normal_())
return mu + std * eps
def get_sinusoid_encoding_table(n_position, d_hid):
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
class DETRVAE(nn.Module):
""" This is the DETR module that performs object detection """
def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names):
""" Initializes the model.
Parameters:
backbones: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
state_dim: robot state dimension of the environment
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.num_queries = num_queries
self.camera_names = camera_names
self.transformer = transformer
self.encoder = encoder
hidden_dim = transformer.d_model
self.action_head = nn.Linear(hidden_dim, state_dim)
self.is_pad_head = nn.Linear(hidden_dim, 1)
self.query_embed = nn.Embedding(num_queries, hidden_dim)
if backbones is not None:
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
self.backbones = nn.ModuleList(backbones)
self.input_proj_robot_state = nn.Linear(14, hidden_dim)
else:
# input_dim = 14 + 7 # robot_state + env_state
self.input_proj_robot_state = nn.Linear(14, hidden_dim)
self.input_proj_env_state = nn.Linear(7, hidden_dim)
self.pos = torch.nn.Embedding(2, hidden_dim)
self.backbones = None
# encoder extra parameters
self.latent_dim = 32 # final size of latent z # TODO tune
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding
self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq
# decoder extra parameters
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent
def forward(self, qpos, image, env_state, actions=None, is_pad=None):
"""
qpos: batch, qpos_dim
image: batch, num_cam, channel, height, width
env_state: None
actions: batch, seq, action_dim
"""
is_training = actions is not None # train or val
bs, _ = qpos.shape
### Obtain latent z from action sequence
if is_training:
# project action sequence to embedding dim, and concat with a CLS token
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
qpos_embed = self.encoder_action_proj(qpos) # (bs, hidden_dim)
qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
cls_embed = self.cls_embed.weight # (1, hidden_dim)
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
encoder_input = torch.cat([cls_embed, qpos_embed, action_embed], axis=1) # (bs, seq+1, hidden_dim)
encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
# do not mask cls token
cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
# obtain position embedding
pos_embed = self.pos_table.clone().detach()
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
# query model
encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad)
encoder_output = encoder_output[0] # take cls output only
latent_info = self.latent_proj(encoder_output)
mu = latent_info[:, :self.latent_dim]
logvar = latent_info[:, self.latent_dim:]
latent_sample = reparametrize(mu, logvar)
latent_input = self.latent_out_proj(latent_sample)
else:
mu = logvar = None
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
latent_input = self.latent_out_proj(latent_sample)
if self.backbones is not None:
# Image observation features and position embeddings
all_cam_features = []
all_cam_pos = []
for cam_id, cam_name in enumerate(self.camera_names):
features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
features = features[0] # take the last layer feature
pos = pos[0]
all_cam_features.append(self.input_proj(features))
all_cam_pos.append(pos)
# proprioception features
proprio_input = self.input_proj_robot_state(qpos)
# fold camera dimension into width dimension
src = torch.cat(all_cam_features, axis=3)
pos = torch.cat(all_cam_pos, axis=3)
hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0]
else:
qpos = self.input_proj_robot_state(qpos)
env_state = self.input_proj_env_state(env_state)
transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0]
a_hat = self.action_head(hs)
is_pad_hat = self.is_pad_head(hs)
return a_hat, is_pad_hat, [mu, logvar]
class CNNMLP(nn.Module):
def __init__(self, backbones, state_dim, camera_names):
""" Initializes the model.
Parameters:
backbones: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
state_dim: robot state dimension of the environment
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.camera_names = camera_names
self.action_head = nn.Linear(1000, state_dim) # TODO add more
if backbones is not None:
self.backbones = nn.ModuleList(backbones)
backbone_down_projs = []
for backbone in backbones:
down_proj = nn.Sequential(
nn.Conv2d(backbone.num_channels, 128, kernel_size=5),
nn.Conv2d(128, 64, kernel_size=5),
nn.Conv2d(64, 32, kernel_size=5)
)
backbone_down_projs.append(down_proj)
self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
mlp_in_dim = 768 * len(backbones) + 14
self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=14, hidden_depth=2)
else:
raise NotImplementedError
def forward(self, qpos, image, env_state, actions=None):
"""
qpos: batch, qpos_dim
image: batch, num_cam, channel, height, width
env_state: None
actions: batch, seq, action_dim
"""
is_training = actions is not None # train or val
bs, _ = qpos.shape
# Image observation features and position embeddings
all_cam_features = []
for cam_id, cam_name in enumerate(self.camera_names):
features, pos = self.backbones[cam_id](image[:, cam_id])
features = features[0] # take the last layer feature
pos = pos[0] # not used
all_cam_features.append(self.backbone_down_projs[cam_id](features))
# flatten everything
flattened_features = []
for cam_feature in all_cam_features:
flattened_features.append(cam_feature.reshape([bs, -1]))
flattened_features = torch.cat(flattened_features, axis=1) # 768 each
features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
a_hat = self.mlp(features)
return a_hat
def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
if hidden_depth == 0:
mods = [nn.Linear(input_dim, output_dim)]
else:
mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
for i in range(hidden_depth - 1):
mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
mods.append(nn.Linear(hidden_dim, output_dim))
trunk = nn.Sequential(*mods)
return trunk
def build_encoder(args):
d_model = args.hidden_dim # 256
dropout = args.dropout # 0.1
nhead = args.nheads # 8
dim_feedforward = args.dim_feedforward # 2048
num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder
normalize_before = args.pre_norm # False
activation = "relu"
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
return encoder
def build(args):
state_dim = 14 # TODO hardcode
# From state
# backbone = None # from state for now, no need for conv nets
# From image
backbones = []
backbone = build_backbone(args)
backbones.append(backbone)
transformer = build_transformer(args)
encoder = build_encoder(args)
model = DETRVAE(
backbones,
transformer,
encoder,
state_dim=state_dim,
num_queries=args.num_queries,
camera_names=args.camera_names,
)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of parameters: %.2fM" % (n_parameters/1e6,))
return model
def build_cnnmlp(args):
state_dim = 14 # TODO hardcode
# From state
# backbone = None # from state for now, no need for conv nets
# From image
backbones = []
for _ in args.camera_names:
backbone = build_backbone(args)
backbones.append(backbone)
model = CNNMLP(
backbones,
state_dim=state_dim,
camera_names=args.camera_names,
)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of parameters: %.2fM" % (n_parameters/1e6,))
return model