diff --git a/detr/main.py b/detr/main.py index 044b2a3..53ce92a 100644 --- a/detr/main.py +++ b/detr/main.py @@ -80,7 +80,7 @@ def get_args_parser(): def build_ACT_model_and_optimizer(args_override): parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()]) - args = parser.parse_args() + args, _ = parser.parse_known_args() for k, v in args_override.items(): setattr(args, k, v) @@ -103,7 +103,7 @@ def build_ACT_model_and_optimizer(args_override): def build_CNNMLP_model_and_optimizer(args_override): parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()]) - args = parser.parse_args() + args, _ = parser.parse_known_args() for k, v in args_override.items(): setattr(args, k, v) diff --git a/imitate_episodes.py b/imitate_episodes.py index 0b3dd88..eda6861 100644 --- a/imitate_episodes.py +++ b/imitate_episodes.py @@ -128,6 +128,7 @@ def main(args): 'use_cached_text_features': use_cached_text_features, 'text_tokenizer_name': text_tokenizer_name, 'text_max_length': text_max_length, + 'debug_input': args.get('debug_input', False), } if is_eval: @@ -376,7 +377,7 @@ def eval_bc(config, ckpt_name, save_episode=True): return success_rate, avg_return -def forward_pass(data, policy): +def forward_pass(data, policy, debug_input=False, debug_tag=''): image_data, qpos_data, action_data, is_pad, text_input_ids, text_attention_mask, text_feature_data, text_feature_valid = data image_data = image_data.cuda() qpos_data = qpos_data.cuda() @@ -391,6 +392,26 @@ def forward_pass(data, policy): if torch.any(text_feature_valid): text_features = text_feature_data + if debug_input: + image_min = float(image_data.min().item()) + image_max = float(image_data.max().item()) + qpos_mean = float(qpos_data.mean().item()) + qpos_std = float(qpos_data.std().item()) + action_mean = float(action_data.mean().item()) + action_std = float(action_data.std().item()) + pad_ratio = float(is_pad.float().mean().item()) + + print(f'[debug_input] {debug_tag} image shape={tuple(image_data.shape)} range=[{image_min:.4f}, {image_max:.4f}]') + print(f'[debug_input] {debug_tag} qpos shape={tuple(qpos_data.shape)} mean/std=({qpos_mean:.4f}, {qpos_std:.4f})') + print(f'[debug_input] {debug_tag} action shape={tuple(action_data.shape)} mean/std=({action_mean:.4f}, {action_std:.4f})') + print(f'[debug_input] {debug_tag} is_pad shape={tuple(is_pad.shape)} pad_ratio={pad_ratio:.4f}') + print( + f'[debug_input] {debug_tag} has_nan_or_inf: ' + f'image={bool(torch.logical_not(torch.isfinite(image_data)).any().item())}, ' + f'qpos={bool(torch.logical_not(torch.isfinite(qpos_data)).any().item())}, ' + f'action={bool(torch.logical_not(torch.isfinite(action_data)).any().item())}' + ) + return policy( qpos_data, image_data, @@ -408,6 +429,7 @@ def train_bc(train_dataloader, val_dataloader, config): seed = config['seed'] policy_class = config['policy_class'] policy_config = config['policy_config'] + debug_input = config.get('debug_input', False) set_seed(seed) @@ -426,7 +448,8 @@ def train_bc(train_dataloader, val_dataloader, config): policy.eval() epoch_dicts = [] for batch_idx, data in enumerate(val_dataloader): - forward_dict = forward_pass(data, policy) + should_debug = debug_input and epoch == 0 and batch_idx == 0 + forward_dict = forward_pass(data, policy, debug_input=should_debug, debug_tag='val/epoch0/batch0') epoch_dicts.append(forward_dict) epoch_summary = compute_dict_mean(epoch_dicts) validation_history.append(epoch_summary) @@ -445,7 +468,8 @@ def train_bc(train_dataloader, val_dataloader, config): policy.train() optimizer.zero_grad() for batch_idx, data in enumerate(train_dataloader): - forward_dict = forward_pass(data, policy) + should_debug = debug_input and epoch == 0 and batch_idx == 0 + forward_dict = forward_pass(data, policy, debug_input=should_debug, debug_tag='train/epoch0/batch0') # backward loss = forward_dict['loss'] loss.backward() @@ -519,5 +543,7 @@ if __name__ == '__main__': parser.add_argument('--text_max_length', action='store', type=int, required=False) parser.add_argument('--image_aug', action='store_true', help='Enable training-time image augmentation (color/highlight/noise/blur)') + parser.add_argument('--debug_input', action='store_true', + help='Print one-batch input sanity checks (shape/range/nan) for val/train at epoch 0') main(vars(parser.parse_args()))