代码可以跑起来了

This commit is contained in:
2026-02-19 15:32:28 +08:00
parent b701d939c2
commit 88d14221ae
11 changed files with 503 additions and 89 deletions

View File

@@ -89,9 +89,32 @@ class Backbone(BackboneBase):
train_backbone: bool,
return_interm_layers: bool,
dilation: bool):
backbone = getattr(torchvision.models, name)(
replace_stride_with_dilation=[False, False, dilation],
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm??
backbone_builder = getattr(torchvision.models, name)
weights = None
if is_main_process():
weight_enum_name_map = {
'resnet18': 'ResNet18_Weights',
'resnet34': 'ResNet34_Weights',
'resnet50': 'ResNet50_Weights',
'resnet101': 'ResNet101_Weights',
}
enum_name = weight_enum_name_map.get(name)
if enum_name is not None and hasattr(torchvision.models, enum_name):
weights = getattr(getattr(torchvision.models, enum_name), 'DEFAULT')
try:
backbone = backbone_builder(
replace_stride_with_dilation=[False, False, dilation],
weights=weights,
norm_layer=FrozenBatchNorm2d,
)
except TypeError:
# Backward compatibility for older torchvision that still expects `pretrained`.
backbone = backbone_builder(
replace_stride_with_dilation=[False, False, dilation],
pretrained=(weights is not None),
norm_layer=FrozenBatchNorm2d,
)
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)