small
This commit is contained in:
@@ -27,7 +27,7 @@ SIM_TASK_CONFIGS = {
|
||||
'sim_insertion_human': {
|
||||
'dataset_dir': DATA_DIR + '/sim_insertion_human',
|
||||
'num_episodes': 50,
|
||||
'episode_len': 400,
|
||||
'episode_len': 500,
|
||||
'camera_names': ['top']
|
||||
},
|
||||
}
|
||||
|
||||
@@ -88,8 +88,6 @@ def main(args):
|
||||
'real_robot': not is_sim
|
||||
}
|
||||
|
||||
train_dataloader, val_dataloader, stats, _ = load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val)
|
||||
|
||||
if is_eval:
|
||||
ckpt_names = [f'policy_best.ckpt']
|
||||
results = []
|
||||
@@ -102,6 +100,8 @@ def main(args):
|
||||
print()
|
||||
exit()
|
||||
|
||||
train_dataloader, val_dataloader, stats, _ = load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val)
|
||||
|
||||
# save dataset stats
|
||||
if not os.path.isdir(ckpt_dir):
|
||||
os.makedirs(ckpt_dir)
|
||||
|
||||
1
utils.py
1
utils.py
@@ -109,6 +109,7 @@ def get_norm_stats(dataset_dir, num_episodes):
|
||||
|
||||
|
||||
def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val):
|
||||
print(f'\nData from: {dataset_dir}\n')
|
||||
# obtain train test split
|
||||
train_ratio = 0.8
|
||||
shuffled_indices = np.random.permutation(num_episodes)
|
||||
|
||||
Reference in New Issue
Block a user