From 88d14221aeeeede9292b85e55b3463aa1d4edd69 Mon Sep 17 00:00:00 2001 From: JC6123 Date: Thu, 19 Feb 2026 15:32:28 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BB=A3=E7=A0=81=E5=8F=AF=E4=BB=A5=E8=B7=91?= =?UTF-8?q?=E8=B5=B7=E6=9D=A5=E4=BA=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ENDOSCOPE_ACT_ADAPTATION_PLAN.md | 32 ++++- conda_env.yaml | 2 + constants.py | 39 +++++- detr/main.py | 18 ++- detr/models/backbone.py | 29 +++- detr/models/detr_vae.py | 60 +++++--- detr/models/transformer.py | 17 ++- imitate_episodes.py | 115 ++++++++++++++-- policy.py | 35 ++++- utils.py | 230 +++++++++++++++++++++++++++---- visualize_episodes.py | 15 +- 11 files changed, 503 insertions(+), 89 deletions(-) diff --git a/ENDOSCOPE_ACT_ADAPTATION_PLAN.md b/ENDOSCOPE_ACT_ADAPTATION_PLAN.md index d6d1dd9..29b96fd 100644 --- a/ENDOSCOPE_ACT_ADAPTATION_PLAN.md +++ b/ENDOSCOPE_ACT_ADAPTATION_PLAN.md @@ -268,10 +268,30 @@ ## 8. 你接下来只需提供的最小信息(进入代码改造前) -1. 2 个电机各自的物理含义与取值范围(单位、上下限)。 -2. 你当前数据中 `qpos` 和 `action` 的实际定义(是否相同)。 -3. text instruction 是每个 episode 一条,还是每个 timestep 一条。 -4. 相机数量、分辨率、帧率。 -5. 是否在训练时冻结 DistilBERT(`freeze_text_encoder=True/False`)。 +1. 2 个电机各自的物理含义与取值范围(单位、上下限):电机分别为 motor_x 和 motor_y,x 的范围为 7000-17384,y 的范围为 8000-18884。对应的 action_x 和 action_y 都为 0~65535 之间 +2. 你当前数据中 `qpos` 和 `action` 的实际定义(是否相同):action 和 qpos 定义接近,只不过是将 0~65535 分别映射到电机磁编码器数值上。 +3. text instruction 是每个 episode 一条,还是每个 timestep 一条:text instruction 是每个 timestep 一条。 +4. 相机数量、分辨率、帧率:相机数量为 1,分辨率为 224*224,帧率为 30Hz;对应的电机控制频率也为 30Hz +5. 是否在训练时冻结 DistilBERT(`freeze_text_encoder=True/False`):DistilBERT 完全冻结。构建训练集时,先将每一个 frame 的 text instruction 用 DistilBERT 编码以后再保存。这样训练过程中不需要调用 DistilBERT。 -> 有了这 5 项,即可进入下一步代码改造。 \ No newline at end of file +> 有了这 5 项,即可进入下一步代码改造。 + +text_input_ids、text_attention_mask 什么意思;'instruction_mode': 'episode-level'没有用到; + +```python +instruction = '' +if self.use_text_instruction: + if '/instruction_timestep' in root: + instruction = self._decode_instruction(root['/instruction_timestep'][start_ts]) + elif '/instruction' in root: + instruction_node = root['/instruction'] + if getattr(instruction_node, 'shape', ()) == (): + instruction = self._decode_instruction(instruction_node[()]) + else: + if len(instruction_node.shape) == 1 and instruction_node.shape[0] == episode_len: + instruction = self._decode_instruction(instruction_node[start_ts]) + else: + instruction = self._decode_instruction(instruction_node[0]) +``` + +为什么修改了 Transformer 的定义?这里是否会生效? \ No newline at end of file diff --git a/conda_env.yaml b/conda_env.yaml index 0f44d6b..477b4b3 100644 --- a/conda_env.yaml +++ b/conda_env.yaml @@ -21,3 +21,5 @@ dependencies: - packaging=23.0 - h5py=3.8.0 - ipython=8.12.0 + - pip: + - transformers==4.38.2 diff --git a/constants.py b/constants.py index f445350..33debc8 100644 --- a/constants.py +++ b/constants.py @@ -1,7 +1,7 @@ import pathlib ### Task parameters -DATA_DIR = '' +DATA_DIR = str(pathlib.Path(__file__).parent.resolve() / 'data') SIM_TASK_CONFIGS = { 'sim_transfer_cube_scripted':{ 'dataset_dir': DATA_DIR + '/sim_transfer_cube_scripted', @@ -32,6 +32,43 @@ SIM_TASK_CONFIGS = { }, } +ENDOSCOPE_TASK_CONFIGS = { + 'endoscope_default': { + 'dataset_dir': DATA_DIR + '/endoscope_default', + 'num_episodes': 50, + 'episode_len': 400, + 'camera_names': ['top'], + 'state_dim': 2, + 'action_dim': 2, + 'use_text_instruction': True, + 'instruction_mode': 'timestep-level', + 'use_cached_text_features': True, + 'text_encoder_type': 'distilbert', + 'text_feature_dim': 768, + 'text_fusion_type': 'concat_transformer_input', + 'freeze_text_encoder': True, + 'text_max_length': 32, + 'text_tokenizer_name': 'distilbert-base-uncased', + }, + 'endoscope_follow': { + 'dataset_dir': DATA_DIR + '/follow', + 'num_episodes': 3, + 'episode_len': 400, + 'camera_names': ['top'], + 'state_dim': 2, + 'action_dim': 2, + 'use_text_instruction': True, + 'instruction_mode': 'timestep-level', + 'use_cached_text_features': True, + 'text_encoder_type': 'distilbert', + 'text_feature_dim': 768, + 'text_fusion_type': 'concat_transformer_input', + 'freeze_text_encoder': True, + 'text_max_length': 32, + 'text_tokenizer_name': 'distilbert-base-uncased', + }, +} + ### Simulation envs fixed constants DT = 0.02 JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"] diff --git a/detr/main.py b/detr/main.py index 3c4a339..d8f85d5 100644 --- a/detr/main.py +++ b/detr/main.py @@ -4,6 +4,7 @@ from pathlib import Path import numpy as np import torch +from torch.optim.adamw import AdamW from .models import build_ACT_model, build_CNNMLP_model import IPython @@ -30,6 +31,15 @@ def get_args_parser(): help="Type of positional embedding to use on top of the image features") parser.add_argument('--camera_names', default=[], type=list, # will be overridden help="A list of camera names") + parser.add_argument('--state_dim', default=14, type=int) + parser.add_argument('--action_dim', default=14, type=int) + parser.add_argument('--use_text', action='store_true') + parser.add_argument('--text_encoder_type', default='distilbert', type=str) + parser.add_argument('--text_feature_dim', default=768, type=int) + parser.add_argument('--text_fusion_type', default='concat_transformer_input', type=str) + parser.add_argument('--freeze_text_encoder', action='store_true') + parser.add_argument('--text_max_length', default=32, type=int) + parser.add_argument('--text_tokenizer_name', default='distilbert-base-uncased', type=str) # * Transformer parser.add_argument('--enc_layers', default=4, type=int, # will be overridden @@ -84,8 +94,8 @@ def build_ACT_model_and_optimizer(args_override): "lr": args.lr_backbone, }, ] - optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, - weight_decay=args.weight_decay) + optimizer = AdamW(param_dicts, lr=args.lr, + weight_decay=args.weight_decay) return model, optimizer @@ -107,8 +117,8 @@ def build_CNNMLP_model_and_optimizer(args_override): "lr": args.lr_backbone, }, ] - optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, - weight_decay=args.weight_decay) + optimizer = AdamW(param_dicts, lr=args.lr, + weight_decay=args.weight_decay) return model, optimizer diff --git a/detr/models/backbone.py b/detr/models/backbone.py index f28637e..de4ace3 100644 --- a/detr/models/backbone.py +++ b/detr/models/backbone.py @@ -89,9 +89,32 @@ class Backbone(BackboneBase): train_backbone: bool, return_interm_layers: bool, dilation: bool): - backbone = getattr(torchvision.models, name)( - replace_stride_with_dilation=[False, False, dilation], - pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm?? + backbone_builder = getattr(torchvision.models, name) + weights = None + if is_main_process(): + weight_enum_name_map = { + 'resnet18': 'ResNet18_Weights', + 'resnet34': 'ResNet34_Weights', + 'resnet50': 'ResNet50_Weights', + 'resnet101': 'ResNet101_Weights', + } + enum_name = weight_enum_name_map.get(name) + if enum_name is not None and hasattr(torchvision.models, enum_name): + weights = getattr(getattr(torchvision.models, enum_name), 'DEFAULT') + + try: + backbone = backbone_builder( + replace_stride_with_dilation=[False, False, dilation], + weights=weights, + norm_layer=FrozenBatchNorm2d, + ) + except TypeError: + # Backward compatibility for older torchvision that still expects `pretrained`. + backbone = backbone_builder( + replace_stride_with_dilation=[False, False, dilation], + pretrained=(weights is not None), + norm_layer=FrozenBatchNorm2d, + ) num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 super().__init__(backbone, train_backbone, num_channels, return_interm_layers) diff --git a/detr/models/detr_vae.py b/detr/models/detr_vae.py index bccfca7..9db76ac 100644 --- a/detr/models/detr_vae.py +++ b/detr/models/detr_vae.py @@ -33,7 +33,8 @@ def get_sinusoid_encoding_table(n_position, d_hid): class DETRVAE(nn.Module): """ This is the DETR module that performs object detection """ - def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names): + def __init__(self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names, + use_text=False, text_feature_dim=768, text_fusion_type='concat_transformer_input'): """ Initializes the model. Parameters: backbones: torch module of the backbone to be used. See backbone.py @@ -48,17 +49,18 @@ class DETRVAE(nn.Module): self.camera_names = camera_names self.transformer = transformer self.encoder = encoder + self.use_text = use_text + self.text_fusion_type = text_fusion_type hidden_dim = transformer.d_model - self.action_head = nn.Linear(hidden_dim, state_dim) + self.action_head = nn.Linear(hidden_dim, action_dim) self.is_pad_head = nn.Linear(hidden_dim, 1) self.query_embed = nn.Embedding(num_queries, hidden_dim) if backbones is not None: self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1) self.backbones = nn.ModuleList(backbones) - self.input_proj_robot_state = nn.Linear(14, hidden_dim) + self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim) else: - # input_dim = 14 + 7 # robot_state + env_state - self.input_proj_robot_state = nn.Linear(14, hidden_dim) + self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim) self.input_proj_env_state = nn.Linear(7, hidden_dim) self.pos = torch.nn.Embedding(2, hidden_dim) self.backbones = None @@ -66,16 +68,18 @@ class DETRVAE(nn.Module): # encoder extra parameters self.latent_dim = 32 # final size of latent z # TODO tune self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding - self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding - self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding + self.encoder_action_proj = nn.Linear(action_dim, hidden_dim) # project action to embedding + self.encoder_joint_proj = nn.Linear(state_dim, hidden_dim) # project qpos to embedding self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq # decoder extra parameters self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding - self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent + num_extra_tokens = 3 if self.use_text else 2 + self.additional_pos_embed = nn.Embedding(num_extra_tokens, hidden_dim) # latent, proprio, optional text + self.text_proj = nn.Linear(text_feature_dim, hidden_dim) if self.use_text else None - def forward(self, qpos, image, env_state, actions=None, is_pad=None): + def forward(self, qpos, image, env_state, text_features=None, actions=None, is_pad=None): """ qpos: batch, qpos_dim image: batch, num_cam, channel, height, width @@ -125,10 +129,25 @@ class DETRVAE(nn.Module): all_cam_pos.append(pos) # proprioception features proprio_input = self.input_proj_robot_state(qpos) + extra_input_tokens = None + if self.use_text and text_features is not None: + if self.text_fusion_type != 'concat_transformer_input': + raise NotImplementedError(f'Unsupported text fusion type: {self.text_fusion_type}') + text_input = self.text_proj(text_features) + extra_input_tokens = text_input.unsqueeze(0) # fold camera dimension into width dimension src = torch.cat(all_cam_features, axis=3) pos = torch.cat(all_cam_pos, axis=3) - hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0] + hs = self.transformer( + src, + None, + self.query_embed.weight, + pos, + latent_input, + proprio_input, + self.additional_pos_embed.weight, + extra_input_tokens=extra_input_tokens, + )[0] else: qpos = self.input_proj_robot_state(qpos) env_state = self.input_proj_env_state(env_state) @@ -141,7 +160,7 @@ class DETRVAE(nn.Module): class CNNMLP(nn.Module): - def __init__(self, backbones, state_dim, camera_names): + def __init__(self, backbones, state_dim, action_dim, camera_names): """ Initializes the model. Parameters: backbones: torch module of the backbone to be used. See backbone.py @@ -153,7 +172,7 @@ class CNNMLP(nn.Module): """ super().__init__() self.camera_names = camera_names - self.action_head = nn.Linear(1000, state_dim) # TODO add more + self.action_head = nn.Linear(1000, action_dim) # TODO add more if backbones is not None: self.backbones = nn.ModuleList(backbones) backbone_down_projs = [] @@ -166,8 +185,8 @@ class CNNMLP(nn.Module): backbone_down_projs.append(down_proj) self.backbone_down_projs = nn.ModuleList(backbone_down_projs) - mlp_in_dim = 768 * len(backbones) + 14 - self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=14, hidden_depth=2) + mlp_in_dim = 768 * len(backbones) + state_dim + self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=action_dim, hidden_depth=2) else: raise NotImplementedError @@ -192,7 +211,7 @@ class CNNMLP(nn.Module): for cam_feature in all_cam_features: flattened_features.append(cam_feature.reshape([bs, -1])) flattened_features = torch.cat(flattened_features, axis=1) # 768 each - features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14 + features = torch.cat([flattened_features, qpos], axis=1) a_hat = self.mlp(features) return a_hat @@ -227,7 +246,8 @@ def build_encoder(args): def build(args): - state_dim = 14 # TODO hardcode + state_dim = args.state_dim + action_dim = args.action_dim # From state # backbone = None # from state for now, no need for conv nets @@ -245,8 +265,12 @@ def build(args): transformer, encoder, state_dim=state_dim, + action_dim=action_dim, num_queries=args.num_queries, camera_names=args.camera_names, + use_text=args.use_text, + text_feature_dim=args.text_feature_dim, + text_fusion_type=args.text_fusion_type, ) n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) @@ -255,7 +279,8 @@ def build(args): return model def build_cnnmlp(args): - state_dim = 14 # TODO hardcode + state_dim = args.state_dim + action_dim = args.action_dim # From state # backbone = None # from state for now, no need for conv nets @@ -268,6 +293,7 @@ def build_cnnmlp(args): model = CNNMLP( backbones, state_dim=state_dim, + action_dim=action_dim, camera_names=args.camera_names, ) diff --git a/detr/models/transformer.py b/detr/models/transformer.py index f38afd0..8866da0 100644 --- a/detr/models/transformer.py +++ b/detr/models/transformer.py @@ -46,7 +46,7 @@ class Transformer(nn.Module): if p.dim() > 1: nn.init.xavier_uniform_(p) - def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None): + def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None, extra_input_tokens=None): # TODO flatten only when input has H and W if len(src.shape) == 4: # has H and W # flatten NxCxHxW to HWxNxC @@ -56,10 +56,19 @@ class Transformer(nn.Module): query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # mask = mask.flatten(1) - additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim - pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0) + additional_inputs = [latent_input, proprio_input] + if extra_input_tokens is not None: + if len(extra_input_tokens.shape) == 2: + extra_input_tokens = extra_input_tokens.unsqueeze(0) + for i in range(extra_input_tokens.shape[0]): + additional_inputs.append(extra_input_tokens[i]) + + addition_input = torch.stack(additional_inputs, axis=0) + if additional_pos_embed is not None: + additional_pos_embed = additional_pos_embed[:addition_input.shape[0]] + additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim + pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0) - addition_input = torch.stack([latent_input, proprio_input], axis=0) src = torch.cat([addition_input, src], axis=0) else: assert len(src.shape) == 3 diff --git a/imitate_episodes.py b/imitate_episodes.py index 34f9a37..77d0077 100644 --- a/imitate_episodes.py +++ b/imitate_episodes.py @@ -10,14 +10,13 @@ from einops import rearrange from constants import DT from constants import PUPPET_GRIPPER_JOINT_OPEN +from constants import SIM_TASK_CONFIGS, ENDOSCOPE_TASK_CONFIGS from utils import load_data # data functions from utils import sample_box_pose, sample_insertion_pose # robot functions from utils import compute_dict_mean, set_seed, detach_dict # helper functions from policy import ACTPolicy, CNNMLPPolicy from visualize_episodes import save_videos -from sim_env import BOX_POSE - import IPython e = IPython.embed @@ -34,25 +33,47 @@ def main(args): num_epochs = args['num_epochs'] # get task parameters - is_sim = task_name[:4] == 'sim_' - if is_sim: - from constants import SIM_TASK_CONFIGS + is_endoscope = task_name in ENDOSCOPE_TASK_CONFIGS + if is_endoscope: + task_config = ENDOSCOPE_TASK_CONFIGS[task_name] + is_sim = False + elif task_name in SIM_TASK_CONFIGS: task_config = SIM_TASK_CONFIGS[task_name] + is_sim = True else: from aloha_scripts.constants import TASK_CONFIGS task_config = TASK_CONFIGS[task_name] + is_sim = False + dataset_dir = task_config['dataset_dir'] num_episodes = task_config['num_episodes'] episode_len = task_config['episode_len'] camera_names = task_config['camera_names'] + state_dim = task_config.get('state_dim', 14) + action_dim = task_config.get('action_dim', state_dim) + use_text_instruction = task_config.get('use_text_instruction', False) + instruction_mode = task_config.get('instruction_mode', 'timestep-level') + use_cached_text_features = task_config.get('use_cached_text_features', True) + text_encoder_type = task_config.get('text_encoder_type', 'distilbert') + text_feature_dim = task_config.get('text_feature_dim', 768) + text_fusion_type = task_config.get('text_fusion_type', 'concat_transformer_input') + text_max_length = task_config.get('text_max_length', 32) + text_tokenizer_name = task_config.get('text_tokenizer_name', 'distilbert-base-uncased') + freeze_text_encoder = task_config.get('freeze_text_encoder', True) + + if args.get('text_encoder_type') is not None: + text_encoder_type = args['text_encoder_type'] + if args.get('text_max_length') is not None: + text_max_length = args['text_max_length'] + if args.get('freeze_text_encoder', False): + freeze_text_encoder = True # fixed parameters - state_dim = 14 lr_backbone = 1e-5 backbone = 'resnet18' if policy_class == 'ACT': - enc_layers = 4 - dec_layers = 7 + enc_layers = 2 + dec_layers = 4 nheads = 8 policy_config = {'lr': args['lr'], 'num_queries': args['chunk_size'], @@ -65,10 +86,25 @@ def main(args): 'dec_layers': dec_layers, 'nheads': nheads, 'camera_names': camera_names, + 'state_dim': state_dim, + 'action_dim': action_dim, + 'use_text': use_text_instruction, + 'text_encoder_type': text_encoder_type, + 'text_feature_dim': text_feature_dim, + 'text_fusion_type': text_fusion_type, + 'freeze_text_encoder': freeze_text_encoder, + 'instruction_mode': instruction_mode, + 'use_cached_text_features': use_cached_text_features, + 'text_max_length': text_max_length, + 'text_tokenizer_name': text_tokenizer_name, } elif policy_class == 'CNNMLP': policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1, - 'camera_names': camera_names,} + 'camera_names': camera_names, + 'state_dim': state_dim, + 'action_dim': action_dim, + 'use_text': use_text_instruction, + } else: raise NotImplementedError @@ -77,6 +113,7 @@ def main(args): 'ckpt_dir': ckpt_dir, 'episode_len': episode_len, 'state_dim': state_dim, + 'action_dim': action_dim, 'lr': args['lr'], 'policy_class': policy_class, 'onscreen_render': onscreen_render, @@ -85,7 +122,12 @@ def main(args): 'seed': args['seed'], 'temporal_agg': args['temporal_agg'], 'camera_names': camera_names, - 'real_robot': not is_sim + 'real_robot': (not is_sim) and (not is_endoscope), + 'use_text_instruction': use_text_instruction, + 'instruction_mode': instruction_mode, + 'use_cached_text_features': use_cached_text_features, + 'text_tokenizer_name': text_tokenizer_name, + 'text_max_length': text_max_length, } if is_eval: @@ -100,7 +142,19 @@ def main(args): print() exit() - train_dataloader, val_dataloader, stats, _ = load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val) + train_dataloader, val_dataloader, stats, _ = load_data( + dataset_dir, + num_episodes, + camera_names, + batch_size_train, + batch_size_val, + use_text_instruction=use_text_instruction, + instruction_mode=instruction_mode, + use_cached_text_features=use_cached_text_features, + text_feature_dim=text_feature_dim, + text_tokenizer_name=text_tokenizer_name, + text_max_length=text_max_length, + ) # save dataset stats if not os.path.isdir(ckpt_dir): @@ -152,6 +206,7 @@ def eval_bc(config, ckpt_name, save_episode=True): set_seed(1000) ckpt_dir = config['ckpt_dir'] state_dim = config['state_dim'] + action_dim = config['action_dim'] real_robot = config['real_robot'] policy_class = config['policy_class'] onscreen_render = config['onscreen_render'] @@ -161,6 +216,7 @@ def eval_bc(config, ckpt_name, save_episode=True): task_name = config['task_name'] temporal_agg = config['temporal_agg'] onscreen_cam = 'angle' + BOX_POSE = None # load policy and stats ckpt_path = os.path.join(ckpt_dir, ckpt_name) @@ -202,8 +258,14 @@ def eval_bc(config, ckpt_name, save_episode=True): rollout_id += 0 ### set task if 'sim_transfer_cube' in task_name: + if BOX_POSE is None: + from sim_env import BOX_POSE as _BOX_POSE + BOX_POSE = _BOX_POSE BOX_POSE[0] = sample_box_pose() # used in sim reset elif 'sim_insertion' in task_name: + if BOX_POSE is None: + from sim_env import BOX_POSE as _BOX_POSE + BOX_POSE = _BOX_POSE BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset ts = env.reset() @@ -216,7 +278,7 @@ def eval_bc(config, ckpt_name, save_episode=True): ### evaluation loop if temporal_agg: - all_time_actions = torch.zeros([max_timesteps, max_timesteps+num_queries, state_dim]).cuda() + all_time_actions = torch.zeros([max_timesteps, max_timesteps+num_queries, action_dim]).cuda() qpos_history = torch.zeros((1, max_timesteps, state_dim)).cuda() image_list = [] # for visualization @@ -314,9 +376,29 @@ def eval_bc(config, ckpt_name, save_episode=True): def forward_pass(data, policy): - image_data, qpos_data, action_data, is_pad = data - image_data, qpos_data, action_data, is_pad = image_data.cuda(), qpos_data.cuda(), action_data.cuda(), is_pad.cuda() - return policy(qpos_data, image_data, action_data, is_pad) # TODO remove None + image_data, qpos_data, action_data, is_pad, text_input_ids, text_attention_mask, text_feature_data, text_feature_valid = data + image_data = image_data.cuda() + qpos_data = qpos_data.cuda() + action_data = action_data.cuda() + is_pad = is_pad.cuda() + text_input_ids = text_input_ids.cuda() + text_attention_mask = text_attention_mask.cuda() + text_feature_data = text_feature_data.cuda() + text_feature_valid = text_feature_valid.cuda() + + text_features = None + if torch.any(text_feature_valid): + text_features = text_feature_data + + return policy( + qpos_data, + image_data, + text_input_ids=text_input_ids, + text_attention_mask=text_attention_mask, + text_features=text_features, + actions=action_data, + is_pad=is_pad, + ) def train_bc(train_dataloader, val_dataloader, config): @@ -431,5 +513,8 @@ if __name__ == '__main__': parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', required=False) parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', required=False) parser.add_argument('--temporal_agg', action='store_true') + parser.add_argument('--text_encoder_type', action='store', type=str, required=False) + parser.add_argument('--freeze_text_encoder', action='store_true') + parser.add_argument('--text_max_length', action='store', type=int, required=False) main(vars(parser.parse_args())) diff --git a/policy.py b/policy.py index 7b091e5..8fbf031 100644 --- a/policy.py +++ b/policy.py @@ -3,6 +3,7 @@ from torch.nn import functional as F import torchvision.transforms as transforms from detr.main import build_ACT_model_and_optimizer, build_CNNMLP_model_and_optimizer +from models.text_encoder import DistilBERTTextEncoder import IPython e = IPython.embed @@ -13,18 +14,44 @@ class ACTPolicy(nn.Module): self.model = model # CVAE decoder self.optimizer = optimizer self.kl_weight = args_override['kl_weight'] + self.use_text = args_override.get('use_text', False) + self.text_encoder = None + if self.use_text: + text_encoder_type = args_override.get('text_encoder_type', 'distilbert') + if text_encoder_type != 'distilbert': + raise NotImplementedError(f'Unsupported text encoder: {text_encoder_type}') + self.text_encoder = DistilBERTTextEncoder( + model_name=args_override.get('text_tokenizer_name', 'distilbert-base-uncased'), + output_dim=args_override.get('text_feature_dim', 768), + freeze=args_override.get('freeze_text_encoder', True), + ) print(f'KL Weight {self.kl_weight}') - def __call__(self, qpos, image, actions=None, is_pad=None): + def __call__(self, qpos, image, text_input_ids=None, text_attention_mask=None, text_features=None, actions=None, is_pad=None): env_state = None normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) image = normalize(image) + + if self.use_text and text_features is None and text_input_ids is not None and text_attention_mask is not None: + if self.text_encoder is None: + raise RuntimeError('Text encoder is not initialized while use_text=True.') + text_features = self.text_encoder(text_input_ids, text_attention_mask) + if actions is not None: # training time + if is_pad is None: + raise ValueError('`is_pad` must be provided during training when `actions` is not None.') actions = actions[:, :self.model.num_queries] is_pad = is_pad[:, :self.model.num_queries] - a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad) + a_hat, is_pad_hat, (mu, logvar) = self.model( + qpos, + image, + env_state, + text_features=text_features, + actions=actions, + is_pad=is_pad, + ) total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar) loss_dict = dict() all_l1 = F.l1_loss(actions, a_hat, reduction='none') @@ -34,7 +61,7 @@ class ACTPolicy(nn.Module): loss_dict['loss'] = loss_dict['l1'] + loss_dict['kl'] * self.kl_weight return loss_dict else: # inference time - a_hat, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior + a_hat, _, (_, _) = self.model(qpos, image, env_state, text_features=text_features) # no action, sample from prior return a_hat def configure_optimizers(self): @@ -48,7 +75,7 @@ class CNNMLPPolicy(nn.Module): self.model = model # decoder self.optimizer = optimizer - def __call__(self, qpos, image, actions=None, is_pad=None): + def __call__(self, qpos, image, text_input_ids=None, text_attention_mask=None, text_features=None, actions=None, is_pad=None): env_state = None # TODO normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) diff --git a/utils.py b/utils.py index d90b782..6b79697 100644 --- a/utils.py +++ b/utils.py @@ -2,21 +2,88 @@ import numpy as np import torch import os import h5py +import re from torch.utils.data import TensorDataset, DataLoader import IPython e = IPython.embed class EpisodicDataset(torch.utils.data.Dataset): - def __init__(self, episode_ids, dataset_dir, camera_names, norm_stats): + def __init__(self, episode_ids, dataset_dir, camera_names, norm_stats, + use_text_instruction=False, + instruction_mode='timestep-level', + use_cached_text_features=True, + text_feature_dim=768, + text_tokenizer_name='distilbert-base-uncased', + text_max_length=32): super(EpisodicDataset).__init__() self.episode_ids = episode_ids self.dataset_dir = dataset_dir self.camera_names = camera_names self.norm_stats = norm_stats + self.use_text_instruction = use_text_instruction + self.instruction_mode = instruction_mode + self.use_cached_text_features = use_cached_text_features + self.text_feature_dim = text_feature_dim + self.text_max_length = text_max_length self.is_sim = None + self.max_episode_len = None + self.action_dim = None + + self.text_tokenizer = None + if self.use_text_instruction: + try: + from transformers import DistilBertTokenizerFast + except ImportError as exc: + raise ImportError( + 'transformers is required for text instruction loading. ' + 'Install it with: pip install transformers' + ) from exc + self.text_tokenizer = DistilBertTokenizerFast.from_pretrained(text_tokenizer_name) + + self._init_episode_shapes() + self.__getitem__(0) # initialize self.is_sim + def _init_episode_shapes(self): + max_len = 0 + action_dim = None + for episode_id in self.episode_ids: + dataset_path = os.path.join(self.dataset_dir, f'episode_{episode_id}.hdf5') + with h5py.File(dataset_path, 'r') as root: + shape = root['/action'].shape + if len(shape) != 2: + raise ValueError(f'Expected /action to have shape [T, D], got {shape} in {dataset_path}') + max_len = max(max_len, int(shape[0])) + if action_dim is None: + action_dim = int(shape[1]) + elif int(shape[1]) != action_dim: + raise ValueError( + f'Inconsistent action dim in dataset. Expected {action_dim}, got {shape[1]} in {dataset_path}' + ) + + if max_len <= 0 or action_dim is None: + raise ValueError(f'Invalid dataset metadata in {self.dataset_dir}') + + self.max_episode_len = max_len + self.action_dim = action_dim + + @staticmethod + def _decode_instruction(raw_value): + if raw_value is None: + return '' + if isinstance(raw_value, bytes): + return raw_value.decode('utf-8') + if isinstance(raw_value, np.bytes_): + return raw_value.tobytes().decode('utf-8') + if isinstance(raw_value, np.ndarray): + if raw_value.shape == (): + return EpisodicDataset._decode_instruction(raw_value.item()) + if raw_value.size == 0: + return '' + return EpisodicDataset._decode_instruction(raw_value.reshape(-1)[0]) + return str(raw_value) + def __len__(self): return len(self.episode_ids) @@ -26,7 +93,7 @@ class EpisodicDataset(torch.utils.data.Dataset): episode_id = self.episode_ids[index] dataset_path = os.path.join(self.dataset_dir, f'episode_{episode_id}.hdf5') with h5py.File(dataset_path, 'r') as root: - is_sim = root.attrs['sim'] + is_sim = bool(root.attrs.get('sim', False)) original_action_shape = root['/action'].shape episode_len = original_action_shape[0] if sample_full_episode: @@ -35,10 +102,40 @@ class EpisodicDataset(torch.utils.data.Dataset): start_ts = np.random.choice(episode_len) # get observation at start_ts only qpos = root['/observations/qpos'][start_ts] - qvel = root['/observations/qvel'][start_ts] image_dict = dict() for cam_name in self.camera_names: image_dict[cam_name] = root[f'/observations/images/{cam_name}'][start_ts] + + instruction = '' + text_feature = None + if self.use_text_instruction: + effective_mode = self.instruction_mode + if effective_mode == 'timestep-level' and '/instruction_timestep' in root: + instruction = self._decode_instruction(root['/instruction_timestep'][start_ts]) + elif '/instruction' in root: + instruction_node = root['/instruction'] + if getattr(instruction_node, 'shape', ()) == (): + instruction = self._decode_instruction(instruction_node[()]) + else: + if len(instruction_node.shape) == 1 and instruction_node.shape[0] == episode_len: + instruction = self._decode_instruction(instruction_node[start_ts]) + else: + instruction = self._decode_instruction(instruction_node[0]) + + if self.use_cached_text_features: + if effective_mode == 'timestep-level' and '/instruction_features_timestep' in root: + text_feature = root['/instruction_features_timestep'][start_ts] + elif '/instruction_features' in root: + feat_node = root['/instruction_features'] + if getattr(feat_node, 'shape', ()) == (): + text_feature = np.array(feat_node[()]) + elif len(feat_node.shape) == 1: + text_feature = feat_node[()] + elif len(feat_node.shape) == 2 and feat_node.shape[0] == episode_len: + text_feature = feat_node[start_ts] + else: + text_feature = feat_node[0] + # get all actions after and including start_ts if is_sim: action = root['/action'][start_ts:] @@ -48,10 +145,10 @@ class EpisodicDataset(torch.utils.data.Dataset): action_len = episode_len - max(0, start_ts - 1) # hack, to make timesteps more aligned self.is_sim = is_sim - padded_action = np.zeros(original_action_shape, dtype=np.float32) + padded_action = np.zeros((self.max_episode_len, self.action_dim), dtype=np.float32) padded_action[:action_len] = action - is_pad = np.zeros(episode_len) - is_pad[action_len:] = 1 + is_pad = np.ones(self.max_episode_len) + is_pad[:action_len] = 0 # new axis for different cameras all_cam_images = [] @@ -73,55 +170,132 @@ class EpisodicDataset(torch.utils.data.Dataset): action_data = (action_data - self.norm_stats["action_mean"]) / self.norm_stats["action_std"] qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats["qpos_std"] - return image_data, qpos_data, action_data, is_pad + if self.use_text_instruction and text_feature is not None: + text_feature_data = torch.from_numpy(np.array(text_feature)).float() + text_feature_valid = torch.tensor(True, dtype=torch.bool) + text_input_ids = torch.zeros(1, dtype=torch.long) + text_attention_mask = torch.zeros(1, dtype=torch.long) + elif self.use_text_instruction: + tokenized = self.text_tokenizer( + instruction, + padding='max_length', + truncation=True, + max_length=self.text_max_length, + return_tensors='pt', + ) + text_input_ids = tokenized['input_ids'].squeeze(0).long() + text_attention_mask = tokenized['attention_mask'].squeeze(0).long() + text_feature_data = torch.zeros(self.text_feature_dim, dtype=torch.float32) + text_feature_valid = torch.tensor(False, dtype=torch.bool) + else: + text_input_ids = torch.zeros(1, dtype=torch.long) + text_attention_mask = torch.zeros(1, dtype=torch.long) + text_feature_data = torch.zeros(self.text_feature_dim, dtype=torch.float32) + text_feature_valid = torch.tensor(False, dtype=torch.bool) + + return image_data, qpos_data, action_data, is_pad, text_input_ids, text_attention_mask, text_feature_data, text_feature_valid -def get_norm_stats(dataset_dir, num_episodes): +def _discover_episode_ids(dataset_dir, num_episodes=None): + pattern = re.compile(r'^episode_(\d+)\.hdf5$') + episode_ids = [] + for fname in os.listdir(dataset_dir): + m = pattern.match(fname) + if m: + episode_ids.append(int(m.group(1))) + episode_ids.sort() + if num_episodes is not None: + episode_ids = episode_ids[:num_episodes] + return episode_ids + + +def get_norm_stats(dataset_dir, episode_ids): all_qpos_data = [] all_action_data = [] - for episode_idx in range(num_episodes): + example_qpos = None + for episode_idx in episode_ids: dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}.hdf5') with h5py.File(dataset_path, 'r') as root: qpos = root['/observations/qpos'][()] - qvel = root['/observations/qvel'][()] action = root['/action'][()] - all_qpos_data.append(torch.from_numpy(qpos)) - all_action_data.append(torch.from_numpy(action)) - all_qpos_data = torch.stack(all_qpos_data) - all_action_data = torch.stack(all_action_data) - all_action_data = all_action_data + qpos_t = torch.from_numpy(qpos) + action_t = torch.from_numpy(action) + all_qpos_data.append(qpos_t) + all_action_data.append(action_t) + if example_qpos is None and len(qpos) > 0: + example_qpos = qpos[0] + + # Episodes may have different lengths; concatenate over time axis. + all_qpos_data = torch.cat(all_qpos_data, dim=0) + all_action_data = torch.cat(all_action_data, dim=0) # normalize action data - action_mean = all_action_data.mean(dim=[0, 1], keepdim=True) - action_std = all_action_data.std(dim=[0, 1], keepdim=True) + action_mean = all_action_data.mean(dim=0, keepdim=True) + action_std = all_action_data.std(dim=0, keepdim=True) action_std = torch.clip(action_std, 1e-2, np.inf) # clipping # normalize qpos data - qpos_mean = all_qpos_data.mean(dim=[0, 1], keepdim=True) - qpos_std = all_qpos_data.std(dim=[0, 1], keepdim=True) + qpos_mean = all_qpos_data.mean(dim=0, keepdim=True) + qpos_std = all_qpos_data.std(dim=0, keepdim=True) qpos_std = torch.clip(qpos_std, 1e-2, np.inf) # clipping stats = {"action_mean": action_mean.numpy().squeeze(), "action_std": action_std.numpy().squeeze(), "qpos_mean": qpos_mean.numpy().squeeze(), "qpos_std": qpos_std.numpy().squeeze(), - "example_qpos": qpos} + "example_qpos": example_qpos} return stats - -def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val): +def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val, + use_text_instruction=False, + instruction_mode='timestep-level', + use_cached_text_features=True, + text_feature_dim=768, + text_tokenizer_name='distilbert-base-uncased', + text_max_length=32): print(f'\nData from: {dataset_dir}\n') + episode_ids = _discover_episode_ids(dataset_dir, num_episodes) + if len(episode_ids) == 0: + raise FileNotFoundError(f'No episode_*.hdf5 found in {dataset_dir}') + if len(episode_ids) < 2: + raise ValueError(f'Need at least 2 episodes for train/val split, found {len(episode_ids)} in {dataset_dir}') + # obtain train test split train_ratio = 0.8 - shuffled_indices = np.random.permutation(num_episodes) - train_indices = shuffled_indices[:int(train_ratio * num_episodes)] - val_indices = shuffled_indices[int(train_ratio * num_episodes):] + shuffled_indices = np.random.permutation(len(episode_ids)) + train_count = int(train_ratio * len(episode_ids)) + train_indices = shuffled_indices[:train_count] + val_indices = shuffled_indices[train_count:] + train_episode_ids = np.array(episode_ids)[train_indices] + val_episode_ids = np.array(episode_ids)[val_indices] # obtain normalization stats for qpos and action - norm_stats = get_norm_stats(dataset_dir, num_episodes) + norm_stats = get_norm_stats(dataset_dir, episode_ids) # construct dataset and dataloader - train_dataset = EpisodicDataset(train_indices, dataset_dir, camera_names, norm_stats) - val_dataset = EpisodicDataset(val_indices, dataset_dir, camera_names, norm_stats) + train_dataset = EpisodicDataset( + train_episode_ids, + dataset_dir, + camera_names, + norm_stats, + use_text_instruction=use_text_instruction, + instruction_mode=instruction_mode, + use_cached_text_features=use_cached_text_features, + text_feature_dim=text_feature_dim, + text_tokenizer_name=text_tokenizer_name, + text_max_length=text_max_length, + ) + val_dataset = EpisodicDataset( + val_episode_ids, + dataset_dir, + camera_names, + norm_stats, + use_text_instruction=use_text_instruction, + instruction_mode=instruction_mode, + use_cached_text_features=use_cached_text_features, + text_feature_dim=text_feature_dim, + text_tokenizer_name=text_tokenizer_name, + text_max_length=text_max_length, + ) train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1) val_dataloader = DataLoader(val_dataset, batch_size=batch_size_val, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1) diff --git a/visualize_episodes.py b/visualize_episodes.py index 4e55e47..edbad13 100644 --- a/visualize_episodes.py +++ b/visualize_episodes.py @@ -21,32 +21,33 @@ def load_hdf5(dataset_dir, dataset_name): with h5py.File(dataset_path, 'r') as root: is_sim = root.attrs['sim'] + dt = float(root.attrs.get('dt', DT)) qpos = root['/observations/qpos'][()] - qvel = root['/observations/qvel'][()] + qvel = root['/observations/qvel'][()] if '/observations/qvel' in root else np.zeros_like(qpos) action = root['/action'][()] image_dict = dict() for cam_name in root[f'/observations/images/'].keys(): image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()] - return qpos, qvel, action, image_dict + return qpos, qvel, action, image_dict, dt def main(args): dataset_dir = args['dataset_dir'] episode_idx = args['episode_idx'] dataset_name = f'episode_{episode_idx}' - qpos, qvel, action, image_dict = load_hdf5(dataset_dir, dataset_name) - save_videos(image_dict, DT, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4')) + qpos, qvel, action, image_dict, dt = load_hdf5(dataset_dir, dataset_name) + save_videos(image_dict, dt, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4')) visualize_joints(qpos, action, plot_path=os.path.join(dataset_dir, dataset_name + '_qpos.png')) # visualize_timestamp(t_list, dataset_path) # TODO addn timestamp back -def save_videos(video, dt, video_path=None): +def save_videos(video, dt, video_path): if isinstance(video, list): cam_names = list(video[0].keys()) h, w, _ = video[0][cam_names[0]].shape w = w * len(cam_names) - fps = int(1/dt) + fps = max(1, int(round(1 / dt))) out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) for ts, image_dict in enumerate(video): images = [] @@ -66,7 +67,7 @@ def save_videos(video, dt, video_path=None): all_cam_videos = np.concatenate(all_cam_videos, axis=2) # width dimension n_frames, h, w, _ = all_cam_videos.shape - fps = int(1 / dt) + fps = max(1, int(round(1 / dt))) out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) for t in range(n_frames): image = all_cam_videos[t]