代码可以跑起来了
This commit is contained in:
@@ -46,7 +46,7 @@ class Transformer(nn.Module):
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None):
|
||||
def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None, extra_input_tokens=None):
|
||||
# TODO flatten only when input has H and W
|
||||
if len(src.shape) == 4: # has H and W
|
||||
# flatten NxCxHxW to HWxNxC
|
||||
@@ -56,10 +56,19 @@ class Transformer(nn.Module):
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
# mask = mask.flatten(1)
|
||||
|
||||
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
|
||||
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
|
||||
additional_inputs = [latent_input, proprio_input]
|
||||
if extra_input_tokens is not None:
|
||||
if len(extra_input_tokens.shape) == 2:
|
||||
extra_input_tokens = extra_input_tokens.unsqueeze(0)
|
||||
for i in range(extra_input_tokens.shape[0]):
|
||||
additional_inputs.append(extra_input_tokens[i])
|
||||
|
||||
addition_input = torch.stack(additional_inputs, axis=0)
|
||||
if additional_pos_embed is not None:
|
||||
additional_pos_embed = additional_pos_embed[:addition_input.shape[0]]
|
||||
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
|
||||
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
|
||||
|
||||
addition_input = torch.stack([latent_input, proprio_input], axis=0)
|
||||
src = torch.cat([addition_input, src], axis=0)
|
||||
else:
|
||||
assert len(src.shape) == 3
|
||||
|
||||
Reference in New Issue
Block a user