baiyanlali-zhao's picture
init
eaf2e33
raw
history blame
3.04 kB
import torch
from math import sqrt
from abc import abstractmethod
from itertools import combinations
from src.gan.gankits import nz
class ExclusionReg:
# NOTE: To be maximised
def __init__(self, lbd):
self.lbd = lbd
@abstractmethod
def forward(self, muss, stdss, betas):
pass
class WassersteinExclusion(ExclusionReg):
def forward(self, muss, stdss, betas):
b, m, d = muss.shape
rho = torch.zeros([b], device=muss.device)
for i, j in combinations(range(m), 2):
x = torch.square((muss[:, i, :] - muss[:, j, :])).sum(dim=-1)
y = torch.sum((stdss[:, i, :] + stdss[:, j, :]), dim=-1)
z = torch.sqrt((stdss[:, i, :] * stdss[:, j, :]).sum(dim=-1))
w = (x + y - 2 * z).sqrt()
rho += betas[:, i] * betas[:, j] * w
return self.lbd * rho
class LogWassersteinExclusion(ExclusionReg):
def forward(self, muss, stdss, betas):
b, m, d = muss.shape
rho = torch.zeros([b], device=muss.device)
for i, j in combinations(range(m), 2):
x = torch.square((muss[:, i, :] - muss[:, j, :])).sum(dim=-1)
y = torch.sum((stdss[:, i, :] + stdss[:, j, :]), dim=-1)
z = torch.sqrt((stdss[:, i, :] * stdss[:, j, :]).sum(dim=-1))
w = (x + y - 2 * z).sqrt()
rho += betas[:, i] * betas[:, j] * torch.log(w + 1)
return self.lbd * rho
class ClipExclusion(ExclusionReg):
def __init__(self, lbd, wbar=0.6 * sqrt(nz)):
super(ClipExclusion, self).__init__(lbd)
self.wbar = wbar
def forward(self, muss, stdss, betas):
b, m, d = muss.shape
rho = torch.zeros([b], device=muss.device)
for i, j in combinations(range(m), 2):
x = torch.square((muss[:, i, :] - muss[:, j, :])).sum(dim=-1)
y = torch.sum((stdss[:, i, :] + stdss[:, j, :]), dim=-1)
z = torch.sqrt((stdss[:, i, :] * stdss[:, j, :]).sum(dim=-1))
w = (x + y - 2 * z).sqrt()
rho += betas[:, i] * betas[:, j] * torch.clip(w, max=self.wbar)
return self.lbd * rho
class LogClipExclusion(ExclusionReg):
def __init__(self, lbd, wbar=0.6 * sqrt(nz)):
super(LogClipExclusion, self).__init__(lbd)
self.wbar = wbar
def forward(self, muss, stdss, betas):
b, m, d = muss.shape
rho = torch.zeros([b], device=muss.device)
for i, j in combinations(range(m), 2):
x = torch.square((muss[:, i, :] - muss[:, j, :])).sum(dim=-1)
y = torch.sum((stdss[:, i, :] + stdss[:, j, :]), dim=-1)
z = torch.sqrt((stdss[:, i, :] * stdss[:, j, :]).sum(dim=-1))
w = (x + y - 2 * z).sqrt()
rho += betas[:, i] * betas[:, j] * torch.log(torch.clip(w, max=self.wbar) + 1)
return self.lbd * rho
# class SurrogateDistReg:
# def __init__(self, lbd, clip=30.):
# self.lbd = lbd
# self.clip = clip
#
# def forward(self, muss, stdss, betas):
if __name__ == '__main__':
pass