import sys import cv2 import numpy as np import torch from imagenet_class_indices import CLS2IDX sys.path.append("Transformer-Explainability") from baselines.ViT.ViT_explanation_generator import LRP, Baselines from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP from baselines.ViT.ViT_new import vit_base_patch16_224 as vit # create heatmap from mask on image def show_cam_on_image(img, mask): heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) heatmap = np.float32(heatmap) / 255 cam = heatmap + np.float32(img) cam = cam / np.max(cam) return cam # initialize ViT pretrained model = vit_LRP(pretrained=True) model.eval() attribution_generator = LRP(model) model_baseline = vit(pretrained=True) model_baseline.eval() baselines_generator = Baselines(model_baseline) def generate_visualization( original_image, class_index=None, method="transformer_attribution", LRP=True ): if LRP: transformer_attribution = attribution_generator.generate_LRP( original_image.unsqueeze(0), method=method, index=class_index ).detach() else: if method == "gradcam": transformer_attribution = baselines_generator.generate_cam_attn( original_image.unsqueeze(0), index=class_index ).detach() else: transformer_attribution = baselines_generator.generate_rollout( original_image.unsqueeze(0) ).detach() if method != "full": transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14) transformer_attribution = torch.nn.functional.interpolate( transformer_attribution, scale_factor=16, mode="bilinear" ) else: transformer_attribution = transformer_attribution.reshape(1, 1, 224, 224) transformer_attribution = ( transformer_attribution.reshape(224, 224).data.cpu().numpy() ) transformer_attribution = ( transformer_attribution - transformer_attribution.min() ) / (transformer_attribution.max() - transformer_attribution.min()) image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy() image_transformer_attribution = ( image_transformer_attribution - image_transformer_attribution.min() ) / (image_transformer_attribution.max() - image_transformer_attribution.min()) vis = show_cam_on_image(image_transformer_attribution, transformer_attribution) vis = np.uint8(255 * vis) vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR) return vis def print_top_classes(predictions, **kwargs): # Print Top-5 predictions prob = torch.softmax(predictions, dim=1) class_indices = predictions.data.topk(5, dim=1)[1][0].tolist() max_str_len = 0 class_names = [] for cls_idx in class_indices: class_names.append(CLS2IDX[cls_idx]) if len(CLS2IDX[cls_idx]) > max_str_len: max_str_len = len(CLS2IDX[cls_idx]) print("Top 5 classes:") for cls_idx in class_indices: output_string = "\t{} : {}".format(cls_idx, CLS2IDX[cls_idx]) output_string += " " * (max_str_len - len(CLS2IDX[cls_idx])) + "\t\t" output_string += "value = {:.3f}\t prob = {:.1f}%".format( predictions[0, cls_idx], 100 * prob[0, cls_idx] ) print(output_string)