File size: 5,182 Bytes
11e6f7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# https://raw.githubusercontent.com/CompVis/latent-diffusion/e66308c7f2e64cb581c6d27ab6fbeb846828253b/ldm/modules/distributions/distributions.py

import torch
import numpy as np
from pdb import set_trace as st


class AbstractDistribution:

    def sample(self):
        raise NotImplementedError()

    def mode(self):
        raise NotImplementedError()


class DiracDistribution(AbstractDistribution):

    def __init__(self, value):
        self.value = value

    def sample(self):
        return self.value

    def mode(self):
        return self.value


@torch.jit.script
def soft_clamp20(x: torch.Tensor):
    # return x.div(5.).tanh_().mul(5.)  # 5. * torch.tanh(x / 5.) <--> soft differentiable clamp between [-5, 5]
    # return x.div(5.).tanh().mul(5.)  # 5. * torch.tanh(x / 5.) <--> soft differentiable clamp between [-5, 5]
    # return x.div(15.).tanh().mul(15.)  # 5. * torch.tanh(x / 5.) <--> soft differentiable clamp between [-5, 5]
    return x.div(20.).tanh().mul(
        20.
    )  # 5. * torch.tanh(x / 5.) <--> soft differentiable clamp between [-5, 5]


# @torch.jit.script
# def soft_clamp(x: torch.Tensor, a: torch.Tensor):
#     return x.div(a).tanh_().mul(a)


class DiagonalGaussianDistribution(object):

    def __init__(self, parameters, deterministic=False, soft_clamp=False):
        self.parameters = parameters
        self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)

        if soft_clamp:
            # self.mean, self.logvar = soft_clamp5(self.mean), soft_clamp5(self.logvar) # as in LSGM, bound the range. needs re-training?
            self.logvar = soft_clamp20(
                self.logvar)  # as in LSGM, bound the range. [-20, 20]
        else:
            self.logvar = torch.clamp(self.logvar, -30.0, 20.0)

        self.deterministic = deterministic
        self.std = torch.exp(0.5 * self.logvar)
        self.var = torch.exp(self.logvar)
        if self.deterministic:
            self.var = self.std = torch.zeros_like(
                self.mean).to(device=self.parameters.device)

    def sample(self):
        x = self.mean + self.std * torch.randn(
            self.mean.shape).to(device=self.parameters.device)
        return x

    # https://github.dev/NVlabs/LSGM/util/distributions.py
    def log_p(self, samples):
        # for calculating the negative encoder entropy term
        normalized_samples = (samples - self.mean) / self.var
        log_p = -0.5 * normalized_samples * normalized_samples - 0.5 * np.log(
            2 * np.pi) - self.logvar  #

        return log_p  # ! TODO

    def normal_entropy(self):
        # for calculating normal entropy. Motivation: supervise logvar directly.
        # normalized_samples = (samples - self.mean) / self.var
        # log_p = - 0.5 * normalized_samples * normalized_samples - 0.5 * np.log(2 * np.pi) - self.logvar #
        # entropy = torch.sum(self.logvar + 0.5 * (np.log(2 * np.pi) + 1),
        #                     dim=[1, 2, 3]).mean(0)
        # entropy = torch.mean(self.logvar + 0.5 * (np.log(2 * np.pi) + 1)) # follow eps loss tradition here, average overall dims.
        entropy = self.logvar + 0.5 * (np.log(2 * np.pi) + 1) # follow eps loss tradition here, average overall dims.
                            

        return entropy  # ! TODO

    def kl(self, other=None):
        if self.deterministic:
            return torch.Tensor([0.])
        else:
            if other is None:
                return 0.5 * torch.sum(
                    torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
                    dim=[1, 2, 3])
            else:
                return 0.5 * torch.sum(
                    torch.pow(self.mean - other.mean, 2) / other.var +
                    self.var / other.var - 1.0 - self.logvar + other.logvar,
                    dim=[1, 2, 3])

    def nll(self, sample, dims=[1, 2, 3]):
        if self.deterministic:
            return torch.Tensor([0.])
        logtwopi = np.log(2.0 * np.pi)
        return 0.5 * torch.sum(logtwopi + self.logvar +
                               torch.pow(sample - self.mean, 2) / self.var,
                               dim=dims)

    def mode(self):
        return self.mean


def normal_kl(mean1, logvar1, mean2, logvar2):
    """
    source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
    Compute the KL divergence between two gaussians.
    Shapes are automatically broadcasted, so batches can be compared to
    scalars, among other use cases.
    """
    tensor = None
    for obj in (mean1, logvar1, mean2, logvar2):
        if isinstance(obj, torch.Tensor):
            tensor = obj
            break
    assert tensor is not None, "at least one argument must be a Tensor"

    # Force variances to be Tensors. Broadcasting helps convert scalars to
    # Tensors, but it does not work for torch.exp().
    logvar1, logvar2 = [
        x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
        for x in (logvar1, logvar2)
    ]

    return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) +
                  ((mean1 - mean2)**2) * torch.exp(-logvar2))