add task configs to constant.py to reduce command line arguments
This commit is contained in:
@@ -8,8 +8,8 @@ from copy import deepcopy
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange
|
||||
|
||||
from constants import DT, SIM_EPISODE_LEN_TRANSFER_CUBE, SIM_EPISODE_LEN_INSERTION, EPISODE_LEN
|
||||
from constants import PUPPET_GRIPPER_JOINT_OPEN, CAMERA_NAMES, SIM_CAMERA_NAMES
|
||||
from constants import DT
|
||||
from constants import PUPPET_GRIPPER_JOINT_OPEN
|
||||
from utils import load_data # data functions
|
||||
from utils import sample_box_pose, sample_insertion_pose # robot functions
|
||||
from utils import compute_dict_mean, set_seed, detach_dict # helper functions
|
||||
@@ -26,7 +26,6 @@ def main(args):
|
||||
# command line parameters
|
||||
is_eval = args['eval']
|
||||
ckpt_dir = args['ckpt_dir']
|
||||
dataset_dir = args['dataset_dir']
|
||||
policy_class = args['policy_class']
|
||||
onscreen_render = args['onscreen_render']
|
||||
task_name = args['task_name']
|
||||
@@ -34,8 +33,20 @@ def main(args):
|
||||
batch_size_val = args['batch_size']
|
||||
num_epochs = args['num_epochs']
|
||||
|
||||
# get task parameters
|
||||
is_sim = task_name[:4] == 'sim_'
|
||||
if is_sim:
|
||||
from constants import SIM_TASK_CONFIGS
|
||||
task_config = SIM_TASK_CONFIGS[task_name]
|
||||
else:
|
||||
from aloha_scripts.constants import TASK_CONFIGS
|
||||
task_config = TASK_CONFIGS[task_name]
|
||||
dataset_dir = task_config['dataset_dir']
|
||||
num_episodes = task_config['num_episodes']
|
||||
episode_len = task_config['episode_len']
|
||||
camera_names = task_config['camera_names']
|
||||
|
||||
# fixed parameters
|
||||
num_episodes = 50
|
||||
state_dim = 14
|
||||
lr_backbone = 1e-5
|
||||
backbone = 'resnet18'
|
||||
@@ -53,41 +64,31 @@ def main(args):
|
||||
'enc_layers': enc_layers,
|
||||
'dec_layers': dec_layers,
|
||||
'nheads': nheads,
|
||||
'camera_names': camera_names,
|
||||
}
|
||||
elif policy_class == 'CNNMLP':
|
||||
policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1}
|
||||
policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1,
|
||||
'camera_names': camera_names,}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
config = {
|
||||
'num_epochs': num_epochs,
|
||||
'ckpt_dir': ckpt_dir,
|
||||
'episode_len': episode_len,
|
||||
'state_dim': state_dim,
|
||||
'lr': args['lr'],
|
||||
'real_robot': 'TBD',
|
||||
'policy_class': policy_class,
|
||||
'onscreen_render': onscreen_render,
|
||||
'policy_config': policy_config,
|
||||
'task_name': task_name,
|
||||
'seed': args['seed'],
|
||||
'temporal_agg': args['temporal_agg']
|
||||
'temporal_agg': args['temporal_agg'],
|
||||
'camera_names': camera_names,
|
||||
'real_robot': not is_sim
|
||||
}
|
||||
|
||||
train_dataloader, val_dataloader, stats, is_sim = load_data(dataset_dir, num_episodes, batch_size_train, batch_size_val)
|
||||
|
||||
if is_sim:
|
||||
policy_config['camera_names'] = SIM_CAMERA_NAMES
|
||||
config['camera_names'] = SIM_CAMERA_NAMES
|
||||
config['real_robot'] = False
|
||||
if task_name == 'transfer_cube':
|
||||
config['episode_len'] = SIM_EPISODE_LEN_TRANSFER_CUBE
|
||||
elif task_name == 'insertion':
|
||||
config['episode_len'] = SIM_EPISODE_LEN_INSERTION
|
||||
else:
|
||||
policy_config['camera_names'] = CAMERA_NAMES
|
||||
config['camera_names'] = CAMERA_NAMES
|
||||
config['real_robot'] = True
|
||||
config['episode_len'] = EPISODE_LEN
|
||||
train_dataloader, val_dataloader, stats, _ = load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val)
|
||||
|
||||
if is_eval:
|
||||
ckpt_names = [f'policy_best.ckpt']
|
||||
@@ -159,7 +160,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
|
||||
max_timesteps = config['episode_len']
|
||||
task_name = config['task_name']
|
||||
temporal_agg = config['temporal_agg']
|
||||
onscreen_cam = 'main'
|
||||
onscreen_cam = 'angle'
|
||||
|
||||
# load policy and stats
|
||||
ckpt_path = os.path.join(ckpt_dir, ckpt_name)
|
||||
@@ -178,8 +179,8 @@ def eval_bc(config, ckpt_name, save_episode=True):
|
||||
|
||||
# load environment
|
||||
if real_robot:
|
||||
from scripts.utils import move_grippers # requires aloha
|
||||
from scripts.real_env import make_real_env # requires aloha
|
||||
from aloha_scripts.robot_utils import move_grippers # requires aloha
|
||||
from aloha_scripts.real_env import make_real_env # requires aloha
|
||||
env = make_real_env(init_node=True)
|
||||
env_max_reward = 0
|
||||
else:
|
||||
@@ -200,12 +201,11 @@ def eval_bc(config, ckpt_name, save_episode=True):
|
||||
for rollout_id in range(num_rollouts):
|
||||
rollout_id += 0
|
||||
### set task
|
||||
if task_name == 'transfer_cube':
|
||||
if 'sim_transfer_cube' in task_name:
|
||||
BOX_POSE[0] = sample_box_pose() # used in sim reset
|
||||
elif task_name == 'insertion':
|
||||
elif 'sim_insertion' in task_name:
|
||||
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
ts = env.reset()
|
||||
|
||||
### onscreen render
|
||||
@@ -417,7 +417,6 @@ if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--eval', action='store_true')
|
||||
parser.add_argument('--onscreen_render', action='store_true')
|
||||
parser.add_argument('--dataset_dir', action='store', type=str, help='dataset_dir', required=True)
|
||||
parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=True)
|
||||
parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', required=True)
|
||||
parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True)
|
||||
|
||||
Reference in New Issue
Block a user