代码可以跑起来了

This commit is contained in:
2026-02-19 15:32:28 +08:00
parent b701d939c2
commit 88d14221ae
11 changed files with 503 additions and 89 deletions

View File

@@ -268,10 +268,30 @@
## 8. 你接下来只需提供的最小信息(进入代码改造前) ## 8. 你接下来只需提供的最小信息(进入代码改造前)
1. 2 个电机各自的物理含义与取值范围(单位、上下限) 1. 2 个电机各自的物理含义与取值范围(单位、上下限):电机分别为 motor_x 和 motor_yx 的范围为 7000-17384y 的范围为 8000-18884。对应的 action_x 和 action_y 都为 0~65535 之间
2. 你当前数据中 `qpos``action` 的实际定义(是否相同)。 2. 你当前数据中 `qpos``action` 的实际定义(是否相同)action 和 qpos 定义接近,只不过是将 0~65535 分别映射到电机磁编码器数值上
3. text instruction 是每个 episode 一条,还是每个 timestep 一条。 3. text instruction 是每个 episode 一条,还是每个 timestep 一条text instruction 是每个 timestep 一条。
4. 相机数量、分辨率、帧率 4. 相机数量、分辨率、帧率:相机数量为 1分辨率为 224*224帧率为 30Hz对应的电机控制频率也为 30Hz
5. 是否在训练时冻结 DistilBERT`freeze_text_encoder=True/False`)。 5. 是否在训练时冻结 DistilBERT`freeze_text_encoder=True/False`DistilBERT 完全冻结。构建训练集时,先将每一个 frame 的 text instruction 用 DistilBERT 编码以后再保存。这样训练过程中不需要调用 DistilBERT
> 有了这 5 项,即可进入下一步代码改造。 > 有了这 5 项,即可进入下一步代码改造。
text_input_ids、text_attention_mask 什么意思;'instruction_mode': 'episode-level'没有用到;
```python
instruction = ''
if self.use_text_instruction:
if '/instruction_timestep' in root:
instruction = self._decode_instruction(root['/instruction_timestep'][start_ts])
elif '/instruction' in root:
instruction_node = root['/instruction']
if getattr(instruction_node, 'shape', ()) == ():
instruction = self._decode_instruction(instruction_node[()])
else:
if len(instruction_node.shape) == 1 and instruction_node.shape[0] == episode_len:
instruction = self._decode_instruction(instruction_node[start_ts])
else:
instruction = self._decode_instruction(instruction_node[0])
```
为什么修改了 Transformer 的定义?这里是否会生效?

View File

@@ -21,3 +21,5 @@ dependencies:
- packaging=23.0 - packaging=23.0
- h5py=3.8.0 - h5py=3.8.0
- ipython=8.12.0 - ipython=8.12.0
- pip:
- transformers==4.38.2

View File

@@ -1,7 +1,7 @@
import pathlib import pathlib
### Task parameters ### Task parameters
DATA_DIR = '<put your data dir here>' DATA_DIR = str(pathlib.Path(__file__).parent.resolve() / 'data')
SIM_TASK_CONFIGS = { SIM_TASK_CONFIGS = {
'sim_transfer_cube_scripted':{ 'sim_transfer_cube_scripted':{
'dataset_dir': DATA_DIR + '/sim_transfer_cube_scripted', 'dataset_dir': DATA_DIR + '/sim_transfer_cube_scripted',
@@ -32,6 +32,43 @@ SIM_TASK_CONFIGS = {
}, },
} }
ENDOSCOPE_TASK_CONFIGS = {
'endoscope_default': {
'dataset_dir': DATA_DIR + '/endoscope_default',
'num_episodes': 50,
'episode_len': 400,
'camera_names': ['top'],
'state_dim': 2,
'action_dim': 2,
'use_text_instruction': True,
'instruction_mode': 'timestep-level',
'use_cached_text_features': True,
'text_encoder_type': 'distilbert',
'text_feature_dim': 768,
'text_fusion_type': 'concat_transformer_input',
'freeze_text_encoder': True,
'text_max_length': 32,
'text_tokenizer_name': 'distilbert-base-uncased',
},
'endoscope_follow': {
'dataset_dir': DATA_DIR + '/follow',
'num_episodes': 3,
'episode_len': 400,
'camera_names': ['top'],
'state_dim': 2,
'action_dim': 2,
'use_text_instruction': True,
'instruction_mode': 'timestep-level',
'use_cached_text_features': True,
'text_encoder_type': 'distilbert',
'text_feature_dim': 768,
'text_fusion_type': 'concat_transformer_input',
'freeze_text_encoder': True,
'text_max_length': 32,
'text_tokenizer_name': 'distilbert-base-uncased',
},
}
### Simulation envs fixed constants ### Simulation envs fixed constants
DT = 0.02 DT = 0.02
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"] JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]

View File

@@ -4,6 +4,7 @@ from pathlib import Path
import numpy as np import numpy as np
import torch import torch
from torch.optim.adamw import AdamW
from .models import build_ACT_model, build_CNNMLP_model from .models import build_ACT_model, build_CNNMLP_model
import IPython import IPython
@@ -30,6 +31,15 @@ def get_args_parser():
help="Type of positional embedding to use on top of the image features") help="Type of positional embedding to use on top of the image features")
parser.add_argument('--camera_names', default=[], type=list, # will be overridden parser.add_argument('--camera_names', default=[], type=list, # will be overridden
help="A list of camera names") help="A list of camera names")
parser.add_argument('--state_dim', default=14, type=int)
parser.add_argument('--action_dim', default=14, type=int)
parser.add_argument('--use_text', action='store_true')
parser.add_argument('--text_encoder_type', default='distilbert', type=str)
parser.add_argument('--text_feature_dim', default=768, type=int)
parser.add_argument('--text_fusion_type', default='concat_transformer_input', type=str)
parser.add_argument('--freeze_text_encoder', action='store_true')
parser.add_argument('--text_max_length', default=32, type=int)
parser.add_argument('--text_tokenizer_name', default='distilbert-base-uncased', type=str)
# * Transformer # * Transformer
parser.add_argument('--enc_layers', default=4, type=int, # will be overridden parser.add_argument('--enc_layers', default=4, type=int, # will be overridden
@@ -84,8 +94,8 @@ def build_ACT_model_and_optimizer(args_override):
"lr": args.lr_backbone, "lr": args.lr_backbone,
}, },
] ]
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, optimizer = AdamW(param_dicts, lr=args.lr,
weight_decay=args.weight_decay) weight_decay=args.weight_decay)
return model, optimizer return model, optimizer
@@ -107,8 +117,8 @@ def build_CNNMLP_model_and_optimizer(args_override):
"lr": args.lr_backbone, "lr": args.lr_backbone,
}, },
] ]
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, optimizer = AdamW(param_dicts, lr=args.lr,
weight_decay=args.weight_decay) weight_decay=args.weight_decay)
return model, optimizer return model, optimizer

View File

@@ -89,9 +89,32 @@ class Backbone(BackboneBase):
train_backbone: bool, train_backbone: bool,
return_interm_layers: bool, return_interm_layers: bool,
dilation: bool): dilation: bool):
backbone = getattr(torchvision.models, name)( backbone_builder = getattr(torchvision.models, name)
replace_stride_with_dilation=[False, False, dilation], weights = None
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm?? if is_main_process():
weight_enum_name_map = {
'resnet18': 'ResNet18_Weights',
'resnet34': 'ResNet34_Weights',
'resnet50': 'ResNet50_Weights',
'resnet101': 'ResNet101_Weights',
}
enum_name = weight_enum_name_map.get(name)
if enum_name is not None and hasattr(torchvision.models, enum_name):
weights = getattr(getattr(torchvision.models, enum_name), 'DEFAULT')
try:
backbone = backbone_builder(
replace_stride_with_dilation=[False, False, dilation],
weights=weights,
norm_layer=FrozenBatchNorm2d,
)
except TypeError:
# Backward compatibility for older torchvision that still expects `pretrained`.
backbone = backbone_builder(
replace_stride_with_dilation=[False, False, dilation],
pretrained=(weights is not None),
norm_layer=FrozenBatchNorm2d,
)
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
super().__init__(backbone, train_backbone, num_channels, return_interm_layers) super().__init__(backbone, train_backbone, num_channels, return_interm_layers)

View File

@@ -33,7 +33,8 @@ def get_sinusoid_encoding_table(n_position, d_hid):
class DETRVAE(nn.Module): class DETRVAE(nn.Module):
""" This is the DETR module that performs object detection """ """ This is the DETR module that performs object detection """
def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names): def __init__(self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names,
use_text=False, text_feature_dim=768, text_fusion_type='concat_transformer_input'):
""" Initializes the model. """ Initializes the model.
Parameters: Parameters:
backbones: torch module of the backbone to be used. See backbone.py backbones: torch module of the backbone to be used. See backbone.py
@@ -48,17 +49,18 @@ class DETRVAE(nn.Module):
self.camera_names = camera_names self.camera_names = camera_names
self.transformer = transformer self.transformer = transformer
self.encoder = encoder self.encoder = encoder
self.use_text = use_text
self.text_fusion_type = text_fusion_type
hidden_dim = transformer.d_model hidden_dim = transformer.d_model
self.action_head = nn.Linear(hidden_dim, state_dim) self.action_head = nn.Linear(hidden_dim, action_dim)
self.is_pad_head = nn.Linear(hidden_dim, 1) self.is_pad_head = nn.Linear(hidden_dim, 1)
self.query_embed = nn.Embedding(num_queries, hidden_dim) self.query_embed = nn.Embedding(num_queries, hidden_dim)
if backbones is not None: if backbones is not None:
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1) self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
self.backbones = nn.ModuleList(backbones) self.backbones = nn.ModuleList(backbones)
self.input_proj_robot_state = nn.Linear(14, hidden_dim) self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
else: else:
# input_dim = 14 + 7 # robot_state + env_state self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
self.input_proj_robot_state = nn.Linear(14, hidden_dim)
self.input_proj_env_state = nn.Linear(7, hidden_dim) self.input_proj_env_state = nn.Linear(7, hidden_dim)
self.pos = torch.nn.Embedding(2, hidden_dim) self.pos = torch.nn.Embedding(2, hidden_dim)
self.backbones = None self.backbones = None
@@ -66,16 +68,18 @@ class DETRVAE(nn.Module):
# encoder extra parameters # encoder extra parameters
self.latent_dim = 32 # final size of latent z # TODO tune self.latent_dim = 32 # final size of latent z # TODO tune
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding self.encoder_action_proj = nn.Linear(action_dim, hidden_dim) # project action to embedding
self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding self.encoder_joint_proj = nn.Linear(state_dim, hidden_dim) # project qpos to embedding
self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq
# decoder extra parameters # decoder extra parameters
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent num_extra_tokens = 3 if self.use_text else 2
self.additional_pos_embed = nn.Embedding(num_extra_tokens, hidden_dim) # latent, proprio, optional text
self.text_proj = nn.Linear(text_feature_dim, hidden_dim) if self.use_text else None
def forward(self, qpos, image, env_state, actions=None, is_pad=None): def forward(self, qpos, image, env_state, text_features=None, actions=None, is_pad=None):
""" """
qpos: batch, qpos_dim qpos: batch, qpos_dim
image: batch, num_cam, channel, height, width image: batch, num_cam, channel, height, width
@@ -125,10 +129,25 @@ class DETRVAE(nn.Module):
all_cam_pos.append(pos) all_cam_pos.append(pos)
# proprioception features # proprioception features
proprio_input = self.input_proj_robot_state(qpos) proprio_input = self.input_proj_robot_state(qpos)
extra_input_tokens = None
if self.use_text and text_features is not None:
if self.text_fusion_type != 'concat_transformer_input':
raise NotImplementedError(f'Unsupported text fusion type: {self.text_fusion_type}')
text_input = self.text_proj(text_features)
extra_input_tokens = text_input.unsqueeze(0)
# fold camera dimension into width dimension # fold camera dimension into width dimension
src = torch.cat(all_cam_features, axis=3) src = torch.cat(all_cam_features, axis=3)
pos = torch.cat(all_cam_pos, axis=3) pos = torch.cat(all_cam_pos, axis=3)
hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0] hs = self.transformer(
src,
None,
self.query_embed.weight,
pos,
latent_input,
proprio_input,
self.additional_pos_embed.weight,
extra_input_tokens=extra_input_tokens,
)[0]
else: else:
qpos = self.input_proj_robot_state(qpos) qpos = self.input_proj_robot_state(qpos)
env_state = self.input_proj_env_state(env_state) env_state = self.input_proj_env_state(env_state)
@@ -141,7 +160,7 @@ class DETRVAE(nn.Module):
class CNNMLP(nn.Module): class CNNMLP(nn.Module):
def __init__(self, backbones, state_dim, camera_names): def __init__(self, backbones, state_dim, action_dim, camera_names):
""" Initializes the model. """ Initializes the model.
Parameters: Parameters:
backbones: torch module of the backbone to be used. See backbone.py backbones: torch module of the backbone to be used. See backbone.py
@@ -153,7 +172,7 @@ class CNNMLP(nn.Module):
""" """
super().__init__() super().__init__()
self.camera_names = camera_names self.camera_names = camera_names
self.action_head = nn.Linear(1000, state_dim) # TODO add more self.action_head = nn.Linear(1000, action_dim) # TODO add more
if backbones is not None: if backbones is not None:
self.backbones = nn.ModuleList(backbones) self.backbones = nn.ModuleList(backbones)
backbone_down_projs = [] backbone_down_projs = []
@@ -166,8 +185,8 @@ class CNNMLP(nn.Module):
backbone_down_projs.append(down_proj) backbone_down_projs.append(down_proj)
self.backbone_down_projs = nn.ModuleList(backbone_down_projs) self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
mlp_in_dim = 768 * len(backbones) + 14 mlp_in_dim = 768 * len(backbones) + state_dim
self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=14, hidden_depth=2) self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=action_dim, hidden_depth=2)
else: else:
raise NotImplementedError raise NotImplementedError
@@ -192,7 +211,7 @@ class CNNMLP(nn.Module):
for cam_feature in all_cam_features: for cam_feature in all_cam_features:
flattened_features.append(cam_feature.reshape([bs, -1])) flattened_features.append(cam_feature.reshape([bs, -1]))
flattened_features = torch.cat(flattened_features, axis=1) # 768 each flattened_features = torch.cat(flattened_features, axis=1) # 768 each
features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14 features = torch.cat([flattened_features, qpos], axis=1)
a_hat = self.mlp(features) a_hat = self.mlp(features)
return a_hat return a_hat
@@ -227,7 +246,8 @@ def build_encoder(args):
def build(args): def build(args):
state_dim = 14 # TODO hardcode state_dim = args.state_dim
action_dim = args.action_dim
# From state # From state
# backbone = None # from state for now, no need for conv nets # backbone = None # from state for now, no need for conv nets
@@ -245,8 +265,12 @@ def build(args):
transformer, transformer,
encoder, encoder,
state_dim=state_dim, state_dim=state_dim,
action_dim=action_dim,
num_queries=args.num_queries, num_queries=args.num_queries,
camera_names=args.camera_names, camera_names=args.camera_names,
use_text=args.use_text,
text_feature_dim=args.text_feature_dim,
text_fusion_type=args.text_fusion_type,
) )
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
@@ -255,7 +279,8 @@ def build(args):
return model return model
def build_cnnmlp(args): def build_cnnmlp(args):
state_dim = 14 # TODO hardcode state_dim = args.state_dim
action_dim = args.action_dim
# From state # From state
# backbone = None # from state for now, no need for conv nets # backbone = None # from state for now, no need for conv nets
@@ -268,6 +293,7 @@ def build_cnnmlp(args):
model = CNNMLP( model = CNNMLP(
backbones, backbones,
state_dim=state_dim, state_dim=state_dim,
action_dim=action_dim,
camera_names=args.camera_names, camera_names=args.camera_names,
) )

View File

@@ -46,7 +46,7 @@ class Transformer(nn.Module):
if p.dim() > 1: if p.dim() > 1:
nn.init.xavier_uniform_(p) nn.init.xavier_uniform_(p)
def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None): def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None, extra_input_tokens=None):
# TODO flatten only when input has H and W # TODO flatten only when input has H and W
if len(src.shape) == 4: # has H and W if len(src.shape) == 4: # has H and W
# flatten NxCxHxW to HWxNxC # flatten NxCxHxW to HWxNxC
@@ -56,10 +56,19 @@ class Transformer(nn.Module):
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
# mask = mask.flatten(1) # mask = mask.flatten(1)
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim additional_inputs = [latent_input, proprio_input]
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0) if extra_input_tokens is not None:
if len(extra_input_tokens.shape) == 2:
extra_input_tokens = extra_input_tokens.unsqueeze(0)
for i in range(extra_input_tokens.shape[0]):
additional_inputs.append(extra_input_tokens[i])
addition_input = torch.stack(additional_inputs, axis=0)
if additional_pos_embed is not None:
additional_pos_embed = additional_pos_embed[:addition_input.shape[0]]
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
addition_input = torch.stack([latent_input, proprio_input], axis=0)
src = torch.cat([addition_input, src], axis=0) src = torch.cat([addition_input, src], axis=0)
else: else:
assert len(src.shape) == 3 assert len(src.shape) == 3

View File

@@ -10,14 +10,13 @@ from einops import rearrange
from constants import DT from constants import DT
from constants import PUPPET_GRIPPER_JOINT_OPEN from constants import PUPPET_GRIPPER_JOINT_OPEN
from constants import SIM_TASK_CONFIGS, ENDOSCOPE_TASK_CONFIGS
from utils import load_data # data functions from utils import load_data # data functions
from utils import sample_box_pose, sample_insertion_pose # robot functions from utils import sample_box_pose, sample_insertion_pose # robot functions
from utils import compute_dict_mean, set_seed, detach_dict # helper functions from utils import compute_dict_mean, set_seed, detach_dict # helper functions
from policy import ACTPolicy, CNNMLPPolicy from policy import ACTPolicy, CNNMLPPolicy
from visualize_episodes import save_videos from visualize_episodes import save_videos
from sim_env import BOX_POSE
import IPython import IPython
e = IPython.embed e = IPython.embed
@@ -34,25 +33,47 @@ def main(args):
num_epochs = args['num_epochs'] num_epochs = args['num_epochs']
# get task parameters # get task parameters
is_sim = task_name[:4] == 'sim_' is_endoscope = task_name in ENDOSCOPE_TASK_CONFIGS
if is_sim: if is_endoscope:
from constants import SIM_TASK_CONFIGS task_config = ENDOSCOPE_TASK_CONFIGS[task_name]
is_sim = False
elif task_name in SIM_TASK_CONFIGS:
task_config = SIM_TASK_CONFIGS[task_name] task_config = SIM_TASK_CONFIGS[task_name]
is_sim = True
else: else:
from aloha_scripts.constants import TASK_CONFIGS from aloha_scripts.constants import TASK_CONFIGS
task_config = TASK_CONFIGS[task_name] task_config = TASK_CONFIGS[task_name]
is_sim = False
dataset_dir = task_config['dataset_dir'] dataset_dir = task_config['dataset_dir']
num_episodes = task_config['num_episodes'] num_episodes = task_config['num_episodes']
episode_len = task_config['episode_len'] episode_len = task_config['episode_len']
camera_names = task_config['camera_names'] camera_names = task_config['camera_names']
state_dim = task_config.get('state_dim', 14)
action_dim = task_config.get('action_dim', state_dim)
use_text_instruction = task_config.get('use_text_instruction', False)
instruction_mode = task_config.get('instruction_mode', 'timestep-level')
use_cached_text_features = task_config.get('use_cached_text_features', True)
text_encoder_type = task_config.get('text_encoder_type', 'distilbert')
text_feature_dim = task_config.get('text_feature_dim', 768)
text_fusion_type = task_config.get('text_fusion_type', 'concat_transformer_input')
text_max_length = task_config.get('text_max_length', 32)
text_tokenizer_name = task_config.get('text_tokenizer_name', 'distilbert-base-uncased')
freeze_text_encoder = task_config.get('freeze_text_encoder', True)
if args.get('text_encoder_type') is not None:
text_encoder_type = args['text_encoder_type']
if args.get('text_max_length') is not None:
text_max_length = args['text_max_length']
if args.get('freeze_text_encoder', False):
freeze_text_encoder = True
# fixed parameters # fixed parameters
state_dim = 14
lr_backbone = 1e-5 lr_backbone = 1e-5
backbone = 'resnet18' backbone = 'resnet18'
if policy_class == 'ACT': if policy_class == 'ACT':
enc_layers = 4 enc_layers = 2
dec_layers = 7 dec_layers = 4
nheads = 8 nheads = 8
policy_config = {'lr': args['lr'], policy_config = {'lr': args['lr'],
'num_queries': args['chunk_size'], 'num_queries': args['chunk_size'],
@@ -65,10 +86,25 @@ def main(args):
'dec_layers': dec_layers, 'dec_layers': dec_layers,
'nheads': nheads, 'nheads': nheads,
'camera_names': camera_names, 'camera_names': camera_names,
'state_dim': state_dim,
'action_dim': action_dim,
'use_text': use_text_instruction,
'text_encoder_type': text_encoder_type,
'text_feature_dim': text_feature_dim,
'text_fusion_type': text_fusion_type,
'freeze_text_encoder': freeze_text_encoder,
'instruction_mode': instruction_mode,
'use_cached_text_features': use_cached_text_features,
'text_max_length': text_max_length,
'text_tokenizer_name': text_tokenizer_name,
} }
elif policy_class == 'CNNMLP': elif policy_class == 'CNNMLP':
policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1, policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1,
'camera_names': camera_names,} 'camera_names': camera_names,
'state_dim': state_dim,
'action_dim': action_dim,
'use_text': use_text_instruction,
}
else: else:
raise NotImplementedError raise NotImplementedError
@@ -77,6 +113,7 @@ def main(args):
'ckpt_dir': ckpt_dir, 'ckpt_dir': ckpt_dir,
'episode_len': episode_len, 'episode_len': episode_len,
'state_dim': state_dim, 'state_dim': state_dim,
'action_dim': action_dim,
'lr': args['lr'], 'lr': args['lr'],
'policy_class': policy_class, 'policy_class': policy_class,
'onscreen_render': onscreen_render, 'onscreen_render': onscreen_render,
@@ -85,7 +122,12 @@ def main(args):
'seed': args['seed'], 'seed': args['seed'],
'temporal_agg': args['temporal_agg'], 'temporal_agg': args['temporal_agg'],
'camera_names': camera_names, 'camera_names': camera_names,
'real_robot': not is_sim 'real_robot': (not is_sim) and (not is_endoscope),
'use_text_instruction': use_text_instruction,
'instruction_mode': instruction_mode,
'use_cached_text_features': use_cached_text_features,
'text_tokenizer_name': text_tokenizer_name,
'text_max_length': text_max_length,
} }
if is_eval: if is_eval:
@@ -100,7 +142,19 @@ def main(args):
print() print()
exit() exit()
train_dataloader, val_dataloader, stats, _ = load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val) train_dataloader, val_dataloader, stats, _ = load_data(
dataset_dir,
num_episodes,
camera_names,
batch_size_train,
batch_size_val,
use_text_instruction=use_text_instruction,
instruction_mode=instruction_mode,
use_cached_text_features=use_cached_text_features,
text_feature_dim=text_feature_dim,
text_tokenizer_name=text_tokenizer_name,
text_max_length=text_max_length,
)
# save dataset stats # save dataset stats
if not os.path.isdir(ckpt_dir): if not os.path.isdir(ckpt_dir):
@@ -152,6 +206,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
set_seed(1000) set_seed(1000)
ckpt_dir = config['ckpt_dir'] ckpt_dir = config['ckpt_dir']
state_dim = config['state_dim'] state_dim = config['state_dim']
action_dim = config['action_dim']
real_robot = config['real_robot'] real_robot = config['real_robot']
policy_class = config['policy_class'] policy_class = config['policy_class']
onscreen_render = config['onscreen_render'] onscreen_render = config['onscreen_render']
@@ -161,6 +216,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
task_name = config['task_name'] task_name = config['task_name']
temporal_agg = config['temporal_agg'] temporal_agg = config['temporal_agg']
onscreen_cam = 'angle' onscreen_cam = 'angle'
BOX_POSE = None
# load policy and stats # load policy and stats
ckpt_path = os.path.join(ckpt_dir, ckpt_name) ckpt_path = os.path.join(ckpt_dir, ckpt_name)
@@ -202,8 +258,14 @@ def eval_bc(config, ckpt_name, save_episode=True):
rollout_id += 0 rollout_id += 0
### set task ### set task
if 'sim_transfer_cube' in task_name: if 'sim_transfer_cube' in task_name:
if BOX_POSE is None:
from sim_env import BOX_POSE as _BOX_POSE
BOX_POSE = _BOX_POSE
BOX_POSE[0] = sample_box_pose() # used in sim reset BOX_POSE[0] = sample_box_pose() # used in sim reset
elif 'sim_insertion' in task_name: elif 'sim_insertion' in task_name:
if BOX_POSE is None:
from sim_env import BOX_POSE as _BOX_POSE
BOX_POSE = _BOX_POSE
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
ts = env.reset() ts = env.reset()
@@ -216,7 +278,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
### evaluation loop ### evaluation loop
if temporal_agg: if temporal_agg:
all_time_actions = torch.zeros([max_timesteps, max_timesteps+num_queries, state_dim]).cuda() all_time_actions = torch.zeros([max_timesteps, max_timesteps+num_queries, action_dim]).cuda()
qpos_history = torch.zeros((1, max_timesteps, state_dim)).cuda() qpos_history = torch.zeros((1, max_timesteps, state_dim)).cuda()
image_list = [] # for visualization image_list = [] # for visualization
@@ -314,9 +376,29 @@ def eval_bc(config, ckpt_name, save_episode=True):
def forward_pass(data, policy): def forward_pass(data, policy):
image_data, qpos_data, action_data, is_pad = data image_data, qpos_data, action_data, is_pad, text_input_ids, text_attention_mask, text_feature_data, text_feature_valid = data
image_data, qpos_data, action_data, is_pad = image_data.cuda(), qpos_data.cuda(), action_data.cuda(), is_pad.cuda() image_data = image_data.cuda()
return policy(qpos_data, image_data, action_data, is_pad) # TODO remove None qpos_data = qpos_data.cuda()
action_data = action_data.cuda()
is_pad = is_pad.cuda()
text_input_ids = text_input_ids.cuda()
text_attention_mask = text_attention_mask.cuda()
text_feature_data = text_feature_data.cuda()
text_feature_valid = text_feature_valid.cuda()
text_features = None
if torch.any(text_feature_valid):
text_features = text_feature_data
return policy(
qpos_data,
image_data,
text_input_ids=text_input_ids,
text_attention_mask=text_attention_mask,
text_features=text_features,
actions=action_data,
is_pad=is_pad,
)
def train_bc(train_dataloader, val_dataloader, config): def train_bc(train_dataloader, val_dataloader, config):
@@ -431,5 +513,8 @@ if __name__ == '__main__':
parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', required=False) parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', required=False)
parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', required=False) parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', required=False)
parser.add_argument('--temporal_agg', action='store_true') parser.add_argument('--temporal_agg', action='store_true')
parser.add_argument('--text_encoder_type', action='store', type=str, required=False)
parser.add_argument('--freeze_text_encoder', action='store_true')
parser.add_argument('--text_max_length', action='store', type=int, required=False)
main(vars(parser.parse_args())) main(vars(parser.parse_args()))

View File

@@ -3,6 +3,7 @@ from torch.nn import functional as F
import torchvision.transforms as transforms import torchvision.transforms as transforms
from detr.main import build_ACT_model_and_optimizer, build_CNNMLP_model_and_optimizer from detr.main import build_ACT_model_and_optimizer, build_CNNMLP_model_and_optimizer
from models.text_encoder import DistilBERTTextEncoder
import IPython import IPython
e = IPython.embed e = IPython.embed
@@ -13,18 +14,44 @@ class ACTPolicy(nn.Module):
self.model = model # CVAE decoder self.model = model # CVAE decoder
self.optimizer = optimizer self.optimizer = optimizer
self.kl_weight = args_override['kl_weight'] self.kl_weight = args_override['kl_weight']
self.use_text = args_override.get('use_text', False)
self.text_encoder = None
if self.use_text:
text_encoder_type = args_override.get('text_encoder_type', 'distilbert')
if text_encoder_type != 'distilbert':
raise NotImplementedError(f'Unsupported text encoder: {text_encoder_type}')
self.text_encoder = DistilBERTTextEncoder(
model_name=args_override.get('text_tokenizer_name', 'distilbert-base-uncased'),
output_dim=args_override.get('text_feature_dim', 768),
freeze=args_override.get('freeze_text_encoder', True),
)
print(f'KL Weight {self.kl_weight}') print(f'KL Weight {self.kl_weight}')
def __call__(self, qpos, image, actions=None, is_pad=None): def __call__(self, qpos, image, text_input_ids=None, text_attention_mask=None, text_features=None, actions=None, is_pad=None):
env_state = None env_state = None
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]) std=[0.229, 0.224, 0.225])
image = normalize(image) image = normalize(image)
if self.use_text and text_features is None and text_input_ids is not None and text_attention_mask is not None:
if self.text_encoder is None:
raise RuntimeError('Text encoder is not initialized while use_text=True.')
text_features = self.text_encoder(text_input_ids, text_attention_mask)
if actions is not None: # training time if actions is not None: # training time
if is_pad is None:
raise ValueError('`is_pad` must be provided during training when `actions` is not None.')
actions = actions[:, :self.model.num_queries] actions = actions[:, :self.model.num_queries]
is_pad = is_pad[:, :self.model.num_queries] is_pad = is_pad[:, :self.model.num_queries]
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad) a_hat, is_pad_hat, (mu, logvar) = self.model(
qpos,
image,
env_state,
text_features=text_features,
actions=actions,
is_pad=is_pad,
)
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar) total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
loss_dict = dict() loss_dict = dict()
all_l1 = F.l1_loss(actions, a_hat, reduction='none') all_l1 = F.l1_loss(actions, a_hat, reduction='none')
@@ -34,7 +61,7 @@ class ACTPolicy(nn.Module):
loss_dict['loss'] = loss_dict['l1'] + loss_dict['kl'] * self.kl_weight loss_dict['loss'] = loss_dict['l1'] + loss_dict['kl'] * self.kl_weight
return loss_dict return loss_dict
else: # inference time else: # inference time
a_hat, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior a_hat, _, (_, _) = self.model(qpos, image, env_state, text_features=text_features) # no action, sample from prior
return a_hat return a_hat
def configure_optimizers(self): def configure_optimizers(self):
@@ -48,7 +75,7 @@ class CNNMLPPolicy(nn.Module):
self.model = model # decoder self.model = model # decoder
self.optimizer = optimizer self.optimizer = optimizer
def __call__(self, qpos, image, actions=None, is_pad=None): def __call__(self, qpos, image, text_input_ids=None, text_attention_mask=None, text_features=None, actions=None, is_pad=None):
env_state = None # TODO env_state = None # TODO
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]) std=[0.229, 0.224, 0.225])

230
utils.py
View File

@@ -2,21 +2,88 @@ import numpy as np
import torch import torch
import os import os
import h5py import h5py
import re
from torch.utils.data import TensorDataset, DataLoader from torch.utils.data import TensorDataset, DataLoader
import IPython import IPython
e = IPython.embed e = IPython.embed
class EpisodicDataset(torch.utils.data.Dataset): class EpisodicDataset(torch.utils.data.Dataset):
def __init__(self, episode_ids, dataset_dir, camera_names, norm_stats): def __init__(self, episode_ids, dataset_dir, camera_names, norm_stats,
use_text_instruction=False,
instruction_mode='timestep-level',
use_cached_text_features=True,
text_feature_dim=768,
text_tokenizer_name='distilbert-base-uncased',
text_max_length=32):
super(EpisodicDataset).__init__() super(EpisodicDataset).__init__()
self.episode_ids = episode_ids self.episode_ids = episode_ids
self.dataset_dir = dataset_dir self.dataset_dir = dataset_dir
self.camera_names = camera_names self.camera_names = camera_names
self.norm_stats = norm_stats self.norm_stats = norm_stats
self.use_text_instruction = use_text_instruction
self.instruction_mode = instruction_mode
self.use_cached_text_features = use_cached_text_features
self.text_feature_dim = text_feature_dim
self.text_max_length = text_max_length
self.is_sim = None self.is_sim = None
self.max_episode_len = None
self.action_dim = None
self.text_tokenizer = None
if self.use_text_instruction:
try:
from transformers import DistilBertTokenizerFast
except ImportError as exc:
raise ImportError(
'transformers is required for text instruction loading. '
'Install it with: pip install transformers'
) from exc
self.text_tokenizer = DistilBertTokenizerFast.from_pretrained(text_tokenizer_name)
self._init_episode_shapes()
self.__getitem__(0) # initialize self.is_sim self.__getitem__(0) # initialize self.is_sim
def _init_episode_shapes(self):
max_len = 0
action_dim = None
for episode_id in self.episode_ids:
dataset_path = os.path.join(self.dataset_dir, f'episode_{episode_id}.hdf5')
with h5py.File(dataset_path, 'r') as root:
shape = root['/action'].shape
if len(shape) != 2:
raise ValueError(f'Expected /action to have shape [T, D], got {shape} in {dataset_path}')
max_len = max(max_len, int(shape[0]))
if action_dim is None:
action_dim = int(shape[1])
elif int(shape[1]) != action_dim:
raise ValueError(
f'Inconsistent action dim in dataset. Expected {action_dim}, got {shape[1]} in {dataset_path}'
)
if max_len <= 0 or action_dim is None:
raise ValueError(f'Invalid dataset metadata in {self.dataset_dir}')
self.max_episode_len = max_len
self.action_dim = action_dim
@staticmethod
def _decode_instruction(raw_value):
if raw_value is None:
return ''
if isinstance(raw_value, bytes):
return raw_value.decode('utf-8')
if isinstance(raw_value, np.bytes_):
return raw_value.tobytes().decode('utf-8')
if isinstance(raw_value, np.ndarray):
if raw_value.shape == ():
return EpisodicDataset._decode_instruction(raw_value.item())
if raw_value.size == 0:
return ''
return EpisodicDataset._decode_instruction(raw_value.reshape(-1)[0])
return str(raw_value)
def __len__(self): def __len__(self):
return len(self.episode_ids) return len(self.episode_ids)
@@ -26,7 +93,7 @@ class EpisodicDataset(torch.utils.data.Dataset):
episode_id = self.episode_ids[index] episode_id = self.episode_ids[index]
dataset_path = os.path.join(self.dataset_dir, f'episode_{episode_id}.hdf5') dataset_path = os.path.join(self.dataset_dir, f'episode_{episode_id}.hdf5')
with h5py.File(dataset_path, 'r') as root: with h5py.File(dataset_path, 'r') as root:
is_sim = root.attrs['sim'] is_sim = bool(root.attrs.get('sim', False))
original_action_shape = root['/action'].shape original_action_shape = root['/action'].shape
episode_len = original_action_shape[0] episode_len = original_action_shape[0]
if sample_full_episode: if sample_full_episode:
@@ -35,10 +102,40 @@ class EpisodicDataset(torch.utils.data.Dataset):
start_ts = np.random.choice(episode_len) start_ts = np.random.choice(episode_len)
# get observation at start_ts only # get observation at start_ts only
qpos = root['/observations/qpos'][start_ts] qpos = root['/observations/qpos'][start_ts]
qvel = root['/observations/qvel'][start_ts]
image_dict = dict() image_dict = dict()
for cam_name in self.camera_names: for cam_name in self.camera_names:
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][start_ts] image_dict[cam_name] = root[f'/observations/images/{cam_name}'][start_ts]
instruction = ''
text_feature = None
if self.use_text_instruction:
effective_mode = self.instruction_mode
if effective_mode == 'timestep-level' and '/instruction_timestep' in root:
instruction = self._decode_instruction(root['/instruction_timestep'][start_ts])
elif '/instruction' in root:
instruction_node = root['/instruction']
if getattr(instruction_node, 'shape', ()) == ():
instruction = self._decode_instruction(instruction_node[()])
else:
if len(instruction_node.shape) == 1 and instruction_node.shape[0] == episode_len:
instruction = self._decode_instruction(instruction_node[start_ts])
else:
instruction = self._decode_instruction(instruction_node[0])
if self.use_cached_text_features:
if effective_mode == 'timestep-level' and '/instruction_features_timestep' in root:
text_feature = root['/instruction_features_timestep'][start_ts]
elif '/instruction_features' in root:
feat_node = root['/instruction_features']
if getattr(feat_node, 'shape', ()) == ():
text_feature = np.array(feat_node[()])
elif len(feat_node.shape) == 1:
text_feature = feat_node[()]
elif len(feat_node.shape) == 2 and feat_node.shape[0] == episode_len:
text_feature = feat_node[start_ts]
else:
text_feature = feat_node[0]
# get all actions after and including start_ts # get all actions after and including start_ts
if is_sim: if is_sim:
action = root['/action'][start_ts:] action = root['/action'][start_ts:]
@@ -48,10 +145,10 @@ class EpisodicDataset(torch.utils.data.Dataset):
action_len = episode_len - max(0, start_ts - 1) # hack, to make timesteps more aligned action_len = episode_len - max(0, start_ts - 1) # hack, to make timesteps more aligned
self.is_sim = is_sim self.is_sim = is_sim
padded_action = np.zeros(original_action_shape, dtype=np.float32) padded_action = np.zeros((self.max_episode_len, self.action_dim), dtype=np.float32)
padded_action[:action_len] = action padded_action[:action_len] = action
is_pad = np.zeros(episode_len) is_pad = np.ones(self.max_episode_len)
is_pad[action_len:] = 1 is_pad[:action_len] = 0
# new axis for different cameras # new axis for different cameras
all_cam_images = [] all_cam_images = []
@@ -73,55 +170,132 @@ class EpisodicDataset(torch.utils.data.Dataset):
action_data = (action_data - self.norm_stats["action_mean"]) / self.norm_stats["action_std"] action_data = (action_data - self.norm_stats["action_mean"]) / self.norm_stats["action_std"]
qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats["qpos_std"] qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats["qpos_std"]
return image_data, qpos_data, action_data, is_pad if self.use_text_instruction and text_feature is not None:
text_feature_data = torch.from_numpy(np.array(text_feature)).float()
text_feature_valid = torch.tensor(True, dtype=torch.bool)
text_input_ids = torch.zeros(1, dtype=torch.long)
text_attention_mask = torch.zeros(1, dtype=torch.long)
elif self.use_text_instruction:
tokenized = self.text_tokenizer(
instruction,
padding='max_length',
truncation=True,
max_length=self.text_max_length,
return_tensors='pt',
)
text_input_ids = tokenized['input_ids'].squeeze(0).long()
text_attention_mask = tokenized['attention_mask'].squeeze(0).long()
text_feature_data = torch.zeros(self.text_feature_dim, dtype=torch.float32)
text_feature_valid = torch.tensor(False, dtype=torch.bool)
else:
text_input_ids = torch.zeros(1, dtype=torch.long)
text_attention_mask = torch.zeros(1, dtype=torch.long)
text_feature_data = torch.zeros(self.text_feature_dim, dtype=torch.float32)
text_feature_valid = torch.tensor(False, dtype=torch.bool)
return image_data, qpos_data, action_data, is_pad, text_input_ids, text_attention_mask, text_feature_data, text_feature_valid
def get_norm_stats(dataset_dir, num_episodes): def _discover_episode_ids(dataset_dir, num_episodes=None):
pattern = re.compile(r'^episode_(\d+)\.hdf5$')
episode_ids = []
for fname in os.listdir(dataset_dir):
m = pattern.match(fname)
if m:
episode_ids.append(int(m.group(1)))
episode_ids.sort()
if num_episodes is not None:
episode_ids = episode_ids[:num_episodes]
return episode_ids
def get_norm_stats(dataset_dir, episode_ids):
all_qpos_data = [] all_qpos_data = []
all_action_data = [] all_action_data = []
for episode_idx in range(num_episodes): example_qpos = None
for episode_idx in episode_ids:
dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}.hdf5') dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}.hdf5')
with h5py.File(dataset_path, 'r') as root: with h5py.File(dataset_path, 'r') as root:
qpos = root['/observations/qpos'][()] qpos = root['/observations/qpos'][()]
qvel = root['/observations/qvel'][()]
action = root['/action'][()] action = root['/action'][()]
all_qpos_data.append(torch.from_numpy(qpos)) qpos_t = torch.from_numpy(qpos)
all_action_data.append(torch.from_numpy(action)) action_t = torch.from_numpy(action)
all_qpos_data = torch.stack(all_qpos_data) all_qpos_data.append(qpos_t)
all_action_data = torch.stack(all_action_data) all_action_data.append(action_t)
all_action_data = all_action_data if example_qpos is None and len(qpos) > 0:
example_qpos = qpos[0]
# Episodes may have different lengths; concatenate over time axis.
all_qpos_data = torch.cat(all_qpos_data, dim=0)
all_action_data = torch.cat(all_action_data, dim=0)
# normalize action data # normalize action data
action_mean = all_action_data.mean(dim=[0, 1], keepdim=True) action_mean = all_action_data.mean(dim=0, keepdim=True)
action_std = all_action_data.std(dim=[0, 1], keepdim=True) action_std = all_action_data.std(dim=0, keepdim=True)
action_std = torch.clip(action_std, 1e-2, np.inf) # clipping action_std = torch.clip(action_std, 1e-2, np.inf) # clipping
# normalize qpos data # normalize qpos data
qpos_mean = all_qpos_data.mean(dim=[0, 1], keepdim=True) qpos_mean = all_qpos_data.mean(dim=0, keepdim=True)
qpos_std = all_qpos_data.std(dim=[0, 1], keepdim=True) qpos_std = all_qpos_data.std(dim=0, keepdim=True)
qpos_std = torch.clip(qpos_std, 1e-2, np.inf) # clipping qpos_std = torch.clip(qpos_std, 1e-2, np.inf) # clipping
stats = {"action_mean": action_mean.numpy().squeeze(), "action_std": action_std.numpy().squeeze(), stats = {"action_mean": action_mean.numpy().squeeze(), "action_std": action_std.numpy().squeeze(),
"qpos_mean": qpos_mean.numpy().squeeze(), "qpos_std": qpos_std.numpy().squeeze(), "qpos_mean": qpos_mean.numpy().squeeze(), "qpos_std": qpos_std.numpy().squeeze(),
"example_qpos": qpos} "example_qpos": example_qpos}
return stats return stats
def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val,
def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val): use_text_instruction=False,
instruction_mode='timestep-level',
use_cached_text_features=True,
text_feature_dim=768,
text_tokenizer_name='distilbert-base-uncased',
text_max_length=32):
print(f'\nData from: {dataset_dir}\n') print(f'\nData from: {dataset_dir}\n')
episode_ids = _discover_episode_ids(dataset_dir, num_episodes)
if len(episode_ids) == 0:
raise FileNotFoundError(f'No episode_*.hdf5 found in {dataset_dir}')
if len(episode_ids) < 2:
raise ValueError(f'Need at least 2 episodes for train/val split, found {len(episode_ids)} in {dataset_dir}')
# obtain train test split # obtain train test split
train_ratio = 0.8 train_ratio = 0.8
shuffled_indices = np.random.permutation(num_episodes) shuffled_indices = np.random.permutation(len(episode_ids))
train_indices = shuffled_indices[:int(train_ratio * num_episodes)] train_count = int(train_ratio * len(episode_ids))
val_indices = shuffled_indices[int(train_ratio * num_episodes):] train_indices = shuffled_indices[:train_count]
val_indices = shuffled_indices[train_count:]
train_episode_ids = np.array(episode_ids)[train_indices]
val_episode_ids = np.array(episode_ids)[val_indices]
# obtain normalization stats for qpos and action # obtain normalization stats for qpos and action
norm_stats = get_norm_stats(dataset_dir, num_episodes) norm_stats = get_norm_stats(dataset_dir, episode_ids)
# construct dataset and dataloader # construct dataset and dataloader
train_dataset = EpisodicDataset(train_indices, dataset_dir, camera_names, norm_stats) train_dataset = EpisodicDataset(
val_dataset = EpisodicDataset(val_indices, dataset_dir, camera_names, norm_stats) train_episode_ids,
dataset_dir,
camera_names,
norm_stats,
use_text_instruction=use_text_instruction,
instruction_mode=instruction_mode,
use_cached_text_features=use_cached_text_features,
text_feature_dim=text_feature_dim,
text_tokenizer_name=text_tokenizer_name,
text_max_length=text_max_length,
)
val_dataset = EpisodicDataset(
val_episode_ids,
dataset_dir,
camera_names,
norm_stats,
use_text_instruction=use_text_instruction,
instruction_mode=instruction_mode,
use_cached_text_features=use_cached_text_features,
text_feature_dim=text_feature_dim,
text_tokenizer_name=text_tokenizer_name,
text_max_length=text_max_length,
)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1) train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size_val, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1) val_dataloader = DataLoader(val_dataset, batch_size=batch_size_val, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)

View File

@@ -21,32 +21,33 @@ def load_hdf5(dataset_dir, dataset_name):
with h5py.File(dataset_path, 'r') as root: with h5py.File(dataset_path, 'r') as root:
is_sim = root.attrs['sim'] is_sim = root.attrs['sim']
dt = float(root.attrs.get('dt', DT))
qpos = root['/observations/qpos'][()] qpos = root['/observations/qpos'][()]
qvel = root['/observations/qvel'][()] qvel = root['/observations/qvel'][()] if '/observations/qvel' in root else np.zeros_like(qpos)
action = root['/action'][()] action = root['/action'][()]
image_dict = dict() image_dict = dict()
for cam_name in root[f'/observations/images/'].keys(): for cam_name in root[f'/observations/images/'].keys():
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()] image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()]
return qpos, qvel, action, image_dict return qpos, qvel, action, image_dict, dt
def main(args): def main(args):
dataset_dir = args['dataset_dir'] dataset_dir = args['dataset_dir']
episode_idx = args['episode_idx'] episode_idx = args['episode_idx']
dataset_name = f'episode_{episode_idx}' dataset_name = f'episode_{episode_idx}'
qpos, qvel, action, image_dict = load_hdf5(dataset_dir, dataset_name) qpos, qvel, action, image_dict, dt = load_hdf5(dataset_dir, dataset_name)
save_videos(image_dict, DT, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4')) save_videos(image_dict, dt, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4'))
visualize_joints(qpos, action, plot_path=os.path.join(dataset_dir, dataset_name + '_qpos.png')) visualize_joints(qpos, action, plot_path=os.path.join(dataset_dir, dataset_name + '_qpos.png'))
# visualize_timestamp(t_list, dataset_path) # TODO addn timestamp back # visualize_timestamp(t_list, dataset_path) # TODO addn timestamp back
def save_videos(video, dt, video_path=None): def save_videos(video, dt, video_path):
if isinstance(video, list): if isinstance(video, list):
cam_names = list(video[0].keys()) cam_names = list(video[0].keys())
h, w, _ = video[0][cam_names[0]].shape h, w, _ = video[0][cam_names[0]].shape
w = w * len(cam_names) w = w * len(cam_names)
fps = int(1/dt) fps = max(1, int(round(1 / dt)))
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
for ts, image_dict in enumerate(video): for ts, image_dict in enumerate(video):
images = [] images = []
@@ -66,7 +67,7 @@ def save_videos(video, dt, video_path=None):
all_cam_videos = np.concatenate(all_cam_videos, axis=2) # width dimension all_cam_videos = np.concatenate(all_cam_videos, axis=2) # width dimension
n_frames, h, w, _ = all_cam_videos.shape n_frames, h, w, _ = all_cam_videos.shape
fps = int(1 / dt) fps = max(1, int(round(1 / dt)))
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
for t in range(n_frames): for t in range(n_frames):
image = all_cam_videos[t] image = all_cam_videos[t]