File size: 1,360 Bytes
482ab8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import torch
import torch.nn as nn
from einops import rearrange
def get_volume_mask_loss(opt):
return VolumeMaskLoss()
class VolumeMaskLoss(nn.Module):
def __init__(self):
super().__init__()
self.bce_loss = nn.BCELoss(reduction="mean")
def _get_volume_mask(self, mask):
with torch.no_grad():
h, w = mask.shape[-2:]
# use orthogonal vector [0, 1] and [1, 0] to generate the ground truth
mask[torch.where(mask > 0.5)] = 1.0
mask[torch.where(mask <= 0.5)] = 0.0
mask = rearrange(mask, "b c h w -> b c (h w)")
mask_append = 1 - mask.clone()
mask = torch.cat([mask, mask_append], dim=1)
mask = torch.bmm(mask.transpose(-1, -2), mask)
mask = rearrange(mask, "b (h1 w1) (h2 w2) -> b h1 w1 h2 w2", h1=h, h2=h)
mask = 1 - mask # 0 indicates consistency, and 1 indicates inconsistency
return mask
def forward(self, out_volume, mask):
volume_size = out_volume.shape[-2:]
if volume_size != mask.shape[-2:]:
mask = nn.functional.interpolate(
mask, size=volume_size, mode="bilinear", align_corners=False
)
volume_mask = self._get_volume_mask(mask)
loss = self.bce_loss(out_volume, volume_mask)
return {"loss": loss}
|