File size: 2,932 Bytes
557fb53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from models.utils import calculate_metrics


import pytorch_lightning as pl
import torch
import torch.nn as nn


class TrainingEnvironment(pl.LightningModule):
    def __init__(
        self,
        model: nn.Module,
        criterion: nn.Module,
        config: dict,
        learning_rate=1e-4,
        log_spectrograms=False,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.model = model
        self.criterion = criterion
        self.learning_rate = learning_rate
        self.log_spectrograms = log_spectrograms
        self.config = config
        self.has_multi_label_predictions = (
            not type(criterion).__name__ == "CrossEntropyLoss"
        )
        self.save_hyperparameters(
            {
                "model": type(model).__name__,
                "loss": type(criterion).__name__,
                "config": config,
                **kwargs,
            }
        )

    def training_step(
        self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int
    ) -> torch.Tensor:
        features, labels = batch
        outputs = self.model(features)
        loss = self.criterion(outputs, labels)
        metrics = calculate_metrics(
            outputs,
            labels,
            prefix="train/",
            multi_label=self.has_multi_label_predictions,
        )
        self.log_dict(metrics, prog_bar=True)
        # Log spectrograms
        if self.log_spectrograms and batch_index % 100 == 0:
            tensorboard = self.logger.experiment
            img_index = torch.randint(0, len(features), (1,)).item()
            img = features[img_index][0]
            img = (img - img.min()) / (img.max() - img.min())
            tensorboard.add_image(
                f"batch: {batch_index}, element: {img_index}", img, 0, dataformats="HW"
            )
        return loss

    def validation_step(
        self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int
    ):
        x, y = batch
        preds = self.model(x)
        metrics = calculate_metrics(
            preds, y, prefix="val/", multi_label=self.has_multi_label_predictions
        )
        metrics["val/loss"] = self.criterion(preds, y)
        self.log_dict(metrics, prog_bar=True)

    def test_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int):
        x, y = batch
        preds = self.model(x)
        self.log_dict(
            calculate_metrics(
                preds, y, prefix="test/", multi_label=self.has_multi_label_predictions
            ),
            prog_bar=True,
        )

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler,
            "monitor": "val/loss",
        }