Swarm-SLAM  1.0.0
C-SLAM Framework
cosplace.py
Go to the documentation of this file.
1 import numpy as np
2 
3 import os
4 from os.path import join, exists, isfile, realpath, dirname
5 import torch
6 import torch.nn as nn
7 import torch.nn.functional as F
8 import torch.optim as optim
9 from torch.autograd import Variable
10 from torch.utils.data import DataLoader, SubsetRandomSampler
11 from torch.utils.data.dataset import Subset
12 import torchvision.transforms as transforms
13 from PIL import Image
14 from datetime import datetime
15 import torchvision.datasets as datasets
16 import torchvision.models as models
17 import numpy as np
18 import sys
19 import pickle
20 import sklearn
21 from sklearn.neighbors import NearestNeighbors
22 from cslam.vpr.cosplace_utils.network import GeoLocalizationNet
23 from ament_index_python.packages import get_package_share_directory
24 
25 IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
26 IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
27 
28 
29 class CosPlace(object):
30  """CosPlace matcher
31  """
32 
33  def __init__(self, params, node):
34  """Initialization
35 
36  Args:
37  params (dict): parameters
38  """
39  self.paramsparams = params
40  self.nodenode = node
41 
42  self.enableenable = self.paramsparams['frontend.nn_checkpoint'].lower(
43  ) != 'disable'
44  if self.enableenable:
45  pkg_folder = get_package_share_directory("cslam")
46  self.paramsparams['frontend.nn_checkpoint'] = join(
47  pkg_folder, self.paramsparams['frontend.nn_checkpoint'])
48 
49  if torch.cuda.is_available():
50  self.devicedevice = torch.device("cuda")
51  else:
52  self.devicedevice = torch.device("cpu")
53 
54  self.descriptor_dimdescriptor_dim = self.paramsparams[
55  'frontend.cosplace.descriptor_dim']
56  self.modelmodel = GeoLocalizationNet(
57  self.paramsparams['frontend.cosplace.backbone'], self.descriptor_dimdescriptor_dim,
58  node)
59 
60  resume_ckpt = self.paramsparams['frontend.nn_checkpoint']
61  if isfile(resume_ckpt):
62  self.nodenode.get_logger().info("loading checkpoint '{}'".format(resume_ckpt))
63  checkpoint = torch.load(
64  resume_ckpt, map_location=lambda storage, loc: storage)
65 
66  self.modelmodel.load_state_dict(checkpoint)
67  self.modelmodel = self.modelmodel.to(self.devicedevice)
68  else:
69  self.nodenode.get_logger().error("Error: Checkpoint path is incorrect {}".format(resume_ckpt))
70  exit()
71 
72  self.modelmodel.eval()
73  self.transformtransform = transforms.Compose([
74  transforms.CenterCrop(self.paramsparams["frontend.image_crop_size"]),
75  transforms.Resize(224, interpolation=3),
76  transforms.ToTensor(),
77  transforms.Normalize(IMAGENET_DEFAULT_MEAN,
78  IMAGENET_DEFAULT_STD),
79  ])
80 
81  def compute_embedding(self, keyframe):
82  """Load image to device and extract the global image descriptor
83 
84  Args:
85  keyframe (image): image to match
86 
87  Returns:
88  np.array: global image descriptor
89  """
90  if self.enableenable:
91  with torch.no_grad():
92  image = Image.fromarray(keyframe)
93  input = self.transformtransform(image)
94  input = torch.unsqueeze(input, 0)
95  input = input.to(self.devicedevice)
96 
97  image_encoding = self.modelmodel.forward(input)
98 
99  output = image_encoding[0].detach().cpu().numpy()
100  del input, image_encoding, image
101  return output
102  else:
103  # Random descriptor if disabled
104  # Use this option only for testing
105  return np.random.rand(self.descriptor_dimdescriptor_dim)
def __init__(self, params, node)
Definition: cosplace.py:33
def compute_embedding(self, keyframe)
Definition: cosplace.py:81