add task configs to constant.py to reduce command line arguments
This commit is contained in:
@@ -2,6 +2,7 @@ import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from pyquaternion import Quaternion
|
||||
|
||||
from constants import SIM_TASK_CONFIGS
|
||||
from ee_sim_env import make_ee_sim_env
|
||||
|
||||
import IPython
|
||||
@@ -154,13 +155,11 @@ def test_policy(task_name):
|
||||
inject_noise = False
|
||||
|
||||
# setup the environment
|
||||
from constants import SIM_EPISODE_LEN_TRANSFER_CUBE, SIM_EPISODE_LEN_INSERTION
|
||||
if task_name == 'transfer_cube':
|
||||
env = make_ee_sim_env('transfer_cube')
|
||||
episode_len = SIM_EPISODE_LEN_TRANSFER_CUBE
|
||||
elif task_name == 'insertion':
|
||||
env = make_ee_sim_env('insertion')
|
||||
episode_len = SIM_EPISODE_LEN_INSERTION
|
||||
episode_len = SIM_TASK_CONFIGS[task_name]['episode_len']
|
||||
if 'sim_transfer_cube' in task_name:
|
||||
env = make_ee_sim_env('sim_transfer_cube')
|
||||
elif 'sim_insertion' in task_name:
|
||||
env = make_ee_sim_env('sim_insertion')
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -169,7 +168,7 @@ def test_policy(task_name):
|
||||
episode = [ts]
|
||||
if onscreen_render:
|
||||
ax = plt.subplot()
|
||||
plt_img = ax.imshow(ts.observation['images']['main'])
|
||||
plt_img = ax.imshow(ts.observation['images']['angle'])
|
||||
plt.ion()
|
||||
|
||||
policy = PickAndTransferPolicy(inject_noise)
|
||||
@@ -178,7 +177,7 @@ def test_policy(task_name):
|
||||
ts = env.step(action)
|
||||
episode.append(ts)
|
||||
if onscreen_render:
|
||||
plt_img.set_data(ts.observation['images']['main'])
|
||||
plt_img.set_data(ts.observation['images']['angle'])
|
||||
plt.pause(0.02)
|
||||
plt.close()
|
||||
|
||||
@@ -190,6 +189,6 @@ def test_policy(task_name):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_task_name = 'transfer_cube'
|
||||
test_task_name = 'sim_transfer_cube_scripted'
|
||||
test_policy(test_task_name)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user