This commit is contained in:
Tony Zhao
2023-03-14 14:04:08 -07:00
parent 5a33ee8db0
commit 76cf30b4fe
3 changed files with 4 additions and 3 deletions

View File

@@ -27,7 +27,7 @@ SIM_TASK_CONFIGS = {
'sim_insertion_human': { 'sim_insertion_human': {
'dataset_dir': DATA_DIR + '/sim_insertion_human', 'dataset_dir': DATA_DIR + '/sim_insertion_human',
'num_episodes': 50, 'num_episodes': 50,
'episode_len': 400, 'episode_len': 500,
'camera_names': ['top'] 'camera_names': ['top']
}, },
} }

View File

@@ -88,8 +88,6 @@ def main(args):
'real_robot': not is_sim 'real_robot': not is_sim
} }
train_dataloader, val_dataloader, stats, _ = load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val)
if is_eval: if is_eval:
ckpt_names = [f'policy_best.ckpt'] ckpt_names = [f'policy_best.ckpt']
results = [] results = []
@@ -102,6 +100,8 @@ def main(args):
print() print()
exit() exit()
train_dataloader, val_dataloader, stats, _ = load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val)
# save dataset stats # save dataset stats
if not os.path.isdir(ckpt_dir): if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir) os.makedirs(ckpt_dir)

View File

@@ -109,6 +109,7 @@ def get_norm_stats(dataset_dir, num_episodes):
def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val): def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val):
print(f'\nData from: {dataset_dir}\n')
# obtain train test split # obtain train test split
train_ratio = 0.8 train_ratio = 0.8
shuffled_indices = np.random.permutation(num_episodes) shuffled_indices = np.random.permutation(num_episodes)