add task configs to constant.py to reduce command line arguments

This commit is contained in:
Tony Zhao
2023-03-05 16:52:47 -08:00
parent 092735ddb9
commit 5a33ee8db0
11 changed files with 131 additions and 116 deletions

View File

@@ -47,11 +47,11 @@ To set up a new terminal, run:
### Simulated experiments
We use ``transfer_cube`` task in the examples below. Another option is ``insertion``.
We use ``sim_transfer_cube_scripted`` task in the examples below. Another option is ``sim_insertion_scripted``.
To generated 50 episodes of scripted data, run:
python3 record_sim_episodes.py \
--task_name transfer_cube \
--task_name sim_transfer_cube_scripted \
--dataset_dir <data save dir> \
--num_episodes 50
@@ -64,24 +64,15 @@ To train ACT:
# Transfer Cube task
python3 imitate_episodes.py \
--dataset_dir <data save dir> \
--task_name sim_transfer_cube_scripted \
--ckpt_dir <ckpt dir> \
--policy_class ACT --kl_weight 10 --chunk_size 100 --hidden_dim 256 --batch_size 8 --dim_feedforward 2048 \
--task_name transfer_cube --seed 0 \
--temporal_agg \
--num_epochs 1000 --lr 1e-4
--policy_class ACT --kl_weight 10 --chunk_size 100 --hidden_dim 512 --batch_size 8 --dim_feedforward 3200 \
--num_epochs 2000 --lr 1e-5 \
--seed 0
# Bimanual Insertion task
python3 imitate_episodes.py \
--dataset_dir <data save dir> \
--ckpt_dir <ckpt dir> \
--policy_class ACT --kl_weight 10 --chunk_size 100 --hidden_dim 256 --batch_size 8 --dim_feedforward 2048 \
--task_name insertion --seed 0 \
--temporal_agg \
--num_epochs 2000 --lr 1e-5
To evaluate the policy, run the same command but add ``--eval``. The success rate
should be around 85% for transfer cube, and around 50% for insertion.
should be around 90% for transfer cube, and around 50% for insertion.
Videos will be saved to ``<ckpt_dir>`` for each rollout.
You can also add ``--onscreen_render`` to see real-time rendering during evaluation.

View File

@@ -27,7 +27,6 @@
<camera name="left_pillar" pos="-0.5 0.2 0.6" fovy="78" mode="targetbody" target="table"/>
<camera name="right_pillar" pos="0.5 0.2 0.6" fovy="78" mode="targetbody" target="table"/>
<camera name="main" pos="0 -0.2 0.4" fovy="78" mode="targetbody" target="midair"/>
<camera name="top" pos="0 0.6 0.8" fovy="78" mode="targetbody" target="table"/>
<camera name="angle" pos="0 0 0.6" fovy="78" mode="targetbody" target="table"/>
<camera name="front_close" pos="0 0.2 0.4" fovy="78" mode="targetbody" target="vx300s_left/camera_focus"/>

View File

@@ -1,18 +1,41 @@
import pathlib
### Parameters that changes across tasks
EPISODE_LEN = 600
### Task parameters
DATA_DIR = '<put your data dir here>'
SIM_TASK_CONFIGS = {
'sim_transfer_cube_scripted':{
'dataset_dir': DATA_DIR + '/sim_transfer_cube_scripted',
'num_episodes': 50,
'episode_len': 400,
'camera_names': ['top']
},
### ALOHA fixed constants
'sim_transfer_cube_human':{
'dataset_dir': DATA_DIR + '/sim_transfer_cube_human',
'num_episodes': 50,
'episode_len': 400,
'camera_names': ['top']
},
'sim_insertion_scripted': {
'dataset_dir': DATA_DIR + '/sim_insertion_scripted',
'num_episodes': 50,
'episode_len': 400,
'camera_names': ['top']
},
'sim_insertion_human': {
'dataset_dir': DATA_DIR + '/sim_insertion_human',
'num_episodes': 50,
'episode_len': 400,
'camera_names': ['top']
},
}
### Simulation envs fixed constants
DT = 0.02
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
CAMERA_NAMES = ['cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist'] # defines the number and ordering of cameras
BOX_INIT_POSE = [0.2, 0.5, 0.05, 1, 0, 0, 0]
START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239]
SIM_CAMERA_NAMES = ['main']
SIM_EPISODE_LEN_TRANSFER_CUBE = 400
SIM_EPISODE_LEN_INSERTION = 400
XML_DIR = str(pathlib.Path(__file__).parent.resolve()) + '/assets/' # note: absolute path

View File

@@ -55,7 +55,6 @@ def get_args_parser():
# repeat args in imitate_episodes just to avoid error. Will not be used
parser.add_argument('--eval', action='store_true')
parser.add_argument('--onscreen_render', action='store_true')
parser.add_argument('--dataset_dir', action='store', type=str, help='dataset_dir', required=True)
parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=True)
parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', required=True)
parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True)

View File

@@ -35,13 +35,15 @@ def make_ee_sim_env(task_name):
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
"""
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_ee_{task_name}.xml')
physics = mujoco.Physics.from_xml_path(xml_path)
if task_name == 'transfer_cube':
if 'sim_transfer_cube' in task_name:
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_ee_transfer_cube.xml')
physics = mujoco.Physics.from_xml_path(xml_path)
task = TransferCubeEETask(random=False)
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
n_sub_steps=None, flat_observation=False)
elif task_name == 'insertion':
elif 'sim_insertion' in task_name:
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_ee_insertion.xml')
physics = mujoco.Physics.from_xml_path(xml_path)
task = InsertionEETask(random=False)
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
n_sub_steps=None, flat_observation=False)
@@ -133,8 +135,9 @@ class BimanualViperXEETask(base.Task):
obs['qvel'] = self.get_qvel(physics)
obs['env_state'] = self.get_env_state(physics)
obs['images'] = dict()
obs['images']['main'] = physics.render(height=480, width=640, camera_id='main') # TODO hardcoded camera name
obs['images']['top'] = physics.render(height=480, width=640, camera_id='top')
obs['images']['angle'] = physics.render(height=480, width=640, camera_id='angle')
obs['images']['vis'] = physics.render(height=480, width=640, camera_id='front_close')
# used in scripted policy to obtain starting pose
obs['mocap_pose_left'] = np.concatenate([physics.data.mocap_pos[0], physics.data.mocap_quat[0]]).copy()
obs['mocap_pose_right'] = np.concatenate([physics.data.mocap_pos[1], physics.data.mocap_quat[1]]).copy()

View File

@@ -8,8 +8,8 @@ from copy import deepcopy
from tqdm import tqdm
from einops import rearrange
from constants import DT, SIM_EPISODE_LEN_TRANSFER_CUBE, SIM_EPISODE_LEN_INSERTION, EPISODE_LEN
from constants import PUPPET_GRIPPER_JOINT_OPEN, CAMERA_NAMES, SIM_CAMERA_NAMES
from constants import DT
from constants import PUPPET_GRIPPER_JOINT_OPEN
from utils import load_data # data functions
from utils import sample_box_pose, sample_insertion_pose # robot functions
from utils import compute_dict_mean, set_seed, detach_dict # helper functions
@@ -26,7 +26,6 @@ def main(args):
# command line parameters
is_eval = args['eval']
ckpt_dir = args['ckpt_dir']
dataset_dir = args['dataset_dir']
policy_class = args['policy_class']
onscreen_render = args['onscreen_render']
task_name = args['task_name']
@@ -34,8 +33,20 @@ def main(args):
batch_size_val = args['batch_size']
num_epochs = args['num_epochs']
# get task parameters
is_sim = task_name[:4] == 'sim_'
if is_sim:
from constants import SIM_TASK_CONFIGS
task_config = SIM_TASK_CONFIGS[task_name]
else:
from aloha_scripts.constants import TASK_CONFIGS
task_config = TASK_CONFIGS[task_name]
dataset_dir = task_config['dataset_dir']
num_episodes = task_config['num_episodes']
episode_len = task_config['episode_len']
camera_names = task_config['camera_names']
# fixed parameters
num_episodes = 50
state_dim = 14
lr_backbone = 1e-5
backbone = 'resnet18'
@@ -53,41 +64,31 @@ def main(args):
'enc_layers': enc_layers,
'dec_layers': dec_layers,
'nheads': nheads,
'camera_names': camera_names,
}
elif policy_class == 'CNNMLP':
policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1}
policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1,
'camera_names': camera_names,}
else:
raise NotImplementedError
config = {
'num_epochs': num_epochs,
'ckpt_dir': ckpt_dir,
'episode_len': episode_len,
'state_dim': state_dim,
'lr': args['lr'],
'real_robot': 'TBD',
'policy_class': policy_class,
'onscreen_render': onscreen_render,
'policy_config': policy_config,
'task_name': task_name,
'seed': args['seed'],
'temporal_agg': args['temporal_agg']
'temporal_agg': args['temporal_agg'],
'camera_names': camera_names,
'real_robot': not is_sim
}
train_dataloader, val_dataloader, stats, is_sim = load_data(dataset_dir, num_episodes, batch_size_train, batch_size_val)
if is_sim:
policy_config['camera_names'] = SIM_CAMERA_NAMES
config['camera_names'] = SIM_CAMERA_NAMES
config['real_robot'] = False
if task_name == 'transfer_cube':
config['episode_len'] = SIM_EPISODE_LEN_TRANSFER_CUBE
elif task_name == 'insertion':
config['episode_len'] = SIM_EPISODE_LEN_INSERTION
else:
policy_config['camera_names'] = CAMERA_NAMES
config['camera_names'] = CAMERA_NAMES
config['real_robot'] = True
config['episode_len'] = EPISODE_LEN
train_dataloader, val_dataloader, stats, _ = load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val)
if is_eval:
ckpt_names = [f'policy_best.ckpt']
@@ -159,7 +160,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
max_timesteps = config['episode_len']
task_name = config['task_name']
temporal_agg = config['temporal_agg']
onscreen_cam = 'main'
onscreen_cam = 'angle'
# load policy and stats
ckpt_path = os.path.join(ckpt_dir, ckpt_name)
@@ -178,8 +179,8 @@ def eval_bc(config, ckpt_name, save_episode=True):
# load environment
if real_robot:
from scripts.utils import move_grippers # requires aloha
from scripts.real_env import make_real_env # requires aloha
from aloha_scripts.robot_utils import move_grippers # requires aloha
from aloha_scripts.real_env import make_real_env # requires aloha
env = make_real_env(init_node=True)
env_max_reward = 0
else:
@@ -200,12 +201,11 @@ def eval_bc(config, ckpt_name, save_episode=True):
for rollout_id in range(num_rollouts):
rollout_id += 0
### set task
if task_name == 'transfer_cube':
if 'sim_transfer_cube' in task_name:
BOX_POSE[0] = sample_box_pose() # used in sim reset
elif task_name == 'insertion':
elif 'sim_insertion' in task_name:
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
else:
raise NotImplementedError
ts = env.reset()
### onscreen render
@@ -417,7 +417,6 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--eval', action='store_true')
parser.add_argument('--onscreen_render', action='store_true')
parser.add_argument('--dataset_dir', action='store', type=str, help='dataset_dir', required=True)
parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=True)
parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', required=True)
parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True)

View File

@@ -5,8 +5,7 @@ import argparse
import matplotlib.pyplot as plt
import h5py_cache
from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN
from constants import SIM_EPISODE_LEN_TRANSFER_CUBE, SIM_EPISODE_LEN_INSERTION, SIM_CAMERA_NAMES
from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN, SIM_TASK_CONFIGS
from ee_sim_env import make_ee_sim_env
from sim_env import make_sim_env, BOX_POSE
from scripted_policy import PickAndTransferPolicy, InsertionPolicy
@@ -29,21 +28,24 @@ def main(args):
num_episodes = args['num_episodes']
onscreen_render = args['onscreen_render']
inject_noise = False
render_cam_name = 'angle'
if not os.path.isdir(dataset_dir):
os.makedirs(dataset_dir, exist_ok=True)
if task_name == 'transfer_cube':
episode_len = SIM_TASK_CONFIGS[task_name]['episode_len']
camera_names = SIM_TASK_CONFIGS[task_name]['camera_names']
if task_name == 'sim_transfer_cube_scripted':
policy_cls = PickAndTransferPolicy
episode_len = SIM_EPISODE_LEN_TRANSFER_CUBE
elif task_name == 'insertion':
elif task_name == 'sim_insertion_scripted':
policy_cls = InsertionPolicy
episode_len = SIM_EPISODE_LEN_INSERTION
else:
raise NotImplementedError
success = []
for episode_idx in range(num_episodes):
print(f'{episode_idx=}')
print('Rollout out EE space scripted policy')
# setup the environment
env = make_ee_sim_env(task_name)
ts = env.reset()
@@ -52,14 +54,14 @@ def main(args):
# setup plotting
if onscreen_render:
ax = plt.subplot()
plt_img = ax.imshow(ts.observation['images']['main'])
plt_img = ax.imshow(ts.observation['images'][render_cam_name])
plt.ion()
for step in range(episode_len):
action = policy(ts)
ts = env.step(action)
episode.append(ts)
if onscreen_render:
plt_img.set_data(ts.observation['images']['main'])
plt_img.set_data(ts.observation['images'][render_cam_name])
plt.pause(0.002)
plt.close()
@@ -87,7 +89,7 @@ def main(args):
del policy
# setup the environment
print(f'====== Start Replaying ======')
print('Replaying joint commands')
env = make_sim_env(task_name)
BOX_POSE[0] = subtask_info # make sure the sim_env has the same object configurations as ee_sim_env
ts = env.reset()
@@ -96,14 +98,14 @@ def main(args):
# setup plotting
if onscreen_render:
ax = plt.subplot()
plt_img = ax.imshow(ts.observation['images']['main'])
plt_img = ax.imshow(ts.observation['images'][render_cam_name])
plt.ion()
for t in range(len(joint_traj)): # note: this will increase episode length by 1
action = joint_traj[t]
ts = env.step(action)
episode_replay.append(ts)
if onscreen_render:
plt_img.set_data(ts.observation['images']['main'])
plt_img.set_data(ts.observation['images'][render_cam_name])
plt.pause(0.02)
episode_return = np.sum([ts.reward for ts in episode_replay[1:]])
@@ -121,7 +123,7 @@ def main(args):
For each timestep:
observations
- images
- main (480, 640, 3) 'uint8'
- each_cam_name (480, 640, 3) 'uint8'
- qpos (14,) 'float64'
- qvel (14,) 'float64'
@@ -133,7 +135,7 @@ def main(args):
'/observations/qvel': [],
'/action': [],
}
for cam_name in SIM_CAMERA_NAMES:
for cam_name in camera_names:
data_dict[f'/observations/images/{cam_name}'] = []
# because the replaying, there will be eps_len + 1 actions and eps_len + 2 timesteps
@@ -150,7 +152,7 @@ def main(args):
data_dict['/observations/qpos'].append(ts.observation['qpos'])
data_dict['/observations/qvel'].append(ts.observation['qvel'])
data_dict['/action'].append(action)
for cam_name in SIM_CAMERA_NAMES:
for cam_name in camera_names:
data_dict[f'/observations/images/{cam_name}'].append(ts.observation['images'][cam_name])
# HDF5
@@ -161,8 +163,9 @@ def main(args):
root.attrs['sim'] = True
obs = root.create_group('observations')
image = obs.create_group('images')
cam_main = image.create_dataset('main', (max_timesteps, 480, 640, 3), dtype='uint8',
chunks=(1, 480, 640, 3), )
for cam_name in camera_names:
_ = image.create_dataset(cam_name, (max_timesteps, 480, 640, 3), dtype='uint8',
chunks=(1, 480, 640, 3), )
# compression='gzip',compression_opts=2,)
# compression=32001, compression_opts=(0, 0, 0, 0, 9, 1, 1), shuffle=False)
qpos = obs.create_dataset('qpos', (max_timesteps, 14))

View File

@@ -2,6 +2,7 @@ import numpy as np
import matplotlib.pyplot as plt
from pyquaternion import Quaternion
from constants import SIM_TASK_CONFIGS
from ee_sim_env import make_ee_sim_env
import IPython
@@ -154,13 +155,11 @@ def test_policy(task_name):
inject_noise = False
# setup the environment
from constants import SIM_EPISODE_LEN_TRANSFER_CUBE, SIM_EPISODE_LEN_INSERTION
if task_name == 'transfer_cube':
env = make_ee_sim_env('transfer_cube')
episode_len = SIM_EPISODE_LEN_TRANSFER_CUBE
elif task_name == 'insertion':
env = make_ee_sim_env('insertion')
episode_len = SIM_EPISODE_LEN_INSERTION
episode_len = SIM_TASK_CONFIGS[task_name]['episode_len']
if 'sim_transfer_cube' in task_name:
env = make_ee_sim_env('sim_transfer_cube')
elif 'sim_insertion' in task_name:
env = make_ee_sim_env('sim_insertion')
else:
raise NotImplementedError
@@ -169,7 +168,7 @@ def test_policy(task_name):
episode = [ts]
if onscreen_render:
ax = plt.subplot()
plt_img = ax.imshow(ts.observation['images']['main'])
plt_img = ax.imshow(ts.observation['images']['angle'])
plt.ion()
policy = PickAndTransferPolicy(inject_noise)
@@ -178,7 +177,7 @@ def test_policy(task_name):
ts = env.step(action)
episode.append(ts)
if onscreen_render:
plt_img.set_data(ts.observation['images']['main'])
plt_img.set_data(ts.observation['images']['angle'])
plt.pause(0.02)
plt.close()
@@ -190,6 +189,6 @@ def test_policy(task_name):
if __name__ == '__main__':
test_task_name = 'transfer_cube'
test_task_name = 'sim_transfer_cube_scripted'
test_policy(test_task_name)

View File

@@ -6,7 +6,7 @@ from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import base
from constants import DT, XML_DIR, START_ARM_POSE, BOX_INIT_POSE
from constants import DT, XML_DIR, START_ARM_POSE
from constants import PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN
from constants import MASTER_GRIPPER_POSITION_NORMALIZE_FN
from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN
@@ -35,13 +35,15 @@ def make_sim_env(task_name):
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
"""
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_{task_name}.xml')
physics = mujoco.Physics.from_xml_path(xml_path)
if task_name == 'transfer_cube':
if 'sim_transfer_cube' in task_name:
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_transfer_cube.xml')
physics = mujoco.Physics.from_xml_path(xml_path)
task = TransferCubeTask(random=False)
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
n_sub_steps=None, flat_observation=False)
elif task_name == 'insertion':
elif 'sim_insertion' in task_name:
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_insertion.xml')
physics = mujoco.Physics.from_xml_path(xml_path)
task = InsertionTask(random=False)
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
n_sub_steps=None, flat_observation=False)
@@ -105,8 +107,9 @@ class BimanualViperXTask(base.Task):
obs['qvel'] = self.get_qvel(physics)
obs['env_state'] = self.get_env_state(physics)
obs['images'] = dict()
obs['images']['main'] = physics.render(height=480, width=640, camera_id='top') # TODO hardcoded camera name
obs['images']['vis'] = physics.render(height=480, width=640, camera_id='front_close') # TODO hardcoded camera name
obs['images']['top'] = physics.render(height=480, width=640, camera_id='top')
obs['images']['angle'] = physics.render(height=480, width=640, camera_id='angle')
obs['images']['vis'] = physics.render(height=480, width=640, camera_id='front_close')
return obs
@@ -241,9 +244,10 @@ def get_action(master_bot_left, master_bot_right):
return action
def test_sim_teleop():
""" Testing teleoperation in sim with ALOHA. Requires hardware and ALOHA repo to work. """
from interbotix_xs_modules.arm import InterbotixManipulatorXS
BOX_POSE[0] = BOX_INIT_POSE
BOX_POSE[0] = [0.2, 0.5, 0.05, 1, 0, 0, 0]
# source of data
master_bot_left = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper",
@@ -252,12 +256,12 @@ def test_sim_teleop():
robot_name=f'master_right', init_node=False)
# setup the environment
env = make_sim_env()
env = make_sim_env('sim_transfer_cube')
ts = env.reset()
episode = [ts]
# setup plotting
ax = plt.subplot()
plt_img = ax.imshow(ts.observation['image'])
plt_img = ax.imshow(ts.observation['images']['angle'])
plt.ion()
for t in range(1000):
@@ -265,7 +269,7 @@ def test_sim_teleop():
ts = env.step(action)
episode.append(ts)
plt_img.set_data(ts.observation['image'])
plt_img.set_data(ts.observation['images']['angle'])
plt.pause(0.02)

View File

@@ -3,16 +3,16 @@ import torch
import os
import h5py
from torch.utils.data import TensorDataset, DataLoader
from constants import SIM_CAMERA_NAMES, CAMERA_NAMES
import IPython
e = IPython.embed
class EpisodicDataset(torch.utils.data.Dataset):
def __init__(self, episode_ids, dataset_dir, norm_stats):
def __init__(self, episode_ids, dataset_dir, camera_names, norm_stats):
super(EpisodicDataset).__init__()
self.episode_ids = episode_ids
self.dataset_dir = dataset_dir
self.camera_names = camera_names
self.norm_stats = norm_stats
self.is_sim = None
self.__getitem__(0) # initialize self.is_sim
@@ -27,10 +27,6 @@ class EpisodicDataset(torch.utils.data.Dataset):
dataset_path = os.path.join(self.dataset_dir, f'episode_{episode_id}.hdf5')
with h5py.File(dataset_path, 'r') as root:
is_sim = root.attrs['sim']
if is_sim:
camera_names = SIM_CAMERA_NAMES
else:
camera_names = CAMERA_NAMES
original_action_shape = root['/action'].shape
episode_len = original_action_shape[0]
if sample_full_episode:
@@ -41,7 +37,7 @@ class EpisodicDataset(torch.utils.data.Dataset):
qpos = root['/observations/qpos'][start_ts]
qvel = root['/observations/qvel'][start_ts]
image_dict = dict()
for cam_name in camera_names:
for cam_name in self.camera_names:
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][start_ts]
# get all actions after and including start_ts
if is_sim:
@@ -59,7 +55,7 @@ class EpisodicDataset(torch.utils.data.Dataset):
# new axis for different cameras
all_cam_images = []
for cam_name in camera_names:
for cam_name in self.camera_names:
all_cam_images.append(image_dict[cam_name])
all_cam_images = np.stack(all_cam_images, axis=0)
@@ -112,9 +108,9 @@ def get_norm_stats(dataset_dir, num_episodes):
return stats
def load_data(dataset_dir, num_episodes, batch_size_train, batch_size_val):
def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val):
# obtain train test split
train_ratio = 0.8 # TODO
train_ratio = 0.8
shuffled_indices = np.random.permutation(num_episodes)
train_indices = shuffled_indices[:int(train_ratio * num_episodes)]
val_indices = shuffled_indices[int(train_ratio * num_episodes):]
@@ -123,8 +119,8 @@ def load_data(dataset_dir, num_episodes, batch_size_train, batch_size_val):
norm_stats = get_norm_stats(dataset_dir, num_episodes)
# construct dataset and dataloader
train_dataset = EpisodicDataset(train_indices, dataset_dir, norm_stats)
val_dataset = EpisodicDataset(val_indices, dataset_dir, norm_stats)
train_dataset = EpisodicDataset(train_indices, dataset_dir, camera_names, norm_stats)
val_dataset = EpisodicDataset(val_indices, dataset_dir, camera_names, norm_stats)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size_val, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)

View File

@@ -5,7 +5,7 @@ import h5py
import argparse
import matplotlib.pyplot as plt
from constants import DT, CAMERA_NAMES, SIM_CAMERA_NAMES
from constants import DT
import IPython
e = IPython.embed
@@ -25,8 +25,7 @@ def load_hdf5(dataset_dir, dataset_name):
qvel = root['/observations/qvel'][()]
action = root['/action'][()]
image_dict = dict()
camera_names = SIM_CAMERA_NAMES if is_sim else CAMERA_NAMES
for cam_name in camera_names:
for cam_name in root[f'/observations/images/'].keys():
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()]
return qpos, qvel, action, image_dict