Spaces:
Runtime error
Runtime error
from matplotlib import pyplot as plt | |
import torch | |
import torch.nn.functional as F | |
from constants import COLORS | |
from utils import fig2img | |
def visualize_prediction( | |
pil_img, output_dict, threshold=0.7, id2label=None | |
): | |
keep = output_dict["scores"] > threshold | |
boxes = output_dict["boxes"][keep].tolist() | |
scores = output_dict["scores"][keep].tolist() | |
labels = output_dict["labels"][keep].tolist() | |
if id2label is not None: | |
labels = [id2label[x] for x in labels] | |
fig, ax = plt.subplots(figsize=(12, 12)) | |
ax.imshow(pil_img) | |
colors = COLORS * 100 | |
for score, (xmin, ymin, xmax, ymax), label, color in zip( | |
scores, boxes, labels, colors | |
): | |
ax.add_patch( | |
plt.Rectangle( | |
(xmin, ymin), | |
xmax - xmin, | |
ymax - ymin, | |
fill=False, | |
color=color, | |
linewidth=2, | |
) | |
) | |
ax.text( | |
xmin, | |
ymin, | |
f"{label}: {score:0.2f}", | |
fontsize=10, | |
bbox=dict(facecolor="yellow", alpha=0.5), | |
) | |
ax.axis("off") | |
return fig2img(fig) | |
def visualize_attention_map(pil_img, attention_map): | |
attention_map = attention_map[-1].detach().cpu() | |
avg_attention_weight = torch.mean(attention_map, dim=1).squeeze() | |
avg_attention_weight_resized = ( | |
F.interpolate( | |
avg_attention_weight.unsqueeze(0).unsqueeze(0), | |
size=pil_img.size[::-1], | |
mode="bicubic", | |
) | |
.squeeze() | |
.numpy() | |
) | |
plt.imshow(pil_img) | |
plt.imshow(avg_attention_weight_resized, alpha=0.7, cmap="viridis") | |
plt.axis("off") | |
fig = plt.gcf() | |
return fig2img(fig) | |