Matteo Sirri
feat: initial commit
169e11c
import numpy as np
import torchvision.transforms.functional as F
from torchvision import transforms
from typing import DefaultDict
import matplotlib.pyplot as plt
import matplotlib
import torch
import logging
from torchvision.utils import draw_bounding_boxes
matplotlib.style.use('ggplot')
logging.getLogger('matplotlib').setLevel(logging.CRITICAL)
logging.getLogger('PIL').setLevel(logging.CRITICAL)
def save_plot(train_loss_list, label, output_dir):
"""
Function to save the loss plot to disk.
"""
# Loss plots.
plt.figure(figsize=(10, 7))
plt.plot(
train_loss_list, linestyle='-',
label=label
)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig(f"{output_dir}/{label}.png")
def save_train_loss_plot(train_loss_dict: DefaultDict, output_dir):
"""
Function to save the loss plots to disk.
"""
for key in train_loss_dict.keys():
save_plot(train_loss_dict[key], key, output_dir)
def show(imgs):
if not isinstance(imgs, list):
imgs = [imgs]
fig, axs = plt.subplots(nrows=len(imgs), ncols=1,
figsize=(45, 21), squeeze=False)
for i, img in enumerate(imgs):
img = img.detach()
img = F.to_pil_image(img)
img = np.asarray(img)
axs[i, 0].imshow(img)
axs[i, 0].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
plt.show()
def plot_img_tensor(img_tensor):
transforms.ToPILImage()(img_tensor).show()
def show_img(data_loader, model, device, th=0.7):
for imgs, target in data_loader:
with torch.no_grad():
prediction = model([imgs[0].to(device)])[0]
plot_img_tensor(add_bbox(imgs[0], prediction, th))
plot_img_tensor(add_bbox(imgs[0], target[0]['boxes']))
break
def add_bbox(img, output, th=None):
img_canvas = img.clone()
img_canvas = torch.clip(img*255, 0, 255)
img_canvas = img_canvas.type(torch.uint8)
if th == None:
img_with_bbbox = draw_bounding_boxes(
img_canvas, boxes=output, width=4)
else:
mask = (output["scores"] > th) & (output["labels"] == 1)
scores_list = [score for score in (
output["scores"][mask]).tolist()]
labels_list = [str(label) for label in (
output["labels"][mask]).tolist()]
labels = ["person" for label in labels_list if label == "1"]
assert len(labels) == len(scores_list) == len(labels_list)
for i in range(0, len(labels)):
labels[i] = f"{labels[i]}:{scores_list[i]:.3f}"
img_with_bbbox = draw_bounding_boxes(
img_canvas, boxes=output["boxes"][mask], labels=labels, width=4)
return img_with_bbbox