112 lines
4.7 KiB
Python
112 lines
4.7 KiB
Python
import torch.nn as nn
|
|
from torch.nn import functional as F
|
|
import torchvision.transforms as transforms
|
|
|
|
from detr.main import build_ACT_model_and_optimizer, build_CNNMLP_model_and_optimizer
|
|
from models.text_encoder import DistilBERTTextEncoder
|
|
import IPython
|
|
e = IPython.embed
|
|
|
|
class ACTPolicy(nn.Module):
|
|
def __init__(self, args_override):
|
|
super().__init__()
|
|
model, optimizer = build_ACT_model_and_optimizer(args_override)
|
|
self.model = model # CVAE decoder
|
|
self.optimizer = optimizer
|
|
self.kl_weight = args_override['kl_weight']
|
|
self.use_text = args_override.get('use_text', False)
|
|
self.text_encoder = None
|
|
if self.use_text:
|
|
text_encoder_type = args_override.get('text_encoder_type', 'distilbert')
|
|
if text_encoder_type != 'distilbert':
|
|
raise NotImplementedError(f'Unsupported text encoder: {text_encoder_type}')
|
|
self.text_encoder = DistilBERTTextEncoder(
|
|
model_name=args_override.get('text_tokenizer_name', 'distilbert-base-uncased'),
|
|
output_dim=args_override.get('text_feature_dim', 768),
|
|
freeze=args_override.get('freeze_text_encoder', True),
|
|
)
|
|
print(f'KL Weight {self.kl_weight}')
|
|
|
|
def __call__(self, qpos, image, text_input_ids=None, text_attention_mask=None, text_features=None, actions=None, is_pad=None):
|
|
env_state = None
|
|
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
std=[0.229, 0.224, 0.225])
|
|
image = normalize(image)
|
|
|
|
if self.use_text and text_features is None and text_input_ids is not None and text_attention_mask is not None:
|
|
if self.text_encoder is None:
|
|
raise RuntimeError('Text encoder is not initialized while use_text=True.')
|
|
text_features = self.text_encoder(text_input_ids, text_attention_mask)
|
|
|
|
if actions is not None: # training time
|
|
if is_pad is None:
|
|
raise ValueError('`is_pad` must be provided during training when `actions` is not None.')
|
|
actions = actions[:, :self.model.num_queries]
|
|
is_pad = is_pad[:, :self.model.num_queries]
|
|
|
|
a_hat, is_pad_hat, (mu, logvar) = self.model(
|
|
qpos,
|
|
image,
|
|
env_state,
|
|
text_features=text_features,
|
|
actions=actions,
|
|
is_pad=is_pad,
|
|
)
|
|
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
|
|
loss_dict = dict()
|
|
all_l1 = F.l1_loss(actions, a_hat, reduction='none')
|
|
l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean()
|
|
loss_dict['l1'] = l1
|
|
loss_dict['kl'] = total_kld[0]
|
|
loss_dict['loss'] = loss_dict['l1'] + loss_dict['kl'] * self.kl_weight
|
|
return loss_dict
|
|
else: # inference time
|
|
a_hat, _, (_, _) = self.model(qpos, image, env_state, text_features=text_features) # no action, sample from prior
|
|
return a_hat
|
|
|
|
def configure_optimizers(self):
|
|
return self.optimizer
|
|
|
|
|
|
class CNNMLPPolicy(nn.Module):
|
|
def __init__(self, args_override):
|
|
super().__init__()
|
|
model, optimizer = build_CNNMLP_model_and_optimizer(args_override)
|
|
self.model = model # decoder
|
|
self.optimizer = optimizer
|
|
|
|
def __call__(self, qpos, image, text_input_ids=None, text_attention_mask=None, text_features=None, actions=None, is_pad=None):
|
|
env_state = None # TODO
|
|
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
std=[0.229, 0.224, 0.225])
|
|
image = normalize(image)
|
|
if actions is not None: # training time
|
|
actions = actions[:, 0]
|
|
a_hat = self.model(qpos, image, env_state, actions)
|
|
mse = F.mse_loss(actions, a_hat)
|
|
loss_dict = dict()
|
|
loss_dict['mse'] = mse
|
|
loss_dict['loss'] = loss_dict['mse']
|
|
return loss_dict
|
|
else: # inference time
|
|
a_hat = self.model(qpos, image, env_state) # no action, sample from prior
|
|
return a_hat
|
|
|
|
def configure_optimizers(self):
|
|
return self.optimizer
|
|
|
|
def kl_divergence(mu, logvar):
|
|
batch_size = mu.size(0)
|
|
assert batch_size != 0
|
|
if mu.data.ndimension() == 4:
|
|
mu = mu.view(mu.size(0), mu.size(1))
|
|
if logvar.data.ndimension() == 4:
|
|
logvar = logvar.view(logvar.size(0), logvar.size(1))
|
|
|
|
klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
|
|
total_kld = klds.sum(1).mean(0, True)
|
|
dimension_wise_kld = klds.mean(0)
|
|
mean_kld = klds.mean(1).mean(0, True)
|
|
|
|
return total_kld, dimension_wise_kld, mean_kld
|