代码可以跑起来了
This commit is contained in:
@@ -10,14 +10,13 @@ from einops import rearrange
|
||||
|
||||
from constants import DT
|
||||
from constants import PUPPET_GRIPPER_JOINT_OPEN
|
||||
from constants import SIM_TASK_CONFIGS, ENDOSCOPE_TASK_CONFIGS
|
||||
from utils import load_data # data functions
|
||||
from utils import sample_box_pose, sample_insertion_pose # robot functions
|
||||
from utils import compute_dict_mean, set_seed, detach_dict # helper functions
|
||||
from policy import ACTPolicy, CNNMLPPolicy
|
||||
from visualize_episodes import save_videos
|
||||
|
||||
from sim_env import BOX_POSE
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
@@ -34,25 +33,47 @@ def main(args):
|
||||
num_epochs = args['num_epochs']
|
||||
|
||||
# get task parameters
|
||||
is_sim = task_name[:4] == 'sim_'
|
||||
if is_sim:
|
||||
from constants import SIM_TASK_CONFIGS
|
||||
is_endoscope = task_name in ENDOSCOPE_TASK_CONFIGS
|
||||
if is_endoscope:
|
||||
task_config = ENDOSCOPE_TASK_CONFIGS[task_name]
|
||||
is_sim = False
|
||||
elif task_name in SIM_TASK_CONFIGS:
|
||||
task_config = SIM_TASK_CONFIGS[task_name]
|
||||
is_sim = True
|
||||
else:
|
||||
from aloha_scripts.constants import TASK_CONFIGS
|
||||
task_config = TASK_CONFIGS[task_name]
|
||||
is_sim = False
|
||||
|
||||
dataset_dir = task_config['dataset_dir']
|
||||
num_episodes = task_config['num_episodes']
|
||||
episode_len = task_config['episode_len']
|
||||
camera_names = task_config['camera_names']
|
||||
state_dim = task_config.get('state_dim', 14)
|
||||
action_dim = task_config.get('action_dim', state_dim)
|
||||
use_text_instruction = task_config.get('use_text_instruction', False)
|
||||
instruction_mode = task_config.get('instruction_mode', 'timestep-level')
|
||||
use_cached_text_features = task_config.get('use_cached_text_features', True)
|
||||
text_encoder_type = task_config.get('text_encoder_type', 'distilbert')
|
||||
text_feature_dim = task_config.get('text_feature_dim', 768)
|
||||
text_fusion_type = task_config.get('text_fusion_type', 'concat_transformer_input')
|
||||
text_max_length = task_config.get('text_max_length', 32)
|
||||
text_tokenizer_name = task_config.get('text_tokenizer_name', 'distilbert-base-uncased')
|
||||
freeze_text_encoder = task_config.get('freeze_text_encoder', True)
|
||||
|
||||
if args.get('text_encoder_type') is not None:
|
||||
text_encoder_type = args['text_encoder_type']
|
||||
if args.get('text_max_length') is not None:
|
||||
text_max_length = args['text_max_length']
|
||||
if args.get('freeze_text_encoder', False):
|
||||
freeze_text_encoder = True
|
||||
|
||||
# fixed parameters
|
||||
state_dim = 14
|
||||
lr_backbone = 1e-5
|
||||
backbone = 'resnet18'
|
||||
if policy_class == 'ACT':
|
||||
enc_layers = 4
|
||||
dec_layers = 7
|
||||
enc_layers = 2
|
||||
dec_layers = 4
|
||||
nheads = 8
|
||||
policy_config = {'lr': args['lr'],
|
||||
'num_queries': args['chunk_size'],
|
||||
@@ -65,10 +86,25 @@ def main(args):
|
||||
'dec_layers': dec_layers,
|
||||
'nheads': nheads,
|
||||
'camera_names': camera_names,
|
||||
'state_dim': state_dim,
|
||||
'action_dim': action_dim,
|
||||
'use_text': use_text_instruction,
|
||||
'text_encoder_type': text_encoder_type,
|
||||
'text_feature_dim': text_feature_dim,
|
||||
'text_fusion_type': text_fusion_type,
|
||||
'freeze_text_encoder': freeze_text_encoder,
|
||||
'instruction_mode': instruction_mode,
|
||||
'use_cached_text_features': use_cached_text_features,
|
||||
'text_max_length': text_max_length,
|
||||
'text_tokenizer_name': text_tokenizer_name,
|
||||
}
|
||||
elif policy_class == 'CNNMLP':
|
||||
policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1,
|
||||
'camera_names': camera_names,}
|
||||
'camera_names': camera_names,
|
||||
'state_dim': state_dim,
|
||||
'action_dim': action_dim,
|
||||
'use_text': use_text_instruction,
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -77,6 +113,7 @@ def main(args):
|
||||
'ckpt_dir': ckpt_dir,
|
||||
'episode_len': episode_len,
|
||||
'state_dim': state_dim,
|
||||
'action_dim': action_dim,
|
||||
'lr': args['lr'],
|
||||
'policy_class': policy_class,
|
||||
'onscreen_render': onscreen_render,
|
||||
@@ -85,7 +122,12 @@ def main(args):
|
||||
'seed': args['seed'],
|
||||
'temporal_agg': args['temporal_agg'],
|
||||
'camera_names': camera_names,
|
||||
'real_robot': not is_sim
|
||||
'real_robot': (not is_sim) and (not is_endoscope),
|
||||
'use_text_instruction': use_text_instruction,
|
||||
'instruction_mode': instruction_mode,
|
||||
'use_cached_text_features': use_cached_text_features,
|
||||
'text_tokenizer_name': text_tokenizer_name,
|
||||
'text_max_length': text_max_length,
|
||||
}
|
||||
|
||||
if is_eval:
|
||||
@@ -100,7 +142,19 @@ def main(args):
|
||||
print()
|
||||
exit()
|
||||
|
||||
train_dataloader, val_dataloader, stats, _ = load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val)
|
||||
train_dataloader, val_dataloader, stats, _ = load_data(
|
||||
dataset_dir,
|
||||
num_episodes,
|
||||
camera_names,
|
||||
batch_size_train,
|
||||
batch_size_val,
|
||||
use_text_instruction=use_text_instruction,
|
||||
instruction_mode=instruction_mode,
|
||||
use_cached_text_features=use_cached_text_features,
|
||||
text_feature_dim=text_feature_dim,
|
||||
text_tokenizer_name=text_tokenizer_name,
|
||||
text_max_length=text_max_length,
|
||||
)
|
||||
|
||||
# save dataset stats
|
||||
if not os.path.isdir(ckpt_dir):
|
||||
@@ -152,6 +206,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
|
||||
set_seed(1000)
|
||||
ckpt_dir = config['ckpt_dir']
|
||||
state_dim = config['state_dim']
|
||||
action_dim = config['action_dim']
|
||||
real_robot = config['real_robot']
|
||||
policy_class = config['policy_class']
|
||||
onscreen_render = config['onscreen_render']
|
||||
@@ -161,6 +216,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
|
||||
task_name = config['task_name']
|
||||
temporal_agg = config['temporal_agg']
|
||||
onscreen_cam = 'angle'
|
||||
BOX_POSE = None
|
||||
|
||||
# load policy and stats
|
||||
ckpt_path = os.path.join(ckpt_dir, ckpt_name)
|
||||
@@ -202,8 +258,14 @@ def eval_bc(config, ckpt_name, save_episode=True):
|
||||
rollout_id += 0
|
||||
### set task
|
||||
if 'sim_transfer_cube' in task_name:
|
||||
if BOX_POSE is None:
|
||||
from sim_env import BOX_POSE as _BOX_POSE
|
||||
BOX_POSE = _BOX_POSE
|
||||
BOX_POSE[0] = sample_box_pose() # used in sim reset
|
||||
elif 'sim_insertion' in task_name:
|
||||
if BOX_POSE is None:
|
||||
from sim_env import BOX_POSE as _BOX_POSE
|
||||
BOX_POSE = _BOX_POSE
|
||||
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
|
||||
|
||||
ts = env.reset()
|
||||
@@ -216,7 +278,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
|
||||
|
||||
### evaluation loop
|
||||
if temporal_agg:
|
||||
all_time_actions = torch.zeros([max_timesteps, max_timesteps+num_queries, state_dim]).cuda()
|
||||
all_time_actions = torch.zeros([max_timesteps, max_timesteps+num_queries, action_dim]).cuda()
|
||||
|
||||
qpos_history = torch.zeros((1, max_timesteps, state_dim)).cuda()
|
||||
image_list = [] # for visualization
|
||||
@@ -314,9 +376,29 @@ def eval_bc(config, ckpt_name, save_episode=True):
|
||||
|
||||
|
||||
def forward_pass(data, policy):
|
||||
image_data, qpos_data, action_data, is_pad = data
|
||||
image_data, qpos_data, action_data, is_pad = image_data.cuda(), qpos_data.cuda(), action_data.cuda(), is_pad.cuda()
|
||||
return policy(qpos_data, image_data, action_data, is_pad) # TODO remove None
|
||||
image_data, qpos_data, action_data, is_pad, text_input_ids, text_attention_mask, text_feature_data, text_feature_valid = data
|
||||
image_data = image_data.cuda()
|
||||
qpos_data = qpos_data.cuda()
|
||||
action_data = action_data.cuda()
|
||||
is_pad = is_pad.cuda()
|
||||
text_input_ids = text_input_ids.cuda()
|
||||
text_attention_mask = text_attention_mask.cuda()
|
||||
text_feature_data = text_feature_data.cuda()
|
||||
text_feature_valid = text_feature_valid.cuda()
|
||||
|
||||
text_features = None
|
||||
if torch.any(text_feature_valid):
|
||||
text_features = text_feature_data
|
||||
|
||||
return policy(
|
||||
qpos_data,
|
||||
image_data,
|
||||
text_input_ids=text_input_ids,
|
||||
text_attention_mask=text_attention_mask,
|
||||
text_features=text_features,
|
||||
actions=action_data,
|
||||
is_pad=is_pad,
|
||||
)
|
||||
|
||||
|
||||
def train_bc(train_dataloader, val_dataloader, config):
|
||||
@@ -431,5 +513,8 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', required=False)
|
||||
parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', required=False)
|
||||
parser.add_argument('--temporal_agg', action='store_true')
|
||||
parser.add_argument('--text_encoder_type', action='store', type=str, required=False)
|
||||
parser.add_argument('--freeze_text_encoder', action='store_true')
|
||||
parser.add_argument('--text_max_length', action='store', type=int, required=False)
|
||||
|
||||
main(vars(parser.parse_args()))
|
||||
|
||||
Reference in New Issue
Block a user