代码可以跑起来了
This commit is contained in:
18
detr/main.py
18
detr/main.py
@@ -4,6 +4,7 @@ from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.optim.adamw import AdamW
|
||||
from .models import build_ACT_model, build_CNNMLP_model
|
||||
|
||||
import IPython
|
||||
@@ -30,6 +31,15 @@ def get_args_parser():
|
||||
help="Type of positional embedding to use on top of the image features")
|
||||
parser.add_argument('--camera_names', default=[], type=list, # will be overridden
|
||||
help="A list of camera names")
|
||||
parser.add_argument('--state_dim', default=14, type=int)
|
||||
parser.add_argument('--action_dim', default=14, type=int)
|
||||
parser.add_argument('--use_text', action='store_true')
|
||||
parser.add_argument('--text_encoder_type', default='distilbert', type=str)
|
||||
parser.add_argument('--text_feature_dim', default=768, type=int)
|
||||
parser.add_argument('--text_fusion_type', default='concat_transformer_input', type=str)
|
||||
parser.add_argument('--freeze_text_encoder', action='store_true')
|
||||
parser.add_argument('--text_max_length', default=32, type=int)
|
||||
parser.add_argument('--text_tokenizer_name', default='distilbert-base-uncased', type=str)
|
||||
|
||||
# * Transformer
|
||||
parser.add_argument('--enc_layers', default=4, type=int, # will be overridden
|
||||
@@ -84,8 +94,8 @@ def build_ACT_model_and_optimizer(args_override):
|
||||
"lr": args.lr_backbone,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
|
||||
weight_decay=args.weight_decay)
|
||||
optimizer = AdamW(param_dicts, lr=args.lr,
|
||||
weight_decay=args.weight_decay)
|
||||
|
||||
return model, optimizer
|
||||
|
||||
@@ -107,8 +117,8 @@ def build_CNNMLP_model_and_optimizer(args_override):
|
||||
"lr": args.lr_backbone,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
|
||||
weight_decay=args.weight_decay)
|
||||
optimizer = AdamW(param_dicts, lr=args.lr,
|
||||
weight_decay=args.weight_decay)
|
||||
|
||||
return model, optimizer
|
||||
|
||||
|
||||
@@ -89,9 +89,32 @@ class Backbone(BackboneBase):
|
||||
train_backbone: bool,
|
||||
return_interm_layers: bool,
|
||||
dilation: bool):
|
||||
backbone = getattr(torchvision.models, name)(
|
||||
replace_stride_with_dilation=[False, False, dilation],
|
||||
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm??
|
||||
backbone_builder = getattr(torchvision.models, name)
|
||||
weights = None
|
||||
if is_main_process():
|
||||
weight_enum_name_map = {
|
||||
'resnet18': 'ResNet18_Weights',
|
||||
'resnet34': 'ResNet34_Weights',
|
||||
'resnet50': 'ResNet50_Weights',
|
||||
'resnet101': 'ResNet101_Weights',
|
||||
}
|
||||
enum_name = weight_enum_name_map.get(name)
|
||||
if enum_name is not None and hasattr(torchvision.models, enum_name):
|
||||
weights = getattr(getattr(torchvision.models, enum_name), 'DEFAULT')
|
||||
|
||||
try:
|
||||
backbone = backbone_builder(
|
||||
replace_stride_with_dilation=[False, False, dilation],
|
||||
weights=weights,
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
)
|
||||
except TypeError:
|
||||
# Backward compatibility for older torchvision that still expects `pretrained`.
|
||||
backbone = backbone_builder(
|
||||
replace_stride_with_dilation=[False, False, dilation],
|
||||
pretrained=(weights is not None),
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
)
|
||||
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
|
||||
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
|
||||
|
||||
|
||||
@@ -33,7 +33,8 @@ def get_sinusoid_encoding_table(n_position, d_hid):
|
||||
|
||||
class DETRVAE(nn.Module):
|
||||
""" This is the DETR module that performs object detection """
|
||||
def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names):
|
||||
def __init__(self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names,
|
||||
use_text=False, text_feature_dim=768, text_fusion_type='concat_transformer_input'):
|
||||
""" Initializes the model.
|
||||
Parameters:
|
||||
backbones: torch module of the backbone to be used. See backbone.py
|
||||
@@ -48,17 +49,18 @@ class DETRVAE(nn.Module):
|
||||
self.camera_names = camera_names
|
||||
self.transformer = transformer
|
||||
self.encoder = encoder
|
||||
self.use_text = use_text
|
||||
self.text_fusion_type = text_fusion_type
|
||||
hidden_dim = transformer.d_model
|
||||
self.action_head = nn.Linear(hidden_dim, state_dim)
|
||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||
self.is_pad_head = nn.Linear(hidden_dim, 1)
|
||||
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
||||
if backbones is not None:
|
||||
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
|
||||
self.backbones = nn.ModuleList(backbones)
|
||||
self.input_proj_robot_state = nn.Linear(14, hidden_dim)
|
||||
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||
else:
|
||||
# input_dim = 14 + 7 # robot_state + env_state
|
||||
self.input_proj_robot_state = nn.Linear(14, hidden_dim)
|
||||
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||
self.input_proj_env_state = nn.Linear(7, hidden_dim)
|
||||
self.pos = torch.nn.Embedding(2, hidden_dim)
|
||||
self.backbones = None
|
||||
@@ -66,16 +68,18 @@ class DETRVAE(nn.Module):
|
||||
# encoder extra parameters
|
||||
self.latent_dim = 32 # final size of latent z # TODO tune
|
||||
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
|
||||
self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding
|
||||
self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
|
||||
self.encoder_action_proj = nn.Linear(action_dim, hidden_dim) # project action to embedding
|
||||
self.encoder_joint_proj = nn.Linear(state_dim, hidden_dim) # project qpos to embedding
|
||||
self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
|
||||
self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq
|
||||
|
||||
# decoder extra parameters
|
||||
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
|
||||
self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent
|
||||
num_extra_tokens = 3 if self.use_text else 2
|
||||
self.additional_pos_embed = nn.Embedding(num_extra_tokens, hidden_dim) # latent, proprio, optional text
|
||||
self.text_proj = nn.Linear(text_feature_dim, hidden_dim) if self.use_text else None
|
||||
|
||||
def forward(self, qpos, image, env_state, actions=None, is_pad=None):
|
||||
def forward(self, qpos, image, env_state, text_features=None, actions=None, is_pad=None):
|
||||
"""
|
||||
qpos: batch, qpos_dim
|
||||
image: batch, num_cam, channel, height, width
|
||||
@@ -125,10 +129,25 @@ class DETRVAE(nn.Module):
|
||||
all_cam_pos.append(pos)
|
||||
# proprioception features
|
||||
proprio_input = self.input_proj_robot_state(qpos)
|
||||
extra_input_tokens = None
|
||||
if self.use_text and text_features is not None:
|
||||
if self.text_fusion_type != 'concat_transformer_input':
|
||||
raise NotImplementedError(f'Unsupported text fusion type: {self.text_fusion_type}')
|
||||
text_input = self.text_proj(text_features)
|
||||
extra_input_tokens = text_input.unsqueeze(0)
|
||||
# fold camera dimension into width dimension
|
||||
src = torch.cat(all_cam_features, axis=3)
|
||||
pos = torch.cat(all_cam_pos, axis=3)
|
||||
hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0]
|
||||
hs = self.transformer(
|
||||
src,
|
||||
None,
|
||||
self.query_embed.weight,
|
||||
pos,
|
||||
latent_input,
|
||||
proprio_input,
|
||||
self.additional_pos_embed.weight,
|
||||
extra_input_tokens=extra_input_tokens,
|
||||
)[0]
|
||||
else:
|
||||
qpos = self.input_proj_robot_state(qpos)
|
||||
env_state = self.input_proj_env_state(env_state)
|
||||
@@ -141,7 +160,7 @@ class DETRVAE(nn.Module):
|
||||
|
||||
|
||||
class CNNMLP(nn.Module):
|
||||
def __init__(self, backbones, state_dim, camera_names):
|
||||
def __init__(self, backbones, state_dim, action_dim, camera_names):
|
||||
""" Initializes the model.
|
||||
Parameters:
|
||||
backbones: torch module of the backbone to be used. See backbone.py
|
||||
@@ -153,7 +172,7 @@ class CNNMLP(nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
self.camera_names = camera_names
|
||||
self.action_head = nn.Linear(1000, state_dim) # TODO add more
|
||||
self.action_head = nn.Linear(1000, action_dim) # TODO add more
|
||||
if backbones is not None:
|
||||
self.backbones = nn.ModuleList(backbones)
|
||||
backbone_down_projs = []
|
||||
@@ -166,8 +185,8 @@ class CNNMLP(nn.Module):
|
||||
backbone_down_projs.append(down_proj)
|
||||
self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
|
||||
|
||||
mlp_in_dim = 768 * len(backbones) + 14
|
||||
self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=14, hidden_depth=2)
|
||||
mlp_in_dim = 768 * len(backbones) + state_dim
|
||||
self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=action_dim, hidden_depth=2)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -192,7 +211,7 @@ class CNNMLP(nn.Module):
|
||||
for cam_feature in all_cam_features:
|
||||
flattened_features.append(cam_feature.reshape([bs, -1]))
|
||||
flattened_features = torch.cat(flattened_features, axis=1) # 768 each
|
||||
features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
|
||||
features = torch.cat([flattened_features, qpos], axis=1)
|
||||
a_hat = self.mlp(features)
|
||||
return a_hat
|
||||
|
||||
@@ -227,7 +246,8 @@ def build_encoder(args):
|
||||
|
||||
|
||||
def build(args):
|
||||
state_dim = 14 # TODO hardcode
|
||||
state_dim = args.state_dim
|
||||
action_dim = args.action_dim
|
||||
|
||||
# From state
|
||||
# backbone = None # from state for now, no need for conv nets
|
||||
@@ -245,8 +265,12 @@ def build(args):
|
||||
transformer,
|
||||
encoder,
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
num_queries=args.num_queries,
|
||||
camera_names=args.camera_names,
|
||||
use_text=args.use_text,
|
||||
text_feature_dim=args.text_feature_dim,
|
||||
text_fusion_type=args.text_fusion_type,
|
||||
)
|
||||
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
@@ -255,7 +279,8 @@ def build(args):
|
||||
return model
|
||||
|
||||
def build_cnnmlp(args):
|
||||
state_dim = 14 # TODO hardcode
|
||||
state_dim = args.state_dim
|
||||
action_dim = args.action_dim
|
||||
|
||||
# From state
|
||||
# backbone = None # from state for now, no need for conv nets
|
||||
@@ -268,6 +293,7 @@ def build_cnnmlp(args):
|
||||
model = CNNMLP(
|
||||
backbones,
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
camera_names=args.camera_names,
|
||||
)
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ class Transformer(nn.Module):
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None):
|
||||
def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None, extra_input_tokens=None):
|
||||
# TODO flatten only when input has H and W
|
||||
if len(src.shape) == 4: # has H and W
|
||||
# flatten NxCxHxW to HWxNxC
|
||||
@@ -56,10 +56,19 @@ class Transformer(nn.Module):
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
# mask = mask.flatten(1)
|
||||
|
||||
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
|
||||
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
|
||||
additional_inputs = [latent_input, proprio_input]
|
||||
if extra_input_tokens is not None:
|
||||
if len(extra_input_tokens.shape) == 2:
|
||||
extra_input_tokens = extra_input_tokens.unsqueeze(0)
|
||||
for i in range(extra_input_tokens.shape[0]):
|
||||
additional_inputs.append(extra_input_tokens[i])
|
||||
|
||||
addition_input = torch.stack(additional_inputs, axis=0)
|
||||
if additional_pos_embed is not None:
|
||||
additional_pos_embed = additional_pos_embed[:addition_input.shape[0]]
|
||||
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
|
||||
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
|
||||
|
||||
addition_input = torch.stack([latent_input, proprio_input], axis=0)
|
||||
src = torch.cat([addition_input, src], axis=0)
|
||||
else:
|
||||
assert len(src.shape) == 3
|
||||
|
||||
Reference in New Issue
Block a user