Files
aloha/utils.py
2026-02-19 21:29:32 +08:00

451 lines
18 KiB
Python

import numpy as np
import torch
import os
import h5py
import re
from torch.utils.data import TensorDataset, DataLoader
import torchvision.transforms.functional as TF
import IPython
e = IPython.embed
class EpisodicDataset(torch.utils.data.Dataset):
def __init__(self, episode_ids, dataset_dir, camera_names, norm_stats,
use_text_instruction=False,
instruction_mode='timestep-level',
use_cached_text_features=True,
text_feature_dim=768,
text_tokenizer_name='distilbert-base-uncased',
text_max_length=32,
image_augment=False,
image_aug_cfg=None):
super(EpisodicDataset).__init__()
self.episode_ids = episode_ids
self.dataset_dir = dataset_dir
self.camera_names = camera_names
self.norm_stats = norm_stats
self.use_text_instruction = use_text_instruction
self.instruction_mode = instruction_mode
self.use_cached_text_features = use_cached_text_features
self.text_feature_dim = text_feature_dim
self.text_max_length = text_max_length
self.image_augment = image_augment
self.image_aug_cfg = {
'p_color': 0.8,
'p_highlight': 0.5,
'p_noise': 0.5,
'p_blur': 0.3,
'brightness': 0.25,
'contrast': 0.25,
'saturation': 0.25,
'hue': 0.08,
'highlight_strength': (0.15, 0.5),
'noise_std': (0.005, 0.03),
'blur_sigma': (0.1, 1.5),
'blur_kernel_choices': (3, 5),
}
if image_aug_cfg is not None:
self.image_aug_cfg.update(image_aug_cfg)
self.is_sim = None
self.max_episode_len = None
self.action_dim = None
self.text_tokenizer = None
if self.use_text_instruction:
try:
from transformers import DistilBertTokenizerFast
except ImportError as exc:
raise ImportError(
'transformers is required for text instruction loading. '
'Install it with: pip install transformers'
) from exc
self.text_tokenizer = DistilBertTokenizerFast.from_pretrained(text_tokenizer_name)
self._init_episode_shapes()
self.__getitem__(0) # initialize self.is_sim
def _apply_image_augmentation(self, all_cam_images):
"""
Apply identical augmentation parameters to all camera images for one sample.
all_cam_images: np.ndarray [K, H, W, C], uint8
"""
imgs = torch.from_numpy(all_cam_images).float() / 255.0
imgs = torch.einsum('k h w c -> k c h w', imgs)
cfg = self.image_aug_cfg
# color jitter (shared params)
if np.random.rand() < cfg['p_color']:
b = 1.0 + np.random.uniform(-cfg['brightness'], cfg['brightness'])
c = 1.0 + np.random.uniform(-cfg['contrast'], cfg['contrast'])
s = 1.0 + np.random.uniform(-cfg['saturation'], cfg['saturation'])
h = np.random.uniform(-cfg['hue'], cfg['hue'])
for cam_idx in range(imgs.shape[0]):
img = imgs[cam_idx]
img = TF.adjust_brightness(img, b)
img = TF.adjust_contrast(img, c)
img = TF.adjust_saturation(img, s)
img = TF.adjust_hue(img, h)
imgs[cam_idx] = img
# synthetic highlight / glare (shared parameters)
if np.random.rand() < cfg['p_highlight']:
_, h_img, w_img = imgs[0].shape
cx = np.random.uniform(0.2 * w_img, 0.8 * w_img)
cy = np.random.uniform(0.2 * h_img, 0.8 * h_img)
sigma = np.random.uniform(0.08, 0.2) * min(h_img, w_img)
strength = np.random.uniform(*cfg['highlight_strength'])
yy, xx = torch.meshgrid(
torch.arange(h_img, dtype=torch.float32),
torch.arange(w_img, dtype=torch.float32),
indexing='ij',
)
gauss = torch.exp(-((xx - cx) ** 2 + (yy - cy) ** 2) / (2.0 * sigma * sigma))
gauss = (gauss * strength).unsqueeze(0)
imgs = imgs + gauss
# gaussian noise
if np.random.rand() < cfg['p_noise']:
noise_std = np.random.uniform(*cfg['noise_std'])
imgs = imgs + torch.randn_like(imgs) * noise_std
# gaussian blur
if np.random.rand() < cfg['p_blur']:
kernel = int(np.random.choice(cfg['blur_kernel_choices']))
sigma = float(np.random.uniform(*cfg['blur_sigma']))
for cam_idx in range(imgs.shape[0]):
imgs[cam_idx] = TF.gaussian_blur(
imgs[cam_idx],
kernel_size=[kernel, kernel],
sigma=[sigma, sigma],
)
imgs = imgs.clamp(0.0, 1.0)
imgs = torch.einsum('k c h w -> k h w c', imgs)
imgs = (imgs * 255.0).byte().cpu().numpy()
return imgs
def _init_episode_shapes(self):
max_len = 0
action_dim = None
for episode_id in self.episode_ids:
dataset_path = os.path.join(self.dataset_dir, f'episode_{episode_id}.hdf5')
with h5py.File(dataset_path, 'r') as root:
shape = root['/action'].shape
if len(shape) != 2:
raise ValueError(f'Expected /action to have shape [T, D], got {shape} in {dataset_path}')
max_len = max(max_len, int(shape[0]))
if action_dim is None:
action_dim = int(shape[1])
elif int(shape[1]) != action_dim:
raise ValueError(
f'Inconsistent action dim in dataset. Expected {action_dim}, got {shape[1]} in {dataset_path}'
)
if max_len <= 0 or action_dim is None:
raise ValueError(f'Invalid dataset metadata in {self.dataset_dir}')
self.max_episode_len = max_len
self.action_dim = action_dim
@staticmethod
def _decode_instruction(raw_value):
if raw_value is None:
return ''
if isinstance(raw_value, bytes):
return raw_value.decode('utf-8')
if isinstance(raw_value, np.bytes_):
return raw_value.tobytes().decode('utf-8')
if isinstance(raw_value, np.ndarray):
if raw_value.shape == ():
return EpisodicDataset._decode_instruction(raw_value.item())
if raw_value.size == 0:
return ''
return EpisodicDataset._decode_instruction(raw_value.reshape(-1)[0])
return str(raw_value)
def __len__(self):
return len(self.episode_ids)
def __getitem__(self, index):
sample_full_episode = False # hardcode
episode_id = self.episode_ids[index]
dataset_path = os.path.join(self.dataset_dir, f'episode_{episode_id}.hdf5')
with h5py.File(dataset_path, 'r') as root:
is_sim = bool(root.attrs.get('sim', False))
original_action_shape = root['/action'].shape
episode_len = original_action_shape[0]
if sample_full_episode:
start_ts = 0
else:
start_ts = np.random.choice(episode_len)
# get observation at start_ts only
qpos = root['/observations/qpos'][start_ts]
image_dict = dict()
for cam_name in self.camera_names:
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][start_ts]
instruction = ''
text_feature = None
if self.use_text_instruction:
effective_mode = self.instruction_mode
if effective_mode == 'timestep-level' and '/instruction_timestep' in root:
instruction = self._decode_instruction(root['/instruction_timestep'][start_ts])
elif '/instruction' in root:
instruction_node = root['/instruction']
if getattr(instruction_node, 'shape', ()) == ():
instruction = self._decode_instruction(instruction_node[()])
else:
if len(instruction_node.shape) == 1 and instruction_node.shape[0] == episode_len:
instruction = self._decode_instruction(instruction_node[start_ts])
else:
instruction = self._decode_instruction(instruction_node[0])
if self.use_cached_text_features:
if effective_mode == 'timestep-level' and '/instruction_features_timestep' in root:
text_feature = root['/instruction_features_timestep'][start_ts]
elif '/instruction_features' in root:
feat_node = root['/instruction_features']
if getattr(feat_node, 'shape', ()) == ():
text_feature = np.array(feat_node[()])
elif len(feat_node.shape) == 1:
text_feature = feat_node[()]
elif len(feat_node.shape) == 2 and feat_node.shape[0] == episode_len:
text_feature = feat_node[start_ts]
else:
text_feature = feat_node[0]
# get all actions after and including start_ts
if is_sim:
action = root['/action'][start_ts:]
action_len = episode_len - start_ts
else:
action = root['/action'][max(0, start_ts - 1):] # hack, to make timesteps more aligned
action_len = episode_len - max(0, start_ts - 1) # hack, to make timesteps more aligned
self.is_sim = is_sim
padded_action = np.zeros((self.max_episode_len, self.action_dim), dtype=np.float32)
padded_action[:action_len] = action
is_pad = np.ones(self.max_episode_len)
is_pad[:action_len] = 0
# new axis for different cameras
all_cam_images = []
for cam_name in self.camera_names:
all_cam_images.append(image_dict[cam_name])
all_cam_images = np.stack(all_cam_images, axis=0)
if self.image_augment:
all_cam_images = self._apply_image_augmentation(all_cam_images)
# construct observations
image_data = torch.from_numpy(all_cam_images)
qpos_data = torch.from_numpy(qpos).float()
action_data = torch.from_numpy(padded_action).float()
is_pad = torch.from_numpy(is_pad).bool()
# channel last
image_data = torch.einsum('k h w c -> k c h w', image_data)
# normalize image and change dtype to float
image_data = image_data / 255.0
action_data = (action_data - self.norm_stats["action_mean"]) / self.norm_stats["action_std"]
qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats["qpos_std"]
if self.use_text_instruction and text_feature is not None:
text_feature_data = torch.from_numpy(np.array(text_feature)).float()
text_feature_valid = torch.tensor(True, dtype=torch.bool)
text_input_ids = torch.zeros(1, dtype=torch.long)
text_attention_mask = torch.zeros(1, dtype=torch.long)
elif self.use_text_instruction:
tokenized = self.text_tokenizer(
instruction,
padding='max_length',
truncation=True,
max_length=self.text_max_length,
return_tensors='pt',
)
text_input_ids = tokenized['input_ids'].squeeze(0).long()
text_attention_mask = tokenized['attention_mask'].squeeze(0).long()
text_feature_data = torch.zeros(self.text_feature_dim, dtype=torch.float32)
text_feature_valid = torch.tensor(False, dtype=torch.bool)
else:
text_input_ids = torch.zeros(1, dtype=torch.long)
text_attention_mask = torch.zeros(1, dtype=torch.long)
text_feature_data = torch.zeros(self.text_feature_dim, dtype=torch.float32)
text_feature_valid = torch.tensor(False, dtype=torch.bool)
return image_data, qpos_data, action_data, is_pad, text_input_ids, text_attention_mask, text_feature_data, text_feature_valid
def _discover_episode_ids(dataset_dir, num_episodes=None):
pattern = re.compile(r'^episode_(\d+)\.hdf5$')
episode_ids = []
for fname in os.listdir(dataset_dir):
m = pattern.match(fname)
if m:
episode_ids.append(int(m.group(1)))
episode_ids.sort()
if num_episodes is not None:
episode_ids = episode_ids[:num_episodes]
return episode_ids
def get_norm_stats(dataset_dir, episode_ids):
all_qpos_data = []
all_action_data = []
example_qpos = None
for episode_idx in episode_ids:
dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}.hdf5')
with h5py.File(dataset_path, 'r') as root:
qpos = root['/observations/qpos'][()]
action = root['/action'][()]
qpos_t = torch.from_numpy(qpos)
action_t = torch.from_numpy(action)
all_qpos_data.append(qpos_t)
all_action_data.append(action_t)
if example_qpos is None and len(qpos) > 0:
example_qpos = qpos[0]
# Episodes may have different lengths; concatenate over time axis.
all_qpos_data = torch.cat(all_qpos_data, dim=0)
all_action_data = torch.cat(all_action_data, dim=0)
# normalize action data
action_mean = all_action_data.mean(dim=0, keepdim=True)
action_std = all_action_data.std(dim=0, keepdim=True)
action_std = torch.clip(action_std, 1e-2, np.inf) # clipping
# normalize qpos data
qpos_mean = all_qpos_data.mean(dim=0, keepdim=True)
qpos_std = all_qpos_data.std(dim=0, keepdim=True)
qpos_std = torch.clip(qpos_std, 1e-2, np.inf) # clipping
stats = {"action_mean": action_mean.numpy().squeeze(), "action_std": action_std.numpy().squeeze(),
"qpos_mean": qpos_mean.numpy().squeeze(), "qpos_std": qpos_std.numpy().squeeze(),
"example_qpos": example_qpos}
return stats
def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val,
use_text_instruction=False,
instruction_mode='timestep-level',
use_cached_text_features=True,
text_feature_dim=768,
text_tokenizer_name='distilbert-base-uncased',
text_max_length=32,
image_augment=False,
image_aug_cfg=None):
print(f'\nData from: {dataset_dir}\n')
episode_ids = _discover_episode_ids(dataset_dir, num_episodes)
if len(episode_ids) == 0:
raise FileNotFoundError(f'No episode_*.hdf5 found in {dataset_dir}')
if len(episode_ids) < 2:
raise ValueError(f'Need at least 2 episodes for train/val split, found {len(episode_ids)} in {dataset_dir}')
# obtain train test split
train_ratio = 0.8
shuffled_indices = np.random.permutation(len(episode_ids))
train_count = int(train_ratio * len(episode_ids))
train_indices = shuffled_indices[:train_count]
val_indices = shuffled_indices[train_count:]
train_episode_ids = np.array(episode_ids)[train_indices]
val_episode_ids = np.array(episode_ids)[val_indices]
# obtain normalization stats for qpos and action
norm_stats = get_norm_stats(dataset_dir, episode_ids)
# construct dataset and dataloader
train_dataset = EpisodicDataset(
train_episode_ids,
dataset_dir,
camera_names,
norm_stats,
use_text_instruction=use_text_instruction,
instruction_mode=instruction_mode,
use_cached_text_features=use_cached_text_features,
text_feature_dim=text_feature_dim,
text_tokenizer_name=text_tokenizer_name,
text_max_length=text_max_length,
image_augment=image_augment,
image_aug_cfg=image_aug_cfg,
)
val_dataset = EpisodicDataset(
val_episode_ids,
dataset_dir,
camera_names,
norm_stats,
use_text_instruction=use_text_instruction,
instruction_mode=instruction_mode,
use_cached_text_features=use_cached_text_features,
text_feature_dim=text_feature_dim,
text_tokenizer_name=text_tokenizer_name,
text_max_length=text_max_length,
image_augment=False,
)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size_val, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
return train_dataloader, val_dataloader, norm_stats, train_dataset.is_sim
### env utils
def sample_box_pose():
x_range = [0.0, 0.2]
y_range = [0.4, 0.6]
z_range = [0.05, 0.05]
ranges = np.vstack([x_range, y_range, z_range])
cube_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
cube_quat = np.array([1, 0, 0, 0])
return np.concatenate([cube_position, cube_quat])
def sample_insertion_pose():
# Peg
x_range = [0.1, 0.2]
y_range = [0.4, 0.6]
z_range = [0.05, 0.05]
ranges = np.vstack([x_range, y_range, z_range])
peg_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
peg_quat = np.array([1, 0, 0, 0])
peg_pose = np.concatenate([peg_position, peg_quat])
# Socket
x_range = [-0.2, -0.1]
y_range = [0.4, 0.6]
z_range = [0.05, 0.05]
ranges = np.vstack([x_range, y_range, z_range])
socket_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
socket_quat = np.array([1, 0, 0, 0])
socket_pose = np.concatenate([socket_position, socket_quat])
return peg_pose, socket_pose
### helper functions
def compute_dict_mean(epoch_dicts):
result = {k: None for k in epoch_dicts[0]}
num_items = len(epoch_dicts)
for k in result:
value_sum = 0
for epoch_dict in epoch_dicts:
value_sum += epoch_dict[k]
result[k] = value_sum / num_items
return result
def detach_dict(d):
new_d = dict()
for k, v in d.items():
new_d[k] = v.detach()
return new_d
def set_seed(seed):
torch.manual_seed(seed)
np.random.seed(seed)