follow用policy_last

This commit is contained in:
2026-02-20 16:45:16 +08:00
parent 88d0cc5ca2
commit 81e1bf8838
4 changed files with 129 additions and 30 deletions

View File

@@ -40,6 +40,7 @@ ENDOSCOPE_TASK_CONFIGS = {
'camera_names': ['top'], 'camera_names': ['top'],
'state_dim': 2, 'state_dim': 2,
'action_dim': 2, 'action_dim': 2,
'real_action_t_minus_1': False,
'use_text_instruction': True, 'use_text_instruction': True,
'instruction_mode': 'timestep-level', 'instruction_mode': 'timestep-level',
'use_cached_text_features': True, 'use_cached_text_features': True,
@@ -57,6 +58,7 @@ ENDOSCOPE_TASK_CONFIGS = {
'camera_names': ['top'], 'camera_names': ['top'],
'state_dim': 2, 'state_dim': 2,
'action_dim': 2, 'action_dim': 2,
'real_action_t_minus_1': False,
'use_text_instruction': True, 'use_text_instruction': True,
'instruction_mode': 'timestep-level', 'instruction_mode': 'timestep-level',
'use_cached_text_features': True, 'use_cached_text_features': True,
@@ -74,6 +76,37 @@ ENDOSCOPE_TASK_CONFIGS = {
'camera_names': ['top'], 'camera_names': ['top'],
'state_dim': 2, 'state_dim': 2,
'action_dim': 2, 'action_dim': 2,
'real_action_t_minus_1': False,
'use_text_instruction': False,
},
'endoscope_sanity_check': {
'dataset_dir': DATA_DIR + '/sanity-check',
'num_episodes': 3,
'episode_len': 400,
'camera_names': ['top'],
'state_dim': 2,
'action_dim': 2,
'real_action_t_minus_1': False,
'use_text_instruction': False,
},
'endoscope_cannulation_no_text': {
'dataset_dir': DATA_DIR + '/cannulation-no-text',
'num_episodes': 3,
'episode_len': 400,
'camera_names': ['top'],
'state_dim': 2,
'action_dim': 2,
'real_action_t_minus_1': False,
'use_text_instruction': False,
},
'endoscope_follow_no_text': {
'dataset_dir': DATA_DIR + '/follow-no-text',
'num_episodes': 3,
'episode_len': 400,
'camera_names': ['top'],
'state_dim': 2,
'action_dim': 2,
'real_action_t_minus_1': False,
'use_text_instruction': False, 'use_text_instruction': False,
}, },
} }

View File

@@ -70,7 +70,7 @@ def get_args_parser():
parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True) parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True)
parser.add_argument('--seed', action='store', type=int, help='seed', required=True) parser.add_argument('--seed', action='store', type=int, help='seed', required=True)
parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', required=True) parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', required=True)
parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False) parser.add_argument('--kl_weight', action='store', type=float, help='KL Weight', required=False)
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False) parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)
parser.add_argument('--temporal_agg', action='store_true') parser.add_argument('--temporal_agg', action='store_true')
parser.add_argument('--image_aug', action='store_true') parser.add_argument('--image_aug', action='store_true')

View File

@@ -20,6 +20,15 @@ from visualize_episodes import save_videos
import IPython import IPython
e = IPython.embed e = IPython.embed
def load_checkpoint_state_dict(ckpt_path):
"""Load checkpoint state_dict safely across different torch versions."""
try:
return torch.load(ckpt_path, map_location='cpu', weights_only=True)
except TypeError:
# For older PyTorch versions that do not support `weights_only`.
return torch.load(ckpt_path, map_location='cpu')
def main(args): def main(args):
set_seed(1) set_seed(1)
# command line parameters # command line parameters
@@ -60,6 +69,7 @@ def main(args):
text_max_length = task_config.get('text_max_length', 32) text_max_length = task_config.get('text_max_length', 32)
text_tokenizer_name = task_config.get('text_tokenizer_name', 'distilbert-base-uncased') text_tokenizer_name = task_config.get('text_tokenizer_name', 'distilbert-base-uncased')
freeze_text_encoder = task_config.get('freeze_text_encoder', True) freeze_text_encoder = task_config.get('freeze_text_encoder', True)
real_action_t_minus_1 = task_config.get('real_action_t_minus_1', True)
if args.get('text_encoder_type') is not None: if args.get('text_encoder_type') is not None:
text_encoder_type = args['text_encoder_type'] text_encoder_type = args['text_encoder_type']
@@ -67,6 +77,8 @@ def main(args):
text_max_length = args['text_max_length'] text_max_length = args['text_max_length']
if args.get('freeze_text_encoder', False): if args.get('freeze_text_encoder', False):
freeze_text_encoder = True freeze_text_encoder = True
if args.get('disable_real_action_shift', False):
real_action_t_minus_1 = False
# fixed parameters # fixed parameters
lr_backbone = 1e-5 lr_backbone = 1e-5
@@ -110,6 +122,8 @@ def main(args):
config = { config = {
'num_epochs': num_epochs, 'num_epochs': num_epochs,
'train_steps_per_epoch': args.get('train_steps_per_epoch', None),
'resume_ckpt_path': args.get('resume_ckpt', None),
'ckpt_dir': ckpt_dir, 'ckpt_dir': ckpt_dir,
'episode_len': episode_len, 'episode_len': episode_len,
'state_dim': state_dim, 'state_dim': state_dim,
@@ -131,6 +145,16 @@ def main(args):
'debug_input': args.get('debug_input', False), 'debug_input': args.get('debug_input', False),
} }
if config['resume_ckpt_path']:
resume_ckpt_path = config['resume_ckpt_path']
if not os.path.isabs(resume_ckpt_path):
candidate_path = os.path.join(ckpt_dir, resume_ckpt_path)
if os.path.isfile(candidate_path):
resume_ckpt_path = candidate_path
if not os.path.isfile(resume_ckpt_path):
raise FileNotFoundError(f'--resume_ckpt not found: {config["resume_ckpt_path"]}')
config['resume_ckpt_path'] = resume_ckpt_path
if is_eval: if is_eval:
ckpt_names = [f'policy_best.ckpt'] ckpt_names = [f'policy_best.ckpt']
results = [] results = []
@@ -155,6 +179,7 @@ def main(args):
text_feature_dim=text_feature_dim, text_feature_dim=text_feature_dim,
text_tokenizer_name=text_tokenizer_name, text_tokenizer_name=text_tokenizer_name,
text_max_length=text_max_length, text_max_length=text_max_length,
real_action_t_minus_1=real_action_t_minus_1,
image_augment=args['image_aug'], image_augment=args['image_aug'],
) )
@@ -223,7 +248,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
# load policy and stats # load policy and stats
ckpt_path = os.path.join(ckpt_dir, ckpt_name) ckpt_path = os.path.join(ckpt_dir, ckpt_name)
policy = make_policy(policy_class, policy_config) policy = make_policy(policy_class, policy_config)
loading_status = policy.load_state_dict(torch.load(ckpt_path)) loading_status = policy.load_state_dict(load_checkpoint_state_dict(ckpt_path))
print(loading_status) print(loading_status)
policy.cuda() policy.cuda()
policy.eval() policy.eval()
@@ -425,6 +450,8 @@ def forward_pass(data, policy, debug_input=False, debug_tag=''):
def train_bc(train_dataloader, val_dataloader, config): def train_bc(train_dataloader, val_dataloader, config):
num_epochs = config['num_epochs'] num_epochs = config['num_epochs']
train_steps_per_epoch = config.get('train_steps_per_epoch', None)
resume_ckpt_path = config.get('resume_ckpt_path', None)
ckpt_dir = config['ckpt_dir'] ckpt_dir = config['ckpt_dir']
seed = config['seed'] seed = config['seed']
policy_class = config['policy_class'] policy_class = config['policy_class']
@@ -435,6 +462,12 @@ def train_bc(train_dataloader, val_dataloader, config):
policy = make_policy(policy_class, policy_config) policy = make_policy(policy_class, policy_config)
policy.cuda() policy.cuda()
if resume_ckpt_path:
loading_status = policy.load_state_dict(load_checkpoint_state_dict(resume_ckpt_path))
print(f'Loaded finetune init ckpt: {resume_ckpt_path}')
print(loading_status)
optimizer = make_optimizer(policy_class, policy) optimizer = make_optimizer(policy_class, policy)
train_history = [] train_history = []
@@ -467,8 +500,22 @@ def train_bc(train_dataloader, val_dataloader, config):
# training # training
policy.train() policy.train()
optimizer.zero_grad() optimizer.zero_grad()
for batch_idx, data in enumerate(train_dataloader): epoch_train_dicts = []
should_debug = debug_input and epoch == 0 and batch_idx == 0 if train_steps_per_epoch is None or train_steps_per_epoch <= 0:
train_steps_this_epoch = len(train_dataloader)
train_iterator = iter(train_dataloader)
else:
train_steps_this_epoch = int(train_steps_per_epoch)
train_iterator = iter(train_dataloader)
for step_idx in range(train_steps_this_epoch):
try:
data = next(train_iterator)
except StopIteration:
train_iterator = iter(train_dataloader)
data = next(train_iterator)
should_debug = debug_input and epoch == 0 and step_idx == 0
forward_dict = forward_pass(data, policy, debug_input=should_debug, debug_tag='train/epoch0/batch0') forward_dict = forward_pass(data, policy, debug_input=should_debug, debug_tag='train/epoch0/batch0')
# backward # backward
loss = forward_dict['loss'] loss = forward_dict['loss']
@@ -476,7 +523,8 @@ def train_bc(train_dataloader, val_dataloader, config):
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
train_history.append(detach_dict(forward_dict)) train_history.append(detach_dict(forward_dict))
epoch_summary = compute_dict_mean(train_history[(batch_idx+1)*epoch:(batch_idx+1)*(epoch+1)]) epoch_train_dicts.append(detach_dict(forward_dict))
epoch_summary = compute_dict_mean(epoch_train_dicts)
epoch_train_loss = epoch_summary['loss'] epoch_train_loss = epoch_summary['loss']
print(f'Train loss: {epoch_train_loss:.5f}') print(f'Train loss: {epoch_train_loss:.5f}')
summary_string = '' summary_string = ''
@@ -533,7 +581,7 @@ if __name__ == '__main__':
parser.add_argument('--lr', action='store', type=float, help='lr', required=True) parser.add_argument('--lr', action='store', type=float, help='lr', required=True)
# for ACT # for ACT
parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False) parser.add_argument('--kl_weight', action='store', type=float, help='KL Weight', required=False)
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False) parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)
parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', required=False) parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', required=False)
parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', required=False) parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', required=False)
@@ -543,6 +591,12 @@ if __name__ == '__main__':
parser.add_argument('--text_max_length', action='store', type=int, required=False) parser.add_argument('--text_max_length', action='store', type=int, required=False)
parser.add_argument('--image_aug', action='store_true', parser.add_argument('--image_aug', action='store_true',
help='Enable training-time image augmentation (color/highlight/noise/blur)') help='Enable training-time image augmentation (color/highlight/noise/blur)')
parser.add_argument('--train_steps_per_epoch', action='store', type=int, required=False,
help='If set > 0, run a fixed number of optimizer steps per epoch by cycling over the train dataloader')
parser.add_argument('--disable_real_action_shift', action='store_true',
help='Disable real-data action alignment shift (use action[start_ts:] instead of action[start_ts-1:])')
parser.add_argument('--resume_ckpt', action='store', type=str, required=False,
help='Optional checkpoint path to initialize model weights for fine-tuning')
parser.add_argument('--debug_input', action='store_true', parser.add_argument('--debug_input', action='store_true',
help='Print one-batch input sanity checks (shape/range/nan) for val/train at epoch 0') help='Print one-batch input sanity checks (shape/range/nan) for val/train at epoch 0')

View File

@@ -17,6 +17,7 @@ class EpisodicDataset(torch.utils.data.Dataset):
text_feature_dim=768, text_feature_dim=768,
text_tokenizer_name='distilbert-base-uncased', text_tokenizer_name='distilbert-base-uncased',
text_max_length=32, text_max_length=32,
real_action_t_minus_1=True,
image_augment=False, image_augment=False,
image_aug_cfg=None): image_aug_cfg=None):
super(EpisodicDataset).__init__() super(EpisodicDataset).__init__()
@@ -29,20 +30,21 @@ class EpisodicDataset(torch.utils.data.Dataset):
self.use_cached_text_features = use_cached_text_features self.use_cached_text_features = use_cached_text_features
self.text_feature_dim = text_feature_dim self.text_feature_dim = text_feature_dim
self.text_max_length = text_max_length self.text_max_length = text_max_length
self.real_action_t_minus_1 = real_action_t_minus_1
self.image_augment = image_augment self.image_augment = image_augment
self.image_aug_cfg = { self.image_aug_cfg = {
'p_color': 0.8, 'p_color': 0.4,
'p_highlight': 0.5, 'p_highlight': 0.3,
'p_noise': 0.5, 'p_noise': 0.35,
'p_blur': 0.3, 'p_blur': 0.15,
'brightness': 0.25, 'brightness': 0.12,
'contrast': 0.25, 'contrast': 0.12,
'saturation': 0.25, 'saturation': 0.12,
'hue': 0.08, 'hue': 0.03,
'highlight_strength': (0.15, 0.5), 'highlight_strength': (0.08, 0.25),
'noise_std': (0.005, 0.03), 'noise_std': (0.003, 0.015),
'blur_sigma': (0.1, 1.5), 'blur_sigma': (0.1, 0.8),
'blur_kernel_choices': (3, 5), 'blur_kernel_choices': (3, ),
} }
if image_aug_cfg is not None: if image_aug_cfg is not None:
self.image_aug_cfg.update(image_aug_cfg) self.image_aug_cfg.update(image_aug_cfg)
@@ -221,8 +223,9 @@ class EpisodicDataset(torch.utils.data.Dataset):
action = root['/action'][start_ts:] action = root['/action'][start_ts:]
action_len = episode_len - start_ts action_len = episode_len - start_ts
else: else:
action = root['/action'][max(0, start_ts - 1):] # hack, to make timesteps more aligned action_start = max(0, start_ts - 1) if self.real_action_t_minus_1 else start_ts
action_len = episode_len - max(0, start_ts - 1) # hack, to make timesteps more aligned action = root['/action'][action_start:]
action_len = episode_len - action_start
self.is_sim = is_sim self.is_sim = is_sim
padded_action = np.zeros((self.max_episode_len, self.action_dim), dtype=np.float32) padded_action = np.zeros((self.max_episode_len, self.action_dim), dtype=np.float32)
@@ -334,19 +337,26 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s
text_feature_dim=768, text_feature_dim=768,
text_tokenizer_name='distilbert-base-uncased', text_tokenizer_name='distilbert-base-uncased',
text_max_length=32, text_max_length=32,
real_action_t_minus_1=True,
image_augment=False, image_augment=False,
image_aug_cfg=None): image_aug_cfg=None):
print(f'\nData from: {dataset_dir}\n') print(f'\nData from: {dataset_dir}\n')
episode_ids = _discover_episode_ids(dataset_dir, num_episodes) episode_ids = _discover_episode_ids(dataset_dir, num_episodes)
if len(episode_ids) == 0: if len(episode_ids) == 0:
raise FileNotFoundError(f'No episode_*.hdf5 found in {dataset_dir}') raise FileNotFoundError(f'No episode_*.hdf5 found in {dataset_dir}')
if len(episode_ids) < 2:
raise ValueError(f'Need at least 2 episodes for train/val split, found {len(episode_ids)} in {dataset_dir}')
# obtain train test split # obtain train/val split
if len(episode_ids) == 1:
# sanity-check mode: reuse the same episode for both train and val
# so training/evaluation loops remain unchanged.
train_episode_ids = np.array(episode_ids)
val_episode_ids = np.array(episode_ids)
print('[load_data] Only 1 episode found. Reusing the same episode for both train and val (sanity-check mode).')
else:
train_ratio = 0.9 train_ratio = 0.9
shuffled_indices = np.random.permutation(len(episode_ids)) shuffled_indices = np.random.permutation(len(episode_ids))
train_count = int(train_ratio * len(episode_ids)) train_count = int(train_ratio * len(episode_ids))
train_count = max(1, min(len(episode_ids) - 1, train_count))
train_indices = shuffled_indices[:train_count] train_indices = shuffled_indices[:train_count]
val_indices = shuffled_indices[train_count:] val_indices = shuffled_indices[train_count:]
train_episode_ids = np.array(episode_ids)[train_indices] train_episode_ids = np.array(episode_ids)[train_indices]
@@ -367,6 +377,7 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s
text_feature_dim=text_feature_dim, text_feature_dim=text_feature_dim,
text_tokenizer_name=text_tokenizer_name, text_tokenizer_name=text_tokenizer_name,
text_max_length=text_max_length, text_max_length=text_max_length,
real_action_t_minus_1=real_action_t_minus_1,
image_augment=image_augment, image_augment=image_augment,
image_aug_cfg=image_aug_cfg, image_aug_cfg=image_aug_cfg,
) )
@@ -381,6 +392,7 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s
text_feature_dim=text_feature_dim, text_feature_dim=text_feature_dim,
text_tokenizer_name=text_tokenizer_name, text_tokenizer_name=text_tokenizer_name,
text_max_length=text_max_length, text_max_length=text_max_length,
real_action_t_minus_1=real_action_t_minus_1,
image_augment=False, image_augment=False,
) )
train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1) train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)