代码可以跑起来了

This commit is contained in:
2026-02-19 15:32:28 +08:00
parent b701d939c2
commit 88d14221ae
11 changed files with 503 additions and 89 deletions

View File

@@ -46,7 +46,7 @@ class Transformer(nn.Module):
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None):
def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None, extra_input_tokens=None):
# TODO flatten only when input has H and W
if len(src.shape) == 4: # has H and W
# flatten NxCxHxW to HWxNxC
@@ -56,10 +56,19 @@ class Transformer(nn.Module):
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
# mask = mask.flatten(1)
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
additional_inputs = [latent_input, proprio_input]
if extra_input_tokens is not None:
if len(extra_input_tokens.shape) == 2:
extra_input_tokens = extra_input_tokens.unsqueeze(0)
for i in range(extra_input_tokens.shape[0]):
additional_inputs.append(extra_input_tokens[i])
addition_input = torch.stack(additional_inputs, axis=0)
if additional_pos_embed is not None:
additional_pos_embed = additional_pos_embed[:addition_input.shape[0]]
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
addition_input = torch.stack([latent_input, proprio_input], axis=0)
src = torch.cat([addition_input, src], axis=0)
else:
assert len(src.shape) == 3