import numpy as np import torch import os import h5py import re from torch.utils.data import TensorDataset, DataLoader import torchvision.transforms.functional as TF import IPython e = IPython.embed class EpisodicDataset(torch.utils.data.Dataset): def __init__(self, episode_ids, dataset_dir, camera_names, norm_stats, use_text_instruction=False, instruction_mode='timestep-level', use_cached_text_features=True, text_feature_dim=768, text_tokenizer_name='distilbert-base-uncased', text_max_length=32, image_augment=False, image_aug_cfg=None): super(EpisodicDataset).__init__() self.episode_ids = episode_ids self.dataset_dir = dataset_dir self.camera_names = camera_names self.norm_stats = norm_stats self.use_text_instruction = use_text_instruction self.instruction_mode = instruction_mode self.use_cached_text_features = use_cached_text_features self.text_feature_dim = text_feature_dim self.text_max_length = text_max_length self.image_augment = image_augment self.image_aug_cfg = { 'p_color': 0.8, 'p_highlight': 0.5, 'p_noise': 0.5, 'p_blur': 0.3, 'brightness': 0.25, 'contrast': 0.25, 'saturation': 0.25, 'hue': 0.08, 'highlight_strength': (0.15, 0.5), 'noise_std': (0.005, 0.03), 'blur_sigma': (0.1, 1.5), 'blur_kernel_choices': (3, 5), } if image_aug_cfg is not None: self.image_aug_cfg.update(image_aug_cfg) self.is_sim = None self.max_episode_len = None self.action_dim = None self.text_tokenizer = None if self.use_text_instruction: try: from transformers import DistilBertTokenizerFast except ImportError as exc: raise ImportError( 'transformers is required for text instruction loading. ' 'Install it with: pip install transformers' ) from exc self.text_tokenizer = DistilBertTokenizerFast.from_pretrained(text_tokenizer_name) self._init_episode_shapes() self.__getitem__(0) # initialize self.is_sim def _apply_image_augmentation(self, all_cam_images): """ Apply identical augmentation parameters to all camera images for one sample. all_cam_images: np.ndarray [K, H, W, C], uint8 """ imgs = torch.from_numpy(all_cam_images).float() / 255.0 imgs = torch.einsum('k h w c -> k c h w', imgs) cfg = self.image_aug_cfg # color jitter (shared params) if np.random.rand() < cfg['p_color']: b = 1.0 + np.random.uniform(-cfg['brightness'], cfg['brightness']) c = 1.0 + np.random.uniform(-cfg['contrast'], cfg['contrast']) s = 1.0 + np.random.uniform(-cfg['saturation'], cfg['saturation']) h = np.random.uniform(-cfg['hue'], cfg['hue']) for cam_idx in range(imgs.shape[0]): img = imgs[cam_idx] img = TF.adjust_brightness(img, b) img = TF.adjust_contrast(img, c) img = TF.adjust_saturation(img, s) img = TF.adjust_hue(img, h) imgs[cam_idx] = img # synthetic highlight / glare (shared parameters) if np.random.rand() < cfg['p_highlight']: _, h_img, w_img = imgs[0].shape cx = np.random.uniform(0.2 * w_img, 0.8 * w_img) cy = np.random.uniform(0.2 * h_img, 0.8 * h_img) sigma = np.random.uniform(0.08, 0.2) * min(h_img, w_img) strength = np.random.uniform(*cfg['highlight_strength']) yy, xx = torch.meshgrid( torch.arange(h_img, dtype=torch.float32), torch.arange(w_img, dtype=torch.float32), indexing='ij', ) gauss = torch.exp(-((xx - cx) ** 2 + (yy - cy) ** 2) / (2.0 * sigma * sigma)) gauss = (gauss * strength).unsqueeze(0) imgs = imgs + gauss # gaussian noise if np.random.rand() < cfg['p_noise']: noise_std = np.random.uniform(*cfg['noise_std']) imgs = imgs + torch.randn_like(imgs) * noise_std # gaussian blur if np.random.rand() < cfg['p_blur']: kernel = int(np.random.choice(cfg['blur_kernel_choices'])) sigma = float(np.random.uniform(*cfg['blur_sigma'])) for cam_idx in range(imgs.shape[0]): imgs[cam_idx] = TF.gaussian_blur( imgs[cam_idx], kernel_size=[kernel, kernel], sigma=[sigma, sigma], ) imgs = imgs.clamp(0.0, 1.0) imgs = torch.einsum('k c h w -> k h w c', imgs) imgs = (imgs * 255.0).byte().cpu().numpy() return imgs def _init_episode_shapes(self): max_len = 0 action_dim = None for episode_id in self.episode_ids: dataset_path = os.path.join(self.dataset_dir, f'episode_{episode_id}.hdf5') with h5py.File(dataset_path, 'r') as root: shape = root['/action'].shape if len(shape) != 2: raise ValueError(f'Expected /action to have shape [T, D], got {shape} in {dataset_path}') max_len = max(max_len, int(shape[0])) if action_dim is None: action_dim = int(shape[1]) elif int(shape[1]) != action_dim: raise ValueError( f'Inconsistent action dim in dataset. Expected {action_dim}, got {shape[1]} in {dataset_path}' ) if max_len <= 0 or action_dim is None: raise ValueError(f'Invalid dataset metadata in {self.dataset_dir}') self.max_episode_len = max_len self.action_dim = action_dim @staticmethod def _decode_instruction(raw_value): if raw_value is None: return '' if isinstance(raw_value, bytes): return raw_value.decode('utf-8') if isinstance(raw_value, np.bytes_): return raw_value.tobytes().decode('utf-8') if isinstance(raw_value, np.ndarray): if raw_value.shape == (): return EpisodicDataset._decode_instruction(raw_value.item()) if raw_value.size == 0: return '' return EpisodicDataset._decode_instruction(raw_value.reshape(-1)[0]) return str(raw_value) def __len__(self): return len(self.episode_ids) def __getitem__(self, index): sample_full_episode = False # hardcode episode_id = self.episode_ids[index] dataset_path = os.path.join(self.dataset_dir, f'episode_{episode_id}.hdf5') with h5py.File(dataset_path, 'r') as root: is_sim = bool(root.attrs.get('sim', False)) original_action_shape = root['/action'].shape episode_len = original_action_shape[0] if sample_full_episode: start_ts = 0 else: start_ts = np.random.choice(episode_len) # get observation at start_ts only qpos = root['/observations/qpos'][start_ts] image_dict = dict() for cam_name in self.camera_names: image_dict[cam_name] = root[f'/observations/images/{cam_name}'][start_ts] instruction = '' text_feature = None if self.use_text_instruction: effective_mode = self.instruction_mode if effective_mode == 'timestep-level' and '/instruction_timestep' in root: instruction = self._decode_instruction(root['/instruction_timestep'][start_ts]) elif '/instruction' in root: instruction_node = root['/instruction'] if getattr(instruction_node, 'shape', ()) == (): instruction = self._decode_instruction(instruction_node[()]) else: if len(instruction_node.shape) == 1 and instruction_node.shape[0] == episode_len: instruction = self._decode_instruction(instruction_node[start_ts]) else: instruction = self._decode_instruction(instruction_node[0]) if self.use_cached_text_features: if effective_mode == 'timestep-level' and '/instruction_features_timestep' in root: text_feature = root['/instruction_features_timestep'][start_ts] elif '/instruction_features' in root: feat_node = root['/instruction_features'] if getattr(feat_node, 'shape', ()) == (): text_feature = np.array(feat_node[()]) elif len(feat_node.shape) == 1: text_feature = feat_node[()] elif len(feat_node.shape) == 2 and feat_node.shape[0] == episode_len: text_feature = feat_node[start_ts] else: text_feature = feat_node[0] # get all actions after and including start_ts if is_sim: action = root['/action'][start_ts:] action_len = episode_len - start_ts else: action = root['/action'][max(0, start_ts - 1):] # hack, to make timesteps more aligned action_len = episode_len - max(0, start_ts - 1) # hack, to make timesteps more aligned self.is_sim = is_sim padded_action = np.zeros((self.max_episode_len, self.action_dim), dtype=np.float32) padded_action[:action_len] = action is_pad = np.ones(self.max_episode_len) is_pad[:action_len] = 0 # new axis for different cameras all_cam_images = [] for cam_name in self.camera_names: all_cam_images.append(image_dict[cam_name]) all_cam_images = np.stack(all_cam_images, axis=0) if self.image_augment: all_cam_images = self._apply_image_augmentation(all_cam_images) # construct observations image_data = torch.from_numpy(all_cam_images) qpos_data = torch.from_numpy(qpos).float() action_data = torch.from_numpy(padded_action).float() is_pad = torch.from_numpy(is_pad).bool() # channel last image_data = torch.einsum('k h w c -> k c h w', image_data) # normalize image and change dtype to float image_data = image_data / 255.0 action_data = (action_data - self.norm_stats["action_mean"]) / self.norm_stats["action_std"] qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats["qpos_std"] if self.use_text_instruction and text_feature is not None: text_feature_data = torch.from_numpy(np.array(text_feature)).float() text_feature_valid = torch.tensor(True, dtype=torch.bool) text_input_ids = torch.zeros(1, dtype=torch.long) text_attention_mask = torch.zeros(1, dtype=torch.long) elif self.use_text_instruction: tokenized = self.text_tokenizer( instruction, padding='max_length', truncation=True, max_length=self.text_max_length, return_tensors='pt', ) text_input_ids = tokenized['input_ids'].squeeze(0).long() text_attention_mask = tokenized['attention_mask'].squeeze(0).long() text_feature_data = torch.zeros(self.text_feature_dim, dtype=torch.float32) text_feature_valid = torch.tensor(False, dtype=torch.bool) else: text_input_ids = torch.zeros(1, dtype=torch.long) text_attention_mask = torch.zeros(1, dtype=torch.long) text_feature_data = torch.zeros(self.text_feature_dim, dtype=torch.float32) text_feature_valid = torch.tensor(False, dtype=torch.bool) return image_data, qpos_data, action_data, is_pad, text_input_ids, text_attention_mask, text_feature_data, text_feature_valid def _discover_episode_ids(dataset_dir, num_episodes=None): pattern = re.compile(r'^episode_(\d+)\.hdf5$') episode_ids = [] for fname in os.listdir(dataset_dir): m = pattern.match(fname) if m: episode_ids.append(int(m.group(1))) episode_ids.sort() if num_episodes is not None: episode_ids = episode_ids[:num_episodes] return episode_ids def get_norm_stats(dataset_dir, episode_ids): all_qpos_data = [] all_action_data = [] example_qpos = None for episode_idx in episode_ids: dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}.hdf5') with h5py.File(dataset_path, 'r') as root: qpos = root['/observations/qpos'][()] action = root['/action'][()] qpos_t = torch.from_numpy(qpos) action_t = torch.from_numpy(action) all_qpos_data.append(qpos_t) all_action_data.append(action_t) if example_qpos is None and len(qpos) > 0: example_qpos = qpos[0] # Episodes may have different lengths; concatenate over time axis. all_qpos_data = torch.cat(all_qpos_data, dim=0) all_action_data = torch.cat(all_action_data, dim=0) # normalize action data action_mean = all_action_data.mean(dim=0, keepdim=True) action_std = all_action_data.std(dim=0, keepdim=True) action_std = torch.clip(action_std, 1e-2, np.inf) # clipping # normalize qpos data qpos_mean = all_qpos_data.mean(dim=0, keepdim=True) qpos_std = all_qpos_data.std(dim=0, keepdim=True) qpos_std = torch.clip(qpos_std, 1e-2, np.inf) # clipping stats = {"action_mean": action_mean.numpy().squeeze(), "action_std": action_std.numpy().squeeze(), "qpos_mean": qpos_mean.numpy().squeeze(), "qpos_std": qpos_std.numpy().squeeze(), "example_qpos": example_qpos} return stats def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val, use_text_instruction=False, instruction_mode='timestep-level', use_cached_text_features=True, text_feature_dim=768, text_tokenizer_name='distilbert-base-uncased', text_max_length=32, image_augment=False, image_aug_cfg=None): print(f'\nData from: {dataset_dir}\n') episode_ids = _discover_episode_ids(dataset_dir, num_episodes) if len(episode_ids) == 0: raise FileNotFoundError(f'No episode_*.hdf5 found in {dataset_dir}') if len(episode_ids) < 2: raise ValueError(f'Need at least 2 episodes for train/val split, found {len(episode_ids)} in {dataset_dir}') # obtain train test split train_ratio = 0.8 shuffled_indices = np.random.permutation(len(episode_ids)) train_count = int(train_ratio * len(episode_ids)) train_indices = shuffled_indices[:train_count] val_indices = shuffled_indices[train_count:] train_episode_ids = np.array(episode_ids)[train_indices] val_episode_ids = np.array(episode_ids)[val_indices] # obtain normalization stats for qpos and action norm_stats = get_norm_stats(dataset_dir, episode_ids) # construct dataset and dataloader train_dataset = EpisodicDataset( train_episode_ids, dataset_dir, camera_names, norm_stats, use_text_instruction=use_text_instruction, instruction_mode=instruction_mode, use_cached_text_features=use_cached_text_features, text_feature_dim=text_feature_dim, text_tokenizer_name=text_tokenizer_name, text_max_length=text_max_length, image_augment=image_augment, image_aug_cfg=image_aug_cfg, ) val_dataset = EpisodicDataset( val_episode_ids, dataset_dir, camera_names, norm_stats, use_text_instruction=use_text_instruction, instruction_mode=instruction_mode, use_cached_text_features=use_cached_text_features, text_feature_dim=text_feature_dim, text_tokenizer_name=text_tokenizer_name, text_max_length=text_max_length, image_augment=False, ) train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1) val_dataloader = DataLoader(val_dataset, batch_size=batch_size_val, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1) return train_dataloader, val_dataloader, norm_stats, train_dataset.is_sim ### env utils def sample_box_pose(): x_range = [0.0, 0.2] y_range = [0.4, 0.6] z_range = [0.05, 0.05] ranges = np.vstack([x_range, y_range, z_range]) cube_position = np.random.uniform(ranges[:, 0], ranges[:, 1]) cube_quat = np.array([1, 0, 0, 0]) return np.concatenate([cube_position, cube_quat]) def sample_insertion_pose(): # Peg x_range = [0.1, 0.2] y_range = [0.4, 0.6] z_range = [0.05, 0.05] ranges = np.vstack([x_range, y_range, z_range]) peg_position = np.random.uniform(ranges[:, 0], ranges[:, 1]) peg_quat = np.array([1, 0, 0, 0]) peg_pose = np.concatenate([peg_position, peg_quat]) # Socket x_range = [-0.2, -0.1] y_range = [0.4, 0.6] z_range = [0.05, 0.05] ranges = np.vstack([x_range, y_range, z_range]) socket_position = np.random.uniform(ranges[:, 0], ranges[:, 1]) socket_quat = np.array([1, 0, 0, 0]) socket_pose = np.concatenate([socket_position, socket_quat]) return peg_pose, socket_pose ### helper functions def compute_dict_mean(epoch_dicts): result = {k: None for k in epoch_dicts[0]} num_items = len(epoch_dicts) for k in result: value_sum = 0 for epoch_dict in epoch_dicts: value_sum += epoch_dict[k] result[k] = value_sum / num_items return result def detach_dict(d): new_d = dict() for k, v in d.items(): new_d[k] = v.detach() return new_d def set_seed(seed): torch.manual_seed(seed) np.random.seed(seed)