Shaltiel commited on
Commit
81c680b
1 Parent(s): b20f15c

add start,end to entities (#1)

Browse files

- add start,end to entities (26aaacaf5bae033452c5a36cd9d5ee60b836a925)

Files changed (2) hide show
  1. BertForJointParsing.py +23 -14
  2. BertForPrefixMarking.py +1 -2
BertForJointParsing.py CHANGED
@@ -200,8 +200,9 @@ class BertForJointParsing(BertPreTrainedModel):
200
  if self.prefix is not None:
201
  inputs = encode_sentences_for_bert_for_prefix_marking(tokenizer, sentences, padding)
202
  else:
203
- inputs = tokenizer(sentences, padding=padding, truncation=truncation, return_tensors='pt')
204
-
 
205
  # Copy the tensors to the right device, and parse!
206
  inputs = {k:v.to(self.device) for k,v in inputs.items()}
207
  output = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_syntax_mst)
@@ -230,7 +231,7 @@ class BertForJointParsing(BertPreTrainedModel):
230
 
231
  # NER logits each sentence gets a list(tuple(word, ner))
232
  if output.ner_logits is not None:
233
- for sent_idx,parsed in enumerate(ner_parse_logits(inputs, sentences, tokenizer, output.ner_logits, self.config.id2label)):
234
  if per_token_ner:
235
  merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
236
  final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(parsed)
@@ -247,17 +248,18 @@ class BertForJointParsing(BertPreTrainedModel):
247
  def aggregate_ner_tokens(predictions):
248
  entities = []
249
  prev = None
250
- for word,pred in predictions:
251
  # O does nothing
252
  if pred == 'O': prev = None
253
  # B- || I-entity != prev (different entity or none)
254
  elif pred.startswith('B-') or pred[2:] != prev:
255
  prev = pred[2:]
256
- entities.append(([word], prev))
257
- else: entities[-1][0].append(word)
258
-
259
- return [dict(phrase=' '.join(words), label=label) for words,label in entities]
260
-
 
261
 
262
  def merge_token_list(src, update, key):
263
  for token_src, token_update in zip(src, update):
@@ -272,9 +274,9 @@ def combine_token_wordpieces(input_ids: torch.Tensor, tokenizer: BertTokenizerFa
272
  else: ret.append(token)
273
  return ret
274
 
275
- def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor, id2label: Dict[int, str]):
276
- input_ids = inputs['input_ids']
277
-
278
  predictions = torch.argmax(logits, dim=-1)
279
  batch_ret = []
280
  for batch_idx in range(len(sentences)):
@@ -286,11 +288,18 @@ def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], toke
286
  if token_id in [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]: continue
287
 
288
  token = tokenizer._convert_id_to_token(token_id)
 
 
 
289
  # wordpieces should just be appended to the previous word
 
 
290
  if token.startswith('##'):
291
- ret[-1] = (ret[-1][0] + token[2:], ret[-1][1])
292
  continue
293
- ret.append((token, id2label[predictions[batch_idx, tok_idx].item()]))
 
 
294
  return batch_ret
295
 
296
  def lex_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor):
 
200
  if self.prefix is not None:
201
  inputs = encode_sentences_for_bert_for_prefix_marking(tokenizer, sentences, padding)
202
  else:
203
+ inputs = tokenizer(sentences, padding=padding, truncation=truncation, return_offsets_mapping=True, return_tensors='pt')
204
+
205
+ offset_mapping = inputs.pop('offset_mapping')
206
  # Copy the tensors to the right device, and parse!
207
  inputs = {k:v.to(self.device) for k,v in inputs.items()}
208
  output = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_syntax_mst)
 
231
 
232
  # NER logits each sentence gets a list(tuple(word, ner))
233
  if output.ner_logits is not None:
234
+ for sent_idx,parsed in enumerate(ner_parse_logits(inputs, sentences, tokenizer, output.ner_logits, self.config.id2label, offset_mapping)):
235
  if per_token_ner:
236
  merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
237
  final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(parsed)
 
248
  def aggregate_ner_tokens(predictions):
249
  entities = []
250
  prev = None
251
+ for word, pred, start, end in predictions:
252
  # O does nothing
253
  if pred == 'O': prev = None
254
  # B- || I-entity != prev (different entity or none)
255
  elif pred.startswith('B-') or pred[2:] != prev:
256
  prev = pred[2:]
257
+ entities.append([[word], prev, start, end])
258
+ else:
259
+ entities[-1][0].append(word)
260
+ entities[-1][3] = end
261
+
262
+ return [dict(phrase=' '.join(words), label=label, start=start, end=end) for words, label, start, end in entities]
263
 
264
  def merge_token_list(src, update, key):
265
  for token_src, token_update in zip(src, update):
 
274
  else: ret.append(token)
275
  return ret
276
 
277
+ def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor, id2label: Dict[int, str], offset_mapping):
278
+ input_ids = inputs['input_ids']
279
+
280
  predictions = torch.argmax(logits, dim=-1)
281
  batch_ret = []
282
  for batch_idx in range(len(sentences)):
 
288
  if token_id in [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]: continue
289
 
290
  token = tokenizer._convert_id_to_token(token_id)
291
+
292
+ # get the offsets for this token
293
+ start_pos, end_pos = offset_mapping[batch_idx, tok_idx]
294
  # wordpieces should just be appended to the previous word
295
+ # we modify the last token in ret
296
+ # by discarding the original end position and replacing it with the new token's end position
297
  if token.startswith('##'):
298
+ ret[-1] = (ret[-1][0] + token[2:], ret[-1][1], ret[-1][2], end_pos.item())
299
  continue
300
+ # for each token, we append a tuple containing: token, label, start position, end position
301
+ ret.append((token, id2label[predictions[batch_idx, tok_idx].item()], start_pos.item(), end_pos.item()))
302
+
303
  return batch_ret
304
 
305
  def lex_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor):
BertForPrefixMarking.py CHANGED
@@ -184,8 +184,7 @@ def parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenize
184
  return ret
185
 
186
  def encode_sentences_for_bert_for_prefix_marking(tokenizer: BertTokenizerFast, sentences: List[str], padding='longest', truncation=True):
187
- inputs = tokenizer(sentences, padding=padding, truncation=truncation, return_tensors='pt')
188
-
189
  # create our prefix_id_options array which will be like the input ids shape but with an addtional
190
  # dimension containing for each prefix whether it can be for that word
191
  prefix_id_options = torch.full(inputs['input_ids'].shape + (TOTAL_POSSIBLE_PREFIX_CLASSES,), TOTAL_POSSIBLE_PREFIX_CLASSES, dtype=torch.long)
 
184
  return ret
185
 
186
  def encode_sentences_for_bert_for_prefix_marking(tokenizer: BertTokenizerFast, sentences: List[str], padding='longest', truncation=True):
187
+ inputs = tokenizer(sentences, padding=padding, truncation=truncation, return_offsets_mapping=True, return_tensors='pt')
 
188
  # create our prefix_id_options array which will be like the input ids shape but with an addtional
189
  # dimension containing for each prefix whether it can be for that word
190
  prefix_id_options = torch.full(inputs['input_ids'].shape + (TOTAL_POSSIBLE_PREFIX_CLASSES,), TOTAL_POSSIBLE_PREFIX_CLASSES, dtype=torch.long)