import numpy as np import torchvision.transforms.functional as F from torchvision import transforms from typing import DefaultDict import matplotlib.pyplot as plt import matplotlib import torch import logging from torchvision.utils import draw_bounding_boxes matplotlib.style.use('ggplot') logging.getLogger('matplotlib').setLevel(logging.CRITICAL) logging.getLogger('PIL').setLevel(logging.CRITICAL) def save_plot(train_loss_list, label, output_dir): """ Function to save the loss plot to disk. """ # Loss plots. plt.figure(figsize=(10, 7)) plt.plot( train_loss_list, linestyle='-', label=label ) plt.xlabel('Epochs') plt.ylabel('Loss') plt.legend() plt.savefig(f"{output_dir}/{label}.png") def save_train_loss_plot(train_loss_dict: DefaultDict, output_dir): """ Function to save the loss plots to disk. """ for key in train_loss_dict.keys(): save_plot(train_loss_dict[key], key, output_dir) def show(imgs): if not isinstance(imgs, list): imgs = [imgs] fig, axs = plt.subplots(nrows=len(imgs), ncols=1, figsize=(45, 21), squeeze=False) for i, img in enumerate(imgs): img = img.detach() img = F.to_pil_image(img) img = np.asarray(img) axs[i, 0].imshow(img) axs[i, 0].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) plt.show() def plot_img_tensor(img_tensor): transforms.ToPILImage()(img_tensor).show() def show_img(data_loader, model, device, th=0.7): for imgs, target in data_loader: with torch.no_grad(): prediction = model([imgs[0].to(device)])[0] plot_img_tensor(add_bbox(imgs[0], prediction, th)) plot_img_tensor(add_bbox(imgs[0], target[0]['boxes'])) break def add_bbox(img, output, th=None): img_canvas = img.clone() img_canvas = torch.clip(img*255, 0, 255) img_canvas = img_canvas.type(torch.uint8) if th == None: img_with_bbbox = draw_bounding_boxes( img_canvas, boxes=output, width=4) else: mask = (output["scores"] > th) & (output["labels"] == 1) scores_list = [score for score in ( output["scores"][mask]).tolist()] labels_list = [str(label) for label in ( output["labels"][mask]).tolist()] labels = ["person" for label in labels_list if label == "1"] assert len(labels) == len(scores_list) == len(labels_list) for i in range(0, len(labels)): labels[i] = f"{labels[i]}:{scores_list[i]:.3f}" img_with_bbbox = draw_bounding_boxes( img_canvas, boxes=output["boxes"][mask], labels=labels, width=4) return img_with_bbbox