代码可以跑起来了
This commit is contained in:
18
detr/main.py
18
detr/main.py
@@ -4,6 +4,7 @@ from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.optim.adamw import AdamW
|
||||
from .models import build_ACT_model, build_CNNMLP_model
|
||||
|
||||
import IPython
|
||||
@@ -30,6 +31,15 @@ def get_args_parser():
|
||||
help="Type of positional embedding to use on top of the image features")
|
||||
parser.add_argument('--camera_names', default=[], type=list, # will be overridden
|
||||
help="A list of camera names")
|
||||
parser.add_argument('--state_dim', default=14, type=int)
|
||||
parser.add_argument('--action_dim', default=14, type=int)
|
||||
parser.add_argument('--use_text', action='store_true')
|
||||
parser.add_argument('--text_encoder_type', default='distilbert', type=str)
|
||||
parser.add_argument('--text_feature_dim', default=768, type=int)
|
||||
parser.add_argument('--text_fusion_type', default='concat_transformer_input', type=str)
|
||||
parser.add_argument('--freeze_text_encoder', action='store_true')
|
||||
parser.add_argument('--text_max_length', default=32, type=int)
|
||||
parser.add_argument('--text_tokenizer_name', default='distilbert-base-uncased', type=str)
|
||||
|
||||
# * Transformer
|
||||
parser.add_argument('--enc_layers', default=4, type=int, # will be overridden
|
||||
@@ -84,8 +94,8 @@ def build_ACT_model_and_optimizer(args_override):
|
||||
"lr": args.lr_backbone,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
|
||||
weight_decay=args.weight_decay)
|
||||
optimizer = AdamW(param_dicts, lr=args.lr,
|
||||
weight_decay=args.weight_decay)
|
||||
|
||||
return model, optimizer
|
||||
|
||||
@@ -107,8 +117,8 @@ def build_CNNMLP_model_and_optimizer(args_override):
|
||||
"lr": args.lr_backbone,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
|
||||
weight_decay=args.weight_decay)
|
||||
optimizer = AdamW(param_dicts, lr=args.lr,
|
||||
weight_decay=args.weight_decay)
|
||||
|
||||
return model, optimizer
|
||||
|
||||
|
||||
Reference in New Issue
Block a user