File size: 3,711 Bytes
9d0d223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# pyre-strict

import typing as tp

import torch


# Based on code from https://github.com/pix2pixzero/pix2pix-zero
@torch.enable_grad()
def noise_regularization(
    e_t: torch.Tensor,
    noise_pred_optimal: torch.Tensor,
    lambda_kl: float,
    lambda_ac: float,
    num_reg_steps: int,
    num_ac_rolls: int,
    generator: tp.Optional[torch._C.Generator] = None,
) -> torch.Tensor:
    should_move_back_to_cpu = e_t.device.type == "mps"
    # print(should_move_back_to_cpu)
    if should_move_back_to_cpu:
        e_t = e_t.to("cpu")
        noise_pred_optimal = noise_pred_optimal.to("cpu")
    for _outer in range(num_reg_steps):
        if lambda_kl > 0:
            _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
            l_kld = patchify_latents_kl_divergence(_var, noise_pred_optimal)
            l_kld.backward()
            _grad = _var.grad.detach()  # pyre-ignore
            _grad = torch.clip(_grad, -100, 100)
            e_t = e_t - lambda_kl * _grad
        if lambda_ac > 0:
            for _inner in range(num_ac_rolls):
                _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
                l_ac = auto_corr_loss(_var.unsqueeze(1), generator=generator)
                l_ac.backward()  # pyre-ignore
                _grad = _var.grad.detach() / num_ac_rolls
                e_t = e_t - lambda_ac * _grad
        e_t = e_t.detach()

    return e_t if not should_move_back_to_cpu else e_t.to("mps")


# Based on code from https://github.com/pix2pixzero/pix2pix-zero
def auto_corr_loss(
    x: torch.Tensor,
    random_shift: bool = True,
    generator: tp.Optional[torch._C.Generator] = None,
) -> tp.Union[float, torch.Tensor]:
    B, C, H, W = x.shape
    assert B == 1
    x = x.squeeze(0)
    # x must be shape [C,H,W] now
    reg_loss = 0.0
    for ch_idx in range(x.shape[0]):
        noise = x[ch_idx][None, None, :, :]
        while True:
            if random_shift:
                roll_amount = torch.randint(
                    0, noise.shape[2] // 2, (1,), generator=generator
                ).item()
            else:
                roll_amount = 1
            reg_loss += torch.pow(
                (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean(), 2  # pyre-ignore
            )
            reg_loss += torch.pow(
                (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean(), 2  # pyre-ignore
            )
            if noise.shape[2] <= 8:
                break
            noise = torch.nn.functional.avg_pool2d(noise, kernel_size=2)
    return reg_loss


def patchify_latents_kl_divergence(
    x0: torch.Tensor, x1: torch.Tensor, patch_size: int = 4, num_channels: int = 4
) -> torch.Tensor:

    def patchify_tensor(input_tensor: torch.Tensor) -> torch.Tensor:
        patches = (
            input_tensor.unfold(1, patch_size, patch_size)
            .unfold(2, patch_size, patch_size)
            .unfold(3, patch_size, patch_size)
        )
        patches = patches.contiguous().view(-1, num_channels, patch_size, patch_size)
        return patches

    x0 = patchify_tensor(x0)
    x1 = patchify_tensor(x1)

    kl = latents_kl_divergence(x0, x1).sum()
    return kl


def latents_kl_divergence(x0: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
    EPSILON = 1e-6
    x0 = x0.view(x0.shape[0], x0.shape[1], -1)
    x1 = x1.view(x1.shape[0], x1.shape[1], -1)
    mu0 = x0.mean(dim=-1)
    mu1 = x1.mean(dim=-1)
    var0 = x0.var(dim=-1)
    var1 = x1.var(dim=-1)
    kl = (
        torch.log((var1 + EPSILON) / (var0 + EPSILON))
        + (var0 + torch.pow((mu0 - mu1), 2)) / (var1 + EPSILON)
        - 1
    )
    kl = torch.abs(kl).sum(dim=-1)
    return kl