|
import pytorch_lightning as pl |
|
import torch |
|
|
|
from datasets import load_metric |
|
from torch import nn |
|
from transformers import SegformerForSemanticSegmentation |
|
from typing import Dict |
|
|
|
|
|
class SidewalkSegmentationModel(pl.LightningModule): |
|
def __init__( |
|
self, |
|
num_labels: int, |
|
id2label: Dict[int, str], |
|
model_flavor: int = 0, |
|
learning_rate: float = 6e-5, |
|
): |
|
super().__init__() |
|
self.id2label = id2label |
|
self.label2id = {v: k for k, v in id2label.items()} |
|
self.learning_rate = learning_rate |
|
self.metrics = { |
|
"train": load_metric("mean_iou"), |
|
"val": load_metric("mean_iou"), |
|
} |
|
|
|
self.model = SegformerForSemanticSegmentation.from_pretrained( |
|
f"nvidia/mit-b{model_flavor}", num_labels=num_labels, id2label=self.id2label, label2id=self.label2id, |
|
) |
|
self.save_hyperparameters() |
|
|
|
|
|
def forward(self, *args, **kwargs): |
|
return self.model(*args, **kwargs) |
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
pixel_values = batch["pixel_values"] |
|
labels = batch["labels"] |
|
|
|
outputs = self(pixel_values=pixel_values, labels=labels) |
|
loss, logits = outputs.loss, outputs.logits |
|
|
|
self.add_batch_to_metric("train", logits, labels) |
|
self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True) |
|
return {"loss": loss} |
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
pixel_values = batch["pixel_values"] |
|
labels = batch["labels"] |
|
|
|
outputs = self(pixel_values=pixel_values, labels=labels) |
|
loss, logits = outputs.loss, outputs.logits |
|
|
|
self.add_batch_to_metric("val", logits, labels) |
|
self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True) |
|
return {"val_loss": loss} |
|
|
|
|
|
def training_epoch_end(self, training_step_outputs): |
|
""" |
|
Log the training metrics. |
|
""" |
|
metrics = self.metrics["train"].compute(num_labels=len(self.id2label), ignore_index=255, reduce_labels=False) |
|
self.log("train_mean_iou", metrics["mean_iou"], prog_bar=True, on_step=False, on_epoch=True) |
|
self.log("train_mean_acc", metrics["mean_accuracy"], prog_bar=True, on_step=False, on_epoch=True) |
|
|
|
|
|
def validation_epoch_end(self, validation_step_outputs): |
|
""" |
|
Log the validation metrics. |
|
""" |
|
metrics = self.metrics["val"].compute(num_labels=len(self.id2label), ignore_index=255, reduce_labels=False) |
|
self.log("val_mean_iou", metrics["mean_iou"], prog_bar=True, on_step=False, on_epoch=True) |
|
self.log("val_mean_acc", metrics["mean_accuracy"], prog_bar=True, on_step=False, on_epoch=True) |
|
|
|
|
|
def add_batch_to_metric(self, stage: str, logits: torch.Tensor, labels: torch.Tensor): |
|
""" |
|
Add the current batch to the metric. |
|
|
|
Parameters |
|
---------- |
|
stage : str |
|
Stage of the training. Either "train" or "val". |
|
logits : torch.Tensor |
|
Predicted logits. |
|
labels : torch.Tensor |
|
Ground truth labels. |
|
""" |
|
with torch.no_grad(): |
|
upsampled_logits = nn.functional.interpolate( |
|
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False |
|
) |
|
predicted = upsampled_logits.argmax(dim=1) |
|
self.metrics[stage].add_batch( |
|
predictions=predicted.detach().cpu().numpy(), references=labels.detach().cpu().numpy() |
|
) |
|
|
|
|
|
def configure_optimizers(self) -> torch.optim.AdamW: |
|
""" |
|
Configure the optimizer. |
|
|
|
Returns |
|
------- |
|
torch.optim.AdamW |
|
Optimizer for the model |
|
""" |
|
return torch.optim.AdamW(self.parameters(), lr=self.learning_rate) |
|
|