File size: 2,894 Bytes
ad7bc89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch.nn as nn
import torch
from ...util import append_dims, instantiate_from_config


class Denoiser(nn.Module):
    def __init__(self, weighting_config, scaling_config):
        super().__init__()

        self.weighting = instantiate_from_config(weighting_config)
        self.scaling = instantiate_from_config(scaling_config)

    def possibly_quantize_sigma(self, sigma):
        return sigma

    def possibly_quantize_c_noise(self, c_noise):
        return c_noise

    def w(self, sigma):
        return self.weighting(sigma)

    def __call__(self, network, input, sigma, cond, sigmas_ref=None, **kwargs):
        sigma = self.possibly_quantize_sigma(sigma)
        sigma_shape = sigma.shape
        sigma = append_dims(sigma, input.ndim)
        if sigmas_ref is not None:
            if kwargs is not None:
                kwargs['sigmas_ref'] = sigmas_ref
            else:
                kwargs = {'sigmas_ref': sigmas_ref}

            if kwargs['input_ref'] is not None:
                noise = torch.randn_like(kwargs['input_ref'])
                kwargs['input_ref'] = kwargs['input_ref'] + noise * append_dims(sigmas_ref, kwargs['input_ref'].ndim)

        if 'input_ref' in kwargs and kwargs['input_ref'] is not None and 'sigmas_ref' in kwargs:
            _, _, c_in_ref, c_noise_ref = self.scaling(append_dims(kwargs['sigmas_ref'], kwargs['input_ref'].ndim))
            kwargs['input_ref'] = kwargs['input_ref']*c_in_ref
            kwargs['sigmas_ref'] = self.possibly_quantize_c_noise(kwargs['sigmas_ref'])

        c_skip, c_out, c_in, c_noise = self.scaling(sigma)
        c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
        predict, fg_mask_list, alphas_list, rgb_list = network(input * c_in, c_noise, cond, **kwargs)
        return predict * c_out + input * c_skip, fg_mask_list, alphas_list, rgb_list


class DiscreteDenoiser(Denoiser):
    def __init__(
        self,
        weighting_config,
        scaling_config,
        num_idx,
        discretization_config,
        do_append_zero=False,
        quantize_c_noise=True,
        flip=True,
    ):
        super().__init__(weighting_config, scaling_config)
        sigmas = instantiate_from_config(discretization_config)(
            num_idx, do_append_zero=do_append_zero, flip=flip
        )
        self.register_buffer("sigmas", sigmas)
        self.quantize_c_noise = quantize_c_noise

    def sigma_to_idx(self, sigma):
        dists = sigma - self.sigmas[:, None]
        return dists.abs().argmin(dim=0).view(sigma.shape)

    def idx_to_sigma(self, idx):
        return self.sigmas[idx]

    def possibly_quantize_sigma(self, sigma):
        return self.idx_to_sigma(self.sigma_to_idx(sigma))

    def possibly_quantize_c_noise(self, c_noise):
        if self.quantize_c_noise:
            return self.sigma_to_idx(c_noise)
        else:
            return c_noise