|
import matplotlib.pyplot as plt |
|
|
|
from generic_utils import generate_visualization |
|
|
|
|
|
def do_lrp(transform, image, class_index=None): |
|
fig, axs = plt.subplots(1, 2) |
|
axs[0].imshow(image) |
|
axs[0].axis("off") |
|
|
|
transformed_image = transform(image) |
|
viz = generate_visualization( |
|
transformed_image, class_index=class_index, method="full" |
|
) |
|
|
|
axs[1].imshow(viz) |
|
axs[1].axis("off") |
|
return fig |
|
|
|
|
|
def do_partial_lrp(transform, image, class_index=None): |
|
fig, axs = plt.subplots(1, 2) |
|
axs[0].imshow(image) |
|
axs[0].axis("off") |
|
|
|
transformed_image = transform(image) |
|
viz = generate_visualization( |
|
transformed_image, class_index=class_index, method="last_layer" |
|
) |
|
|
|
axs[1].imshow(viz) |
|
axs[1].axis("off") |
|
return fig |
|
|