File size: 13,180 Bytes
2f85de4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
# python3.8
"""Contains implementation of Discriminator described in StyleNeRF."""

import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from models.utils.ops import upsample
from models.utils.ops import downsample
from models.utils.camera import camera_9d_to_16d

from models.utils.official_stylegan2_model_helper import EqualConv2d
from models.utils.official_stylegan2_model_helper import MappingNetwork
from models.utils.official_stylegan2_model_helper import DiscriminatorBlock
from models.utils.official_stylegan2_model_helper import DiscriminatorEpilogue


class Discriminator(nn.Module):
    def __init__(self,
        c_dim,                          # Conditioning label (C) dimensionality.
        img_resolution,                 # Input resolution.
        img_channels,                   # Number of input color channels.
        architecture        = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
        channel_base        = 1,        # Overall multiplier for the number of channels.
        channel_max         = 512,      # Maximum number of channels in any layer.
        num_fp16_res        = 0,        # Use FP16 for the N highest resolutions.
        conv_clamp          = None,     # Clamp the output of convolution layers to +-X, None = disable clamping.
        cmap_dim            = None,     # Dimensionality of mapped conditioning label, None = default.
        lowres_head         = None,     # add a low-resolution discriminator head
        dual_discriminator  = False,    # add low-resolution (NeRF) image
        dual_input_ratio    = None,     # optional another low-res image input, which will be interpolated to the main input
        block_kwargs        = {},       # Arguments for DiscriminatorBlock.
        mapping_kwargs      = {},       # Arguments for MappingNetwork.
        epilogue_kwargs     = {},       # Arguments for DiscriminatorEpilogue.
        upsample_type       = 'default',

        progressive         = False,
        resize_real_early   = False,    # Peform resizing before the training loop
        enable_ema          = False,    # Additionally save an EMA checkpoint

        predict_camera      = False,    # Learn camera predictor as InfoGAN
        predict_9d_camera   = False,    # Use 9D camera distribution
        predict_3d_camera   = False,    # Use 3D camera (u, v, r), assuming camera is on the unit sphere
        no_camera_condition = False,    # Disable camera conditioning in the discriminator
        saperate_camera     = False,    # by default, only works in the lowest resolution.
        **unused
    ):
        super().__init__()
        # setup parameters
        self.img_resolution      = img_resolution
        self.img_resolution_log2 = int(np.log2(img_resolution))
        self.img_channels        = img_channels
        self.block_resolutions   = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
        self.architecture        = architecture
        self.lowres_head         = lowres_head
        self.dual_input_ratio    = dual_input_ratio
        self.dual_discriminator  = dual_discriminator
        self.upsample_type       = upsample_type
        self.progressive         = progressive
        self.resize_real_early   = resize_real_early
        self.enable_ema          = enable_ema
        self.predict_camera      = predict_camera
        self.predict_9d_camera   = predict_9d_camera
        self.predict_3d_camera   = predict_3d_camera
        self.no_camera_condition = no_camera_condition
        self.separate_camera     = saperate_camera
        if self.progressive:
            assert self.architecture == 'skip', "not supporting other types for now."
        if self.dual_input_ratio is not None:  # similar to EG3d, concat low/high-res images
            self.img_channels    = self.img_channels * 2
        if self.predict_camera:
            assert not (self.predict_9d_camera and self.predict_3d_camera), "cannot achieve at the same time"
        channel_base = int(channel_base * 32768)
        channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
        fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)

        # camera prediction module
        self.c_dim = c_dim
        if predict_camera:
            if not self.no_camera_condition:
                if self.predict_3d_camera:
                    self.c_dim = out_dim = 3     # (u, v) on the sphere
                else:
                    self.c_dim = 16              # extrinsic 4x4 (for now)
                    if self.predict_9d_camera:
                        out_dim = 9
                    else:
                        out_dim = 16
            self.projector = EqualConv2d(channels_dict[4], out_dim, 4, padding=0, bias=False)

        if cmap_dim is None:
            cmap_dim = channels_dict[4]
        if self.c_dim == 0:
            cmap_dim = 0
        if self.c_dim > 0:
            self.mapping = MappingNetwork(z_dim=0, c_dim=self.c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)

        # main discriminator blocks
        common_kwargs = dict(img_channels=self.img_channels, architecture=architecture, conv_clamp=conv_clamp)
        cur_layer_idx = 0
        for res in self.block_resolutions:
            in_channels = channels_dict[res] if res < img_resolution else 0
            tmp_channels = channels_dict[res]
            out_channels = channels_dict[res // 2]
            use_fp16 = (res >= fp16_resolution)
            block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
                first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
            setattr(self, f'b{res}', block)
            cur_layer_idx += block.num_layers

        # dual discriminator or separate camera predictor
        if self.separate_camera or self.dual_discriminator:
            cur_layer_idx = 0
            for res in [r for r in self.block_resolutions if r <= self.lowres_head]:
                in_channels = channels_dict[res] if res < img_resolution else 0
                tmp_channels = channels_dict[res]
                out_channels = channels_dict[res // 2]
                block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
                    first_layer_idx=cur_layer_idx, use_fp16=False, **block_kwargs, **common_kwargs)
                setattr(self, f'c{res}', block)
                cur_layer_idx += block.num_layers

        # final output module
        self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
        self.register_buffer("alpha", torch.scalar_tensor(-1))

    def set_alpha(self, alpha):
        if alpha is not None:
            self.alpha = self.alpha * 0 + alpha

    def set_resolution(self, res):
        self.curr_status = res

    def get_estimated_camera(self, img, **block_kwargs):
        if isinstance(img, dict):
            img = img['img']
        img4cam = img.clone()
        if self.progressive and (img.size(-1) != self.lowres_head):
            img4cam = downsample(img, self.lowres_head)

        c, xc = None, None
        for res in [r for r in self.block_resolutions if r <= self.lowres_head or (not self.progressive)]:
            xc, img4cam = getattr(self, f'c{res}')(xc, img4cam, **block_kwargs)

        if self.separate_camera:
            c = self.projector(xc)[:,:,0,0]
            if self.predict_9d_camera:
                c = camera_9d_to_16d(c)
        return c, xc, img4cam

    def get_camera_loss(self, RT=None, UV=None, c=None):
        if UV is not None:  # UV has higher priority?
            return F.mse_loss(UV, c)
            # lu = torch.stack([(UV[:,0] - c[:, 0]) ** 2, (UV[:,0] - c[:, 0] + 1) ** 2, (UV[:,0] - c[:, 0] - 1) ** 2], 0).min(0).values
            # return torch.mean(sum(lu + (UV[:,1] - c[:, 1]) ** 2 + (UV[:,2] - c[:, 2]) ** 2))
        elif RT is not None:
            return F.smooth_l1_loss(RT.reshape(RT.size(0), -1), c) * 10
        return None

    def get_block_resolutions(self, input_img):
        block_resolutions = self.block_resolutions
        lowres_head = self.lowres_head
        alpha = self.alpha
        img_res = input_img.size(-1)
        if self.progressive and (self.lowres_head is not None) and (self.alpha > -1):
            if (self.alpha < 1) and (self.alpha > 0):
                try:
                    n_levels, _, before_res, target_res = self.curr_status
                    alpha, index = math.modf(self.alpha * n_levels)
                    index = int(index)
                except Exception as e:  # TODO: this is a hack, better to save status as buffers.
                    before_res = target_res = img_res
                if before_res == target_res:  # no upsampling was used in generator, do not increase the discriminator
                    alpha = 0
                block_resolutions = [res for res in self.block_resolutions if res <= target_res]
                lowres_head = before_res
            elif self.alpha == 0:
                block_resolutions = [res for res in self.block_resolutions if res <= lowres_head]
        return block_resolutions, alpha, lowres_head

    def forward(self, inputs, c=None, aug_pipe=None, return_camera=False, **block_kwargs):
        if not isinstance(inputs, dict):
            inputs = {'img': inputs}
        img = inputs['img']
        block_resolutions, alpha, lowres_head = self.get_block_resolutions(img)
        if img.size(-1) > block_resolutions[0]:
            img = downsample(img, block_resolutions[0])

        # this is to handle real images to obtain nerf-size image.
        if (self.dual_discriminator or (self.dual_input_ratio is not None)) and ('img_nerf' not in inputs):
            inputs['img_nerf'] = img
            if self.dual_discriminator and (inputs['img_nerf'].size(-1) > self.lowres_head):  # using Conv to read image.
                inputs['img_nerf'] = downsample(inputs['img_nerf'], self.lowres_head)
            elif self.dual_input_ratio is not None:   # similar to EG3d
                if inputs['img_nerf'].size(-1) > (img.size(-1) // self.dual_input_ratio):
                    inputs['img_nerf'] = downsample(inputs['img_nerf'], img.size(-1) // self.dual_input_ratio)
                img = torch.cat([img, upsample(inputs['img_nerf'], img.size(-1))], 1)

        camera_loss = None
        RT = inputs['camera_matrices'][1].detach() if 'camera_matrices' in inputs else None
        UV = inputs['camera_matrices'][2].detach() if 'camera_matrices' in inputs else None

        # perform separate camera predictor or dual discriminator
        if self.dual_discriminator or self.separate_camera:
            temp_img = img if not self.dual_discriminator else inputs['img_nerf']
            c_nerf, x_nerf, img_nerf = self.get_estimated_camera(temp_img, **block_kwargs)
            if c.size(-1) == 0 and self.separate_camera:
                c = c_nerf
                if self.predict_3d_camera:
                    camera_loss = self.get_camera_loss(RT, UV, c)

        # if applied data augmentation for discriminator
        if aug_pipe is not None:
            assert self.separate_camera or (not self.predict_camera), "ada may break the camera predictor."
            img = aug_pipe(img)

        # obtain the downsampled image for progressive growing
        if self.progressive and (self.lowres_head is not None) and (self.alpha > -1) and (self.alpha < 1) and (alpha > 0):
            img0 = downsample(img, img.size(-1) // 2)

        x = None if (not self.progressive) or (block_resolutions[0] == self.img_resolution) \
            else getattr(self, f'b{block_resolutions[0]}').fromrgb(img)
        for res in block_resolutions:
            block = getattr(self, f'b{res}')
            if (lowres_head == res) and (self.alpha > -1) and (self.alpha < 1) and (alpha > 0):
                if self.architecture == 'skip':
                    img = img * alpha + img0 * (1 - alpha)
                if self.progressive:
                    x = x * alpha + block.fromrgb(img0) * (1 - alpha)
            x, img = block(x, img, **block_kwargs)

        # predict camera based on discriminator features
        if (c.size(-1) == 0) and self.predict_camera and (not self.separate_camera):
            c = self.projector(x)[:,:,0,0]
            if self.predict_9d_camera:
                c = camera_9d_to_16d(c)
            if self.predict_3d_camera:
                camera_loss = self.get_camera_loss(RT, UV, c)

        # camera conditional discriminator
        cmap = None
        if self.c_dim > 0:
            cc = c.clone().detach()
            cmap = self.mapping(None, cc)
        logits  = self.b4(x, img, cmap)
        if self.dual_discriminator:
            logits = torch.cat([logits, self.b4(x_nerf, img_nerf, cmap)], 0)

        outputs = {'logits': logits}
        if self.predict_camera and (camera_loss is not None):
            outputs['camera_loss'] = camera_loss
        if return_camera:
            outputs['camera'] = c
        return outputs