MMD_MP_Text_Dection / utils_MMD.py
alwayse's picture
Upload 9 files
d0e1f8b
raw
history blame contribute delete
No virus
1.71 kB
import torch
def Pdist2(x, y):
"""compute the paired distance between x and y."""
x_norm = (x ** 2).sum(1).view(-1, 1)
if y is not None:
y_norm = (y ** 2).sum(1).view(1, -1)
else:
y = x
y_norm = x_norm.view(1, -1)
Pdist = x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(y, 0, 1))
Pdist[Pdist<0]=0
return Pdist
def MMD_batch2(Fea, len_s, Fea_org, sigma, sigma0=0.1, epsilon = 10**(-10), is_smooth=True, is_var_computed=True, use_1sample_U=True, coeff_xy=2):
X = Fea[0:len_s, :]
Y = Fea[len_s:, :]
if is_smooth:
X_org = Fea_org[0:len_s, :]
Y_org = Fea_org[len_s:, :]
L = 1 # generalized Gaussian (if L>1)
nx = X.shape[0]
ny = Y.shape[0]
Dxx = Pdist2(X, X)
Dyy = torch.zeros(Fea.shape[0] - len_s, 1).to(Dxx.device)
# Dyy = Pdist2(Y, Y)
Dxy = Pdist2(X, Y).transpose(0,1)
if is_smooth:
Dxx_org = Pdist2(X_org, X_org)
Dyy_org = torch.zeros(Fea.shape[0] - len_s, 1).to(Dxx.device)
# Dyy_org = Pdist2(Y_org, Y_org) # 1,1 0
Dxy_org = Pdist2(X_org, Y_org).transpose(0,1)
if is_smooth:
Kx = (1-epsilon) * torch.exp(-(Dxx / sigma0)**L -Dxx_org / sigma) + epsilon * torch.exp(-Dxx_org / sigma)
Ky = (1-epsilon) * torch.exp(-(Dyy / sigma0)**L -Dyy_org / sigma) + epsilon * torch.exp(-Dyy_org / sigma)
Kxy = (1-epsilon) * torch.exp(-(Dxy / sigma0)**L -Dxy_org / sigma) + epsilon * torch.exp(-Dxy_org / sigma)
else:
Kx = torch.exp(-Dxx / sigma0)
Ky = torch.exp(-Dyy / sigma0)
Kxy = torch.exp(-Dxy / sigma0)
nx = Kx.shape[0]
is_unbiased = False
if 1:
xx = torch.div((torch.sum(Kx)), (nx * nx))
yy = Ky.reshape(-1)
# one-sample U-statistic.
xy = torch.div(torch.sum(Kxy, dim = 1), (nx ))
mmd2 = xx - 2 * xy + yy
return mmd2