diff --git a/constants.py b/constants.py index 7dfaaae..f445350 100644 --- a/constants.py +++ b/constants.py @@ -27,7 +27,7 @@ SIM_TASK_CONFIGS = { 'sim_insertion_human': { 'dataset_dir': DATA_DIR + '/sim_insertion_human', 'num_episodes': 50, - 'episode_len': 400, + 'episode_len': 500, 'camera_names': ['top'] }, } diff --git a/imitate_episodes.py b/imitate_episodes.py index 5500c28..34f9a37 100644 --- a/imitate_episodes.py +++ b/imitate_episodes.py @@ -88,8 +88,6 @@ def main(args): 'real_robot': not is_sim } - train_dataloader, val_dataloader, stats, _ = load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val) - if is_eval: ckpt_names = [f'policy_best.ckpt'] results = [] @@ -102,6 +100,8 @@ def main(args): print() exit() + train_dataloader, val_dataloader, stats, _ = load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val) + # save dataset stats if not os.path.isdir(ckpt_dir): os.makedirs(ckpt_dir) diff --git a/utils.py b/utils.py index fa7e3cb..673cbb1 100644 --- a/utils.py +++ b/utils.py @@ -109,6 +109,7 @@ def get_norm_stats(dataset_dir, num_episodes): def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val): + print(f'\nData from: {dataset_dir}\n') # obtain train test split train_ratio = 0.8 shuffled_indices = np.random.permutation(num_episodes)