follow用policy_last
This commit is contained in:
33
constants.py
33
constants.py
@@ -40,6 +40,7 @@ ENDOSCOPE_TASK_CONFIGS = {
|
||||
'camera_names': ['top'],
|
||||
'state_dim': 2,
|
||||
'action_dim': 2,
|
||||
'real_action_t_minus_1': False,
|
||||
'use_text_instruction': True,
|
||||
'instruction_mode': 'timestep-level',
|
||||
'use_cached_text_features': True,
|
||||
@@ -57,6 +58,7 @@ ENDOSCOPE_TASK_CONFIGS = {
|
||||
'camera_names': ['top'],
|
||||
'state_dim': 2,
|
||||
'action_dim': 2,
|
||||
'real_action_t_minus_1': False,
|
||||
'use_text_instruction': True,
|
||||
'instruction_mode': 'timestep-level',
|
||||
'use_cached_text_features': True,
|
||||
@@ -74,6 +76,37 @@ ENDOSCOPE_TASK_CONFIGS = {
|
||||
'camera_names': ['top'],
|
||||
'state_dim': 2,
|
||||
'action_dim': 2,
|
||||
'real_action_t_minus_1': False,
|
||||
'use_text_instruction': False,
|
||||
},
|
||||
'endoscope_sanity_check': {
|
||||
'dataset_dir': DATA_DIR + '/sanity-check',
|
||||
'num_episodes': 3,
|
||||
'episode_len': 400,
|
||||
'camera_names': ['top'],
|
||||
'state_dim': 2,
|
||||
'action_dim': 2,
|
||||
'real_action_t_minus_1': False,
|
||||
'use_text_instruction': False,
|
||||
},
|
||||
'endoscope_cannulation_no_text': {
|
||||
'dataset_dir': DATA_DIR + '/cannulation-no-text',
|
||||
'num_episodes': 3,
|
||||
'episode_len': 400,
|
||||
'camera_names': ['top'],
|
||||
'state_dim': 2,
|
||||
'action_dim': 2,
|
||||
'real_action_t_minus_1': False,
|
||||
'use_text_instruction': False,
|
||||
},
|
||||
'endoscope_follow_no_text': {
|
||||
'dataset_dir': DATA_DIR + '/follow-no-text',
|
||||
'num_episodes': 3,
|
||||
'episode_len': 400,
|
||||
'camera_names': ['top'],
|
||||
'state_dim': 2,
|
||||
'action_dim': 2,
|
||||
'real_action_t_minus_1': False,
|
||||
'use_text_instruction': False,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ def get_args_parser():
|
||||
parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True)
|
||||
parser.add_argument('--seed', action='store', type=int, help='seed', required=True)
|
||||
parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', required=True)
|
||||
parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False)
|
||||
parser.add_argument('--kl_weight', action='store', type=float, help='KL Weight', required=False)
|
||||
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)
|
||||
parser.add_argument('--temporal_agg', action='store_true')
|
||||
parser.add_argument('--image_aug', action='store_true')
|
||||
|
||||
@@ -20,6 +20,15 @@ from visualize_episodes import save_videos
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
|
||||
def load_checkpoint_state_dict(ckpt_path):
|
||||
"""Load checkpoint state_dict safely across different torch versions."""
|
||||
try:
|
||||
return torch.load(ckpt_path, map_location='cpu', weights_only=True)
|
||||
except TypeError:
|
||||
# For older PyTorch versions that do not support `weights_only`.
|
||||
return torch.load(ckpt_path, map_location='cpu')
|
||||
|
||||
def main(args):
|
||||
set_seed(1)
|
||||
# command line parameters
|
||||
@@ -60,6 +69,7 @@ def main(args):
|
||||
text_max_length = task_config.get('text_max_length', 32)
|
||||
text_tokenizer_name = task_config.get('text_tokenizer_name', 'distilbert-base-uncased')
|
||||
freeze_text_encoder = task_config.get('freeze_text_encoder', True)
|
||||
real_action_t_minus_1 = task_config.get('real_action_t_minus_1', True)
|
||||
|
||||
if args.get('text_encoder_type') is not None:
|
||||
text_encoder_type = args['text_encoder_type']
|
||||
@@ -67,6 +77,8 @@ def main(args):
|
||||
text_max_length = args['text_max_length']
|
||||
if args.get('freeze_text_encoder', False):
|
||||
freeze_text_encoder = True
|
||||
if args.get('disable_real_action_shift', False):
|
||||
real_action_t_minus_1 = False
|
||||
|
||||
# fixed parameters
|
||||
lr_backbone = 1e-5
|
||||
@@ -110,6 +122,8 @@ def main(args):
|
||||
|
||||
config = {
|
||||
'num_epochs': num_epochs,
|
||||
'train_steps_per_epoch': args.get('train_steps_per_epoch', None),
|
||||
'resume_ckpt_path': args.get('resume_ckpt', None),
|
||||
'ckpt_dir': ckpt_dir,
|
||||
'episode_len': episode_len,
|
||||
'state_dim': state_dim,
|
||||
@@ -131,6 +145,16 @@ def main(args):
|
||||
'debug_input': args.get('debug_input', False),
|
||||
}
|
||||
|
||||
if config['resume_ckpt_path']:
|
||||
resume_ckpt_path = config['resume_ckpt_path']
|
||||
if not os.path.isabs(resume_ckpt_path):
|
||||
candidate_path = os.path.join(ckpt_dir, resume_ckpt_path)
|
||||
if os.path.isfile(candidate_path):
|
||||
resume_ckpt_path = candidate_path
|
||||
if not os.path.isfile(resume_ckpt_path):
|
||||
raise FileNotFoundError(f'--resume_ckpt not found: {config["resume_ckpt_path"]}')
|
||||
config['resume_ckpt_path'] = resume_ckpt_path
|
||||
|
||||
if is_eval:
|
||||
ckpt_names = [f'policy_best.ckpt']
|
||||
results = []
|
||||
@@ -155,6 +179,7 @@ def main(args):
|
||||
text_feature_dim=text_feature_dim,
|
||||
text_tokenizer_name=text_tokenizer_name,
|
||||
text_max_length=text_max_length,
|
||||
real_action_t_minus_1=real_action_t_minus_1,
|
||||
image_augment=args['image_aug'],
|
||||
)
|
||||
|
||||
@@ -223,7 +248,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
|
||||
# load policy and stats
|
||||
ckpt_path = os.path.join(ckpt_dir, ckpt_name)
|
||||
policy = make_policy(policy_class, policy_config)
|
||||
loading_status = policy.load_state_dict(torch.load(ckpt_path))
|
||||
loading_status = policy.load_state_dict(load_checkpoint_state_dict(ckpt_path))
|
||||
print(loading_status)
|
||||
policy.cuda()
|
||||
policy.eval()
|
||||
@@ -425,6 +450,8 @@ def forward_pass(data, policy, debug_input=False, debug_tag=''):
|
||||
|
||||
def train_bc(train_dataloader, val_dataloader, config):
|
||||
num_epochs = config['num_epochs']
|
||||
train_steps_per_epoch = config.get('train_steps_per_epoch', None)
|
||||
resume_ckpt_path = config.get('resume_ckpt_path', None)
|
||||
ckpt_dir = config['ckpt_dir']
|
||||
seed = config['seed']
|
||||
policy_class = config['policy_class']
|
||||
@@ -435,6 +462,12 @@ def train_bc(train_dataloader, val_dataloader, config):
|
||||
|
||||
policy = make_policy(policy_class, policy_config)
|
||||
policy.cuda()
|
||||
|
||||
if resume_ckpt_path:
|
||||
loading_status = policy.load_state_dict(load_checkpoint_state_dict(resume_ckpt_path))
|
||||
print(f'Loaded finetune init ckpt: {resume_ckpt_path}')
|
||||
print(loading_status)
|
||||
|
||||
optimizer = make_optimizer(policy_class, policy)
|
||||
|
||||
train_history = []
|
||||
@@ -467,8 +500,22 @@ def train_bc(train_dataloader, val_dataloader, config):
|
||||
# training
|
||||
policy.train()
|
||||
optimizer.zero_grad()
|
||||
for batch_idx, data in enumerate(train_dataloader):
|
||||
should_debug = debug_input and epoch == 0 and batch_idx == 0
|
||||
epoch_train_dicts = []
|
||||
if train_steps_per_epoch is None or train_steps_per_epoch <= 0:
|
||||
train_steps_this_epoch = len(train_dataloader)
|
||||
train_iterator = iter(train_dataloader)
|
||||
else:
|
||||
train_steps_this_epoch = int(train_steps_per_epoch)
|
||||
train_iterator = iter(train_dataloader)
|
||||
|
||||
for step_idx in range(train_steps_this_epoch):
|
||||
try:
|
||||
data = next(train_iterator)
|
||||
except StopIteration:
|
||||
train_iterator = iter(train_dataloader)
|
||||
data = next(train_iterator)
|
||||
|
||||
should_debug = debug_input and epoch == 0 and step_idx == 0
|
||||
forward_dict = forward_pass(data, policy, debug_input=should_debug, debug_tag='train/epoch0/batch0')
|
||||
# backward
|
||||
loss = forward_dict['loss']
|
||||
@@ -476,7 +523,8 @@ def train_bc(train_dataloader, val_dataloader, config):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
train_history.append(detach_dict(forward_dict))
|
||||
epoch_summary = compute_dict_mean(train_history[(batch_idx+1)*epoch:(batch_idx+1)*(epoch+1)])
|
||||
epoch_train_dicts.append(detach_dict(forward_dict))
|
||||
epoch_summary = compute_dict_mean(epoch_train_dicts)
|
||||
epoch_train_loss = epoch_summary['loss']
|
||||
print(f'Train loss: {epoch_train_loss:.5f}')
|
||||
summary_string = ''
|
||||
@@ -533,7 +581,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--lr', action='store', type=float, help='lr', required=True)
|
||||
|
||||
# for ACT
|
||||
parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False)
|
||||
parser.add_argument('--kl_weight', action='store', type=float, help='KL Weight', required=False)
|
||||
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)
|
||||
parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', required=False)
|
||||
parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', required=False)
|
||||
@@ -543,6 +591,12 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--text_max_length', action='store', type=int, required=False)
|
||||
parser.add_argument('--image_aug', action='store_true',
|
||||
help='Enable training-time image augmentation (color/highlight/noise/blur)')
|
||||
parser.add_argument('--train_steps_per_epoch', action='store', type=int, required=False,
|
||||
help='If set > 0, run a fixed number of optimizer steps per epoch by cycling over the train dataloader')
|
||||
parser.add_argument('--disable_real_action_shift', action='store_true',
|
||||
help='Disable real-data action alignment shift (use action[start_ts:] instead of action[start_ts-1:])')
|
||||
parser.add_argument('--resume_ckpt', action='store', type=str, required=False,
|
||||
help='Optional checkpoint path to initialize model weights for fine-tuning')
|
||||
parser.add_argument('--debug_input', action='store_true',
|
||||
help='Print one-batch input sanity checks (shape/range/nan) for val/train at epoch 0')
|
||||
|
||||
|
||||
46
utils.py
46
utils.py
@@ -17,6 +17,7 @@ class EpisodicDataset(torch.utils.data.Dataset):
|
||||
text_feature_dim=768,
|
||||
text_tokenizer_name='distilbert-base-uncased',
|
||||
text_max_length=32,
|
||||
real_action_t_minus_1=True,
|
||||
image_augment=False,
|
||||
image_aug_cfg=None):
|
||||
super(EpisodicDataset).__init__()
|
||||
@@ -29,20 +30,21 @@ class EpisodicDataset(torch.utils.data.Dataset):
|
||||
self.use_cached_text_features = use_cached_text_features
|
||||
self.text_feature_dim = text_feature_dim
|
||||
self.text_max_length = text_max_length
|
||||
self.real_action_t_minus_1 = real_action_t_minus_1
|
||||
self.image_augment = image_augment
|
||||
self.image_aug_cfg = {
|
||||
'p_color': 0.8,
|
||||
'p_highlight': 0.5,
|
||||
'p_noise': 0.5,
|
||||
'p_blur': 0.3,
|
||||
'brightness': 0.25,
|
||||
'contrast': 0.25,
|
||||
'saturation': 0.25,
|
||||
'hue': 0.08,
|
||||
'highlight_strength': (0.15, 0.5),
|
||||
'noise_std': (0.005, 0.03),
|
||||
'blur_sigma': (0.1, 1.5),
|
||||
'blur_kernel_choices': (3, 5),
|
||||
'p_color': 0.4,
|
||||
'p_highlight': 0.3,
|
||||
'p_noise': 0.35,
|
||||
'p_blur': 0.15,
|
||||
'brightness': 0.12,
|
||||
'contrast': 0.12,
|
||||
'saturation': 0.12,
|
||||
'hue': 0.03,
|
||||
'highlight_strength': (0.08, 0.25),
|
||||
'noise_std': (0.003, 0.015),
|
||||
'blur_sigma': (0.1, 0.8),
|
||||
'blur_kernel_choices': (3, ),
|
||||
}
|
||||
if image_aug_cfg is not None:
|
||||
self.image_aug_cfg.update(image_aug_cfg)
|
||||
@@ -221,8 +223,9 @@ class EpisodicDataset(torch.utils.data.Dataset):
|
||||
action = root['/action'][start_ts:]
|
||||
action_len = episode_len - start_ts
|
||||
else:
|
||||
action = root['/action'][max(0, start_ts - 1):] # hack, to make timesteps more aligned
|
||||
action_len = episode_len - max(0, start_ts - 1) # hack, to make timesteps more aligned
|
||||
action_start = max(0, start_ts - 1) if self.real_action_t_minus_1 else start_ts
|
||||
action = root['/action'][action_start:]
|
||||
action_len = episode_len - action_start
|
||||
|
||||
self.is_sim = is_sim
|
||||
padded_action = np.zeros((self.max_episode_len, self.action_dim), dtype=np.float32)
|
||||
@@ -334,19 +337,26 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s
|
||||
text_feature_dim=768,
|
||||
text_tokenizer_name='distilbert-base-uncased',
|
||||
text_max_length=32,
|
||||
real_action_t_minus_1=True,
|
||||
image_augment=False,
|
||||
image_aug_cfg=None):
|
||||
print(f'\nData from: {dataset_dir}\n')
|
||||
episode_ids = _discover_episode_ids(dataset_dir, num_episodes)
|
||||
if len(episode_ids) == 0:
|
||||
raise FileNotFoundError(f'No episode_*.hdf5 found in {dataset_dir}')
|
||||
if len(episode_ids) < 2:
|
||||
raise ValueError(f'Need at least 2 episodes for train/val split, found {len(episode_ids)} in {dataset_dir}')
|
||||
|
||||
# obtain train test split
|
||||
# obtain train/val split
|
||||
if len(episode_ids) == 1:
|
||||
# sanity-check mode: reuse the same episode for both train and val
|
||||
# so training/evaluation loops remain unchanged.
|
||||
train_episode_ids = np.array(episode_ids)
|
||||
val_episode_ids = np.array(episode_ids)
|
||||
print('[load_data] Only 1 episode found. Reusing the same episode for both train and val (sanity-check mode).')
|
||||
else:
|
||||
train_ratio = 0.9
|
||||
shuffled_indices = np.random.permutation(len(episode_ids))
|
||||
train_count = int(train_ratio * len(episode_ids))
|
||||
train_count = max(1, min(len(episode_ids) - 1, train_count))
|
||||
train_indices = shuffled_indices[:train_count]
|
||||
val_indices = shuffled_indices[train_count:]
|
||||
train_episode_ids = np.array(episode_ids)[train_indices]
|
||||
@@ -367,6 +377,7 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s
|
||||
text_feature_dim=text_feature_dim,
|
||||
text_tokenizer_name=text_tokenizer_name,
|
||||
text_max_length=text_max_length,
|
||||
real_action_t_minus_1=real_action_t_minus_1,
|
||||
image_augment=image_augment,
|
||||
image_aug_cfg=image_aug_cfg,
|
||||
)
|
||||
@@ -381,6 +392,7 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s
|
||||
text_feature_dim=text_feature_dim,
|
||||
text_tokenizer_name=text_tokenizer_name,
|
||||
text_max_length=text_max_length,
|
||||
real_action_t_minus_1=real_action_t_minus_1,
|
||||
image_augment=False,
|
||||
)
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
|
||||
|
||||
Reference in New Issue
Block a user