4 import torch.nn.functional
as F
5 from torch.nn.parameter
import Parameter
8 def gem(x, p=3, eps=1e-6):
9 return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
15 self.
pp = Parameter(torch.ones(1)*p)
18 return gem(x, p=self.
pp, eps=self.
epseps)
20 return self.__class__.__name__ +
'(' +
'p=' +
'{:.4f}'.format(self.
pp.data.tolist()[0]) +
', ' +
'eps=' + str(self.
epseps) +
')'
27 assert x.shape[2] == x.shape[3] == 1, f
"{x.shape[2]} != {x.shape[3]} != 1"
36 return F.normalize(x, p=2, dim=self.
dimdim)
def __init__(self, p=3, eps=1e-6)
def __init__(self, dim=1)
def gem(x, p=3, eps=1e-6)