Hiveurban commited on
Commit
44ef5e1
โ€ข
1 Parent(s): 2267576

Upload BertForJointParsing.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. BertForJointParsing.py +523 -0
BertForJointParsing.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import re
3
+ from operator import itemgetter
4
+ import torch
5
+ from torch import nn
6
+ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
7
+ from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
8
+ from transformers.models.bert.modeling_bert import BertOnlyMLMHead
9
+ from transformers.utils import ModelOutput
10
+ from .BertForSyntaxParsing import BertSyntaxParsingHead, SyntaxLabels, SyntaxLogitsOutput, parse_logits as syntax_parse_logits
11
+ from .BertForPrefixMarking import BertPrefixMarkingHead, parse_logits as prefix_parse_logits, encode_sentences_for_bert_for_prefix_marking, get_prefixes_from_str
12
+ from .BertForMorphTagging import BertMorphTaggingHead, MorphLogitsOutput, MorphLabels, parse_logits as morph_parse_logits
13
+
14
+ import warnings
15
+
16
+ @dataclass
17
+ class JointParsingOutput(ModelOutput):
18
+ loss: Optional[torch.FloatTensor] = None
19
+ # logits will contain the optional predictions for the given labels
20
+ logits: Optional[Union[SyntaxLogitsOutput, None]] = None
21
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
22
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
23
+ # if no labels are given, we will always include the syntax logits separately
24
+ syntax_logits: Optional[SyntaxLogitsOutput] = None
25
+ ner_logits: Optional[torch.FloatTensor] = None
26
+ prefix_logits: Optional[torch.FloatTensor] = None
27
+ lex_logits: Optional[torch.FloatTensor] = None
28
+ morph_logits: Optional[MorphLogitsOutput] = None
29
+
30
+ # wrapper class to wrap a torch.nn.Module so that you can store a module in multiple linked
31
+ # properties without registering the parameter multiple times
32
+ class ModuleRef:
33
+ def __init__(self, module: torch.nn.Module):
34
+ self.module = module
35
+
36
+ def forward(self, *args, **kwargs):
37
+ return self.module.forward(*args, **kwargs)
38
+
39
+ def __call__(self, *args, **kwargs):
40
+ return self.module(*args, **kwargs)
41
+
42
+ class BertForJointParsing(BertPreTrainedModel):
43
+ _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
44
+
45
+ def __init__(self, config, do_syntax=None, do_ner=None, do_prefix=None, do_lex=None, do_morph=None, syntax_head_size=64):
46
+ super().__init__(config)
47
+
48
+ self.bert = BertModel(config, add_pooling_layer=False)
49
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
50
+ # create all the heads as None, and then populate them as defined
51
+ self.syntax, self.ner, self.prefix, self.lex, self.morph = (None,)*5
52
+
53
+ if do_syntax is not None:
54
+ config.do_syntax = do_syntax
55
+ config.syntax_head_size = syntax_head_size
56
+ if do_ner is not None: config.do_ner = do_ner
57
+ if do_prefix is not None: config.do_prefix = do_prefix
58
+ if do_lex is not None: config.do_lex = do_lex
59
+ if do_morph is not None: config.do_morph = do_morph
60
+
61
+ # add all the individual heads
62
+ if config.do_syntax:
63
+ self.syntax = BertSyntaxParsingHead(config)
64
+ if config.do_ner:
65
+ self.num_labels = config.num_labels
66
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels) # name it same as in BertForTokenClassification
67
+ self.ner = ModuleRef(self.classifier)
68
+ if config.do_prefix:
69
+ self.prefix = BertPrefixMarkingHead(config)
70
+ if config.do_lex:
71
+ self.cls = BertOnlyMLMHead(config) # name it the same as in BertForMaskedLM
72
+ self.lex = ModuleRef(self.cls)
73
+ if config.do_morph:
74
+ self.morph = BertMorphTaggingHead(config)
75
+
76
+ # Initialize weights and apply final processing
77
+ self.post_init()
78
+
79
+ def get_output_embeddings(self):
80
+ return self.cls.predictions.decoder if self.lex is not None else None
81
+
82
+ def set_output_embeddings(self, new_embeddings):
83
+ if self.lex is not None:
84
+
85
+ self.cls.predictions.decoder = new_embeddings
86
+
87
+ def forward(
88
+ self,
89
+ input_ids: Optional[torch.Tensor] = None,
90
+ attention_mask: Optional[torch.Tensor] = None,
91
+ token_type_ids: Optional[torch.Tensor] = None,
92
+ position_ids: Optional[torch.Tensor] = None,
93
+ prefix_class_id_options: Optional[torch.Tensor] = None,
94
+ labels: Optional[Union[SyntaxLabels, MorphLabels, torch.Tensor]] = None,
95
+ labels_type: Optional[Literal['syntax', 'ner', 'prefix', 'lex', 'morph']] = None,
96
+ head_mask: Optional[torch.Tensor] = None,
97
+ inputs_embeds: Optional[torch.Tensor] = None,
98
+ output_attentions: Optional[bool] = None,
99
+ output_hidden_states: Optional[bool] = None,
100
+ return_dict: Optional[bool] = None,
101
+ compute_syntax_mst: Optional[bool] = None
102
+ ):
103
+ if return_dict is False:
104
+ warnings.warn("Specified `return_dict=False` but the flag is ignored and treated as always True in this model.")
105
+
106
+ if labels is not None and labels_type is None:
107
+ raise ValueError("Cannot specify labels without labels_type")
108
+
109
+ if labels_type == 'seg' and prefix_class_id_options is None:
110
+ raise ValueError('Cannot calculate prefix logits without prefix_class_id_options')
111
+
112
+ if compute_syntax_mst is not None and self.syntax is None:
113
+ raise ValueError("Cannot compute syntax MST when the syntax head isn't loaded")
114
+
115
+
116
+ bert_outputs = self.bert(
117
+ input_ids,
118
+ attention_mask=attention_mask,
119
+ token_type_ids=token_type_ids,
120
+ position_ids=position_ids,
121
+ head_mask=head_mask,
122
+ inputs_embeds=inputs_embeds,
123
+ output_attentions=output_attentions,
124
+ output_hidden_states=output_hidden_states,
125
+ return_dict=True,
126
+ )
127
+
128
+ # calculate the extended attention mask for any child that might need it
129
+ extended_attention_mask = None
130
+ if attention_mask is not None:
131
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.size())
132
+
133
+ # extract the hidden states, and apply the dropout
134
+ hidden_states = self.dropout(bert_outputs[0])
135
+
136
+ logits = None
137
+ syntax_logits = None
138
+ ner_logits = None
139
+ prefix_logits = None
140
+ lex_logits = None
141
+ morph_logits = None
142
+
143
+ # Calculate the syntax
144
+ if self.syntax is not None and (labels is None or labels_type == 'syntax'):
145
+ # apply the syntax head
146
+ loss, syntax_logits = self.syntax(hidden_states, extended_attention_mask, labels, compute_syntax_mst)
147
+ logits = syntax_logits
148
+
149
+ # Calculate the NER
150
+ if self.ner is not None and (labels is None or labels_type == 'ner'):
151
+ ner_logits = self.ner(hidden_states)
152
+ logits = ner_logits
153
+ if labels is not None:
154
+ loss_fct = nn.CrossEntropyLoss()
155
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
156
+
157
+ # Calculate the segmentation
158
+ if self.prefix is not None and (labels is None or labels_type == 'prefix'):
159
+ loss, prefix_logits = self.prefix(hidden_states, prefix_class_id_options, labels)
160
+ logits = prefix_logits
161
+
162
+ # Calculate the lexeme
163
+ if self.lex is not None and (labels is None or labels_type == 'lex'):
164
+ lex_logits = self.lex(hidden_states)
165
+ logits = lex_logits
166
+ if labels is not None:
167
+ loss_fct = nn.CrossEntropyLoss() # -100 index = padding token
168
+ loss = loss_fct(lex_logits.view(-1, self.config.vocab_size), labels.view(-1))
169
+
170
+ if self.morph is not None and (labels is None or labels_type == 'morph'):
171
+ loss, morph_logits = self.morph(hidden_states, labels)
172
+ logits = morph_logits
173
+
174
+ # no labels => logits = None
175
+ if labels is None: logits = None
176
+
177
+ return JointParsingOutput(
178
+ loss,
179
+ logits,
180
+ hidden_states=bert_outputs.hidden_states,
181
+ attentions=bert_outputs.attentions,
182
+ # all the predicted logits section
183
+ syntax_logits=syntax_logits,
184
+ ner_logits=ner_logits,
185
+ prefix_logits=prefix_logits,
186
+ lex_logits=lex_logits,
187
+ morph_logits=morph_logits
188
+ )
189
+
190
+ def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, padding='longest', truncation=True, compute_syntax_mst=True, per_token_ner=False, output_style: Literal['json', 'ud', 'iahlt_ud'] = 'json'):
191
+ is_single_sentence = isinstance(sentences, str)
192
+ if is_single_sentence:
193
+ sentences = [sentences]
194
+
195
+ if output_style not in ['json', 'ud', 'iahlt_ud']:
196
+ raise ValueError('output_style must be in json/ud/iahlt_ud')
197
+ if output_style in ['ud', 'iahlt_ud'] and (self.prefix is None or self.morph is None or self.syntax is None or self.lex is None):
198
+ raise ValueError("Cannot output UD format when any of the prefix,morph,syntax, and lex heads aren't loaded.")
199
+
200
+ # predict the logits for the sentence
201
+ if self.prefix is not None:
202
+ inputs = encode_sentences_for_bert_for_prefix_marking(tokenizer, sentences, padding)
203
+ else:
204
+ inputs = tokenizer(sentences, padding=padding, truncation=truncation, return_offsets_mapping=True, return_tensors='pt')
205
+
206
+ offset_mapping = inputs.pop('offset_mapping')
207
+ # Copy the tensors to the right device, and parse!
208
+ inputs = {k:v.to(self.device) for k,v in inputs.items()}
209
+ output = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_syntax_mst)
210
+
211
+ input_ids = inputs['input_ids'].tolist() # convert once
212
+ final_output = [dict(text=sentence, tokens=combine_token_wordpieces(ids, offsets, tokenizer)) for sentence, ids, offsets in zip(sentences, input_ids, offset_mapping)]
213
+ # Syntax logits: each sentence gets a dict(tree: List[dict(word,dep_head,dep_head_idx,dep_func)], root_idx: int)
214
+ if output.syntax_logits is not None:
215
+ for sent_idx,parsed in enumerate(syntax_parse_logits(input_ids, sentences, tokenizer, output.syntax_logits)):
216
+ merge_token_list(final_output[sent_idx]['tokens'], parsed['tree'], 'syntax')
217
+ final_output[sent_idx]['root_idx'] = parsed['root_idx']
218
+
219
+ # Prefix logits: each sentence gets a list([prefix_segment, word_without_prefix]) - **WITH CLS & SEP**
220
+ if output.prefix_logits is not None:
221
+ for sent_idx,parsed in enumerate(prefix_parse_logits(input_ids, sentences, tokenizer, output.prefix_logits)):
222
+ merge_token_list(final_output[sent_idx]['tokens'], map(tuple, parsed[1:-1]), 'seg')
223
+
224
+ # Lex logits each sentence gets a list(tuple(word, lexeme))
225
+ if output.lex_logits is not None:
226
+ for sent_idx, parsed in enumerate(lex_parse_logits(input_ids, sentences, tokenizer, output.lex_logits)):
227
+ merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'lex')
228
+
229
+ # morph logits each sentences get a dict(text=str, tokens=list(dict(token, pos, feats, prefixes, suffix, suffix_feats?)))
230
+ if output.morph_logits is not None:
231
+ for sent_idx,parsed in enumerate(morph_parse_logits(input_ids, sentences, tokenizer, output.morph_logits)):
232
+ merge_token_list(final_output[sent_idx]['tokens'], parsed['tokens'], 'morph')
233
+
234
+ # NER logits each sentence gets a list(tuple(word, ner))
235
+ if output.ner_logits is not None:
236
+ for sent_idx,parsed in enumerate(ner_parse_logits(input_ids, sentences, tokenizer, output.ner_logits, self.config.id2label)):
237
+ if per_token_ner:
238
+ merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
239
+ final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(final_output[sent_idx], parsed)
240
+
241
+ if output_style in ['ud', 'iahlt_ud']:
242
+ final_output = convert_output_to_ud(final_output, style='htb' if output_style == 'ud' else 'iahlt')
243
+
244
+ if is_single_sentence:
245
+ final_output = final_output[0]
246
+ return final_output
247
+
248
+
249
+
250
+ def aggregate_ner_tokens(final_output, parsed):
251
+ entities = []
252
+ prev = None
253
+ for token_idx, (d, (word, pred)) in enumerate(zip(final_output['tokens'], parsed)):
254
+ # O does nothing
255
+ if pred == 'O': prev = None
256
+ # B- || I-entity != prev (different entity or none)
257
+ elif pred.startswith('B-') or pred[2:] != prev:
258
+ prev = pred[2:]
259
+ entities.append([[word], dict(label=prev, start=d['offsets']['start'], end=d['offsets']['end'], token_start=token_idx, token_end=token_idx)])
260
+ else:
261
+ entities[-1][0].append(word)
262
+ entities[-1][1]['end'] = d['offsets']['end']
263
+ entities[-1][1]['token_end'] = token_idx
264
+
265
+ return [dict(phrase=' '.join(words), **d) for words, d in entities]
266
+
267
+ def merge_token_list(src, update, key):
268
+ for token_src, token_update in zip(src, update):
269
+ token_src[key] = token_update
270
+
271
+ def combine_token_wordpieces(input_ids: List[int], offset_mapping: torch.Tensor, tokenizer: BertTokenizerFast):
272
+ offset_mapping = offset_mapping.tolist()
273
+ ret = []
274
+ special_toks = tokenizer.all_special_tokens
275
+ for token, offsets in zip(tokenizer.convert_ids_to_tokens(input_ids), offset_mapping):
276
+ if token in special_toks: continue
277
+ if token.startswith('##'):
278
+ ret[-1]['token'] += token[2:]
279
+ ret[-1]['offsets']['end'] = offsets[1]
280
+ else: ret.append(dict(token=token, offsets=dict(start=offsets[0], end=offsets[1])))
281
+ return ret
282
+
283
+ def ner_parse_logits(input_ids: List[List[int]], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor, id2label: Dict[int, str]):
284
+ predictions = torch.argmax(logits, dim=-1).tolist()
285
+ batch_ret = []
286
+
287
+ special_toks = tokenizer.all_special_tokens
288
+ for batch_idx in range(len(sentences)):
289
+
290
+ ret = []
291
+ batch_ret.append(ret)
292
+
293
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[batch_idx])
294
+ for tok_idx in range(len(tokens)):
295
+ token = tokens[tok_idx]
296
+ if token in special_toks: continue
297
+
298
+ # wordpieces should just be appended to the previous word
299
+ # we modify the last token in ret
300
+ # by discarding the original end position and replacing it with the new token's end position
301
+ if token.startswith('##'):
302
+ continue
303
+ # for each token, we append a tuple containing: token, label, start position, end position
304
+ ret.append((token, id2label[predictions[batch_idx][tok_idx]]))
305
+
306
+ return batch_ret
307
+
308
+ def lex_parse_logits(input_ids: List[List[int]], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor):
309
+
310
+ predictions = torch.argsort(logits, dim=-1, descending=True)[..., :3].tolist()
311
+ batch_ret = []
312
+
313
+ special_toks = tokenizer.all_special_tokens
314
+ for batch_idx in range(len(sentences)):
315
+ intermediate_ret = []
316
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[batch_idx])
317
+ for tok_idx in range(len(tokens)):
318
+ token = tokens[tok_idx]
319
+ if token in special_toks: continue
320
+
321
+ # wordpieces should just be appended to the previous word
322
+ if token.startswith('##'):
323
+ intermediate_ret[-1] = (intermediate_ret[-1][0] + token[2:], intermediate_ret[-1][1])
324
+ continue
325
+ intermediate_ret.append((token, tokenizer.convert_ids_to_tokens(predictions[batch_idx][tok_idx])))
326
+
327
+ # build the final output taking into account valid letters
328
+ ret = []
329
+ batch_ret.append(ret)
330
+ for (token, lexemes) in intermediate_ret:
331
+ # must overlap on at least 2 non ืื”ื•ื™ letters
332
+ possible_lets = set(c for c in token if c not in 'ืื”ื•ื™')
333
+ final_lex = '[BLANK]'
334
+ for lex in lexemes:
335
+ if sum(c in possible_lets for c in lex) >= min([2, len(possible_lets), len([c for c in lex if c not in 'ืื”ื•ื™'])]):
336
+ final_lex = lex
337
+ break
338
+ ret.append((token, final_lex))
339
+
340
+ return batch_ret
341
+
342
+ ud_prefixes_to_pos = {
343
+ 'ืฉ': ['SCONJ'],
344
+ 'ืžืฉ': ['SCONJ'],
345
+ 'ื›ืฉ': ['SCONJ'],
346
+ 'ืœื›ืฉ': ['SCONJ'],
347
+ 'ื‘ืฉ': ['SCONJ'],
348
+ 'ืœืฉ': ['SCONJ'],
349
+ 'ื•': ['CCONJ'],
350
+ 'ืœ': ['ADP'],
351
+ 'ื”': ['DET', 'SCONJ'],
352
+ 'ืž': ['ADP', 'SCONJ'],
353
+ 'ื‘': ['ADP'],
354
+ 'ื›': ['ADP', 'ADV'],
355
+ }
356
+ ud_suffix_to_htb_str = {
357
+ 'Gender=Masc|Number=Sing|Person=3': '_ื”ื•ื',
358
+ 'Gender=Masc|Number=Plur|Person=3': '_ื”ื',
359
+ 'Gender=Fem|Number=Sing|Person=3': '_ื”ื™ื',
360
+ 'Gender=Fem|Number=Plur|Person=3': '_ื”ืŸ',
361
+ 'Gender=Fem,Masc|Number=Plur|Person=1': '_ืื ื—ื ื•',
362
+ 'Gender=Fem,Masc|Number=Sing|Person=1': '_ืื ื™',
363
+ 'Gender=Masc|Number=Plur|Person=2': '_ืืชื',
364
+ 'Gender=Masc|Number=Sing|Person=3': '_ื”ื•ื',
365
+ 'Gender=Masc|Number=Sing|Person=2': '_ืืชื”',
366
+ 'Gender=Fem|Number=Sing|Person=2': '_ืืช',
367
+ 'Gender=Masc|Number=Plur|Person=3': '_ื”ื'
368
+ }
369
+ def convert_output_to_ud(output_sentences, style: Literal['htb', 'iahlt']):
370
+ if style not in ['htb', 'iahlt']:
371
+ raise ValueError('style must be htb/iahlt')
372
+
373
+ final_output = []
374
+ for sent_idx, sentence in enumerate(output_sentences):
375
+ # next, go through each word and insert it in the UD format. Store in a temp format for the post process
376
+ intermediate_output = []
377
+ ranges = []
378
+ # store a mapping between each word index and the actual line it appears in
379
+ idx_to_key = {-1: 0}
380
+ for word_idx,word in enumerate(sentence['tokens']):
381
+ try:
382
+ # handle blank lexemes
383
+ if word['lex'] == '[BLANK]':
384
+ word['lex'] = word['seg'][-1]
385
+ except KeyError:
386
+ import json
387
+ print(json.dumps(sentence, ensure_ascii=False, indent=2))
388
+ exit(0)
389
+
390
+ start = len(intermediate_output)
391
+ # Add in all the prefixes
392
+ if len(word['seg']) > 1:
393
+ for pre in get_prefixes_from_str(word['seg'][0], greedy=True):
394
+ # pos - just take the first valid pos that appears in the predicted prefixes list.
395
+ pos = next((pos for pos in ud_prefixes_to_pos[pre] if pos in word['morph']['prefixes']), ud_prefixes_to_pos[pre][0])
396
+ dep, func = ud_get_prefix_dep(pre, word, word_idx)
397
+ intermediate_output.append(dict(word=pre, lex=pre, pos=pos, dep=dep, func=func, feats='_'))
398
+
399
+ # if there was an implicit heh, add it in dependent on the method
400
+ if not 'ื”' in pre and intermediate_output[-1]['pos'] == 'ADP' and 'DET' in word['morph']['prefixes']:
401
+ if style == 'htb':
402
+ intermediate_output.append(dict(word='ื”_', lex='ื”', pos='DET', dep=word_idx, func='det', feats='_'))
403
+ elif style == 'iahlt':
404
+ intermediate_output[-1]['feats'] = 'Definite=Def|PronType=Art'
405
+
406
+
407
+ idx_to_key[word_idx] = len(intermediate_output) + 1
408
+ # add the main word in!
409
+ intermediate_output.append(dict(
410
+ word=word['seg'][-1], lex=word['lex'], pos=word['morph']['pos'],
411
+ dep=word['syntax']['dep_head_idx'], func=word['syntax']['dep_func'],
412
+ feats='|'.join(f'{k}={v}' for k,v in word['morph']['feats'].items())))
413
+
414
+ # if we have suffixes, this changes things
415
+ if word['morph']['suffix']:
416
+ # first determine the dependency info:
417
+ # For adp, num, det - they main word points to here, and the suffix points to the dependency
418
+ entry_to_assign_suf_dep = None
419
+ if word['morph']['pos'] in ['ADP', 'NUM', 'DET']:
420
+ entry_to_assign_suf_dep = intermediate_output[-1]
421
+ intermediate_output[-1]['func'] = 'case'
422
+ dep = word['syntax']['dep_head_idx']
423
+ func = word['syntax']['dep_func']
424
+ else:
425
+ # if pos is verb -> obj, num -> dep, default to -> nmod:poss
426
+ dep = word_idx
427
+ func = {'VERB': 'obj', 'NUM': 'dep'}.get(word['morph']['pos'], 'nmod:poss')
428
+
429
+ s_word, s_lex = word['seg'][-1], word['lex']
430
+ # update the word of the string and extract the string of the suffix!
431
+ # for IAHLT:
432
+ if style == 'iahlt':
433
+ # we need to shorten the main word and extract the suffix
434
+ # if it is longer than the lexeme - just take off the lexeme.
435
+ if len(s_word) > len(s_lex):
436
+ idx = len(s_lex)
437
+ # Otherwise, try to find the last letter of the lexeme, and fail that just take the last letter
438
+ else:
439
+ # take either len-1, or the last occurence (which can be -1 === len-1)
440
+ idx = min([len(s_word) - 1, s_word.rfind(s_lex[-1])])
441
+ # extract the suffix and update the main word
442
+ suf = s_word[idx:]
443
+ intermediate_output[-1]['word'] = s_word[:idx]
444
+ # for htb:
445
+ elif style == 'htb':
446
+ # main word becomes the lexeme, the suffix is based on the features
447
+ intermediate_output[-1]['word'] = (s_lex if s_lex != s_word else s_word[:-1]) + '_'
448
+ suf_feats = word['morph']['suffix_feats']
449
+ suf = ud_suffix_to_htb_str.get(f"Gender={suf_feats.get('Gender', 'Fem,Masc')}|Number={suf_feats.get('Number', 'Sing')}|Person={suf_feats.get('Person', '3')}", "_ื”ื•ื")
450
+ # for HTB, if the function is poss, then add a shel pointing to the next word
451
+ if func == 'nmod:poss' and s_lex != 'ืฉืœ':
452
+ intermediate_output.append(dict(word='_ืฉืœ_', lex='ืฉืœ', pos='ADP', dep=len(intermediate_output) + 2, func='case', feats='_', absolute_dep=True))
453
+ # add the main suffix in
454
+ intermediate_output.append(dict(word=suf, lex='ื”ื•ื', pos='PRON', dep=dep, func=func, feats='|'.join(f'{k}={v}' for k,v in word['morph']['suffix_feats'].items())))
455
+ if entry_to_assign_suf_dep:
456
+ entry_to_assign_suf_dep['dep'] = len(intermediate_output)
457
+ entry_to_assign_suf_dep['absolute_dep'] = True
458
+
459
+ end = len(intermediate_output)
460
+ ranges.append((start, end, word['token']))
461
+
462
+ # now that we have the intermediate output, combine it to the final output
463
+ cur_output = []
464
+ final_output.append(cur_output)
465
+ # first, add the headers
466
+ cur_output.append(f'# sent_id = {sent_idx + 1}')
467
+ cur_output.append(f'# text = {sentence["text"]}')
468
+
469
+ # add in all the actual entries
470
+ for start,end,token in ranges:
471
+ if end - start > 1:
472
+ cur_output.append(f'{start + 1}-{end}\t{token}\t_\t_\t_\t_\t_\t_\t_\t_')
473
+ for idx,output in enumerate(intermediate_output[start:end], start + 1):
474
+ # compute the actual dependency location
475
+ dep = output['dep'] if output.get('absolute_dep', False) else idx_to_key[output['dep']]
476
+ func = normalize_dep_rel(output['func'], style)
477
+ # and add the full ud string in
478
+ cur_output.append('\t'.join([
479
+ str(idx),
480
+ output['word'],
481
+ output['lex'],
482
+ output['pos'],
483
+ output['pos'],
484
+ output['feats'],
485
+ str(dep),
486
+ func,
487
+ '_', '_'
488
+ ]))
489
+ return final_output
490
+
491
+ def normalize_dep_rel(dep, style: Literal['htb', 'iahlt']):
492
+ if style == 'iahlt':
493
+ if dep == 'compound:smixut': return 'compound'
494
+ if dep == 'nsubj:cop': return 'nsubj'
495
+ if dep == 'mark:q': return 'mark'
496
+ if dep == 'case:gen' or dep == 'case:acc': return 'case'
497
+ return dep
498
+
499
+
500
+ def ud_get_prefix_dep(pre, word, word_idx):
501
+ does_follow_main = False
502
+
503
+ # shin goes to the main word for verbs, otherwise follows the word
504
+ if pre.endswith('ืฉ'):
505
+ does_follow_main = word['morph']['pos'] != 'VERB'
506
+ func = 'mark'
507
+ # vuv goes to the main word if the function is in the list, otherwise follows
508
+ elif pre == 'ื•':
509
+ does_follow_main = word['syntax']['dep_func'] not in ["conj", "acl:recl", "parataxis", "root", "acl", "amod", "list", "appos", "dep", "flatccomp"]
510
+ func = 'cc'
511
+ else:
512
+ # for adj, noun, propn, pron, verb - prefixes go to the main word
513
+ if word['morph']['pos'] in ["ADJ", "NOUN", "PROPN", "PRON", "VERB"]:
514
+ does_follow_main = False
515
+ # otherwise - prefix follows the word if the function is in the list
516
+ else: does_follow_main = word['syntax']['dep_func'] in ["compound:affix", "det", "aux", "nummod", "advmod", "dep", "cop", "mark", "fixed"]
517
+
518
+ func = 'case'
519
+ if pre == 'ื”':
520
+ func = 'det' if 'DET' in word['morph']['prefixes'] else 'mark'
521
+
522
+ return (word['syntax']['dep_head_idx'] if does_follow_main else word_idx), func
523
+