add task configs to constant.py to reduce command line arguments

This commit is contained in:
Tony Zhao
2023-03-05 16:52:47 -08:00
parent 092735ddb9
commit 5a33ee8db0
11 changed files with 131 additions and 116 deletions

View File

@@ -2,6 +2,7 @@ import numpy as np
import matplotlib.pyplot as plt
from pyquaternion import Quaternion
from constants import SIM_TASK_CONFIGS
from ee_sim_env import make_ee_sim_env
import IPython
@@ -154,13 +155,11 @@ def test_policy(task_name):
inject_noise = False
# setup the environment
from constants import SIM_EPISODE_LEN_TRANSFER_CUBE, SIM_EPISODE_LEN_INSERTION
if task_name == 'transfer_cube':
env = make_ee_sim_env('transfer_cube')
episode_len = SIM_EPISODE_LEN_TRANSFER_CUBE
elif task_name == 'insertion':
env = make_ee_sim_env('insertion')
episode_len = SIM_EPISODE_LEN_INSERTION
episode_len = SIM_TASK_CONFIGS[task_name]['episode_len']
if 'sim_transfer_cube' in task_name:
env = make_ee_sim_env('sim_transfer_cube')
elif 'sim_insertion' in task_name:
env = make_ee_sim_env('sim_insertion')
else:
raise NotImplementedError
@@ -169,7 +168,7 @@ def test_policy(task_name):
episode = [ts]
if onscreen_render:
ax = plt.subplot()
plt_img = ax.imshow(ts.observation['images']['main'])
plt_img = ax.imshow(ts.observation['images']['angle'])
plt.ion()
policy = PickAndTransferPolicy(inject_noise)
@@ -178,7 +177,7 @@ def test_policy(task_name):
ts = env.step(action)
episode.append(ts)
if onscreen_render:
plt_img.set_data(ts.observation['images']['main'])
plt_img.set_data(ts.observation['images']['angle'])
plt.pause(0.02)
plt.close()
@@ -190,6 +189,6 @@ def test_policy(task_name):
if __name__ == '__main__':
test_task_name = 'transfer_cube'
test_task_name = 'sim_transfer_cube_scripted'
test_policy(test_task_name)