Spaces:
Runtime error
Runtime error
File size: 2,160 Bytes
a9289c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import matplotlib.pyplot as plt
from dpt.vit import get_mean_attention_map
def visualize_attention(input, model, prediction, model_type):
input = (input + 1.0)/2.0
attn1 = model.pretrained.attention["attn_1"]
attn2 = model.pretrained.attention["attn_2"]
attn3 = model.pretrained.attention["attn_3"]
attn4 = model.pretrained.attention["attn_4"]
plt.subplot(3,4,1), plt.imshow(input.squeeze().permute(1,2,0)), plt.title("Input", fontsize=8), plt.axis("off")
plt.subplot(3,4,2), plt.imshow(prediction), plt.set_cmap("inferno"), plt.title("Prediction", fontsize=8), plt.axis("off")
if model_type == "dpt_hybrid":
h = [3,6,9,12]
else:
h = [6,12,18,24]
# upper left
plt.subplot(345),
ax1 = plt.imshow(get_mean_attention_map(attn1, 1, input.shape))
plt.ylabel("Upper left corner", fontsize=8)
plt.title(f"Layer {h[0]}", fontsize=8)
gc = plt.gca()
gc.axes.xaxis.set_ticklabels([])
gc.axes.yaxis.set_ticklabels([])
gc.axes.xaxis.set_ticks([])
gc.axes.yaxis.set_ticks([])
plt.subplot(346),
plt.imshow(get_mean_attention_map(attn2, 1, input.shape))
plt.title(f"Layer {h[1]}", fontsize=8)
plt.axis("off"),
plt.subplot(347),
plt.imshow(get_mean_attention_map(attn3, 1, input.shape))
plt.title(f"Layer {h[2]}", fontsize=8)
plt.axis("off"),
plt.subplot(348),
plt.imshow(get_mean_attention_map(attn4, 1, input.shape))
plt.title(f"Layer {h[3]}", fontsize=8)
plt.axis("off"),
# lower right
plt.subplot(3,4,9), plt.imshow(get_mean_attention_map(attn1, -1, input.shape))
plt.ylabel("Lower right corner", fontsize=8)
gc = plt.gca()
gc.axes.xaxis.set_ticklabels([])
gc.axes.yaxis.set_ticklabels([])
gc.axes.xaxis.set_ticks([])
gc.axes.yaxis.set_ticks([])
plt.subplot(3,4,10), plt.imshow(get_mean_attention_map(attn2, -1, input.shape)), plt.axis("off")
plt.subplot(3,4,11), plt.imshow(get_mean_attention_map(attn3, -1, input.shape)), plt.axis("off")
plt.subplot(3,4,12), plt.imshow(get_mean_attention_map(attn4, -1, input.shape)), plt.axis("off")
plt.tight_layout()
plt.show()
|