From 40e175529638dddc1656147d9cbc7a29814bda3c Mon Sep 17 00:00:00 2001 From: Tony Zhao Date: Fri, 23 Jun 2023 13:30:39 -0700 Subject: [PATCH] typo in VAE encoder's joint projection --- detr/models/detr_vae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/detr/models/detr_vae.py b/detr/models/detr_vae.py index 42005c5..bccfca7 100644 --- a/detr/models/detr_vae.py +++ b/detr/models/detr_vae.py @@ -88,7 +88,7 @@ class DETRVAE(nn.Module): if is_training: # project action sequence to embedding dim, and concat with a CLS token action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim) - qpos_embed = self.encoder_action_proj(qpos) # (bs, hidden_dim) + qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim) qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim) cls_embed = self.cls_embed.weight # (1, hidden_dim) cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)