加了--debug_input参数
This commit is contained in:
@@ -128,6 +128,7 @@ def main(args):
|
||||
'use_cached_text_features': use_cached_text_features,
|
||||
'text_tokenizer_name': text_tokenizer_name,
|
||||
'text_max_length': text_max_length,
|
||||
'debug_input': args.get('debug_input', False),
|
||||
}
|
||||
|
||||
if is_eval:
|
||||
@@ -376,7 +377,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
|
||||
return success_rate, avg_return
|
||||
|
||||
|
||||
def forward_pass(data, policy):
|
||||
def forward_pass(data, policy, debug_input=False, debug_tag=''):
|
||||
image_data, qpos_data, action_data, is_pad, text_input_ids, text_attention_mask, text_feature_data, text_feature_valid = data
|
||||
image_data = image_data.cuda()
|
||||
qpos_data = qpos_data.cuda()
|
||||
@@ -391,6 +392,26 @@ def forward_pass(data, policy):
|
||||
if torch.any(text_feature_valid):
|
||||
text_features = text_feature_data
|
||||
|
||||
if debug_input:
|
||||
image_min = float(image_data.min().item())
|
||||
image_max = float(image_data.max().item())
|
||||
qpos_mean = float(qpos_data.mean().item())
|
||||
qpos_std = float(qpos_data.std().item())
|
||||
action_mean = float(action_data.mean().item())
|
||||
action_std = float(action_data.std().item())
|
||||
pad_ratio = float(is_pad.float().mean().item())
|
||||
|
||||
print(f'[debug_input] {debug_tag} image shape={tuple(image_data.shape)} range=[{image_min:.4f}, {image_max:.4f}]')
|
||||
print(f'[debug_input] {debug_tag} qpos shape={tuple(qpos_data.shape)} mean/std=({qpos_mean:.4f}, {qpos_std:.4f})')
|
||||
print(f'[debug_input] {debug_tag} action shape={tuple(action_data.shape)} mean/std=({action_mean:.4f}, {action_std:.4f})')
|
||||
print(f'[debug_input] {debug_tag} is_pad shape={tuple(is_pad.shape)} pad_ratio={pad_ratio:.4f}')
|
||||
print(
|
||||
f'[debug_input] {debug_tag} has_nan_or_inf: '
|
||||
f'image={bool(torch.logical_not(torch.isfinite(image_data)).any().item())}, '
|
||||
f'qpos={bool(torch.logical_not(torch.isfinite(qpos_data)).any().item())}, '
|
||||
f'action={bool(torch.logical_not(torch.isfinite(action_data)).any().item())}'
|
||||
)
|
||||
|
||||
return policy(
|
||||
qpos_data,
|
||||
image_data,
|
||||
@@ -408,6 +429,7 @@ def train_bc(train_dataloader, val_dataloader, config):
|
||||
seed = config['seed']
|
||||
policy_class = config['policy_class']
|
||||
policy_config = config['policy_config']
|
||||
debug_input = config.get('debug_input', False)
|
||||
|
||||
set_seed(seed)
|
||||
|
||||
@@ -426,7 +448,8 @@ def train_bc(train_dataloader, val_dataloader, config):
|
||||
policy.eval()
|
||||
epoch_dicts = []
|
||||
for batch_idx, data in enumerate(val_dataloader):
|
||||
forward_dict = forward_pass(data, policy)
|
||||
should_debug = debug_input and epoch == 0 and batch_idx == 0
|
||||
forward_dict = forward_pass(data, policy, debug_input=should_debug, debug_tag='val/epoch0/batch0')
|
||||
epoch_dicts.append(forward_dict)
|
||||
epoch_summary = compute_dict_mean(epoch_dicts)
|
||||
validation_history.append(epoch_summary)
|
||||
@@ -445,7 +468,8 @@ def train_bc(train_dataloader, val_dataloader, config):
|
||||
policy.train()
|
||||
optimizer.zero_grad()
|
||||
for batch_idx, data in enumerate(train_dataloader):
|
||||
forward_dict = forward_pass(data, policy)
|
||||
should_debug = debug_input and epoch == 0 and batch_idx == 0
|
||||
forward_dict = forward_pass(data, policy, debug_input=should_debug, debug_tag='train/epoch0/batch0')
|
||||
# backward
|
||||
loss = forward_dict['loss']
|
||||
loss.backward()
|
||||
@@ -519,5 +543,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--text_max_length', action='store', type=int, required=False)
|
||||
parser.add_argument('--image_aug', action='store_true',
|
||||
help='Enable training-time image augmentation (color/highlight/noise/blur)')
|
||||
parser.add_argument('--debug_input', action='store_true',
|
||||
help='Print one-batch input sanity checks (shape/range/nan) for val/train at epoch 0')
|
||||
|
||||
main(vars(parser.parse_args()))
|
||||
|
||||
Reference in New Issue
Block a user