emanuelaboros commited on
Commit
5915b56
·
verified ·
1 Parent(s): a84fd08

Update generic_ner.py

Browse files
Files changed (1) hide show
  1. generic_ner.py +60 -74
generic_ner.py CHANGED
@@ -1,15 +1,16 @@
 
1
  import numpy as np
 
2
  from nltk.chunk import conlltags2tree
3
  from nltk import pos_tag
4
  from nltk.tree import Tree
5
- import re, string
6
- import pysbd
7
- import torch
8
  import torch.nn.functional as F
9
- from transformers import Pipeline
10
  from langdetect import detect
11
- from nltk.tokenize import sent_tokenize
12
- from typing import List
 
 
13
 
14
 
15
  def tokenize(text):
@@ -201,74 +202,59 @@ class MultitaskTokenClassificationPipeline(Pipeline):
201
  }
202
  return preprocess_kwargs, {}, {}
203
 
204
- class MultitaskTokenClassificationPipeline(Pipeline):
205
-
206
- def _sanitize_parameters(self, **kwargs):
207
- preprocess_kwargs = {}
208
- if "text" in kwargs:
209
- preprocess_kwargs["text"] = kwargs["text"]
210
- self.label_map = self.model.config.label_map
211
- self.id2label = {
212
- task: {id_: label for label, id_ in labels.items()}
213
- for task, labels in self.label_map.items()
214
- }
215
- return preprocess_kwargs, {}, {}
216
-
217
- def preprocess(self, text, **kwargs):
218
-
219
- language = detect(text)
220
- sentences = segment_and_trim_sentences(text, language, 512)
221
-
222
- tokenized_inputs = self.tokenizer(
223
- sentences,
224
- padding="max_length",
225
- truncation=True,
226
- max_length=512,
227
- return_tensors="pt",
228
- )
229
-
230
- text_sentence = [
231
- tokenize(add_spaces_around_punctuation(sentence))
232
- for sentence in sentences
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  ]
234
- return tokenized_inputs, text_sentence, text
235
-
236
- def _forward(self, inputs):
237
- inputs, text_sentence, text = inputs
238
- input_ids = inputs["input_ids"].to(self.model.device)
239
- attention_mask = inputs["attention_mask"].to(self.model.device)
240
-
241
- with torch.no_grad():
242
- outputs = self.model(input_ids, attention_mask)
243
-
244
- return outputs, text_sentence, text
245
-
246
- def postprocess(self, outputs, **kwargs):
247
- tokens_result, text_sentence, text = outputs
248
-
249
- predictions = {}
250
- confidence_scores = {}
251
- for task, logits in tokens_result.logits.items():
252
- predictions[task] = torch.argmax(logits, dim=-1).tolist()
253
- confidence_scores[task] = F.softmax(logits, dim=-1).tolist()
254
-
255
- decoded_predictions = {}
256
- for task, preds in predictions.items():
257
- decoded_predictions[task] = [
258
- [self.id2label[task][label] for label in seq] for seq in preds
259
- ]
260
- entities = {}
261
- for task, preds in predictions.items():
262
- words_list, preds_list, confidence_list = realign(
263
- text_sentence,
264
- preds[0],
265
- confidence_scores[task][0],
266
- self.tokenizer,
267
- self.id2label[task],
268
- )
269
 
270
- entities[task] = get_entities(
271
- words_list, preds_list, confidence_list, text
272
- )
273
 
274
- return entities
 
1
+ from transformers import Pipeline
2
  import numpy as np
3
+ import torch
4
  from nltk.chunk import conlltags2tree
5
  from nltk import pos_tag
6
  from nltk.tree import Tree
7
+ import string
 
 
8
  import torch.nn.functional as F
 
9
  from langdetect import detect
10
+
11
+
12
+ import re, string
13
+ import pysbd
14
 
15
 
16
  def tokenize(text):
 
202
  }
203
  return preprocess_kwargs, {}, {}
204
 
205
+ def preprocess(self, text, **kwargs):
206
+ language = detect(text)
207
+ sentences = segment_and_trim_sentences(text, language, 512)
208
+
209
+ tokenized_inputs = self.tokenizer(
210
+ text, padding="max_length", truncation=True, max_length=512
211
+ )
212
+
213
+ text_sentence = tokenize(add_spaces_around_punctuation(text))
214
+ return tokenized_inputs, text_sentence, text
215
+
216
+ def _forward(self, inputs):
217
+ inputs, text_sentence, text = inputs
218
+ input_ids = torch.tensor([inputs["input_ids"]], dtype=torch.long).to(
219
+ self.model.device
220
+ )
221
+ attention_mask = torch.tensor([inputs["attention_mask"]], dtype=torch.long).to(
222
+ self.model.device
223
+ )
224
+ with torch.no_grad():
225
+ outputs = self.model(input_ids, attention_mask)
226
+ return outputs, text_sentence, text
227
+
228
+ def postprocess(self, outputs, **kwargs):
229
+ """
230
+ Postprocess the outputs of the model
231
+ :param outputs:
232
+ :param kwargs:
233
+ :return:
234
+ """
235
+ tokens_result, text_sentence, text = outputs
236
+
237
+ predictions = {}
238
+ confidence_scores = {}
239
+ for task, logits in tokens_result.logits.items():
240
+ predictions[task] = torch.argmax(logits, dim=-1).tolist()
241
+ confidence_scores[task] = F.softmax(logits, dim=-1).tolist()
242
+
243
+ decoded_predictions = {}
244
+ for task, preds in predictions.items():
245
+ decoded_predictions[task] = [
246
+ [self.id2label[task][label] for label in seq] for seq in preds
247
  ]
248
+ entities = {}
249
+ for task, preds in predictions.items():
250
+ words_list, preds_list, confidence_list = realign(
251
+ text_sentence,
252
+ preds[0],
253
+ confidence_scores[task][0],
254
+ self.tokenizer,
255
+ self.id2label[task],
256
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
+ entities[task] = get_entities(words_list, preds_list, confidence_list, text)
 
 
259
 
260
+ return entities