diff --git a/README.md b/README.md index 4cc91c2..7cb1b01 100644 --- a/README.md +++ b/README.md @@ -47,11 +47,11 @@ To set up a new terminal, run: ### Simulated experiments -We use ``transfer_cube`` task in the examples below. Another option is ``insertion``. +We use ``sim_transfer_cube_scripted`` task in the examples below. Another option is ``sim_insertion_scripted``. To generated 50 episodes of scripted data, run: python3 record_sim_episodes.py \ - --task_name transfer_cube \ + --task_name sim_transfer_cube_scripted \ --dataset_dir \ --num_episodes 50 @@ -64,24 +64,15 @@ To train ACT: # Transfer Cube task python3 imitate_episodes.py \ - --dataset_dir \ + --task_name sim_transfer_cube_scripted \ --ckpt_dir \ - --policy_class ACT --kl_weight 10 --chunk_size 100 --hidden_dim 256 --batch_size 8 --dim_feedforward 2048 \ - --task_name transfer_cube --seed 0 \ - --temporal_agg \ - --num_epochs 1000 --lr 1e-4 - - # Bimanual Insertion task - python3 imitate_episodes.py \ - --dataset_dir \ - --ckpt_dir \ - --policy_class ACT --kl_weight 10 --chunk_size 100 --hidden_dim 256 --batch_size 8 --dim_feedforward 2048 \ - --task_name insertion --seed 0 \ - --temporal_agg \ - --num_epochs 2000 --lr 1e-5 + --policy_class ACT --kl_weight 10 --chunk_size 100 --hidden_dim 512 --batch_size 8 --dim_feedforward 3200 \ + --num_epochs 2000 --lr 1e-5 \ + --seed 0 + To evaluate the policy, run the same command but add ``--eval``. The success rate -should be around 85% for transfer cube, and around 50% for insertion. +should be around 90% for transfer cube, and around 50% for insertion. Videos will be saved to ```` for each rollout. You can also add ``--onscreen_render`` to see real-time rendering during evaluation. diff --git a/assets/scene.xml b/assets/scene.xml index 5f596bf..ae59450 100644 --- a/assets/scene.xml +++ b/assets/scene.xml @@ -27,7 +27,6 @@ - diff --git a/constants.py b/constants.py index e626194..7dfaaae 100644 --- a/constants.py +++ b/constants.py @@ -1,18 +1,41 @@ import pathlib -### Parameters that changes across tasks -EPISODE_LEN = 600 +### Task parameters +DATA_DIR = '' +SIM_TASK_CONFIGS = { + 'sim_transfer_cube_scripted':{ + 'dataset_dir': DATA_DIR + '/sim_transfer_cube_scripted', + 'num_episodes': 50, + 'episode_len': 400, + 'camera_names': ['top'] + }, -### ALOHA fixed constants + 'sim_transfer_cube_human':{ + 'dataset_dir': DATA_DIR + '/sim_transfer_cube_human', + 'num_episodes': 50, + 'episode_len': 400, + 'camera_names': ['top'] + }, + + 'sim_insertion_scripted': { + 'dataset_dir': DATA_DIR + '/sim_insertion_scripted', + 'num_episodes': 50, + 'episode_len': 400, + 'camera_names': ['top'] + }, + + 'sim_insertion_human': { + 'dataset_dir': DATA_DIR + '/sim_insertion_human', + 'num_episodes': 50, + 'episode_len': 400, + 'camera_names': ['top'] + }, +} + +### Simulation envs fixed constants DT = 0.02 JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"] -CAMERA_NAMES = ['cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist'] # defines the number and ordering of cameras -BOX_INIT_POSE = [0.2, 0.5, 0.05, 1, 0, 0, 0] START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239] -SIM_CAMERA_NAMES = ['main'] - -SIM_EPISODE_LEN_TRANSFER_CUBE = 400 -SIM_EPISODE_LEN_INSERTION = 400 XML_DIR = str(pathlib.Path(__file__).parent.resolve()) + '/assets/' # note: absolute path diff --git a/detr/main.py b/detr/main.py index 213c5fb..3c4a339 100644 --- a/detr/main.py +++ b/detr/main.py @@ -55,7 +55,6 @@ def get_args_parser(): # repeat args in imitate_episodes just to avoid error. Will not be used parser.add_argument('--eval', action='store_true') parser.add_argument('--onscreen_render', action='store_true') - parser.add_argument('--dataset_dir', action='store', type=str, help='dataset_dir', required=True) parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=True) parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', required=True) parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True) diff --git a/ee_sim_env.py b/ee_sim_env.py index a51b0ab..01df233 100644 --- a/ee_sim_env.py +++ b/ee_sim_env.py @@ -35,13 +35,15 @@ def make_ee_sim_env(task_name): right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing) "images": {"main": (480x640x3)} # h, w, c, dtype='uint8' """ - xml_path = os.path.join(XML_DIR, f'bimanual_viperx_ee_{task_name}.xml') - physics = mujoco.Physics.from_xml_path(xml_path) - if task_name == 'transfer_cube': + if 'sim_transfer_cube' in task_name: + xml_path = os.path.join(XML_DIR, f'bimanual_viperx_ee_transfer_cube.xml') + physics = mujoco.Physics.from_xml_path(xml_path) task = TransferCubeEETask(random=False) env = control.Environment(physics, task, time_limit=20, control_timestep=DT, n_sub_steps=None, flat_observation=False) - elif task_name == 'insertion': + elif 'sim_insertion' in task_name: + xml_path = os.path.join(XML_DIR, f'bimanual_viperx_ee_insertion.xml') + physics = mujoco.Physics.from_xml_path(xml_path) task = InsertionEETask(random=False) env = control.Environment(physics, task, time_limit=20, control_timestep=DT, n_sub_steps=None, flat_observation=False) @@ -133,8 +135,9 @@ class BimanualViperXEETask(base.Task): obs['qvel'] = self.get_qvel(physics) obs['env_state'] = self.get_env_state(physics) obs['images'] = dict() - obs['images']['main'] = physics.render(height=480, width=640, camera_id='main') # TODO hardcoded camera name - + obs['images']['top'] = physics.render(height=480, width=640, camera_id='top') + obs['images']['angle'] = physics.render(height=480, width=640, camera_id='angle') + obs['images']['vis'] = physics.render(height=480, width=640, camera_id='front_close') # used in scripted policy to obtain starting pose obs['mocap_pose_left'] = np.concatenate([physics.data.mocap_pos[0], physics.data.mocap_quat[0]]).copy() obs['mocap_pose_right'] = np.concatenate([physics.data.mocap_pos[1], physics.data.mocap_quat[1]]).copy() diff --git a/imitate_episodes.py b/imitate_episodes.py index 8f13401..5500c28 100644 --- a/imitate_episodes.py +++ b/imitate_episodes.py @@ -8,8 +8,8 @@ from copy import deepcopy from tqdm import tqdm from einops import rearrange -from constants import DT, SIM_EPISODE_LEN_TRANSFER_CUBE, SIM_EPISODE_LEN_INSERTION, EPISODE_LEN -from constants import PUPPET_GRIPPER_JOINT_OPEN, CAMERA_NAMES, SIM_CAMERA_NAMES +from constants import DT +from constants import PUPPET_GRIPPER_JOINT_OPEN from utils import load_data # data functions from utils import sample_box_pose, sample_insertion_pose # robot functions from utils import compute_dict_mean, set_seed, detach_dict # helper functions @@ -26,7 +26,6 @@ def main(args): # command line parameters is_eval = args['eval'] ckpt_dir = args['ckpt_dir'] - dataset_dir = args['dataset_dir'] policy_class = args['policy_class'] onscreen_render = args['onscreen_render'] task_name = args['task_name'] @@ -34,8 +33,20 @@ def main(args): batch_size_val = args['batch_size'] num_epochs = args['num_epochs'] + # get task parameters + is_sim = task_name[:4] == 'sim_' + if is_sim: + from constants import SIM_TASK_CONFIGS + task_config = SIM_TASK_CONFIGS[task_name] + else: + from aloha_scripts.constants import TASK_CONFIGS + task_config = TASK_CONFIGS[task_name] + dataset_dir = task_config['dataset_dir'] + num_episodes = task_config['num_episodes'] + episode_len = task_config['episode_len'] + camera_names = task_config['camera_names'] + # fixed parameters - num_episodes = 50 state_dim = 14 lr_backbone = 1e-5 backbone = 'resnet18' @@ -53,41 +64,31 @@ def main(args): 'enc_layers': enc_layers, 'dec_layers': dec_layers, 'nheads': nheads, + 'camera_names': camera_names, } elif policy_class == 'CNNMLP': - policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1} + policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1, + 'camera_names': camera_names,} else: raise NotImplementedError config = { 'num_epochs': num_epochs, 'ckpt_dir': ckpt_dir, + 'episode_len': episode_len, 'state_dim': state_dim, 'lr': args['lr'], - 'real_robot': 'TBD', 'policy_class': policy_class, 'onscreen_render': onscreen_render, 'policy_config': policy_config, 'task_name': task_name, 'seed': args['seed'], - 'temporal_agg': args['temporal_agg'] + 'temporal_agg': args['temporal_agg'], + 'camera_names': camera_names, + 'real_robot': not is_sim } - train_dataloader, val_dataloader, stats, is_sim = load_data(dataset_dir, num_episodes, batch_size_train, batch_size_val) - - if is_sim: - policy_config['camera_names'] = SIM_CAMERA_NAMES - config['camera_names'] = SIM_CAMERA_NAMES - config['real_robot'] = False - if task_name == 'transfer_cube': - config['episode_len'] = SIM_EPISODE_LEN_TRANSFER_CUBE - elif task_name == 'insertion': - config['episode_len'] = SIM_EPISODE_LEN_INSERTION - else: - policy_config['camera_names'] = CAMERA_NAMES - config['camera_names'] = CAMERA_NAMES - config['real_robot'] = True - config['episode_len'] = EPISODE_LEN + train_dataloader, val_dataloader, stats, _ = load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val) if is_eval: ckpt_names = [f'policy_best.ckpt'] @@ -159,7 +160,7 @@ def eval_bc(config, ckpt_name, save_episode=True): max_timesteps = config['episode_len'] task_name = config['task_name'] temporal_agg = config['temporal_agg'] - onscreen_cam = 'main' + onscreen_cam = 'angle' # load policy and stats ckpt_path = os.path.join(ckpt_dir, ckpt_name) @@ -178,8 +179,8 @@ def eval_bc(config, ckpt_name, save_episode=True): # load environment if real_robot: - from scripts.utils import move_grippers # requires aloha - from scripts.real_env import make_real_env # requires aloha + from aloha_scripts.robot_utils import move_grippers # requires aloha + from aloha_scripts.real_env import make_real_env # requires aloha env = make_real_env(init_node=True) env_max_reward = 0 else: @@ -200,12 +201,11 @@ def eval_bc(config, ckpt_name, save_episode=True): for rollout_id in range(num_rollouts): rollout_id += 0 ### set task - if task_name == 'transfer_cube': + if 'sim_transfer_cube' in task_name: BOX_POSE[0] = sample_box_pose() # used in sim reset - elif task_name == 'insertion': + elif 'sim_insertion' in task_name: BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset - else: - raise NotImplementedError + ts = env.reset() ### onscreen render @@ -417,7 +417,6 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--eval', action='store_true') parser.add_argument('--onscreen_render', action='store_true') - parser.add_argument('--dataset_dir', action='store', type=str, help='dataset_dir', required=True) parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=True) parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', required=True) parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True) diff --git a/record_sim_episodes.py b/record_sim_episodes.py index 9bbebe0..a8f8491 100644 --- a/record_sim_episodes.py +++ b/record_sim_episodes.py @@ -5,8 +5,7 @@ import argparse import matplotlib.pyplot as plt import h5py_cache -from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN -from constants import SIM_EPISODE_LEN_TRANSFER_CUBE, SIM_EPISODE_LEN_INSERTION, SIM_CAMERA_NAMES +from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN, SIM_TASK_CONFIGS from ee_sim_env import make_ee_sim_env from sim_env import make_sim_env, BOX_POSE from scripted_policy import PickAndTransferPolicy, InsertionPolicy @@ -29,21 +28,24 @@ def main(args): num_episodes = args['num_episodes'] onscreen_render = args['onscreen_render'] inject_noise = False + render_cam_name = 'angle' if not os.path.isdir(dataset_dir): os.makedirs(dataset_dir, exist_ok=True) - if task_name == 'transfer_cube': + episode_len = SIM_TASK_CONFIGS[task_name]['episode_len'] + camera_names = SIM_TASK_CONFIGS[task_name]['camera_names'] + if task_name == 'sim_transfer_cube_scripted': policy_cls = PickAndTransferPolicy - episode_len = SIM_EPISODE_LEN_TRANSFER_CUBE - elif task_name == 'insertion': + elif task_name == 'sim_insertion_scripted': policy_cls = InsertionPolicy - episode_len = SIM_EPISODE_LEN_INSERTION else: raise NotImplementedError success = [] for episode_idx in range(num_episodes): + print(f'{episode_idx=}') + print('Rollout out EE space scripted policy') # setup the environment env = make_ee_sim_env(task_name) ts = env.reset() @@ -52,14 +54,14 @@ def main(args): # setup plotting if onscreen_render: ax = plt.subplot() - plt_img = ax.imshow(ts.observation['images']['main']) + plt_img = ax.imshow(ts.observation['images'][render_cam_name]) plt.ion() for step in range(episode_len): action = policy(ts) ts = env.step(action) episode.append(ts) if onscreen_render: - plt_img.set_data(ts.observation['images']['main']) + plt_img.set_data(ts.observation['images'][render_cam_name]) plt.pause(0.002) plt.close() @@ -87,7 +89,7 @@ def main(args): del policy # setup the environment - print(f'====== Start Replaying ======') + print('Replaying joint commands') env = make_sim_env(task_name) BOX_POSE[0] = subtask_info # make sure the sim_env has the same object configurations as ee_sim_env ts = env.reset() @@ -96,14 +98,14 @@ def main(args): # setup plotting if onscreen_render: ax = plt.subplot() - plt_img = ax.imshow(ts.observation['images']['main']) + plt_img = ax.imshow(ts.observation['images'][render_cam_name]) plt.ion() for t in range(len(joint_traj)): # note: this will increase episode length by 1 action = joint_traj[t] ts = env.step(action) episode_replay.append(ts) if onscreen_render: - plt_img.set_data(ts.observation['images']['main']) + plt_img.set_data(ts.observation['images'][render_cam_name]) plt.pause(0.02) episode_return = np.sum([ts.reward for ts in episode_replay[1:]]) @@ -121,7 +123,7 @@ def main(args): For each timestep: observations - images - - main (480, 640, 3) 'uint8' + - each_cam_name (480, 640, 3) 'uint8' - qpos (14,) 'float64' - qvel (14,) 'float64' @@ -133,7 +135,7 @@ def main(args): '/observations/qvel': [], '/action': [], } - for cam_name in SIM_CAMERA_NAMES: + for cam_name in camera_names: data_dict[f'/observations/images/{cam_name}'] = [] # because the replaying, there will be eps_len + 1 actions and eps_len + 2 timesteps @@ -150,7 +152,7 @@ def main(args): data_dict['/observations/qpos'].append(ts.observation['qpos']) data_dict['/observations/qvel'].append(ts.observation['qvel']) data_dict['/action'].append(action) - for cam_name in SIM_CAMERA_NAMES: + for cam_name in camera_names: data_dict[f'/observations/images/{cam_name}'].append(ts.observation['images'][cam_name]) # HDF5 @@ -161,8 +163,9 @@ def main(args): root.attrs['sim'] = True obs = root.create_group('observations') image = obs.create_group('images') - cam_main = image.create_dataset('main', (max_timesteps, 480, 640, 3), dtype='uint8', - chunks=(1, 480, 640, 3), ) + for cam_name in camera_names: + _ = image.create_dataset(cam_name, (max_timesteps, 480, 640, 3), dtype='uint8', + chunks=(1, 480, 640, 3), ) # compression='gzip',compression_opts=2,) # compression=32001, compression_opts=(0, 0, 0, 0, 9, 1, 1), shuffle=False) qpos = obs.create_dataset('qpos', (max_timesteps, 14)) diff --git a/scripted_policy.py b/scripted_policy.py index dcd612e..4fd8f00 100644 --- a/scripted_policy.py +++ b/scripted_policy.py @@ -2,6 +2,7 @@ import numpy as np import matplotlib.pyplot as plt from pyquaternion import Quaternion +from constants import SIM_TASK_CONFIGS from ee_sim_env import make_ee_sim_env import IPython @@ -154,13 +155,11 @@ def test_policy(task_name): inject_noise = False # setup the environment - from constants import SIM_EPISODE_LEN_TRANSFER_CUBE, SIM_EPISODE_LEN_INSERTION - if task_name == 'transfer_cube': - env = make_ee_sim_env('transfer_cube') - episode_len = SIM_EPISODE_LEN_TRANSFER_CUBE - elif task_name == 'insertion': - env = make_ee_sim_env('insertion') - episode_len = SIM_EPISODE_LEN_INSERTION + episode_len = SIM_TASK_CONFIGS[task_name]['episode_len'] + if 'sim_transfer_cube' in task_name: + env = make_ee_sim_env('sim_transfer_cube') + elif 'sim_insertion' in task_name: + env = make_ee_sim_env('sim_insertion') else: raise NotImplementedError @@ -169,7 +168,7 @@ def test_policy(task_name): episode = [ts] if onscreen_render: ax = plt.subplot() - plt_img = ax.imshow(ts.observation['images']['main']) + plt_img = ax.imshow(ts.observation['images']['angle']) plt.ion() policy = PickAndTransferPolicy(inject_noise) @@ -178,7 +177,7 @@ def test_policy(task_name): ts = env.step(action) episode.append(ts) if onscreen_render: - plt_img.set_data(ts.observation['images']['main']) + plt_img.set_data(ts.observation['images']['angle']) plt.pause(0.02) plt.close() @@ -190,6 +189,6 @@ def test_policy(task_name): if __name__ == '__main__': - test_task_name = 'transfer_cube' + test_task_name = 'sim_transfer_cube_scripted' test_policy(test_task_name) diff --git a/sim_env.py b/sim_env.py index 55828e7..b79b935 100644 --- a/sim_env.py +++ b/sim_env.py @@ -6,7 +6,7 @@ from dm_control import mujoco from dm_control.rl import control from dm_control.suite import base -from constants import DT, XML_DIR, START_ARM_POSE, BOX_INIT_POSE +from constants import DT, XML_DIR, START_ARM_POSE from constants import PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN from constants import MASTER_GRIPPER_POSITION_NORMALIZE_FN from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN @@ -35,13 +35,15 @@ def make_sim_env(task_name): right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing) "images": {"main": (480x640x3)} # h, w, c, dtype='uint8' """ - xml_path = os.path.join(XML_DIR, f'bimanual_viperx_{task_name}.xml') - physics = mujoco.Physics.from_xml_path(xml_path) - if task_name == 'transfer_cube': + if 'sim_transfer_cube' in task_name: + xml_path = os.path.join(XML_DIR, f'bimanual_viperx_transfer_cube.xml') + physics = mujoco.Physics.from_xml_path(xml_path) task = TransferCubeTask(random=False) env = control.Environment(physics, task, time_limit=20, control_timestep=DT, n_sub_steps=None, flat_observation=False) - elif task_name == 'insertion': + elif 'sim_insertion' in task_name: + xml_path = os.path.join(XML_DIR, f'bimanual_viperx_insertion.xml') + physics = mujoco.Physics.from_xml_path(xml_path) task = InsertionTask(random=False) env = control.Environment(physics, task, time_limit=20, control_timestep=DT, n_sub_steps=None, flat_observation=False) @@ -105,8 +107,9 @@ class BimanualViperXTask(base.Task): obs['qvel'] = self.get_qvel(physics) obs['env_state'] = self.get_env_state(physics) obs['images'] = dict() - obs['images']['main'] = physics.render(height=480, width=640, camera_id='top') # TODO hardcoded camera name - obs['images']['vis'] = physics.render(height=480, width=640, camera_id='front_close') # TODO hardcoded camera name + obs['images']['top'] = physics.render(height=480, width=640, camera_id='top') + obs['images']['angle'] = physics.render(height=480, width=640, camera_id='angle') + obs['images']['vis'] = physics.render(height=480, width=640, camera_id='front_close') return obs @@ -241,9 +244,10 @@ def get_action(master_bot_left, master_bot_right): return action def test_sim_teleop(): + """ Testing teleoperation in sim with ALOHA. Requires hardware and ALOHA repo to work. """ from interbotix_xs_modules.arm import InterbotixManipulatorXS - BOX_POSE[0] = BOX_INIT_POSE + BOX_POSE[0] = [0.2, 0.5, 0.05, 1, 0, 0, 0] # source of data master_bot_left = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper", @@ -252,12 +256,12 @@ def test_sim_teleop(): robot_name=f'master_right', init_node=False) # setup the environment - env = make_sim_env() + env = make_sim_env('sim_transfer_cube') ts = env.reset() episode = [ts] # setup plotting ax = plt.subplot() - plt_img = ax.imshow(ts.observation['image']) + plt_img = ax.imshow(ts.observation['images']['angle']) plt.ion() for t in range(1000): @@ -265,7 +269,7 @@ def test_sim_teleop(): ts = env.step(action) episode.append(ts) - plt_img.set_data(ts.observation['image']) + plt_img.set_data(ts.observation['images']['angle']) plt.pause(0.02) diff --git a/utils.py b/utils.py index d3851b5..fa7e3cb 100644 --- a/utils.py +++ b/utils.py @@ -3,16 +3,16 @@ import torch import os import h5py from torch.utils.data import TensorDataset, DataLoader -from constants import SIM_CAMERA_NAMES, CAMERA_NAMES import IPython e = IPython.embed class EpisodicDataset(torch.utils.data.Dataset): - def __init__(self, episode_ids, dataset_dir, norm_stats): + def __init__(self, episode_ids, dataset_dir, camera_names, norm_stats): super(EpisodicDataset).__init__() self.episode_ids = episode_ids self.dataset_dir = dataset_dir + self.camera_names = camera_names self.norm_stats = norm_stats self.is_sim = None self.__getitem__(0) # initialize self.is_sim @@ -27,10 +27,6 @@ class EpisodicDataset(torch.utils.data.Dataset): dataset_path = os.path.join(self.dataset_dir, f'episode_{episode_id}.hdf5') with h5py.File(dataset_path, 'r') as root: is_sim = root.attrs['sim'] - if is_sim: - camera_names = SIM_CAMERA_NAMES - else: - camera_names = CAMERA_NAMES original_action_shape = root['/action'].shape episode_len = original_action_shape[0] if sample_full_episode: @@ -41,7 +37,7 @@ class EpisodicDataset(torch.utils.data.Dataset): qpos = root['/observations/qpos'][start_ts] qvel = root['/observations/qvel'][start_ts] image_dict = dict() - for cam_name in camera_names: + for cam_name in self.camera_names: image_dict[cam_name] = root[f'/observations/images/{cam_name}'][start_ts] # get all actions after and including start_ts if is_sim: @@ -59,7 +55,7 @@ class EpisodicDataset(torch.utils.data.Dataset): # new axis for different cameras all_cam_images = [] - for cam_name in camera_names: + for cam_name in self.camera_names: all_cam_images.append(image_dict[cam_name]) all_cam_images = np.stack(all_cam_images, axis=0) @@ -112,9 +108,9 @@ def get_norm_stats(dataset_dir, num_episodes): return stats -def load_data(dataset_dir, num_episodes, batch_size_train, batch_size_val): +def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val): # obtain train test split - train_ratio = 0.8 # TODO + train_ratio = 0.8 shuffled_indices = np.random.permutation(num_episodes) train_indices = shuffled_indices[:int(train_ratio * num_episodes)] val_indices = shuffled_indices[int(train_ratio * num_episodes):] @@ -123,8 +119,8 @@ def load_data(dataset_dir, num_episodes, batch_size_train, batch_size_val): norm_stats = get_norm_stats(dataset_dir, num_episodes) # construct dataset and dataloader - train_dataset = EpisodicDataset(train_indices, dataset_dir, norm_stats) - val_dataset = EpisodicDataset(val_indices, dataset_dir, norm_stats) + train_dataset = EpisodicDataset(train_indices, dataset_dir, camera_names, norm_stats) + val_dataset = EpisodicDataset(val_indices, dataset_dir, camera_names, norm_stats) train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1) val_dataloader = DataLoader(val_dataset, batch_size=batch_size_val, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1) diff --git a/visualize_episodes.py b/visualize_episodes.py index 9fb315e..4e55e47 100644 --- a/visualize_episodes.py +++ b/visualize_episodes.py @@ -5,7 +5,7 @@ import h5py import argparse import matplotlib.pyplot as plt -from constants import DT, CAMERA_NAMES, SIM_CAMERA_NAMES +from constants import DT import IPython e = IPython.embed @@ -25,8 +25,7 @@ def load_hdf5(dataset_dir, dataset_name): qvel = root['/observations/qvel'][()] action = root['/action'][()] image_dict = dict() - camera_names = SIM_CAMERA_NAMES if is_sim else CAMERA_NAMES - for cam_name in camera_names: + for cam_name in root[f'/observations/images/'].keys(): image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()] return qpos, qvel, action, image_dict