Spaces:
Runtime error
Runtime error
Upload utils.py
Browse filesContains the utility function
utils.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from modeling import DocFormerEncoder,ResNetFeatureExtractor,DocFormerEmbeddings,LanguageFeatureExtractor
|
3 |
+
|
4 |
+
class DocFormerForClassification(nn.Module):
|
5 |
+
|
6 |
+
def __init__(self, config):
|
7 |
+
super(DocFormerForClassification, self).__init__()
|
8 |
+
|
9 |
+
self.resnet = ResNetFeatureExtractor(hidden_dim = config['max_position_embeddings'])
|
10 |
+
self.embeddings = DocFormerEmbeddings(config)
|
11 |
+
self.lang_emb = LanguageFeatureExtractor()
|
12 |
+
self.config = config
|
13 |
+
self.dropout = nn.Dropout(config['hidden_dropout_prob'])
|
14 |
+
self.linear_layer = nn.Linear(in_features = config['hidden_size'], out_features = 16) ## Number of Classes
|
15 |
+
self.encoder = DocFormerEncoder(config)
|
16 |
+
|
17 |
+
def forward(self, batch_dict):
|
18 |
+
|
19 |
+
x_feat = batch_dict['x_features']
|
20 |
+
y_feat = batch_dict['y_features']
|
21 |
+
|
22 |
+
token = batch_dict['input_ids']
|
23 |
+
img = batch_dict['resized_scaled_img']
|
24 |
+
|
25 |
+
v_bar_s, t_bar_s = self.embeddings(x_feat,y_feat)
|
26 |
+
v_bar = self.resnet(img)
|
27 |
+
t_bar = self.lang_emb(token)
|
28 |
+
out = self.encoder(t_bar,v_bar,t_bar_s,v_bar_s)
|
29 |
+
out = self.linear_layer(out)
|
30 |
+
out = out[:, 0, :]
|
31 |
+
return out
|
32 |
+
|
33 |
+
|
34 |
+
## Defining pytorch lightning model
|
35 |
+
import pytorch_lightning as pl
|
36 |
+
from sklearn.metrics import accuracy_score, confusion_matrix
|
37 |
+
import pandas as pd
|
38 |
+
import matplotlib.pyplot as plt
|
39 |
+
import seaborn as sns
|
40 |
+
import numpy as np
|
41 |
+
import torchmetrics
|
42 |
+
import wandb
|
43 |
+
import torch
|
44 |
+
|
45 |
+
class DocFormer(pl.LightningModule):
|
46 |
+
|
47 |
+
def __init__(self, config , lr = 5e-5):
|
48 |
+
super(DocFormer, self).__init__()
|
49 |
+
|
50 |
+
self.save_hyperparameters()
|
51 |
+
self.config = config
|
52 |
+
self.docformer = DocFormerForClassification(config)
|
53 |
+
|
54 |
+
self.num_classes = 16
|
55 |
+
self.train_accuracy_metric = torchmetrics.Accuracy()
|
56 |
+
self.val_accuracy_metric = torchmetrics.Accuracy()
|
57 |
+
self.f1_metric = torchmetrics.F1Score(num_classes=self.num_classes)
|
58 |
+
self.precision_macro_metric = torchmetrics.Precision(
|
59 |
+
average="macro", num_classes=self.num_classes
|
60 |
+
)
|
61 |
+
self.recall_macro_metric = torchmetrics.Recall(
|
62 |
+
average="macro", num_classes=self.num_classes
|
63 |
+
)
|
64 |
+
self.precision_micro_metric = torchmetrics.Precision(average="micro")
|
65 |
+
self.recall_micro_metric = torchmetrics.Recall(average="micro")
|
66 |
+
|
67 |
+
def forward(self, batch_dict):
|
68 |
+
logits = self.docformer(batch_dict)
|
69 |
+
return logits
|
70 |
+
|
71 |
+
def training_step(self, batch, batch_idx):
|
72 |
+
logits = self.forward(batch)
|
73 |
+
|
74 |
+
loss = nn.CrossEntropyLoss()(logits, batch['label'])
|
75 |
+
preds = torch.argmax(logits, 1)
|
76 |
+
|
77 |
+
## Calculating the accuracy score
|
78 |
+
train_acc = self.train_accuracy_metric(preds, batch["label"])
|
79 |
+
|
80 |
+
## Logging
|
81 |
+
self.log('train/loss', loss,prog_bar = True, on_epoch=True, logger=True, on_step=True)
|
82 |
+
self.log('train/acc', train_acc, prog_bar = True, on_epoch=True, logger=True, on_step=True)
|
83 |
+
|
84 |
+
return loss
|
85 |
+
|
86 |
+
def validation_step(self, batch, batch_idx):
|
87 |
+
logits = self.forward(batch)
|
88 |
+
loss = nn.CrossEntropyLoss()(logits, batch['label'])
|
89 |
+
preds = torch.argmax(logits, 1)
|
90 |
+
|
91 |
+
labels = batch['label']
|
92 |
+
# Metrics
|
93 |
+
valid_acc = self.val_accuracy_metric(preds, labels)
|
94 |
+
precision_macro = self.precision_macro_metric(preds, labels)
|
95 |
+
recall_macro = self.recall_macro_metric(preds, labels)
|
96 |
+
precision_micro = self.precision_micro_metric(preds, labels)
|
97 |
+
recall_micro = self.recall_micro_metric(preds, labels)
|
98 |
+
f1 = self.f1_metric(preds, labels)
|
99 |
+
|
100 |
+
# Logging metrics
|
101 |
+
self.log("valid/loss", loss, prog_bar=True, on_step=True, logger=True)
|
102 |
+
self.log("valid/acc", valid_acc, prog_bar=True, on_epoch=True, logger=True, on_step=True)
|
103 |
+
self.log("valid/precision_macro", precision_macro, prog_bar=True, on_epoch=True, logger=True, on_step=True)
|
104 |
+
self.log("valid/recall_macro", recall_macro, prog_bar=True, on_epoch=True, logger=True, on_step=True)
|
105 |
+
self.log("valid/precision_micro", precision_micro, prog_bar=True, on_epoch=True, logger=True, on_step=True)
|
106 |
+
self.log("valid/recall_micro", recall_micro, prog_bar=True, on_epoch=True, logger=True, on_step=True)
|
107 |
+
self.log("valid/f1", f1, prog_bar=True, on_epoch=True)
|
108 |
+
|
109 |
+
return {"label": batch['label'], "logits": logits}
|
110 |
+
|
111 |
+
def validation_epoch_end(self, outputs):
|
112 |
+
labels = torch.cat([x["label"] for x in outputs])
|
113 |
+
logits = torch.cat([x["logits"] for x in outputs])
|
114 |
+
preds = torch.argmax(logits, 1)
|
115 |
+
|
116 |
+
wandb.log({"cm": wandb.sklearn.plot_confusion_matrix(labels.cpu().numpy(), preds.cpu().numpy())})
|
117 |
+
self.logger.experiment.log(
|
118 |
+
{"roc": wandb.plot.roc_curve(labels.cpu().numpy(), logits.cpu().numpy())}
|
119 |
+
)
|
120 |
+
|
121 |
+
def configure_optimizers(self):
|
122 |
+
return torch.optim.AdamW(self.parameters(), lr = self.hparams['lr'])
|