add task configs to constant.py to reduce command line arguments

This commit is contained in:
Tony Zhao
2023-03-05 16:52:47 -08:00
parent 092735ddb9
commit 5a33ee8db0
11 changed files with 131 additions and 116 deletions

View File

@@ -8,8 +8,8 @@ from copy import deepcopy
from tqdm import tqdm
from einops import rearrange
from constants import DT, SIM_EPISODE_LEN_TRANSFER_CUBE, SIM_EPISODE_LEN_INSERTION, EPISODE_LEN
from constants import PUPPET_GRIPPER_JOINT_OPEN, CAMERA_NAMES, SIM_CAMERA_NAMES
from constants import DT
from constants import PUPPET_GRIPPER_JOINT_OPEN
from utils import load_data # data functions
from utils import sample_box_pose, sample_insertion_pose # robot functions
from utils import compute_dict_mean, set_seed, detach_dict # helper functions
@@ -26,7 +26,6 @@ def main(args):
# command line parameters
is_eval = args['eval']
ckpt_dir = args['ckpt_dir']
dataset_dir = args['dataset_dir']
policy_class = args['policy_class']
onscreen_render = args['onscreen_render']
task_name = args['task_name']
@@ -34,8 +33,20 @@ def main(args):
batch_size_val = args['batch_size']
num_epochs = args['num_epochs']
# get task parameters
is_sim = task_name[:4] == 'sim_'
if is_sim:
from constants import SIM_TASK_CONFIGS
task_config = SIM_TASK_CONFIGS[task_name]
else:
from aloha_scripts.constants import TASK_CONFIGS
task_config = TASK_CONFIGS[task_name]
dataset_dir = task_config['dataset_dir']
num_episodes = task_config['num_episodes']
episode_len = task_config['episode_len']
camera_names = task_config['camera_names']
# fixed parameters
num_episodes = 50
state_dim = 14
lr_backbone = 1e-5
backbone = 'resnet18'
@@ -53,41 +64,31 @@ def main(args):
'enc_layers': enc_layers,
'dec_layers': dec_layers,
'nheads': nheads,
'camera_names': camera_names,
}
elif policy_class == 'CNNMLP':
policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1}
policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1,
'camera_names': camera_names,}
else:
raise NotImplementedError
config = {
'num_epochs': num_epochs,
'ckpt_dir': ckpt_dir,
'episode_len': episode_len,
'state_dim': state_dim,
'lr': args['lr'],
'real_robot': 'TBD',
'policy_class': policy_class,
'onscreen_render': onscreen_render,
'policy_config': policy_config,
'task_name': task_name,
'seed': args['seed'],
'temporal_agg': args['temporal_agg']
'temporal_agg': args['temporal_agg'],
'camera_names': camera_names,
'real_robot': not is_sim
}
train_dataloader, val_dataloader, stats, is_sim = load_data(dataset_dir, num_episodes, batch_size_train, batch_size_val)
if is_sim:
policy_config['camera_names'] = SIM_CAMERA_NAMES
config['camera_names'] = SIM_CAMERA_NAMES
config['real_robot'] = False
if task_name == 'transfer_cube':
config['episode_len'] = SIM_EPISODE_LEN_TRANSFER_CUBE
elif task_name == 'insertion':
config['episode_len'] = SIM_EPISODE_LEN_INSERTION
else:
policy_config['camera_names'] = CAMERA_NAMES
config['camera_names'] = CAMERA_NAMES
config['real_robot'] = True
config['episode_len'] = EPISODE_LEN
train_dataloader, val_dataloader, stats, _ = load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val)
if is_eval:
ckpt_names = [f'policy_best.ckpt']
@@ -159,7 +160,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
max_timesteps = config['episode_len']
task_name = config['task_name']
temporal_agg = config['temporal_agg']
onscreen_cam = 'main'
onscreen_cam = 'angle'
# load policy and stats
ckpt_path = os.path.join(ckpt_dir, ckpt_name)
@@ -178,8 +179,8 @@ def eval_bc(config, ckpt_name, save_episode=True):
# load environment
if real_robot:
from scripts.utils import move_grippers # requires aloha
from scripts.real_env import make_real_env # requires aloha
from aloha_scripts.robot_utils import move_grippers # requires aloha
from aloha_scripts.real_env import make_real_env # requires aloha
env = make_real_env(init_node=True)
env_max_reward = 0
else:
@@ -200,12 +201,11 @@ def eval_bc(config, ckpt_name, save_episode=True):
for rollout_id in range(num_rollouts):
rollout_id += 0
### set task
if task_name == 'transfer_cube':
if 'sim_transfer_cube' in task_name:
BOX_POSE[0] = sample_box_pose() # used in sim reset
elif task_name == 'insertion':
elif 'sim_insertion' in task_name:
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
else:
raise NotImplementedError
ts = env.reset()
### onscreen render
@@ -417,7 +417,6 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--eval', action='store_true')
parser.add_argument('--onscreen_render', action='store_true')
parser.add_argument('--dataset_dir', action='store', type=str, help='dataset_dir', required=True)
parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=True)
parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', required=True)
parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True)