|
import torch |
|
import torch.nn as nn |
|
import pytorch_lightning as pl |
|
|
|
from torchmetrics.classification import BinaryAccuracy |
|
|
|
class BiLSTM(pl.LightningModule): |
|
def __init__(self, lr, num_classes, input_size, hidden_size=300, num_layers=2, dropout=0.5): |
|
super(BiLSTM, self).__init__() |
|
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=True, batch_first=True) |
|
self.dropout = nn.Dropout(dropout) |
|
self.output_layer = nn.Linear(hidden_size * 2, num_classes) |
|
self.criterion = nn.BCEWithLogitsLoss() |
|
self.accuracy_metric = BinaryAccuracy() |
|
self.lr = lr |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, X): |
|
lstm_output, _ = self.lstm(X) |
|
preds = self.output_layer(self.dropout(lstm_output)) |
|
|
|
return preds |
|
|
|
def configure_optimizers(self): |
|
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) |
|
|
|
return optimizer |
|
|
|
def training_step(self, train_batch, batch_idx): |
|
X, target = train_batch |
|
|
|
preds = self(X) |
|
|
|
preds = preds.squeeze(1) |
|
loss = self.criterion(preds, target.float()) |
|
|
|
preds = self.sigmoid(preds) |
|
accuracy = self.accuracy_metric(preds, target) |
|
|
|
self.log_dict({'train_loss': loss, 'train_accuracy': accuracy}, prog_bar=True, on_epoch=True) |
|
|
|
return loss |
|
|
|
def validation_step(self, valid_batch, batch_idx): |
|
X, target = valid_batch |
|
|
|
preds = self(X) |
|
|
|
preds = preds.squeeze(1) |
|
loss = self.criterion(preds, target.float()) |
|
|
|
preds = self.sigmoid(preds) |
|
accuracy = self.accuracy_metric(preds, target) |
|
|
|
self.log_dict({'val_loss': loss, 'val_accuracy': accuracy}, prog_bar=True, on_epoch=True) |
|
|
|
return loss |
|
|
|
def test_step(self, test_batch, batch_idx): |
|
X, target = test_batch |
|
|
|
preds = self(X) |
|
|
|
preds = preds.squeeze(1) |
|
loss = self.criterion(preds, target.float()) |
|
|
|
preds = self.sigmoid(preds) |
|
accuracy = self.accuracy_metric(preds, target) |
|
|
|
self.log_dict({'test_loss': loss, 'test_accuracy': accuracy}, prog_bar=True, on_epoch=True) |
|
|
|
return loss |
|
|