代码可以跑起来了

This commit is contained in:
2026-02-19 15:32:28 +08:00
parent b701d939c2
commit 88d14221ae
11 changed files with 503 additions and 89 deletions

View File

@@ -89,9 +89,32 @@ class Backbone(BackboneBase):
train_backbone: bool,
return_interm_layers: bool,
dilation: bool):
backbone = getattr(torchvision.models, name)(
replace_stride_with_dilation=[False, False, dilation],
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm??
backbone_builder = getattr(torchvision.models, name)
weights = None
if is_main_process():
weight_enum_name_map = {
'resnet18': 'ResNet18_Weights',
'resnet34': 'ResNet34_Weights',
'resnet50': 'ResNet50_Weights',
'resnet101': 'ResNet101_Weights',
}
enum_name = weight_enum_name_map.get(name)
if enum_name is not None and hasattr(torchvision.models, enum_name):
weights = getattr(getattr(torchvision.models, enum_name), 'DEFAULT')
try:
backbone = backbone_builder(
replace_stride_with_dilation=[False, False, dilation],
weights=weights,
norm_layer=FrozenBatchNorm2d,
)
except TypeError:
# Backward compatibility for older torchvision that still expects `pretrained`.
backbone = backbone_builder(
replace_stride_with_dilation=[False, False, dilation],
pretrained=(weights is not None),
norm_layer=FrozenBatchNorm2d,
)
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)

View File

@@ -33,7 +33,8 @@ def get_sinusoid_encoding_table(n_position, d_hid):
class DETRVAE(nn.Module):
""" This is the DETR module that performs object detection """
def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names):
def __init__(self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names,
use_text=False, text_feature_dim=768, text_fusion_type='concat_transformer_input'):
""" Initializes the model.
Parameters:
backbones: torch module of the backbone to be used. See backbone.py
@@ -48,17 +49,18 @@ class DETRVAE(nn.Module):
self.camera_names = camera_names
self.transformer = transformer
self.encoder = encoder
self.use_text = use_text
self.text_fusion_type = text_fusion_type
hidden_dim = transformer.d_model
self.action_head = nn.Linear(hidden_dim, state_dim)
self.action_head = nn.Linear(hidden_dim, action_dim)
self.is_pad_head = nn.Linear(hidden_dim, 1)
self.query_embed = nn.Embedding(num_queries, hidden_dim)
if backbones is not None:
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
self.backbones = nn.ModuleList(backbones)
self.input_proj_robot_state = nn.Linear(14, hidden_dim)
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
else:
# input_dim = 14 + 7 # robot_state + env_state
self.input_proj_robot_state = nn.Linear(14, hidden_dim)
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
self.input_proj_env_state = nn.Linear(7, hidden_dim)
self.pos = torch.nn.Embedding(2, hidden_dim)
self.backbones = None
@@ -66,16 +68,18 @@ class DETRVAE(nn.Module):
# encoder extra parameters
self.latent_dim = 32 # final size of latent z # TODO tune
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding
self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
self.encoder_action_proj = nn.Linear(action_dim, hidden_dim) # project action to embedding
self.encoder_joint_proj = nn.Linear(state_dim, hidden_dim) # project qpos to embedding
self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq
# decoder extra parameters
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent
num_extra_tokens = 3 if self.use_text else 2
self.additional_pos_embed = nn.Embedding(num_extra_tokens, hidden_dim) # latent, proprio, optional text
self.text_proj = nn.Linear(text_feature_dim, hidden_dim) if self.use_text else None
def forward(self, qpos, image, env_state, actions=None, is_pad=None):
def forward(self, qpos, image, env_state, text_features=None, actions=None, is_pad=None):
"""
qpos: batch, qpos_dim
image: batch, num_cam, channel, height, width
@@ -125,10 +129,25 @@ class DETRVAE(nn.Module):
all_cam_pos.append(pos)
# proprioception features
proprio_input = self.input_proj_robot_state(qpos)
extra_input_tokens = None
if self.use_text and text_features is not None:
if self.text_fusion_type != 'concat_transformer_input':
raise NotImplementedError(f'Unsupported text fusion type: {self.text_fusion_type}')
text_input = self.text_proj(text_features)
extra_input_tokens = text_input.unsqueeze(0)
# fold camera dimension into width dimension
src = torch.cat(all_cam_features, axis=3)
pos = torch.cat(all_cam_pos, axis=3)
hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0]
hs = self.transformer(
src,
None,
self.query_embed.weight,
pos,
latent_input,
proprio_input,
self.additional_pos_embed.weight,
extra_input_tokens=extra_input_tokens,
)[0]
else:
qpos = self.input_proj_robot_state(qpos)
env_state = self.input_proj_env_state(env_state)
@@ -141,7 +160,7 @@ class DETRVAE(nn.Module):
class CNNMLP(nn.Module):
def __init__(self, backbones, state_dim, camera_names):
def __init__(self, backbones, state_dim, action_dim, camera_names):
""" Initializes the model.
Parameters:
backbones: torch module of the backbone to be used. See backbone.py
@@ -153,7 +172,7 @@ class CNNMLP(nn.Module):
"""
super().__init__()
self.camera_names = camera_names
self.action_head = nn.Linear(1000, state_dim) # TODO add more
self.action_head = nn.Linear(1000, action_dim) # TODO add more
if backbones is not None:
self.backbones = nn.ModuleList(backbones)
backbone_down_projs = []
@@ -166,8 +185,8 @@ class CNNMLP(nn.Module):
backbone_down_projs.append(down_proj)
self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
mlp_in_dim = 768 * len(backbones) + 14
self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=14, hidden_depth=2)
mlp_in_dim = 768 * len(backbones) + state_dim
self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=action_dim, hidden_depth=2)
else:
raise NotImplementedError
@@ -192,7 +211,7 @@ class CNNMLP(nn.Module):
for cam_feature in all_cam_features:
flattened_features.append(cam_feature.reshape([bs, -1]))
flattened_features = torch.cat(flattened_features, axis=1) # 768 each
features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
features = torch.cat([flattened_features, qpos], axis=1)
a_hat = self.mlp(features)
return a_hat
@@ -227,7 +246,8 @@ def build_encoder(args):
def build(args):
state_dim = 14 # TODO hardcode
state_dim = args.state_dim
action_dim = args.action_dim
# From state
# backbone = None # from state for now, no need for conv nets
@@ -245,8 +265,12 @@ def build(args):
transformer,
encoder,
state_dim=state_dim,
action_dim=action_dim,
num_queries=args.num_queries,
camera_names=args.camera_names,
use_text=args.use_text,
text_feature_dim=args.text_feature_dim,
text_fusion_type=args.text_fusion_type,
)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
@@ -255,7 +279,8 @@ def build(args):
return model
def build_cnnmlp(args):
state_dim = 14 # TODO hardcode
state_dim = args.state_dim
action_dim = args.action_dim
# From state
# backbone = None # from state for now, no need for conv nets
@@ -268,6 +293,7 @@ def build_cnnmlp(args):
model = CNNMLP(
backbones,
state_dim=state_dim,
action_dim=action_dim,
camera_names=args.camera_names,
)

View File

@@ -46,7 +46,7 @@ class Transformer(nn.Module):
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None):
def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None, extra_input_tokens=None):
# TODO flatten only when input has H and W
if len(src.shape) == 4: # has H and W
# flatten NxCxHxW to HWxNxC
@@ -56,10 +56,19 @@ class Transformer(nn.Module):
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
# mask = mask.flatten(1)
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
additional_inputs = [latent_input, proprio_input]
if extra_input_tokens is not None:
if len(extra_input_tokens.shape) == 2:
extra_input_tokens = extra_input_tokens.unsqueeze(0)
for i in range(extra_input_tokens.shape[0]):
additional_inputs.append(extra_input_tokens[i])
addition_input = torch.stack(additional_inputs, axis=0)
if additional_pos_embed is not None:
additional_pos_embed = additional_pos_embed[:addition_input.shape[0]]
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
addition_input = torch.stack([latent_input, proprio_input], axis=0)
src = torch.cat([addition_input, src], axis=0)
else:
assert len(src.shape) == 3