代码可以跑起来了
This commit is contained in:
@@ -268,10 +268,30 @@
|
|||||||
|
|
||||||
## 8. 你接下来只需提供的最小信息(进入代码改造前)
|
## 8. 你接下来只需提供的最小信息(进入代码改造前)
|
||||||
|
|
||||||
1. 2 个电机各自的物理含义与取值范围(单位、上下限)。
|
1. 2 个电机各自的物理含义与取值范围(单位、上下限):电机分别为 motor_x 和 motor_y,x 的范围为 7000-17384,y 的范围为 8000-18884。对应的 action_x 和 action_y 都为 0~65535 之间
|
||||||
2. 你当前数据中 `qpos` 和 `action` 的实际定义(是否相同)。
|
2. 你当前数据中 `qpos` 和 `action` 的实际定义(是否相同):action 和 qpos 定义接近,只不过是将 0~65535 分别映射到电机磁编码器数值上。
|
||||||
3. text instruction 是每个 episode 一条,还是每个 timestep 一条。
|
3. text instruction 是每个 episode 一条,还是每个 timestep 一条:text instruction 是每个 timestep 一条。
|
||||||
4. 相机数量、分辨率、帧率。
|
4. 相机数量、分辨率、帧率:相机数量为 1,分辨率为 224*224,帧率为 30Hz;对应的电机控制频率也为 30Hz
|
||||||
5. 是否在训练时冻结 DistilBERT(`freeze_text_encoder=True/False`)。
|
5. 是否在训练时冻结 DistilBERT(`freeze_text_encoder=True/False`):DistilBERT 完全冻结。构建训练集时,先将每一个 frame 的 text instruction 用 DistilBERT 编码以后再保存。这样训练过程中不需要调用 DistilBERT。
|
||||||
|
|
||||||
> 有了这 5 项,即可进入下一步代码改造。
|
> 有了这 5 项,即可进入下一步代码改造。
|
||||||
|
|
||||||
|
text_input_ids、text_attention_mask 什么意思;'instruction_mode': 'episode-level'没有用到;
|
||||||
|
|
||||||
|
```python
|
||||||
|
instruction = ''
|
||||||
|
if self.use_text_instruction:
|
||||||
|
if '/instruction_timestep' in root:
|
||||||
|
instruction = self._decode_instruction(root['/instruction_timestep'][start_ts])
|
||||||
|
elif '/instruction' in root:
|
||||||
|
instruction_node = root['/instruction']
|
||||||
|
if getattr(instruction_node, 'shape', ()) == ():
|
||||||
|
instruction = self._decode_instruction(instruction_node[()])
|
||||||
|
else:
|
||||||
|
if len(instruction_node.shape) == 1 and instruction_node.shape[0] == episode_len:
|
||||||
|
instruction = self._decode_instruction(instruction_node[start_ts])
|
||||||
|
else:
|
||||||
|
instruction = self._decode_instruction(instruction_node[0])
|
||||||
|
```
|
||||||
|
|
||||||
|
为什么修改了 Transformer 的定义?这里是否会生效?
|
||||||
@@ -21,3 +21,5 @@ dependencies:
|
|||||||
- packaging=23.0
|
- packaging=23.0
|
||||||
- h5py=3.8.0
|
- h5py=3.8.0
|
||||||
- ipython=8.12.0
|
- ipython=8.12.0
|
||||||
|
- pip:
|
||||||
|
- transformers==4.38.2
|
||||||
|
|||||||
39
constants.py
39
constants.py
@@ -1,7 +1,7 @@
|
|||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
### Task parameters
|
### Task parameters
|
||||||
DATA_DIR = '<put your data dir here>'
|
DATA_DIR = str(pathlib.Path(__file__).parent.resolve() / 'data')
|
||||||
SIM_TASK_CONFIGS = {
|
SIM_TASK_CONFIGS = {
|
||||||
'sim_transfer_cube_scripted':{
|
'sim_transfer_cube_scripted':{
|
||||||
'dataset_dir': DATA_DIR + '/sim_transfer_cube_scripted',
|
'dataset_dir': DATA_DIR + '/sim_transfer_cube_scripted',
|
||||||
@@ -32,6 +32,43 @@ SIM_TASK_CONFIGS = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ENDOSCOPE_TASK_CONFIGS = {
|
||||||
|
'endoscope_default': {
|
||||||
|
'dataset_dir': DATA_DIR + '/endoscope_default',
|
||||||
|
'num_episodes': 50,
|
||||||
|
'episode_len': 400,
|
||||||
|
'camera_names': ['top'],
|
||||||
|
'state_dim': 2,
|
||||||
|
'action_dim': 2,
|
||||||
|
'use_text_instruction': True,
|
||||||
|
'instruction_mode': 'timestep-level',
|
||||||
|
'use_cached_text_features': True,
|
||||||
|
'text_encoder_type': 'distilbert',
|
||||||
|
'text_feature_dim': 768,
|
||||||
|
'text_fusion_type': 'concat_transformer_input',
|
||||||
|
'freeze_text_encoder': True,
|
||||||
|
'text_max_length': 32,
|
||||||
|
'text_tokenizer_name': 'distilbert-base-uncased',
|
||||||
|
},
|
||||||
|
'endoscope_follow': {
|
||||||
|
'dataset_dir': DATA_DIR + '/follow',
|
||||||
|
'num_episodes': 3,
|
||||||
|
'episode_len': 400,
|
||||||
|
'camera_names': ['top'],
|
||||||
|
'state_dim': 2,
|
||||||
|
'action_dim': 2,
|
||||||
|
'use_text_instruction': True,
|
||||||
|
'instruction_mode': 'timestep-level',
|
||||||
|
'use_cached_text_features': True,
|
||||||
|
'text_encoder_type': 'distilbert',
|
||||||
|
'text_feature_dim': 768,
|
||||||
|
'text_fusion_type': 'concat_transformer_input',
|
||||||
|
'freeze_text_encoder': True,
|
||||||
|
'text_max_length': 32,
|
||||||
|
'text_tokenizer_name': 'distilbert-base-uncased',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
### Simulation envs fixed constants
|
### Simulation envs fixed constants
|
||||||
DT = 0.02
|
DT = 0.02
|
||||||
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
|
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
|
||||||
|
|||||||
18
detr/main.py
18
detr/main.py
@@ -4,6 +4,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from torch.optim.adamw import AdamW
|
||||||
from .models import build_ACT_model, build_CNNMLP_model
|
from .models import build_ACT_model, build_CNNMLP_model
|
||||||
|
|
||||||
import IPython
|
import IPython
|
||||||
@@ -30,6 +31,15 @@ def get_args_parser():
|
|||||||
help="Type of positional embedding to use on top of the image features")
|
help="Type of positional embedding to use on top of the image features")
|
||||||
parser.add_argument('--camera_names', default=[], type=list, # will be overridden
|
parser.add_argument('--camera_names', default=[], type=list, # will be overridden
|
||||||
help="A list of camera names")
|
help="A list of camera names")
|
||||||
|
parser.add_argument('--state_dim', default=14, type=int)
|
||||||
|
parser.add_argument('--action_dim', default=14, type=int)
|
||||||
|
parser.add_argument('--use_text', action='store_true')
|
||||||
|
parser.add_argument('--text_encoder_type', default='distilbert', type=str)
|
||||||
|
parser.add_argument('--text_feature_dim', default=768, type=int)
|
||||||
|
parser.add_argument('--text_fusion_type', default='concat_transformer_input', type=str)
|
||||||
|
parser.add_argument('--freeze_text_encoder', action='store_true')
|
||||||
|
parser.add_argument('--text_max_length', default=32, type=int)
|
||||||
|
parser.add_argument('--text_tokenizer_name', default='distilbert-base-uncased', type=str)
|
||||||
|
|
||||||
# * Transformer
|
# * Transformer
|
||||||
parser.add_argument('--enc_layers', default=4, type=int, # will be overridden
|
parser.add_argument('--enc_layers', default=4, type=int, # will be overridden
|
||||||
@@ -84,8 +94,8 @@ def build_ACT_model_and_optimizer(args_override):
|
|||||||
"lr": args.lr_backbone,
|
"lr": args.lr_backbone,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
|
optimizer = AdamW(param_dicts, lr=args.lr,
|
||||||
weight_decay=args.weight_decay)
|
weight_decay=args.weight_decay)
|
||||||
|
|
||||||
return model, optimizer
|
return model, optimizer
|
||||||
|
|
||||||
@@ -107,8 +117,8 @@ def build_CNNMLP_model_and_optimizer(args_override):
|
|||||||
"lr": args.lr_backbone,
|
"lr": args.lr_backbone,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
|
optimizer = AdamW(param_dicts, lr=args.lr,
|
||||||
weight_decay=args.weight_decay)
|
weight_decay=args.weight_decay)
|
||||||
|
|
||||||
return model, optimizer
|
return model, optimizer
|
||||||
|
|
||||||
|
|||||||
@@ -89,9 +89,32 @@ class Backbone(BackboneBase):
|
|||||||
train_backbone: bool,
|
train_backbone: bool,
|
||||||
return_interm_layers: bool,
|
return_interm_layers: bool,
|
||||||
dilation: bool):
|
dilation: bool):
|
||||||
backbone = getattr(torchvision.models, name)(
|
backbone_builder = getattr(torchvision.models, name)
|
||||||
replace_stride_with_dilation=[False, False, dilation],
|
weights = None
|
||||||
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm??
|
if is_main_process():
|
||||||
|
weight_enum_name_map = {
|
||||||
|
'resnet18': 'ResNet18_Weights',
|
||||||
|
'resnet34': 'ResNet34_Weights',
|
||||||
|
'resnet50': 'ResNet50_Weights',
|
||||||
|
'resnet101': 'ResNet101_Weights',
|
||||||
|
}
|
||||||
|
enum_name = weight_enum_name_map.get(name)
|
||||||
|
if enum_name is not None and hasattr(torchvision.models, enum_name):
|
||||||
|
weights = getattr(getattr(torchvision.models, enum_name), 'DEFAULT')
|
||||||
|
|
||||||
|
try:
|
||||||
|
backbone = backbone_builder(
|
||||||
|
replace_stride_with_dilation=[False, False, dilation],
|
||||||
|
weights=weights,
|
||||||
|
norm_layer=FrozenBatchNorm2d,
|
||||||
|
)
|
||||||
|
except TypeError:
|
||||||
|
# Backward compatibility for older torchvision that still expects `pretrained`.
|
||||||
|
backbone = backbone_builder(
|
||||||
|
replace_stride_with_dilation=[False, False, dilation],
|
||||||
|
pretrained=(weights is not None),
|
||||||
|
norm_layer=FrozenBatchNorm2d,
|
||||||
|
)
|
||||||
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
|
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
|
||||||
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
|
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
|
||||||
|
|
||||||
|
|||||||
@@ -33,7 +33,8 @@ def get_sinusoid_encoding_table(n_position, d_hid):
|
|||||||
|
|
||||||
class DETRVAE(nn.Module):
|
class DETRVAE(nn.Module):
|
||||||
""" This is the DETR module that performs object detection """
|
""" This is the DETR module that performs object detection """
|
||||||
def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names):
|
def __init__(self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names,
|
||||||
|
use_text=False, text_feature_dim=768, text_fusion_type='concat_transformer_input'):
|
||||||
""" Initializes the model.
|
""" Initializes the model.
|
||||||
Parameters:
|
Parameters:
|
||||||
backbones: torch module of the backbone to be used. See backbone.py
|
backbones: torch module of the backbone to be used. See backbone.py
|
||||||
@@ -48,17 +49,18 @@ class DETRVAE(nn.Module):
|
|||||||
self.camera_names = camera_names
|
self.camera_names = camera_names
|
||||||
self.transformer = transformer
|
self.transformer = transformer
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
|
self.use_text = use_text
|
||||||
|
self.text_fusion_type = text_fusion_type
|
||||||
hidden_dim = transformer.d_model
|
hidden_dim = transformer.d_model
|
||||||
self.action_head = nn.Linear(hidden_dim, state_dim)
|
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||||
self.is_pad_head = nn.Linear(hidden_dim, 1)
|
self.is_pad_head = nn.Linear(hidden_dim, 1)
|
||||||
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
||||||
if backbones is not None:
|
if backbones is not None:
|
||||||
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
|
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
|
||||||
self.backbones = nn.ModuleList(backbones)
|
self.backbones = nn.ModuleList(backbones)
|
||||||
self.input_proj_robot_state = nn.Linear(14, hidden_dim)
|
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||||
else:
|
else:
|
||||||
# input_dim = 14 + 7 # robot_state + env_state
|
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||||
self.input_proj_robot_state = nn.Linear(14, hidden_dim)
|
|
||||||
self.input_proj_env_state = nn.Linear(7, hidden_dim)
|
self.input_proj_env_state = nn.Linear(7, hidden_dim)
|
||||||
self.pos = torch.nn.Embedding(2, hidden_dim)
|
self.pos = torch.nn.Embedding(2, hidden_dim)
|
||||||
self.backbones = None
|
self.backbones = None
|
||||||
@@ -66,16 +68,18 @@ class DETRVAE(nn.Module):
|
|||||||
# encoder extra parameters
|
# encoder extra parameters
|
||||||
self.latent_dim = 32 # final size of latent z # TODO tune
|
self.latent_dim = 32 # final size of latent z # TODO tune
|
||||||
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
|
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
|
||||||
self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding
|
self.encoder_action_proj = nn.Linear(action_dim, hidden_dim) # project action to embedding
|
||||||
self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
|
self.encoder_joint_proj = nn.Linear(state_dim, hidden_dim) # project qpos to embedding
|
||||||
self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
|
self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
|
||||||
self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq
|
self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq
|
||||||
|
|
||||||
# decoder extra parameters
|
# decoder extra parameters
|
||||||
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
|
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
|
||||||
self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent
|
num_extra_tokens = 3 if self.use_text else 2
|
||||||
|
self.additional_pos_embed = nn.Embedding(num_extra_tokens, hidden_dim) # latent, proprio, optional text
|
||||||
|
self.text_proj = nn.Linear(text_feature_dim, hidden_dim) if self.use_text else None
|
||||||
|
|
||||||
def forward(self, qpos, image, env_state, actions=None, is_pad=None):
|
def forward(self, qpos, image, env_state, text_features=None, actions=None, is_pad=None):
|
||||||
"""
|
"""
|
||||||
qpos: batch, qpos_dim
|
qpos: batch, qpos_dim
|
||||||
image: batch, num_cam, channel, height, width
|
image: batch, num_cam, channel, height, width
|
||||||
@@ -125,10 +129,25 @@ class DETRVAE(nn.Module):
|
|||||||
all_cam_pos.append(pos)
|
all_cam_pos.append(pos)
|
||||||
# proprioception features
|
# proprioception features
|
||||||
proprio_input = self.input_proj_robot_state(qpos)
|
proprio_input = self.input_proj_robot_state(qpos)
|
||||||
|
extra_input_tokens = None
|
||||||
|
if self.use_text and text_features is not None:
|
||||||
|
if self.text_fusion_type != 'concat_transformer_input':
|
||||||
|
raise NotImplementedError(f'Unsupported text fusion type: {self.text_fusion_type}')
|
||||||
|
text_input = self.text_proj(text_features)
|
||||||
|
extra_input_tokens = text_input.unsqueeze(0)
|
||||||
# fold camera dimension into width dimension
|
# fold camera dimension into width dimension
|
||||||
src = torch.cat(all_cam_features, axis=3)
|
src = torch.cat(all_cam_features, axis=3)
|
||||||
pos = torch.cat(all_cam_pos, axis=3)
|
pos = torch.cat(all_cam_pos, axis=3)
|
||||||
hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0]
|
hs = self.transformer(
|
||||||
|
src,
|
||||||
|
None,
|
||||||
|
self.query_embed.weight,
|
||||||
|
pos,
|
||||||
|
latent_input,
|
||||||
|
proprio_input,
|
||||||
|
self.additional_pos_embed.weight,
|
||||||
|
extra_input_tokens=extra_input_tokens,
|
||||||
|
)[0]
|
||||||
else:
|
else:
|
||||||
qpos = self.input_proj_robot_state(qpos)
|
qpos = self.input_proj_robot_state(qpos)
|
||||||
env_state = self.input_proj_env_state(env_state)
|
env_state = self.input_proj_env_state(env_state)
|
||||||
@@ -141,7 +160,7 @@ class DETRVAE(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class CNNMLP(nn.Module):
|
class CNNMLP(nn.Module):
|
||||||
def __init__(self, backbones, state_dim, camera_names):
|
def __init__(self, backbones, state_dim, action_dim, camera_names):
|
||||||
""" Initializes the model.
|
""" Initializes the model.
|
||||||
Parameters:
|
Parameters:
|
||||||
backbones: torch module of the backbone to be used. See backbone.py
|
backbones: torch module of the backbone to be used. See backbone.py
|
||||||
@@ -153,7 +172,7 @@ class CNNMLP(nn.Module):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.camera_names = camera_names
|
self.camera_names = camera_names
|
||||||
self.action_head = nn.Linear(1000, state_dim) # TODO add more
|
self.action_head = nn.Linear(1000, action_dim) # TODO add more
|
||||||
if backbones is not None:
|
if backbones is not None:
|
||||||
self.backbones = nn.ModuleList(backbones)
|
self.backbones = nn.ModuleList(backbones)
|
||||||
backbone_down_projs = []
|
backbone_down_projs = []
|
||||||
@@ -166,8 +185,8 @@ class CNNMLP(nn.Module):
|
|||||||
backbone_down_projs.append(down_proj)
|
backbone_down_projs.append(down_proj)
|
||||||
self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
|
self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
|
||||||
|
|
||||||
mlp_in_dim = 768 * len(backbones) + 14
|
mlp_in_dim = 768 * len(backbones) + state_dim
|
||||||
self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=14, hidden_depth=2)
|
self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=action_dim, hidden_depth=2)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -192,7 +211,7 @@ class CNNMLP(nn.Module):
|
|||||||
for cam_feature in all_cam_features:
|
for cam_feature in all_cam_features:
|
||||||
flattened_features.append(cam_feature.reshape([bs, -1]))
|
flattened_features.append(cam_feature.reshape([bs, -1]))
|
||||||
flattened_features = torch.cat(flattened_features, axis=1) # 768 each
|
flattened_features = torch.cat(flattened_features, axis=1) # 768 each
|
||||||
features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
|
features = torch.cat([flattened_features, qpos], axis=1)
|
||||||
a_hat = self.mlp(features)
|
a_hat = self.mlp(features)
|
||||||
return a_hat
|
return a_hat
|
||||||
|
|
||||||
@@ -227,7 +246,8 @@ def build_encoder(args):
|
|||||||
|
|
||||||
|
|
||||||
def build(args):
|
def build(args):
|
||||||
state_dim = 14 # TODO hardcode
|
state_dim = args.state_dim
|
||||||
|
action_dim = args.action_dim
|
||||||
|
|
||||||
# From state
|
# From state
|
||||||
# backbone = None # from state for now, no need for conv nets
|
# backbone = None # from state for now, no need for conv nets
|
||||||
@@ -245,8 +265,12 @@ def build(args):
|
|||||||
transformer,
|
transformer,
|
||||||
encoder,
|
encoder,
|
||||||
state_dim=state_dim,
|
state_dim=state_dim,
|
||||||
|
action_dim=action_dim,
|
||||||
num_queries=args.num_queries,
|
num_queries=args.num_queries,
|
||||||
camera_names=args.camera_names,
|
camera_names=args.camera_names,
|
||||||
|
use_text=args.use_text,
|
||||||
|
text_feature_dim=args.text_feature_dim,
|
||||||
|
text_fusion_type=args.text_fusion_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
@@ -255,7 +279,8 @@ def build(args):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
def build_cnnmlp(args):
|
def build_cnnmlp(args):
|
||||||
state_dim = 14 # TODO hardcode
|
state_dim = args.state_dim
|
||||||
|
action_dim = args.action_dim
|
||||||
|
|
||||||
# From state
|
# From state
|
||||||
# backbone = None # from state for now, no need for conv nets
|
# backbone = None # from state for now, no need for conv nets
|
||||||
@@ -268,6 +293,7 @@ def build_cnnmlp(args):
|
|||||||
model = CNNMLP(
|
model = CNNMLP(
|
||||||
backbones,
|
backbones,
|
||||||
state_dim=state_dim,
|
state_dim=state_dim,
|
||||||
|
action_dim=action_dim,
|
||||||
camera_names=args.camera_names,
|
camera_names=args.camera_names,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ class Transformer(nn.Module):
|
|||||||
if p.dim() > 1:
|
if p.dim() > 1:
|
||||||
nn.init.xavier_uniform_(p)
|
nn.init.xavier_uniform_(p)
|
||||||
|
|
||||||
def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None):
|
def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None, extra_input_tokens=None):
|
||||||
# TODO flatten only when input has H and W
|
# TODO flatten only when input has H and W
|
||||||
if len(src.shape) == 4: # has H and W
|
if len(src.shape) == 4: # has H and W
|
||||||
# flatten NxCxHxW to HWxNxC
|
# flatten NxCxHxW to HWxNxC
|
||||||
@@ -56,10 +56,19 @@ class Transformer(nn.Module):
|
|||||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||||
# mask = mask.flatten(1)
|
# mask = mask.flatten(1)
|
||||||
|
|
||||||
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
|
additional_inputs = [latent_input, proprio_input]
|
||||||
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
|
if extra_input_tokens is not None:
|
||||||
|
if len(extra_input_tokens.shape) == 2:
|
||||||
|
extra_input_tokens = extra_input_tokens.unsqueeze(0)
|
||||||
|
for i in range(extra_input_tokens.shape[0]):
|
||||||
|
additional_inputs.append(extra_input_tokens[i])
|
||||||
|
|
||||||
|
addition_input = torch.stack(additional_inputs, axis=0)
|
||||||
|
if additional_pos_embed is not None:
|
||||||
|
additional_pos_embed = additional_pos_embed[:addition_input.shape[0]]
|
||||||
|
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
|
||||||
|
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
|
||||||
|
|
||||||
addition_input = torch.stack([latent_input, proprio_input], axis=0)
|
|
||||||
src = torch.cat([addition_input, src], axis=0)
|
src = torch.cat([addition_input, src], axis=0)
|
||||||
else:
|
else:
|
||||||
assert len(src.shape) == 3
|
assert len(src.shape) == 3
|
||||||
|
|||||||
@@ -10,14 +10,13 @@ from einops import rearrange
|
|||||||
|
|
||||||
from constants import DT
|
from constants import DT
|
||||||
from constants import PUPPET_GRIPPER_JOINT_OPEN
|
from constants import PUPPET_GRIPPER_JOINT_OPEN
|
||||||
|
from constants import SIM_TASK_CONFIGS, ENDOSCOPE_TASK_CONFIGS
|
||||||
from utils import load_data # data functions
|
from utils import load_data # data functions
|
||||||
from utils import sample_box_pose, sample_insertion_pose # robot functions
|
from utils import sample_box_pose, sample_insertion_pose # robot functions
|
||||||
from utils import compute_dict_mean, set_seed, detach_dict # helper functions
|
from utils import compute_dict_mean, set_seed, detach_dict # helper functions
|
||||||
from policy import ACTPolicy, CNNMLPPolicy
|
from policy import ACTPolicy, CNNMLPPolicy
|
||||||
from visualize_episodes import save_videos
|
from visualize_episodes import save_videos
|
||||||
|
|
||||||
from sim_env import BOX_POSE
|
|
||||||
|
|
||||||
import IPython
|
import IPython
|
||||||
e = IPython.embed
|
e = IPython.embed
|
||||||
|
|
||||||
@@ -34,25 +33,47 @@ def main(args):
|
|||||||
num_epochs = args['num_epochs']
|
num_epochs = args['num_epochs']
|
||||||
|
|
||||||
# get task parameters
|
# get task parameters
|
||||||
is_sim = task_name[:4] == 'sim_'
|
is_endoscope = task_name in ENDOSCOPE_TASK_CONFIGS
|
||||||
if is_sim:
|
if is_endoscope:
|
||||||
from constants import SIM_TASK_CONFIGS
|
task_config = ENDOSCOPE_TASK_CONFIGS[task_name]
|
||||||
|
is_sim = False
|
||||||
|
elif task_name in SIM_TASK_CONFIGS:
|
||||||
task_config = SIM_TASK_CONFIGS[task_name]
|
task_config = SIM_TASK_CONFIGS[task_name]
|
||||||
|
is_sim = True
|
||||||
else:
|
else:
|
||||||
from aloha_scripts.constants import TASK_CONFIGS
|
from aloha_scripts.constants import TASK_CONFIGS
|
||||||
task_config = TASK_CONFIGS[task_name]
|
task_config = TASK_CONFIGS[task_name]
|
||||||
|
is_sim = False
|
||||||
|
|
||||||
dataset_dir = task_config['dataset_dir']
|
dataset_dir = task_config['dataset_dir']
|
||||||
num_episodes = task_config['num_episodes']
|
num_episodes = task_config['num_episodes']
|
||||||
episode_len = task_config['episode_len']
|
episode_len = task_config['episode_len']
|
||||||
camera_names = task_config['camera_names']
|
camera_names = task_config['camera_names']
|
||||||
|
state_dim = task_config.get('state_dim', 14)
|
||||||
|
action_dim = task_config.get('action_dim', state_dim)
|
||||||
|
use_text_instruction = task_config.get('use_text_instruction', False)
|
||||||
|
instruction_mode = task_config.get('instruction_mode', 'timestep-level')
|
||||||
|
use_cached_text_features = task_config.get('use_cached_text_features', True)
|
||||||
|
text_encoder_type = task_config.get('text_encoder_type', 'distilbert')
|
||||||
|
text_feature_dim = task_config.get('text_feature_dim', 768)
|
||||||
|
text_fusion_type = task_config.get('text_fusion_type', 'concat_transformer_input')
|
||||||
|
text_max_length = task_config.get('text_max_length', 32)
|
||||||
|
text_tokenizer_name = task_config.get('text_tokenizer_name', 'distilbert-base-uncased')
|
||||||
|
freeze_text_encoder = task_config.get('freeze_text_encoder', True)
|
||||||
|
|
||||||
|
if args.get('text_encoder_type') is not None:
|
||||||
|
text_encoder_type = args['text_encoder_type']
|
||||||
|
if args.get('text_max_length') is not None:
|
||||||
|
text_max_length = args['text_max_length']
|
||||||
|
if args.get('freeze_text_encoder', False):
|
||||||
|
freeze_text_encoder = True
|
||||||
|
|
||||||
# fixed parameters
|
# fixed parameters
|
||||||
state_dim = 14
|
|
||||||
lr_backbone = 1e-5
|
lr_backbone = 1e-5
|
||||||
backbone = 'resnet18'
|
backbone = 'resnet18'
|
||||||
if policy_class == 'ACT':
|
if policy_class == 'ACT':
|
||||||
enc_layers = 4
|
enc_layers = 2
|
||||||
dec_layers = 7
|
dec_layers = 4
|
||||||
nheads = 8
|
nheads = 8
|
||||||
policy_config = {'lr': args['lr'],
|
policy_config = {'lr': args['lr'],
|
||||||
'num_queries': args['chunk_size'],
|
'num_queries': args['chunk_size'],
|
||||||
@@ -65,10 +86,25 @@ def main(args):
|
|||||||
'dec_layers': dec_layers,
|
'dec_layers': dec_layers,
|
||||||
'nheads': nheads,
|
'nheads': nheads,
|
||||||
'camera_names': camera_names,
|
'camera_names': camera_names,
|
||||||
|
'state_dim': state_dim,
|
||||||
|
'action_dim': action_dim,
|
||||||
|
'use_text': use_text_instruction,
|
||||||
|
'text_encoder_type': text_encoder_type,
|
||||||
|
'text_feature_dim': text_feature_dim,
|
||||||
|
'text_fusion_type': text_fusion_type,
|
||||||
|
'freeze_text_encoder': freeze_text_encoder,
|
||||||
|
'instruction_mode': instruction_mode,
|
||||||
|
'use_cached_text_features': use_cached_text_features,
|
||||||
|
'text_max_length': text_max_length,
|
||||||
|
'text_tokenizer_name': text_tokenizer_name,
|
||||||
}
|
}
|
||||||
elif policy_class == 'CNNMLP':
|
elif policy_class == 'CNNMLP':
|
||||||
policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1,
|
policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1,
|
||||||
'camera_names': camera_names,}
|
'camera_names': camera_names,
|
||||||
|
'state_dim': state_dim,
|
||||||
|
'action_dim': action_dim,
|
||||||
|
'use_text': use_text_instruction,
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -77,6 +113,7 @@ def main(args):
|
|||||||
'ckpt_dir': ckpt_dir,
|
'ckpt_dir': ckpt_dir,
|
||||||
'episode_len': episode_len,
|
'episode_len': episode_len,
|
||||||
'state_dim': state_dim,
|
'state_dim': state_dim,
|
||||||
|
'action_dim': action_dim,
|
||||||
'lr': args['lr'],
|
'lr': args['lr'],
|
||||||
'policy_class': policy_class,
|
'policy_class': policy_class,
|
||||||
'onscreen_render': onscreen_render,
|
'onscreen_render': onscreen_render,
|
||||||
@@ -85,7 +122,12 @@ def main(args):
|
|||||||
'seed': args['seed'],
|
'seed': args['seed'],
|
||||||
'temporal_agg': args['temporal_agg'],
|
'temporal_agg': args['temporal_agg'],
|
||||||
'camera_names': camera_names,
|
'camera_names': camera_names,
|
||||||
'real_robot': not is_sim
|
'real_robot': (not is_sim) and (not is_endoscope),
|
||||||
|
'use_text_instruction': use_text_instruction,
|
||||||
|
'instruction_mode': instruction_mode,
|
||||||
|
'use_cached_text_features': use_cached_text_features,
|
||||||
|
'text_tokenizer_name': text_tokenizer_name,
|
||||||
|
'text_max_length': text_max_length,
|
||||||
}
|
}
|
||||||
|
|
||||||
if is_eval:
|
if is_eval:
|
||||||
@@ -100,7 +142,19 @@ def main(args):
|
|||||||
print()
|
print()
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
train_dataloader, val_dataloader, stats, _ = load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val)
|
train_dataloader, val_dataloader, stats, _ = load_data(
|
||||||
|
dataset_dir,
|
||||||
|
num_episodes,
|
||||||
|
camera_names,
|
||||||
|
batch_size_train,
|
||||||
|
batch_size_val,
|
||||||
|
use_text_instruction=use_text_instruction,
|
||||||
|
instruction_mode=instruction_mode,
|
||||||
|
use_cached_text_features=use_cached_text_features,
|
||||||
|
text_feature_dim=text_feature_dim,
|
||||||
|
text_tokenizer_name=text_tokenizer_name,
|
||||||
|
text_max_length=text_max_length,
|
||||||
|
)
|
||||||
|
|
||||||
# save dataset stats
|
# save dataset stats
|
||||||
if not os.path.isdir(ckpt_dir):
|
if not os.path.isdir(ckpt_dir):
|
||||||
@@ -152,6 +206,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
|
|||||||
set_seed(1000)
|
set_seed(1000)
|
||||||
ckpt_dir = config['ckpt_dir']
|
ckpt_dir = config['ckpt_dir']
|
||||||
state_dim = config['state_dim']
|
state_dim = config['state_dim']
|
||||||
|
action_dim = config['action_dim']
|
||||||
real_robot = config['real_robot']
|
real_robot = config['real_robot']
|
||||||
policy_class = config['policy_class']
|
policy_class = config['policy_class']
|
||||||
onscreen_render = config['onscreen_render']
|
onscreen_render = config['onscreen_render']
|
||||||
@@ -161,6 +216,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
|
|||||||
task_name = config['task_name']
|
task_name = config['task_name']
|
||||||
temporal_agg = config['temporal_agg']
|
temporal_agg = config['temporal_agg']
|
||||||
onscreen_cam = 'angle'
|
onscreen_cam = 'angle'
|
||||||
|
BOX_POSE = None
|
||||||
|
|
||||||
# load policy and stats
|
# load policy and stats
|
||||||
ckpt_path = os.path.join(ckpt_dir, ckpt_name)
|
ckpt_path = os.path.join(ckpt_dir, ckpt_name)
|
||||||
@@ -202,8 +258,14 @@ def eval_bc(config, ckpt_name, save_episode=True):
|
|||||||
rollout_id += 0
|
rollout_id += 0
|
||||||
### set task
|
### set task
|
||||||
if 'sim_transfer_cube' in task_name:
|
if 'sim_transfer_cube' in task_name:
|
||||||
|
if BOX_POSE is None:
|
||||||
|
from sim_env import BOX_POSE as _BOX_POSE
|
||||||
|
BOX_POSE = _BOX_POSE
|
||||||
BOX_POSE[0] = sample_box_pose() # used in sim reset
|
BOX_POSE[0] = sample_box_pose() # used in sim reset
|
||||||
elif 'sim_insertion' in task_name:
|
elif 'sim_insertion' in task_name:
|
||||||
|
if BOX_POSE is None:
|
||||||
|
from sim_env import BOX_POSE as _BOX_POSE
|
||||||
|
BOX_POSE = _BOX_POSE
|
||||||
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
|
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
|
||||||
|
|
||||||
ts = env.reset()
|
ts = env.reset()
|
||||||
@@ -216,7 +278,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
|
|||||||
|
|
||||||
### evaluation loop
|
### evaluation loop
|
||||||
if temporal_agg:
|
if temporal_agg:
|
||||||
all_time_actions = torch.zeros([max_timesteps, max_timesteps+num_queries, state_dim]).cuda()
|
all_time_actions = torch.zeros([max_timesteps, max_timesteps+num_queries, action_dim]).cuda()
|
||||||
|
|
||||||
qpos_history = torch.zeros((1, max_timesteps, state_dim)).cuda()
|
qpos_history = torch.zeros((1, max_timesteps, state_dim)).cuda()
|
||||||
image_list = [] # for visualization
|
image_list = [] # for visualization
|
||||||
@@ -314,9 +376,29 @@ def eval_bc(config, ckpt_name, save_episode=True):
|
|||||||
|
|
||||||
|
|
||||||
def forward_pass(data, policy):
|
def forward_pass(data, policy):
|
||||||
image_data, qpos_data, action_data, is_pad = data
|
image_data, qpos_data, action_data, is_pad, text_input_ids, text_attention_mask, text_feature_data, text_feature_valid = data
|
||||||
image_data, qpos_data, action_data, is_pad = image_data.cuda(), qpos_data.cuda(), action_data.cuda(), is_pad.cuda()
|
image_data = image_data.cuda()
|
||||||
return policy(qpos_data, image_data, action_data, is_pad) # TODO remove None
|
qpos_data = qpos_data.cuda()
|
||||||
|
action_data = action_data.cuda()
|
||||||
|
is_pad = is_pad.cuda()
|
||||||
|
text_input_ids = text_input_ids.cuda()
|
||||||
|
text_attention_mask = text_attention_mask.cuda()
|
||||||
|
text_feature_data = text_feature_data.cuda()
|
||||||
|
text_feature_valid = text_feature_valid.cuda()
|
||||||
|
|
||||||
|
text_features = None
|
||||||
|
if torch.any(text_feature_valid):
|
||||||
|
text_features = text_feature_data
|
||||||
|
|
||||||
|
return policy(
|
||||||
|
qpos_data,
|
||||||
|
image_data,
|
||||||
|
text_input_ids=text_input_ids,
|
||||||
|
text_attention_mask=text_attention_mask,
|
||||||
|
text_features=text_features,
|
||||||
|
actions=action_data,
|
||||||
|
is_pad=is_pad,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def train_bc(train_dataloader, val_dataloader, config):
|
def train_bc(train_dataloader, val_dataloader, config):
|
||||||
@@ -431,5 +513,8 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', required=False)
|
parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', required=False)
|
||||||
parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', required=False)
|
parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', required=False)
|
||||||
parser.add_argument('--temporal_agg', action='store_true')
|
parser.add_argument('--temporal_agg', action='store_true')
|
||||||
|
parser.add_argument('--text_encoder_type', action='store', type=str, required=False)
|
||||||
|
parser.add_argument('--freeze_text_encoder', action='store_true')
|
||||||
|
parser.add_argument('--text_max_length', action='store', type=int, required=False)
|
||||||
|
|
||||||
main(vars(parser.parse_args()))
|
main(vars(parser.parse_args()))
|
||||||
|
|||||||
35
policy.py
35
policy.py
@@ -3,6 +3,7 @@ from torch.nn import functional as F
|
|||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
|
|
||||||
from detr.main import build_ACT_model_and_optimizer, build_CNNMLP_model_and_optimizer
|
from detr.main import build_ACT_model_and_optimizer, build_CNNMLP_model_and_optimizer
|
||||||
|
from models.text_encoder import DistilBERTTextEncoder
|
||||||
import IPython
|
import IPython
|
||||||
e = IPython.embed
|
e = IPython.embed
|
||||||
|
|
||||||
@@ -13,18 +14,44 @@ class ACTPolicy(nn.Module):
|
|||||||
self.model = model # CVAE decoder
|
self.model = model # CVAE decoder
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
self.kl_weight = args_override['kl_weight']
|
self.kl_weight = args_override['kl_weight']
|
||||||
|
self.use_text = args_override.get('use_text', False)
|
||||||
|
self.text_encoder = None
|
||||||
|
if self.use_text:
|
||||||
|
text_encoder_type = args_override.get('text_encoder_type', 'distilbert')
|
||||||
|
if text_encoder_type != 'distilbert':
|
||||||
|
raise NotImplementedError(f'Unsupported text encoder: {text_encoder_type}')
|
||||||
|
self.text_encoder = DistilBERTTextEncoder(
|
||||||
|
model_name=args_override.get('text_tokenizer_name', 'distilbert-base-uncased'),
|
||||||
|
output_dim=args_override.get('text_feature_dim', 768),
|
||||||
|
freeze=args_override.get('freeze_text_encoder', True),
|
||||||
|
)
|
||||||
print(f'KL Weight {self.kl_weight}')
|
print(f'KL Weight {self.kl_weight}')
|
||||||
|
|
||||||
def __call__(self, qpos, image, actions=None, is_pad=None):
|
def __call__(self, qpos, image, text_input_ids=None, text_attention_mask=None, text_features=None, actions=None, is_pad=None):
|
||||||
env_state = None
|
env_state = None
|
||||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||||
std=[0.229, 0.224, 0.225])
|
std=[0.229, 0.224, 0.225])
|
||||||
image = normalize(image)
|
image = normalize(image)
|
||||||
|
|
||||||
|
if self.use_text and text_features is None and text_input_ids is not None and text_attention_mask is not None:
|
||||||
|
if self.text_encoder is None:
|
||||||
|
raise RuntimeError('Text encoder is not initialized while use_text=True.')
|
||||||
|
text_features = self.text_encoder(text_input_ids, text_attention_mask)
|
||||||
|
|
||||||
if actions is not None: # training time
|
if actions is not None: # training time
|
||||||
|
if is_pad is None:
|
||||||
|
raise ValueError('`is_pad` must be provided during training when `actions` is not None.')
|
||||||
actions = actions[:, :self.model.num_queries]
|
actions = actions[:, :self.model.num_queries]
|
||||||
is_pad = is_pad[:, :self.model.num_queries]
|
is_pad = is_pad[:, :self.model.num_queries]
|
||||||
|
|
||||||
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
|
a_hat, is_pad_hat, (mu, logvar) = self.model(
|
||||||
|
qpos,
|
||||||
|
image,
|
||||||
|
env_state,
|
||||||
|
text_features=text_features,
|
||||||
|
actions=actions,
|
||||||
|
is_pad=is_pad,
|
||||||
|
)
|
||||||
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
|
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
|
||||||
loss_dict = dict()
|
loss_dict = dict()
|
||||||
all_l1 = F.l1_loss(actions, a_hat, reduction='none')
|
all_l1 = F.l1_loss(actions, a_hat, reduction='none')
|
||||||
@@ -34,7 +61,7 @@ class ACTPolicy(nn.Module):
|
|||||||
loss_dict['loss'] = loss_dict['l1'] + loss_dict['kl'] * self.kl_weight
|
loss_dict['loss'] = loss_dict['l1'] + loss_dict['kl'] * self.kl_weight
|
||||||
return loss_dict
|
return loss_dict
|
||||||
else: # inference time
|
else: # inference time
|
||||||
a_hat, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
|
a_hat, _, (_, _) = self.model(qpos, image, env_state, text_features=text_features) # no action, sample from prior
|
||||||
return a_hat
|
return a_hat
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
@@ -48,7 +75,7 @@ class CNNMLPPolicy(nn.Module):
|
|||||||
self.model = model # decoder
|
self.model = model # decoder
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
|
|
||||||
def __call__(self, qpos, image, actions=None, is_pad=None):
|
def __call__(self, qpos, image, text_input_ids=None, text_attention_mask=None, text_features=None, actions=None, is_pad=None):
|
||||||
env_state = None # TODO
|
env_state = None # TODO
|
||||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||||
std=[0.229, 0.224, 0.225])
|
std=[0.229, 0.224, 0.225])
|
||||||
|
|||||||
230
utils.py
230
utils.py
@@ -2,21 +2,88 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
import h5py
|
import h5py
|
||||||
|
import re
|
||||||
from torch.utils.data import TensorDataset, DataLoader
|
from torch.utils.data import TensorDataset, DataLoader
|
||||||
|
|
||||||
import IPython
|
import IPython
|
||||||
e = IPython.embed
|
e = IPython.embed
|
||||||
|
|
||||||
class EpisodicDataset(torch.utils.data.Dataset):
|
class EpisodicDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, episode_ids, dataset_dir, camera_names, norm_stats):
|
def __init__(self, episode_ids, dataset_dir, camera_names, norm_stats,
|
||||||
|
use_text_instruction=False,
|
||||||
|
instruction_mode='timestep-level',
|
||||||
|
use_cached_text_features=True,
|
||||||
|
text_feature_dim=768,
|
||||||
|
text_tokenizer_name='distilbert-base-uncased',
|
||||||
|
text_max_length=32):
|
||||||
super(EpisodicDataset).__init__()
|
super(EpisodicDataset).__init__()
|
||||||
self.episode_ids = episode_ids
|
self.episode_ids = episode_ids
|
||||||
self.dataset_dir = dataset_dir
|
self.dataset_dir = dataset_dir
|
||||||
self.camera_names = camera_names
|
self.camera_names = camera_names
|
||||||
self.norm_stats = norm_stats
|
self.norm_stats = norm_stats
|
||||||
|
self.use_text_instruction = use_text_instruction
|
||||||
|
self.instruction_mode = instruction_mode
|
||||||
|
self.use_cached_text_features = use_cached_text_features
|
||||||
|
self.text_feature_dim = text_feature_dim
|
||||||
|
self.text_max_length = text_max_length
|
||||||
self.is_sim = None
|
self.is_sim = None
|
||||||
|
self.max_episode_len = None
|
||||||
|
self.action_dim = None
|
||||||
|
|
||||||
|
self.text_tokenizer = None
|
||||||
|
if self.use_text_instruction:
|
||||||
|
try:
|
||||||
|
from transformers import DistilBertTokenizerFast
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
'transformers is required for text instruction loading. '
|
||||||
|
'Install it with: pip install transformers'
|
||||||
|
) from exc
|
||||||
|
self.text_tokenizer = DistilBertTokenizerFast.from_pretrained(text_tokenizer_name)
|
||||||
|
|
||||||
|
self._init_episode_shapes()
|
||||||
|
|
||||||
self.__getitem__(0) # initialize self.is_sim
|
self.__getitem__(0) # initialize self.is_sim
|
||||||
|
|
||||||
|
def _init_episode_shapes(self):
|
||||||
|
max_len = 0
|
||||||
|
action_dim = None
|
||||||
|
for episode_id in self.episode_ids:
|
||||||
|
dataset_path = os.path.join(self.dataset_dir, f'episode_{episode_id}.hdf5')
|
||||||
|
with h5py.File(dataset_path, 'r') as root:
|
||||||
|
shape = root['/action'].shape
|
||||||
|
if len(shape) != 2:
|
||||||
|
raise ValueError(f'Expected /action to have shape [T, D], got {shape} in {dataset_path}')
|
||||||
|
max_len = max(max_len, int(shape[0]))
|
||||||
|
if action_dim is None:
|
||||||
|
action_dim = int(shape[1])
|
||||||
|
elif int(shape[1]) != action_dim:
|
||||||
|
raise ValueError(
|
||||||
|
f'Inconsistent action dim in dataset. Expected {action_dim}, got {shape[1]} in {dataset_path}'
|
||||||
|
)
|
||||||
|
|
||||||
|
if max_len <= 0 or action_dim is None:
|
||||||
|
raise ValueError(f'Invalid dataset metadata in {self.dataset_dir}')
|
||||||
|
|
||||||
|
self.max_episode_len = max_len
|
||||||
|
self.action_dim = action_dim
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _decode_instruction(raw_value):
|
||||||
|
if raw_value is None:
|
||||||
|
return ''
|
||||||
|
if isinstance(raw_value, bytes):
|
||||||
|
return raw_value.decode('utf-8')
|
||||||
|
if isinstance(raw_value, np.bytes_):
|
||||||
|
return raw_value.tobytes().decode('utf-8')
|
||||||
|
if isinstance(raw_value, np.ndarray):
|
||||||
|
if raw_value.shape == ():
|
||||||
|
return EpisodicDataset._decode_instruction(raw_value.item())
|
||||||
|
if raw_value.size == 0:
|
||||||
|
return ''
|
||||||
|
return EpisodicDataset._decode_instruction(raw_value.reshape(-1)[0])
|
||||||
|
return str(raw_value)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.episode_ids)
|
return len(self.episode_ids)
|
||||||
|
|
||||||
@@ -26,7 +93,7 @@ class EpisodicDataset(torch.utils.data.Dataset):
|
|||||||
episode_id = self.episode_ids[index]
|
episode_id = self.episode_ids[index]
|
||||||
dataset_path = os.path.join(self.dataset_dir, f'episode_{episode_id}.hdf5')
|
dataset_path = os.path.join(self.dataset_dir, f'episode_{episode_id}.hdf5')
|
||||||
with h5py.File(dataset_path, 'r') as root:
|
with h5py.File(dataset_path, 'r') as root:
|
||||||
is_sim = root.attrs['sim']
|
is_sim = bool(root.attrs.get('sim', False))
|
||||||
original_action_shape = root['/action'].shape
|
original_action_shape = root['/action'].shape
|
||||||
episode_len = original_action_shape[0]
|
episode_len = original_action_shape[0]
|
||||||
if sample_full_episode:
|
if sample_full_episode:
|
||||||
@@ -35,10 +102,40 @@ class EpisodicDataset(torch.utils.data.Dataset):
|
|||||||
start_ts = np.random.choice(episode_len)
|
start_ts = np.random.choice(episode_len)
|
||||||
# get observation at start_ts only
|
# get observation at start_ts only
|
||||||
qpos = root['/observations/qpos'][start_ts]
|
qpos = root['/observations/qpos'][start_ts]
|
||||||
qvel = root['/observations/qvel'][start_ts]
|
|
||||||
image_dict = dict()
|
image_dict = dict()
|
||||||
for cam_name in self.camera_names:
|
for cam_name in self.camera_names:
|
||||||
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][start_ts]
|
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][start_ts]
|
||||||
|
|
||||||
|
instruction = ''
|
||||||
|
text_feature = None
|
||||||
|
if self.use_text_instruction:
|
||||||
|
effective_mode = self.instruction_mode
|
||||||
|
if effective_mode == 'timestep-level' and '/instruction_timestep' in root:
|
||||||
|
instruction = self._decode_instruction(root['/instruction_timestep'][start_ts])
|
||||||
|
elif '/instruction' in root:
|
||||||
|
instruction_node = root['/instruction']
|
||||||
|
if getattr(instruction_node, 'shape', ()) == ():
|
||||||
|
instruction = self._decode_instruction(instruction_node[()])
|
||||||
|
else:
|
||||||
|
if len(instruction_node.shape) == 1 and instruction_node.shape[0] == episode_len:
|
||||||
|
instruction = self._decode_instruction(instruction_node[start_ts])
|
||||||
|
else:
|
||||||
|
instruction = self._decode_instruction(instruction_node[0])
|
||||||
|
|
||||||
|
if self.use_cached_text_features:
|
||||||
|
if effective_mode == 'timestep-level' and '/instruction_features_timestep' in root:
|
||||||
|
text_feature = root['/instruction_features_timestep'][start_ts]
|
||||||
|
elif '/instruction_features' in root:
|
||||||
|
feat_node = root['/instruction_features']
|
||||||
|
if getattr(feat_node, 'shape', ()) == ():
|
||||||
|
text_feature = np.array(feat_node[()])
|
||||||
|
elif len(feat_node.shape) == 1:
|
||||||
|
text_feature = feat_node[()]
|
||||||
|
elif len(feat_node.shape) == 2 and feat_node.shape[0] == episode_len:
|
||||||
|
text_feature = feat_node[start_ts]
|
||||||
|
else:
|
||||||
|
text_feature = feat_node[0]
|
||||||
|
|
||||||
# get all actions after and including start_ts
|
# get all actions after and including start_ts
|
||||||
if is_sim:
|
if is_sim:
|
||||||
action = root['/action'][start_ts:]
|
action = root['/action'][start_ts:]
|
||||||
@@ -48,10 +145,10 @@ class EpisodicDataset(torch.utils.data.Dataset):
|
|||||||
action_len = episode_len - max(0, start_ts - 1) # hack, to make timesteps more aligned
|
action_len = episode_len - max(0, start_ts - 1) # hack, to make timesteps more aligned
|
||||||
|
|
||||||
self.is_sim = is_sim
|
self.is_sim = is_sim
|
||||||
padded_action = np.zeros(original_action_shape, dtype=np.float32)
|
padded_action = np.zeros((self.max_episode_len, self.action_dim), dtype=np.float32)
|
||||||
padded_action[:action_len] = action
|
padded_action[:action_len] = action
|
||||||
is_pad = np.zeros(episode_len)
|
is_pad = np.ones(self.max_episode_len)
|
||||||
is_pad[action_len:] = 1
|
is_pad[:action_len] = 0
|
||||||
|
|
||||||
# new axis for different cameras
|
# new axis for different cameras
|
||||||
all_cam_images = []
|
all_cam_images = []
|
||||||
@@ -73,55 +170,132 @@ class EpisodicDataset(torch.utils.data.Dataset):
|
|||||||
action_data = (action_data - self.norm_stats["action_mean"]) / self.norm_stats["action_std"]
|
action_data = (action_data - self.norm_stats["action_mean"]) / self.norm_stats["action_std"]
|
||||||
qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats["qpos_std"]
|
qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats["qpos_std"]
|
||||||
|
|
||||||
return image_data, qpos_data, action_data, is_pad
|
if self.use_text_instruction and text_feature is not None:
|
||||||
|
text_feature_data = torch.from_numpy(np.array(text_feature)).float()
|
||||||
|
text_feature_valid = torch.tensor(True, dtype=torch.bool)
|
||||||
|
text_input_ids = torch.zeros(1, dtype=torch.long)
|
||||||
|
text_attention_mask = torch.zeros(1, dtype=torch.long)
|
||||||
|
elif self.use_text_instruction:
|
||||||
|
tokenized = self.text_tokenizer(
|
||||||
|
instruction,
|
||||||
|
padding='max_length',
|
||||||
|
truncation=True,
|
||||||
|
max_length=self.text_max_length,
|
||||||
|
return_tensors='pt',
|
||||||
|
)
|
||||||
|
text_input_ids = tokenized['input_ids'].squeeze(0).long()
|
||||||
|
text_attention_mask = tokenized['attention_mask'].squeeze(0).long()
|
||||||
|
text_feature_data = torch.zeros(self.text_feature_dim, dtype=torch.float32)
|
||||||
|
text_feature_valid = torch.tensor(False, dtype=torch.bool)
|
||||||
|
else:
|
||||||
|
text_input_ids = torch.zeros(1, dtype=torch.long)
|
||||||
|
text_attention_mask = torch.zeros(1, dtype=torch.long)
|
||||||
|
text_feature_data = torch.zeros(self.text_feature_dim, dtype=torch.float32)
|
||||||
|
text_feature_valid = torch.tensor(False, dtype=torch.bool)
|
||||||
|
|
||||||
|
return image_data, qpos_data, action_data, is_pad, text_input_ids, text_attention_mask, text_feature_data, text_feature_valid
|
||||||
|
|
||||||
|
|
||||||
def get_norm_stats(dataset_dir, num_episodes):
|
def _discover_episode_ids(dataset_dir, num_episodes=None):
|
||||||
|
pattern = re.compile(r'^episode_(\d+)\.hdf5$')
|
||||||
|
episode_ids = []
|
||||||
|
for fname in os.listdir(dataset_dir):
|
||||||
|
m = pattern.match(fname)
|
||||||
|
if m:
|
||||||
|
episode_ids.append(int(m.group(1)))
|
||||||
|
episode_ids.sort()
|
||||||
|
if num_episodes is not None:
|
||||||
|
episode_ids = episode_ids[:num_episodes]
|
||||||
|
return episode_ids
|
||||||
|
|
||||||
|
|
||||||
|
def get_norm_stats(dataset_dir, episode_ids):
|
||||||
all_qpos_data = []
|
all_qpos_data = []
|
||||||
all_action_data = []
|
all_action_data = []
|
||||||
for episode_idx in range(num_episodes):
|
example_qpos = None
|
||||||
|
for episode_idx in episode_ids:
|
||||||
dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}.hdf5')
|
dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}.hdf5')
|
||||||
with h5py.File(dataset_path, 'r') as root:
|
with h5py.File(dataset_path, 'r') as root:
|
||||||
qpos = root['/observations/qpos'][()]
|
qpos = root['/observations/qpos'][()]
|
||||||
qvel = root['/observations/qvel'][()]
|
|
||||||
action = root['/action'][()]
|
action = root['/action'][()]
|
||||||
all_qpos_data.append(torch.from_numpy(qpos))
|
qpos_t = torch.from_numpy(qpos)
|
||||||
all_action_data.append(torch.from_numpy(action))
|
action_t = torch.from_numpy(action)
|
||||||
all_qpos_data = torch.stack(all_qpos_data)
|
all_qpos_data.append(qpos_t)
|
||||||
all_action_data = torch.stack(all_action_data)
|
all_action_data.append(action_t)
|
||||||
all_action_data = all_action_data
|
if example_qpos is None and len(qpos) > 0:
|
||||||
|
example_qpos = qpos[0]
|
||||||
|
|
||||||
|
# Episodes may have different lengths; concatenate over time axis.
|
||||||
|
all_qpos_data = torch.cat(all_qpos_data, dim=0)
|
||||||
|
all_action_data = torch.cat(all_action_data, dim=0)
|
||||||
|
|
||||||
# normalize action data
|
# normalize action data
|
||||||
action_mean = all_action_data.mean(dim=[0, 1], keepdim=True)
|
action_mean = all_action_data.mean(dim=0, keepdim=True)
|
||||||
action_std = all_action_data.std(dim=[0, 1], keepdim=True)
|
action_std = all_action_data.std(dim=0, keepdim=True)
|
||||||
action_std = torch.clip(action_std, 1e-2, np.inf) # clipping
|
action_std = torch.clip(action_std, 1e-2, np.inf) # clipping
|
||||||
|
|
||||||
# normalize qpos data
|
# normalize qpos data
|
||||||
qpos_mean = all_qpos_data.mean(dim=[0, 1], keepdim=True)
|
qpos_mean = all_qpos_data.mean(dim=0, keepdim=True)
|
||||||
qpos_std = all_qpos_data.std(dim=[0, 1], keepdim=True)
|
qpos_std = all_qpos_data.std(dim=0, keepdim=True)
|
||||||
qpos_std = torch.clip(qpos_std, 1e-2, np.inf) # clipping
|
qpos_std = torch.clip(qpos_std, 1e-2, np.inf) # clipping
|
||||||
|
|
||||||
stats = {"action_mean": action_mean.numpy().squeeze(), "action_std": action_std.numpy().squeeze(),
|
stats = {"action_mean": action_mean.numpy().squeeze(), "action_std": action_std.numpy().squeeze(),
|
||||||
"qpos_mean": qpos_mean.numpy().squeeze(), "qpos_std": qpos_std.numpy().squeeze(),
|
"qpos_mean": qpos_mean.numpy().squeeze(), "qpos_std": qpos_std.numpy().squeeze(),
|
||||||
"example_qpos": qpos}
|
"example_qpos": example_qpos}
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val,
|
||||||
def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val):
|
use_text_instruction=False,
|
||||||
|
instruction_mode='timestep-level',
|
||||||
|
use_cached_text_features=True,
|
||||||
|
text_feature_dim=768,
|
||||||
|
text_tokenizer_name='distilbert-base-uncased',
|
||||||
|
text_max_length=32):
|
||||||
print(f'\nData from: {dataset_dir}\n')
|
print(f'\nData from: {dataset_dir}\n')
|
||||||
|
episode_ids = _discover_episode_ids(dataset_dir, num_episodes)
|
||||||
|
if len(episode_ids) == 0:
|
||||||
|
raise FileNotFoundError(f'No episode_*.hdf5 found in {dataset_dir}')
|
||||||
|
if len(episode_ids) < 2:
|
||||||
|
raise ValueError(f'Need at least 2 episodes for train/val split, found {len(episode_ids)} in {dataset_dir}')
|
||||||
|
|
||||||
# obtain train test split
|
# obtain train test split
|
||||||
train_ratio = 0.8
|
train_ratio = 0.8
|
||||||
shuffled_indices = np.random.permutation(num_episodes)
|
shuffled_indices = np.random.permutation(len(episode_ids))
|
||||||
train_indices = shuffled_indices[:int(train_ratio * num_episodes)]
|
train_count = int(train_ratio * len(episode_ids))
|
||||||
val_indices = shuffled_indices[int(train_ratio * num_episodes):]
|
train_indices = shuffled_indices[:train_count]
|
||||||
|
val_indices = shuffled_indices[train_count:]
|
||||||
|
train_episode_ids = np.array(episode_ids)[train_indices]
|
||||||
|
val_episode_ids = np.array(episode_ids)[val_indices]
|
||||||
|
|
||||||
# obtain normalization stats for qpos and action
|
# obtain normalization stats for qpos and action
|
||||||
norm_stats = get_norm_stats(dataset_dir, num_episodes)
|
norm_stats = get_norm_stats(dataset_dir, episode_ids)
|
||||||
|
|
||||||
# construct dataset and dataloader
|
# construct dataset and dataloader
|
||||||
train_dataset = EpisodicDataset(train_indices, dataset_dir, camera_names, norm_stats)
|
train_dataset = EpisodicDataset(
|
||||||
val_dataset = EpisodicDataset(val_indices, dataset_dir, camera_names, norm_stats)
|
train_episode_ids,
|
||||||
|
dataset_dir,
|
||||||
|
camera_names,
|
||||||
|
norm_stats,
|
||||||
|
use_text_instruction=use_text_instruction,
|
||||||
|
instruction_mode=instruction_mode,
|
||||||
|
use_cached_text_features=use_cached_text_features,
|
||||||
|
text_feature_dim=text_feature_dim,
|
||||||
|
text_tokenizer_name=text_tokenizer_name,
|
||||||
|
text_max_length=text_max_length,
|
||||||
|
)
|
||||||
|
val_dataset = EpisodicDataset(
|
||||||
|
val_episode_ids,
|
||||||
|
dataset_dir,
|
||||||
|
camera_names,
|
||||||
|
norm_stats,
|
||||||
|
use_text_instruction=use_text_instruction,
|
||||||
|
instruction_mode=instruction_mode,
|
||||||
|
use_cached_text_features=use_cached_text_features,
|
||||||
|
text_feature_dim=text_feature_dim,
|
||||||
|
text_tokenizer_name=text_tokenizer_name,
|
||||||
|
text_max_length=text_max_length,
|
||||||
|
)
|
||||||
train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
|
train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
|
||||||
val_dataloader = DataLoader(val_dataset, batch_size=batch_size_val, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
|
val_dataloader = DataLoader(val_dataset, batch_size=batch_size_val, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
|
||||||
|
|
||||||
|
|||||||
@@ -21,32 +21,33 @@ def load_hdf5(dataset_dir, dataset_name):
|
|||||||
|
|
||||||
with h5py.File(dataset_path, 'r') as root:
|
with h5py.File(dataset_path, 'r') as root:
|
||||||
is_sim = root.attrs['sim']
|
is_sim = root.attrs['sim']
|
||||||
|
dt = float(root.attrs.get('dt', DT))
|
||||||
qpos = root['/observations/qpos'][()]
|
qpos = root['/observations/qpos'][()]
|
||||||
qvel = root['/observations/qvel'][()]
|
qvel = root['/observations/qvel'][()] if '/observations/qvel' in root else np.zeros_like(qpos)
|
||||||
action = root['/action'][()]
|
action = root['/action'][()]
|
||||||
image_dict = dict()
|
image_dict = dict()
|
||||||
for cam_name in root[f'/observations/images/'].keys():
|
for cam_name in root[f'/observations/images/'].keys():
|
||||||
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()]
|
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()]
|
||||||
|
|
||||||
return qpos, qvel, action, image_dict
|
return qpos, qvel, action, image_dict, dt
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
dataset_dir = args['dataset_dir']
|
dataset_dir = args['dataset_dir']
|
||||||
episode_idx = args['episode_idx']
|
episode_idx = args['episode_idx']
|
||||||
dataset_name = f'episode_{episode_idx}'
|
dataset_name = f'episode_{episode_idx}'
|
||||||
|
|
||||||
qpos, qvel, action, image_dict = load_hdf5(dataset_dir, dataset_name)
|
qpos, qvel, action, image_dict, dt = load_hdf5(dataset_dir, dataset_name)
|
||||||
save_videos(image_dict, DT, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4'))
|
save_videos(image_dict, dt, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4'))
|
||||||
visualize_joints(qpos, action, plot_path=os.path.join(dataset_dir, dataset_name + '_qpos.png'))
|
visualize_joints(qpos, action, plot_path=os.path.join(dataset_dir, dataset_name + '_qpos.png'))
|
||||||
# visualize_timestamp(t_list, dataset_path) # TODO addn timestamp back
|
# visualize_timestamp(t_list, dataset_path) # TODO addn timestamp back
|
||||||
|
|
||||||
|
|
||||||
def save_videos(video, dt, video_path=None):
|
def save_videos(video, dt, video_path):
|
||||||
if isinstance(video, list):
|
if isinstance(video, list):
|
||||||
cam_names = list(video[0].keys())
|
cam_names = list(video[0].keys())
|
||||||
h, w, _ = video[0][cam_names[0]].shape
|
h, w, _ = video[0][cam_names[0]].shape
|
||||||
w = w * len(cam_names)
|
w = w * len(cam_names)
|
||||||
fps = int(1/dt)
|
fps = max(1, int(round(1 / dt)))
|
||||||
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||||
for ts, image_dict in enumerate(video):
|
for ts, image_dict in enumerate(video):
|
||||||
images = []
|
images = []
|
||||||
@@ -66,7 +67,7 @@ def save_videos(video, dt, video_path=None):
|
|||||||
all_cam_videos = np.concatenate(all_cam_videos, axis=2) # width dimension
|
all_cam_videos = np.concatenate(all_cam_videos, axis=2) # width dimension
|
||||||
|
|
||||||
n_frames, h, w, _ = all_cam_videos.shape
|
n_frames, h, w, _ = all_cam_videos.shape
|
||||||
fps = int(1 / dt)
|
fps = max(1, int(round(1 / dt)))
|
||||||
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||||
for t in range(n_frames):
|
for t in range(n_frames):
|
||||||
image = all_cam_videos[t]
|
image = all_cam_videos[t]
|
||||||
|
|||||||
Reference in New Issue
Block a user