4 from os.path
import join, exists, isfile, realpath, dirname
7 import torch.nn.functional
as F
8 import torch.optim
as optim
9 from torch.autograd
import Variable
10 from torch.utils.data
import DataLoader, SubsetRandomSampler
11 from torch.utils.data.dataset
import Subset
12 import torchvision.transforms
as transforms
14 from datetime
import datetime
15 import torchvision.datasets
as datasets
16 import torchvision.models
as models
21 from sklearn.neighbors
import NearestNeighbors
22 from cslam.vpr.cosplace_utils.network
import GeoLocalizationNet
23 from ament_index_python.packages
import get_package_share_directory
25 IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
26 IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
37 params (dict): parameters
42 self.
enableenable = self.
paramsparams[
'frontend.nn_checkpoint'].lower(
45 pkg_folder = get_package_share_directory(
"cslam")
46 self.
paramsparams[
'frontend.nn_checkpoint'] = join(
47 pkg_folder, self.
paramsparams[
'frontend.nn_checkpoint'])
49 if torch.cuda.is_available():
50 self.
devicedevice = torch.device(
"cuda")
52 self.
devicedevice = torch.device(
"cpu")
55 'frontend.cosplace.descriptor_dim']
56 self.
modelmodel = GeoLocalizationNet(
60 resume_ckpt = self.
paramsparams[
'frontend.nn_checkpoint']
61 if isfile(resume_ckpt):
62 self.
nodenode.get_logger().info(
"loading checkpoint '{}'".format(resume_ckpt))
63 checkpoint = torch.load(
64 resume_ckpt, map_location=
lambda storage, loc: storage)
66 self.
modelmodel.load_state_dict(checkpoint)
69 self.
nodenode.get_logger().error(
"Error: Checkpoint path is incorrect {}".format(resume_ckpt))
72 self.
modelmodel.eval()
74 transforms.CenterCrop(self.
paramsparams[
"frontend.image_crop_size"]),
75 transforms.Resize(224, interpolation=3),
76 transforms.ToTensor(),
77 transforms.Normalize(IMAGENET_DEFAULT_MEAN,
78 IMAGENET_DEFAULT_STD),
82 """Load image to device and extract the global image descriptor
85 keyframe (image): image to match
88 np.array: global image descriptor
92 image = Image.fromarray(keyframe)
94 input = torch.unsqueeze(input, 0)
95 input = input.to(self.
devicedevice)
97 image_encoding = self.
modelmodel.forward(input)
99 output = image_encoding[0].detach().cpu().numpy()
100 del input, image_encoding, image
def __init__(self, params, node)
def compute_embedding(self, keyframe)