GradCAMViT / plot.py
raedinkhaled's picture
Create plot.py
a14b289
import matplotlib.pyplot as plt
from app import inference, examples
from PIL import Image
plt.rcParams["figure.figsize"] = (11,2)
title = ["CAM", "ROLLOUT"]
fig_resnet, axis_resnet = plt.subplots(1, len(examples))
plots = [plt.subplots(1, len(examples)) for _ in range(2)]
for i, image_path in enumerate(examples):
image = Image.open(image_path)
result = inference(image)
for j, (fig, axis) in enumerate(plots):
axis[i].imshow(result[2*j+1])
axis[i].set_title(result[2*j])
axis[i].set_axis_off()
for i, (plot, title) in enumerate(zip(plots, title)):
# plot[0].suptitle(title)
plot[0].savefig(f"{title}.png")