iakarshu commited on
Commit
a519960
1 Parent(s): 104451d

Upload utils.py

Browse files

Contains the utility function

Files changed (1) hide show
  1. utils.py +122 -0
utils.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from modeling import DocFormerEncoder,ResNetFeatureExtractor,DocFormerEmbeddings,LanguageFeatureExtractor
3
+
4
+ class DocFormerForClassification(nn.Module):
5
+
6
+ def __init__(self, config):
7
+ super(DocFormerForClassification, self).__init__()
8
+
9
+ self.resnet = ResNetFeatureExtractor(hidden_dim = config['max_position_embeddings'])
10
+ self.embeddings = DocFormerEmbeddings(config)
11
+ self.lang_emb = LanguageFeatureExtractor()
12
+ self.config = config
13
+ self.dropout = nn.Dropout(config['hidden_dropout_prob'])
14
+ self.linear_layer = nn.Linear(in_features = config['hidden_size'], out_features = 16) ## Number of Classes
15
+ self.encoder = DocFormerEncoder(config)
16
+
17
+ def forward(self, batch_dict):
18
+
19
+ x_feat = batch_dict['x_features']
20
+ y_feat = batch_dict['y_features']
21
+
22
+ token = batch_dict['input_ids']
23
+ img = batch_dict['resized_scaled_img']
24
+
25
+ v_bar_s, t_bar_s = self.embeddings(x_feat,y_feat)
26
+ v_bar = self.resnet(img)
27
+ t_bar = self.lang_emb(token)
28
+ out = self.encoder(t_bar,v_bar,t_bar_s,v_bar_s)
29
+ out = self.linear_layer(out)
30
+ out = out[:, 0, :]
31
+ return out
32
+
33
+
34
+ ## Defining pytorch lightning model
35
+ import pytorch_lightning as pl
36
+ from sklearn.metrics import accuracy_score, confusion_matrix
37
+ import pandas as pd
38
+ import matplotlib.pyplot as plt
39
+ import seaborn as sns
40
+ import numpy as np
41
+ import torchmetrics
42
+ import wandb
43
+ import torch
44
+
45
+ class DocFormer(pl.LightningModule):
46
+
47
+ def __init__(self, config , lr = 5e-5):
48
+ super(DocFormer, self).__init__()
49
+
50
+ self.save_hyperparameters()
51
+ self.config = config
52
+ self.docformer = DocFormerForClassification(config)
53
+
54
+ self.num_classes = 16
55
+ self.train_accuracy_metric = torchmetrics.Accuracy()
56
+ self.val_accuracy_metric = torchmetrics.Accuracy()
57
+ self.f1_metric = torchmetrics.F1Score(num_classes=self.num_classes)
58
+ self.precision_macro_metric = torchmetrics.Precision(
59
+ average="macro", num_classes=self.num_classes
60
+ )
61
+ self.recall_macro_metric = torchmetrics.Recall(
62
+ average="macro", num_classes=self.num_classes
63
+ )
64
+ self.precision_micro_metric = torchmetrics.Precision(average="micro")
65
+ self.recall_micro_metric = torchmetrics.Recall(average="micro")
66
+
67
+ def forward(self, batch_dict):
68
+ logits = self.docformer(batch_dict)
69
+ return logits
70
+
71
+ def training_step(self, batch, batch_idx):
72
+ logits = self.forward(batch)
73
+
74
+ loss = nn.CrossEntropyLoss()(logits, batch['label'])
75
+ preds = torch.argmax(logits, 1)
76
+
77
+ ## Calculating the accuracy score
78
+ train_acc = self.train_accuracy_metric(preds, batch["label"])
79
+
80
+ ## Logging
81
+ self.log('train/loss', loss,prog_bar = True, on_epoch=True, logger=True, on_step=True)
82
+ self.log('train/acc', train_acc, prog_bar = True, on_epoch=True, logger=True, on_step=True)
83
+
84
+ return loss
85
+
86
+ def validation_step(self, batch, batch_idx):
87
+ logits = self.forward(batch)
88
+ loss = nn.CrossEntropyLoss()(logits, batch['label'])
89
+ preds = torch.argmax(logits, 1)
90
+
91
+ labels = batch['label']
92
+ # Metrics
93
+ valid_acc = self.val_accuracy_metric(preds, labels)
94
+ precision_macro = self.precision_macro_metric(preds, labels)
95
+ recall_macro = self.recall_macro_metric(preds, labels)
96
+ precision_micro = self.precision_micro_metric(preds, labels)
97
+ recall_micro = self.recall_micro_metric(preds, labels)
98
+ f1 = self.f1_metric(preds, labels)
99
+
100
+ # Logging metrics
101
+ self.log("valid/loss", loss, prog_bar=True, on_step=True, logger=True)
102
+ self.log("valid/acc", valid_acc, prog_bar=True, on_epoch=True, logger=True, on_step=True)
103
+ self.log("valid/precision_macro", precision_macro, prog_bar=True, on_epoch=True, logger=True, on_step=True)
104
+ self.log("valid/recall_macro", recall_macro, prog_bar=True, on_epoch=True, logger=True, on_step=True)
105
+ self.log("valid/precision_micro", precision_micro, prog_bar=True, on_epoch=True, logger=True, on_step=True)
106
+ self.log("valid/recall_micro", recall_micro, prog_bar=True, on_epoch=True, logger=True, on_step=True)
107
+ self.log("valid/f1", f1, prog_bar=True, on_epoch=True)
108
+
109
+ return {"label": batch['label'], "logits": logits}
110
+
111
+ def validation_epoch_end(self, outputs):
112
+ labels = torch.cat([x["label"] for x in outputs])
113
+ logits = torch.cat([x["logits"] for x in outputs])
114
+ preds = torch.argmax(logits, 1)
115
+
116
+ wandb.log({"cm": wandb.sklearn.plot_confusion_matrix(labels.cpu().numpy(), preds.cpu().numpy())})
117
+ self.logger.experiment.log(
118
+ {"roc": wandb.plot.roc_curve(labels.cpu().numpy(), logits.cpu().numpy())}
119
+ )
120
+
121
+ def configure_optimizers(self):
122
+ return torch.optim.AdamW(self.parameters(), lr = self.hparams['lr'])