Upload BertForJointParsing.py
Browse files- BertForJointParsing.py +174 -4
BertForJointParsing.py
CHANGED
@@ -8,7 +8,7 @@ from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
|
|
8 |
from transformers.models.bert.modeling_bert import BertOnlyMLMHead
|
9 |
from transformers.utils import ModelOutput
|
10 |
from .BertForSyntaxParsing import BertSyntaxParsingHead, SyntaxLabels, SyntaxLogitsOutput, parse_logits as syntax_parse_logits
|
11 |
-
from .BertForPrefixMarking import BertPrefixMarkingHead, parse_logits as prefix_parse_logits, encode_sentences_for_bert_for_prefix_marking
|
12 |
from .BertForMorphTagging import BertMorphTaggingHead, MorphLogitsOutput, MorphLabels, parse_logits as morph_parse_logits
|
13 |
|
14 |
import warnings
|
@@ -186,11 +186,14 @@ class BertForJointParsing(BertPreTrainedModel):
|
|
186 |
morph_logits=morph_logits
|
187 |
)
|
188 |
|
189 |
-
def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, padding='longest', truncation=True, compute_syntax_mst=True, per_token_ner=False):
|
190 |
is_single_sentence = isinstance(sentences, str)
|
191 |
if is_single_sentence:
|
192 |
sentences = [sentences]
|
193 |
-
|
|
|
|
|
|
|
194 |
# predict the logits for the sentence
|
195 |
if self.prefix is not None:
|
196 |
inputs = encode_sentences_for_bert_for_prefix_marking(tokenizer, sentences, padding)
|
@@ -230,10 +233,15 @@ class BertForJointParsing(BertPreTrainedModel):
|
|
230 |
merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
|
231 |
final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(parsed)
|
232 |
|
|
|
|
|
|
|
233 |
if is_single_sentence:
|
234 |
final_output = final_output[0]
|
235 |
return final_output
|
236 |
|
|
|
|
|
237 |
def aggregate_ner_tokens(predictions):
|
238 |
entities = []
|
239 |
prev = None
|
@@ -302,4 +310,166 @@ def lex_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], toke
|
|
302 |
ret[-1] = (ret[-1][0] + token[2:], ret[-1][1])
|
303 |
continue
|
304 |
ret.append((token, tokenizer._convert_id_to_token(predictions[batch_idx, tok_idx])))
|
305 |
-
return batch_ret
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from transformers.models.bert.modeling_bert import BertOnlyMLMHead
|
9 |
from transformers.utils import ModelOutput
|
10 |
from .BertForSyntaxParsing import BertSyntaxParsingHead, SyntaxLabels, SyntaxLogitsOutput, parse_logits as syntax_parse_logits
|
11 |
+
from .BertForPrefixMarking import BertPrefixMarkingHead, parse_logits as prefix_parse_logits, encode_sentences_for_bert_for_prefix_marking, get_prefixes_from_str
|
12 |
from .BertForMorphTagging import BertMorphTaggingHead, MorphLogitsOutput, MorphLabels, parse_logits as morph_parse_logits
|
13 |
|
14 |
import warnings
|
|
|
186 |
morph_logits=morph_logits
|
187 |
)
|
188 |
|
189 |
+
def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, padding='longest', truncation=True, compute_syntax_mst=True, per_token_ner=False, as_iahlt_ud=False, as_htb_ud=False):
|
190 |
is_single_sentence = isinstance(sentences, str)
|
191 |
if is_single_sentence:
|
192 |
sentences = [sentences]
|
193 |
+
|
194 |
+
if (as_htb_ud or as_iahlt_ud) and (self.prefix is None or self.morph is None or self.syntax is None or self.lex is None):
|
195 |
+
raise ValueError("Cannot output UD format when any of the prefix,morph,syntax,lex heads aren't loaded.")
|
196 |
+
|
197 |
# predict the logits for the sentence
|
198 |
if self.prefix is not None:
|
199 |
inputs = encode_sentences_for_bert_for_prefix_marking(tokenizer, sentences, padding)
|
|
|
233 |
merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
|
234 |
final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(parsed)
|
235 |
|
236 |
+
if as_iahlt_ud or as_htb_ud:
|
237 |
+
final_output = convert_output_to_ud(final_output, htb_extras=as_htb_ud)
|
238 |
+
|
239 |
if is_single_sentence:
|
240 |
final_output = final_output[0]
|
241 |
return final_output
|
242 |
|
243 |
+
|
244 |
+
|
245 |
def aggregate_ner_tokens(predictions):
|
246 |
entities = []
|
247 |
prev = None
|
|
|
310 |
ret[-1] = (ret[-1][0] + token[2:], ret[-1][1])
|
311 |
continue
|
312 |
ret.append((token, tokenizer._convert_id_to_token(predictions[batch_idx, tok_idx])))
|
313 |
+
return batch_ret
|
314 |
+
|
315 |
+
ud_prefixes_to_pos = {
|
316 |
+
'ืฉ': ['SCONJ'],
|
317 |
+
'ืืฉ': ['SCONJ'],
|
318 |
+
'ืืฉ': ['SCONJ'],
|
319 |
+
'ืืืฉ': ['SCONJ'],
|
320 |
+
'ืืฉ': ['SCONJ'],
|
321 |
+
'ืืฉ': ['SCONJ'],
|
322 |
+
'ื': ['CCONJ'],
|
323 |
+
'ื': ['ADP'],
|
324 |
+
'ื': ['DET', 'SCONJ'],
|
325 |
+
'ื': ['ADP', 'SCONJ'],
|
326 |
+
'ื': ['ADP'],
|
327 |
+
'ื': ['ADP', 'ADV'],
|
328 |
+
}
|
329 |
+
ud_suffix_to_htb_str = {
|
330 |
+
'Gender=Masc|Number=Sing|Person=3': '_ืืื',
|
331 |
+
'Gender=Masc|Number=Plur|Person=3': '_ืื',
|
332 |
+
'Gender=Fem|Number=Sing|Person=3': '_ืืื',
|
333 |
+
'Gender=Fem|Number=Plur|Person=3': '_ืื',
|
334 |
+
'Gender=Fem,Masc|Number=Plur|Person=1': '_ืื ืื ื',
|
335 |
+
'Gender=Fem,Masc|Number=Sing|Person=1': '_ืื ื',
|
336 |
+
'Gender=Masc|Number=Plur|Person=2': '_ืืชื',
|
337 |
+
'Gender=Masc|Number=Sing|Person=3': '_ืืื',
|
338 |
+
'Gender=Masc|Number=Sing|Person=2': '_ืืชื',
|
339 |
+
'Gender=Fem|Number=Sing|Person=2': '_ืืช',
|
340 |
+
'Gender=Masc|Number=Plur|Person=3': '_ืื'
|
341 |
+
}
|
342 |
+
def convert_output_to_ud(output_sentences, htb_extras=False):
|
343 |
+
final_output = []
|
344 |
+
for sent_idx, sentence in enumerate(output_sentences):
|
345 |
+
# next, go through each word and insert it in the UD format. Store in a temp format for the post process
|
346 |
+
intermediate_output = []
|
347 |
+
ranges = []
|
348 |
+
# store a mapping between each word index and the actual line it appears in
|
349 |
+
idx_to_key = {-1: 0}
|
350 |
+
for word_idx,word in enumerate(sentence['tokens']):
|
351 |
+
# handle blank lexemes
|
352 |
+
if word['lex'] == '[BLANK]':
|
353 |
+
word['lex'] = word['seg'][-1]
|
354 |
+
|
355 |
+
start = len(intermediate_output)
|
356 |
+
idx_to_key[word_idx] = len(intermediate_output) + 1
|
357 |
+
# Add in all the prefixes
|
358 |
+
if len(word['seg']) > 1:
|
359 |
+
for pre in get_prefixes_from_str(word['seg'][0], greedy=True):
|
360 |
+
# pos - just take the first valid pos that appears in the predicted prefixes list.
|
361 |
+
pos = next((pos for pos in ud_prefixes_to_pos[pre] if pos in word['morph']['prefixes']), ud_prefixes_to_pos[pre][0])
|
362 |
+
dep, func = ud_get_prefix_dep(pre, word, word_idx)
|
363 |
+
intermediate_output.append(dict(word=pre, lex=pre, pos=pos, dep=dep, func=func, feats='_'))
|
364 |
+
|
365 |
+
# if there was an implicit heh, add it in dependent on the method
|
366 |
+
if not 'ื' in pre and intermediate_output[-1]['pos'] == 'ADP' and 'DET' in word['morph']['prefixes']:
|
367 |
+
if htb_extras:
|
368 |
+
intermediate_output.append(dict(word='ื_', lex='ื', pos='DET', dep=word_idx, func='det', feats='_'))
|
369 |
+
else:
|
370 |
+
intermediate_output[-1]['feats'] = 'Definite=Def|PronType=Art'
|
371 |
+
|
372 |
+
|
373 |
+
# add the main word in!
|
374 |
+
intermediate_output.append(dict(
|
375 |
+
word=word['seg'][-1], lex=word['lex'], pos=word['morph']['pos'],
|
376 |
+
dep=word['syntax']['dep_head_idx'], func=word['syntax']['dep_func'],
|
377 |
+
feats='|'.join(f'{k}={v}' for k,v in word['morph']['feats'].items())))
|
378 |
+
|
379 |
+
# if we have suffixes, this changes things
|
380 |
+
if word['morph']['suffix']:
|
381 |
+
# first determine the dependency info:
|
382 |
+
# For adp, num, det - they main word points to here, and the suffix points to the dependency
|
383 |
+
if word['morph']['pos'] in ['ADP', 'NUM', 'DET']:
|
384 |
+
intermediate_output[-1]['dep'] = len(intermediate_output)
|
385 |
+
intermediate_output[-1]['absolute_dep'] = True
|
386 |
+
intermediate_output[-1]['func'] = 'case'
|
387 |
+
dep = word['syntax']['dep_head_idx']
|
388 |
+
func = word['syntax']['dep_func']
|
389 |
+
else:
|
390 |
+
# if pos is verb -> obj, num -> dep, default to -> nmod:poss
|
391 |
+
dep = word_idx
|
392 |
+
func = {'VERB': 'obj', 'NUM': 'dep'}.get(word['morph']['pos'], 'nmod:poss')
|
393 |
+
|
394 |
+
s_word, s_lex = word['seg'][-1], word['lex']
|
395 |
+
# update the word of the string and extract the string of the suffix!
|
396 |
+
# for IAHLT:
|
397 |
+
if not htb_extras:
|
398 |
+
# we need to shorten the main word and extract the suffix
|
399 |
+
# if it is longer than the lexeme - just take off the lexeme.
|
400 |
+
if len(s_word) > len(s_lex):
|
401 |
+
idx = len(s_lex)
|
402 |
+
# Otherwise, try to find the last letter of the lexeme, and fail that just take the last letter
|
403 |
+
else:
|
404 |
+
# take either len-1, or the last occurence (which can be -1 === len-1)
|
405 |
+
idx = min([len(s_word) - 1, s_word.rfind(s_lex[-1])])
|
406 |
+
# extract the suffix and update the main word
|
407 |
+
suf = s_word[idx:]
|
408 |
+
intermediate_output[-1]['word'] = s_word[:idx]
|
409 |
+
# for htb:
|
410 |
+
else:
|
411 |
+
# main word becomes the lexeme, the suffix is based on the features
|
412 |
+
intermediate_output[-1]['word'] = s_lex + '_'
|
413 |
+
suf_feats = word['morph']['suffix_feats']
|
414 |
+
suf = ud_suffix_to_htb_str.get(f"Gender={suf_feats['Gender']}|Number={suf_feats['Number']}|Person={suf_feats['Person']}", "_ืืื")
|
415 |
+
# for HTB, if the function is poss, then add a shel pointing to the next word
|
416 |
+
if func == 'nmod:poss':
|
417 |
+
intermediate_output.append(dict(word='_ืฉื_', lex='ืฉื', pos='ADP', dep=len(intermediate_output) + 1, func='case', feats='_', absolute_dep=True))
|
418 |
+
# add the main suffix in
|
419 |
+
intermediate_output.append(dict(word=suf, lex='ืืื', pos='PRON', dep=dep, func=func, feats='|'.join(f'{k}={v}' for k,v in word['morph']['suffix_feats'].items())))
|
420 |
+
end = len(intermediate_output)
|
421 |
+
ranges.append((start, end, word['token']))
|
422 |
+
|
423 |
+
# now that we have the intermediate output, combine it to the final output
|
424 |
+
cur_output = []
|
425 |
+
final_output.append(cur_output)
|
426 |
+
# first, add the headers
|
427 |
+
cur_output.append(f'# sent_id = {sent_idx + 1}')
|
428 |
+
cur_output.append(f'# text = {sentence["text"]}')
|
429 |
+
|
430 |
+
# add in all the actual entries
|
431 |
+
for start,end,token in ranges:
|
432 |
+
if end - start > 1:
|
433 |
+
cur_output.append(f'{start + 1}-{end}\t{token}\t_\t_\t_\t_\t_\t_\t_\t_')
|
434 |
+
for idx,output in enumerate(intermediate_output[start:end], start):
|
435 |
+
# compute the actual dependency location
|
436 |
+
dep = output['dep'] if output.get('absolute_dep', False) else idx_to_key[output['dep']]
|
437 |
+
# and add the full ud string in
|
438 |
+
cur_output.append('\t'.join([
|
439 |
+
str(idx),
|
440 |
+
output['word'],
|
441 |
+
output['lex'],
|
442 |
+
output['pos'],
|
443 |
+
output['pos'],
|
444 |
+
output['feats'],
|
445 |
+
str(dep),
|
446 |
+
output['func'],
|
447 |
+
'_', '_'
|
448 |
+
]))
|
449 |
+
return final_output
|
450 |
+
|
451 |
+
|
452 |
+
def ud_get_prefix_dep(pre, word, word_idx):
|
453 |
+
does_follow_main = False
|
454 |
+
|
455 |
+
# shin goes to the main word for verbs, otherwise follows the word
|
456 |
+
if pre.endswith('ืฉ'):
|
457 |
+
does_follow_main = word['morph']['pos'] != 'VERB'
|
458 |
+
func = 'mark'
|
459 |
+
# vuv goes to the main word if the function is in the list, otherwise follows
|
460 |
+
elif pre == 'ื':
|
461 |
+
does_follow_main = word['syntax']['dep_func'] in ["conj", "acl:recl", "parataxis", "root", "acl", "amod", "list", "appos", "dep", "flatccomp"]
|
462 |
+
func = 'cc'
|
463 |
+
else:
|
464 |
+
# for adj, noun, propn, pron, verb - prefixes go to the main word
|
465 |
+
if word['morph']['pos'] in ["ADJ", "NOUN", "PROPN", "PRON", "VERB"]:
|
466 |
+
does_follow_main = True
|
467 |
+
# otherwise - prefix follows the word if the function is in the list
|
468 |
+
else: does_follow_main = word['syntax']['dep_func'] in ["compound:affix", "det", "aux", "nummod", "advmod", "dep", "cop", "mark", "fixed"]
|
469 |
+
|
470 |
+
func = 'case'
|
471 |
+
if pre == 'ื':
|
472 |
+
func = 'det' if 'DET' in word['morph']['prefixes'] else 'mark'
|
473 |
+
|
474 |
+
return (word['syntax']['dep_head_idx'] if does_follow_main else word_idx), func
|
475 |
+
|