import matplotlib.pyplot as plt from generic_utils import generate_visualization def do_lrp(transform, image, class_index=None): fig, axs = plt.subplots(1, 2) axs[0].imshow(image) axs[0].axis("off") transformed_image = transform(image) viz = generate_visualization( transformed_image, class_index=class_index, method="full" ) axs[1].imshow(viz) axs[1].axis("off") return fig def do_partial_lrp(transform, image, class_index=None): fig, axs = plt.subplots(1, 2) axs[0].imshow(image) axs[0].axis("off") transformed_image = transform(image) viz = generate_visualization( transformed_image, class_index=class_index, method="last_layer" ) axs[1].imshow(viz) axs[1].axis("off") return fig