diff --git a/detr/main.py b/detr/main.py index d8f85d5..044b2a3 100644 --- a/detr/main.py +++ b/detr/main.py @@ -73,6 +73,7 @@ def get_args_parser(): parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False) parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False) parser.add_argument('--temporal_agg', action='store_true') + parser.add_argument('--image_aug', action='store_true') return parser diff --git a/imitate_episodes.py b/imitate_episodes.py index 77d0077..0b3dd88 100644 --- a/imitate_episodes.py +++ b/imitate_episodes.py @@ -154,6 +154,7 @@ def main(args): text_feature_dim=text_feature_dim, text_tokenizer_name=text_tokenizer_name, text_max_length=text_max_length, + image_augment=args['image_aug'], ) # save dataset stats @@ -516,5 +517,7 @@ if __name__ == '__main__': parser.add_argument('--text_encoder_type', action='store', type=str, required=False) parser.add_argument('--freeze_text_encoder', action='store_true') parser.add_argument('--text_max_length', action='store', type=int, required=False) + parser.add_argument('--image_aug', action='store_true', + help='Enable training-time image augmentation (color/highlight/noise/blur)') main(vars(parser.parse_args())) diff --git a/utils.py b/utils.py index 6b79697..e2ea612 100644 --- a/utils.py +++ b/utils.py @@ -4,6 +4,7 @@ import os import h5py import re from torch.utils.data import TensorDataset, DataLoader +import torchvision.transforms.functional as TF import IPython e = IPython.embed @@ -15,7 +16,9 @@ class EpisodicDataset(torch.utils.data.Dataset): use_cached_text_features=True, text_feature_dim=768, text_tokenizer_name='distilbert-base-uncased', - text_max_length=32): + text_max_length=32, + image_augment=False, + image_aug_cfg=None): super(EpisodicDataset).__init__() self.episode_ids = episode_ids self.dataset_dir = dataset_dir @@ -26,6 +29,23 @@ class EpisodicDataset(torch.utils.data.Dataset): self.use_cached_text_features = use_cached_text_features self.text_feature_dim = text_feature_dim self.text_max_length = text_max_length + self.image_augment = image_augment + self.image_aug_cfg = { + 'p_color': 0.8, + 'p_highlight': 0.5, + 'p_noise': 0.5, + 'p_blur': 0.3, + 'brightness': 0.25, + 'contrast': 0.25, + 'saturation': 0.25, + 'hue': 0.08, + 'highlight_strength': (0.15, 0.5), + 'noise_std': (0.005, 0.03), + 'blur_sigma': (0.1, 1.5), + 'blur_kernel_choices': (3, 5), + } + if image_aug_cfg is not None: + self.image_aug_cfg.update(image_aug_cfg) self.is_sim = None self.max_episode_len = None self.action_dim = None @@ -45,6 +65,66 @@ class EpisodicDataset(torch.utils.data.Dataset): self.__getitem__(0) # initialize self.is_sim + def _apply_image_augmentation(self, all_cam_images): + """ + Apply identical augmentation parameters to all camera images for one sample. + all_cam_images: np.ndarray [K, H, W, C], uint8 + """ + imgs = torch.from_numpy(all_cam_images).float() / 255.0 + imgs = torch.einsum('k h w c -> k c h w', imgs) + + cfg = self.image_aug_cfg + # color jitter (shared params) + if np.random.rand() < cfg['p_color']: + b = 1.0 + np.random.uniform(-cfg['brightness'], cfg['brightness']) + c = 1.0 + np.random.uniform(-cfg['contrast'], cfg['contrast']) + s = 1.0 + np.random.uniform(-cfg['saturation'], cfg['saturation']) + h = np.random.uniform(-cfg['hue'], cfg['hue']) + for cam_idx in range(imgs.shape[0]): + img = imgs[cam_idx] + img = TF.adjust_brightness(img, b) + img = TF.adjust_contrast(img, c) + img = TF.adjust_saturation(img, s) + img = TF.adjust_hue(img, h) + imgs[cam_idx] = img + + # synthetic highlight / glare (shared parameters) + if np.random.rand() < cfg['p_highlight']: + _, h_img, w_img = imgs[0].shape + cx = np.random.uniform(0.2 * w_img, 0.8 * w_img) + cy = np.random.uniform(0.2 * h_img, 0.8 * h_img) + sigma = np.random.uniform(0.08, 0.2) * min(h_img, w_img) + strength = np.random.uniform(*cfg['highlight_strength']) + yy, xx = torch.meshgrid( + torch.arange(h_img, dtype=torch.float32), + torch.arange(w_img, dtype=torch.float32), + indexing='ij', + ) + gauss = torch.exp(-((xx - cx) ** 2 + (yy - cy) ** 2) / (2.0 * sigma * sigma)) + gauss = (gauss * strength).unsqueeze(0) + imgs = imgs + gauss + + # gaussian noise + if np.random.rand() < cfg['p_noise']: + noise_std = np.random.uniform(*cfg['noise_std']) + imgs = imgs + torch.randn_like(imgs) * noise_std + + # gaussian blur + if np.random.rand() < cfg['p_blur']: + kernel = int(np.random.choice(cfg['blur_kernel_choices'])) + sigma = float(np.random.uniform(*cfg['blur_sigma'])) + for cam_idx in range(imgs.shape[0]): + imgs[cam_idx] = TF.gaussian_blur( + imgs[cam_idx], + kernel_size=[kernel, kernel], + sigma=[sigma, sigma], + ) + + imgs = imgs.clamp(0.0, 1.0) + imgs = torch.einsum('k c h w -> k h w c', imgs) + imgs = (imgs * 255.0).byte().cpu().numpy() + return imgs + def _init_episode_shapes(self): max_len = 0 action_dim = None @@ -155,6 +235,8 @@ class EpisodicDataset(torch.utils.data.Dataset): for cam_name in self.camera_names: all_cam_images.append(image_dict[cam_name]) all_cam_images = np.stack(all_cam_images, axis=0) + if self.image_augment: + all_cam_images = self._apply_image_augmentation(all_cam_images) # construct observations image_data = torch.from_numpy(all_cam_images) @@ -251,7 +333,9 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s use_cached_text_features=True, text_feature_dim=768, text_tokenizer_name='distilbert-base-uncased', - text_max_length=32): + text_max_length=32, + image_augment=False, + image_aug_cfg=None): print(f'\nData from: {dataset_dir}\n') episode_ids = _discover_episode_ids(dataset_dir, num_episodes) if len(episode_ids) == 0: @@ -283,6 +367,8 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s text_feature_dim=text_feature_dim, text_tokenizer_name=text_tokenizer_name, text_max_length=text_max_length, + image_augment=image_augment, + image_aug_cfg=image_aug_cfg, ) val_dataset = EpisodicDataset( val_episode_ids, @@ -295,6 +381,7 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s text_feature_dim=text_feature_dim, text_tokenizer_name=text_tokenizer_name, text_max_length=text_max_length, + image_augment=False, ) train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1) val_dataloader = DataLoader(val_dataset, batch_size=batch_size_val, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)