emanuelaboros commited on
Commit
d0242b2
·
verified ·
1 Parent(s): f6fd959

Update generic_ner.py

Browse files
Files changed (1) hide show
  1. generic_ner.py +1 -30
generic_ner.py CHANGED
@@ -111,7 +111,7 @@ def get_entities(tokens, tags, confidences, text):
111
  "score": np.average(confidences[idx : idx + len(subtree)])
112
  * 100,
113
  "index": (idx, idx + len(subtree)),
114
- "word": original_string,
115
  "start": entity_start_position,
116
  "end": entity_end_position,
117
  }
@@ -242,35 +242,6 @@ class MultitaskTokenClassificationPipeline(Pipeline):
242
  outputs = self.model(input_ids, attention_mask)
243
  return outputs, text_sentences, text
244
 
245
- # def _forward(self, inputs):
246
- # inputs, text_sentences, text = inputs
247
- # all_logits = {}
248
- #
249
- # for i in range(len(text_sentences)):
250
- # print(inputs["input_ids"][i].shape)
251
- # input_ids = torch.tensor([inputs["input_ids"][i]], dtype=torch.long).to(
252
- # self.model.device
253
- # )
254
- # attention_mask = torch.tensor(
255
- # [inputs["attention_mask"][i]], dtype=torch.long
256
- # ).to(self.model.device)
257
- #
258
- # with torch.no_grad():
259
- # outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
260
- #
261
- # # Accumulate logits for each task
262
- # if not all_logits:
263
- # all_logits = {task: logits for task, logits in outputs.logits.items()}
264
- # else:
265
- # for task in all_logits:
266
- # all_logits[task] = torch.cat(
267
- # (all_logits[task], outputs.logits[task]), dim=1
268
- # )
269
- #
270
- # # Replace outputs.logits with accumulated logits
271
- # outputs.logits = all_logits
272
- #
273
- # return outputs, text_sentences, text
274
 
275
  def postprocess(self, outputs, **kwargs):
276
  """
 
111
  "score": np.average(confidences[idx : idx + len(subtree)])
112
  * 100,
113
  "index": (idx, idx + len(subtree)),
114
+ "word": text[entity_start_position:entity_end_position], #original_string,
115
  "start": entity_start_position,
116
  "end": entity_end_position,
117
  }
 
242
  outputs = self.model(input_ids, attention_mask)
243
  return outputs, text_sentences, text
244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
  def postprocess(self, outputs, **kwargs):
247
  """