follow用policy_last
This commit is contained in:
@@ -20,6 +20,15 @@ from visualize_episodes import save_videos
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
|
||||
def load_checkpoint_state_dict(ckpt_path):
|
||||
"""Load checkpoint state_dict safely across different torch versions."""
|
||||
try:
|
||||
return torch.load(ckpt_path, map_location='cpu', weights_only=True)
|
||||
except TypeError:
|
||||
# For older PyTorch versions that do not support `weights_only`.
|
||||
return torch.load(ckpt_path, map_location='cpu')
|
||||
|
||||
def main(args):
|
||||
set_seed(1)
|
||||
# command line parameters
|
||||
@@ -60,6 +69,7 @@ def main(args):
|
||||
text_max_length = task_config.get('text_max_length', 32)
|
||||
text_tokenizer_name = task_config.get('text_tokenizer_name', 'distilbert-base-uncased')
|
||||
freeze_text_encoder = task_config.get('freeze_text_encoder', True)
|
||||
real_action_t_minus_1 = task_config.get('real_action_t_minus_1', True)
|
||||
|
||||
if args.get('text_encoder_type') is not None:
|
||||
text_encoder_type = args['text_encoder_type']
|
||||
@@ -67,6 +77,8 @@ def main(args):
|
||||
text_max_length = args['text_max_length']
|
||||
if args.get('freeze_text_encoder', False):
|
||||
freeze_text_encoder = True
|
||||
if args.get('disable_real_action_shift', False):
|
||||
real_action_t_minus_1 = False
|
||||
|
||||
# fixed parameters
|
||||
lr_backbone = 1e-5
|
||||
@@ -110,6 +122,8 @@ def main(args):
|
||||
|
||||
config = {
|
||||
'num_epochs': num_epochs,
|
||||
'train_steps_per_epoch': args.get('train_steps_per_epoch', None),
|
||||
'resume_ckpt_path': args.get('resume_ckpt', None),
|
||||
'ckpt_dir': ckpt_dir,
|
||||
'episode_len': episode_len,
|
||||
'state_dim': state_dim,
|
||||
@@ -131,6 +145,16 @@ def main(args):
|
||||
'debug_input': args.get('debug_input', False),
|
||||
}
|
||||
|
||||
if config['resume_ckpt_path']:
|
||||
resume_ckpt_path = config['resume_ckpt_path']
|
||||
if not os.path.isabs(resume_ckpt_path):
|
||||
candidate_path = os.path.join(ckpt_dir, resume_ckpt_path)
|
||||
if os.path.isfile(candidate_path):
|
||||
resume_ckpt_path = candidate_path
|
||||
if not os.path.isfile(resume_ckpt_path):
|
||||
raise FileNotFoundError(f'--resume_ckpt not found: {config["resume_ckpt_path"]}')
|
||||
config['resume_ckpt_path'] = resume_ckpt_path
|
||||
|
||||
if is_eval:
|
||||
ckpt_names = [f'policy_best.ckpt']
|
||||
results = []
|
||||
@@ -155,6 +179,7 @@ def main(args):
|
||||
text_feature_dim=text_feature_dim,
|
||||
text_tokenizer_name=text_tokenizer_name,
|
||||
text_max_length=text_max_length,
|
||||
real_action_t_minus_1=real_action_t_minus_1,
|
||||
image_augment=args['image_aug'],
|
||||
)
|
||||
|
||||
@@ -223,7 +248,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
|
||||
# load policy and stats
|
||||
ckpt_path = os.path.join(ckpt_dir, ckpt_name)
|
||||
policy = make_policy(policy_class, policy_config)
|
||||
loading_status = policy.load_state_dict(torch.load(ckpt_path))
|
||||
loading_status = policy.load_state_dict(load_checkpoint_state_dict(ckpt_path))
|
||||
print(loading_status)
|
||||
policy.cuda()
|
||||
policy.eval()
|
||||
@@ -425,6 +450,8 @@ def forward_pass(data, policy, debug_input=False, debug_tag=''):
|
||||
|
||||
def train_bc(train_dataloader, val_dataloader, config):
|
||||
num_epochs = config['num_epochs']
|
||||
train_steps_per_epoch = config.get('train_steps_per_epoch', None)
|
||||
resume_ckpt_path = config.get('resume_ckpt_path', None)
|
||||
ckpt_dir = config['ckpt_dir']
|
||||
seed = config['seed']
|
||||
policy_class = config['policy_class']
|
||||
@@ -435,6 +462,12 @@ def train_bc(train_dataloader, val_dataloader, config):
|
||||
|
||||
policy = make_policy(policy_class, policy_config)
|
||||
policy.cuda()
|
||||
|
||||
if resume_ckpt_path:
|
||||
loading_status = policy.load_state_dict(load_checkpoint_state_dict(resume_ckpt_path))
|
||||
print(f'Loaded finetune init ckpt: {resume_ckpt_path}')
|
||||
print(loading_status)
|
||||
|
||||
optimizer = make_optimizer(policy_class, policy)
|
||||
|
||||
train_history = []
|
||||
@@ -467,8 +500,22 @@ def train_bc(train_dataloader, val_dataloader, config):
|
||||
# training
|
||||
policy.train()
|
||||
optimizer.zero_grad()
|
||||
for batch_idx, data in enumerate(train_dataloader):
|
||||
should_debug = debug_input and epoch == 0 and batch_idx == 0
|
||||
epoch_train_dicts = []
|
||||
if train_steps_per_epoch is None or train_steps_per_epoch <= 0:
|
||||
train_steps_this_epoch = len(train_dataloader)
|
||||
train_iterator = iter(train_dataloader)
|
||||
else:
|
||||
train_steps_this_epoch = int(train_steps_per_epoch)
|
||||
train_iterator = iter(train_dataloader)
|
||||
|
||||
for step_idx in range(train_steps_this_epoch):
|
||||
try:
|
||||
data = next(train_iterator)
|
||||
except StopIteration:
|
||||
train_iterator = iter(train_dataloader)
|
||||
data = next(train_iterator)
|
||||
|
||||
should_debug = debug_input and epoch == 0 and step_idx == 0
|
||||
forward_dict = forward_pass(data, policy, debug_input=should_debug, debug_tag='train/epoch0/batch0')
|
||||
# backward
|
||||
loss = forward_dict['loss']
|
||||
@@ -476,7 +523,8 @@ def train_bc(train_dataloader, val_dataloader, config):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
train_history.append(detach_dict(forward_dict))
|
||||
epoch_summary = compute_dict_mean(train_history[(batch_idx+1)*epoch:(batch_idx+1)*(epoch+1)])
|
||||
epoch_train_dicts.append(detach_dict(forward_dict))
|
||||
epoch_summary = compute_dict_mean(epoch_train_dicts)
|
||||
epoch_train_loss = epoch_summary['loss']
|
||||
print(f'Train loss: {epoch_train_loss:.5f}')
|
||||
summary_string = ''
|
||||
@@ -533,7 +581,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--lr', action='store', type=float, help='lr', required=True)
|
||||
|
||||
# for ACT
|
||||
parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False)
|
||||
parser.add_argument('--kl_weight', action='store', type=float, help='KL Weight', required=False)
|
||||
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)
|
||||
parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', required=False)
|
||||
parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', required=False)
|
||||
@@ -543,6 +591,12 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--text_max_length', action='store', type=int, required=False)
|
||||
parser.add_argument('--image_aug', action='store_true',
|
||||
help='Enable training-time image augmentation (color/highlight/noise/blur)')
|
||||
parser.add_argument('--train_steps_per_epoch', action='store', type=int, required=False,
|
||||
help='If set > 0, run a fixed number of optimizer steps per epoch by cycling over the train dataloader')
|
||||
parser.add_argument('--disable_real_action_shift', action='store_true',
|
||||
help='Disable real-data action alignment shift (use action[start_ts:] instead of action[start_ts-1:])')
|
||||
parser.add_argument('--resume_ckpt', action='store', type=str, required=False,
|
||||
help='Optional checkpoint path to initialize model weights for fine-tuning')
|
||||
parser.add_argument('--debug_input', action='store_true',
|
||||
help='Print one-batch input sanity checks (shape/range/nan) for val/train at epoch 0')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user