typo in VAE encoder's joint projection
This commit is contained in:
@@ -88,7 +88,7 @@ class DETRVAE(nn.Module):
|
|||||||
if is_training:
|
if is_training:
|
||||||
# project action sequence to embedding dim, and concat with a CLS token
|
# project action sequence to embedding dim, and concat with a CLS token
|
||||||
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
|
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
|
||||||
qpos_embed = self.encoder_action_proj(qpos) # (bs, hidden_dim)
|
qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
|
||||||
qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
|
qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
|
||||||
cls_embed = self.cls_embed.weight # (1, hidden_dim)
|
cls_embed = self.cls_embed.weight # (1, hidden_dim)
|
||||||
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
|
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
|
||||||
|
|||||||
Reference in New Issue
Block a user