Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,182 Bytes
11e6f7b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
# https://raw.githubusercontent.com/CompVis/latent-diffusion/e66308c7f2e64cb581c6d27ab6fbeb846828253b/ldm/modules/distributions/distributions.py
import torch
import numpy as np
from pdb import set_trace as st
class AbstractDistribution:
def sample(self):
raise NotImplementedError()
def mode(self):
raise NotImplementedError()
class DiracDistribution(AbstractDistribution):
def __init__(self, value):
self.value = value
def sample(self):
return self.value
def mode(self):
return self.value
@torch.jit.script
def soft_clamp20(x: torch.Tensor):
# return x.div(5.).tanh_().mul(5.) # 5. * torch.tanh(x / 5.) <--> soft differentiable clamp between [-5, 5]
# return x.div(5.).tanh().mul(5.) # 5. * torch.tanh(x / 5.) <--> soft differentiable clamp between [-5, 5]
# return x.div(15.).tanh().mul(15.) # 5. * torch.tanh(x / 5.) <--> soft differentiable clamp between [-5, 5]
return x.div(20.).tanh().mul(
20.
) # 5. * torch.tanh(x / 5.) <--> soft differentiable clamp between [-5, 5]
# @torch.jit.script
# def soft_clamp(x: torch.Tensor, a: torch.Tensor):
# return x.div(a).tanh_().mul(a)
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False, soft_clamp=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
if soft_clamp:
# self.mean, self.logvar = soft_clamp5(self.mean), soft_clamp5(self.logvar) # as in LSGM, bound the range. needs re-training?
self.logvar = soft_clamp20(
self.logvar) # as in LSGM, bound the range. [-20, 20]
else:
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean).to(device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(
self.mean.shape).to(device=self.parameters.device)
return x
# https://github.dev/NVlabs/LSGM/util/distributions.py
def log_p(self, samples):
# for calculating the negative encoder entropy term
normalized_samples = (samples - self.mean) / self.var
log_p = -0.5 * normalized_samples * normalized_samples - 0.5 * np.log(
2 * np.pi) - self.logvar #
return log_p # ! TODO
def normal_entropy(self):
# for calculating normal entropy. Motivation: supervise logvar directly.
# normalized_samples = (samples - self.mean) / self.var
# log_p = - 0.5 * normalized_samples * normalized_samples - 0.5 * np.log(2 * np.pi) - self.logvar #
# entropy = torch.sum(self.logvar + 0.5 * (np.log(2 * np.pi) + 1),
# dim=[1, 2, 3]).mean(0)
# entropy = torch.mean(self.logvar + 0.5 * (np.log(2 * np.pi) + 1)) # follow eps loss tradition here, average overall dims.
entropy = self.logvar + 0.5 * (np.log(2 * np.pi) + 1) # follow eps loss tradition here, average overall dims.
return entropy # ! TODO
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var +
self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(logtwopi + self.logvar +
torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1, logvar2 = [
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) +
((mean1 - mean2)**2) * torch.exp(-logvar2))
|