Spaces:
Runtime error
Runtime error
File size: 2,932 Bytes
557fb53 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
from models.utils import calculate_metrics
import pytorch_lightning as pl
import torch
import torch.nn as nn
class TrainingEnvironment(pl.LightningModule):
def __init__(
self,
model: nn.Module,
criterion: nn.Module,
config: dict,
learning_rate=1e-4,
log_spectrograms=False,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.model = model
self.criterion = criterion
self.learning_rate = learning_rate
self.log_spectrograms = log_spectrograms
self.config = config
self.has_multi_label_predictions = (
not type(criterion).__name__ == "CrossEntropyLoss"
)
self.save_hyperparameters(
{
"model": type(model).__name__,
"loss": type(criterion).__name__,
"config": config,
**kwargs,
}
)
def training_step(
self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int
) -> torch.Tensor:
features, labels = batch
outputs = self.model(features)
loss = self.criterion(outputs, labels)
metrics = calculate_metrics(
outputs,
labels,
prefix="train/",
multi_label=self.has_multi_label_predictions,
)
self.log_dict(metrics, prog_bar=True)
# Log spectrograms
if self.log_spectrograms and batch_index % 100 == 0:
tensorboard = self.logger.experiment
img_index = torch.randint(0, len(features), (1,)).item()
img = features[img_index][0]
img = (img - img.min()) / (img.max() - img.min())
tensorboard.add_image(
f"batch: {batch_index}, element: {img_index}", img, 0, dataformats="HW"
)
return loss
def validation_step(
self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int
):
x, y = batch
preds = self.model(x)
metrics = calculate_metrics(
preds, y, prefix="val/", multi_label=self.has_multi_label_predictions
)
metrics["val/loss"] = self.criterion(preds, y)
self.log_dict(metrics, prog_bar=True)
def test_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int):
x, y = batch
preds = self.model(x)
self.log_dict(
calculate_metrics(
preds, y, prefix="test/", multi_label=self.has_multi_label_predictions
),
prog_bar=True,
)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
return {
"optimizer": optimizer,
"lr_scheduler": scheduler,
"monitor": "val/loss",
}
|