lilt / utils.py
iakarshu's picture
Upload utils.py
9a6e24a
import torch.nn as nn
from modeling import LiLT
import torch
## Defining pytorch lightning model
from sklearn.metrics import accuracy_score, confusion_matrix
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torchmetrics
import pytorch_lightning as pl
id2label = ['scientific_report',
'resume',
'memo',
'file_folder',
'specification',
'news_article',
'letter',
'form',
'budget',
'handwritten',
'email',
'invoice',
'presentation',
'scientific_publication',
'questionnaire',
'advertisement']
class LiLTForClassification(nn.Module):
def __init__(self, config):
super(LiLTForClassification, self).__init__()
self.lilt = LiLT(config)
self.config = config
self.dropout = nn.Dropout(config['hidden_dropout_prob'])
self.linear_layer = nn.Linear(in_features = config['hidden_size'] * 2, out_features = len(id2label)) ## Number of Classes
def forward(self, batch_dict):
encodings = self.lilt(batch_dict['input_words'], batch_dict['input_boxes'])
final_out = torch.cat([encodings['layout_hidden_states'][-1],
encodings['text_hidden_states'][-1]
],
axis = -1)[:, 0, :]
final_out = self.linear_layer(final_out)
return final_out
class LiLTPL(pl.LightningModule):
def __init__(self, config , lr = 5e-5):
super(LiLTPL, self).__init__()
self.save_hyperparameters()
self.config = config
self.lilt = LiLTForClassification(config)
self.num_classes = len(id2label)
self.train_accuracy_metric = torchmetrics.Accuracy()
self.val_accuracy_metric = torchmetrics.Accuracy()
self.f1_metric = torchmetrics.F1Score(num_classes=self.num_classes)
self.precision_macro_metric = torchmetrics.Precision(
average="macro", num_classes=self.num_classes
)
self.recall_macro_metric = torchmetrics.Recall(
average="macro", num_classes=self.num_classes
)
self.precision_micro_metric = torchmetrics.Precision(average="micro")
self.recall_micro_metric = torchmetrics.Recall(average="micro")
def forward(self, batch_dict):
logits = self.lilt(batch_dict)
return logits
def training_step(self, batch, batch_idx):
logits = self.forward(batch)
loss = nn.CrossEntropyLoss()(logits, batch['label'])
preds = torch.argmax(logits, 1)
## Calculating the accuracy score
train_acc = self.train_accuracy_metric(preds, batch["label"])
## Logging
self.log('train/loss', loss,prog_bar = True, on_epoch=True, logger=True, on_step=True)
self.log('train/acc', train_acc, prog_bar = True, on_epoch=True, logger=True, on_step=True)
return loss
def validation_step(self, batch, batch_idx):
logits = self.forward(batch)
loss = nn.CrossEntropyLoss()(logits, batch['label'])
preds = torch.argmax(logits, 1)
labels = batch['label']
# Metrics
valid_acc = self.val_accuracy_metric(preds, labels)
precision_macro = self.precision_macro_metric(preds, labels)
recall_macro = self.recall_macro_metric(preds, labels)
precision_micro = self.precision_micro_metric(preds, labels)
recall_micro = self.recall_micro_metric(preds, labels)
f1 = self.f1_metric(preds, labels)
# Logging metrics
self.log("valid/loss", loss, prog_bar=True, on_step=True, logger=True)
self.log("valid/acc", valid_acc, prog_bar=True, on_epoch=True, logger=True, on_step=True)
self.log("valid/precision_macro", precision_macro, prog_bar=True, on_epoch=True, logger=True, on_step=True)
self.log("valid/recall_macro", recall_macro, prog_bar=True, on_epoch=True, logger=True, on_step=True)
self.log("valid/precision_micro", precision_micro, prog_bar=True, on_epoch=True, logger=True, on_step=True)
self.log("valid/recall_micro", recall_micro, prog_bar=True, on_epoch=True, logger=True, on_step=True)
self.log("valid/f1", f1, prog_bar=True, on_epoch=True)
return {"label": batch['label'], "logits": logits}
def validation_epoch_end(self, outputs):
labels = torch.cat([x["label"] for x in outputs])
logits = torch.cat([x["logits"] for x in outputs])
preds = torch.argmax(logits, 1)
wandb.log({"cm": wandb.sklearn.plot_confusion_matrix(labels.cpu().numpy(), preds.cpu().numpy())})
self.logger.experiment.log(
{"roc": wandb.plot.roc_curve(labels.cpu().numpy(), logits.cpu().numpy())}
)
def configure_optimizers(self):
return torch.optim.AdamW(self.parameters(), lr = self.hparams['lr'])