代码可以跑起来了

This commit is contained in:
2026-02-19 15:32:28 +08:00
parent b701d939c2
commit 88d14221ae
11 changed files with 503 additions and 89 deletions

View File

@@ -268,10 +268,30 @@
## 8. 你接下来只需提供的最小信息(进入代码改造前)
1. 2 个电机各自的物理含义与取值范围(单位、上下限)
2. 你当前数据中 `qpos``action` 的实际定义(是否相同)。
3. text instruction 是每个 episode 一条,还是每个 timestep 一条。
4. 相机数量、分辨率、帧率
5. 是否在训练时冻结 DistilBERT`freeze_text_encoder=True/False`)。
1. 2 个电机各自的物理含义与取值范围(单位、上下限):电机分别为 motor_x 和 motor_yx 的范围为 7000-17384y 的范围为 8000-18884。对应的 action_x 和 action_y 都为 0~65535 之间
2. 你当前数据中 `qpos``action` 的实际定义(是否相同)action 和 qpos 定义接近,只不过是将 0~65535 分别映射到电机磁编码器数值上
3. text instruction 是每个 episode 一条,还是每个 timestep 一条text instruction 是每个 timestep 一条。
4. 相机数量、分辨率、帧率:相机数量为 1分辨率为 224*224帧率为 30Hz对应的电机控制频率也为 30Hz
5. 是否在训练时冻结 DistilBERT`freeze_text_encoder=True/False`DistilBERT 完全冻结。构建训练集时,先将每一个 frame 的 text instruction 用 DistilBERT 编码以后再保存。这样训练过程中不需要调用 DistilBERT
> 有了这 5 项,即可进入下一步代码改造。
text_input_ids、text_attention_mask 什么意思;'instruction_mode': 'episode-level'没有用到;
```python
instruction = ''
if self.use_text_instruction:
if '/instruction_timestep' in root:
instruction = self._decode_instruction(root['/instruction_timestep'][start_ts])
elif '/instruction' in root:
instruction_node = root['/instruction']
if getattr(instruction_node, 'shape', ()) == ():
instruction = self._decode_instruction(instruction_node[()])
else:
if len(instruction_node.shape) == 1 and instruction_node.shape[0] == episode_len:
instruction = self._decode_instruction(instruction_node[start_ts])
else:
instruction = self._decode_instruction(instruction_node[0])
```
为什么修改了 Transformer 的定义?这里是否会生效?

View File

@@ -21,3 +21,5 @@ dependencies:
- packaging=23.0
- h5py=3.8.0
- ipython=8.12.0
- pip:
- transformers==4.38.2

View File

@@ -1,7 +1,7 @@
import pathlib
### Task parameters
DATA_DIR = '<put your data dir here>'
DATA_DIR = str(pathlib.Path(__file__).parent.resolve() / 'data')
SIM_TASK_CONFIGS = {
'sim_transfer_cube_scripted':{
'dataset_dir': DATA_DIR + '/sim_transfer_cube_scripted',
@@ -32,6 +32,43 @@ SIM_TASK_CONFIGS = {
},
}
ENDOSCOPE_TASK_CONFIGS = {
'endoscope_default': {
'dataset_dir': DATA_DIR + '/endoscope_default',
'num_episodes': 50,
'episode_len': 400,
'camera_names': ['top'],
'state_dim': 2,
'action_dim': 2,
'use_text_instruction': True,
'instruction_mode': 'timestep-level',
'use_cached_text_features': True,
'text_encoder_type': 'distilbert',
'text_feature_dim': 768,
'text_fusion_type': 'concat_transformer_input',
'freeze_text_encoder': True,
'text_max_length': 32,
'text_tokenizer_name': 'distilbert-base-uncased',
},
'endoscope_follow': {
'dataset_dir': DATA_DIR + '/follow',
'num_episodes': 3,
'episode_len': 400,
'camera_names': ['top'],
'state_dim': 2,
'action_dim': 2,
'use_text_instruction': True,
'instruction_mode': 'timestep-level',
'use_cached_text_features': True,
'text_encoder_type': 'distilbert',
'text_feature_dim': 768,
'text_fusion_type': 'concat_transformer_input',
'freeze_text_encoder': True,
'text_max_length': 32,
'text_tokenizer_name': 'distilbert-base-uncased',
},
}
### Simulation envs fixed constants
DT = 0.02
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]

View File

@@ -4,6 +4,7 @@ from pathlib import Path
import numpy as np
import torch
from torch.optim.adamw import AdamW
from .models import build_ACT_model, build_CNNMLP_model
import IPython
@@ -30,6 +31,15 @@ def get_args_parser():
help="Type of positional embedding to use on top of the image features")
parser.add_argument('--camera_names', default=[], type=list, # will be overridden
help="A list of camera names")
parser.add_argument('--state_dim', default=14, type=int)
parser.add_argument('--action_dim', default=14, type=int)
parser.add_argument('--use_text', action='store_true')
parser.add_argument('--text_encoder_type', default='distilbert', type=str)
parser.add_argument('--text_feature_dim', default=768, type=int)
parser.add_argument('--text_fusion_type', default='concat_transformer_input', type=str)
parser.add_argument('--freeze_text_encoder', action='store_true')
parser.add_argument('--text_max_length', default=32, type=int)
parser.add_argument('--text_tokenizer_name', default='distilbert-base-uncased', type=str)
# * Transformer
parser.add_argument('--enc_layers', default=4, type=int, # will be overridden
@@ -84,7 +94,7 @@ def build_ACT_model_and_optimizer(args_override):
"lr": args.lr_backbone,
},
]
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
optimizer = AdamW(param_dicts, lr=args.lr,
weight_decay=args.weight_decay)
return model, optimizer
@@ -107,7 +117,7 @@ def build_CNNMLP_model_and_optimizer(args_override):
"lr": args.lr_backbone,
},
]
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
optimizer = AdamW(param_dicts, lr=args.lr,
weight_decay=args.weight_decay)
return model, optimizer

View File

@@ -89,9 +89,32 @@ class Backbone(BackboneBase):
train_backbone: bool,
return_interm_layers: bool,
dilation: bool):
backbone = getattr(torchvision.models, name)(
backbone_builder = getattr(torchvision.models, name)
weights = None
if is_main_process():
weight_enum_name_map = {
'resnet18': 'ResNet18_Weights',
'resnet34': 'ResNet34_Weights',
'resnet50': 'ResNet50_Weights',
'resnet101': 'ResNet101_Weights',
}
enum_name = weight_enum_name_map.get(name)
if enum_name is not None and hasattr(torchvision.models, enum_name):
weights = getattr(getattr(torchvision.models, enum_name), 'DEFAULT')
try:
backbone = backbone_builder(
replace_stride_with_dilation=[False, False, dilation],
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm??
weights=weights,
norm_layer=FrozenBatchNorm2d,
)
except TypeError:
# Backward compatibility for older torchvision that still expects `pretrained`.
backbone = backbone_builder(
replace_stride_with_dilation=[False, False, dilation],
pretrained=(weights is not None),
norm_layer=FrozenBatchNorm2d,
)
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)

View File

@@ -33,7 +33,8 @@ def get_sinusoid_encoding_table(n_position, d_hid):
class DETRVAE(nn.Module):
""" This is the DETR module that performs object detection """
def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names):
def __init__(self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names,
use_text=False, text_feature_dim=768, text_fusion_type='concat_transformer_input'):
""" Initializes the model.
Parameters:
backbones: torch module of the backbone to be used. See backbone.py
@@ -48,17 +49,18 @@ class DETRVAE(nn.Module):
self.camera_names = camera_names
self.transformer = transformer
self.encoder = encoder
self.use_text = use_text
self.text_fusion_type = text_fusion_type
hidden_dim = transformer.d_model
self.action_head = nn.Linear(hidden_dim, state_dim)
self.action_head = nn.Linear(hidden_dim, action_dim)
self.is_pad_head = nn.Linear(hidden_dim, 1)
self.query_embed = nn.Embedding(num_queries, hidden_dim)
if backbones is not None:
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
self.backbones = nn.ModuleList(backbones)
self.input_proj_robot_state = nn.Linear(14, hidden_dim)
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
else:
# input_dim = 14 + 7 # robot_state + env_state
self.input_proj_robot_state = nn.Linear(14, hidden_dim)
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
self.input_proj_env_state = nn.Linear(7, hidden_dim)
self.pos = torch.nn.Embedding(2, hidden_dim)
self.backbones = None
@@ -66,16 +68,18 @@ class DETRVAE(nn.Module):
# encoder extra parameters
self.latent_dim = 32 # final size of latent z # TODO tune
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding
self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
self.encoder_action_proj = nn.Linear(action_dim, hidden_dim) # project action to embedding
self.encoder_joint_proj = nn.Linear(state_dim, hidden_dim) # project qpos to embedding
self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq
# decoder extra parameters
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent
num_extra_tokens = 3 if self.use_text else 2
self.additional_pos_embed = nn.Embedding(num_extra_tokens, hidden_dim) # latent, proprio, optional text
self.text_proj = nn.Linear(text_feature_dim, hidden_dim) if self.use_text else None
def forward(self, qpos, image, env_state, actions=None, is_pad=None):
def forward(self, qpos, image, env_state, text_features=None, actions=None, is_pad=None):
"""
qpos: batch, qpos_dim
image: batch, num_cam, channel, height, width
@@ -125,10 +129,25 @@ class DETRVAE(nn.Module):
all_cam_pos.append(pos)
# proprioception features
proprio_input = self.input_proj_robot_state(qpos)
extra_input_tokens = None
if self.use_text and text_features is not None:
if self.text_fusion_type != 'concat_transformer_input':
raise NotImplementedError(f'Unsupported text fusion type: {self.text_fusion_type}')
text_input = self.text_proj(text_features)
extra_input_tokens = text_input.unsqueeze(0)
# fold camera dimension into width dimension
src = torch.cat(all_cam_features, axis=3)
pos = torch.cat(all_cam_pos, axis=3)
hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0]
hs = self.transformer(
src,
None,
self.query_embed.weight,
pos,
latent_input,
proprio_input,
self.additional_pos_embed.weight,
extra_input_tokens=extra_input_tokens,
)[0]
else:
qpos = self.input_proj_robot_state(qpos)
env_state = self.input_proj_env_state(env_state)
@@ -141,7 +160,7 @@ class DETRVAE(nn.Module):
class CNNMLP(nn.Module):
def __init__(self, backbones, state_dim, camera_names):
def __init__(self, backbones, state_dim, action_dim, camera_names):
""" Initializes the model.
Parameters:
backbones: torch module of the backbone to be used. See backbone.py
@@ -153,7 +172,7 @@ class CNNMLP(nn.Module):
"""
super().__init__()
self.camera_names = camera_names
self.action_head = nn.Linear(1000, state_dim) # TODO add more
self.action_head = nn.Linear(1000, action_dim) # TODO add more
if backbones is not None:
self.backbones = nn.ModuleList(backbones)
backbone_down_projs = []
@@ -166,8 +185,8 @@ class CNNMLP(nn.Module):
backbone_down_projs.append(down_proj)
self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
mlp_in_dim = 768 * len(backbones) + 14
self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=14, hidden_depth=2)
mlp_in_dim = 768 * len(backbones) + state_dim
self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=action_dim, hidden_depth=2)
else:
raise NotImplementedError
@@ -192,7 +211,7 @@ class CNNMLP(nn.Module):
for cam_feature in all_cam_features:
flattened_features.append(cam_feature.reshape([bs, -1]))
flattened_features = torch.cat(flattened_features, axis=1) # 768 each
features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
features = torch.cat([flattened_features, qpos], axis=1)
a_hat = self.mlp(features)
return a_hat
@@ -227,7 +246,8 @@ def build_encoder(args):
def build(args):
state_dim = 14 # TODO hardcode
state_dim = args.state_dim
action_dim = args.action_dim
# From state
# backbone = None # from state for now, no need for conv nets
@@ -245,8 +265,12 @@ def build(args):
transformer,
encoder,
state_dim=state_dim,
action_dim=action_dim,
num_queries=args.num_queries,
camera_names=args.camera_names,
use_text=args.use_text,
text_feature_dim=args.text_feature_dim,
text_fusion_type=args.text_fusion_type,
)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
@@ -255,7 +279,8 @@ def build(args):
return model
def build_cnnmlp(args):
state_dim = 14 # TODO hardcode
state_dim = args.state_dim
action_dim = args.action_dim
# From state
# backbone = None # from state for now, no need for conv nets
@@ -268,6 +293,7 @@ def build_cnnmlp(args):
model = CNNMLP(
backbones,
state_dim=state_dim,
action_dim=action_dim,
camera_names=args.camera_names,
)

View File

@@ -46,7 +46,7 @@ class Transformer(nn.Module):
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None):
def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None, extra_input_tokens=None):
# TODO flatten only when input has H and W
if len(src.shape) == 4: # has H and W
# flatten NxCxHxW to HWxNxC
@@ -56,10 +56,19 @@ class Transformer(nn.Module):
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
# mask = mask.flatten(1)
additional_inputs = [latent_input, proprio_input]
if extra_input_tokens is not None:
if len(extra_input_tokens.shape) == 2:
extra_input_tokens = extra_input_tokens.unsqueeze(0)
for i in range(extra_input_tokens.shape[0]):
additional_inputs.append(extra_input_tokens[i])
addition_input = torch.stack(additional_inputs, axis=0)
if additional_pos_embed is not None:
additional_pos_embed = additional_pos_embed[:addition_input.shape[0]]
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
addition_input = torch.stack([latent_input, proprio_input], axis=0)
src = torch.cat([addition_input, src], axis=0)
else:
assert len(src.shape) == 3

View File

@@ -10,14 +10,13 @@ from einops import rearrange
from constants import DT
from constants import PUPPET_GRIPPER_JOINT_OPEN
from constants import SIM_TASK_CONFIGS, ENDOSCOPE_TASK_CONFIGS
from utils import load_data # data functions
from utils import sample_box_pose, sample_insertion_pose # robot functions
from utils import compute_dict_mean, set_seed, detach_dict # helper functions
from policy import ACTPolicy, CNNMLPPolicy
from visualize_episodes import save_videos
from sim_env import BOX_POSE
import IPython
e = IPython.embed
@@ -34,25 +33,47 @@ def main(args):
num_epochs = args['num_epochs']
# get task parameters
is_sim = task_name[:4] == 'sim_'
if is_sim:
from constants import SIM_TASK_CONFIGS
is_endoscope = task_name in ENDOSCOPE_TASK_CONFIGS
if is_endoscope:
task_config = ENDOSCOPE_TASK_CONFIGS[task_name]
is_sim = False
elif task_name in SIM_TASK_CONFIGS:
task_config = SIM_TASK_CONFIGS[task_name]
is_sim = True
else:
from aloha_scripts.constants import TASK_CONFIGS
task_config = TASK_CONFIGS[task_name]
is_sim = False
dataset_dir = task_config['dataset_dir']
num_episodes = task_config['num_episodes']
episode_len = task_config['episode_len']
camera_names = task_config['camera_names']
state_dim = task_config.get('state_dim', 14)
action_dim = task_config.get('action_dim', state_dim)
use_text_instruction = task_config.get('use_text_instruction', False)
instruction_mode = task_config.get('instruction_mode', 'timestep-level')
use_cached_text_features = task_config.get('use_cached_text_features', True)
text_encoder_type = task_config.get('text_encoder_type', 'distilbert')
text_feature_dim = task_config.get('text_feature_dim', 768)
text_fusion_type = task_config.get('text_fusion_type', 'concat_transformer_input')
text_max_length = task_config.get('text_max_length', 32)
text_tokenizer_name = task_config.get('text_tokenizer_name', 'distilbert-base-uncased')
freeze_text_encoder = task_config.get('freeze_text_encoder', True)
if args.get('text_encoder_type') is not None:
text_encoder_type = args['text_encoder_type']
if args.get('text_max_length') is not None:
text_max_length = args['text_max_length']
if args.get('freeze_text_encoder', False):
freeze_text_encoder = True
# fixed parameters
state_dim = 14
lr_backbone = 1e-5
backbone = 'resnet18'
if policy_class == 'ACT':
enc_layers = 4
dec_layers = 7
enc_layers = 2
dec_layers = 4
nheads = 8
policy_config = {'lr': args['lr'],
'num_queries': args['chunk_size'],
@@ -65,10 +86,25 @@ def main(args):
'dec_layers': dec_layers,
'nheads': nheads,
'camera_names': camera_names,
'state_dim': state_dim,
'action_dim': action_dim,
'use_text': use_text_instruction,
'text_encoder_type': text_encoder_type,
'text_feature_dim': text_feature_dim,
'text_fusion_type': text_fusion_type,
'freeze_text_encoder': freeze_text_encoder,
'instruction_mode': instruction_mode,
'use_cached_text_features': use_cached_text_features,
'text_max_length': text_max_length,
'text_tokenizer_name': text_tokenizer_name,
}
elif policy_class == 'CNNMLP':
policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1,
'camera_names': camera_names,}
'camera_names': camera_names,
'state_dim': state_dim,
'action_dim': action_dim,
'use_text': use_text_instruction,
}
else:
raise NotImplementedError
@@ -77,6 +113,7 @@ def main(args):
'ckpt_dir': ckpt_dir,
'episode_len': episode_len,
'state_dim': state_dim,
'action_dim': action_dim,
'lr': args['lr'],
'policy_class': policy_class,
'onscreen_render': onscreen_render,
@@ -85,7 +122,12 @@ def main(args):
'seed': args['seed'],
'temporal_agg': args['temporal_agg'],
'camera_names': camera_names,
'real_robot': not is_sim
'real_robot': (not is_sim) and (not is_endoscope),
'use_text_instruction': use_text_instruction,
'instruction_mode': instruction_mode,
'use_cached_text_features': use_cached_text_features,
'text_tokenizer_name': text_tokenizer_name,
'text_max_length': text_max_length,
}
if is_eval:
@@ -100,7 +142,19 @@ def main(args):
print()
exit()
train_dataloader, val_dataloader, stats, _ = load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val)
train_dataloader, val_dataloader, stats, _ = load_data(
dataset_dir,
num_episodes,
camera_names,
batch_size_train,
batch_size_val,
use_text_instruction=use_text_instruction,
instruction_mode=instruction_mode,
use_cached_text_features=use_cached_text_features,
text_feature_dim=text_feature_dim,
text_tokenizer_name=text_tokenizer_name,
text_max_length=text_max_length,
)
# save dataset stats
if not os.path.isdir(ckpt_dir):
@@ -152,6 +206,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
set_seed(1000)
ckpt_dir = config['ckpt_dir']
state_dim = config['state_dim']
action_dim = config['action_dim']
real_robot = config['real_robot']
policy_class = config['policy_class']
onscreen_render = config['onscreen_render']
@@ -161,6 +216,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
task_name = config['task_name']
temporal_agg = config['temporal_agg']
onscreen_cam = 'angle'
BOX_POSE = None
# load policy and stats
ckpt_path = os.path.join(ckpt_dir, ckpt_name)
@@ -202,8 +258,14 @@ def eval_bc(config, ckpt_name, save_episode=True):
rollout_id += 0
### set task
if 'sim_transfer_cube' in task_name:
if BOX_POSE is None:
from sim_env import BOX_POSE as _BOX_POSE
BOX_POSE = _BOX_POSE
BOX_POSE[0] = sample_box_pose() # used in sim reset
elif 'sim_insertion' in task_name:
if BOX_POSE is None:
from sim_env import BOX_POSE as _BOX_POSE
BOX_POSE = _BOX_POSE
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
ts = env.reset()
@@ -216,7 +278,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
### evaluation loop
if temporal_agg:
all_time_actions = torch.zeros([max_timesteps, max_timesteps+num_queries, state_dim]).cuda()
all_time_actions = torch.zeros([max_timesteps, max_timesteps+num_queries, action_dim]).cuda()
qpos_history = torch.zeros((1, max_timesteps, state_dim)).cuda()
image_list = [] # for visualization
@@ -314,9 +376,29 @@ def eval_bc(config, ckpt_name, save_episode=True):
def forward_pass(data, policy):
image_data, qpos_data, action_data, is_pad = data
image_data, qpos_data, action_data, is_pad = image_data.cuda(), qpos_data.cuda(), action_data.cuda(), is_pad.cuda()
return policy(qpos_data, image_data, action_data, is_pad) # TODO remove None
image_data, qpos_data, action_data, is_pad, text_input_ids, text_attention_mask, text_feature_data, text_feature_valid = data
image_data = image_data.cuda()
qpos_data = qpos_data.cuda()
action_data = action_data.cuda()
is_pad = is_pad.cuda()
text_input_ids = text_input_ids.cuda()
text_attention_mask = text_attention_mask.cuda()
text_feature_data = text_feature_data.cuda()
text_feature_valid = text_feature_valid.cuda()
text_features = None
if torch.any(text_feature_valid):
text_features = text_feature_data
return policy(
qpos_data,
image_data,
text_input_ids=text_input_ids,
text_attention_mask=text_attention_mask,
text_features=text_features,
actions=action_data,
is_pad=is_pad,
)
def train_bc(train_dataloader, val_dataloader, config):
@@ -431,5 +513,8 @@ if __name__ == '__main__':
parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', required=False)
parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', required=False)
parser.add_argument('--temporal_agg', action='store_true')
parser.add_argument('--text_encoder_type', action='store', type=str, required=False)
parser.add_argument('--freeze_text_encoder', action='store_true')
parser.add_argument('--text_max_length', action='store', type=int, required=False)
main(vars(parser.parse_args()))

View File

@@ -3,6 +3,7 @@ from torch.nn import functional as F
import torchvision.transforms as transforms
from detr.main import build_ACT_model_and_optimizer, build_CNNMLP_model_and_optimizer
from models.text_encoder import DistilBERTTextEncoder
import IPython
e = IPython.embed
@@ -13,18 +14,44 @@ class ACTPolicy(nn.Module):
self.model = model # CVAE decoder
self.optimizer = optimizer
self.kl_weight = args_override['kl_weight']
self.use_text = args_override.get('use_text', False)
self.text_encoder = None
if self.use_text:
text_encoder_type = args_override.get('text_encoder_type', 'distilbert')
if text_encoder_type != 'distilbert':
raise NotImplementedError(f'Unsupported text encoder: {text_encoder_type}')
self.text_encoder = DistilBERTTextEncoder(
model_name=args_override.get('text_tokenizer_name', 'distilbert-base-uncased'),
output_dim=args_override.get('text_feature_dim', 768),
freeze=args_override.get('freeze_text_encoder', True),
)
print(f'KL Weight {self.kl_weight}')
def __call__(self, qpos, image, actions=None, is_pad=None):
def __call__(self, qpos, image, text_input_ids=None, text_attention_mask=None, text_features=None, actions=None, is_pad=None):
env_state = None
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
image = normalize(image)
if self.use_text and text_features is None and text_input_ids is not None and text_attention_mask is not None:
if self.text_encoder is None:
raise RuntimeError('Text encoder is not initialized while use_text=True.')
text_features = self.text_encoder(text_input_ids, text_attention_mask)
if actions is not None: # training time
if is_pad is None:
raise ValueError('`is_pad` must be provided during training when `actions` is not None.')
actions = actions[:, :self.model.num_queries]
is_pad = is_pad[:, :self.model.num_queries]
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
a_hat, is_pad_hat, (mu, logvar) = self.model(
qpos,
image,
env_state,
text_features=text_features,
actions=actions,
is_pad=is_pad,
)
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
loss_dict = dict()
all_l1 = F.l1_loss(actions, a_hat, reduction='none')
@@ -34,7 +61,7 @@ class ACTPolicy(nn.Module):
loss_dict['loss'] = loss_dict['l1'] + loss_dict['kl'] * self.kl_weight
return loss_dict
else: # inference time
a_hat, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
a_hat, _, (_, _) = self.model(qpos, image, env_state, text_features=text_features) # no action, sample from prior
return a_hat
def configure_optimizers(self):
@@ -48,7 +75,7 @@ class CNNMLPPolicy(nn.Module):
self.model = model # decoder
self.optimizer = optimizer
def __call__(self, qpos, image, actions=None, is_pad=None):
def __call__(self, qpos, image, text_input_ids=None, text_attention_mask=None, text_features=None, actions=None, is_pad=None):
env_state = None # TODO
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])

230
utils.py
View File

@@ -2,21 +2,88 @@ import numpy as np
import torch
import os
import h5py
import re
from torch.utils.data import TensorDataset, DataLoader
import IPython
e = IPython.embed
class EpisodicDataset(torch.utils.data.Dataset):
def __init__(self, episode_ids, dataset_dir, camera_names, norm_stats):
def __init__(self, episode_ids, dataset_dir, camera_names, norm_stats,
use_text_instruction=False,
instruction_mode='timestep-level',
use_cached_text_features=True,
text_feature_dim=768,
text_tokenizer_name='distilbert-base-uncased',
text_max_length=32):
super(EpisodicDataset).__init__()
self.episode_ids = episode_ids
self.dataset_dir = dataset_dir
self.camera_names = camera_names
self.norm_stats = norm_stats
self.use_text_instruction = use_text_instruction
self.instruction_mode = instruction_mode
self.use_cached_text_features = use_cached_text_features
self.text_feature_dim = text_feature_dim
self.text_max_length = text_max_length
self.is_sim = None
self.max_episode_len = None
self.action_dim = None
self.text_tokenizer = None
if self.use_text_instruction:
try:
from transformers import DistilBertTokenizerFast
except ImportError as exc:
raise ImportError(
'transformers is required for text instruction loading. '
'Install it with: pip install transformers'
) from exc
self.text_tokenizer = DistilBertTokenizerFast.from_pretrained(text_tokenizer_name)
self._init_episode_shapes()
self.__getitem__(0) # initialize self.is_sim
def _init_episode_shapes(self):
max_len = 0
action_dim = None
for episode_id in self.episode_ids:
dataset_path = os.path.join(self.dataset_dir, f'episode_{episode_id}.hdf5')
with h5py.File(dataset_path, 'r') as root:
shape = root['/action'].shape
if len(shape) != 2:
raise ValueError(f'Expected /action to have shape [T, D], got {shape} in {dataset_path}')
max_len = max(max_len, int(shape[0]))
if action_dim is None:
action_dim = int(shape[1])
elif int(shape[1]) != action_dim:
raise ValueError(
f'Inconsistent action dim in dataset. Expected {action_dim}, got {shape[1]} in {dataset_path}'
)
if max_len <= 0 or action_dim is None:
raise ValueError(f'Invalid dataset metadata in {self.dataset_dir}')
self.max_episode_len = max_len
self.action_dim = action_dim
@staticmethod
def _decode_instruction(raw_value):
if raw_value is None:
return ''
if isinstance(raw_value, bytes):
return raw_value.decode('utf-8')
if isinstance(raw_value, np.bytes_):
return raw_value.tobytes().decode('utf-8')
if isinstance(raw_value, np.ndarray):
if raw_value.shape == ():
return EpisodicDataset._decode_instruction(raw_value.item())
if raw_value.size == 0:
return ''
return EpisodicDataset._decode_instruction(raw_value.reshape(-1)[0])
return str(raw_value)
def __len__(self):
return len(self.episode_ids)
@@ -26,7 +93,7 @@ class EpisodicDataset(torch.utils.data.Dataset):
episode_id = self.episode_ids[index]
dataset_path = os.path.join(self.dataset_dir, f'episode_{episode_id}.hdf5')
with h5py.File(dataset_path, 'r') as root:
is_sim = root.attrs['sim']
is_sim = bool(root.attrs.get('sim', False))
original_action_shape = root['/action'].shape
episode_len = original_action_shape[0]
if sample_full_episode:
@@ -35,10 +102,40 @@ class EpisodicDataset(torch.utils.data.Dataset):
start_ts = np.random.choice(episode_len)
# get observation at start_ts only
qpos = root['/observations/qpos'][start_ts]
qvel = root['/observations/qvel'][start_ts]
image_dict = dict()
for cam_name in self.camera_names:
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][start_ts]
instruction = ''
text_feature = None
if self.use_text_instruction:
effective_mode = self.instruction_mode
if effective_mode == 'timestep-level' and '/instruction_timestep' in root:
instruction = self._decode_instruction(root['/instruction_timestep'][start_ts])
elif '/instruction' in root:
instruction_node = root['/instruction']
if getattr(instruction_node, 'shape', ()) == ():
instruction = self._decode_instruction(instruction_node[()])
else:
if len(instruction_node.shape) == 1 and instruction_node.shape[0] == episode_len:
instruction = self._decode_instruction(instruction_node[start_ts])
else:
instruction = self._decode_instruction(instruction_node[0])
if self.use_cached_text_features:
if effective_mode == 'timestep-level' and '/instruction_features_timestep' in root:
text_feature = root['/instruction_features_timestep'][start_ts]
elif '/instruction_features' in root:
feat_node = root['/instruction_features']
if getattr(feat_node, 'shape', ()) == ():
text_feature = np.array(feat_node[()])
elif len(feat_node.shape) == 1:
text_feature = feat_node[()]
elif len(feat_node.shape) == 2 and feat_node.shape[0] == episode_len:
text_feature = feat_node[start_ts]
else:
text_feature = feat_node[0]
# get all actions after and including start_ts
if is_sim:
action = root['/action'][start_ts:]
@@ -48,10 +145,10 @@ class EpisodicDataset(torch.utils.data.Dataset):
action_len = episode_len - max(0, start_ts - 1) # hack, to make timesteps more aligned
self.is_sim = is_sim
padded_action = np.zeros(original_action_shape, dtype=np.float32)
padded_action = np.zeros((self.max_episode_len, self.action_dim), dtype=np.float32)
padded_action[:action_len] = action
is_pad = np.zeros(episode_len)
is_pad[action_len:] = 1
is_pad = np.ones(self.max_episode_len)
is_pad[:action_len] = 0
# new axis for different cameras
all_cam_images = []
@@ -73,55 +170,132 @@ class EpisodicDataset(torch.utils.data.Dataset):
action_data = (action_data - self.norm_stats["action_mean"]) / self.norm_stats["action_std"]
qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats["qpos_std"]
return image_data, qpos_data, action_data, is_pad
if self.use_text_instruction and text_feature is not None:
text_feature_data = torch.from_numpy(np.array(text_feature)).float()
text_feature_valid = torch.tensor(True, dtype=torch.bool)
text_input_ids = torch.zeros(1, dtype=torch.long)
text_attention_mask = torch.zeros(1, dtype=torch.long)
elif self.use_text_instruction:
tokenized = self.text_tokenizer(
instruction,
padding='max_length',
truncation=True,
max_length=self.text_max_length,
return_tensors='pt',
)
text_input_ids = tokenized['input_ids'].squeeze(0).long()
text_attention_mask = tokenized['attention_mask'].squeeze(0).long()
text_feature_data = torch.zeros(self.text_feature_dim, dtype=torch.float32)
text_feature_valid = torch.tensor(False, dtype=torch.bool)
else:
text_input_ids = torch.zeros(1, dtype=torch.long)
text_attention_mask = torch.zeros(1, dtype=torch.long)
text_feature_data = torch.zeros(self.text_feature_dim, dtype=torch.float32)
text_feature_valid = torch.tensor(False, dtype=torch.bool)
return image_data, qpos_data, action_data, is_pad, text_input_ids, text_attention_mask, text_feature_data, text_feature_valid
def get_norm_stats(dataset_dir, num_episodes):
def _discover_episode_ids(dataset_dir, num_episodes=None):
pattern = re.compile(r'^episode_(\d+)\.hdf5$')
episode_ids = []
for fname in os.listdir(dataset_dir):
m = pattern.match(fname)
if m:
episode_ids.append(int(m.group(1)))
episode_ids.sort()
if num_episodes is not None:
episode_ids = episode_ids[:num_episodes]
return episode_ids
def get_norm_stats(dataset_dir, episode_ids):
all_qpos_data = []
all_action_data = []
for episode_idx in range(num_episodes):
example_qpos = None
for episode_idx in episode_ids:
dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}.hdf5')
with h5py.File(dataset_path, 'r') as root:
qpos = root['/observations/qpos'][()]
qvel = root['/observations/qvel'][()]
action = root['/action'][()]
all_qpos_data.append(torch.from_numpy(qpos))
all_action_data.append(torch.from_numpy(action))
all_qpos_data = torch.stack(all_qpos_data)
all_action_data = torch.stack(all_action_data)
all_action_data = all_action_data
qpos_t = torch.from_numpy(qpos)
action_t = torch.from_numpy(action)
all_qpos_data.append(qpos_t)
all_action_data.append(action_t)
if example_qpos is None and len(qpos) > 0:
example_qpos = qpos[0]
# Episodes may have different lengths; concatenate over time axis.
all_qpos_data = torch.cat(all_qpos_data, dim=0)
all_action_data = torch.cat(all_action_data, dim=0)
# normalize action data
action_mean = all_action_data.mean(dim=[0, 1], keepdim=True)
action_std = all_action_data.std(dim=[0, 1], keepdim=True)
action_mean = all_action_data.mean(dim=0, keepdim=True)
action_std = all_action_data.std(dim=0, keepdim=True)
action_std = torch.clip(action_std, 1e-2, np.inf) # clipping
# normalize qpos data
qpos_mean = all_qpos_data.mean(dim=[0, 1], keepdim=True)
qpos_std = all_qpos_data.std(dim=[0, 1], keepdim=True)
qpos_mean = all_qpos_data.mean(dim=0, keepdim=True)
qpos_std = all_qpos_data.std(dim=0, keepdim=True)
qpos_std = torch.clip(qpos_std, 1e-2, np.inf) # clipping
stats = {"action_mean": action_mean.numpy().squeeze(), "action_std": action_std.numpy().squeeze(),
"qpos_mean": qpos_mean.numpy().squeeze(), "qpos_std": qpos_std.numpy().squeeze(),
"example_qpos": qpos}
"example_qpos": example_qpos}
return stats
def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val):
def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val,
use_text_instruction=False,
instruction_mode='timestep-level',
use_cached_text_features=True,
text_feature_dim=768,
text_tokenizer_name='distilbert-base-uncased',
text_max_length=32):
print(f'\nData from: {dataset_dir}\n')
episode_ids = _discover_episode_ids(dataset_dir, num_episodes)
if len(episode_ids) == 0:
raise FileNotFoundError(f'No episode_*.hdf5 found in {dataset_dir}')
if len(episode_ids) < 2:
raise ValueError(f'Need at least 2 episodes for train/val split, found {len(episode_ids)} in {dataset_dir}')
# obtain train test split
train_ratio = 0.8
shuffled_indices = np.random.permutation(num_episodes)
train_indices = shuffled_indices[:int(train_ratio * num_episodes)]
val_indices = shuffled_indices[int(train_ratio * num_episodes):]
shuffled_indices = np.random.permutation(len(episode_ids))
train_count = int(train_ratio * len(episode_ids))
train_indices = shuffled_indices[:train_count]
val_indices = shuffled_indices[train_count:]
train_episode_ids = np.array(episode_ids)[train_indices]
val_episode_ids = np.array(episode_ids)[val_indices]
# obtain normalization stats for qpos and action
norm_stats = get_norm_stats(dataset_dir, num_episodes)
norm_stats = get_norm_stats(dataset_dir, episode_ids)
# construct dataset and dataloader
train_dataset = EpisodicDataset(train_indices, dataset_dir, camera_names, norm_stats)
val_dataset = EpisodicDataset(val_indices, dataset_dir, camera_names, norm_stats)
train_dataset = EpisodicDataset(
train_episode_ids,
dataset_dir,
camera_names,
norm_stats,
use_text_instruction=use_text_instruction,
instruction_mode=instruction_mode,
use_cached_text_features=use_cached_text_features,
text_feature_dim=text_feature_dim,
text_tokenizer_name=text_tokenizer_name,
text_max_length=text_max_length,
)
val_dataset = EpisodicDataset(
val_episode_ids,
dataset_dir,
camera_names,
norm_stats,
use_text_instruction=use_text_instruction,
instruction_mode=instruction_mode,
use_cached_text_features=use_cached_text_features,
text_feature_dim=text_feature_dim,
text_tokenizer_name=text_tokenizer_name,
text_max_length=text_max_length,
)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size_val, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)

View File

@@ -21,32 +21,33 @@ def load_hdf5(dataset_dir, dataset_name):
with h5py.File(dataset_path, 'r') as root:
is_sim = root.attrs['sim']
dt = float(root.attrs.get('dt', DT))
qpos = root['/observations/qpos'][()]
qvel = root['/observations/qvel'][()]
qvel = root['/observations/qvel'][()] if '/observations/qvel' in root else np.zeros_like(qpos)
action = root['/action'][()]
image_dict = dict()
for cam_name in root[f'/observations/images/'].keys():
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()]
return qpos, qvel, action, image_dict
return qpos, qvel, action, image_dict, dt
def main(args):
dataset_dir = args['dataset_dir']
episode_idx = args['episode_idx']
dataset_name = f'episode_{episode_idx}'
qpos, qvel, action, image_dict = load_hdf5(dataset_dir, dataset_name)
save_videos(image_dict, DT, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4'))
qpos, qvel, action, image_dict, dt = load_hdf5(dataset_dir, dataset_name)
save_videos(image_dict, dt, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4'))
visualize_joints(qpos, action, plot_path=os.path.join(dataset_dir, dataset_name + '_qpos.png'))
# visualize_timestamp(t_list, dataset_path) # TODO addn timestamp back
def save_videos(video, dt, video_path=None):
def save_videos(video, dt, video_path):
if isinstance(video, list):
cam_names = list(video[0].keys())
h, w, _ = video[0][cam_names[0]].shape
w = w * len(cam_names)
fps = int(1/dt)
fps = max(1, int(round(1 / dt)))
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
for ts, image_dict in enumerate(video):
images = []
@@ -66,7 +67,7 @@ def save_videos(video, dt, video_path=None):
all_cam_videos = np.concatenate(all_cam_videos, axis=2) # width dimension
n_frames, h, w, _ = all_cam_videos.shape
fps = int(1 / dt)
fps = max(1, int(round(1 / dt)))
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
for t in range(n_frames):
image = all_cam_videos[t]