数据增强
This commit is contained in:
@@ -73,6 +73,7 @@ def get_args_parser():
|
|||||||
parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False)
|
parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False)
|
||||||
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)
|
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)
|
||||||
parser.add_argument('--temporal_agg', action='store_true')
|
parser.add_argument('--temporal_agg', action='store_true')
|
||||||
|
parser.add_argument('--image_aug', action='store_true')
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|||||||
@@ -154,6 +154,7 @@ def main(args):
|
|||||||
text_feature_dim=text_feature_dim,
|
text_feature_dim=text_feature_dim,
|
||||||
text_tokenizer_name=text_tokenizer_name,
|
text_tokenizer_name=text_tokenizer_name,
|
||||||
text_max_length=text_max_length,
|
text_max_length=text_max_length,
|
||||||
|
image_augment=args['image_aug'],
|
||||||
)
|
)
|
||||||
|
|
||||||
# save dataset stats
|
# save dataset stats
|
||||||
@@ -516,5 +517,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--text_encoder_type', action='store', type=str, required=False)
|
parser.add_argument('--text_encoder_type', action='store', type=str, required=False)
|
||||||
parser.add_argument('--freeze_text_encoder', action='store_true')
|
parser.add_argument('--freeze_text_encoder', action='store_true')
|
||||||
parser.add_argument('--text_max_length', action='store', type=int, required=False)
|
parser.add_argument('--text_max_length', action='store', type=int, required=False)
|
||||||
|
parser.add_argument('--image_aug', action='store_true',
|
||||||
|
help='Enable training-time image augmentation (color/highlight/noise/blur)')
|
||||||
|
|
||||||
main(vars(parser.parse_args()))
|
main(vars(parser.parse_args()))
|
||||||
|
|||||||
91
utils.py
91
utils.py
@@ -4,6 +4,7 @@ import os
|
|||||||
import h5py
|
import h5py
|
||||||
import re
|
import re
|
||||||
from torch.utils.data import TensorDataset, DataLoader
|
from torch.utils.data import TensorDataset, DataLoader
|
||||||
|
import torchvision.transforms.functional as TF
|
||||||
|
|
||||||
import IPython
|
import IPython
|
||||||
e = IPython.embed
|
e = IPython.embed
|
||||||
@@ -15,7 +16,9 @@ class EpisodicDataset(torch.utils.data.Dataset):
|
|||||||
use_cached_text_features=True,
|
use_cached_text_features=True,
|
||||||
text_feature_dim=768,
|
text_feature_dim=768,
|
||||||
text_tokenizer_name='distilbert-base-uncased',
|
text_tokenizer_name='distilbert-base-uncased',
|
||||||
text_max_length=32):
|
text_max_length=32,
|
||||||
|
image_augment=False,
|
||||||
|
image_aug_cfg=None):
|
||||||
super(EpisodicDataset).__init__()
|
super(EpisodicDataset).__init__()
|
||||||
self.episode_ids = episode_ids
|
self.episode_ids = episode_ids
|
||||||
self.dataset_dir = dataset_dir
|
self.dataset_dir = dataset_dir
|
||||||
@@ -26,6 +29,23 @@ class EpisodicDataset(torch.utils.data.Dataset):
|
|||||||
self.use_cached_text_features = use_cached_text_features
|
self.use_cached_text_features = use_cached_text_features
|
||||||
self.text_feature_dim = text_feature_dim
|
self.text_feature_dim = text_feature_dim
|
||||||
self.text_max_length = text_max_length
|
self.text_max_length = text_max_length
|
||||||
|
self.image_augment = image_augment
|
||||||
|
self.image_aug_cfg = {
|
||||||
|
'p_color': 0.8,
|
||||||
|
'p_highlight': 0.5,
|
||||||
|
'p_noise': 0.5,
|
||||||
|
'p_blur': 0.3,
|
||||||
|
'brightness': 0.25,
|
||||||
|
'contrast': 0.25,
|
||||||
|
'saturation': 0.25,
|
||||||
|
'hue': 0.08,
|
||||||
|
'highlight_strength': (0.15, 0.5),
|
||||||
|
'noise_std': (0.005, 0.03),
|
||||||
|
'blur_sigma': (0.1, 1.5),
|
||||||
|
'blur_kernel_choices': (3, 5),
|
||||||
|
}
|
||||||
|
if image_aug_cfg is not None:
|
||||||
|
self.image_aug_cfg.update(image_aug_cfg)
|
||||||
self.is_sim = None
|
self.is_sim = None
|
||||||
self.max_episode_len = None
|
self.max_episode_len = None
|
||||||
self.action_dim = None
|
self.action_dim = None
|
||||||
@@ -45,6 +65,66 @@ class EpisodicDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
self.__getitem__(0) # initialize self.is_sim
|
self.__getitem__(0) # initialize self.is_sim
|
||||||
|
|
||||||
|
def _apply_image_augmentation(self, all_cam_images):
|
||||||
|
"""
|
||||||
|
Apply identical augmentation parameters to all camera images for one sample.
|
||||||
|
all_cam_images: np.ndarray [K, H, W, C], uint8
|
||||||
|
"""
|
||||||
|
imgs = torch.from_numpy(all_cam_images).float() / 255.0
|
||||||
|
imgs = torch.einsum('k h w c -> k c h w', imgs)
|
||||||
|
|
||||||
|
cfg = self.image_aug_cfg
|
||||||
|
# color jitter (shared params)
|
||||||
|
if np.random.rand() < cfg['p_color']:
|
||||||
|
b = 1.0 + np.random.uniform(-cfg['brightness'], cfg['brightness'])
|
||||||
|
c = 1.0 + np.random.uniform(-cfg['contrast'], cfg['contrast'])
|
||||||
|
s = 1.0 + np.random.uniform(-cfg['saturation'], cfg['saturation'])
|
||||||
|
h = np.random.uniform(-cfg['hue'], cfg['hue'])
|
||||||
|
for cam_idx in range(imgs.shape[0]):
|
||||||
|
img = imgs[cam_idx]
|
||||||
|
img = TF.adjust_brightness(img, b)
|
||||||
|
img = TF.adjust_contrast(img, c)
|
||||||
|
img = TF.adjust_saturation(img, s)
|
||||||
|
img = TF.adjust_hue(img, h)
|
||||||
|
imgs[cam_idx] = img
|
||||||
|
|
||||||
|
# synthetic highlight / glare (shared parameters)
|
||||||
|
if np.random.rand() < cfg['p_highlight']:
|
||||||
|
_, h_img, w_img = imgs[0].shape
|
||||||
|
cx = np.random.uniform(0.2 * w_img, 0.8 * w_img)
|
||||||
|
cy = np.random.uniform(0.2 * h_img, 0.8 * h_img)
|
||||||
|
sigma = np.random.uniform(0.08, 0.2) * min(h_img, w_img)
|
||||||
|
strength = np.random.uniform(*cfg['highlight_strength'])
|
||||||
|
yy, xx = torch.meshgrid(
|
||||||
|
torch.arange(h_img, dtype=torch.float32),
|
||||||
|
torch.arange(w_img, dtype=torch.float32),
|
||||||
|
indexing='ij',
|
||||||
|
)
|
||||||
|
gauss = torch.exp(-((xx - cx) ** 2 + (yy - cy) ** 2) / (2.0 * sigma * sigma))
|
||||||
|
gauss = (gauss * strength).unsqueeze(0)
|
||||||
|
imgs = imgs + gauss
|
||||||
|
|
||||||
|
# gaussian noise
|
||||||
|
if np.random.rand() < cfg['p_noise']:
|
||||||
|
noise_std = np.random.uniform(*cfg['noise_std'])
|
||||||
|
imgs = imgs + torch.randn_like(imgs) * noise_std
|
||||||
|
|
||||||
|
# gaussian blur
|
||||||
|
if np.random.rand() < cfg['p_blur']:
|
||||||
|
kernel = int(np.random.choice(cfg['blur_kernel_choices']))
|
||||||
|
sigma = float(np.random.uniform(*cfg['blur_sigma']))
|
||||||
|
for cam_idx in range(imgs.shape[0]):
|
||||||
|
imgs[cam_idx] = TF.gaussian_blur(
|
||||||
|
imgs[cam_idx],
|
||||||
|
kernel_size=[kernel, kernel],
|
||||||
|
sigma=[sigma, sigma],
|
||||||
|
)
|
||||||
|
|
||||||
|
imgs = imgs.clamp(0.0, 1.0)
|
||||||
|
imgs = torch.einsum('k c h w -> k h w c', imgs)
|
||||||
|
imgs = (imgs * 255.0).byte().cpu().numpy()
|
||||||
|
return imgs
|
||||||
|
|
||||||
def _init_episode_shapes(self):
|
def _init_episode_shapes(self):
|
||||||
max_len = 0
|
max_len = 0
|
||||||
action_dim = None
|
action_dim = None
|
||||||
@@ -155,6 +235,8 @@ class EpisodicDataset(torch.utils.data.Dataset):
|
|||||||
for cam_name in self.camera_names:
|
for cam_name in self.camera_names:
|
||||||
all_cam_images.append(image_dict[cam_name])
|
all_cam_images.append(image_dict[cam_name])
|
||||||
all_cam_images = np.stack(all_cam_images, axis=0)
|
all_cam_images = np.stack(all_cam_images, axis=0)
|
||||||
|
if self.image_augment:
|
||||||
|
all_cam_images = self._apply_image_augmentation(all_cam_images)
|
||||||
|
|
||||||
# construct observations
|
# construct observations
|
||||||
image_data = torch.from_numpy(all_cam_images)
|
image_data = torch.from_numpy(all_cam_images)
|
||||||
@@ -251,7 +333,9 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s
|
|||||||
use_cached_text_features=True,
|
use_cached_text_features=True,
|
||||||
text_feature_dim=768,
|
text_feature_dim=768,
|
||||||
text_tokenizer_name='distilbert-base-uncased',
|
text_tokenizer_name='distilbert-base-uncased',
|
||||||
text_max_length=32):
|
text_max_length=32,
|
||||||
|
image_augment=False,
|
||||||
|
image_aug_cfg=None):
|
||||||
print(f'\nData from: {dataset_dir}\n')
|
print(f'\nData from: {dataset_dir}\n')
|
||||||
episode_ids = _discover_episode_ids(dataset_dir, num_episodes)
|
episode_ids = _discover_episode_ids(dataset_dir, num_episodes)
|
||||||
if len(episode_ids) == 0:
|
if len(episode_ids) == 0:
|
||||||
@@ -283,6 +367,8 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s
|
|||||||
text_feature_dim=text_feature_dim,
|
text_feature_dim=text_feature_dim,
|
||||||
text_tokenizer_name=text_tokenizer_name,
|
text_tokenizer_name=text_tokenizer_name,
|
||||||
text_max_length=text_max_length,
|
text_max_length=text_max_length,
|
||||||
|
image_augment=image_augment,
|
||||||
|
image_aug_cfg=image_aug_cfg,
|
||||||
)
|
)
|
||||||
val_dataset = EpisodicDataset(
|
val_dataset = EpisodicDataset(
|
||||||
val_episode_ids,
|
val_episode_ids,
|
||||||
@@ -295,6 +381,7 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s
|
|||||||
text_feature_dim=text_feature_dim,
|
text_feature_dim=text_feature_dim,
|
||||||
text_tokenizer_name=text_tokenizer_name,
|
text_tokenizer_name=text_tokenizer_name,
|
||||||
text_max_length=text_max_length,
|
text_max_length=text_max_length,
|
||||||
|
image_augment=False,
|
||||||
)
|
)
|
||||||
train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
|
train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
|
||||||
val_dataloader = DataLoader(val_dataset, batch_size=batch_size_val, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
|
val_dataloader = DataLoader(val_dataset, batch_size=batch_size_val, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
|
||||||
|
|||||||
Reference in New Issue
Block a user