Spaces:
Runtime error
Runtime error
import lightning as L | |
import torch.nn.functional as F | |
from torch import optim | |
from torchmetrics.classification import Accuracy, F1Score | |
from timm.models import VisionTransformer | |
import torch | |
class ViTTinyClassifier(L.LightningModule): | |
def __init__( | |
self, | |
img_size: int = 224, | |
num_classes: int = 2, # Binary classification with two classes | |
embed_dim: int = 64, | |
depth: int = 6, | |
num_heads: int = 2, | |
patch_size: int = 16, | |
mlp_ratio: float = 3.0, | |
pre_norm: bool = False, | |
lr: float = 1e-3, | |
weight_decay: float = 1e-5, | |
factor: float = 0.1, | |
patience: int = 10, | |
min_lr: float = 1e-6, | |
): | |
super().__init__() | |
self.save_hyperparameters() | |
# Vision Transformer model initialization | |
self.model = VisionTransformer( | |
img_size=img_size, | |
patch_size=patch_size, | |
in_chans=3, | |
num_classes=num_classes, | |
embed_dim=embed_dim, | |
depth=depth, | |
num_heads=num_heads, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=True, | |
pre_norm=pre_norm, | |
global_pool="token", | |
) | |
# Define accuracy and F1 metrics for binary classification | |
self.train_acc = Accuracy(task="binary") | |
self.val_acc = Accuracy(task="binary") | |
self.test_acc = Accuracy(task="binary") | |
self.train_f1 = F1Score(task="binary") | |
self.val_f1 = F1Score(task="binary") | |
self.test_f1 = F1Score(task="binary") | |
def forward(self, x): | |
return self.model(x) | |
def _shared_step(self, batch, stage): | |
x, y = batch | |
logits = self(x) # Model output shape: [batch_size, num_classes] | |
loss = F.cross_entropy(logits, y) # Cross-entropy for binary classification | |
preds = torch.argmax(logits, dim=1) # Predicted class (0 or 1) | |
# Update and log metrics | |
acc = getattr(self, f"{stage}_acc") | |
f1 = getattr(self, f"{stage}_f1") | |
acc(preds, y) | |
f1(preds, y) | |
# Logging of metrics and loss | |
self.log(f"{stage}_loss", loss, prog_bar=True, on_epoch=True) | |
self.log(f"{stage}_acc", acc, prog_bar=True, on_epoch=True) | |
self.log(f"{stage}_f1", f1, prog_bar=True, on_epoch=True) | |
return loss | |
def training_step(self, batch, batch_idx): | |
return self._shared_step(batch, "train") | |
def validation_step(self, batch, batch_idx): | |
self._shared_step(batch, "val") | |
def test_step(self, batch, batch_idx): | |
self._shared_step(batch, "test") | |
def configure_optimizers(self): | |
optimizer = optim.AdamW( | |
self.parameters(), | |
lr=self.hparams.lr, | |
weight_decay=self.hparams.weight_decay, | |
) | |
scheduler = optim.lr_scheduler.ReduceLROnPlateau( | |
optimizer, | |
mode="min", | |
factor=self.hparams.factor, | |
patience=self.hparams.patience, | |
min_lr=self.hparams.min_lr, | |
) | |
return { | |
"optimizer": optimizer, | |
"lr_scheduler": { | |
"scheduler": scheduler, | |
"monitor": "val_loss", | |
"interval": "epoch", | |
}, | |
} | |
if __name__ == "__main__": | |
model = ViTTinyClassifier() | |
print(model) | |