数据增强

This commit is contained in:
2026-02-19 21:29:32 +08:00
parent 88d14221ae
commit 7023d5dde4
3 changed files with 93 additions and 2 deletions

View File

@@ -73,6 +73,7 @@ def get_args_parser():
parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False) parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False)
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False) parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)
parser.add_argument('--temporal_agg', action='store_true') parser.add_argument('--temporal_agg', action='store_true')
parser.add_argument('--image_aug', action='store_true')
return parser return parser

View File

@@ -154,6 +154,7 @@ def main(args):
text_feature_dim=text_feature_dim, text_feature_dim=text_feature_dim,
text_tokenizer_name=text_tokenizer_name, text_tokenizer_name=text_tokenizer_name,
text_max_length=text_max_length, text_max_length=text_max_length,
image_augment=args['image_aug'],
) )
# save dataset stats # save dataset stats
@@ -516,5 +517,7 @@ if __name__ == '__main__':
parser.add_argument('--text_encoder_type', action='store', type=str, required=False) parser.add_argument('--text_encoder_type', action='store', type=str, required=False)
parser.add_argument('--freeze_text_encoder', action='store_true') parser.add_argument('--freeze_text_encoder', action='store_true')
parser.add_argument('--text_max_length', action='store', type=int, required=False) parser.add_argument('--text_max_length', action='store', type=int, required=False)
parser.add_argument('--image_aug', action='store_true',
help='Enable training-time image augmentation (color/highlight/noise/blur)')
main(vars(parser.parse_args())) main(vars(parser.parse_args()))

View File

@@ -4,6 +4,7 @@ import os
import h5py import h5py
import re import re
from torch.utils.data import TensorDataset, DataLoader from torch.utils.data import TensorDataset, DataLoader
import torchvision.transforms.functional as TF
import IPython import IPython
e = IPython.embed e = IPython.embed
@@ -15,7 +16,9 @@ class EpisodicDataset(torch.utils.data.Dataset):
use_cached_text_features=True, use_cached_text_features=True,
text_feature_dim=768, text_feature_dim=768,
text_tokenizer_name='distilbert-base-uncased', text_tokenizer_name='distilbert-base-uncased',
text_max_length=32): text_max_length=32,
image_augment=False,
image_aug_cfg=None):
super(EpisodicDataset).__init__() super(EpisodicDataset).__init__()
self.episode_ids = episode_ids self.episode_ids = episode_ids
self.dataset_dir = dataset_dir self.dataset_dir = dataset_dir
@@ -26,6 +29,23 @@ class EpisodicDataset(torch.utils.data.Dataset):
self.use_cached_text_features = use_cached_text_features self.use_cached_text_features = use_cached_text_features
self.text_feature_dim = text_feature_dim self.text_feature_dim = text_feature_dim
self.text_max_length = text_max_length self.text_max_length = text_max_length
self.image_augment = image_augment
self.image_aug_cfg = {
'p_color': 0.8,
'p_highlight': 0.5,
'p_noise': 0.5,
'p_blur': 0.3,
'brightness': 0.25,
'contrast': 0.25,
'saturation': 0.25,
'hue': 0.08,
'highlight_strength': (0.15, 0.5),
'noise_std': (0.005, 0.03),
'blur_sigma': (0.1, 1.5),
'blur_kernel_choices': (3, 5),
}
if image_aug_cfg is not None:
self.image_aug_cfg.update(image_aug_cfg)
self.is_sim = None self.is_sim = None
self.max_episode_len = None self.max_episode_len = None
self.action_dim = None self.action_dim = None
@@ -45,6 +65,66 @@ class EpisodicDataset(torch.utils.data.Dataset):
self.__getitem__(0) # initialize self.is_sim self.__getitem__(0) # initialize self.is_sim
def _apply_image_augmentation(self, all_cam_images):
"""
Apply identical augmentation parameters to all camera images for one sample.
all_cam_images: np.ndarray [K, H, W, C], uint8
"""
imgs = torch.from_numpy(all_cam_images).float() / 255.0
imgs = torch.einsum('k h w c -> k c h w', imgs)
cfg = self.image_aug_cfg
# color jitter (shared params)
if np.random.rand() < cfg['p_color']:
b = 1.0 + np.random.uniform(-cfg['brightness'], cfg['brightness'])
c = 1.0 + np.random.uniform(-cfg['contrast'], cfg['contrast'])
s = 1.0 + np.random.uniform(-cfg['saturation'], cfg['saturation'])
h = np.random.uniform(-cfg['hue'], cfg['hue'])
for cam_idx in range(imgs.shape[0]):
img = imgs[cam_idx]
img = TF.adjust_brightness(img, b)
img = TF.adjust_contrast(img, c)
img = TF.adjust_saturation(img, s)
img = TF.adjust_hue(img, h)
imgs[cam_idx] = img
# synthetic highlight / glare (shared parameters)
if np.random.rand() < cfg['p_highlight']:
_, h_img, w_img = imgs[0].shape
cx = np.random.uniform(0.2 * w_img, 0.8 * w_img)
cy = np.random.uniform(0.2 * h_img, 0.8 * h_img)
sigma = np.random.uniform(0.08, 0.2) * min(h_img, w_img)
strength = np.random.uniform(*cfg['highlight_strength'])
yy, xx = torch.meshgrid(
torch.arange(h_img, dtype=torch.float32),
torch.arange(w_img, dtype=torch.float32),
indexing='ij',
)
gauss = torch.exp(-((xx - cx) ** 2 + (yy - cy) ** 2) / (2.0 * sigma * sigma))
gauss = (gauss * strength).unsqueeze(0)
imgs = imgs + gauss
# gaussian noise
if np.random.rand() < cfg['p_noise']:
noise_std = np.random.uniform(*cfg['noise_std'])
imgs = imgs + torch.randn_like(imgs) * noise_std
# gaussian blur
if np.random.rand() < cfg['p_blur']:
kernel = int(np.random.choice(cfg['blur_kernel_choices']))
sigma = float(np.random.uniform(*cfg['blur_sigma']))
for cam_idx in range(imgs.shape[0]):
imgs[cam_idx] = TF.gaussian_blur(
imgs[cam_idx],
kernel_size=[kernel, kernel],
sigma=[sigma, sigma],
)
imgs = imgs.clamp(0.0, 1.0)
imgs = torch.einsum('k c h w -> k h w c', imgs)
imgs = (imgs * 255.0).byte().cpu().numpy()
return imgs
def _init_episode_shapes(self): def _init_episode_shapes(self):
max_len = 0 max_len = 0
action_dim = None action_dim = None
@@ -155,6 +235,8 @@ class EpisodicDataset(torch.utils.data.Dataset):
for cam_name in self.camera_names: for cam_name in self.camera_names:
all_cam_images.append(image_dict[cam_name]) all_cam_images.append(image_dict[cam_name])
all_cam_images = np.stack(all_cam_images, axis=0) all_cam_images = np.stack(all_cam_images, axis=0)
if self.image_augment:
all_cam_images = self._apply_image_augmentation(all_cam_images)
# construct observations # construct observations
image_data = torch.from_numpy(all_cam_images) image_data = torch.from_numpy(all_cam_images)
@@ -251,7 +333,9 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s
use_cached_text_features=True, use_cached_text_features=True,
text_feature_dim=768, text_feature_dim=768,
text_tokenizer_name='distilbert-base-uncased', text_tokenizer_name='distilbert-base-uncased',
text_max_length=32): text_max_length=32,
image_augment=False,
image_aug_cfg=None):
print(f'\nData from: {dataset_dir}\n') print(f'\nData from: {dataset_dir}\n')
episode_ids = _discover_episode_ids(dataset_dir, num_episodes) episode_ids = _discover_episode_ids(dataset_dir, num_episodes)
if len(episode_ids) == 0: if len(episode_ids) == 0:
@@ -283,6 +367,8 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s
text_feature_dim=text_feature_dim, text_feature_dim=text_feature_dim,
text_tokenizer_name=text_tokenizer_name, text_tokenizer_name=text_tokenizer_name,
text_max_length=text_max_length, text_max_length=text_max_length,
image_augment=image_augment,
image_aug_cfg=image_aug_cfg,
) )
val_dataset = EpisodicDataset( val_dataset = EpisodicDataset(
val_episode_ids, val_episode_ids,
@@ -295,6 +381,7 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s
text_feature_dim=text_feature_dim, text_feature_dim=text_feature_dim,
text_tokenizer_name=text_tokenizer_name, text_tokenizer_name=text_tokenizer_name,
text_max_length=text_max_length, text_max_length=text_max_length,
image_augment=False,
) )
train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1) train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size_val, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1) val_dataloader = DataLoader(val_dataset, batch_size=batch_size_val, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)