import matplotlib.pyplot as plt from generic_utils import generate_visualization def do_tiba(transform, image, class_index=None): fig, axs = plt.subplots(1, 2) axs[0].imshow(image) axs[0].axis("off") transformed_image = transform(image) viz = generate_visualization(transformed_image, class_index=class_index) axs[1].imshow(viz) axs[1].axis("off") return fig