|
import torch.nn as nn |
|
from modeling import LiLT |
|
import torch |
|
|
|
from sklearn.metrics import accuracy_score, confusion_matrix |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import numpy as np |
|
import torchmetrics |
|
import pytorch_lightning as pl |
|
|
|
|
|
id2label = ['scientific_report', |
|
'resume', |
|
'memo', |
|
'file_folder', |
|
'specification', |
|
'news_article', |
|
'letter', |
|
'form', |
|
'budget', |
|
'handwritten', |
|
'email', |
|
'invoice', |
|
'presentation', |
|
'scientific_publication', |
|
'questionnaire', |
|
'advertisement'] |
|
|
|
class LiLTForClassification(nn.Module): |
|
|
|
def __init__(self, config): |
|
super(LiLTForClassification, self).__init__() |
|
|
|
self.lilt = LiLT(config) |
|
self.config = config |
|
self.dropout = nn.Dropout(config['hidden_dropout_prob']) |
|
self.linear_layer = nn.Linear(in_features = config['hidden_size'] * 2, out_features = len(id2label)) |
|
|
|
def forward(self, batch_dict): |
|
encodings = self.lilt(batch_dict['input_words'], batch_dict['input_boxes']) |
|
final_out = torch.cat([encodings['layout_hidden_states'][-1], |
|
encodings['text_hidden_states'][-1] |
|
], |
|
axis = -1)[:, 0, :] |
|
final_out = self.linear_layer(final_out) |
|
return final_out |
|
|
|
|
|
class LiLTPL(pl.LightningModule): |
|
|
|
def __init__(self, config , lr = 5e-5): |
|
super(LiLTPL, self).__init__() |
|
|
|
self.save_hyperparameters() |
|
self.config = config |
|
self.lilt = LiLTForClassification(config) |
|
|
|
self.num_classes = len(id2label) |
|
self.train_accuracy_metric = torchmetrics.Accuracy() |
|
self.val_accuracy_metric = torchmetrics.Accuracy() |
|
self.f1_metric = torchmetrics.F1Score(num_classes=self.num_classes) |
|
self.precision_macro_metric = torchmetrics.Precision( |
|
average="macro", num_classes=self.num_classes |
|
) |
|
self.recall_macro_metric = torchmetrics.Recall( |
|
average="macro", num_classes=self.num_classes |
|
) |
|
self.precision_micro_metric = torchmetrics.Precision(average="micro") |
|
self.recall_micro_metric = torchmetrics.Recall(average="micro") |
|
|
|
def forward(self, batch_dict): |
|
logits = self.lilt(batch_dict) |
|
return logits |
|
|
|
def training_step(self, batch, batch_idx): |
|
logits = self.forward(batch) |
|
|
|
loss = nn.CrossEntropyLoss()(logits, batch['label']) |
|
preds = torch.argmax(logits, 1) |
|
|
|
|
|
train_acc = self.train_accuracy_metric(preds, batch["label"]) |
|
|
|
|
|
self.log('train/loss', loss,prog_bar = True, on_epoch=True, logger=True, on_step=True) |
|
self.log('train/acc', train_acc, prog_bar = True, on_epoch=True, logger=True, on_step=True) |
|
|
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
logits = self.forward(batch) |
|
loss = nn.CrossEntropyLoss()(logits, batch['label']) |
|
preds = torch.argmax(logits, 1) |
|
|
|
labels = batch['label'] |
|
|
|
valid_acc = self.val_accuracy_metric(preds, labels) |
|
precision_macro = self.precision_macro_metric(preds, labels) |
|
recall_macro = self.recall_macro_metric(preds, labels) |
|
precision_micro = self.precision_micro_metric(preds, labels) |
|
recall_micro = self.recall_micro_metric(preds, labels) |
|
f1 = self.f1_metric(preds, labels) |
|
|
|
|
|
self.log("valid/loss", loss, prog_bar=True, on_step=True, logger=True) |
|
self.log("valid/acc", valid_acc, prog_bar=True, on_epoch=True, logger=True, on_step=True) |
|
self.log("valid/precision_macro", precision_macro, prog_bar=True, on_epoch=True, logger=True, on_step=True) |
|
self.log("valid/recall_macro", recall_macro, prog_bar=True, on_epoch=True, logger=True, on_step=True) |
|
self.log("valid/precision_micro", precision_micro, prog_bar=True, on_epoch=True, logger=True, on_step=True) |
|
self.log("valid/recall_micro", recall_micro, prog_bar=True, on_epoch=True, logger=True, on_step=True) |
|
self.log("valid/f1", f1, prog_bar=True, on_epoch=True) |
|
|
|
return {"label": batch['label'], "logits": logits} |
|
|
|
def validation_epoch_end(self, outputs): |
|
labels = torch.cat([x["label"] for x in outputs]) |
|
logits = torch.cat([x["logits"] for x in outputs]) |
|
preds = torch.argmax(logits, 1) |
|
|
|
wandb.log({"cm": wandb.sklearn.plot_confusion_matrix(labels.cpu().numpy(), preds.cpu().numpy())}) |
|
self.logger.experiment.log( |
|
{"roc": wandb.plot.roc_curve(labels.cpu().numpy(), logits.cpu().numpy())} |
|
) |
|
|
|
def configure_optimizers(self): |
|
return torch.optim.AdamW(self.parameters(), lr = self.hparams['lr']) |
|
|