From 81e1bf8838c785ff3a8db376a0697c79d489564c Mon Sep 17 00:00:00 2001 From: JC6123 Date: Fri, 20 Feb 2026 16:45:16 +0800 Subject: [PATCH] =?UTF-8?q?follow=E7=94=A8policy=5Flast?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- constants.py | 33 +++++++++++++++++++++++ detr/main.py | 2 +- imitate_episodes.py | 64 +++++++++++++++++++++++++++++++++++++++++---- utils.py | 60 +++++++++++++++++++++++++----------------- 4 files changed, 129 insertions(+), 30 deletions(-) diff --git a/constants.py b/constants.py index 463acad..764e61f 100644 --- a/constants.py +++ b/constants.py @@ -40,6 +40,7 @@ ENDOSCOPE_TASK_CONFIGS = { 'camera_names': ['top'], 'state_dim': 2, 'action_dim': 2, + 'real_action_t_minus_1': False, 'use_text_instruction': True, 'instruction_mode': 'timestep-level', 'use_cached_text_features': True, @@ -57,6 +58,7 @@ ENDOSCOPE_TASK_CONFIGS = { 'camera_names': ['top'], 'state_dim': 2, 'action_dim': 2, + 'real_action_t_minus_1': False, 'use_text_instruction': True, 'instruction_mode': 'timestep-level', 'use_cached_text_features': True, @@ -74,6 +76,37 @@ ENDOSCOPE_TASK_CONFIGS = { 'camera_names': ['top'], 'state_dim': 2, 'action_dim': 2, + 'real_action_t_minus_1': False, + 'use_text_instruction': False, + }, + 'endoscope_sanity_check': { + 'dataset_dir': DATA_DIR + '/sanity-check', + 'num_episodes': 3, + 'episode_len': 400, + 'camera_names': ['top'], + 'state_dim': 2, + 'action_dim': 2, + 'real_action_t_minus_1': False, + 'use_text_instruction': False, + }, + 'endoscope_cannulation_no_text': { + 'dataset_dir': DATA_DIR + '/cannulation-no-text', + 'num_episodes': 3, + 'episode_len': 400, + 'camera_names': ['top'], + 'state_dim': 2, + 'action_dim': 2, + 'real_action_t_minus_1': False, + 'use_text_instruction': False, + }, + 'endoscope_follow_no_text': { + 'dataset_dir': DATA_DIR + '/follow-no-text', + 'num_episodes': 3, + 'episode_len': 400, + 'camera_names': ['top'], + 'state_dim': 2, + 'action_dim': 2, + 'real_action_t_minus_1': False, 'use_text_instruction': False, }, } diff --git a/detr/main.py b/detr/main.py index 53ce92a..fe7b2f3 100644 --- a/detr/main.py +++ b/detr/main.py @@ -70,7 +70,7 @@ def get_args_parser(): parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True) parser.add_argument('--seed', action='store', type=int, help='seed', required=True) parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', required=True) - parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False) + parser.add_argument('--kl_weight', action='store', type=float, help='KL Weight', required=False) parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False) parser.add_argument('--temporal_agg', action='store_true') parser.add_argument('--image_aug', action='store_true') diff --git a/imitate_episodes.py b/imitate_episodes.py index eda6861..eee08d4 100644 --- a/imitate_episodes.py +++ b/imitate_episodes.py @@ -20,6 +20,15 @@ from visualize_episodes import save_videos import IPython e = IPython.embed + +def load_checkpoint_state_dict(ckpt_path): + """Load checkpoint state_dict safely across different torch versions.""" + try: + return torch.load(ckpt_path, map_location='cpu', weights_only=True) + except TypeError: + # For older PyTorch versions that do not support `weights_only`. + return torch.load(ckpt_path, map_location='cpu') + def main(args): set_seed(1) # command line parameters @@ -60,6 +69,7 @@ def main(args): text_max_length = task_config.get('text_max_length', 32) text_tokenizer_name = task_config.get('text_tokenizer_name', 'distilbert-base-uncased') freeze_text_encoder = task_config.get('freeze_text_encoder', True) + real_action_t_minus_1 = task_config.get('real_action_t_minus_1', True) if args.get('text_encoder_type') is not None: text_encoder_type = args['text_encoder_type'] @@ -67,6 +77,8 @@ def main(args): text_max_length = args['text_max_length'] if args.get('freeze_text_encoder', False): freeze_text_encoder = True + if args.get('disable_real_action_shift', False): + real_action_t_minus_1 = False # fixed parameters lr_backbone = 1e-5 @@ -110,6 +122,8 @@ def main(args): config = { 'num_epochs': num_epochs, + 'train_steps_per_epoch': args.get('train_steps_per_epoch', None), + 'resume_ckpt_path': args.get('resume_ckpt', None), 'ckpt_dir': ckpt_dir, 'episode_len': episode_len, 'state_dim': state_dim, @@ -131,6 +145,16 @@ def main(args): 'debug_input': args.get('debug_input', False), } + if config['resume_ckpt_path']: + resume_ckpt_path = config['resume_ckpt_path'] + if not os.path.isabs(resume_ckpt_path): + candidate_path = os.path.join(ckpt_dir, resume_ckpt_path) + if os.path.isfile(candidate_path): + resume_ckpt_path = candidate_path + if not os.path.isfile(resume_ckpt_path): + raise FileNotFoundError(f'--resume_ckpt not found: {config["resume_ckpt_path"]}') + config['resume_ckpt_path'] = resume_ckpt_path + if is_eval: ckpt_names = [f'policy_best.ckpt'] results = [] @@ -155,6 +179,7 @@ def main(args): text_feature_dim=text_feature_dim, text_tokenizer_name=text_tokenizer_name, text_max_length=text_max_length, + real_action_t_minus_1=real_action_t_minus_1, image_augment=args['image_aug'], ) @@ -223,7 +248,7 @@ def eval_bc(config, ckpt_name, save_episode=True): # load policy and stats ckpt_path = os.path.join(ckpt_dir, ckpt_name) policy = make_policy(policy_class, policy_config) - loading_status = policy.load_state_dict(torch.load(ckpt_path)) + loading_status = policy.load_state_dict(load_checkpoint_state_dict(ckpt_path)) print(loading_status) policy.cuda() policy.eval() @@ -425,6 +450,8 @@ def forward_pass(data, policy, debug_input=False, debug_tag=''): def train_bc(train_dataloader, val_dataloader, config): num_epochs = config['num_epochs'] + train_steps_per_epoch = config.get('train_steps_per_epoch', None) + resume_ckpt_path = config.get('resume_ckpt_path', None) ckpt_dir = config['ckpt_dir'] seed = config['seed'] policy_class = config['policy_class'] @@ -435,6 +462,12 @@ def train_bc(train_dataloader, val_dataloader, config): policy = make_policy(policy_class, policy_config) policy.cuda() + + if resume_ckpt_path: + loading_status = policy.load_state_dict(load_checkpoint_state_dict(resume_ckpt_path)) + print(f'Loaded finetune init ckpt: {resume_ckpt_path}') + print(loading_status) + optimizer = make_optimizer(policy_class, policy) train_history = [] @@ -467,8 +500,22 @@ def train_bc(train_dataloader, val_dataloader, config): # training policy.train() optimizer.zero_grad() - for batch_idx, data in enumerate(train_dataloader): - should_debug = debug_input and epoch == 0 and batch_idx == 0 + epoch_train_dicts = [] + if train_steps_per_epoch is None or train_steps_per_epoch <= 0: + train_steps_this_epoch = len(train_dataloader) + train_iterator = iter(train_dataloader) + else: + train_steps_this_epoch = int(train_steps_per_epoch) + train_iterator = iter(train_dataloader) + + for step_idx in range(train_steps_this_epoch): + try: + data = next(train_iterator) + except StopIteration: + train_iterator = iter(train_dataloader) + data = next(train_iterator) + + should_debug = debug_input and epoch == 0 and step_idx == 0 forward_dict = forward_pass(data, policy, debug_input=should_debug, debug_tag='train/epoch0/batch0') # backward loss = forward_dict['loss'] @@ -476,7 +523,8 @@ def train_bc(train_dataloader, val_dataloader, config): optimizer.step() optimizer.zero_grad() train_history.append(detach_dict(forward_dict)) - epoch_summary = compute_dict_mean(train_history[(batch_idx+1)*epoch:(batch_idx+1)*(epoch+1)]) + epoch_train_dicts.append(detach_dict(forward_dict)) + epoch_summary = compute_dict_mean(epoch_train_dicts) epoch_train_loss = epoch_summary['loss'] print(f'Train loss: {epoch_train_loss:.5f}') summary_string = '' @@ -533,7 +581,7 @@ if __name__ == '__main__': parser.add_argument('--lr', action='store', type=float, help='lr', required=True) # for ACT - parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False) + parser.add_argument('--kl_weight', action='store', type=float, help='KL Weight', required=False) parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False) parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', required=False) parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', required=False) @@ -543,6 +591,12 @@ if __name__ == '__main__': parser.add_argument('--text_max_length', action='store', type=int, required=False) parser.add_argument('--image_aug', action='store_true', help='Enable training-time image augmentation (color/highlight/noise/blur)') + parser.add_argument('--train_steps_per_epoch', action='store', type=int, required=False, + help='If set > 0, run a fixed number of optimizer steps per epoch by cycling over the train dataloader') + parser.add_argument('--disable_real_action_shift', action='store_true', + help='Disable real-data action alignment shift (use action[start_ts:] instead of action[start_ts-1:])') + parser.add_argument('--resume_ckpt', action='store', type=str, required=False, + help='Optional checkpoint path to initialize model weights for fine-tuning') parser.add_argument('--debug_input', action='store_true', help='Print one-batch input sanity checks (shape/range/nan) for val/train at epoch 0') diff --git a/utils.py b/utils.py index 6d2054c..7539787 100644 --- a/utils.py +++ b/utils.py @@ -17,6 +17,7 @@ class EpisodicDataset(torch.utils.data.Dataset): text_feature_dim=768, text_tokenizer_name='distilbert-base-uncased', text_max_length=32, + real_action_t_minus_1=True, image_augment=False, image_aug_cfg=None): super(EpisodicDataset).__init__() @@ -29,20 +30,21 @@ class EpisodicDataset(torch.utils.data.Dataset): self.use_cached_text_features = use_cached_text_features self.text_feature_dim = text_feature_dim self.text_max_length = text_max_length + self.real_action_t_minus_1 = real_action_t_minus_1 self.image_augment = image_augment self.image_aug_cfg = { - 'p_color': 0.8, - 'p_highlight': 0.5, - 'p_noise': 0.5, - 'p_blur': 0.3, - 'brightness': 0.25, - 'contrast': 0.25, - 'saturation': 0.25, - 'hue': 0.08, - 'highlight_strength': (0.15, 0.5), - 'noise_std': (0.005, 0.03), - 'blur_sigma': (0.1, 1.5), - 'blur_kernel_choices': (3, 5), + 'p_color': 0.4, + 'p_highlight': 0.3, + 'p_noise': 0.35, + 'p_blur': 0.15, + 'brightness': 0.12, + 'contrast': 0.12, + 'saturation': 0.12, + 'hue': 0.03, + 'highlight_strength': (0.08, 0.25), + 'noise_std': (0.003, 0.015), + 'blur_sigma': (0.1, 0.8), + 'blur_kernel_choices': (3, ), } if image_aug_cfg is not None: self.image_aug_cfg.update(image_aug_cfg) @@ -221,8 +223,9 @@ class EpisodicDataset(torch.utils.data.Dataset): action = root['/action'][start_ts:] action_len = episode_len - start_ts else: - action = root['/action'][max(0, start_ts - 1):] # hack, to make timesteps more aligned - action_len = episode_len - max(0, start_ts - 1) # hack, to make timesteps more aligned + action_start = max(0, start_ts - 1) if self.real_action_t_minus_1 else start_ts + action = root['/action'][action_start:] + action_len = episode_len - action_start self.is_sim = is_sim padded_action = np.zeros((self.max_episode_len, self.action_dim), dtype=np.float32) @@ -334,23 +337,30 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s text_feature_dim=768, text_tokenizer_name='distilbert-base-uncased', text_max_length=32, + real_action_t_minus_1=True, image_augment=False, image_aug_cfg=None): print(f'\nData from: {dataset_dir}\n') episode_ids = _discover_episode_ids(dataset_dir, num_episodes) if len(episode_ids) == 0: raise FileNotFoundError(f'No episode_*.hdf5 found in {dataset_dir}') - if len(episode_ids) < 2: - raise ValueError(f'Need at least 2 episodes for train/val split, found {len(episode_ids)} in {dataset_dir}') - # obtain train test split - train_ratio = 0.9 - shuffled_indices = np.random.permutation(len(episode_ids)) - train_count = int(train_ratio * len(episode_ids)) - train_indices = shuffled_indices[:train_count] - val_indices = shuffled_indices[train_count:] - train_episode_ids = np.array(episode_ids)[train_indices] - val_episode_ids = np.array(episode_ids)[val_indices] + # obtain train/val split + if len(episode_ids) == 1: + # sanity-check mode: reuse the same episode for both train and val + # so training/evaluation loops remain unchanged. + train_episode_ids = np.array(episode_ids) + val_episode_ids = np.array(episode_ids) + print('[load_data] Only 1 episode found. Reusing the same episode for both train and val (sanity-check mode).') + else: + train_ratio = 0.9 + shuffled_indices = np.random.permutation(len(episode_ids)) + train_count = int(train_ratio * len(episode_ids)) + train_count = max(1, min(len(episode_ids) - 1, train_count)) + train_indices = shuffled_indices[:train_count] + val_indices = shuffled_indices[train_count:] + train_episode_ids = np.array(episode_ids)[train_indices] + val_episode_ids = np.array(episode_ids)[val_indices] # obtain normalization stats for qpos and action norm_stats = get_norm_stats(dataset_dir, episode_ids) @@ -367,6 +377,7 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s text_feature_dim=text_feature_dim, text_tokenizer_name=text_tokenizer_name, text_max_length=text_max_length, + real_action_t_minus_1=real_action_t_minus_1, image_augment=image_augment, image_aug_cfg=image_aug_cfg, ) @@ -381,6 +392,7 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s text_feature_dim=text_feature_dim, text_tokenizer_name=text_tokenizer_name, text_max_length=text_max_length, + real_action_t_minus_1=real_action_t_minus_1, image_augment=False, ) train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)