follow用policy_last

This commit is contained in:
2026-02-20 16:45:16 +08:00
parent 88d0cc5ca2
commit 81e1bf8838
4 changed files with 129 additions and 30 deletions

View File

@@ -17,6 +17,7 @@ class EpisodicDataset(torch.utils.data.Dataset):
text_feature_dim=768,
text_tokenizer_name='distilbert-base-uncased',
text_max_length=32,
real_action_t_minus_1=True,
image_augment=False,
image_aug_cfg=None):
super(EpisodicDataset).__init__()
@@ -29,20 +30,21 @@ class EpisodicDataset(torch.utils.data.Dataset):
self.use_cached_text_features = use_cached_text_features
self.text_feature_dim = text_feature_dim
self.text_max_length = text_max_length
self.real_action_t_minus_1 = real_action_t_minus_1
self.image_augment = image_augment
self.image_aug_cfg = {
'p_color': 0.8,
'p_highlight': 0.5,
'p_noise': 0.5,
'p_blur': 0.3,
'brightness': 0.25,
'contrast': 0.25,
'saturation': 0.25,
'hue': 0.08,
'highlight_strength': (0.15, 0.5),
'noise_std': (0.005, 0.03),
'blur_sigma': (0.1, 1.5),
'blur_kernel_choices': (3, 5),
'p_color': 0.4,
'p_highlight': 0.3,
'p_noise': 0.35,
'p_blur': 0.15,
'brightness': 0.12,
'contrast': 0.12,
'saturation': 0.12,
'hue': 0.03,
'highlight_strength': (0.08, 0.25),
'noise_std': (0.003, 0.015),
'blur_sigma': (0.1, 0.8),
'blur_kernel_choices': (3, ),
}
if image_aug_cfg is not None:
self.image_aug_cfg.update(image_aug_cfg)
@@ -221,8 +223,9 @@ class EpisodicDataset(torch.utils.data.Dataset):
action = root['/action'][start_ts:]
action_len = episode_len - start_ts
else:
action = root['/action'][max(0, start_ts - 1):] # hack, to make timesteps more aligned
action_len = episode_len - max(0, start_ts - 1) # hack, to make timesteps more aligned
action_start = max(0, start_ts - 1) if self.real_action_t_minus_1 else start_ts
action = root['/action'][action_start:]
action_len = episode_len - action_start
self.is_sim = is_sim
padded_action = np.zeros((self.max_episode_len, self.action_dim), dtype=np.float32)
@@ -334,23 +337,30 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s
text_feature_dim=768,
text_tokenizer_name='distilbert-base-uncased',
text_max_length=32,
real_action_t_minus_1=True,
image_augment=False,
image_aug_cfg=None):
print(f'\nData from: {dataset_dir}\n')
episode_ids = _discover_episode_ids(dataset_dir, num_episodes)
if len(episode_ids) == 0:
raise FileNotFoundError(f'No episode_*.hdf5 found in {dataset_dir}')
if len(episode_ids) < 2:
raise ValueError(f'Need at least 2 episodes for train/val split, found {len(episode_ids)} in {dataset_dir}')
# obtain train test split
train_ratio = 0.9
shuffled_indices = np.random.permutation(len(episode_ids))
train_count = int(train_ratio * len(episode_ids))
train_indices = shuffled_indices[:train_count]
val_indices = shuffled_indices[train_count:]
train_episode_ids = np.array(episode_ids)[train_indices]
val_episode_ids = np.array(episode_ids)[val_indices]
# obtain train/val split
if len(episode_ids) == 1:
# sanity-check mode: reuse the same episode for both train and val
# so training/evaluation loops remain unchanged.
train_episode_ids = np.array(episode_ids)
val_episode_ids = np.array(episode_ids)
print('[load_data] Only 1 episode found. Reusing the same episode for both train and val (sanity-check mode).')
else:
train_ratio = 0.9
shuffled_indices = np.random.permutation(len(episode_ids))
train_count = int(train_ratio * len(episode_ids))
train_count = max(1, min(len(episode_ids) - 1, train_count))
train_indices = shuffled_indices[:train_count]
val_indices = shuffled_indices[train_count:]
train_episode_ids = np.array(episode_ids)[train_indices]
val_episode_ids = np.array(episode_ids)[val_indices]
# obtain normalization stats for qpos and action
norm_stats = get_norm_stats(dataset_dir, episode_ids)
@@ -367,6 +377,7 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s
text_feature_dim=text_feature_dim,
text_tokenizer_name=text_tokenizer_name,
text_max_length=text_max_length,
real_action_t_minus_1=real_action_t_minus_1,
image_augment=image_augment,
image_aug_cfg=image_aug_cfg,
)
@@ -381,6 +392,7 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s
text_feature_dim=text_feature_dim,
text_tokenizer_name=text_tokenizer_name,
text_max_length=text_max_length,
real_action_t_minus_1=real_action_t_minus_1,
image_augment=False,
)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)