Spaces:
Runtime error
Runtime error
import torch.nn as nn | |
from modeling import DocFormerEncoder,ResNetFeatureExtractor,DocFormerEmbeddings,LanguageFeatureExtractor | |
class DocFormerForClassification(nn.Module): | |
def __init__(self, config): | |
super(DocFormerForClassification, self).__init__() | |
self.resnet = ResNetFeatureExtractor(hidden_dim = config['max_position_embeddings']) | |
self.embeddings = DocFormerEmbeddings(config) | |
self.lang_emb = LanguageFeatureExtractor() | |
self.config = config | |
self.dropout = nn.Dropout(config['hidden_dropout_prob']) | |
self.linear_layer = nn.Linear(in_features = config['hidden_size'], out_features = 16) ## Number of Classes | |
self.encoder = DocFormerEncoder(config) | |
def forward(self, batch_dict): | |
x_feat = batch_dict['x_features'] | |
y_feat = batch_dict['y_features'] | |
token = batch_dict['input_ids'] | |
img = batch_dict['resized_scaled_img'] | |
v_bar_s, t_bar_s = self.embeddings(x_feat,y_feat) | |
v_bar = self.resnet(img) | |
t_bar = self.lang_emb(token) | |
out = self.encoder(t_bar,v_bar,t_bar_s,v_bar_s) | |
out = self.linear_layer(out) | |
out = out[:, 0, :] | |
return out | |
## Defining pytorch lightning model | |
import pytorch_lightning as pl | |
from sklearn.metrics import accuracy_score, confusion_matrix | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import numpy as np | |
import torchmetrics | |
import wandb | |
import torch | |
class DocFormer(pl.LightningModule): | |
def __init__(self, config , lr = 5e-5): | |
super(DocFormer, self).__init__() | |
self.save_hyperparameters() | |
self.config = config | |
self.docformer = DocFormerForClassification(config) | |
self.num_classes = 16 | |
self.train_accuracy_metric = torchmetrics.Accuracy() | |
self.val_accuracy_metric = torchmetrics.Accuracy() | |
self.f1_metric = torchmetrics.F1Score(num_classes=self.num_classes) | |
self.precision_macro_metric = torchmetrics.Precision( | |
average="macro", num_classes=self.num_classes | |
) | |
self.recall_macro_metric = torchmetrics.Recall( | |
average="macro", num_classes=self.num_classes | |
) | |
self.precision_micro_metric = torchmetrics.Precision(average="micro") | |
self.recall_micro_metric = torchmetrics.Recall(average="micro") | |
def forward(self, batch_dict): | |
logits = self.docformer(batch_dict) | |
return logits | |
def training_step(self, batch, batch_idx): | |
logits = self.forward(batch) | |
loss = nn.CrossEntropyLoss()(logits, batch['label']) | |
preds = torch.argmax(logits, 1) | |
## Calculating the accuracy score | |
train_acc = self.train_accuracy_metric(preds, batch["label"]) | |
## Logging | |
self.log('train/loss', loss,prog_bar = True, on_epoch=True, logger=True, on_step=True) | |
self.log('train/acc', train_acc, prog_bar = True, on_epoch=True, logger=True, on_step=True) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
logits = self.forward(batch) | |
loss = nn.CrossEntropyLoss()(logits, batch['label']) | |
preds = torch.argmax(logits, 1) | |
labels = batch['label'] | |
# Metrics | |
valid_acc = self.val_accuracy_metric(preds, labels) | |
precision_macro = self.precision_macro_metric(preds, labels) | |
recall_macro = self.recall_macro_metric(preds, labels) | |
precision_micro = self.precision_micro_metric(preds, labels) | |
recall_micro = self.recall_micro_metric(preds, labels) | |
f1 = self.f1_metric(preds, labels) | |
# Logging metrics | |
self.log("valid/loss", loss, prog_bar=True, on_step=True, logger=True) | |
self.log("valid/acc", valid_acc, prog_bar=True, on_epoch=True, logger=True, on_step=True) | |
self.log("valid/precision_macro", precision_macro, prog_bar=True, on_epoch=True, logger=True, on_step=True) | |
self.log("valid/recall_macro", recall_macro, prog_bar=True, on_epoch=True, logger=True, on_step=True) | |
self.log("valid/precision_micro", precision_micro, prog_bar=True, on_epoch=True, logger=True, on_step=True) | |
self.log("valid/recall_micro", recall_micro, prog_bar=True, on_epoch=True, logger=True, on_step=True) | |
self.log("valid/f1", f1, prog_bar=True, on_epoch=True) | |
return {"label": batch['label'], "logits": logits} | |
def validation_epoch_end(self, outputs): | |
labels = torch.cat([x["label"] for x in outputs]) | |
logits = torch.cat([x["logits"] for x in outputs]) | |
preds = torch.argmax(logits, 1) | |
wandb.log({"cm": wandb.sklearn.plot_confusion_matrix(labels.cpu().numpy(), preds.cpu().numpy())}) | |
self.logger.experiment.log( | |
{"roc": wandb.plot.roc_curve(labels.cpu().numpy(), logits.cpu().numpy())} | |
) | |
def configure_optimizers(self): | |
return torch.optim.AdamW(self.parameters(), lr = self.hparams['lr']) |