604 lines
25 KiB
Python
604 lines
25 KiB
Python
import torch
|
|
import numpy as np
|
|
import os
|
|
import pickle
|
|
import argparse
|
|
import matplotlib.pyplot as plt
|
|
from copy import deepcopy
|
|
from tqdm import tqdm
|
|
from einops import rearrange
|
|
|
|
from constants import DT
|
|
from constants import PUPPET_GRIPPER_JOINT_OPEN
|
|
from constants import SIM_TASK_CONFIGS, ENDOSCOPE_TASK_CONFIGS
|
|
from utils import load_data # data functions
|
|
from utils import sample_box_pose, sample_insertion_pose # robot functions
|
|
from utils import compute_dict_mean, set_seed, detach_dict # helper functions
|
|
from policy import ACTPolicy, CNNMLPPolicy
|
|
from visualize_episodes import save_videos
|
|
|
|
import IPython
|
|
e = IPython.embed
|
|
|
|
|
|
def load_checkpoint_state_dict(ckpt_path):
|
|
"""Load checkpoint state_dict safely across different torch versions."""
|
|
try:
|
|
return torch.load(ckpt_path, map_location='cpu', weights_only=True)
|
|
except TypeError:
|
|
# For older PyTorch versions that do not support `weights_only`.
|
|
return torch.load(ckpt_path, map_location='cpu')
|
|
|
|
def main(args):
|
|
set_seed(1)
|
|
# command line parameters
|
|
is_eval = args['eval']
|
|
ckpt_dir = args['ckpt_dir']
|
|
policy_class = args['policy_class']
|
|
onscreen_render = args['onscreen_render']
|
|
task_name = args['task_name']
|
|
batch_size_train = args['batch_size']
|
|
batch_size_val = args['batch_size']
|
|
num_epochs = args['num_epochs']
|
|
|
|
# get task parameters
|
|
is_endoscope = task_name in ENDOSCOPE_TASK_CONFIGS
|
|
if is_endoscope:
|
|
task_config = ENDOSCOPE_TASK_CONFIGS[task_name]
|
|
is_sim = False
|
|
elif task_name in SIM_TASK_CONFIGS:
|
|
task_config = SIM_TASK_CONFIGS[task_name]
|
|
is_sim = True
|
|
else:
|
|
from aloha_scripts.constants import TASK_CONFIGS
|
|
task_config = TASK_CONFIGS[task_name]
|
|
is_sim = False
|
|
|
|
dataset_dir = task_config['dataset_dir']
|
|
num_episodes = task_config['num_episodes']
|
|
episode_len = task_config['episode_len']
|
|
camera_names = task_config['camera_names']
|
|
state_dim = task_config.get('state_dim', 14)
|
|
action_dim = task_config.get('action_dim', state_dim)
|
|
use_text_instruction = task_config.get('use_text_instruction', False)
|
|
instruction_mode = task_config.get('instruction_mode', 'timestep-level')
|
|
use_cached_text_features = task_config.get('use_cached_text_features', True)
|
|
text_encoder_type = task_config.get('text_encoder_type', 'distilbert')
|
|
text_feature_dim = task_config.get('text_feature_dim', 768)
|
|
text_fusion_type = task_config.get('text_fusion_type', 'concat_transformer_input')
|
|
text_max_length = task_config.get('text_max_length', 32)
|
|
text_tokenizer_name = task_config.get('text_tokenizer_name', 'distilbert-base-uncased')
|
|
freeze_text_encoder = task_config.get('freeze_text_encoder', True)
|
|
real_action_t_minus_1 = task_config.get('real_action_t_minus_1', True)
|
|
|
|
if args.get('text_encoder_type') is not None:
|
|
text_encoder_type = args['text_encoder_type']
|
|
if args.get('text_max_length') is not None:
|
|
text_max_length = args['text_max_length']
|
|
if args.get('freeze_text_encoder', False):
|
|
freeze_text_encoder = True
|
|
if args.get('disable_real_action_shift', False):
|
|
real_action_t_minus_1 = False
|
|
|
|
# fixed parameters
|
|
lr_backbone = 1e-5
|
|
backbone = 'resnet18'
|
|
if policy_class == 'ACT':
|
|
enc_layers = 2
|
|
dec_layers = 4
|
|
nheads = 8
|
|
policy_config = {'lr': args['lr'],
|
|
'num_queries': args['chunk_size'],
|
|
'kl_weight': args['kl_weight'],
|
|
'hidden_dim': args['hidden_dim'],
|
|
'dim_feedforward': args['dim_feedforward'],
|
|
'lr_backbone': lr_backbone,
|
|
'backbone': backbone,
|
|
'enc_layers': enc_layers,
|
|
'dec_layers': dec_layers,
|
|
'nheads': nheads,
|
|
'camera_names': camera_names,
|
|
'state_dim': state_dim,
|
|
'action_dim': action_dim,
|
|
'use_text': use_text_instruction,
|
|
'text_encoder_type': text_encoder_type,
|
|
'text_feature_dim': text_feature_dim,
|
|
'text_fusion_type': text_fusion_type,
|
|
'freeze_text_encoder': freeze_text_encoder,
|
|
'instruction_mode': instruction_mode,
|
|
'use_cached_text_features': use_cached_text_features,
|
|
'text_max_length': text_max_length,
|
|
'text_tokenizer_name': text_tokenizer_name,
|
|
}
|
|
elif policy_class == 'CNNMLP':
|
|
policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1,
|
|
'camera_names': camera_names,
|
|
'state_dim': state_dim,
|
|
'action_dim': action_dim,
|
|
'use_text': use_text_instruction,
|
|
}
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
config = {
|
|
'num_epochs': num_epochs,
|
|
'train_steps_per_epoch': args.get('train_steps_per_epoch', None),
|
|
'resume_ckpt_path': args.get('resume_ckpt', None),
|
|
'ckpt_dir': ckpt_dir,
|
|
'episode_len': episode_len,
|
|
'state_dim': state_dim,
|
|
'action_dim': action_dim,
|
|
'lr': args['lr'],
|
|
'policy_class': policy_class,
|
|
'onscreen_render': onscreen_render,
|
|
'policy_config': policy_config,
|
|
'task_name': task_name,
|
|
'seed': args['seed'],
|
|
'temporal_agg': args['temporal_agg'],
|
|
'camera_names': camera_names,
|
|
'real_robot': (not is_sim) and (not is_endoscope),
|
|
'use_text_instruction': use_text_instruction,
|
|
'instruction_mode': instruction_mode,
|
|
'use_cached_text_features': use_cached_text_features,
|
|
'text_tokenizer_name': text_tokenizer_name,
|
|
'text_max_length': text_max_length,
|
|
'debug_input': args.get('debug_input', False),
|
|
}
|
|
|
|
if config['resume_ckpt_path']:
|
|
resume_ckpt_path = config['resume_ckpt_path']
|
|
if not os.path.isabs(resume_ckpt_path):
|
|
candidate_path = os.path.join(ckpt_dir, resume_ckpt_path)
|
|
if os.path.isfile(candidate_path):
|
|
resume_ckpt_path = candidate_path
|
|
if not os.path.isfile(resume_ckpt_path):
|
|
raise FileNotFoundError(f'--resume_ckpt not found: {config["resume_ckpt_path"]}')
|
|
config['resume_ckpt_path'] = resume_ckpt_path
|
|
|
|
if is_eval:
|
|
ckpt_names = [f'policy_best.ckpt']
|
|
results = []
|
|
for ckpt_name in ckpt_names:
|
|
success_rate, avg_return = eval_bc(config, ckpt_name, save_episode=True)
|
|
results.append([ckpt_name, success_rate, avg_return])
|
|
|
|
for ckpt_name, success_rate, avg_return in results:
|
|
print(f'{ckpt_name}: {success_rate=} {avg_return=}')
|
|
print()
|
|
exit()
|
|
|
|
train_dataloader, val_dataloader, stats, _ = load_data(
|
|
dataset_dir,
|
|
num_episodes,
|
|
camera_names,
|
|
batch_size_train,
|
|
batch_size_val,
|
|
use_text_instruction=use_text_instruction,
|
|
instruction_mode=instruction_mode,
|
|
use_cached_text_features=use_cached_text_features,
|
|
text_feature_dim=text_feature_dim,
|
|
text_tokenizer_name=text_tokenizer_name,
|
|
text_max_length=text_max_length,
|
|
real_action_t_minus_1=real_action_t_minus_1,
|
|
image_augment=args['image_aug'],
|
|
)
|
|
|
|
# save dataset stats
|
|
if not os.path.isdir(ckpt_dir):
|
|
os.makedirs(ckpt_dir)
|
|
stats_path = os.path.join(ckpt_dir, f'dataset_stats.pkl')
|
|
with open(stats_path, 'wb') as f:
|
|
pickle.dump(stats, f)
|
|
|
|
best_ckpt_info = train_bc(train_dataloader, val_dataloader, config)
|
|
best_epoch, min_val_loss, best_state_dict = best_ckpt_info
|
|
|
|
# save best checkpoint
|
|
ckpt_path = os.path.join(ckpt_dir, f'policy_best.ckpt')
|
|
torch.save(best_state_dict, ckpt_path)
|
|
print(f'Best ckpt, val loss {min_val_loss:.6f} @ epoch{best_epoch}')
|
|
|
|
|
|
def make_policy(policy_class, policy_config):
|
|
if policy_class == 'ACT':
|
|
policy = ACTPolicy(policy_config)
|
|
elif policy_class == 'CNNMLP':
|
|
policy = CNNMLPPolicy(policy_config)
|
|
else:
|
|
raise NotImplementedError
|
|
return policy
|
|
|
|
|
|
def make_optimizer(policy_class, policy):
|
|
if policy_class == 'ACT':
|
|
optimizer = policy.configure_optimizers()
|
|
elif policy_class == 'CNNMLP':
|
|
optimizer = policy.configure_optimizers()
|
|
else:
|
|
raise NotImplementedError
|
|
return optimizer
|
|
|
|
|
|
def get_image(ts, camera_names):
|
|
curr_images = []
|
|
for cam_name in camera_names:
|
|
curr_image = rearrange(ts.observation['images'][cam_name], 'h w c -> c h w')
|
|
curr_images.append(curr_image)
|
|
curr_image = np.stack(curr_images, axis=0)
|
|
curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0)
|
|
return curr_image
|
|
|
|
|
|
def eval_bc(config, ckpt_name, save_episode=True):
|
|
set_seed(1000)
|
|
ckpt_dir = config['ckpt_dir']
|
|
state_dim = config['state_dim']
|
|
action_dim = config['action_dim']
|
|
real_robot = config['real_robot']
|
|
policy_class = config['policy_class']
|
|
onscreen_render = config['onscreen_render']
|
|
policy_config = config['policy_config']
|
|
camera_names = config['camera_names']
|
|
max_timesteps = config['episode_len']
|
|
task_name = config['task_name']
|
|
temporal_agg = config['temporal_agg']
|
|
onscreen_cam = 'angle'
|
|
BOX_POSE = None
|
|
|
|
# load policy and stats
|
|
ckpt_path = os.path.join(ckpt_dir, ckpt_name)
|
|
policy = make_policy(policy_class, policy_config)
|
|
loading_status = policy.load_state_dict(load_checkpoint_state_dict(ckpt_path))
|
|
print(loading_status)
|
|
policy.cuda()
|
|
policy.eval()
|
|
print(f'Loaded: {ckpt_path}')
|
|
stats_path = os.path.join(ckpt_dir, f'dataset_stats.pkl')
|
|
with open(stats_path, 'rb') as f:
|
|
stats = pickle.load(f)
|
|
|
|
pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
|
|
post_process = lambda a: a * stats['action_std'] + stats['action_mean']
|
|
|
|
# load environment
|
|
if real_robot:
|
|
from aloha_scripts.robot_utils import move_grippers # requires aloha
|
|
from aloha_scripts.real_env import make_real_env # requires aloha
|
|
env = make_real_env(init_node=True)
|
|
env_max_reward = 0
|
|
else:
|
|
from sim_env import make_sim_env
|
|
env = make_sim_env(task_name)
|
|
env_max_reward = env.task.max_reward
|
|
|
|
query_frequency = policy_config['num_queries']
|
|
if temporal_agg:
|
|
query_frequency = 1
|
|
num_queries = policy_config['num_queries']
|
|
|
|
max_timesteps = int(max_timesteps * 1) # may increase for real-world tasks
|
|
|
|
num_rollouts = 50
|
|
episode_returns = []
|
|
highest_rewards = []
|
|
for rollout_id in range(num_rollouts):
|
|
rollout_id += 0
|
|
### set task
|
|
if 'sim_transfer_cube' in task_name:
|
|
if BOX_POSE is None:
|
|
from sim_env import BOX_POSE as _BOX_POSE
|
|
BOX_POSE = _BOX_POSE
|
|
BOX_POSE[0] = sample_box_pose() # used in sim reset
|
|
elif 'sim_insertion' in task_name:
|
|
if BOX_POSE is None:
|
|
from sim_env import BOX_POSE as _BOX_POSE
|
|
BOX_POSE = _BOX_POSE
|
|
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
|
|
|
|
ts = env.reset()
|
|
|
|
### onscreen render
|
|
if onscreen_render:
|
|
ax = plt.subplot()
|
|
plt_img = ax.imshow(env._physics.render(height=480, width=640, camera_id=onscreen_cam))
|
|
plt.ion()
|
|
|
|
### evaluation loop
|
|
if temporal_agg:
|
|
all_time_actions = torch.zeros([max_timesteps, max_timesteps+num_queries, action_dim]).cuda()
|
|
|
|
qpos_history = torch.zeros((1, max_timesteps, state_dim)).cuda()
|
|
image_list = [] # for visualization
|
|
qpos_list = []
|
|
target_qpos_list = []
|
|
rewards = []
|
|
with torch.inference_mode():
|
|
for t in range(max_timesteps):
|
|
### update onscreen render and wait for DT
|
|
if onscreen_render:
|
|
image = env._physics.render(height=480, width=640, camera_id=onscreen_cam)
|
|
plt_img.set_data(image)
|
|
plt.pause(DT)
|
|
|
|
### process previous timestep to get qpos and image_list
|
|
obs = ts.observation
|
|
if 'images' in obs:
|
|
image_list.append(obs['images'])
|
|
else:
|
|
image_list.append({'main': obs['image']})
|
|
qpos_numpy = np.array(obs['qpos'])
|
|
qpos = pre_process(qpos_numpy)
|
|
qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0)
|
|
qpos_history[:, t] = qpos
|
|
curr_image = get_image(ts, camera_names)
|
|
|
|
### query policy
|
|
if config['policy_class'] == "ACT":
|
|
if t % query_frequency == 0:
|
|
all_actions = policy(qpos, curr_image)
|
|
if temporal_agg:
|
|
all_time_actions[[t], t:t+num_queries] = all_actions
|
|
actions_for_curr_step = all_time_actions[:, t]
|
|
actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
|
|
actions_for_curr_step = actions_for_curr_step[actions_populated]
|
|
k = 0.01
|
|
exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
|
|
exp_weights = exp_weights / exp_weights.sum()
|
|
exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
|
|
raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
|
|
else:
|
|
raw_action = all_actions[:, t % query_frequency]
|
|
elif config['policy_class'] == "CNNMLP":
|
|
raw_action = policy(qpos, curr_image)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
### post-process actions
|
|
raw_action = raw_action.squeeze(0).cpu().numpy()
|
|
action = post_process(raw_action)
|
|
target_qpos = action
|
|
|
|
### step the environment
|
|
ts = env.step(target_qpos)
|
|
|
|
### for visualization
|
|
qpos_list.append(qpos_numpy)
|
|
target_qpos_list.append(target_qpos)
|
|
rewards.append(ts.reward)
|
|
|
|
plt.close()
|
|
if real_robot:
|
|
move_grippers([env.puppet_bot_left, env.puppet_bot_right], [PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5) # open
|
|
pass
|
|
|
|
rewards = np.array(rewards)
|
|
episode_return = np.sum(rewards[rewards!=None])
|
|
episode_returns.append(episode_return)
|
|
episode_highest_reward = np.max(rewards)
|
|
highest_rewards.append(episode_highest_reward)
|
|
print(f'Rollout {rollout_id}\n{episode_return=}, {episode_highest_reward=}, {env_max_reward=}, Success: {episode_highest_reward==env_max_reward}')
|
|
|
|
if save_episode:
|
|
save_videos(image_list, DT, video_path=os.path.join(ckpt_dir, f'video{rollout_id}.mp4'))
|
|
|
|
success_rate = np.mean(np.array(highest_rewards) == env_max_reward)
|
|
avg_return = np.mean(episode_returns)
|
|
summary_str = f'\nSuccess rate: {success_rate}\nAverage return: {avg_return}\n\n'
|
|
for r in range(env_max_reward+1):
|
|
more_or_equal_r = (np.array(highest_rewards) >= r).sum()
|
|
more_or_equal_r_rate = more_or_equal_r / num_rollouts
|
|
summary_str += f'Reward >= {r}: {more_or_equal_r}/{num_rollouts} = {more_or_equal_r_rate*100}%\n'
|
|
|
|
print(summary_str)
|
|
|
|
# save success rate to txt
|
|
result_file_name = 'result_' + ckpt_name.split('.')[0] + '.txt'
|
|
with open(os.path.join(ckpt_dir, result_file_name), 'w') as f:
|
|
f.write(summary_str)
|
|
f.write(repr(episode_returns))
|
|
f.write('\n\n')
|
|
f.write(repr(highest_rewards))
|
|
|
|
return success_rate, avg_return
|
|
|
|
|
|
def forward_pass(data, policy, debug_input=False, debug_tag=''):
|
|
image_data, qpos_data, action_data, is_pad, text_input_ids, text_attention_mask, text_feature_data, text_feature_valid = data
|
|
image_data = image_data.cuda()
|
|
qpos_data = qpos_data.cuda()
|
|
action_data = action_data.cuda()
|
|
is_pad = is_pad.cuda()
|
|
text_input_ids = text_input_ids.cuda()
|
|
text_attention_mask = text_attention_mask.cuda()
|
|
text_feature_data = text_feature_data.cuda()
|
|
text_feature_valid = text_feature_valid.cuda()
|
|
|
|
text_features = None
|
|
if torch.any(text_feature_valid):
|
|
text_features = text_feature_data
|
|
|
|
if debug_input:
|
|
image_min = float(image_data.min().item())
|
|
image_max = float(image_data.max().item())
|
|
qpos_mean = float(qpos_data.mean().item())
|
|
qpos_std = float(qpos_data.std().item())
|
|
action_mean = float(action_data.mean().item())
|
|
action_std = float(action_data.std().item())
|
|
pad_ratio = float(is_pad.float().mean().item())
|
|
|
|
print(f'[debug_input] {debug_tag} image shape={tuple(image_data.shape)} range=[{image_min:.4f}, {image_max:.4f}]')
|
|
print(f'[debug_input] {debug_tag} qpos shape={tuple(qpos_data.shape)} mean/std=({qpos_mean:.4f}, {qpos_std:.4f})')
|
|
print(f'[debug_input] {debug_tag} action shape={tuple(action_data.shape)} mean/std=({action_mean:.4f}, {action_std:.4f})')
|
|
print(f'[debug_input] {debug_tag} is_pad shape={tuple(is_pad.shape)} pad_ratio={pad_ratio:.4f}')
|
|
print(
|
|
f'[debug_input] {debug_tag} has_nan_or_inf: '
|
|
f'image={bool(torch.logical_not(torch.isfinite(image_data)).any().item())}, '
|
|
f'qpos={bool(torch.logical_not(torch.isfinite(qpos_data)).any().item())}, '
|
|
f'action={bool(torch.logical_not(torch.isfinite(action_data)).any().item())}'
|
|
)
|
|
|
|
return policy(
|
|
qpos_data,
|
|
image_data,
|
|
text_input_ids=text_input_ids,
|
|
text_attention_mask=text_attention_mask,
|
|
text_features=text_features,
|
|
actions=action_data,
|
|
is_pad=is_pad,
|
|
)
|
|
|
|
|
|
def train_bc(train_dataloader, val_dataloader, config):
|
|
num_epochs = config['num_epochs']
|
|
train_steps_per_epoch = config.get('train_steps_per_epoch', None)
|
|
resume_ckpt_path = config.get('resume_ckpt_path', None)
|
|
ckpt_dir = config['ckpt_dir']
|
|
seed = config['seed']
|
|
policy_class = config['policy_class']
|
|
policy_config = config['policy_config']
|
|
debug_input = config.get('debug_input', False)
|
|
|
|
set_seed(seed)
|
|
|
|
policy = make_policy(policy_class, policy_config)
|
|
policy.cuda()
|
|
|
|
if resume_ckpt_path:
|
|
loading_status = policy.load_state_dict(load_checkpoint_state_dict(resume_ckpt_path))
|
|
print(f'Loaded finetune init ckpt: {resume_ckpt_path}')
|
|
print(loading_status)
|
|
|
|
optimizer = make_optimizer(policy_class, policy)
|
|
|
|
train_history = []
|
|
validation_history = []
|
|
min_val_loss = np.inf
|
|
best_ckpt_info = None
|
|
for epoch in tqdm(range(num_epochs)):
|
|
print(f'\nEpoch {epoch}')
|
|
# validation
|
|
with torch.inference_mode():
|
|
policy.eval()
|
|
epoch_dicts = []
|
|
for batch_idx, data in enumerate(val_dataloader):
|
|
should_debug = debug_input and epoch == 0 and batch_idx == 0
|
|
forward_dict = forward_pass(data, policy, debug_input=should_debug, debug_tag='val/epoch0/batch0')
|
|
epoch_dicts.append(forward_dict)
|
|
epoch_summary = compute_dict_mean(epoch_dicts)
|
|
validation_history.append(epoch_summary)
|
|
|
|
epoch_val_loss = epoch_summary['loss']
|
|
if epoch_val_loss < min_val_loss:
|
|
min_val_loss = epoch_val_loss
|
|
best_ckpt_info = (epoch, min_val_loss, deepcopy(policy.state_dict()))
|
|
print(f'Val loss: {epoch_val_loss:.5f}')
|
|
summary_string = ''
|
|
for k, v in epoch_summary.items():
|
|
summary_string += f'{k}: {v.item():.3f} '
|
|
print(summary_string)
|
|
|
|
# training
|
|
policy.train()
|
|
optimizer.zero_grad()
|
|
epoch_train_dicts = []
|
|
if train_steps_per_epoch is None or train_steps_per_epoch <= 0:
|
|
train_steps_this_epoch = len(train_dataloader)
|
|
train_iterator = iter(train_dataloader)
|
|
else:
|
|
train_steps_this_epoch = int(train_steps_per_epoch)
|
|
train_iterator = iter(train_dataloader)
|
|
|
|
for step_idx in range(train_steps_this_epoch):
|
|
try:
|
|
data = next(train_iterator)
|
|
except StopIteration:
|
|
train_iterator = iter(train_dataloader)
|
|
data = next(train_iterator)
|
|
|
|
should_debug = debug_input and epoch == 0 and step_idx == 0
|
|
forward_dict = forward_pass(data, policy, debug_input=should_debug, debug_tag='train/epoch0/batch0')
|
|
# backward
|
|
loss = forward_dict['loss']
|
|
loss.backward()
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
train_history.append(detach_dict(forward_dict))
|
|
epoch_train_dicts.append(detach_dict(forward_dict))
|
|
epoch_summary = compute_dict_mean(epoch_train_dicts)
|
|
epoch_train_loss = epoch_summary['loss']
|
|
print(f'Train loss: {epoch_train_loss:.5f}')
|
|
summary_string = ''
|
|
for k, v in epoch_summary.items():
|
|
summary_string += f'{k}: {v.item():.3f} '
|
|
print(summary_string)
|
|
|
|
if epoch % 100 == 0:
|
|
ckpt_path = os.path.join(ckpt_dir, f'policy_epoch_{epoch}_seed_{seed}.ckpt')
|
|
torch.save(policy.state_dict(), ckpt_path)
|
|
plot_history(train_history, validation_history, epoch, ckpt_dir, seed)
|
|
|
|
ckpt_path = os.path.join(ckpt_dir, f'policy_last.ckpt')
|
|
torch.save(policy.state_dict(), ckpt_path)
|
|
|
|
best_epoch, min_val_loss, best_state_dict = best_ckpt_info
|
|
ckpt_path = os.path.join(ckpt_dir, f'policy_epoch_{best_epoch}_seed_{seed}.ckpt')
|
|
torch.save(best_state_dict, ckpt_path)
|
|
print(f'Training finished:\nSeed {seed}, val loss {min_val_loss:.6f} at epoch {best_epoch}')
|
|
|
|
# save training curves
|
|
plot_history(train_history, validation_history, num_epochs, ckpt_dir, seed)
|
|
|
|
return best_ckpt_info
|
|
|
|
|
|
def plot_history(train_history, validation_history, num_epochs, ckpt_dir, seed):
|
|
# save training curves
|
|
for key in train_history[0]:
|
|
plot_path = os.path.join(ckpt_dir, f'train_val_{key}_seed_{seed}.png')
|
|
plt.figure()
|
|
train_values = [summary[key].item() for summary in train_history]
|
|
val_values = [summary[key].item() for summary in validation_history]
|
|
plt.plot(np.linspace(0, num_epochs-1, len(train_history)), train_values, label='train')
|
|
plt.plot(np.linspace(0, num_epochs-1, len(validation_history)), val_values, label='validation')
|
|
# plt.ylim([-0.1, 1])
|
|
plt.tight_layout()
|
|
plt.legend()
|
|
plt.title(key)
|
|
plt.savefig(plot_path)
|
|
print(f'Saved plots to {ckpt_dir}')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--eval', action='store_true')
|
|
parser.add_argument('--onscreen_render', action='store_true')
|
|
parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=True)
|
|
parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', required=True)
|
|
parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True)
|
|
parser.add_argument('--batch_size', action='store', type=int, help='batch_size', required=True)
|
|
parser.add_argument('--seed', action='store', type=int, help='seed', required=True)
|
|
parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', required=True)
|
|
parser.add_argument('--lr', action='store', type=float, help='lr', required=True)
|
|
|
|
# for ACT
|
|
parser.add_argument('--kl_weight', action='store', type=float, help='KL Weight', required=False)
|
|
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)
|
|
parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', required=False)
|
|
parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', required=False)
|
|
parser.add_argument('--temporal_agg', action='store_true')
|
|
parser.add_argument('--text_encoder_type', action='store', type=str, required=False)
|
|
parser.add_argument('--freeze_text_encoder', action='store_true')
|
|
parser.add_argument('--text_max_length', action='store', type=int, required=False)
|
|
parser.add_argument('--image_aug', action='store_true',
|
|
help='Enable training-time image augmentation (color/highlight/noise/blur)')
|
|
parser.add_argument('--train_steps_per_epoch', action='store', type=int, required=False,
|
|
help='If set > 0, run a fixed number of optimizer steps per epoch by cycling over the train dataloader')
|
|
parser.add_argument('--disable_real_action_shift', action='store_true',
|
|
help='Disable real-data action alignment shift (use action[start_ts:] instead of action[start_ts-1:])')
|
|
parser.add_argument('--resume_ckpt', action='store', type=str, required=False,
|
|
help='Optional checkpoint path to initialize model weights for fine-tuning')
|
|
parser.add_argument('--debug_input', action='store_true',
|
|
help='Print one-batch input sanity checks (shape/range/nan) for val/train at epoch 0')
|
|
|
|
main(vars(parser.parse_args()))
|