Shaltiel commited on
Commit
f4bee92
1 Parent(s): 8896709

Upload BertForJointParsing.py

Browse files
Files changed (1) hide show
  1. BertForJointParsing.py +8 -7
BertForJointParsing.py CHANGED
@@ -81,6 +81,7 @@ class BertForJointParsing(BertPreTrainedModel):
81
 
82
  def set_output_embeddings(self, new_embeddings):
83
  if self.lex is not None:
 
84
  self.cls.predictions.decoder = new_embeddings
85
 
86
  def forward(
@@ -248,18 +249,19 @@ class BertForJointParsing(BertPreTrainedModel):
248
  def aggregate_ner_tokens(final_output, parsed):
249
  entities = []
250
  prev = None
251
- for d, (word, pred) in zip(final_output['tokens'], parsed):
252
  # O does nothing
253
  if pred == 'O': prev = None
254
  # B- || I-entity != prev (different entity or none)
255
  elif pred.startswith('B-') or pred[2:] != prev:
256
  prev = pred[2:]
257
- entities.append([[word], prev, d['offsets']['start'], d['offsets']['end']])
258
  else:
259
  entities[-1][0].append(word)
260
- entities[-1][3] = d['offsets']['end']
 
261
 
262
- return [dict(phrase=' '.join(words), label=label, start=start, end=end) for words, label, start, end in entities]
263
 
264
  def merge_token_list(src, update, key):
265
  for token_src, token_update in zip(src, update):
@@ -268,13 +270,12 @@ def merge_token_list(src, update, key):
268
  def combine_token_wordpieces(input_ids: torch.Tensor, offset_mapping: torch.Tensor, tokenizer: BertTokenizerFast):
269
  offset_mapping = offset_mapping.tolist()
270
  ret = []
271
- for token_idx, (token, offsets) in enumerate(zip(tokenizer.convert_ids_to_tokens(input_ids), offset_mapping)):
272
  if token in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token]: continue
273
  if token.startswith('##'):
274
  ret[-1]['token'] += token[2:]
275
- ret[-1]['token_idxs'].append(token_idx)
276
  ret[-1]['offsets']['end'] = offsets[1]
277
- else: ret.append(dict(token=token, token_idxs=[token_idx], offsets=dict(start=offsets[0], end=offsets[1])))
278
  return ret
279
 
280
  def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor, id2label: Dict[int, str]):
 
81
 
82
  def set_output_embeddings(self, new_embeddings):
83
  if self.lex is not None:
84
+
85
  self.cls.predictions.decoder = new_embeddings
86
 
87
  def forward(
 
249
  def aggregate_ner_tokens(final_output, parsed):
250
  entities = []
251
  prev = None
252
+ for token_idx, (d, (word, pred)) in enumerate(zip(final_output['tokens'], parsed)):
253
  # O does nothing
254
  if pred == 'O': prev = None
255
  # B- || I-entity != prev (different entity or none)
256
  elif pred.startswith('B-') or pred[2:] != prev:
257
  prev = pred[2:]
258
+ entities.append([[word], dict(label=prev, start=d['offsets']['start'], end=d['offsets']['end'], token_start=token_idx, token_end=token_idx)])
259
  else:
260
  entities[-1][0].append(word)
261
+ entities[-1][1]['end'] = d['offsets']['end']
262
+ entities[-1][1]['token_end'] = token_idx
263
 
264
+ return [dict(phrase=' '.join(words), **d) for words, d in entities]
265
 
266
  def merge_token_list(src, update, key):
267
  for token_src, token_update in zip(src, update):
 
270
  def combine_token_wordpieces(input_ids: torch.Tensor, offset_mapping: torch.Tensor, tokenizer: BertTokenizerFast):
271
  offset_mapping = offset_mapping.tolist()
272
  ret = []
273
+ for token, offsets in zip(tokenizer.convert_ids_to_tokens(input_ids), offset_mapping):
274
  if token in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token]: continue
275
  if token.startswith('##'):
276
  ret[-1]['token'] += token[2:]
 
277
  ret[-1]['offsets']['end'] = offsets[1]
278
+ else: ret.append(dict(token=token, offsets=dict(start=offsets[0], end=offsets[1])))
279
  return ret
280
 
281
  def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor, id2label: Dict[int, str]):