typo in VAE encoder's joint projection

This commit is contained in:
Tony Zhao
2023-06-23 13:30:39 -07:00
parent 28d51cb83f
commit 40e1755296

View File

@@ -88,7 +88,7 @@ class DETRVAE(nn.Module):
if is_training:
# project action sequence to embedding dim, and concat with a CLS token
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
qpos_embed = self.encoder_action_proj(qpos) # (bs, hidden_dim)
qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
cls_embed = self.cls_embed.weight # (1, hidden_dim)
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)