import matplotlib.pyplot as plt | |
from generic_utils import generate_visualization | |
def do_rollout(transform, image, class_index=None): | |
fig, axs = plt.subplots(1, 2) | |
axs[0].imshow(image) | |
axs[0].axis("off") | |
transformed_image = transform(image) | |
viz = generate_visualization( | |
transformed_image, class_index=class_index, method="rollout", LRP=False | |
) | |
axs[1].imshow(viz) | |
axs[1].axis("off") | |
return fig | |