File size: 4,914 Bytes
482ab8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
from typing import Dict, List
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from skimage import segmentation
def get_multi_view_consistency_loss(opt):
loss = MultiViewConsistencyLoss(
opt.mvc_soft,
opt.mvc_zeros_on_au,
opt.mvc_single_weight,
opt.modality,
opt.mvc_spixel,
opt.mvc_num_spixel,
)
return loss
class MultiViewConsistencyLoss(nn.Module):
def __init__(
self,
soft: bool,
zeros_on_au: bool,
single_weight: Dict,
modality: List,
spixel: bool = False,
num_spixel: int = 100,
eps: float = 1e-4,
):
super().__init__()
self.soft = soft
self.zeros_on_au = zeros_on_au
self.single_weight = single_weight
self.modality = modality
self.spixel = spixel
self.num_spixel = num_spixel
self.eps = eps
self.mse_loss = nn.MSELoss(reduction="mean")
def forward(self, output: Dict, label, spixel=None, image=None, mask=None):
tgt_map = torch.zeros_like(
output[self.modality[0]]["out_map"], requires_grad=False
)
with torch.no_grad():
for modality in self.modality:
weight = self.single_weight[modality.lower()]
tgt_map = tgt_map + weight * output[modality]["out_map"]
if self.spixel:
# raw_tgt_map = tgt_map.clone()
tgt_map = get_spixel_tgt_map(tgt_map, spixel)
if not self.soft:
for b in range(tgt_map.shape[0]):
if tgt_map[b, ...].max() <= 0.5 and label[b] == 1.0:
tgt_map[b, ...][
torch.where(tgt_map[b, ...] == torch.max(tgt_map[b, ...]))
] = 1.0
tgt_map[torch.where(tgt_map > 0.5)] = 1
tgt_map[torch.where(tgt_map <= 0.5)] = 0
tgt_map[torch.where(label == 0.0)[0], ...] = 0.0
if self.zeros_on_au:
tgt_map[torch.where(label == 0.0)[0], ...] = 0.0
total_loss = 0.0
loss_dict = {}
for modality in self.modality:
loss = self.mse_loss(output[modality]["out_map"], tgt_map)
loss_dict[f"multi_view_consistency_loss_{modality}"] = loss
total_loss = total_loss + loss
return {**loss_dict, "tgt_map": tgt_map, "total_loss": total_loss}
def _save(
self,
spixel: torch.Tensor,
image: torch.Tensor,
mask: torch.Tensor,
tgt_map: torch.Tensor,
raw_tgt_map: torch.Tensor,
out_path: str = "tmp/spixel_tgt_map.png",
):
spixel = spixel.permute(0, 2, 3, 1).detach().cpu().numpy()
image = image.permute(0, 2, 3, 1).detach().cpu().numpy()
mask = mask.permute(0, 2, 3, 1).detach().cpu().numpy() * 255.0
tgt_map = tgt_map.permute(0, 2, 3, 1).squeeze(3).detach().cpu().numpy() * 255.0
raw_tgt_map = (
raw_tgt_map.permute(0, 2, 3, 1).squeeze(3).detach().cpu().numpy() * 255.0
)
bn = spixel.shape[0]
i = 1
for b in range(bn):
plt.subplot(bn, 5, i)
i += 1
plt.imshow(image[b])
plt.axis("off")
plt.title("image")
plt.subplot(bn, 5, i)
i += 1
plt.imshow(mask[b])
plt.axis("off")
plt.title("mask")
plt.subplot(bn, 5, i)
i += 1
plt.imshow(spixel[b])
plt.axis("off")
plt.title("superpixel")
plt.subplot(bn, 5, i)
i += 1
plt.imshow(raw_tgt_map[b])
plt.axis("off")
plt.title("raw target map")
plt.subplot(bn, 5, i)
i += 1
plt.imshow(tgt_map[b])
plt.axis("off")
plt.title("target map")
plt.tight_layout()
plt.savefig(out_path, dpi=300)
plt.close()
def get_spixel_tgt_map(weighted_sum, spixel):
b, _, h, w = weighted_sum.shape
spixel_tgt_map = torch.zeros_like(weighted_sum, requires_grad=False)
for bidx in range(b):
spixel_indices = spixel[bidx, ...].unique()
# num_spixel = spixel_idx.shape[0]
for spixel_idx in spixel_indices.tolist():
area = (spixel[bidx, ...] == spixel_idx).sum()
weighted_sum_in_area = weighted_sum[bidx, ...][
torch.where(spixel[bidx, ...] == spixel_idx)
].sum()
avg_area = weighted_sum_in_area / area
# this is soft map, and the threshold process will be conducted in the forward function
spixel_tgt_map[bidx][
torch.where(spixel[bidx, ...] == spixel_idx)
] = avg_area
return spixel_tgt_map
if __name__ == "__main__":
mvc_loss = MultiViewConsistencyLoss(True, True, [1, 1, 2])
print("a")
|