|
import numpy as np |
|
import torch |
|
import torchvision.transforms as transforms |
|
from torch.utils.data import DataLoader |
|
from numpy import * |
|
import argparse |
|
from PIL import Image |
|
import imageio |
|
import os |
|
from tqdm import tqdm |
|
from SegmentationTest.utils.metrices import * |
|
|
|
from SegmentationTest.utils import render |
|
from SegmentationTest.utils.saver import Saver |
|
from SegmentationTest.utils.iou import IoU |
|
|
|
from SegmentationTest.data.Imagenet import Imagenet_Segmentation |
|
|
|
|
|
|
|
|
|
from ViT.ViT import vit_base_patch16_224 as vit |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from ViT.explainer import generate_relevance, get_image_with_relevance |
|
|
|
from sklearn.metrics import precision_recall_curve |
|
import matplotlib.pyplot as plt |
|
|
|
import torch.nn.functional as F |
|
|
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
plt.switch_backend('agg') |
|
|
|
|
|
num_workers = 0 |
|
batch_size = 1 |
|
|
|
cls = ['airplane', |
|
'bicycle', |
|
'bird', |
|
'boat', |
|
'bottle', |
|
'bus', |
|
'car', |
|
'cat', |
|
'chair', |
|
'cow', |
|
'dining table', |
|
'dog', |
|
'horse', |
|
'motobike', |
|
'person', |
|
'potted plant', |
|
'sheep', |
|
'sofa', |
|
'train', |
|
'tv' |
|
] |
|
|
|
|
|
parser = argparse.ArgumentParser(description='Training multi-class classifier') |
|
parser.add_argument('--arc', type=str, default='vgg', metavar='N', |
|
help='Model architecture') |
|
parser.add_argument('--train_dataset', type=str, default='imagenet', metavar='N', |
|
help='Testing Dataset') |
|
parser.add_argument('--method', type=str, |
|
default='grad_rollout', |
|
choices=['rollout', 'lrp', 'transformer_attribution', 'full_lrp', 'lrp_last_layer', |
|
'attn_last_layer', 'attn_gradcam'], |
|
help='') |
|
parser.add_argument('--thr', type=float, default=0., |
|
help='threshold') |
|
parser.add_argument('--K', type=int, default=1, |
|
help='new - top K results') |
|
parser.add_argument('--save-img', action='store_true', |
|
default=False, |
|
help='') |
|
parser.add_argument('--no-ia', action='store_true', |
|
default=False, |
|
help='') |
|
parser.add_argument('--no-fx', action='store_true', |
|
default=False, |
|
help='') |
|
parser.add_argument('--no-fgx', action='store_true', |
|
default=False, |
|
help='') |
|
parser.add_argument('--no-m', action='store_true', |
|
default=False, |
|
help='') |
|
parser.add_argument('--no-reg', action='store_true', |
|
default=False, |
|
help='') |
|
parser.add_argument('--is-ablation', type=bool, |
|
default=False, |
|
help='') |
|
parser.add_argument('--imagenet-seg-path', type=str, required=True) |
|
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', |
|
help='path to latest checkpoint (default: none)') |
|
args = parser.parse_args() |
|
|
|
args.checkname = args.method + '_' + args.arc |
|
|
|
alpha = 2 |
|
|
|
cuda = torch.cuda.is_available() |
|
device = torch.device("cuda" if cuda else "cpu") |
|
|
|
|
|
saver = Saver(args) |
|
saver.results_dir = os.path.join(saver.experiment_dir, 'results') |
|
if not os.path.exists(saver.results_dir): |
|
os.makedirs(saver.results_dir) |
|
if not os.path.exists(os.path.join(saver.results_dir, 'input')): |
|
os.makedirs(os.path.join(saver.results_dir, 'input')) |
|
if not os.path.exists(os.path.join(saver.results_dir, 'explain')): |
|
os.makedirs(os.path.join(saver.results_dir, 'explain')) |
|
|
|
args.exp_img_path = os.path.join(saver.results_dir, 'explain/img') |
|
if not os.path.exists(args.exp_img_path): |
|
os.makedirs(args.exp_img_path) |
|
args.exp_np_path = os.path.join(saver.results_dir, 'explain/np') |
|
if not os.path.exists(args.exp_np_path): |
|
os.makedirs(args.exp_np_path) |
|
|
|
|
|
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
|
test_img_trans = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
normalize, |
|
]) |
|
test_lbl_trans = transforms.Compose([ |
|
transforms.Resize((224, 224), Image.NEAREST), |
|
]) |
|
|
|
ds = Imagenet_Segmentation(args.imagenet_seg_path, |
|
transform=test_img_trans, target_transform=test_lbl_trans) |
|
dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=False) |
|
|
|
|
|
if args.checkpoint: |
|
print(f"loading model from checkpoint {args.checkpoint}") |
|
model = vit().cuda() |
|
checkpoint = torch.load(args.checkpoint) |
|
model.load_state_dict(checkpoint['state_dict']) |
|
else: |
|
model = vit(pretrained=True).cuda() |
|
|
|
metric = IoU(2, ignore_index=-1) |
|
|
|
iterator = tqdm(dl) |
|
|
|
model.eval() |
|
|
|
|
|
def compute_pred(output): |
|
pred = output.data.max(1, keepdim=True)[1] |
|
|
|
|
|
T = pred.squeeze().cpu().numpy() |
|
T = np.expand_dims(T, 0) |
|
T = (T[:, np.newaxis] == np.arange(1000)) * 1.0 |
|
T = torch.from_numpy(T).type(torch.FloatTensor) |
|
Tt = T.cuda() |
|
|
|
return Tt |
|
|
|
|
|
def eval_batch(image, labels, evaluator, index): |
|
evaluator.zero_grad() |
|
|
|
if args.save_img: |
|
img = image[0].permute(1, 2, 0).data.cpu().numpy() |
|
img = 255 * (img - img.min()) / (img.max() - img.min()) |
|
img = img.astype('uint8') |
|
Image.fromarray(img, 'RGB').save(os.path.join(saver.results_dir, 'input/{}_input.png'.format(index))) |
|
Image.fromarray((labels.repeat(3, 1, 1).permute(1, 2, 0).data.cpu().numpy() * 255).astype('uint8'), 'RGB').save( |
|
os.path.join(saver.results_dir, 'input/{}_mask.png'.format(index))) |
|
|
|
image.requires_grad = True |
|
|
|
image = image.requires_grad_() |
|
predictions = evaluator(image) |
|
Res = generate_relevance(model, image.cuda()) |
|
|
|
|
|
Res = (Res - Res.min()) / (Res.max() - Res.min()) |
|
|
|
ret = Res.mean() |
|
|
|
Res_1 = Res.gt(ret).type(Res.type()) |
|
Res_0 = Res.le(ret).type(Res.type()) |
|
|
|
Res_1_AP = Res |
|
Res_0_AP = 1 - Res |
|
|
|
Res_1[Res_1 != Res_1] = 0 |
|
Res_0[Res_0 != Res_0] = 0 |
|
Res_1_AP[Res_1_AP != Res_1_AP] = 0 |
|
Res_0_AP[Res_0_AP != Res_0_AP] = 0 |
|
|
|
|
|
pred = Res.clamp(min=args.thr) / Res.max() |
|
pred = pred.view(-1).data.cpu().numpy() |
|
target = labels.view(-1).data.cpu().numpy() |
|
|
|
|
|
output = torch.cat((Res_0, Res_1), 1) |
|
output_AP = torch.cat((Res_0_AP, Res_1_AP), 1) |
|
|
|
if args.save_img: |
|
|
|
mask = F.interpolate(Res_1, [64, 64], mode='bilinear') |
|
mask = mask[0].squeeze().data.cpu().numpy() |
|
|
|
mask = 255 * mask |
|
mask = mask.astype('uint8') |
|
imageio.imsave(os.path.join(args.exp_img_path, 'mask_' + str(index) + '.jpg'), mask) |
|
|
|
relevance = F.interpolate(Res, [64, 64], mode='bilinear') |
|
relevance = relevance[0].permute(1, 2, 0).data.cpu().numpy() |
|
|
|
hm = np.sum(relevance, axis=-1) |
|
maps = (render.hm_to_rgb(hm, scaling=3, sigma=1, cmap='seismic') * 255).astype(np.uint8) |
|
imageio.imsave(os.path.join(args.exp_img_path, 'heatmap_' + str(index) + '.jpg'), maps) |
|
|
|
|
|
batch_inter, batch_union, batch_correct, batch_label = 0, 0, 0, 0 |
|
batch_ap, batch_f1 = 0, 0 |
|
|
|
|
|
correct, labeled = batch_pix_accuracy(output[0].data.cpu(), labels[0]) |
|
inter, union = batch_intersection_union(output[0].data.cpu(), labels[0], 2) |
|
batch_correct += correct |
|
batch_label += labeled |
|
batch_inter += inter |
|
batch_union += union |
|
|
|
|
|
|
|
ap = np.nan_to_num(get_ap_scores(output_AP, labels)) |
|
|
|
batch_ap += ap |
|
|
|
|
|
|
|
return batch_correct, batch_label, batch_inter, batch_union, batch_ap, pred, target |
|
|
|
|
|
total_inter, total_union, total_correct, total_label = np.int64(0), np.int64(0), np.int64(0), np.int64(0) |
|
total_ap, total_f1 = [], [] |
|
|
|
predictions, targets = [], [] |
|
for batch_idx, (image, labels) in enumerate(iterator): |
|
|
|
if args.method == "blur": |
|
images = (image[0].cuda(), image[1].cuda()) |
|
else: |
|
images = image.cuda() |
|
labels = labels.cuda() |
|
|
|
|
|
|
|
|
|
correct, labeled, inter, union, ap, pred, target = eval_batch(images, labels, model, batch_idx) |
|
|
|
predictions.append(pred) |
|
targets.append(target) |
|
|
|
total_correct += correct.astype('int64') |
|
total_label += labeled.astype('int64') |
|
total_inter += inter.astype('int64') |
|
total_union += union.astype('int64') |
|
total_ap += [ap] |
|
|
|
pixAcc = np.float64(1.0) * total_correct / (np.spacing(1, dtype=np.float64) + total_label) |
|
IoU = np.float64(1.0) * total_inter / (np.spacing(1, dtype=np.float64) + total_union) |
|
mIoU = IoU.mean() |
|
mAp = np.mean(total_ap) |
|
|
|
|
|
iterator.set_description('pixAcc: %.4f, mIoU: %.4f, mAP: %.4f' % (pixAcc, mIoU, mAp)) |
|
|
|
predictions = np.concatenate(predictions) |
|
targets = np.concatenate(targets) |
|
pr, rc, thr = precision_recall_curve(targets, predictions) |
|
np.save(os.path.join(saver.experiment_dir, 'precision.npy'), pr) |
|
np.save(os.path.join(saver.experiment_dir, 'recall.npy'), rc) |
|
|
|
plt.figure() |
|
plt.plot(rc, pr) |
|
plt.savefig(os.path.join(saver.experiment_dir, 'PR_curve_{}.png'.format(args.method))) |
|
|
|
txtfile = os.path.join(saver.experiment_dir, 'result_mIoU_%.4f.txt' % mIoU) |
|
|
|
fh = open(txtfile, 'w') |
|
print("Mean IoU over %d classes: %.4f\n" % (2, mIoU)) |
|
print("Pixel-wise Accuracy: %2.2f%%\n" % (pixAcc * 100)) |
|
print("Mean AP over %d classes: %.4f\n" % (2, mAp)) |
|
|
|
|
|
fh.write("Mean IoU over %d classes: %.4f\n" % (2, mIoU)) |
|
fh.write("Pixel-wise Accuracy: %2.2f%%\n" % (pixAcc * 100)) |
|
fh.write("Mean AP over %d classes: %.4f\n" % (2, mAp)) |
|
|
|
fh.close() |