File size: 29,163 Bytes
7f51798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
# Copyright 2022 The Nerfstudio Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Collection of Losses.
"""

import torch
import torch.nn.functional as F
from torch import nn
from torchtyping import TensorType
from torch.autograd import Variable
import numpy as np
from math import exp

# from nerfstudio.cameras.rays import RaySamples
# from nerfstudio.field_components.field_heads import FieldHeadNames

L1Loss = nn.L1Loss
MSELoss = nn.MSELoss

LOSSES = {"L1": L1Loss, "MSE": MSELoss}

EPS = 1.0e-7


def outer(
    t0_starts: TensorType[..., "num_samples_0"],
    t0_ends: TensorType[..., "num_samples_0"],
    t1_starts: TensorType[..., "num_samples_1"],
    t1_ends: TensorType[..., "num_samples_1"],
    y1: TensorType[..., "num_samples_1"],
) -> TensorType[..., "num_samples_0"]:
    """Faster version of

    https://github.com/kakaobrain/NeRF-Factory/blob/f61bb8744a5cb4820a4d968fb3bfbed777550f4a/src/model/mipnerf360/helper.py#L117
    https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/stepfun.py#L64

    Args:
        t0_starts: start of the interval edges
        t0_ends: end of the interval edges
        t1_starts: start of the interval edges
        t1_ends: end of the interval edges
        y1: weights
    """
    cy1 = torch.cat([torch.zeros_like(y1[..., :1]), torch.cumsum(y1, dim=-1)], dim=-1)

    idx_lo = torch.searchsorted(t1_starts.contiguous(), t0_starts.contiguous(), side="right") - 1
    idx_lo = torch.clamp(idx_lo, min=0, max=y1.shape[-1] - 1)
    idx_hi = torch.searchsorted(t1_ends.contiguous(), t0_ends.contiguous(), side="right")
    idx_hi = torch.clamp(idx_hi, min=0, max=y1.shape[-1] - 1)
    cy1_lo = torch.take_along_dim(cy1[..., :-1], idx_lo, dim=-1)
    cy1_hi = torch.take_along_dim(cy1[..., 1:], idx_hi, dim=-1)
    y0_outer = cy1_hi - cy1_lo

    return y0_outer


def lossfun_outer(
    t: TensorType[..., "num_samples+1"],
    w: TensorType[..., "num_samples"],
    t_env: TensorType[..., "num_samples+1"],
    w_env: TensorType[..., "num_samples"],
):
    """
    https://github.com/kakaobrain/NeRF-Factory/blob/f61bb8744a5cb4820a4d968fb3bfbed777550f4a/src/model/mipnerf360/helper.py#L136
    https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/stepfun.py#L80

    Args:
        t: interval edges
        w: weights
        t_env: interval edges of the upper bound enveloping historgram
        w_env: weights that should upper bound the inner (t,w) histogram
    """
    w_outer = outer(t[..., :-1], t[..., 1:], t_env[..., :-1], t_env[..., 1:], w_env)
    return torch.clip(w - w_outer, min=0) ** 2 / (w + EPS)


def ray_samples_to_sdist(ray_samples):
    """Convert ray samples to s space"""
    starts = ray_samples.spacing_starts
    ends = ray_samples.spacing_ends
    sdist = torch.cat([starts[..., 0], ends[..., -1:, 0]], dim=-1)  # (num_rays, num_samples + 1)
    return sdist


def interlevel_loss(weights_list, ray_samples_list):
    """Calculates the proposal loss in the MipNeRF-360 paper.

    https://github.com/kakaobrain/NeRF-Factory/blob/f61bb8744a5cb4820a4d968fb3bfbed777550f4a/src/model/mipnerf360/model.py#L515
    https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/train_utils.py#L133
    """
    c = ray_samples_to_sdist(ray_samples_list[-1]).detach()
    w = weights_list[-1][..., 0].detach()
    loss_interlevel = 0.0
    for ray_samples, weights in zip(ray_samples_list[:-1], weights_list[:-1]):
        sdist = ray_samples_to_sdist(ray_samples)
        cp = sdist  # (num_rays, num_samples + 1)
        wp = weights[..., 0]  # (num_rays, num_samples)
        loss_interlevel += torch.mean(lossfun_outer(c, w, cp, wp))
    return loss_interlevel


## zip-NeRF losses
def blur_stepfun(x, y, r):
    x_c = torch.cat([x - r, x + r], dim=-1)
    x_r, x_idx = torch.sort(x_c, dim=-1)
    zeros = torch.zeros_like(y[:, :1])
    y_1 = (torch.cat([y, zeros], dim=-1) - torch.cat([zeros, y], dim=-1)) / (2 * r)
    x_idx = x_idx[:, :-1]
    y_2 = torch.cat([y_1, -y_1], dim=-1)[
        torch.arange(x_idx.shape[0]).reshape(-1, 1).expand(x_idx.shape).to(x_idx.device), x_idx
    ]

    y_r = torch.cumsum((x_r[:, 1:] - x_r[:, :-1]) * torch.cumsum(y_2, dim=-1), dim=-1)
    y_r = torch.cat([zeros, y_r], dim=-1)
    return x_r, y_r


def interlevel_loss_zip(weights_list, ray_samples_list):
    """Calculates the proposal loss in the Zip-NeRF paper."""
    c = ray_samples_to_sdist(ray_samples_list[-1]).detach()
    w = weights_list[-1][..., 0].detach()

    # 1. normalize
    w_normalize = w / (c[:, 1:] - c[:, :-1])

    loss_interlevel = 0.0
    for ray_samples, weights, r in zip(ray_samples_list[:-1], weights_list[:-1], [0.03, 0.003]):
        # 2. step blur with different r
        x_r, y_r = blur_stepfun(c, w_normalize, r)
        y_r = torch.clip(y_r, min=0)
        assert (y_r >= 0.0).all()

        # 3. accumulate
        y_cum = torch.cumsum((y_r[:, 1:] + y_r[:, :-1]) * 0.5 * (x_r[:, 1:] - x_r[:, :-1]), dim=-1)
        y_cum = torch.cat([torch.zeros_like(y_cum[:, :1]), y_cum], dim=-1)

        # 4 loss
        sdist = ray_samples_to_sdist(ray_samples)
        cp = sdist  # (num_rays, num_samples + 1)
        wp = weights[..., 0]  # (num_rays, num_samples)

        # resample
        inds = torch.searchsorted(x_r, cp, side="right")
        below = torch.clamp(inds - 1, 0, x_r.shape[-1] - 1)
        above = torch.clamp(inds, 0, x_r.shape[-1] - 1)
        cdf_g0 = torch.gather(x_r, -1, below)
        bins_g0 = torch.gather(y_cum, -1, below)
        cdf_g1 = torch.gather(x_r, -1, above)
        bins_g1 = torch.gather(y_cum, -1, above)

        t = torch.clip(torch.nan_to_num((cp - cdf_g0) / (cdf_g1 - cdf_g0), 0), 0, 1)
        bins = bins_g0 + t * (bins_g1 - bins_g0)

        w_gt = bins[:, 1:] - bins[:, :-1]

        # TODO here might be unstable when wp is very small
        loss_interlevel += torch.mean(torch.clip(w_gt - wp, min=0) ** 2 / (wp + 1e-5))

    return loss_interlevel


# Verified
def lossfun_distortion(t, w):
    """
    https://github.com/kakaobrain/NeRF-Factory/blob/f61bb8744a5cb4820a4d968fb3bfbed777550f4a/src/model/mipnerf360/helper.py#L142
    https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/stepfun.py#L266
    """
    ut = (t[..., 1:] + t[..., :-1]) / 2
    dut = torch.abs(ut[..., :, None] - ut[..., None, :])
    loss_inter = torch.sum(w * torch.sum(w[..., None, :] * dut, dim=-1), dim=-1)

    loss_intra = torch.sum(w**2 * (t[..., 1:] - t[..., :-1]), dim=-1) / 3

    return loss_inter + loss_intra


def distortion_loss(weights_list, ray_samples_list):
    """From mipnerf360"""
    c = ray_samples_to_sdist(ray_samples_list[-1])
    w = weights_list[-1][..., 0]
    loss = torch.mean(lossfun_distortion(c, w))
    return loss


# def nerfstudio_distortion_loss(
#     ray_samples: RaySamples,
#     densities: TensorType["bs":..., "num_samples", 1] = None,
#     weights: TensorType["bs":..., "num_samples", 1] = None,
# ) -> TensorType["bs":..., 1]:
#     """Ray based distortion loss proposed in MipNeRF-360. Returns distortion Loss.

#     .. math::

#         \\mathcal{L}(\\mathbf{s}, \\mathbf{w}) =\\iint\\limits_{-\\infty}^{\\,\\,\\,\\infty}
#         \\mathbf{w}_\\mathbf{s}(u)\\mathbf{w}_\\mathbf{s}(v)|u - v|\\,d_{u}\\,d_{v}

#     where :math:`\\mathbf{w}_\\mathbf{s}(u)=\\sum_i w_i \\mathbb{1}_{[\\mathbf{s}_i, \\mathbf{s}_{i+1})}(u)`
#     is the weight at location :math:`u` between bin locations :math:`s_i` and :math:`s_{i+1}`.

#     Args:
#         ray_samples: Ray samples to compute loss over
#         densities: Predicted sample densities
#         weights: Predicted weights from densities and sample locations
#     """
#     if torch.is_tensor(densities):
#         assert not torch.is_tensor(weights), "Cannot use both densities and weights"
#         # Compute the weight at each sample location
#         weights = ray_samples.get_weights(densities)
#     if torch.is_tensor(weights):
#         assert not torch.is_tensor(densities), "Cannot use both densities and weights"

#     starts = ray_samples.spacing_starts
#     ends = ray_samples.spacing_ends

#     assert starts is not None and ends is not None, "Ray samples must have spacing starts and ends"
#     midpoints = (starts + ends) / 2.0  # (..., num_samples, 1)

#     loss = (
#         weights * weights[..., None, :, 0] * torch.abs(midpoints - midpoints[..., None, :, 0])
#     )  # (..., num_samples, num_samples)
#     loss = torch.sum(loss, dim=(-1, -2))[..., None]  # (..., num_samples)
#     loss = loss + 1 / 3.0 * torch.sum(weights**2 * (ends - starts), dim=-2)

#     return loss


def orientation_loss(
    weights: TensorType["bs":..., "num_samples", 1],
    normals: TensorType["bs":..., "num_samples", 3],
    viewdirs: TensorType["bs":..., 3],
):
    """Orientation loss proposed in Ref-NeRF.
    Loss that encourages that all visible normals are facing towards the camera.
    """
    w = weights
    n = normals
    v = viewdirs
    n_dot_v = (n * v[..., None, :]).sum(axis=-1)
    return (w[..., 0] * torch.fmin(torch.zeros_like(n_dot_v), n_dot_v) ** 2).sum(dim=-1)


def pred_normal_loss(
    weights: TensorType["bs":..., "num_samples", 1],
    normals: TensorType["bs":..., "num_samples", 3],
    pred_normals: TensorType["bs":..., "num_samples", 3],
):
    """Loss between normals calculated from density and normals from prediction network."""
    return (weights[..., 0] * (1.0 - torch.sum(normals * pred_normals, dim=-1))).sum(dim=-1)


def monosdf_normal_loss(normal_pred: torch.Tensor, normal_gt: torch.Tensor):
    """normal consistency loss as monosdf

    Args:
        normal_pred (torch.Tensor): volume rendered normal
        normal_gt (torch.Tensor): monocular normal
    """
    normal_gt = torch.nn.functional.normalize(normal_gt, p=2, dim=-1)
    normal_pred = torch.nn.functional.normalize(normal_pred, p=2, dim=-1)
    l1 = torch.abs(normal_pred - normal_gt).sum(dim=-1).mean()
    cos = (1.0 - torch.sum(normal_pred * normal_gt, dim=-1)).mean()
    return l1 + cos


# copy from MiDaS
def compute_scale_and_shift(prediction, target, mask):
    # system matrix: A = [[a_00, a_01], [a_10, a_11]]
    a_00 = torch.sum(mask * prediction * prediction, (1, 2))
    a_01 = torch.sum(mask * prediction, (1, 2))
    a_11 = torch.sum(mask, (1, 2))

    # right hand side: b = [b_0, b_1]
    b_0 = torch.sum(mask * prediction * target, (1, 2))
    b_1 = torch.sum(mask * target, (1, 2))

    # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b
    x_0 = torch.zeros_like(b_0)
    x_1 = torch.zeros_like(b_1)

    det = a_00 * a_11 - a_01 * a_01
    valid = det.nonzero()

    x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid]
    x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid]

    return x_0, x_1


def reduction_batch_based(image_loss, M):
    # average of all valid pixels of the batch

    # avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0)
    divisor = torch.sum(M)

    if divisor == 0:
        return 0
    else:
        return torch.sum(image_loss) / divisor


def reduction_image_based(image_loss, M):
    # mean of average of valid pixels of an image

    # avoid division by 0 (if M = sum(mask) = 0: image_loss = 0)
    valid = M.nonzero()

    image_loss[valid] = image_loss[valid] / M[valid]

    return torch.mean(image_loss)


def mse_loss(prediction, target, mask, reduction=reduction_batch_based):
    M = torch.sum(mask, (1, 2))
    res = prediction - target
    image_loss = torch.sum(mask * res * res, (1, 2))

    return reduction(image_loss, 2 * M)


def gradient_loss(prediction, target, mask, reduction=reduction_batch_based):
    M = torch.sum(mask, (1, 2))

    diff = prediction - target
    diff = torch.mul(mask, diff)

    grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1])
    mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1])
    grad_x = torch.mul(mask_x, grad_x)

    grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :])
    mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :])
    grad_y = torch.mul(mask_y, grad_y)

    image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2))

    return reduction(image_loss, M)


class MiDaSMSELoss(nn.Module):
    def __init__(self, reduction="batch-based"):
        super().__init__()

        if reduction == "batch-based":
            self.__reduction = reduction_batch_based
        else:
            self.__reduction = reduction_image_based

    def forward(self, prediction, target, mask):
        return mse_loss(prediction, target, mask, reduction=self.__reduction)


class GradientLoss(nn.Module):
    def __init__(self, scales=4, reduction="batch-based"):
        super().__init__()

        if reduction == "batch-based":
            self.__reduction = reduction_batch_based
        else:
            self.__reduction = reduction_image_based

        self.__scales = scales

    def forward(self, prediction, target, mask):
        total = 0

        for scale in range(self.__scales):
            step = pow(2, scale)

            total += gradient_loss(
                prediction[:, ::step, ::step],
                target[:, ::step, ::step],
                mask[:, ::step, ::step],
                reduction=self.__reduction,
            )

        return total


class ScaleAndShiftInvariantLoss(nn.Module):
    def __init__(self, alpha=0.5, scales=4, reduction="batch-based"):
        super().__init__()

        self.__data_loss = MiDaSMSELoss(reduction=reduction)
        self.__regularization_loss = GradientLoss(scales=scales, reduction=reduction)
        self.__alpha = alpha

        self.__prediction_ssi = None

    def forward(self, prediction, target, mask):
        scale, shift = compute_scale_and_shift(prediction, target, mask)
        self.__prediction_ssi = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1)

        total = self.__data_loss(self.__prediction_ssi, target, mask)
        if self.__alpha > 0:
            total += self.__alpha * self.__regularization_loss(self.__prediction_ssi, target, mask)

        return total

    def __get_prediction_ssi(self):
        return self.__prediction_ssi

    prediction_ssi = property(__get_prediction_ssi)


# end copy


# copy from https://github.com/svip-lab/Indoor-SfMLearner/blob/0d682b7ce292484e5e3e2161fc9fc07e2f5ca8d1/layers.py#L218
class SSIM(nn.Module):
    """Layer to compute the SSIM loss between a pair of images"""

    def __init__(self, patch_size):
        super(SSIM, self).__init__()
        self.mu_x_pool = nn.AvgPool2d(patch_size, 1)
        self.mu_y_pool = nn.AvgPool2d(patch_size, 1)
        self.sig_x_pool = nn.AvgPool2d(patch_size, 1)
        self.sig_y_pool = nn.AvgPool2d(patch_size, 1)
        self.sig_xy_pool = nn.AvgPool2d(patch_size, 1)

        self.refl = nn.ReflectionPad2d(patch_size // 2)

        self.C1 = 0.01**2
        self.C2 = 0.03**2

    def forward(self, x, y):
        x = self.refl(x)
        y = self.refl(y)

        mu_x = self.mu_x_pool(x)
        mu_y = self.mu_y_pool(y)

        sigma_x = self.sig_x_pool(x**2) - mu_x**2
        sigma_y = self.sig_y_pool(y**2) - mu_y**2
        sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y

        SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2)
        SSIM_d = (mu_x**2 + mu_y**2 + self.C1) * (sigma_x + sigma_y + self.C2)

        return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1)


# TODO test different losses
class NCC(nn.Module):
    """Layer to compute the normalization cross correlation (NCC) of patches"""

    def __init__(self, patch_size: int = 11, min_patch_variance: float = 0.01):
        super(NCC, self).__init__()
        self.patch_size = patch_size
        self.min_patch_variance = min_patch_variance

    def forward(self, x, y):
        # TODO if we use gray image we should do it right after loading the image to save computations
        # to gray image
        x = torch.mean(x, dim=1)
        y = torch.mean(y, dim=1)

        x_mean = torch.mean(x, dim=(1, 2), keepdim=True)
        y_mean = torch.mean(y, dim=(1, 2), keepdim=True)

        x_normalized = x - x_mean
        y_normalized = y - y_mean

        norm = torch.sum(x_normalized * y_normalized, dim=(1, 2))
        var = torch.square(x_normalized).sum(dim=(1, 2)) * torch.square(y_normalized).sum(dim=(1, 2))
        denom = torch.sqrt(var + 1e-6)

        ncc = norm / (denom + 1e-6)

        # ignore pathces with low variances
        not_valid = (torch.square(x_normalized).sum(dim=(1, 2)) < self.min_patch_variance) | (
            torch.square(y_normalized).sum(dim=(1, 2)) < self.min_patch_variance
        )
        ncc[not_valid] = 1.0

        score = 1 - ncc.clip(-1.0, 1.0)  # 0->2: smaller, better
        return score[:, None, None, None]


class MultiViewLoss(nn.Module):
    """compute multi-view consistency loss"""

    def __init__(self, patch_size: int = 11, topk: int = 4, min_patch_variance: float = 0.01):
        super(MultiViewLoss, self).__init__()
        self.patch_size = patch_size
        self.topk = topk
        self.min_patch_variance = min_patch_variance
        # TODO make metric configurable
        # self.ssim = SSIM(patch_size=patch_size)
        # self.ncc = NCC(patch_size=patch_size)
        self.ssim = NCC(patch_size=patch_size, min_patch_variance=min_patch_variance)

        self.iter = 0

    def forward(self, patches: torch.Tensor, valid: torch.Tensor):
        """take the mim

        Args:
            patches (torch.Tensor): _description_
            valid (torch.Tensor): _description_

        Returns:
            _type_: _description_
        """
        num_imgs, num_rays, _, num_channels = patches.shape

        if num_rays <= 0:
            return torch.tensor(0.0).to(patches.device)

        ref_patches = (
            patches[:1, ...]
            .reshape(1, num_rays, self.patch_size, self.patch_size, num_channels)
            .expand(num_imgs - 1, num_rays, self.patch_size, self.patch_size, num_channels)
            .reshape(-1, self.patch_size, self.patch_size, num_channels)
            .permute(0, 3, 1, 2)
        )  # [N_src*N_rays, 3, patch_size, patch_size]
        src_patches = (
            patches[1:, ...]
            .reshape(num_imgs - 1, num_rays, self.patch_size, self.patch_size, num_channels)
            .reshape(-1, self.patch_size, self.patch_size, num_channels)
            .permute(0, 3, 1, 2)
        )  # [N_src*N_rays, 3, patch_size, patch_size]

        # apply same reshape to the valid mask
        src_patches_valid = (
            valid[1:, ...]
            .reshape(num_imgs - 1, num_rays, self.patch_size, self.patch_size, 1)
            .reshape(-1, self.patch_size, self.patch_size, 1)
            .permute(0, 3, 1, 2)
        )  # [N_src*N_rays, 1, patch_size, patch_size]

        ssim = self.ssim(ref_patches.detach(), src_patches)
        ssim = torch.mean(ssim, dim=(1, 2, 3))
        ssim = ssim.reshape(num_imgs - 1, num_rays)

        # ignore invalid patch by setting ssim error to very large value
        ssim_valid = (
            src_patches_valid.reshape(-1, self.patch_size * self.patch_size).all(dim=-1).reshape(num_imgs - 1, num_rays)
        )
        # we should mask the error after we select the topk value, otherwise we might select far way patches that happens to be inside the image
        # ssim[torch.logical_not(ssim_valid)] = 1.1  # max ssim_error is 1

        min_ssim, idx = torch.topk(ssim, k=self.topk, largest=False, dim=0, sorted=True)

        min_ssim_valid = ssim_valid[idx, torch.arange(num_rays)[None].expand_as(idx)]
        # TODO how to set this value for better visualization
        min_ssim[torch.logical_not(min_ssim_valid)] = 0.0  # max ssim_error is 1

        if False:
            # visualization of topK error computations

            import cv2
            import numpy as np

            vis_patch_num = num_rays
            K = min(100, vis_patch_num)

            image = (
                patches[:, :vis_patch_num, :, :]
                .reshape(-1, vis_patch_num, self.patch_size, self.patch_size, 3)
                .permute(1, 2, 0, 3, 4)
                .reshape(vis_patch_num * self.patch_size, -1, 3)
            )

            src_patches_reshaped = src_patches.reshape(
                num_imgs - 1, num_rays, 3, self.patch_size, self.patch_size
            ).permute(1, 0, 3, 4, 2)
            idx = idx.permute(1, 0)

            selected_patch = (
                src_patches_reshaped[torch.arange(num_rays)[:, None].expand(idx.shape), idx]
                .permute(0, 2, 1, 3, 4)
                .reshape(num_rays, self.patch_size, self.topk * self.patch_size, 3)[:vis_patch_num]
                .reshape(-1, self.topk * self.patch_size, 3)
            )

            # apply same reshape to the valid mask
            src_patches_valid_reshaped = src_patches_valid.reshape(
                num_imgs - 1, num_rays, 1, self.patch_size, self.patch_size
            ).permute(1, 0, 3, 4, 2)

            selected_patch_valid = (
                src_patches_valid_reshaped[torch.arange(num_rays)[:, None].expand(idx.shape), idx]
                .permute(0, 2, 1, 3, 4)
                .reshape(num_rays, self.patch_size, self.topk * self.patch_size, 1)[:vis_patch_num]
                .reshape(-1, self.topk * self.patch_size, 1)
            )
            # valid to image
            selected_patch_valid = selected_patch_valid.expand_as(selected_patch).float()
            # breakpoint()

            image = torch.cat([selected_patch_valid, selected_patch, image], dim=1)
            # select top rays with highest errors

            image = image.reshape(num_rays, self.patch_size, -1, 3)

            _, idx2 = torch.topk(
                torch.sum(min_ssim, dim=0) / (min_ssim_valid.float().sum(dim=0) + 1e-6),
                k=K,
                largest=True,
                dim=0,
                sorted=True,
            )

            image = image[idx2].reshape(K * self.patch_size, -1, 3)

            cv2.imwrite(f"vis/{self.iter}.png", (image.detach().cpu().numpy() * 255).astype(np.uint8)[..., ::-1])
            self.iter += 1
            if self.iter == 9:
                breakpoint()

        return torch.sum(min_ssim) / (min_ssim_valid.float().sum() + 1e-6)


# sensor depth loss, adapted from https://github.com/dazinovic/neural-rgbd-surface-reconstruction/blob/main/losses.py
# class SensorDepthLoss(nn.Module):
#     """Sensor Depth loss"""

#     def __init__(self, truncation: float):
#         super(SensorDepthLoss, self).__init__()
#         self.truncation = truncation  #  0.05 * 0.3 5cm scaled

#     def forward(self, batch, outputs):
#         """take the mim

#         Args:
#             batch (Dict): inputs
#             outputs (Dict): outputs data from surface model

#         Returns:
#             l1_loss: l1 loss
#             freespace_loss: free space loss
#             sdf_loss: sdf loss
#         """
#         depth_pred = outputs["depth"]
#         depth_gt = batch["sensor_depth"].to(depth_pred.device)[..., None]
#         valid_gt_mask = depth_gt > 0.0

#         l1_loss = torch.sum(valid_gt_mask * torch.abs(depth_gt - depth_pred)) / (valid_gt_mask.sum() + 1e-6)

#         # free space loss and sdf loss
#         ray_samples = outputs["ray_samples"]
#         filed_outputs = outputs["field_outputs"]
#         pred_sdf = filed_outputs[FieldHeadNames.SDF][..., 0]
#         directions_norm = outputs["directions_norm"]

#         z_vals = ray_samples.frustums.starts[..., 0] / directions_norm

#         truncation = self.truncation
#         front_mask = valid_gt_mask & (z_vals < (depth_gt - truncation))
#         back_mask = valid_gt_mask & (z_vals > (depth_gt + truncation))
#         sdf_mask = valid_gt_mask & (~front_mask) & (~back_mask)

#         num_fs_samples = front_mask.sum()
#         num_sdf_samples = sdf_mask.sum()
#         num_samples = num_fs_samples + num_sdf_samples + 1e-6
#         fs_weight = 1.0 - num_fs_samples / num_samples
#         sdf_weight = 1.0 - num_sdf_samples / num_samples

#         free_space_loss = torch.mean((F.relu(truncation - pred_sdf) * front_mask) ** 2) * fs_weight

#         sdf_loss = torch.mean(((z_vals + pred_sdf) - depth_gt) ** 2 * sdf_mask) * sdf_weight

#         return l1_loss, free_space_loss, sdf_loss

r"""Implements Stochastic Structural SIMilarity(S3IM) algorithm.
It is proposed in the ICCV2023 paper  
`S3IM: Stochastic Structural SIMilarity and Its Unreasonable Effectiveness for Neural Fields`.

Arguments:
    s3im_kernel_size (int): kernel size in ssim's convolution(default: 4)
    s3im_stride (int): stride in ssim's convolution(default: 4)
    s3im_repeat_time (int): repeat time in re-shuffle virtual patch(default: 10)
    s3im_patch_height (height): height of virtual patch(default: 64)
"""

class S3IM(torch.nn.Module):
    def __init__(self, s3im_kernel_size = 4, s3im_stride=4, s3im_repeat_time=10, s3im_patch_height=64, size_average = True):
        super(S3IM, self).__init__()
        self.s3im_kernel_size = s3im_kernel_size
        self.s3im_stride = s3im_stride
        self.s3im_repeat_time = s3im_repeat_time
        self.s3im_patch_height = s3im_patch_height
        self.size_average = size_average
        self.channel = 1
        self.s3im_kernel = self.create_kernel(s3im_kernel_size, self.channel)

    
    def gaussian(self, s3im_kernel_size, sigma):
        gauss = torch.Tensor([exp(-(x - s3im_kernel_size//2)**2/float(2*sigma**2)) for x in range(s3im_kernel_size)])
        return gauss/gauss.sum()

    def create_kernel(self, s3im_kernel_size, channel):
        _1D_window = self.gaussian(s3im_kernel_size, 1.5).unsqueeze(1)
        _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
        s3im_kernel = Variable(_2D_window.expand(channel, 1, s3im_kernel_size, s3im_kernel_size).contiguous())
        return s3im_kernel

    def _ssim(self, img1, img2, s3im_kernel, s3im_kernel_size, channel, size_average = True, s3im_stride=None):
        mu1 = F.conv2d(img1, s3im_kernel, padding = (s3im_kernel_size-1)//2, groups = channel, stride=s3im_stride)
        mu2 = F.conv2d(img2, s3im_kernel, padding = (s3im_kernel_size-1)//2, groups = channel, stride=s3im_stride)

        mu1_sq = mu1.pow(2)
        mu2_sq = mu2.pow(2)
        mu1_mu2 = mu1*mu2

        sigma1_sq = F.conv2d(img1*img1, s3im_kernel, padding = (s3im_kernel_size-1)//2, groups = channel, stride=s3im_stride) - mu1_sq
        sigma2_sq = F.conv2d(img2*img2, s3im_kernel, padding = (s3im_kernel_size-1)//2, groups = channel, stride=s3im_stride) - mu2_sq
        sigma12 = F.conv2d(img1*img2, s3im_kernel, padding = (s3im_kernel_size-1)//2, groups = channel, stride=s3im_stride) - mu1_mu2

        C1 = 0.01**2
        C2 = 0.03**2

        ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))

        if size_average:
            return ssim_map.mean()
        else:
            return ssim_map.mean(1).mean(1).mean(1)
    
    def ssim_loss(self, img1, img2):
        """
        img1, img2: torch.Tensor([b,c,h,w])
        """
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.s3im_kernel.data.type() == img1.data.type():
            s3im_kernel = self.s3im_kernel
        else:
            s3im_kernel = self.create_kernel(self.s3im_kernel_size, channel)

            if img1.is_cuda:
                s3im_kernel = s3im_kernel.cuda(img1.get_device())
            s3im_kernel = s3im_kernel.type_as(img1)

            self.s3im_kernel = s3im_kernel
            self.channel = channel


        return self._ssim(img1, img2, s3im_kernel, self.s3im_kernel_size, channel, self.size_average, s3im_stride=self.s3im_stride)

    def forward(self, src_vec, tar_vec):
        loss = 0.0
        index_list = []
        for i in range(self.s3im_repeat_time):
            if i == 0:
                tmp_index = torch.arange(len(tar_vec))
                index_list.append(tmp_index)
            else:
                ran_idx = torch.randperm(len(tar_vec))
                index_list.append(ran_idx)
        res_index = torch.cat(index_list)
        tar_all = tar_vec[res_index]
        src_all = src_vec[res_index]
        tar_patch = tar_all.permute(1, 0).reshape(1, 3, self.s3im_patch_height, -1)
        src_patch = src_all.permute(1, 0).reshape(1, 3, self.s3im_patch_height, -1)
        loss = (1 - self.ssim_loss(src_patch, tar_patch))
        return loss