File size: 2,728 Bytes
169e11c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import numpy as np
import torchvision.transforms.functional as F
from torchvision import transforms
from typing import DefaultDict
import matplotlib.pyplot as plt
import matplotlib
import torch
import logging
from torchvision.utils import draw_bounding_boxes
matplotlib.style.use('ggplot')
logging.getLogger('matplotlib').setLevel(logging.CRITICAL)
logging.getLogger('PIL').setLevel(logging.CRITICAL)


def save_plot(train_loss_list, label, output_dir):
    """
    Function to save the loss plot to disk.
    """
    # Loss plots.
    plt.figure(figsize=(10, 7))
    plt.plot(
        train_loss_list, linestyle='-',
        label=label
    )
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig(f"{output_dir}/{label}.png")


def save_train_loss_plot(train_loss_dict: DefaultDict, output_dir):
    """
    Function to save the loss plots to disk.
    """
    for key in train_loss_dict.keys():
        save_plot(train_loss_dict[key], key, output_dir)


def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(nrows=len(imgs), ncols=1,
                            figsize=(45, 21), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        img = np.asarray(img)
        axs[i, 0].imshow(img)
        axs[i, 0].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    plt.show()


def plot_img_tensor(img_tensor):
    transforms.ToPILImage()(img_tensor).show()


def show_img(data_loader, model, device, th=0.7):
    for imgs, target in data_loader:
        with torch.no_grad():
            prediction = model([imgs[0].to(device)])[0]
        plot_img_tensor(add_bbox(imgs[0], prediction, th))
        plot_img_tensor(add_bbox(imgs[0], target[0]['boxes']))
        break


def add_bbox(img, output, th=None):
    img_canvas = img.clone()
    img_canvas = torch.clip(img*255, 0, 255)
    img_canvas = img_canvas.type(torch.uint8)

    if th == None:
        img_with_bbbox = draw_bounding_boxes(
            img_canvas, boxes=output, width=4)
    else:
        mask = (output["scores"] > th) & (output["labels"] == 1)
        scores_list = [score for score in (
            output["scores"][mask]).tolist()]
        labels_list = [str(label) for label in (
            output["labels"][mask]).tolist()]
        labels = ["person" for label in labels_list if label == "1"]
        assert len(labels) == len(scores_list) == len(labels_list)

        for i in range(0, len(labels)):
            labels[i] = f"{labels[i]}:{scores_list[i]:.3f}"
        img_with_bbbox = draw_bounding_boxes(
            img_canvas, boxes=output["boxes"][mask], labels=labels, width=4)
    return img_with_bbbox