代码可以跑起来了
This commit is contained in:
@@ -89,9 +89,32 @@ class Backbone(BackboneBase):
|
||||
train_backbone: bool,
|
||||
return_interm_layers: bool,
|
||||
dilation: bool):
|
||||
backbone = getattr(torchvision.models, name)(
|
||||
replace_stride_with_dilation=[False, False, dilation],
|
||||
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm??
|
||||
backbone_builder = getattr(torchvision.models, name)
|
||||
weights = None
|
||||
if is_main_process():
|
||||
weight_enum_name_map = {
|
||||
'resnet18': 'ResNet18_Weights',
|
||||
'resnet34': 'ResNet34_Weights',
|
||||
'resnet50': 'ResNet50_Weights',
|
||||
'resnet101': 'ResNet101_Weights',
|
||||
}
|
||||
enum_name = weight_enum_name_map.get(name)
|
||||
if enum_name is not None and hasattr(torchvision.models, enum_name):
|
||||
weights = getattr(getattr(torchvision.models, enum_name), 'DEFAULT')
|
||||
|
||||
try:
|
||||
backbone = backbone_builder(
|
||||
replace_stride_with_dilation=[False, False, dilation],
|
||||
weights=weights,
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
)
|
||||
except TypeError:
|
||||
# Backward compatibility for older torchvision that still expects `pretrained`.
|
||||
backbone = backbone_builder(
|
||||
replace_stride_with_dilation=[False, False, dilation],
|
||||
pretrained=(weights is not None),
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
)
|
||||
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
|
||||
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user