small
This commit is contained in:
@@ -27,7 +27,7 @@ SIM_TASK_CONFIGS = {
|
|||||||
'sim_insertion_human': {
|
'sim_insertion_human': {
|
||||||
'dataset_dir': DATA_DIR + '/sim_insertion_human',
|
'dataset_dir': DATA_DIR + '/sim_insertion_human',
|
||||||
'num_episodes': 50,
|
'num_episodes': 50,
|
||||||
'episode_len': 400,
|
'episode_len': 500,
|
||||||
'camera_names': ['top']
|
'camera_names': ['top']
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -88,8 +88,6 @@ def main(args):
|
|||||||
'real_robot': not is_sim
|
'real_robot': not is_sim
|
||||||
}
|
}
|
||||||
|
|
||||||
train_dataloader, val_dataloader, stats, _ = load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val)
|
|
||||||
|
|
||||||
if is_eval:
|
if is_eval:
|
||||||
ckpt_names = [f'policy_best.ckpt']
|
ckpt_names = [f'policy_best.ckpt']
|
||||||
results = []
|
results = []
|
||||||
@@ -102,6 +100,8 @@ def main(args):
|
|||||||
print()
|
print()
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
|
train_dataloader, val_dataloader, stats, _ = load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val)
|
||||||
|
|
||||||
# save dataset stats
|
# save dataset stats
|
||||||
if not os.path.isdir(ckpt_dir):
|
if not os.path.isdir(ckpt_dir):
|
||||||
os.makedirs(ckpt_dir)
|
os.makedirs(ckpt_dir)
|
||||||
|
|||||||
1
utils.py
1
utils.py
@@ -109,6 +109,7 @@ def get_norm_stats(dataset_dir, num_episodes):
|
|||||||
|
|
||||||
|
|
||||||
def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val):
|
def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val):
|
||||||
|
print(f'\nData from: {dataset_dir}\n')
|
||||||
# obtain train test split
|
# obtain train test split
|
||||||
train_ratio = 0.8
|
train_ratio = 0.8
|
||||||
shuffled_indices = np.random.permutation(num_episodes)
|
shuffled_indices = np.random.permutation(num_episodes)
|
||||||
|
|||||||
Reference in New Issue
Block a user