comparative-explainability / generic_utils.py
sayakpaul's picture
sayakpaul HF staff
remove cuda
1548a67
raw history blame
No virus
3.35 kB
import sys
import cv2
import numpy as np
import torch
from imagenet_class_indices import CLS2IDX
sys.path.append("Transformer-Explainability")
from baselines.ViT.ViT_explanation_generator import LRP, Baselines
from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
from baselines.ViT.ViT_new import vit_base_patch16_224 as vit
# create heatmap from mask on image
def show_cam_on_image(img, mask):
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
cam = heatmap + np.float32(img)
cam = cam / np.max(cam)
return cam
# initialize ViT pretrained
model = vit_LRP(pretrained=True)
model.eval()
attribution_generator = LRP(model)
model_baseline = vit(pretrained=True)
model_baseline.eval()
baselines_generator = Baselines(model_baseline)
def generate_visualization(
original_image, class_index=None, method="transformer_attribution", LRP=True
):
if LRP:
transformer_attribution = attribution_generator.generate_LRP(
original_image.unsqueeze(0), method=method, index=class_index
).detach()
else:
if method == "gradcam":
transformer_attribution = baselines_generator.generate_cam_attn(
original_image.unsqueeze(0), index=class_index
).detach()
else:
transformer_attribution = baselines_generator.generate_rollout(
original_image.unsqueeze(0)
).detach()
if method != "full":
transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
transformer_attribution = torch.nn.functional.interpolate(
transformer_attribution, scale_factor=16, mode="bilinear"
)
else:
transformer_attribution = transformer_attribution.reshape(1, 1, 224, 224)
transformer_attribution = (
transformer_attribution.reshape(224, 224).data.cpu().numpy()
)
transformer_attribution = (
transformer_attribution - transformer_attribution.min()
) / (transformer_attribution.max() - transformer_attribution.min())
image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy()
image_transformer_attribution = (
image_transformer_attribution - image_transformer_attribution.min()
) / (image_transformer_attribution.max() - image_transformer_attribution.min())
vis = show_cam_on_image(image_transformer_attribution, transformer_attribution)
vis = np.uint8(255 * vis)
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
return vis
def print_top_classes(predictions, **kwargs):
# Print Top-5 predictions
prob = torch.softmax(predictions, dim=1)
class_indices = predictions.data.topk(5, dim=1)[1][0].tolist()
max_str_len = 0
class_names = []
for cls_idx in class_indices:
class_names.append(CLS2IDX[cls_idx])
if len(CLS2IDX[cls_idx]) > max_str_len:
max_str_len = len(CLS2IDX[cls_idx])
print("Top 5 classes:")
for cls_idx in class_indices:
output_string = "\t{} : {}".format(cls_idx, CLS2IDX[cls_idx])
output_string += " " * (max_str_len - len(CLS2IDX[cls_idx])) + "\t\t"
output_string += "value = {:.3f}\t prob = {:.1f}%".format(
predictions[0, cls_idx], 100 * prob[0, cls_idx]
)
print(output_string)