加了--debug_input参数
This commit is contained in:
@@ -80,7 +80,7 @@ def get_args_parser():
|
|||||||
|
|
||||||
def build_ACT_model_and_optimizer(args_override):
|
def build_ACT_model_and_optimizer(args_override):
|
||||||
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
|
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
|
||||||
args = parser.parse_args()
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
for k, v in args_override.items():
|
for k, v in args_override.items():
|
||||||
setattr(args, k, v)
|
setattr(args, k, v)
|
||||||
@@ -103,7 +103,7 @@ def build_ACT_model_and_optimizer(args_override):
|
|||||||
|
|
||||||
def build_CNNMLP_model_and_optimizer(args_override):
|
def build_CNNMLP_model_and_optimizer(args_override):
|
||||||
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
|
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
|
||||||
args = parser.parse_args()
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
for k, v in args_override.items():
|
for k, v in args_override.items():
|
||||||
setattr(args, k, v)
|
setattr(args, k, v)
|
||||||
|
|||||||
@@ -128,6 +128,7 @@ def main(args):
|
|||||||
'use_cached_text_features': use_cached_text_features,
|
'use_cached_text_features': use_cached_text_features,
|
||||||
'text_tokenizer_name': text_tokenizer_name,
|
'text_tokenizer_name': text_tokenizer_name,
|
||||||
'text_max_length': text_max_length,
|
'text_max_length': text_max_length,
|
||||||
|
'debug_input': args.get('debug_input', False),
|
||||||
}
|
}
|
||||||
|
|
||||||
if is_eval:
|
if is_eval:
|
||||||
@@ -376,7 +377,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
|
|||||||
return success_rate, avg_return
|
return success_rate, avg_return
|
||||||
|
|
||||||
|
|
||||||
def forward_pass(data, policy):
|
def forward_pass(data, policy, debug_input=False, debug_tag=''):
|
||||||
image_data, qpos_data, action_data, is_pad, text_input_ids, text_attention_mask, text_feature_data, text_feature_valid = data
|
image_data, qpos_data, action_data, is_pad, text_input_ids, text_attention_mask, text_feature_data, text_feature_valid = data
|
||||||
image_data = image_data.cuda()
|
image_data = image_data.cuda()
|
||||||
qpos_data = qpos_data.cuda()
|
qpos_data = qpos_data.cuda()
|
||||||
@@ -391,6 +392,26 @@ def forward_pass(data, policy):
|
|||||||
if torch.any(text_feature_valid):
|
if torch.any(text_feature_valid):
|
||||||
text_features = text_feature_data
|
text_features = text_feature_data
|
||||||
|
|
||||||
|
if debug_input:
|
||||||
|
image_min = float(image_data.min().item())
|
||||||
|
image_max = float(image_data.max().item())
|
||||||
|
qpos_mean = float(qpos_data.mean().item())
|
||||||
|
qpos_std = float(qpos_data.std().item())
|
||||||
|
action_mean = float(action_data.mean().item())
|
||||||
|
action_std = float(action_data.std().item())
|
||||||
|
pad_ratio = float(is_pad.float().mean().item())
|
||||||
|
|
||||||
|
print(f'[debug_input] {debug_tag} image shape={tuple(image_data.shape)} range=[{image_min:.4f}, {image_max:.4f}]')
|
||||||
|
print(f'[debug_input] {debug_tag} qpos shape={tuple(qpos_data.shape)} mean/std=({qpos_mean:.4f}, {qpos_std:.4f})')
|
||||||
|
print(f'[debug_input] {debug_tag} action shape={tuple(action_data.shape)} mean/std=({action_mean:.4f}, {action_std:.4f})')
|
||||||
|
print(f'[debug_input] {debug_tag} is_pad shape={tuple(is_pad.shape)} pad_ratio={pad_ratio:.4f}')
|
||||||
|
print(
|
||||||
|
f'[debug_input] {debug_tag} has_nan_or_inf: '
|
||||||
|
f'image={bool(torch.logical_not(torch.isfinite(image_data)).any().item())}, '
|
||||||
|
f'qpos={bool(torch.logical_not(torch.isfinite(qpos_data)).any().item())}, '
|
||||||
|
f'action={bool(torch.logical_not(torch.isfinite(action_data)).any().item())}'
|
||||||
|
)
|
||||||
|
|
||||||
return policy(
|
return policy(
|
||||||
qpos_data,
|
qpos_data,
|
||||||
image_data,
|
image_data,
|
||||||
@@ -408,6 +429,7 @@ def train_bc(train_dataloader, val_dataloader, config):
|
|||||||
seed = config['seed']
|
seed = config['seed']
|
||||||
policy_class = config['policy_class']
|
policy_class = config['policy_class']
|
||||||
policy_config = config['policy_config']
|
policy_config = config['policy_config']
|
||||||
|
debug_input = config.get('debug_input', False)
|
||||||
|
|
||||||
set_seed(seed)
|
set_seed(seed)
|
||||||
|
|
||||||
@@ -426,7 +448,8 @@ def train_bc(train_dataloader, val_dataloader, config):
|
|||||||
policy.eval()
|
policy.eval()
|
||||||
epoch_dicts = []
|
epoch_dicts = []
|
||||||
for batch_idx, data in enumerate(val_dataloader):
|
for batch_idx, data in enumerate(val_dataloader):
|
||||||
forward_dict = forward_pass(data, policy)
|
should_debug = debug_input and epoch == 0 and batch_idx == 0
|
||||||
|
forward_dict = forward_pass(data, policy, debug_input=should_debug, debug_tag='val/epoch0/batch0')
|
||||||
epoch_dicts.append(forward_dict)
|
epoch_dicts.append(forward_dict)
|
||||||
epoch_summary = compute_dict_mean(epoch_dicts)
|
epoch_summary = compute_dict_mean(epoch_dicts)
|
||||||
validation_history.append(epoch_summary)
|
validation_history.append(epoch_summary)
|
||||||
@@ -445,7 +468,8 @@ def train_bc(train_dataloader, val_dataloader, config):
|
|||||||
policy.train()
|
policy.train()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
for batch_idx, data in enumerate(train_dataloader):
|
for batch_idx, data in enumerate(train_dataloader):
|
||||||
forward_dict = forward_pass(data, policy)
|
should_debug = debug_input and epoch == 0 and batch_idx == 0
|
||||||
|
forward_dict = forward_pass(data, policy, debug_input=should_debug, debug_tag='train/epoch0/batch0')
|
||||||
# backward
|
# backward
|
||||||
loss = forward_dict['loss']
|
loss = forward_dict['loss']
|
||||||
loss.backward()
|
loss.backward()
|
||||||
@@ -519,5 +543,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--text_max_length', action='store', type=int, required=False)
|
parser.add_argument('--text_max_length', action='store', type=int, required=False)
|
||||||
parser.add_argument('--image_aug', action='store_true',
|
parser.add_argument('--image_aug', action='store_true',
|
||||||
help='Enable training-time image augmentation (color/highlight/noise/blur)')
|
help='Enable training-time image augmentation (color/highlight/noise/blur)')
|
||||||
|
parser.add_argument('--debug_input', action='store_true',
|
||||||
|
help='Print one-batch input sanity checks (shape/range/nan) for val/train at epoch 0')
|
||||||
|
|
||||||
main(vars(parser.parse_args()))
|
main(vars(parser.parse_args()))
|
||||||
|
|||||||
Reference in New Issue
Block a user