Shaltiel commited on
Commit
23ee9eb
โ€ข
1 Parent(s): 6c92b71

Upload BertForJointParsing.py

Browse files
Files changed (1) hide show
  1. BertForJointParsing.py +174 -4
BertForJointParsing.py CHANGED
@@ -8,7 +8,7 @@ from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
8
  from transformers.models.bert.modeling_bert import BertOnlyMLMHead
9
  from transformers.utils import ModelOutput
10
  from .BertForSyntaxParsing import BertSyntaxParsingHead, SyntaxLabels, SyntaxLogitsOutput, parse_logits as syntax_parse_logits
11
- from .BertForPrefixMarking import BertPrefixMarkingHead, parse_logits as prefix_parse_logits, encode_sentences_for_bert_for_prefix_marking
12
  from .BertForMorphTagging import BertMorphTaggingHead, MorphLogitsOutput, MorphLabels, parse_logits as morph_parse_logits
13
 
14
  import warnings
@@ -186,11 +186,14 @@ class BertForJointParsing(BertPreTrainedModel):
186
  morph_logits=morph_logits
187
  )
188
 
189
- def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, padding='longest', truncation=True, compute_syntax_mst=True, per_token_ner=False):
190
  is_single_sentence = isinstance(sentences, str)
191
  if is_single_sentence:
192
  sentences = [sentences]
193
-
 
 
 
194
  # predict the logits for the sentence
195
  if self.prefix is not None:
196
  inputs = encode_sentences_for_bert_for_prefix_marking(tokenizer, sentences, padding)
@@ -230,10 +233,15 @@ class BertForJointParsing(BertPreTrainedModel):
230
  merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
231
  final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(parsed)
232
 
 
 
 
233
  if is_single_sentence:
234
  final_output = final_output[0]
235
  return final_output
236
 
 
 
237
  def aggregate_ner_tokens(predictions):
238
  entities = []
239
  prev = None
@@ -302,4 +310,166 @@ def lex_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], toke
302
  ret[-1] = (ret[-1][0] + token[2:], ret[-1][1])
303
  continue
304
  ret.append((token, tokenizer._convert_id_to_token(predictions[batch_idx, tok_idx])))
305
- return batch_ret
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from transformers.models.bert.modeling_bert import BertOnlyMLMHead
9
  from transformers.utils import ModelOutput
10
  from .BertForSyntaxParsing import BertSyntaxParsingHead, SyntaxLabels, SyntaxLogitsOutput, parse_logits as syntax_parse_logits
11
+ from .BertForPrefixMarking import BertPrefixMarkingHead, parse_logits as prefix_parse_logits, encode_sentences_for_bert_for_prefix_marking, get_prefixes_from_str
12
  from .BertForMorphTagging import BertMorphTaggingHead, MorphLogitsOutput, MorphLabels, parse_logits as morph_parse_logits
13
 
14
  import warnings
 
186
  morph_logits=morph_logits
187
  )
188
 
189
+ def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, padding='longest', truncation=True, compute_syntax_mst=True, per_token_ner=False, as_iahlt_ud=False, as_htb_ud=False):
190
  is_single_sentence = isinstance(sentences, str)
191
  if is_single_sentence:
192
  sentences = [sentences]
193
+
194
+ if (as_htb_ud or as_iahlt_ud) and (self.prefix is None or self.morph is None or self.syntax is None or self.lex is None):
195
+ raise ValueError("Cannot output UD format when any of the prefix,morph,syntax,lex heads aren't loaded.")
196
+
197
  # predict the logits for the sentence
198
  if self.prefix is not None:
199
  inputs = encode_sentences_for_bert_for_prefix_marking(tokenizer, sentences, padding)
 
233
  merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
234
  final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(parsed)
235
 
236
+ if as_iahlt_ud or as_htb_ud:
237
+ final_output = convert_output_to_ud(final_output, htb_extras=as_htb_ud)
238
+
239
  if is_single_sentence:
240
  final_output = final_output[0]
241
  return final_output
242
 
243
+
244
+
245
  def aggregate_ner_tokens(predictions):
246
  entities = []
247
  prev = None
 
310
  ret[-1] = (ret[-1][0] + token[2:], ret[-1][1])
311
  continue
312
  ret.append((token, tokenizer._convert_id_to_token(predictions[batch_idx, tok_idx])))
313
+ return batch_ret
314
+
315
+ ud_prefixes_to_pos = {
316
+ 'ืฉ': ['SCONJ'],
317
+ 'ืžืฉ': ['SCONJ'],
318
+ 'ื›ืฉ': ['SCONJ'],
319
+ 'ืœื›ืฉ': ['SCONJ'],
320
+ 'ื‘ืฉ': ['SCONJ'],
321
+ 'ืœืฉ': ['SCONJ'],
322
+ 'ื•': ['CCONJ'],
323
+ 'ืœ': ['ADP'],
324
+ 'ื”': ['DET', 'SCONJ'],
325
+ 'ืž': ['ADP', 'SCONJ'],
326
+ 'ื‘': ['ADP'],
327
+ 'ื›': ['ADP', 'ADV'],
328
+ }
329
+ ud_suffix_to_htb_str = {
330
+ 'Gender=Masc|Number=Sing|Person=3': '_ื”ื•ื',
331
+ 'Gender=Masc|Number=Plur|Person=3': '_ื”ื',
332
+ 'Gender=Fem|Number=Sing|Person=3': '_ื”ื™ื',
333
+ 'Gender=Fem|Number=Plur|Person=3': '_ื”ืŸ',
334
+ 'Gender=Fem,Masc|Number=Plur|Person=1': '_ืื ื—ื ื•',
335
+ 'Gender=Fem,Masc|Number=Sing|Person=1': '_ืื ื™',
336
+ 'Gender=Masc|Number=Plur|Person=2': '_ืืชื',
337
+ 'Gender=Masc|Number=Sing|Person=3': '_ื”ื•ื',
338
+ 'Gender=Masc|Number=Sing|Person=2': '_ืืชื”',
339
+ 'Gender=Fem|Number=Sing|Person=2': '_ืืช',
340
+ 'Gender=Masc|Number=Plur|Person=3': '_ื”ื'
341
+ }
342
+ def convert_output_to_ud(output_sentences, htb_extras=False):
343
+ final_output = []
344
+ for sent_idx, sentence in enumerate(output_sentences):
345
+ # next, go through each word and insert it in the UD format. Store in a temp format for the post process
346
+ intermediate_output = []
347
+ ranges = []
348
+ # store a mapping between each word index and the actual line it appears in
349
+ idx_to_key = {-1: 0}
350
+ for word_idx,word in enumerate(sentence['tokens']):
351
+ # handle blank lexemes
352
+ if word['lex'] == '[BLANK]':
353
+ word['lex'] = word['seg'][-1]
354
+
355
+ start = len(intermediate_output)
356
+ idx_to_key[word_idx] = len(intermediate_output) + 1
357
+ # Add in all the prefixes
358
+ if len(word['seg']) > 1:
359
+ for pre in get_prefixes_from_str(word['seg'][0], greedy=True):
360
+ # pos - just take the first valid pos that appears in the predicted prefixes list.
361
+ pos = next((pos for pos in ud_prefixes_to_pos[pre] if pos in word['morph']['prefixes']), ud_prefixes_to_pos[pre][0])
362
+ dep, func = ud_get_prefix_dep(pre, word, word_idx)
363
+ intermediate_output.append(dict(word=pre, lex=pre, pos=pos, dep=dep, func=func, feats='_'))
364
+
365
+ # if there was an implicit heh, add it in dependent on the method
366
+ if not 'ื”' in pre and intermediate_output[-1]['pos'] == 'ADP' and 'DET' in word['morph']['prefixes']:
367
+ if htb_extras:
368
+ intermediate_output.append(dict(word='ื”_', lex='ื”', pos='DET', dep=word_idx, func='det', feats='_'))
369
+ else:
370
+ intermediate_output[-1]['feats'] = 'Definite=Def|PronType=Art'
371
+
372
+
373
+ # add the main word in!
374
+ intermediate_output.append(dict(
375
+ word=word['seg'][-1], lex=word['lex'], pos=word['morph']['pos'],
376
+ dep=word['syntax']['dep_head_idx'], func=word['syntax']['dep_func'],
377
+ feats='|'.join(f'{k}={v}' for k,v in word['morph']['feats'].items())))
378
+
379
+ # if we have suffixes, this changes things
380
+ if word['morph']['suffix']:
381
+ # first determine the dependency info:
382
+ # For adp, num, det - they main word points to here, and the suffix points to the dependency
383
+ if word['morph']['pos'] in ['ADP', 'NUM', 'DET']:
384
+ intermediate_output[-1]['dep'] = len(intermediate_output)
385
+ intermediate_output[-1]['absolute_dep'] = True
386
+ intermediate_output[-1]['func'] = 'case'
387
+ dep = word['syntax']['dep_head_idx']
388
+ func = word['syntax']['dep_func']
389
+ else:
390
+ # if pos is verb -> obj, num -> dep, default to -> nmod:poss
391
+ dep = word_idx
392
+ func = {'VERB': 'obj', 'NUM': 'dep'}.get(word['morph']['pos'], 'nmod:poss')
393
+
394
+ s_word, s_lex = word['seg'][-1], word['lex']
395
+ # update the word of the string and extract the string of the suffix!
396
+ # for IAHLT:
397
+ if not htb_extras:
398
+ # we need to shorten the main word and extract the suffix
399
+ # if it is longer than the lexeme - just take off the lexeme.
400
+ if len(s_word) > len(s_lex):
401
+ idx = len(s_lex)
402
+ # Otherwise, try to find the last letter of the lexeme, and fail that just take the last letter
403
+ else:
404
+ # take either len-1, or the last occurence (which can be -1 === len-1)
405
+ idx = min([len(s_word) - 1, s_word.rfind(s_lex[-1])])
406
+ # extract the suffix and update the main word
407
+ suf = s_word[idx:]
408
+ intermediate_output[-1]['word'] = s_word[:idx]
409
+ # for htb:
410
+ else:
411
+ # main word becomes the lexeme, the suffix is based on the features
412
+ intermediate_output[-1]['word'] = s_lex + '_'
413
+ suf_feats = word['morph']['suffix_feats']
414
+ suf = ud_suffix_to_htb_str.get(f"Gender={suf_feats['Gender']}|Number={suf_feats['Number']}|Person={suf_feats['Person']}", "_ื”ื•ื")
415
+ # for HTB, if the function is poss, then add a shel pointing to the next word
416
+ if func == 'nmod:poss':
417
+ intermediate_output.append(dict(word='_ืฉืœ_', lex='ืฉืœ', pos='ADP', dep=len(intermediate_output) + 1, func='case', feats='_', absolute_dep=True))
418
+ # add the main suffix in
419
+ intermediate_output.append(dict(word=suf, lex='ื”ื•ื', pos='PRON', dep=dep, func=func, feats='|'.join(f'{k}={v}' for k,v in word['morph']['suffix_feats'].items())))
420
+ end = len(intermediate_output)
421
+ ranges.append((start, end, word['token']))
422
+
423
+ # now that we have the intermediate output, combine it to the final output
424
+ cur_output = []
425
+ final_output.append(cur_output)
426
+ # first, add the headers
427
+ cur_output.append(f'# sent_id = {sent_idx + 1}')
428
+ cur_output.append(f'# text = {sentence["text"]}')
429
+
430
+ # add in all the actual entries
431
+ for start,end,token in ranges:
432
+ if end - start > 1:
433
+ cur_output.append(f'{start + 1}-{end}\t{token}\t_\t_\t_\t_\t_\t_\t_\t_')
434
+ for idx,output in enumerate(intermediate_output[start:end], start):
435
+ # compute the actual dependency location
436
+ dep = output['dep'] if output.get('absolute_dep', False) else idx_to_key[output['dep']]
437
+ # and add the full ud string in
438
+ cur_output.append('\t'.join([
439
+ str(idx),
440
+ output['word'],
441
+ output['lex'],
442
+ output['pos'],
443
+ output['pos'],
444
+ output['feats'],
445
+ str(dep),
446
+ output['func'],
447
+ '_', '_'
448
+ ]))
449
+ return final_output
450
+
451
+
452
+ def ud_get_prefix_dep(pre, word, word_idx):
453
+ does_follow_main = False
454
+
455
+ # shin goes to the main word for verbs, otherwise follows the word
456
+ if pre.endswith('ืฉ'):
457
+ does_follow_main = word['morph']['pos'] != 'VERB'
458
+ func = 'mark'
459
+ # vuv goes to the main word if the function is in the list, otherwise follows
460
+ elif pre == 'ื•':
461
+ does_follow_main = word['syntax']['dep_func'] in ["conj", "acl:recl", "parataxis", "root", "acl", "amod", "list", "appos", "dep", "flatccomp"]
462
+ func = 'cc'
463
+ else:
464
+ # for adj, noun, propn, pron, verb - prefixes go to the main word
465
+ if word['morph']['pos'] in ["ADJ", "NOUN", "PROPN", "PRON", "VERB"]:
466
+ does_follow_main = True
467
+ # otherwise - prefix follows the word if the function is in the list
468
+ else: does_follow_main = word['syntax']['dep_func'] in ["compound:affix", "det", "aux", "nummod", "advmod", "dep", "cop", "mark", "fixed"]
469
+
470
+ func = 'case'
471
+ if pre == 'ื”':
472
+ func = 'det' if 'DET' in word['morph']['prefixes'] else 'mark'
473
+
474
+ return (word['syntax']['dep_head_idx'] if does_follow_main else word_idx), func
475
+