diff --git a/detr/models/detr_vae.py b/detr/models/detr_vae.py index 42005c5..bccfca7 100644 --- a/detr/models/detr_vae.py +++ b/detr/models/detr_vae.py @@ -88,7 +88,7 @@ class DETRVAE(nn.Module): if is_training: # project action sequence to embedding dim, and concat with a CLS token action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim) - qpos_embed = self.encoder_action_proj(qpos) # (bs, hidden_dim) + qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim) qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim) cls_embed = self.cls_embed.weight # (1, hidden_dim) cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)