follow用policy_last
This commit is contained in:
60
utils.py
60
utils.py
@@ -17,6 +17,7 @@ class EpisodicDataset(torch.utils.data.Dataset):
|
||||
text_feature_dim=768,
|
||||
text_tokenizer_name='distilbert-base-uncased',
|
||||
text_max_length=32,
|
||||
real_action_t_minus_1=True,
|
||||
image_augment=False,
|
||||
image_aug_cfg=None):
|
||||
super(EpisodicDataset).__init__()
|
||||
@@ -29,20 +30,21 @@ class EpisodicDataset(torch.utils.data.Dataset):
|
||||
self.use_cached_text_features = use_cached_text_features
|
||||
self.text_feature_dim = text_feature_dim
|
||||
self.text_max_length = text_max_length
|
||||
self.real_action_t_minus_1 = real_action_t_minus_1
|
||||
self.image_augment = image_augment
|
||||
self.image_aug_cfg = {
|
||||
'p_color': 0.8,
|
||||
'p_highlight': 0.5,
|
||||
'p_noise': 0.5,
|
||||
'p_blur': 0.3,
|
||||
'brightness': 0.25,
|
||||
'contrast': 0.25,
|
||||
'saturation': 0.25,
|
||||
'hue': 0.08,
|
||||
'highlight_strength': (0.15, 0.5),
|
||||
'noise_std': (0.005, 0.03),
|
||||
'blur_sigma': (0.1, 1.5),
|
||||
'blur_kernel_choices': (3, 5),
|
||||
'p_color': 0.4,
|
||||
'p_highlight': 0.3,
|
||||
'p_noise': 0.35,
|
||||
'p_blur': 0.15,
|
||||
'brightness': 0.12,
|
||||
'contrast': 0.12,
|
||||
'saturation': 0.12,
|
||||
'hue': 0.03,
|
||||
'highlight_strength': (0.08, 0.25),
|
||||
'noise_std': (0.003, 0.015),
|
||||
'blur_sigma': (0.1, 0.8),
|
||||
'blur_kernel_choices': (3, ),
|
||||
}
|
||||
if image_aug_cfg is not None:
|
||||
self.image_aug_cfg.update(image_aug_cfg)
|
||||
@@ -221,8 +223,9 @@ class EpisodicDataset(torch.utils.data.Dataset):
|
||||
action = root['/action'][start_ts:]
|
||||
action_len = episode_len - start_ts
|
||||
else:
|
||||
action = root['/action'][max(0, start_ts - 1):] # hack, to make timesteps more aligned
|
||||
action_len = episode_len - max(0, start_ts - 1) # hack, to make timesteps more aligned
|
||||
action_start = max(0, start_ts - 1) if self.real_action_t_minus_1 else start_ts
|
||||
action = root['/action'][action_start:]
|
||||
action_len = episode_len - action_start
|
||||
|
||||
self.is_sim = is_sim
|
||||
padded_action = np.zeros((self.max_episode_len, self.action_dim), dtype=np.float32)
|
||||
@@ -334,23 +337,30 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s
|
||||
text_feature_dim=768,
|
||||
text_tokenizer_name='distilbert-base-uncased',
|
||||
text_max_length=32,
|
||||
real_action_t_minus_1=True,
|
||||
image_augment=False,
|
||||
image_aug_cfg=None):
|
||||
print(f'\nData from: {dataset_dir}\n')
|
||||
episode_ids = _discover_episode_ids(dataset_dir, num_episodes)
|
||||
if len(episode_ids) == 0:
|
||||
raise FileNotFoundError(f'No episode_*.hdf5 found in {dataset_dir}')
|
||||
if len(episode_ids) < 2:
|
||||
raise ValueError(f'Need at least 2 episodes for train/val split, found {len(episode_ids)} in {dataset_dir}')
|
||||
|
||||
# obtain train test split
|
||||
train_ratio = 0.9
|
||||
shuffled_indices = np.random.permutation(len(episode_ids))
|
||||
train_count = int(train_ratio * len(episode_ids))
|
||||
train_indices = shuffled_indices[:train_count]
|
||||
val_indices = shuffled_indices[train_count:]
|
||||
train_episode_ids = np.array(episode_ids)[train_indices]
|
||||
val_episode_ids = np.array(episode_ids)[val_indices]
|
||||
# obtain train/val split
|
||||
if len(episode_ids) == 1:
|
||||
# sanity-check mode: reuse the same episode for both train and val
|
||||
# so training/evaluation loops remain unchanged.
|
||||
train_episode_ids = np.array(episode_ids)
|
||||
val_episode_ids = np.array(episode_ids)
|
||||
print('[load_data] Only 1 episode found. Reusing the same episode for both train and val (sanity-check mode).')
|
||||
else:
|
||||
train_ratio = 0.9
|
||||
shuffled_indices = np.random.permutation(len(episode_ids))
|
||||
train_count = int(train_ratio * len(episode_ids))
|
||||
train_count = max(1, min(len(episode_ids) - 1, train_count))
|
||||
train_indices = shuffled_indices[:train_count]
|
||||
val_indices = shuffled_indices[train_count:]
|
||||
train_episode_ids = np.array(episode_ids)[train_indices]
|
||||
val_episode_ids = np.array(episode_ids)[val_indices]
|
||||
|
||||
# obtain normalization stats for qpos and action
|
||||
norm_stats = get_norm_stats(dataset_dir, episode_ids)
|
||||
@@ -367,6 +377,7 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s
|
||||
text_feature_dim=text_feature_dim,
|
||||
text_tokenizer_name=text_tokenizer_name,
|
||||
text_max_length=text_max_length,
|
||||
real_action_t_minus_1=real_action_t_minus_1,
|
||||
image_augment=image_augment,
|
||||
image_aug_cfg=image_aug_cfg,
|
||||
)
|
||||
@@ -381,6 +392,7 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s
|
||||
text_feature_dim=text_feature_dim,
|
||||
text_tokenizer_name=text_tokenizer_name,
|
||||
text_max_length=text_max_length,
|
||||
real_action_t_minus_1=real_action_t_minus_1,
|
||||
image_augment=False,
|
||||
)
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
|
||||
|
||||
Reference in New Issue
Block a user