import torch import numpy def predict(mask, label, threshold=0.5, score_type='combined'): with torch.no_grad(): if score_type == 'pixel': score = torch.mean(mask, axis=(1, 2, 3)) elif score_type == 'binary': score = label else: score = (torch.mean(mask, axis=(1, 2, 3)) + label) / 2 preds = (score > threshold).type(torch.FloatTensor) return preds, score def test_accuracy(model, test_dl): acc = 0 total = len(test_dl.dataset) for img, mask, label in test_dl: net_mask, net_label = model(img) preds, _ = predict(net_mask, net_label) ac = (preds == label).type(torch.FloatTensor) acc += torch.sum(ac).item() return (acc / total) * 100 def test_loss(model, test_dl, loss_fn): loss = 0 total = len(test_dl) for img, mask, label in test_dl: net_mask, net_label = model(img) losses = loss_fn(net_mask, net_label, mask, label) loss += torch.mean(losses).item() return loss / total