Swarm-SLAM  1.0.0
C-SLAM Framework
network.py
Go to the documentation of this file.
1 
2 import torch
3 import logging
4 import torchvision
5 from torch import nn
6 
7 from cslam.vpr.cosplace_utils.layers import Flatten, L2Norm, GeM
8 
9 
10 CHANNELS_NUM_IN_LAST_CONV = {
11  "resnet18": 512,
12  "resnet50": 2048,
13  "resnet101": 2048,
14  "resnet152": 2048,
15  "vgg16": 512,
16  }
17 
18 
19 class GeoLocalizationNet(nn.Module):
20  def __init__(self, backbone, fc_output_dim, node):
21  super().__init__()
22  self.backbone, features_dim = get_backbone(backbone)
23  self.aggregationaggregation = nn.Sequential(
24  L2Norm(),
25  GeM(),
26  Flatten(),
27  nn.Linear(features_dim, fc_output_dim),
28  L2Norm()
29  )
30  self.nodenode = node
31 
32  def forward(self, x):
33  x = self.backbone(x)
34  x = self.aggregationaggregation(x)
35  return x
36 
37 
38 def get_backbone(backbone_name):
39  if backbone_name.startswith("resnet"):
40  if backbone_name == "resnet18":
41  backbone = torchvision.models.resnet18(pretrained=True)
42  elif backbone_name == "resnet50":
43  backbone = torchvision.models.resnet50(pretrained=True)
44  elif backbone_name == "resnet101":
45  backbone = torchvision.models.resnet101(pretrained=True)
46  elif backbone_name == "resnet152":
47  backbone = torchvision.models.resnet152(pretrained=True)
48 
49  for name, child in backbone.named_children():
50  if name == "layer3": # Freeze layers before conv_3
51  break
52  for params in child.parameters():
53  params.requires_grad = False
54  logging.debug(f"Train only layer3 and layer4 of the {backbone_name}, freeze the previous ones")
55  layers = list(backbone.children())[:-2] # Remove avg pooling and FC layer
56 
57  elif backbone_name == "vgg16":
58  backbone = torchvision.models.vgg16(pretrained=True)
59  layers = list(backbone.features.children())[:-2] # Remove avg pooling and FC layer
60  for l in layers[:-5]:
61  for p in l.parameters(): p.requires_grad = False
62  logging.debug("Train last layers of the VGG-16, freeze the previous ones")
63 
64  backbone = torch.nn.Sequential(*layers)
65 
66  features_dim = CHANNELS_NUM_IN_LAST_CONV[backbone_name]
67 
68  return backbone, features_dim
69 
def forward(self, x)
Definition: network.py:32
def __init__(self, backbone, fc_output_dim, node)
Definition: network.py:20
def get_backbone(backbone_name)
Definition: network.py:38