Text Classification
PyTorch
Safetensors
English
eurovoc
Inference Endpoints
scampion commited on
Commit
bca3e8e
·
1 Parent(s): f3a36cd

Upload eurovoc.py

Browse files
Files changed (1) hide show
  1. eurovoc.py +212 -0
eurovoc.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ import numpy as np
4
+ import pytorch_lightning as pl
5
+ import torch.nn as nn
6
+ from transformers import BertTokenizerFast as BertTokenizer, AdamW, get_linear_schedule_with_warmup, AutoTokenizer, AutoModel
7
+ from huggingface_hub import PyTorchModelHubMixin
8
+
9
+
10
+ class EurovocDataset(Dataset):
11
+
12
+ def __init__(
13
+ self,
14
+ text: np.array,
15
+ labels: np.array,
16
+ tokenizer: BertTokenizer,
17
+ max_token_len: int = 128
18
+ ):
19
+ self.tokenizer = tokenizer
20
+ self.text = text
21
+ self.labels = labels
22
+ self.max_token_len = max_token_len
23
+
24
+ def __len__(self):
25
+ return len(self.labels)
26
+
27
+ def __getitem__(self, index: int):
28
+ text = self.text[index][0]
29
+ labels = self.labels[index]
30
+
31
+ encoding = self.tokenizer.encode_plus(
32
+ text,
33
+ add_special_tokens=True,
34
+ max_length=self.max_token_len,
35
+ return_token_type_ids=False,
36
+ padding="max_length",
37
+ truncation=True,
38
+ return_attention_mask=True,
39
+ return_tensors='pt',
40
+ )
41
+
42
+ return dict(
43
+ text=text,
44
+ input_ids=encoding["input_ids"].flatten(),
45
+ attention_mask=encoding["attention_mask"].flatten(),
46
+ labels=torch.FloatTensor(labels)
47
+ )
48
+
49
+
50
+ class EuroVocLongTextDataset(Dataset):
51
+
52
+ def __splitter__(text, max_lenght):
53
+ l = text.split()
54
+ for i in range(0, len(l), max_lenght):
55
+ yield l[i:i + max_lenght]
56
+
57
+ def __init__(
58
+ self,
59
+ text: np.array,
60
+ labels: np.array,
61
+ tokenizer: BertTokenizer,
62
+ max_token_len: int = 128
63
+ ):
64
+ self.tokenizer = tokenizer
65
+ self.text = text
66
+ self.labels = labels
67
+ self.max_token_len = max_token_len
68
+
69
+ self.chunks_and_labels = [(c, l) for t, l in zip(self.text, self.labels) for c in self.__splitter__(t)]
70
+
71
+ self.encoding = self.tokenizer.batch_encode_plus(
72
+ [c for c, _ in self.chunks_and_labels],
73
+ add_special_tokens=True,
74
+ max_length=self.max_token_len,
75
+ return_token_type_ids=False,
76
+ padding="max_length",
77
+ truncation=True,
78
+ return_attention_mask=True,
79
+ return_tensors='pt',
80
+ )
81
+
82
+ def __len__(self):
83
+ return len(self.chunks_and_labels)
84
+
85
+ def __getitem__(self, index: int):
86
+ text, labels = self.chunks_and_labels[index]
87
+
88
+ return dict(
89
+ text=text,
90
+ input_ids=self.encoding[index]["input_ids"].flatten(),
91
+ attention_mask=self.encoding[index]["attention_mask"].flatten(),
92
+ labels=torch.FloatTensor(labels)
93
+ )
94
+
95
+
96
+ class EurovocDataModule(pl.LightningDataModule):
97
+
98
+ def __init__(self, bert_model_name, x_tr, y_tr, x_test, y_test, batch_size=8, max_token_len=512):
99
+ super().__init__()
100
+
101
+ self.batch_size = batch_size
102
+ self.x_tr = x_tr
103
+ self.y_tr = y_tr
104
+ self.x_test = x_test
105
+ self.y_test = y_test
106
+ self.tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
107
+ self.max_token_len = max_token_len
108
+
109
+ def setup(self, stage=None):
110
+ self.train_dataset = EurovocDataset(
111
+ self.x_tr,
112
+ self.y_tr,
113
+ self.tokenizer,
114
+ self.max_token_len
115
+ )
116
+
117
+ self.test_dataset = EurovocDataset(
118
+ self.x_test,
119
+ self.y_test,
120
+ self.tokenizer,
121
+ self.max_token_len
122
+ )
123
+
124
+ def train_dataloader(self):
125
+ return DataLoader(
126
+ self.train_dataset,
127
+ batch_size=self.batch_size,
128
+ shuffle=True,
129
+ num_workers=2
130
+ )
131
+
132
+ def val_dataloader(self):
133
+ return DataLoader(
134
+ self.test_dataset,
135
+ batch_size=self.batch_size,
136
+ num_workers=2
137
+ )
138
+
139
+ def test_dataloader(self):
140
+ return DataLoader(
141
+ self.test_dataset,
142
+ batch_size=self.batch_size,
143
+ num_workers=2
144
+ )
145
+
146
+
147
+ class EurovocTagger(pl.LightningModule, PyTorchModelHubMixin):
148
+
149
+ def __init__(self, bert_model_name, n_classes, lr=2e-5, eps=1e-8):
150
+ super().__init__()
151
+ self.bert = AutoModel.from_pretrained(bert_model_name)
152
+ self.dropout = nn.Dropout(p=0.2)
153
+ self.classifier1 = nn.Linear(self.bert.config.hidden_size, n_classes)
154
+ self.criterion = nn.BCELoss()
155
+ self.lr = lr
156
+ self.eps = eps
157
+
158
+ def forward(self, input_ids, attention_mask, labels=None):
159
+ output = self.bert(input_ids, attention_mask=attention_mask)
160
+ output = self.dropout(output.pooler_output)
161
+ output = self.classifier1(output)
162
+ output = torch.sigmoid(output)
163
+ loss = 0
164
+ if labels is not None:
165
+ loss = self.criterion(output, labels)
166
+ return loss, output
167
+
168
+ def training_step(self, batch, batch_idx):
169
+ input_ids = batch["input_ids"]
170
+ attention_mask = batch["attention_mask"]
171
+ labels = batch["labels"]
172
+ loss, outputs = self(input_ids, attention_mask, labels)
173
+ self.log("train_loss", loss, prog_bar=True, logger=True)
174
+ return {"loss": loss, "predictions": outputs, "labels": labels}
175
+
176
+ def validation_step(self, batch, batch_idx):
177
+ input_ids = batch["input_ids"]
178
+ attention_mask = batch["attention_mask"]
179
+ labels = batch["labels"]
180
+ loss, outputs = self(input_ids, attention_mask, labels)
181
+ self.log("val_loss", loss, prog_bar=True, logger=True)
182
+ return loss
183
+
184
+ def test_step(self, batch, batch_idx):
185
+ input_ids = batch["input_ids"]
186
+ attention_mask = batch["attention_mask"]
187
+ labels = batch["labels"]
188
+ loss, outputs = self(input_ids, attention_mask, labels)
189
+ self.log("test_loss", loss, prog_bar=True, logger=True)
190
+ return loss
191
+
192
+ def on_train_epoch_end(self, *args, **kwargs):
193
+ return
194
+ #labels = []
195
+ #predictions = []
196
+ #for output in args['outputs']:
197
+ # for out_labels in output["labels"].detach().cpu():
198
+ # labels.append(out_labels)
199
+ # for out_predictions in output["predictions"].detach().cpu():
200
+ # predictions.append(out_predictions)
201
+
202
+ #labels = torch.stack(labels).int()
203
+ #predictions = torch.stack(predictions)
204
+
205
+ #for i, name in enumerate(mlb.classes_):
206
+ # class_roc_auc = auroc(predictions[:, i], labels[:, i])
207
+ # self.logger.experiment.add_scalar(f"{name}_roc_auc/Train", class_roc_auc, self.current_epoch)
208
+
209
+
210
+ def configure_optimizers(self):
211
+ return torch.optim.AdamW(self.parameters(), lr=self.lr, eps=self.eps)
212
+