File size: 1,832 Bytes
06a7cdc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import torch
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from . import config
from .transforms import test_transforms
def generate_confidences(
model,
input_img,
num_top_preds,
):
input_img = test_transforms(image=input_img)
input_img = input_img["image"]
input_img = input_img.unsqueeze(0)
model.eval()
log_probs = model(input_img)[0].detach()
model.train()
probs = torch.exp(log_probs)
confidences = {
config.CLASSES[i]: float(probs[i]) for i in range(len(config.CLASSES))
}
# Select top 5 confidences based on value
confidences = {
k: v
for k, v in sorted(confidences.items(), key=lambda item: item[1], reverse=True)[
:num_top_preds
]
}
return input_img, confidences
def generate_gradcam(
model,
org_img,
input_img,
show_gradcam,
gradcam_layer,
gradcam_opacity,
):
if show_gradcam:
if gradcam_layer == -1:
target_layers = [model.l3[-1]]
elif gradcam_layer == -2:
target_layers = [model.l2[-1]]
cam = GradCAM(
model=model,
target_layers=target_layers,
)
grayscale_cam = cam(input_tensor=input_img, targets=None)
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(
org_img / 255,
grayscale_cam,
use_rgb=True,
image_weight=(1 - gradcam_opacity),
)
else:
visualization = None
return visualization
def generate_missclassified_imgs(
model,
show_misclassified,
num_misclassified,
):
if show_misclassified:
plot = model.plot_incorrect_predictions_helper(num_misclassified)
else:
plot = None
return plot
|