|
import torch |
|
import torch.nn as nn |
|
from torch.utils import data |
|
import torchvision.transforms as transform |
|
import torch.nn.functional as F |
|
from PIL import Image |
|
import numpy as np |
|
from collections import defaultdict, deque |
|
import torch.distributed as dist |
|
|
|
def colorize_mask(mask): |
|
palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 153, 153, 153, 250, 170, 30, |
|
220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60, 255, 0, 0, 0, 0, 142, 0, 0, 70, |
|
0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32] |
|
|
|
zero_pad = 256 * 3 - len(palette) |
|
for i in range(zero_pad): |
|
palette.append(0) |
|
new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P') |
|
new_mask.putpalette(palette) |
|
return new_mask |
|
|
|
|
|
def build_img(args): |
|
from PIL import Image |
|
img = Image.open(args.input_path) |
|
input_transform = transform.Compose([ |
|
transform.ToTensor(), |
|
transform.Normalize([.485, .456, .406], [.229, .224, .225]), |
|
transform.Resize((256, 512))]) |
|
resized_img = input_transform(img) |
|
resized_img = resized_img.unsqueeze(0) |
|
return resized_img |
|
|
|
class ConfusionMatrix(object): |
|
def __init__(self, num_classes): |
|
self.num_classes = num_classes |
|
self.mat = None |
|
|
|
def update(self, a, b): |
|
n = self.num_classes |
|
if self.mat is None: |
|
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device) |
|
with torch.no_grad(): |
|
k = (a >= 0) & (a < n) |
|
inds = n * a[k].to(torch.int64) + b[k] |
|
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n) |
|
|
|
def reset(self): |
|
self.mat.zero_() |
|
|
|
def compute(self): |
|
h = self.mat.float() |
|
acc_global = torch.diag(h).sum() / h.sum() |
|
acc = torch.diag(h) / h.sum(1) |
|
iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h)) |
|
return acc_global, acc, iu |
|
|
|
def reduce_from_all_processes(self): |
|
if not torch.distributed.is_available(): |
|
return |
|
if not torch.distributed.is_initialized(): |
|
return |
|
torch.distributed.barrier() |
|
torch.distributed.all_reduce(self.mat) |
|
|
|
def __str__(self): |
|
acc_global, acc, iu = self.compute() |
|
|
|
return ( |
|
'per-class IoU(%): \n {}\n' |
|
'mean IoU(%): {:.1f}').format( |
|
['{:.1f}'.format(i) for i in (iu * 100).tolist()], |
|
iu.mean().item() * 100) |