follow用policy_last

This commit is contained in:
2026-02-20 16:45:16 +08:00
parent 88d0cc5ca2
commit 81e1bf8838
4 changed files with 129 additions and 30 deletions

View File

@@ -20,6 +20,15 @@ from visualize_episodes import save_videos
import IPython
e = IPython.embed
def load_checkpoint_state_dict(ckpt_path):
"""Load checkpoint state_dict safely across different torch versions."""
try:
return torch.load(ckpt_path, map_location='cpu', weights_only=True)
except TypeError:
# For older PyTorch versions that do not support `weights_only`.
return torch.load(ckpt_path, map_location='cpu')
def main(args):
set_seed(1)
# command line parameters
@@ -60,6 +69,7 @@ def main(args):
text_max_length = task_config.get('text_max_length', 32)
text_tokenizer_name = task_config.get('text_tokenizer_name', 'distilbert-base-uncased')
freeze_text_encoder = task_config.get('freeze_text_encoder', True)
real_action_t_minus_1 = task_config.get('real_action_t_minus_1', True)
if args.get('text_encoder_type') is not None:
text_encoder_type = args['text_encoder_type']
@@ -67,6 +77,8 @@ def main(args):
text_max_length = args['text_max_length']
if args.get('freeze_text_encoder', False):
freeze_text_encoder = True
if args.get('disable_real_action_shift', False):
real_action_t_minus_1 = False
# fixed parameters
lr_backbone = 1e-5
@@ -110,6 +122,8 @@ def main(args):
config = {
'num_epochs': num_epochs,
'train_steps_per_epoch': args.get('train_steps_per_epoch', None),
'resume_ckpt_path': args.get('resume_ckpt', None),
'ckpt_dir': ckpt_dir,
'episode_len': episode_len,
'state_dim': state_dim,
@@ -131,6 +145,16 @@ def main(args):
'debug_input': args.get('debug_input', False),
}
if config['resume_ckpt_path']:
resume_ckpt_path = config['resume_ckpt_path']
if not os.path.isabs(resume_ckpt_path):
candidate_path = os.path.join(ckpt_dir, resume_ckpt_path)
if os.path.isfile(candidate_path):
resume_ckpt_path = candidate_path
if not os.path.isfile(resume_ckpt_path):
raise FileNotFoundError(f'--resume_ckpt not found: {config["resume_ckpt_path"]}')
config['resume_ckpt_path'] = resume_ckpt_path
if is_eval:
ckpt_names = [f'policy_best.ckpt']
results = []
@@ -155,6 +179,7 @@ def main(args):
text_feature_dim=text_feature_dim,
text_tokenizer_name=text_tokenizer_name,
text_max_length=text_max_length,
real_action_t_minus_1=real_action_t_minus_1,
image_augment=args['image_aug'],
)
@@ -223,7 +248,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
# load policy and stats
ckpt_path = os.path.join(ckpt_dir, ckpt_name)
policy = make_policy(policy_class, policy_config)
loading_status = policy.load_state_dict(torch.load(ckpt_path))
loading_status = policy.load_state_dict(load_checkpoint_state_dict(ckpt_path))
print(loading_status)
policy.cuda()
policy.eval()
@@ -425,6 +450,8 @@ def forward_pass(data, policy, debug_input=False, debug_tag=''):
def train_bc(train_dataloader, val_dataloader, config):
num_epochs = config['num_epochs']
train_steps_per_epoch = config.get('train_steps_per_epoch', None)
resume_ckpt_path = config.get('resume_ckpt_path', None)
ckpt_dir = config['ckpt_dir']
seed = config['seed']
policy_class = config['policy_class']
@@ -435,6 +462,12 @@ def train_bc(train_dataloader, val_dataloader, config):
policy = make_policy(policy_class, policy_config)
policy.cuda()
if resume_ckpt_path:
loading_status = policy.load_state_dict(load_checkpoint_state_dict(resume_ckpt_path))
print(f'Loaded finetune init ckpt: {resume_ckpt_path}')
print(loading_status)
optimizer = make_optimizer(policy_class, policy)
train_history = []
@@ -467,8 +500,22 @@ def train_bc(train_dataloader, val_dataloader, config):
# training
policy.train()
optimizer.zero_grad()
for batch_idx, data in enumerate(train_dataloader):
should_debug = debug_input and epoch == 0 and batch_idx == 0
epoch_train_dicts = []
if train_steps_per_epoch is None or train_steps_per_epoch <= 0:
train_steps_this_epoch = len(train_dataloader)
train_iterator = iter(train_dataloader)
else:
train_steps_this_epoch = int(train_steps_per_epoch)
train_iterator = iter(train_dataloader)
for step_idx in range(train_steps_this_epoch):
try:
data = next(train_iterator)
except StopIteration:
train_iterator = iter(train_dataloader)
data = next(train_iterator)
should_debug = debug_input and epoch == 0 and step_idx == 0
forward_dict = forward_pass(data, policy, debug_input=should_debug, debug_tag='train/epoch0/batch0')
# backward
loss = forward_dict['loss']
@@ -476,7 +523,8 @@ def train_bc(train_dataloader, val_dataloader, config):
optimizer.step()
optimizer.zero_grad()
train_history.append(detach_dict(forward_dict))
epoch_summary = compute_dict_mean(train_history[(batch_idx+1)*epoch:(batch_idx+1)*(epoch+1)])
epoch_train_dicts.append(detach_dict(forward_dict))
epoch_summary = compute_dict_mean(epoch_train_dicts)
epoch_train_loss = epoch_summary['loss']
print(f'Train loss: {epoch_train_loss:.5f}')
summary_string = ''
@@ -533,7 +581,7 @@ if __name__ == '__main__':
parser.add_argument('--lr', action='store', type=float, help='lr', required=True)
# for ACT
parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False)
parser.add_argument('--kl_weight', action='store', type=float, help='KL Weight', required=False)
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)
parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', required=False)
parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', required=False)
@@ -543,6 +591,12 @@ if __name__ == '__main__':
parser.add_argument('--text_max_length', action='store', type=int, required=False)
parser.add_argument('--image_aug', action='store_true',
help='Enable training-time image augmentation (color/highlight/noise/blur)')
parser.add_argument('--train_steps_per_epoch', action='store', type=int, required=False,
help='If set > 0, run a fixed number of optimizer steps per epoch by cycling over the train dataloader')
parser.add_argument('--disable_real_action_shift', action='store_true',
help='Disable real-data action alignment shift (use action[start_ts:] instead of action[start_ts-1:])')
parser.add_argument('--resume_ckpt', action='store', type=str, required=False,
help='Optional checkpoint path to initialize model weights for fine-tuning')
parser.add_argument('--debug_input', action='store_true',
help='Print one-batch input sanity checks (shape/range/nan) for val/train at epoch 0')