加了--debug_input参数

This commit is contained in:
2026-02-20 14:59:44 +08:00
parent d85cce8a52
commit 88d0cc5ca2
2 changed files with 31 additions and 5 deletions

View File

@@ -128,6 +128,7 @@ def main(args):
'use_cached_text_features': use_cached_text_features,
'text_tokenizer_name': text_tokenizer_name,
'text_max_length': text_max_length,
'debug_input': args.get('debug_input', False),
}
if is_eval:
@@ -376,7 +377,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
return success_rate, avg_return
def forward_pass(data, policy):
def forward_pass(data, policy, debug_input=False, debug_tag=''):
image_data, qpos_data, action_data, is_pad, text_input_ids, text_attention_mask, text_feature_data, text_feature_valid = data
image_data = image_data.cuda()
qpos_data = qpos_data.cuda()
@@ -391,6 +392,26 @@ def forward_pass(data, policy):
if torch.any(text_feature_valid):
text_features = text_feature_data
if debug_input:
image_min = float(image_data.min().item())
image_max = float(image_data.max().item())
qpos_mean = float(qpos_data.mean().item())
qpos_std = float(qpos_data.std().item())
action_mean = float(action_data.mean().item())
action_std = float(action_data.std().item())
pad_ratio = float(is_pad.float().mean().item())
print(f'[debug_input] {debug_tag} image shape={tuple(image_data.shape)} range=[{image_min:.4f}, {image_max:.4f}]')
print(f'[debug_input] {debug_tag} qpos shape={tuple(qpos_data.shape)} mean/std=({qpos_mean:.4f}, {qpos_std:.4f})')
print(f'[debug_input] {debug_tag} action shape={tuple(action_data.shape)} mean/std=({action_mean:.4f}, {action_std:.4f})')
print(f'[debug_input] {debug_tag} is_pad shape={tuple(is_pad.shape)} pad_ratio={pad_ratio:.4f}')
print(
f'[debug_input] {debug_tag} has_nan_or_inf: '
f'image={bool(torch.logical_not(torch.isfinite(image_data)).any().item())}, '
f'qpos={bool(torch.logical_not(torch.isfinite(qpos_data)).any().item())}, '
f'action={bool(torch.logical_not(torch.isfinite(action_data)).any().item())}'
)
return policy(
qpos_data,
image_data,
@@ -408,6 +429,7 @@ def train_bc(train_dataloader, val_dataloader, config):
seed = config['seed']
policy_class = config['policy_class']
policy_config = config['policy_config']
debug_input = config.get('debug_input', False)
set_seed(seed)
@@ -426,7 +448,8 @@ def train_bc(train_dataloader, val_dataloader, config):
policy.eval()
epoch_dicts = []
for batch_idx, data in enumerate(val_dataloader):
forward_dict = forward_pass(data, policy)
should_debug = debug_input and epoch == 0 and batch_idx == 0
forward_dict = forward_pass(data, policy, debug_input=should_debug, debug_tag='val/epoch0/batch0')
epoch_dicts.append(forward_dict)
epoch_summary = compute_dict_mean(epoch_dicts)
validation_history.append(epoch_summary)
@@ -445,7 +468,8 @@ def train_bc(train_dataloader, val_dataloader, config):
policy.train()
optimizer.zero_grad()
for batch_idx, data in enumerate(train_dataloader):
forward_dict = forward_pass(data, policy)
should_debug = debug_input and epoch == 0 and batch_idx == 0
forward_dict = forward_pass(data, policy, debug_input=should_debug, debug_tag='train/epoch0/batch0')
# backward
loss = forward_dict['loss']
loss.backward()
@@ -519,5 +543,7 @@ if __name__ == '__main__':
parser.add_argument('--text_max_length', action='store', type=int, required=False)
parser.add_argument('--image_aug', action='store_true',
help='Enable training-time image augmentation (color/highlight/noise/blur)')
parser.add_argument('--debug_input', action='store_true',
help='Print one-batch input sanity checks (shape/range/nan) for val/train at epoch 0')
main(vars(parser.parse_args()))