7 from cslam.vpr.cosplace_utils.layers
import Flatten, L2Norm, GeM
10 CHANNELS_NUM_IN_LAST_CONV = {
20 def __init__(self, backbone, fc_output_dim, node):
27 nn.Linear(features_dim, fc_output_dim),
39 if backbone_name.startswith(
"resnet"):
40 if backbone_name ==
"resnet18":
41 backbone = torchvision.models.resnet18(pretrained=
True)
42 elif backbone_name ==
"resnet50":
43 backbone = torchvision.models.resnet50(pretrained=
True)
44 elif backbone_name ==
"resnet101":
45 backbone = torchvision.models.resnet101(pretrained=
True)
46 elif backbone_name ==
"resnet152":
47 backbone = torchvision.models.resnet152(pretrained=
True)
49 for name, child
in backbone.named_children():
52 for params
in child.parameters():
53 params.requires_grad =
False
54 logging.debug(f
"Train only layer3 and layer4 of the {backbone_name}, freeze the previous ones")
55 layers = list(backbone.children())[:-2]
57 elif backbone_name ==
"vgg16":
58 backbone = torchvision.models.vgg16(pretrained=
True)
59 layers = list(backbone.features.children())[:-2]
61 for p
in l.parameters(): p.requires_grad =
False
62 logging.debug(
"Train last layers of the VGG-16, freeze the previous ones")
64 backbone = torch.nn.Sequential(*layers)
66 features_dim = CHANNELS_NUM_IN_LAST_CONV[backbone_name]
68 return backbone, features_dim
def __init__(self, backbone, fc_output_dim, node)
def get_backbone(backbone_name)