# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved """ DETR model and criterion classes. """ import torch from torch import nn from torch.autograd import Variable from .backbone import build_backbone from .transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer import numpy as np import IPython e = IPython.embed def reparametrize(mu, logvar): std = logvar.div(2).exp() eps = Variable(std.data.new(std.size()).normal_()) return mu + std * eps def get_sinusoid_encoding_table(n_position, d_hid): def get_position_angle_vec(position): return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 return torch.FloatTensor(sinusoid_table).unsqueeze(0) class DETRVAE(nn.Module): """ This is the DETR module that performs object detection """ def __init__(self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names, use_text=False, text_feature_dim=768, text_fusion_type='concat_transformer_input'): """ Initializes the model. Parameters: backbones: torch module of the backbone to be used. See backbone.py transformer: torch module of the transformer architecture. See transformer.py state_dim: robot state dimension of the environment num_queries: number of object queries, ie detection slot. This is the maximal number of objects DETR can detect in a single image. For COCO, we recommend 100 queries. aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. """ super().__init__() self.num_queries = num_queries self.camera_names = camera_names self.transformer = transformer self.encoder = encoder self.use_text = use_text self.text_fusion_type = text_fusion_type hidden_dim = transformer.d_model self.action_head = nn.Linear(hidden_dim, action_dim) self.is_pad_head = nn.Linear(hidden_dim, 1) self.query_embed = nn.Embedding(num_queries, hidden_dim) if backbones is not None: self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1) self.backbones = nn.ModuleList(backbones) self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim) else: self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim) self.input_proj_env_state = nn.Linear(7, hidden_dim) self.pos = torch.nn.Embedding(2, hidden_dim) self.backbones = None # encoder extra parameters self.latent_dim = 32 # final size of latent z # TODO tune self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding self.encoder_action_proj = nn.Linear(action_dim, hidden_dim) # project action to embedding self.encoder_joint_proj = nn.Linear(state_dim, hidden_dim) # project qpos to embedding self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq # decoder extra parameters self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding num_extra_tokens = 3 if self.use_text else 2 self.additional_pos_embed = nn.Embedding(num_extra_tokens, hidden_dim) # latent, proprio, optional text self.text_proj = nn.Linear(text_feature_dim, hidden_dim) if self.use_text else None def forward(self, qpos, image, env_state, text_features=None, actions=None, is_pad=None): """ qpos: batch, qpos_dim image: batch, num_cam, channel, height, width env_state: None actions: batch, seq, action_dim """ is_training = actions is not None # train or val bs, _ = qpos.shape ### Obtain latent z from action sequence if is_training: # project action sequence to embedding dim, and concat with a CLS token action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim) qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim) qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim) cls_embed = self.cls_embed.weight # (1, hidden_dim) cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim) encoder_input = torch.cat([cls_embed, qpos_embed, action_embed], axis=1) # (bs, seq+1, hidden_dim) encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim) # do not mask cls token cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1) # obtain position embedding pos_embed = self.pos_table.clone().detach() pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim) # query model encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad) encoder_output = encoder_output[0] # take cls output only latent_info = self.latent_proj(encoder_output) mu = latent_info[:, :self.latent_dim] logvar = latent_info[:, self.latent_dim:] latent_sample = reparametrize(mu, logvar) latent_input = self.latent_out_proj(latent_sample) else: mu = logvar = None latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device) latent_input = self.latent_out_proj(latent_sample) if self.backbones is not None: # Image observation features and position embeddings all_cam_features = [] all_cam_pos = [] for cam_id, cam_name in enumerate(self.camera_names): features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED features = features[0] # take the last layer feature pos = pos[0] all_cam_features.append(self.input_proj(features)) all_cam_pos.append(pos) # proprioception features proprio_input = self.input_proj_robot_state(qpos) extra_input_tokens = None if self.use_text and text_features is not None: if self.text_fusion_type != 'concat_transformer_input': raise NotImplementedError(f'Unsupported text fusion type: {self.text_fusion_type}') text_input = self.text_proj(text_features) extra_input_tokens = text_input.unsqueeze(0) # fold camera dimension into width dimension src = torch.cat(all_cam_features, axis=3) pos = torch.cat(all_cam_pos, axis=3) hs = self.transformer( src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight, extra_input_tokens=extra_input_tokens, )[0] else: qpos = self.input_proj_robot_state(qpos) env_state = self.input_proj_env_state(env_state) transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2 hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0] a_hat = self.action_head(hs) is_pad_hat = self.is_pad_head(hs) return a_hat, is_pad_hat, [mu, logvar] class CNNMLP(nn.Module): def __init__(self, backbones, state_dim, action_dim, camera_names): """ Initializes the model. Parameters: backbones: torch module of the backbone to be used. See backbone.py transformer: torch module of the transformer architecture. See transformer.py state_dim: robot state dimension of the environment num_queries: number of object queries, ie detection slot. This is the maximal number of objects DETR can detect in a single image. For COCO, we recommend 100 queries. aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. """ super().__init__() self.camera_names = camera_names self.action_head = nn.Linear(1000, action_dim) # TODO add more if backbones is not None: self.backbones = nn.ModuleList(backbones) backbone_down_projs = [] for backbone in backbones: down_proj = nn.Sequential( nn.Conv2d(backbone.num_channels, 128, kernel_size=5), nn.Conv2d(128, 64, kernel_size=5), nn.Conv2d(64, 32, kernel_size=5) ) backbone_down_projs.append(down_proj) self.backbone_down_projs = nn.ModuleList(backbone_down_projs) mlp_in_dim = 768 * len(backbones) + state_dim self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=action_dim, hidden_depth=2) else: raise NotImplementedError def forward(self, qpos, image, env_state, actions=None): """ qpos: batch, qpos_dim image: batch, num_cam, channel, height, width env_state: None actions: batch, seq, action_dim """ is_training = actions is not None # train or val bs, _ = qpos.shape # Image observation features and position embeddings all_cam_features = [] for cam_id, cam_name in enumerate(self.camera_names): features, pos = self.backbones[cam_id](image[:, cam_id]) features = features[0] # take the last layer feature pos = pos[0] # not used all_cam_features.append(self.backbone_down_projs[cam_id](features)) # flatten everything flattened_features = [] for cam_feature in all_cam_features: flattened_features.append(cam_feature.reshape([bs, -1])) flattened_features = torch.cat(flattened_features, axis=1) # 768 each features = torch.cat([flattened_features, qpos], axis=1) a_hat = self.mlp(features) return a_hat def mlp(input_dim, hidden_dim, output_dim, hidden_depth): if hidden_depth == 0: mods = [nn.Linear(input_dim, output_dim)] else: mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)] for i in range(hidden_depth - 1): mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)] mods.append(nn.Linear(hidden_dim, output_dim)) trunk = nn.Sequential(*mods) return trunk def build_encoder(args): d_model = args.hidden_dim # 256 dropout = args.dropout # 0.1 nhead = args.nheads # 8 dim_feedforward = args.dim_feedforward # 2048 num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder normalize_before = args.pre_norm # False activation = "relu" encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before) encoder_norm = nn.LayerNorm(d_model) if normalize_before else None encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) return encoder def build(args): state_dim = args.state_dim action_dim = args.action_dim # From state # backbone = None # from state for now, no need for conv nets # From image backbones = [] backbone = build_backbone(args) backbones.append(backbone) transformer = build_transformer(args) encoder = build_encoder(args) model = DETRVAE( backbones, transformer, encoder, state_dim=state_dim, action_dim=action_dim, num_queries=args.num_queries, camera_names=args.camera_names, use_text=args.use_text, text_feature_dim=args.text_feature_dim, text_fusion_type=args.text_fusion_type, ) n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print("number of parameters: %.2fM" % (n_parameters/1e6,)) return model def build_cnnmlp(args): state_dim = args.state_dim action_dim = args.action_dim # From state # backbone = None # from state for now, no need for conv nets # From image backbones = [] for _ in args.camera_names: backbone = build_backbone(args) backbones.append(backbone) model = CNNMLP( backbones, state_dim=state_dim, action_dim=action_dim, camera_names=args.camera_names, ) n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print("number of parameters: %.2fM" % (n_parameters/1e6,)) return model