代码可以跑起来了

This commit is contained in:
2026-02-19 15:32:28 +08:00
parent b701d939c2
commit 88d14221ae
11 changed files with 503 additions and 89 deletions

View File

@@ -3,6 +3,7 @@ from torch.nn import functional as F
import torchvision.transforms as transforms
from detr.main import build_ACT_model_and_optimizer, build_CNNMLP_model_and_optimizer
from models.text_encoder import DistilBERTTextEncoder
import IPython
e = IPython.embed
@@ -13,18 +14,44 @@ class ACTPolicy(nn.Module):
self.model = model # CVAE decoder
self.optimizer = optimizer
self.kl_weight = args_override['kl_weight']
self.use_text = args_override.get('use_text', False)
self.text_encoder = None
if self.use_text:
text_encoder_type = args_override.get('text_encoder_type', 'distilbert')
if text_encoder_type != 'distilbert':
raise NotImplementedError(f'Unsupported text encoder: {text_encoder_type}')
self.text_encoder = DistilBERTTextEncoder(
model_name=args_override.get('text_tokenizer_name', 'distilbert-base-uncased'),
output_dim=args_override.get('text_feature_dim', 768),
freeze=args_override.get('freeze_text_encoder', True),
)
print(f'KL Weight {self.kl_weight}')
def __call__(self, qpos, image, actions=None, is_pad=None):
def __call__(self, qpos, image, text_input_ids=None, text_attention_mask=None, text_features=None, actions=None, is_pad=None):
env_state = None
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
image = normalize(image)
if self.use_text and text_features is None and text_input_ids is not None and text_attention_mask is not None:
if self.text_encoder is None:
raise RuntimeError('Text encoder is not initialized while use_text=True.')
text_features = self.text_encoder(text_input_ids, text_attention_mask)
if actions is not None: # training time
if is_pad is None:
raise ValueError('`is_pad` must be provided during training when `actions` is not None.')
actions = actions[:, :self.model.num_queries]
is_pad = is_pad[:, :self.model.num_queries]
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
a_hat, is_pad_hat, (mu, logvar) = self.model(
qpos,
image,
env_state,
text_features=text_features,
actions=actions,
is_pad=is_pad,
)
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
loss_dict = dict()
all_l1 = F.l1_loss(actions, a_hat, reduction='none')
@@ -34,7 +61,7 @@ class ACTPolicy(nn.Module):
loss_dict['loss'] = loss_dict['l1'] + loss_dict['kl'] * self.kl_weight
return loss_dict
else: # inference time
a_hat, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
a_hat, _, (_, _) = self.model(qpos, image, env_state, text_features=text_features) # no action, sample from prior
return a_hat
def configure_optimizers(self):
@@ -48,7 +75,7 @@ class CNNMLPPolicy(nn.Module):
self.model = model # decoder
self.optimizer = optimizer
def __call__(self, qpos, image, actions=None, is_pad=None):
def __call__(self, qpos, image, text_input_ids=None, text_attention_mask=None, text_features=None, actions=None, is_pad=None):
env_state = None # TODO
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])