yslan's picture
init
7f51798
raw
history blame
5.65 kB
# https://raw.githubusercontent.com/CompVis/latent-diffusion/e66308c7f2e64cb581c6d27ab6fbeb846828253b/ldm/modules/distributions/distributions.py
import torch
import numpy as np
from pdb import set_trace as st
class AbstractDistribution:
def sample(self):
raise NotImplementedError()
def mode(self):
raise NotImplementedError()
class DiracDistribution(AbstractDistribution):
def __init__(self, value):
self.value = value
def sample(self):
return self.value
def mode(self):
return self.value
@torch.jit.script
def soft_clamp20(x: torch.Tensor):
return x.div(20.).tanh().mul(
20.
) # 5. * torch.tanh(x / 5.) <--> soft differentiable clamp between [-5, 5]
# @torch.jit.script
# def soft_clamp(x: torch.Tensor, a: torch.Tensor):
# return x.div(a).tanh_().mul(a)
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False, soft_clamp=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
if soft_clamp:
# self.mean, self.logvar = soft_clamp5(self.mean), soft_clamp5(self.logvar) # as in LSGM, bound the range. needs re-training?
self.logvar = soft_clamp20(
self.logvar) # as in LSGM, bound the range. [-20, 20]
else:
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean).to(device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(
self.mean.shape).to(device=self.parameters.device)
return x
# https://github.dev/NVlabs/LSGM/util/distributions.py
def log_p(self, samples):
# for calculating the negative encoder entropy term
normalized_samples = (samples - self.mean) / self.var
log_p = -0.5 * normalized_samples * normalized_samples - 0.5 * np.log(
2 * np.pi) - self.logvar #
return log_p # ! TODO
def normal_entropy(self):
# for calculating normal entropy. Motivation: supervise logvar directly.
# normalized_samples = (samples - self.mean) / self.var
# log_p = - 0.5 * normalized_samples * normalized_samples - 0.5 * np.log(2 * np.pi) - self.logvar #
# entropy = torch.sum(self.logvar + 0.5 * (np.log(2 * np.pi) + 1),
# dim=[1, 2, 3]).mean(0)
# entropy = torch.mean(self.logvar + 0.5 * (np.log(2 * np.pi) + 1)) # follow eps loss tradition here, average overall dims.
entropy = self.logvar + 0.5 * (np.log(2 * np.pi) + 1) # follow eps loss tradition here, average overall dims.
return entropy # ! TODO
def kl(self, other=None, pt_ft_separate=False, ft_separate=False):
def kl_fn(mean, var, logvar):
return 0.5 * torch.sum(
torch.pow(mean, 2) + var - 1.0 - logvar,
dim=list(range(1,mean.ndim))) # support B L C-like VAE latent
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
if pt_ft_separate: # as in LION
pt_kl = kl_fn(self.mean[:, :3], self.var[:, :3], self.logvar[:, :3]) # (B C L) input
ft_kl = kl_fn(self.mean[:, 3:], self.var[:, 3:], self.logvar[:, 3:]) # (B C L) input
return pt_kl, ft_kl
elif ft_separate:
ft_kl = kl_fn(self.mean[:, :], self.var[:, :], self.logvar[:, :]) # (B C L) input
return ft_kl
else:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=list(range(1,self.mean.ndim))) # support B L C-like VAE latent
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var +
self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(logtwopi + self.logvar +
torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1, logvar2 = [
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) +
((mean1 - mean2)**2) * torch.exp(-logvar2))