代码可以跑起来了
This commit is contained in:
@@ -33,7 +33,8 @@ def get_sinusoid_encoding_table(n_position, d_hid):
|
||||
|
||||
class DETRVAE(nn.Module):
|
||||
""" This is the DETR module that performs object detection """
|
||||
def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names):
|
||||
def __init__(self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names,
|
||||
use_text=False, text_feature_dim=768, text_fusion_type='concat_transformer_input'):
|
||||
""" Initializes the model.
|
||||
Parameters:
|
||||
backbones: torch module of the backbone to be used. See backbone.py
|
||||
@@ -48,17 +49,18 @@ class DETRVAE(nn.Module):
|
||||
self.camera_names = camera_names
|
||||
self.transformer = transformer
|
||||
self.encoder = encoder
|
||||
self.use_text = use_text
|
||||
self.text_fusion_type = text_fusion_type
|
||||
hidden_dim = transformer.d_model
|
||||
self.action_head = nn.Linear(hidden_dim, state_dim)
|
||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||
self.is_pad_head = nn.Linear(hidden_dim, 1)
|
||||
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
||||
if backbones is not None:
|
||||
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
|
||||
self.backbones = nn.ModuleList(backbones)
|
||||
self.input_proj_robot_state = nn.Linear(14, hidden_dim)
|
||||
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||
else:
|
||||
# input_dim = 14 + 7 # robot_state + env_state
|
||||
self.input_proj_robot_state = nn.Linear(14, hidden_dim)
|
||||
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||
self.input_proj_env_state = nn.Linear(7, hidden_dim)
|
||||
self.pos = torch.nn.Embedding(2, hidden_dim)
|
||||
self.backbones = None
|
||||
@@ -66,16 +68,18 @@ class DETRVAE(nn.Module):
|
||||
# encoder extra parameters
|
||||
self.latent_dim = 32 # final size of latent z # TODO tune
|
||||
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
|
||||
self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding
|
||||
self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
|
||||
self.encoder_action_proj = nn.Linear(action_dim, hidden_dim) # project action to embedding
|
||||
self.encoder_joint_proj = nn.Linear(state_dim, hidden_dim) # project qpos to embedding
|
||||
self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
|
||||
self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq
|
||||
|
||||
# decoder extra parameters
|
||||
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
|
||||
self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent
|
||||
num_extra_tokens = 3 if self.use_text else 2
|
||||
self.additional_pos_embed = nn.Embedding(num_extra_tokens, hidden_dim) # latent, proprio, optional text
|
||||
self.text_proj = nn.Linear(text_feature_dim, hidden_dim) if self.use_text else None
|
||||
|
||||
def forward(self, qpos, image, env_state, actions=None, is_pad=None):
|
||||
def forward(self, qpos, image, env_state, text_features=None, actions=None, is_pad=None):
|
||||
"""
|
||||
qpos: batch, qpos_dim
|
||||
image: batch, num_cam, channel, height, width
|
||||
@@ -125,10 +129,25 @@ class DETRVAE(nn.Module):
|
||||
all_cam_pos.append(pos)
|
||||
# proprioception features
|
||||
proprio_input = self.input_proj_robot_state(qpos)
|
||||
extra_input_tokens = None
|
||||
if self.use_text and text_features is not None:
|
||||
if self.text_fusion_type != 'concat_transformer_input':
|
||||
raise NotImplementedError(f'Unsupported text fusion type: {self.text_fusion_type}')
|
||||
text_input = self.text_proj(text_features)
|
||||
extra_input_tokens = text_input.unsqueeze(0)
|
||||
# fold camera dimension into width dimension
|
||||
src = torch.cat(all_cam_features, axis=3)
|
||||
pos = torch.cat(all_cam_pos, axis=3)
|
||||
hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0]
|
||||
hs = self.transformer(
|
||||
src,
|
||||
None,
|
||||
self.query_embed.weight,
|
||||
pos,
|
||||
latent_input,
|
||||
proprio_input,
|
||||
self.additional_pos_embed.weight,
|
||||
extra_input_tokens=extra_input_tokens,
|
||||
)[0]
|
||||
else:
|
||||
qpos = self.input_proj_robot_state(qpos)
|
||||
env_state = self.input_proj_env_state(env_state)
|
||||
@@ -141,7 +160,7 @@ class DETRVAE(nn.Module):
|
||||
|
||||
|
||||
class CNNMLP(nn.Module):
|
||||
def __init__(self, backbones, state_dim, camera_names):
|
||||
def __init__(self, backbones, state_dim, action_dim, camera_names):
|
||||
""" Initializes the model.
|
||||
Parameters:
|
||||
backbones: torch module of the backbone to be used. See backbone.py
|
||||
@@ -153,7 +172,7 @@ class CNNMLP(nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
self.camera_names = camera_names
|
||||
self.action_head = nn.Linear(1000, state_dim) # TODO add more
|
||||
self.action_head = nn.Linear(1000, action_dim) # TODO add more
|
||||
if backbones is not None:
|
||||
self.backbones = nn.ModuleList(backbones)
|
||||
backbone_down_projs = []
|
||||
@@ -166,8 +185,8 @@ class CNNMLP(nn.Module):
|
||||
backbone_down_projs.append(down_proj)
|
||||
self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
|
||||
|
||||
mlp_in_dim = 768 * len(backbones) + 14
|
||||
self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=14, hidden_depth=2)
|
||||
mlp_in_dim = 768 * len(backbones) + state_dim
|
||||
self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=action_dim, hidden_depth=2)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -192,7 +211,7 @@ class CNNMLP(nn.Module):
|
||||
for cam_feature in all_cam_features:
|
||||
flattened_features.append(cam_feature.reshape([bs, -1]))
|
||||
flattened_features = torch.cat(flattened_features, axis=1) # 768 each
|
||||
features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
|
||||
features = torch.cat([flattened_features, qpos], axis=1)
|
||||
a_hat = self.mlp(features)
|
||||
return a_hat
|
||||
|
||||
@@ -227,7 +246,8 @@ def build_encoder(args):
|
||||
|
||||
|
||||
def build(args):
|
||||
state_dim = 14 # TODO hardcode
|
||||
state_dim = args.state_dim
|
||||
action_dim = args.action_dim
|
||||
|
||||
# From state
|
||||
# backbone = None # from state for now, no need for conv nets
|
||||
@@ -245,8 +265,12 @@ def build(args):
|
||||
transformer,
|
||||
encoder,
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
num_queries=args.num_queries,
|
||||
camera_names=args.camera_names,
|
||||
use_text=args.use_text,
|
||||
text_feature_dim=args.text_feature_dim,
|
||||
text_fusion_type=args.text_fusion_type,
|
||||
)
|
||||
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
@@ -255,7 +279,8 @@ def build(args):
|
||||
return model
|
||||
|
||||
def build_cnnmlp(args):
|
||||
state_dim = 14 # TODO hardcode
|
||||
state_dim = args.state_dim
|
||||
action_dim = args.action_dim
|
||||
|
||||
# From state
|
||||
# backbone = None # from state for now, no need for conv nets
|
||||
@@ -268,6 +293,7 @@ def build_cnnmlp(args):
|
||||
model = CNNMLP(
|
||||
backbones,
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
camera_names=args.camera_names,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user