数据增强
This commit is contained in:
91
utils.py
91
utils.py
@@ -4,6 +4,7 @@ import os
|
||||
import h5py
|
||||
import re
|
||||
from torch.utils.data import TensorDataset, DataLoader
|
||||
import torchvision.transforms.functional as TF
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
@@ -15,7 +16,9 @@ class EpisodicDataset(torch.utils.data.Dataset):
|
||||
use_cached_text_features=True,
|
||||
text_feature_dim=768,
|
||||
text_tokenizer_name='distilbert-base-uncased',
|
||||
text_max_length=32):
|
||||
text_max_length=32,
|
||||
image_augment=False,
|
||||
image_aug_cfg=None):
|
||||
super(EpisodicDataset).__init__()
|
||||
self.episode_ids = episode_ids
|
||||
self.dataset_dir = dataset_dir
|
||||
@@ -26,6 +29,23 @@ class EpisodicDataset(torch.utils.data.Dataset):
|
||||
self.use_cached_text_features = use_cached_text_features
|
||||
self.text_feature_dim = text_feature_dim
|
||||
self.text_max_length = text_max_length
|
||||
self.image_augment = image_augment
|
||||
self.image_aug_cfg = {
|
||||
'p_color': 0.8,
|
||||
'p_highlight': 0.5,
|
||||
'p_noise': 0.5,
|
||||
'p_blur': 0.3,
|
||||
'brightness': 0.25,
|
||||
'contrast': 0.25,
|
||||
'saturation': 0.25,
|
||||
'hue': 0.08,
|
||||
'highlight_strength': (0.15, 0.5),
|
||||
'noise_std': (0.005, 0.03),
|
||||
'blur_sigma': (0.1, 1.5),
|
||||
'blur_kernel_choices': (3, 5),
|
||||
}
|
||||
if image_aug_cfg is not None:
|
||||
self.image_aug_cfg.update(image_aug_cfg)
|
||||
self.is_sim = None
|
||||
self.max_episode_len = None
|
||||
self.action_dim = None
|
||||
@@ -45,6 +65,66 @@ class EpisodicDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.__getitem__(0) # initialize self.is_sim
|
||||
|
||||
def _apply_image_augmentation(self, all_cam_images):
|
||||
"""
|
||||
Apply identical augmentation parameters to all camera images for one sample.
|
||||
all_cam_images: np.ndarray [K, H, W, C], uint8
|
||||
"""
|
||||
imgs = torch.from_numpy(all_cam_images).float() / 255.0
|
||||
imgs = torch.einsum('k h w c -> k c h w', imgs)
|
||||
|
||||
cfg = self.image_aug_cfg
|
||||
# color jitter (shared params)
|
||||
if np.random.rand() < cfg['p_color']:
|
||||
b = 1.0 + np.random.uniform(-cfg['brightness'], cfg['brightness'])
|
||||
c = 1.0 + np.random.uniform(-cfg['contrast'], cfg['contrast'])
|
||||
s = 1.0 + np.random.uniform(-cfg['saturation'], cfg['saturation'])
|
||||
h = np.random.uniform(-cfg['hue'], cfg['hue'])
|
||||
for cam_idx in range(imgs.shape[0]):
|
||||
img = imgs[cam_idx]
|
||||
img = TF.adjust_brightness(img, b)
|
||||
img = TF.adjust_contrast(img, c)
|
||||
img = TF.adjust_saturation(img, s)
|
||||
img = TF.adjust_hue(img, h)
|
||||
imgs[cam_idx] = img
|
||||
|
||||
# synthetic highlight / glare (shared parameters)
|
||||
if np.random.rand() < cfg['p_highlight']:
|
||||
_, h_img, w_img = imgs[0].shape
|
||||
cx = np.random.uniform(0.2 * w_img, 0.8 * w_img)
|
||||
cy = np.random.uniform(0.2 * h_img, 0.8 * h_img)
|
||||
sigma = np.random.uniform(0.08, 0.2) * min(h_img, w_img)
|
||||
strength = np.random.uniform(*cfg['highlight_strength'])
|
||||
yy, xx = torch.meshgrid(
|
||||
torch.arange(h_img, dtype=torch.float32),
|
||||
torch.arange(w_img, dtype=torch.float32),
|
||||
indexing='ij',
|
||||
)
|
||||
gauss = torch.exp(-((xx - cx) ** 2 + (yy - cy) ** 2) / (2.0 * sigma * sigma))
|
||||
gauss = (gauss * strength).unsqueeze(0)
|
||||
imgs = imgs + gauss
|
||||
|
||||
# gaussian noise
|
||||
if np.random.rand() < cfg['p_noise']:
|
||||
noise_std = np.random.uniform(*cfg['noise_std'])
|
||||
imgs = imgs + torch.randn_like(imgs) * noise_std
|
||||
|
||||
# gaussian blur
|
||||
if np.random.rand() < cfg['p_blur']:
|
||||
kernel = int(np.random.choice(cfg['blur_kernel_choices']))
|
||||
sigma = float(np.random.uniform(*cfg['blur_sigma']))
|
||||
for cam_idx in range(imgs.shape[0]):
|
||||
imgs[cam_idx] = TF.gaussian_blur(
|
||||
imgs[cam_idx],
|
||||
kernel_size=[kernel, kernel],
|
||||
sigma=[sigma, sigma],
|
||||
)
|
||||
|
||||
imgs = imgs.clamp(0.0, 1.0)
|
||||
imgs = torch.einsum('k c h w -> k h w c', imgs)
|
||||
imgs = (imgs * 255.0).byte().cpu().numpy()
|
||||
return imgs
|
||||
|
||||
def _init_episode_shapes(self):
|
||||
max_len = 0
|
||||
action_dim = None
|
||||
@@ -155,6 +235,8 @@ class EpisodicDataset(torch.utils.data.Dataset):
|
||||
for cam_name in self.camera_names:
|
||||
all_cam_images.append(image_dict[cam_name])
|
||||
all_cam_images = np.stack(all_cam_images, axis=0)
|
||||
if self.image_augment:
|
||||
all_cam_images = self._apply_image_augmentation(all_cam_images)
|
||||
|
||||
# construct observations
|
||||
image_data = torch.from_numpy(all_cam_images)
|
||||
@@ -251,7 +333,9 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s
|
||||
use_cached_text_features=True,
|
||||
text_feature_dim=768,
|
||||
text_tokenizer_name='distilbert-base-uncased',
|
||||
text_max_length=32):
|
||||
text_max_length=32,
|
||||
image_augment=False,
|
||||
image_aug_cfg=None):
|
||||
print(f'\nData from: {dataset_dir}\n')
|
||||
episode_ids = _discover_episode_ids(dataset_dir, num_episodes)
|
||||
if len(episode_ids) == 0:
|
||||
@@ -283,6 +367,8 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s
|
||||
text_feature_dim=text_feature_dim,
|
||||
text_tokenizer_name=text_tokenizer_name,
|
||||
text_max_length=text_max_length,
|
||||
image_augment=image_augment,
|
||||
image_aug_cfg=image_aug_cfg,
|
||||
)
|
||||
val_dataset = EpisodicDataset(
|
||||
val_episode_ids,
|
||||
@@ -295,6 +381,7 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s
|
||||
text_feature_dim=text_feature_dim,
|
||||
text_tokenizer_name=text_tokenizer_name,
|
||||
text_max_length=text_max_length,
|
||||
image_augment=False,
|
||||
)
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
|
||||
val_dataloader = DataLoader(val_dataset, batch_size=batch_size_val, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
|
||||
|
||||
Reference in New Issue
Block a user