add task configs to constant.py to reduce command line arguments
This commit is contained in:
@@ -5,8 +5,7 @@ import argparse
|
||||
import matplotlib.pyplot as plt
|
||||
import h5py_cache
|
||||
|
||||
from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN
|
||||
from constants import SIM_EPISODE_LEN_TRANSFER_CUBE, SIM_EPISODE_LEN_INSERTION, SIM_CAMERA_NAMES
|
||||
from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN, SIM_TASK_CONFIGS
|
||||
from ee_sim_env import make_ee_sim_env
|
||||
from sim_env import make_sim_env, BOX_POSE
|
||||
from scripted_policy import PickAndTransferPolicy, InsertionPolicy
|
||||
@@ -29,21 +28,24 @@ def main(args):
|
||||
num_episodes = args['num_episodes']
|
||||
onscreen_render = args['onscreen_render']
|
||||
inject_noise = False
|
||||
render_cam_name = 'angle'
|
||||
|
||||
if not os.path.isdir(dataset_dir):
|
||||
os.makedirs(dataset_dir, exist_ok=True)
|
||||
|
||||
if task_name == 'transfer_cube':
|
||||
episode_len = SIM_TASK_CONFIGS[task_name]['episode_len']
|
||||
camera_names = SIM_TASK_CONFIGS[task_name]['camera_names']
|
||||
if task_name == 'sim_transfer_cube_scripted':
|
||||
policy_cls = PickAndTransferPolicy
|
||||
episode_len = SIM_EPISODE_LEN_TRANSFER_CUBE
|
||||
elif task_name == 'insertion':
|
||||
elif task_name == 'sim_insertion_scripted':
|
||||
policy_cls = InsertionPolicy
|
||||
episode_len = SIM_EPISODE_LEN_INSERTION
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
success = []
|
||||
for episode_idx in range(num_episodes):
|
||||
print(f'{episode_idx=}')
|
||||
print('Rollout out EE space scripted policy')
|
||||
# setup the environment
|
||||
env = make_ee_sim_env(task_name)
|
||||
ts = env.reset()
|
||||
@@ -52,14 +54,14 @@ def main(args):
|
||||
# setup plotting
|
||||
if onscreen_render:
|
||||
ax = plt.subplot()
|
||||
plt_img = ax.imshow(ts.observation['images']['main'])
|
||||
plt_img = ax.imshow(ts.observation['images'][render_cam_name])
|
||||
plt.ion()
|
||||
for step in range(episode_len):
|
||||
action = policy(ts)
|
||||
ts = env.step(action)
|
||||
episode.append(ts)
|
||||
if onscreen_render:
|
||||
plt_img.set_data(ts.observation['images']['main'])
|
||||
plt_img.set_data(ts.observation['images'][render_cam_name])
|
||||
plt.pause(0.002)
|
||||
plt.close()
|
||||
|
||||
@@ -87,7 +89,7 @@ def main(args):
|
||||
del policy
|
||||
|
||||
# setup the environment
|
||||
print(f'====== Start Replaying ======')
|
||||
print('Replaying joint commands')
|
||||
env = make_sim_env(task_name)
|
||||
BOX_POSE[0] = subtask_info # make sure the sim_env has the same object configurations as ee_sim_env
|
||||
ts = env.reset()
|
||||
@@ -96,14 +98,14 @@ def main(args):
|
||||
# setup plotting
|
||||
if onscreen_render:
|
||||
ax = plt.subplot()
|
||||
plt_img = ax.imshow(ts.observation['images']['main'])
|
||||
plt_img = ax.imshow(ts.observation['images'][render_cam_name])
|
||||
plt.ion()
|
||||
for t in range(len(joint_traj)): # note: this will increase episode length by 1
|
||||
action = joint_traj[t]
|
||||
ts = env.step(action)
|
||||
episode_replay.append(ts)
|
||||
if onscreen_render:
|
||||
plt_img.set_data(ts.observation['images']['main'])
|
||||
plt_img.set_data(ts.observation['images'][render_cam_name])
|
||||
plt.pause(0.02)
|
||||
|
||||
episode_return = np.sum([ts.reward for ts in episode_replay[1:]])
|
||||
@@ -121,7 +123,7 @@ def main(args):
|
||||
For each timestep:
|
||||
observations
|
||||
- images
|
||||
- main (480, 640, 3) 'uint8'
|
||||
- each_cam_name (480, 640, 3) 'uint8'
|
||||
- qpos (14,) 'float64'
|
||||
- qvel (14,) 'float64'
|
||||
|
||||
@@ -133,7 +135,7 @@ def main(args):
|
||||
'/observations/qvel': [],
|
||||
'/action': [],
|
||||
}
|
||||
for cam_name in SIM_CAMERA_NAMES:
|
||||
for cam_name in camera_names:
|
||||
data_dict[f'/observations/images/{cam_name}'] = []
|
||||
|
||||
# because the replaying, there will be eps_len + 1 actions and eps_len + 2 timesteps
|
||||
@@ -150,7 +152,7 @@ def main(args):
|
||||
data_dict['/observations/qpos'].append(ts.observation['qpos'])
|
||||
data_dict['/observations/qvel'].append(ts.observation['qvel'])
|
||||
data_dict['/action'].append(action)
|
||||
for cam_name in SIM_CAMERA_NAMES:
|
||||
for cam_name in camera_names:
|
||||
data_dict[f'/observations/images/{cam_name}'].append(ts.observation['images'][cam_name])
|
||||
|
||||
# HDF5
|
||||
@@ -161,8 +163,9 @@ def main(args):
|
||||
root.attrs['sim'] = True
|
||||
obs = root.create_group('observations')
|
||||
image = obs.create_group('images')
|
||||
cam_main = image.create_dataset('main', (max_timesteps, 480, 640, 3), dtype='uint8',
|
||||
chunks=(1, 480, 640, 3), )
|
||||
for cam_name in camera_names:
|
||||
_ = image.create_dataset(cam_name, (max_timesteps, 480, 640, 3), dtype='uint8',
|
||||
chunks=(1, 480, 640, 3), )
|
||||
# compression='gzip',compression_opts=2,)
|
||||
# compression=32001, compression_opts=(0, 0, 0, 0, 9, 1, 1), shuffle=False)
|
||||
qpos = obs.create_dataset('qpos', (max_timesteps, 14))
|
||||
|
||||
Reference in New Issue
Block a user