Spaces:
Sleeping
Sleeping
import torch | |
from math import sqrt | |
from abc import abstractmethod | |
from itertools import combinations | |
from src.gan.gankits import nz | |
class ExclusionReg: | |
# NOTE: To be maximised | |
def __init__(self, lbd): | |
self.lbd = lbd | |
def forward(self, muss, stdss, betas): | |
pass | |
class WassersteinExclusion(ExclusionReg): | |
def forward(self, muss, stdss, betas): | |
b, m, d = muss.shape | |
rho = torch.zeros([b], device=muss.device) | |
for i, j in combinations(range(m), 2): | |
x = torch.square((muss[:, i, :] - muss[:, j, :])).sum(dim=-1) | |
y = torch.sum((stdss[:, i, :] + stdss[:, j, :]), dim=-1) | |
z = torch.sqrt((stdss[:, i, :] * stdss[:, j, :]).sum(dim=-1)) | |
w = (x + y - 2 * z).sqrt() | |
rho += betas[:, i] * betas[:, j] * w | |
return self.lbd * rho | |
class LogWassersteinExclusion(ExclusionReg): | |
def forward(self, muss, stdss, betas): | |
b, m, d = muss.shape | |
rho = torch.zeros([b], device=muss.device) | |
for i, j in combinations(range(m), 2): | |
x = torch.square((muss[:, i, :] - muss[:, j, :])).sum(dim=-1) | |
y = torch.sum((stdss[:, i, :] + stdss[:, j, :]), dim=-1) | |
z = torch.sqrt((stdss[:, i, :] * stdss[:, j, :]).sum(dim=-1)) | |
w = (x + y - 2 * z).sqrt() | |
rho += betas[:, i] * betas[:, j] * torch.log(w + 1) | |
return self.lbd * rho | |
class ClipExclusion(ExclusionReg): | |
def __init__(self, lbd, wbar=0.6 * sqrt(nz)): | |
super(ClipExclusion, self).__init__(lbd) | |
self.wbar = wbar | |
def forward(self, muss, stdss, betas): | |
b, m, d = muss.shape | |
rho = torch.zeros([b], device=muss.device) | |
for i, j in combinations(range(m), 2): | |
x = torch.square((muss[:, i, :] - muss[:, j, :])).sum(dim=-1) | |
y = torch.sum((stdss[:, i, :] + stdss[:, j, :]), dim=-1) | |
z = torch.sqrt((stdss[:, i, :] * stdss[:, j, :]).sum(dim=-1)) | |
w = (x + y - 2 * z).sqrt() | |
rho += betas[:, i] * betas[:, j] * torch.clip(w, max=self.wbar) | |
return self.lbd * rho | |
class LogClipExclusion(ExclusionReg): | |
def __init__(self, lbd, wbar=0.6 * sqrt(nz)): | |
super(LogClipExclusion, self).__init__(lbd) | |
self.wbar = wbar | |
def forward(self, muss, stdss, betas): | |
b, m, d = muss.shape | |
rho = torch.zeros([b], device=muss.device) | |
for i, j in combinations(range(m), 2): | |
x = torch.square((muss[:, i, :] - muss[:, j, :])).sum(dim=-1) | |
y = torch.sum((stdss[:, i, :] + stdss[:, j, :]), dim=-1) | |
z = torch.sqrt((stdss[:, i, :] * stdss[:, j, :]).sum(dim=-1)) | |
w = (x + y - 2 * z).sqrt() | |
rho += betas[:, i] * betas[:, j] * torch.log(torch.clip(w, max=self.wbar) + 1) | |
return self.lbd * rho | |
# class SurrogateDistReg: | |
# def __init__(self, lbd, clip=30.): | |
# self.lbd = lbd | |
# self.clip = clip | |
# | |
# def forward(self, muss, stdss, betas): | |
if __name__ == '__main__': | |
pass | |