File size: 3,843 Bytes
be6ec35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import lightning.pytorch as pl
from . import config
from .utils import (check_class_accuracy,get_evaluation_bboxes,mean_average_precision,plot_couple_examples)
from lightning.pytorch.callbacks import Callback


class plot_examples_callback(Callback):
    def __init__(self, epoch_interval: int = 5) -> None:
        super().__init__()
        self.epoch_interval = epoch_interval

    def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
        if (trainer.current_epoch + 1) % self.epoch_interval == 0:
            plot_couple_examples(
                model=pl_module,
                loader=pl_module.train_dataloader(),
                thresh=0.6,
                iou_thresh=0.5,
                anchors=pl_module.scaled_anchors,
            )


class class_accuracy_callback(pl.Callback):
    def __init__(self, train_epoch_interval: int = 1, test_epoch_interval: int = 10) -> None:
        super().__init__()
        self.train_epoch_interval = train_epoch_interval
        self.test_epoch_interval = test_epoch_interval

    def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
        if (trainer.current_epoch + 1) % self.train_epoch_interval == 0:
            class_acc, no_obj_acc, obj_acc = check_class_accuracy(model=pl_module, loader=pl_module.train_dataloader(), threshold=config.CONF_THRESHOLD)
            class_acc = round(class_acc.item(),2)
            no_obj_acc = round(no_obj_acc.item(),2)
            obj_acc = round(obj_acc.item(),2)

            pl_module.log_dict(
                {
                    "train_class_acc": class_acc,
                    "train_no_obj_acc": no_obj_acc,
                    "train_obj_acc": obj_acc,
                },
                logger=True,
            )
            print(f"Epoch: {trainer.current_epoch + 1}")
            print("Train Metrics")
            print(f"Loss: {trainer.callback_metrics['train_loss_epoch']}")
            print(f"Class Accuracy: {class_acc:2f}%")
            print(f"No Object Accuracy: {no_obj_acc:2f}%")
            print(f"Object Accuracy: {obj_acc:2f}%")

        if (trainer.current_epoch + 1) % self.test_epoch_interval == 0:
            class_acc, no_obj_acc, obj_acc = check_class_accuracy(model=pl_module, loader=pl_module.test_dataloader(), threshold=config.CONF_THRESHOLD)
            class_acc = round(class_acc.item(),2)
            no_obj_acc = round(no_obj_acc.item(),2)
            obj_acc = round(obj_acc.item(),2)
            
            pl_module.log_dict(
                {
                    "test_class_acc": class_acc,
                    "test_no_obj_acc": no_obj_acc,
                    "test_obj_acc": obj_acc,
                },
                logger=True,
            )

            print("Test Metrics")
            print(f"Class Accuracy: {class_acc:2f}%")
            print(f"No Object Accuracy: {no_obj_acc:2f}%")
            print(f"Object Accuracy: {obj_acc:2f}%")

class map_callback(pl.Callback):
    def __init__(self, epoch_interval: int = 10) -> None:
        super().__init__()
        self.epoch_interval = epoch_interval

    def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
        if (trainer.current_epoch + 1) % self.epoch_interval == 0:
            pred_boxes, true_boxes = get_evaluation_bboxes(loader=pl_module.test_dataloader(), model=pl_module, iou_threshold=config.NMS_IOU_THRESH, anchors=config.ANCHORS, threshold=config.CONF_THRESHOLD, device=config.DEVICE,)

            map_val = mean_average_precision(pred_boxes=pred_boxes, true_boxes=true_boxes, iou_threshold=config.MAP_IOU_THRESH, box_format="midpoint", num_classes=config.NUM_CLASSES)
            print("MAP: ", map_val.item())
            pl_module.log("MAP",map_val.item(),logger=True)
            pl_module.train()