Files
aloha/imitate_episodes.py
2026-02-20 16:45:16 +08:00

604 lines
25 KiB
Python

import torch
import numpy as np
import os
import pickle
import argparse
import matplotlib.pyplot as plt
from copy import deepcopy
from tqdm import tqdm
from einops import rearrange
from constants import DT
from constants import PUPPET_GRIPPER_JOINT_OPEN
from constants import SIM_TASK_CONFIGS, ENDOSCOPE_TASK_CONFIGS
from utils import load_data # data functions
from utils import sample_box_pose, sample_insertion_pose # robot functions
from utils import compute_dict_mean, set_seed, detach_dict # helper functions
from policy import ACTPolicy, CNNMLPPolicy
from visualize_episodes import save_videos
import IPython
e = IPython.embed
def load_checkpoint_state_dict(ckpt_path):
"""Load checkpoint state_dict safely across different torch versions."""
try:
return torch.load(ckpt_path, map_location='cpu', weights_only=True)
except TypeError:
# For older PyTorch versions that do not support `weights_only`.
return torch.load(ckpt_path, map_location='cpu')
def main(args):
set_seed(1)
# command line parameters
is_eval = args['eval']
ckpt_dir = args['ckpt_dir']
policy_class = args['policy_class']
onscreen_render = args['onscreen_render']
task_name = args['task_name']
batch_size_train = args['batch_size']
batch_size_val = args['batch_size']
num_epochs = args['num_epochs']
# get task parameters
is_endoscope = task_name in ENDOSCOPE_TASK_CONFIGS
if is_endoscope:
task_config = ENDOSCOPE_TASK_CONFIGS[task_name]
is_sim = False
elif task_name in SIM_TASK_CONFIGS:
task_config = SIM_TASK_CONFIGS[task_name]
is_sim = True
else:
from aloha_scripts.constants import TASK_CONFIGS
task_config = TASK_CONFIGS[task_name]
is_sim = False
dataset_dir = task_config['dataset_dir']
num_episodes = task_config['num_episodes']
episode_len = task_config['episode_len']
camera_names = task_config['camera_names']
state_dim = task_config.get('state_dim', 14)
action_dim = task_config.get('action_dim', state_dim)
use_text_instruction = task_config.get('use_text_instruction', False)
instruction_mode = task_config.get('instruction_mode', 'timestep-level')
use_cached_text_features = task_config.get('use_cached_text_features', True)
text_encoder_type = task_config.get('text_encoder_type', 'distilbert')
text_feature_dim = task_config.get('text_feature_dim', 768)
text_fusion_type = task_config.get('text_fusion_type', 'concat_transformer_input')
text_max_length = task_config.get('text_max_length', 32)
text_tokenizer_name = task_config.get('text_tokenizer_name', 'distilbert-base-uncased')
freeze_text_encoder = task_config.get('freeze_text_encoder', True)
real_action_t_minus_1 = task_config.get('real_action_t_minus_1', True)
if args.get('text_encoder_type') is not None:
text_encoder_type = args['text_encoder_type']
if args.get('text_max_length') is not None:
text_max_length = args['text_max_length']
if args.get('freeze_text_encoder', False):
freeze_text_encoder = True
if args.get('disable_real_action_shift', False):
real_action_t_minus_1 = False
# fixed parameters
lr_backbone = 1e-5
backbone = 'resnet18'
if policy_class == 'ACT':
enc_layers = 2
dec_layers = 4
nheads = 8
policy_config = {'lr': args['lr'],
'num_queries': args['chunk_size'],
'kl_weight': args['kl_weight'],
'hidden_dim': args['hidden_dim'],
'dim_feedforward': args['dim_feedforward'],
'lr_backbone': lr_backbone,
'backbone': backbone,
'enc_layers': enc_layers,
'dec_layers': dec_layers,
'nheads': nheads,
'camera_names': camera_names,
'state_dim': state_dim,
'action_dim': action_dim,
'use_text': use_text_instruction,
'text_encoder_type': text_encoder_type,
'text_feature_dim': text_feature_dim,
'text_fusion_type': text_fusion_type,
'freeze_text_encoder': freeze_text_encoder,
'instruction_mode': instruction_mode,
'use_cached_text_features': use_cached_text_features,
'text_max_length': text_max_length,
'text_tokenizer_name': text_tokenizer_name,
}
elif policy_class == 'CNNMLP':
policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1,
'camera_names': camera_names,
'state_dim': state_dim,
'action_dim': action_dim,
'use_text': use_text_instruction,
}
else:
raise NotImplementedError
config = {
'num_epochs': num_epochs,
'train_steps_per_epoch': args.get('train_steps_per_epoch', None),
'resume_ckpt_path': args.get('resume_ckpt', None),
'ckpt_dir': ckpt_dir,
'episode_len': episode_len,
'state_dim': state_dim,
'action_dim': action_dim,
'lr': args['lr'],
'policy_class': policy_class,
'onscreen_render': onscreen_render,
'policy_config': policy_config,
'task_name': task_name,
'seed': args['seed'],
'temporal_agg': args['temporal_agg'],
'camera_names': camera_names,
'real_robot': (not is_sim) and (not is_endoscope),
'use_text_instruction': use_text_instruction,
'instruction_mode': instruction_mode,
'use_cached_text_features': use_cached_text_features,
'text_tokenizer_name': text_tokenizer_name,
'text_max_length': text_max_length,
'debug_input': args.get('debug_input', False),
}
if config['resume_ckpt_path']:
resume_ckpt_path = config['resume_ckpt_path']
if not os.path.isabs(resume_ckpt_path):
candidate_path = os.path.join(ckpt_dir, resume_ckpt_path)
if os.path.isfile(candidate_path):
resume_ckpt_path = candidate_path
if not os.path.isfile(resume_ckpt_path):
raise FileNotFoundError(f'--resume_ckpt not found: {config["resume_ckpt_path"]}')
config['resume_ckpt_path'] = resume_ckpt_path
if is_eval:
ckpt_names = [f'policy_best.ckpt']
results = []
for ckpt_name in ckpt_names:
success_rate, avg_return = eval_bc(config, ckpt_name, save_episode=True)
results.append([ckpt_name, success_rate, avg_return])
for ckpt_name, success_rate, avg_return in results:
print(f'{ckpt_name}: {success_rate=} {avg_return=}')
print()
exit()
train_dataloader, val_dataloader, stats, _ = load_data(
dataset_dir,
num_episodes,
camera_names,
batch_size_train,
batch_size_val,
use_text_instruction=use_text_instruction,
instruction_mode=instruction_mode,
use_cached_text_features=use_cached_text_features,
text_feature_dim=text_feature_dim,
text_tokenizer_name=text_tokenizer_name,
text_max_length=text_max_length,
real_action_t_minus_1=real_action_t_minus_1,
image_augment=args['image_aug'],
)
# save dataset stats
if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir)
stats_path = os.path.join(ckpt_dir, f'dataset_stats.pkl')
with open(stats_path, 'wb') as f:
pickle.dump(stats, f)
best_ckpt_info = train_bc(train_dataloader, val_dataloader, config)
best_epoch, min_val_loss, best_state_dict = best_ckpt_info
# save best checkpoint
ckpt_path = os.path.join(ckpt_dir, f'policy_best.ckpt')
torch.save(best_state_dict, ckpt_path)
print(f'Best ckpt, val loss {min_val_loss:.6f} @ epoch{best_epoch}')
def make_policy(policy_class, policy_config):
if policy_class == 'ACT':
policy = ACTPolicy(policy_config)
elif policy_class == 'CNNMLP':
policy = CNNMLPPolicy(policy_config)
else:
raise NotImplementedError
return policy
def make_optimizer(policy_class, policy):
if policy_class == 'ACT':
optimizer = policy.configure_optimizers()
elif policy_class == 'CNNMLP':
optimizer = policy.configure_optimizers()
else:
raise NotImplementedError
return optimizer
def get_image(ts, camera_names):
curr_images = []
for cam_name in camera_names:
curr_image = rearrange(ts.observation['images'][cam_name], 'h w c -> c h w')
curr_images.append(curr_image)
curr_image = np.stack(curr_images, axis=0)
curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0)
return curr_image
def eval_bc(config, ckpt_name, save_episode=True):
set_seed(1000)
ckpt_dir = config['ckpt_dir']
state_dim = config['state_dim']
action_dim = config['action_dim']
real_robot = config['real_robot']
policy_class = config['policy_class']
onscreen_render = config['onscreen_render']
policy_config = config['policy_config']
camera_names = config['camera_names']
max_timesteps = config['episode_len']
task_name = config['task_name']
temporal_agg = config['temporal_agg']
onscreen_cam = 'angle'
BOX_POSE = None
# load policy and stats
ckpt_path = os.path.join(ckpt_dir, ckpt_name)
policy = make_policy(policy_class, policy_config)
loading_status = policy.load_state_dict(load_checkpoint_state_dict(ckpt_path))
print(loading_status)
policy.cuda()
policy.eval()
print(f'Loaded: {ckpt_path}')
stats_path = os.path.join(ckpt_dir, f'dataset_stats.pkl')
with open(stats_path, 'rb') as f:
stats = pickle.load(f)
pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
post_process = lambda a: a * stats['action_std'] + stats['action_mean']
# load environment
if real_robot:
from aloha_scripts.robot_utils import move_grippers # requires aloha
from aloha_scripts.real_env import make_real_env # requires aloha
env = make_real_env(init_node=True)
env_max_reward = 0
else:
from sim_env import make_sim_env
env = make_sim_env(task_name)
env_max_reward = env.task.max_reward
query_frequency = policy_config['num_queries']
if temporal_agg:
query_frequency = 1
num_queries = policy_config['num_queries']
max_timesteps = int(max_timesteps * 1) # may increase for real-world tasks
num_rollouts = 50
episode_returns = []
highest_rewards = []
for rollout_id in range(num_rollouts):
rollout_id += 0
### set task
if 'sim_transfer_cube' in task_name:
if BOX_POSE is None:
from sim_env import BOX_POSE as _BOX_POSE
BOX_POSE = _BOX_POSE
BOX_POSE[0] = sample_box_pose() # used in sim reset
elif 'sim_insertion' in task_name:
if BOX_POSE is None:
from sim_env import BOX_POSE as _BOX_POSE
BOX_POSE = _BOX_POSE
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
ts = env.reset()
### onscreen render
if onscreen_render:
ax = plt.subplot()
plt_img = ax.imshow(env._physics.render(height=480, width=640, camera_id=onscreen_cam))
plt.ion()
### evaluation loop
if temporal_agg:
all_time_actions = torch.zeros([max_timesteps, max_timesteps+num_queries, action_dim]).cuda()
qpos_history = torch.zeros((1, max_timesteps, state_dim)).cuda()
image_list = [] # for visualization
qpos_list = []
target_qpos_list = []
rewards = []
with torch.inference_mode():
for t in range(max_timesteps):
### update onscreen render and wait for DT
if onscreen_render:
image = env._physics.render(height=480, width=640, camera_id=onscreen_cam)
plt_img.set_data(image)
plt.pause(DT)
### process previous timestep to get qpos and image_list
obs = ts.observation
if 'images' in obs:
image_list.append(obs['images'])
else:
image_list.append({'main': obs['image']})
qpos_numpy = np.array(obs['qpos'])
qpos = pre_process(qpos_numpy)
qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0)
qpos_history[:, t] = qpos
curr_image = get_image(ts, camera_names)
### query policy
if config['policy_class'] == "ACT":
if t % query_frequency == 0:
all_actions = policy(qpos, curr_image)
if temporal_agg:
all_time_actions[[t], t:t+num_queries] = all_actions
actions_for_curr_step = all_time_actions[:, t]
actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
actions_for_curr_step = actions_for_curr_step[actions_populated]
k = 0.01
exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
exp_weights = exp_weights / exp_weights.sum()
exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
else:
raw_action = all_actions[:, t % query_frequency]
elif config['policy_class'] == "CNNMLP":
raw_action = policy(qpos, curr_image)
else:
raise NotImplementedError
### post-process actions
raw_action = raw_action.squeeze(0).cpu().numpy()
action = post_process(raw_action)
target_qpos = action
### step the environment
ts = env.step(target_qpos)
### for visualization
qpos_list.append(qpos_numpy)
target_qpos_list.append(target_qpos)
rewards.append(ts.reward)
plt.close()
if real_robot:
move_grippers([env.puppet_bot_left, env.puppet_bot_right], [PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5) # open
pass
rewards = np.array(rewards)
episode_return = np.sum(rewards[rewards!=None])
episode_returns.append(episode_return)
episode_highest_reward = np.max(rewards)
highest_rewards.append(episode_highest_reward)
print(f'Rollout {rollout_id}\n{episode_return=}, {episode_highest_reward=}, {env_max_reward=}, Success: {episode_highest_reward==env_max_reward}')
if save_episode:
save_videos(image_list, DT, video_path=os.path.join(ckpt_dir, f'video{rollout_id}.mp4'))
success_rate = np.mean(np.array(highest_rewards) == env_max_reward)
avg_return = np.mean(episode_returns)
summary_str = f'\nSuccess rate: {success_rate}\nAverage return: {avg_return}\n\n'
for r in range(env_max_reward+1):
more_or_equal_r = (np.array(highest_rewards) >= r).sum()
more_or_equal_r_rate = more_or_equal_r / num_rollouts
summary_str += f'Reward >= {r}: {more_or_equal_r}/{num_rollouts} = {more_or_equal_r_rate*100}%\n'
print(summary_str)
# save success rate to txt
result_file_name = 'result_' + ckpt_name.split('.')[0] + '.txt'
with open(os.path.join(ckpt_dir, result_file_name), 'w') as f:
f.write(summary_str)
f.write(repr(episode_returns))
f.write('\n\n')
f.write(repr(highest_rewards))
return success_rate, avg_return
def forward_pass(data, policy, debug_input=False, debug_tag=''):
image_data, qpos_data, action_data, is_pad, text_input_ids, text_attention_mask, text_feature_data, text_feature_valid = data
image_data = image_data.cuda()
qpos_data = qpos_data.cuda()
action_data = action_data.cuda()
is_pad = is_pad.cuda()
text_input_ids = text_input_ids.cuda()
text_attention_mask = text_attention_mask.cuda()
text_feature_data = text_feature_data.cuda()
text_feature_valid = text_feature_valid.cuda()
text_features = None
if torch.any(text_feature_valid):
text_features = text_feature_data
if debug_input:
image_min = float(image_data.min().item())
image_max = float(image_data.max().item())
qpos_mean = float(qpos_data.mean().item())
qpos_std = float(qpos_data.std().item())
action_mean = float(action_data.mean().item())
action_std = float(action_data.std().item())
pad_ratio = float(is_pad.float().mean().item())
print(f'[debug_input] {debug_tag} image shape={tuple(image_data.shape)} range=[{image_min:.4f}, {image_max:.4f}]')
print(f'[debug_input] {debug_tag} qpos shape={tuple(qpos_data.shape)} mean/std=({qpos_mean:.4f}, {qpos_std:.4f})')
print(f'[debug_input] {debug_tag} action shape={tuple(action_data.shape)} mean/std=({action_mean:.4f}, {action_std:.4f})')
print(f'[debug_input] {debug_tag} is_pad shape={tuple(is_pad.shape)} pad_ratio={pad_ratio:.4f}')
print(
f'[debug_input] {debug_tag} has_nan_or_inf: '
f'image={bool(torch.logical_not(torch.isfinite(image_data)).any().item())}, '
f'qpos={bool(torch.logical_not(torch.isfinite(qpos_data)).any().item())}, '
f'action={bool(torch.logical_not(torch.isfinite(action_data)).any().item())}'
)
return policy(
qpos_data,
image_data,
text_input_ids=text_input_ids,
text_attention_mask=text_attention_mask,
text_features=text_features,
actions=action_data,
is_pad=is_pad,
)
def train_bc(train_dataloader, val_dataloader, config):
num_epochs = config['num_epochs']
train_steps_per_epoch = config.get('train_steps_per_epoch', None)
resume_ckpt_path = config.get('resume_ckpt_path', None)
ckpt_dir = config['ckpt_dir']
seed = config['seed']
policy_class = config['policy_class']
policy_config = config['policy_config']
debug_input = config.get('debug_input', False)
set_seed(seed)
policy = make_policy(policy_class, policy_config)
policy.cuda()
if resume_ckpt_path:
loading_status = policy.load_state_dict(load_checkpoint_state_dict(resume_ckpt_path))
print(f'Loaded finetune init ckpt: {resume_ckpt_path}')
print(loading_status)
optimizer = make_optimizer(policy_class, policy)
train_history = []
validation_history = []
min_val_loss = np.inf
best_ckpt_info = None
for epoch in tqdm(range(num_epochs)):
print(f'\nEpoch {epoch}')
# validation
with torch.inference_mode():
policy.eval()
epoch_dicts = []
for batch_idx, data in enumerate(val_dataloader):
should_debug = debug_input and epoch == 0 and batch_idx == 0
forward_dict = forward_pass(data, policy, debug_input=should_debug, debug_tag='val/epoch0/batch0')
epoch_dicts.append(forward_dict)
epoch_summary = compute_dict_mean(epoch_dicts)
validation_history.append(epoch_summary)
epoch_val_loss = epoch_summary['loss']
if epoch_val_loss < min_val_loss:
min_val_loss = epoch_val_loss
best_ckpt_info = (epoch, min_val_loss, deepcopy(policy.state_dict()))
print(f'Val loss: {epoch_val_loss:.5f}')
summary_string = ''
for k, v in epoch_summary.items():
summary_string += f'{k}: {v.item():.3f} '
print(summary_string)
# training
policy.train()
optimizer.zero_grad()
epoch_train_dicts = []
if train_steps_per_epoch is None or train_steps_per_epoch <= 0:
train_steps_this_epoch = len(train_dataloader)
train_iterator = iter(train_dataloader)
else:
train_steps_this_epoch = int(train_steps_per_epoch)
train_iterator = iter(train_dataloader)
for step_idx in range(train_steps_this_epoch):
try:
data = next(train_iterator)
except StopIteration:
train_iterator = iter(train_dataloader)
data = next(train_iterator)
should_debug = debug_input and epoch == 0 and step_idx == 0
forward_dict = forward_pass(data, policy, debug_input=should_debug, debug_tag='train/epoch0/batch0')
# backward
loss = forward_dict['loss']
loss.backward()
optimizer.step()
optimizer.zero_grad()
train_history.append(detach_dict(forward_dict))
epoch_train_dicts.append(detach_dict(forward_dict))
epoch_summary = compute_dict_mean(epoch_train_dicts)
epoch_train_loss = epoch_summary['loss']
print(f'Train loss: {epoch_train_loss:.5f}')
summary_string = ''
for k, v in epoch_summary.items():
summary_string += f'{k}: {v.item():.3f} '
print(summary_string)
if epoch % 100 == 0:
ckpt_path = os.path.join(ckpt_dir, f'policy_epoch_{epoch}_seed_{seed}.ckpt')
torch.save(policy.state_dict(), ckpt_path)
plot_history(train_history, validation_history, epoch, ckpt_dir, seed)
ckpt_path = os.path.join(ckpt_dir, f'policy_last.ckpt')
torch.save(policy.state_dict(), ckpt_path)
best_epoch, min_val_loss, best_state_dict = best_ckpt_info
ckpt_path = os.path.join(ckpt_dir, f'policy_epoch_{best_epoch}_seed_{seed}.ckpt')
torch.save(best_state_dict, ckpt_path)
print(f'Training finished:\nSeed {seed}, val loss {min_val_loss:.6f} at epoch {best_epoch}')
# save training curves
plot_history(train_history, validation_history, num_epochs, ckpt_dir, seed)
return best_ckpt_info
def plot_history(train_history, validation_history, num_epochs, ckpt_dir, seed):
# save training curves
for key in train_history[0]:
plot_path = os.path.join(ckpt_dir, f'train_val_{key}_seed_{seed}.png')
plt.figure()
train_values = [summary[key].item() for summary in train_history]
val_values = [summary[key].item() for summary in validation_history]
plt.plot(np.linspace(0, num_epochs-1, len(train_history)), train_values, label='train')
plt.plot(np.linspace(0, num_epochs-1, len(validation_history)), val_values, label='validation')
# plt.ylim([-0.1, 1])
plt.tight_layout()
plt.legend()
plt.title(key)
plt.savefig(plot_path)
print(f'Saved plots to {ckpt_dir}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--eval', action='store_true')
parser.add_argument('--onscreen_render', action='store_true')
parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=True)
parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', required=True)
parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True)
parser.add_argument('--batch_size', action='store', type=int, help='batch_size', required=True)
parser.add_argument('--seed', action='store', type=int, help='seed', required=True)
parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', required=True)
parser.add_argument('--lr', action='store', type=float, help='lr', required=True)
# for ACT
parser.add_argument('--kl_weight', action='store', type=float, help='KL Weight', required=False)
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)
parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', required=False)
parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', required=False)
parser.add_argument('--temporal_agg', action='store_true')
parser.add_argument('--text_encoder_type', action='store', type=str, required=False)
parser.add_argument('--freeze_text_encoder', action='store_true')
parser.add_argument('--text_max_length', action='store', type=int, required=False)
parser.add_argument('--image_aug', action='store_true',
help='Enable training-time image augmentation (color/highlight/noise/blur)')
parser.add_argument('--train_steps_per_epoch', action='store', type=int, required=False,
help='If set > 0, run a fixed number of optimizer steps per epoch by cycling over the train dataloader')
parser.add_argument('--disable_real_action_shift', action='store_true',
help='Disable real-data action alignment shift (use action[start_ts:] instead of action[start_ts-1:])')
parser.add_argument('--resume_ckpt', action='store', type=str, required=False,
help='Optional checkpoint path to initialize model weights for fine-tuning')
parser.add_argument('--debug_input', action='store_true',
help='Print one-batch input sanity checks (shape/range/nan) for val/train at epoch 0')
main(vars(parser.parse_args()))