Shaltiel commited on
Commit
9e48f1b
1 Parent(s): f4bee92

Added major speedup

Browse files
BertForJointParsing.py CHANGED
@@ -208,31 +208,32 @@ class BertForJointParsing(BertPreTrainedModel):
208
  inputs = {k:v.to(self.device) for k,v in inputs.items()}
209
  output = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_syntax_mst)
210
 
211
- final_output = [dict(text=sentence, tokens=combine_token_wordpieces(ids, offsets, tokenizer)) for sentence, ids, offsets in zip(sentences, inputs['input_ids'], offset_mapping)]
 
212
  # Syntax logits: each sentence gets a dict(tree: List[dict(word,dep_head,dep_head_idx,dep_func)], root_idx: int)
213
  if output.syntax_logits is not None:
214
- for sent_idx,parsed in enumerate(syntax_parse_logits(inputs, sentences, tokenizer, output.syntax_logits)):
215
  merge_token_list(final_output[sent_idx]['tokens'], parsed['tree'], 'syntax')
216
  final_output[sent_idx]['root_idx'] = parsed['root_idx']
217
 
218
  # Prefix logits: each sentence gets a list([prefix_segment, word_without_prefix]) - **WITH CLS & SEP**
219
  if output.prefix_logits is not None:
220
- for sent_idx,parsed in enumerate(prefix_parse_logits(inputs, sentences, tokenizer, output.prefix_logits)):
221
  merge_token_list(final_output[sent_idx]['tokens'], map(tuple, parsed[1:-1]), 'seg')
222
 
223
  # Lex logits each sentence gets a list(tuple(word, lexeme))
224
  if output.lex_logits is not None:
225
- for sent_idx, parsed in enumerate(lex_parse_logits(inputs, sentences, tokenizer, output.lex_logits)):
226
  merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'lex')
227
 
228
  # morph logits each sentences get a dict(text=str, tokens=list(dict(token, pos, feats, prefixes, suffix, suffix_feats?)))
229
  if output.morph_logits is not None:
230
- for sent_idx,parsed in enumerate(morph_parse_logits(inputs, sentences, tokenizer, output.morph_logits)):
231
  merge_token_list(final_output[sent_idx]['tokens'], parsed['tokens'], 'morph')
232
 
233
  # NER logits each sentence gets a list(tuple(word, ner))
234
  if output.ner_logits is not None:
235
- for sent_idx,parsed in enumerate(ner_parse_logits(inputs, sentences, tokenizer, output.ner_logits, self.config.id2label)):
236
  if per_token_ner:
237
  merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
238
  final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(final_output[sent_idx], parsed)
@@ -267,31 +268,32 @@ def merge_token_list(src, update, key):
267
  for token_src, token_update in zip(src, update):
268
  token_src[key] = token_update
269
 
270
- def combine_token_wordpieces(input_ids: torch.Tensor, offset_mapping: torch.Tensor, tokenizer: BertTokenizerFast):
271
  offset_mapping = offset_mapping.tolist()
272
  ret = []
 
273
  for token, offsets in zip(tokenizer.convert_ids_to_tokens(input_ids), offset_mapping):
274
- if token in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token]: continue
275
  if token.startswith('##'):
276
  ret[-1]['token'] += token[2:]
277
  ret[-1]['offsets']['end'] = offsets[1]
278
  else: ret.append(dict(token=token, offsets=dict(start=offsets[0], end=offsets[1])))
279
  return ret
280
 
281
- def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor, id2label: Dict[int, str]):
282
- input_ids = inputs['input_ids']
283
-
284
- predictions = torch.argmax(logits, dim=-1)
285
  batch_ret = []
 
 
286
  for batch_idx in range(len(sentences)):
 
287
  ret = []
288
  batch_ret.append(ret)
289
- for tok_idx in range(input_ids.shape[1]):
290
- token_id = input_ids[batch_idx, tok_idx]
291
- # ignore cls, sep, pad
292
- if token_id in [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]: continue
293
 
294
- token = tokenizer._convert_id_to_token(token_id)
 
 
 
295
 
296
  # wordpieces should just be appended to the previous word
297
  # we modify the last token in ret
@@ -299,29 +301,29 @@ def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], toke
299
  if token.startswith('##'):
300
  continue
301
  # for each token, we append a tuple containing: token, label, start position, end position
302
- ret.append((token, id2label[predictions[batch_idx, tok_idx].item()]))
303
 
304
  return batch_ret
305
 
306
- def lex_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor):
307
- input_ids = inputs['input_ids']
308
 
309
- predictions = torch.argsort(logits, dim=-1, descending=True)[..., :3]
310
  batch_ret = []
 
 
311
  for batch_idx in range(len(sentences)):
312
  intermediate_ret = []
313
- for tok_idx in range(input_ids.shape[1]):
314
- token_id = input_ids[batch_idx, tok_idx]
315
- # ignore cls, sep, pad
316
- if token_id in [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]: continue
317
 
318
- token = tokenizer._convert_id_to_token(token_id)
319
  # wordpieces should just be appended to the previous word
320
  if token.startswith('##'):
321
  intermediate_ret[-1] = (intermediate_ret[-1][0] + token[2:], intermediate_ret[-1][1])
322
  continue
323
- intermediate_ret.append((token, tokenizer.convert_ids_to_tokens(predictions[batch_idx, tok_idx])))
324
-
325
  # build the final output taking into account valid letters
326
  ret = []
327
  batch_ret.append(ret)
@@ -376,10 +378,15 @@ def convert_output_to_ud(output_sentences, style: Literal['htb', 'iahlt']):
376
  # store a mapping between each word index and the actual line it appears in
377
  idx_to_key = {-1: 0}
378
  for word_idx,word in enumerate(sentence['tokens']):
379
- # handle blank lexemes
380
- if word['lex'] == '[BLANK]':
381
- word['lex'] = word['seg'][-1]
382
-
 
 
 
 
 
383
  start = len(intermediate_output)
384
  # Add in all the prefixes
385
  if len(word['seg']) > 1:
 
208
  inputs = {k:v.to(self.device) for k,v in inputs.items()}
209
  output = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_syntax_mst)
210
 
211
+ input_ids = inputs['input_ids'].tolist() # convert once
212
+ final_output = [dict(text=sentence, tokens=combine_token_wordpieces(ids, offsets, tokenizer)) for sentence, ids, offsets in zip(sentences, input_ids, offset_mapping)]
213
  # Syntax logits: each sentence gets a dict(tree: List[dict(word,dep_head,dep_head_idx,dep_func)], root_idx: int)
214
  if output.syntax_logits is not None:
215
+ for sent_idx,parsed in enumerate(syntax_parse_logits(input_ids, sentences, tokenizer, output.syntax_logits)):
216
  merge_token_list(final_output[sent_idx]['tokens'], parsed['tree'], 'syntax')
217
  final_output[sent_idx]['root_idx'] = parsed['root_idx']
218
 
219
  # Prefix logits: each sentence gets a list([prefix_segment, word_without_prefix]) - **WITH CLS & SEP**
220
  if output.prefix_logits is not None:
221
+ for sent_idx,parsed in enumerate(prefix_parse_logits(input_ids, sentences, tokenizer, output.prefix_logits)):
222
  merge_token_list(final_output[sent_idx]['tokens'], map(tuple, parsed[1:-1]), 'seg')
223
 
224
  # Lex logits each sentence gets a list(tuple(word, lexeme))
225
  if output.lex_logits is not None:
226
+ for sent_idx, parsed in enumerate(lex_parse_logits(input_ids, sentences, tokenizer, output.lex_logits)):
227
  merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'lex')
228
 
229
  # morph logits each sentences get a dict(text=str, tokens=list(dict(token, pos, feats, prefixes, suffix, suffix_feats?)))
230
  if output.morph_logits is not None:
231
+ for sent_idx,parsed in enumerate(morph_parse_logits(input_ids, sentences, tokenizer, output.morph_logits)):
232
  merge_token_list(final_output[sent_idx]['tokens'], parsed['tokens'], 'morph')
233
 
234
  # NER logits each sentence gets a list(tuple(word, ner))
235
  if output.ner_logits is not None:
236
+ for sent_idx,parsed in enumerate(ner_parse_logits(input_ids, sentences, tokenizer, output.ner_logits, self.config.id2label)):
237
  if per_token_ner:
238
  merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
239
  final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(final_output[sent_idx], parsed)
 
268
  for token_src, token_update in zip(src, update):
269
  token_src[key] = token_update
270
 
271
+ def combine_token_wordpieces(input_ids: List[int], offset_mapping: torch.Tensor, tokenizer: BertTokenizerFast):
272
  offset_mapping = offset_mapping.tolist()
273
  ret = []
274
+ special_toks = tokenizer.all_special_tokens
275
  for token, offsets in zip(tokenizer.convert_ids_to_tokens(input_ids), offset_mapping):
276
+ if token in special_toks: continue
277
  if token.startswith('##'):
278
  ret[-1]['token'] += token[2:]
279
  ret[-1]['offsets']['end'] = offsets[1]
280
  else: ret.append(dict(token=token, offsets=dict(start=offsets[0], end=offsets[1])))
281
  return ret
282
 
283
+ def ner_parse_logits(input_ids: List[List[int]], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor, id2label: Dict[int, str]):
284
+ predictions = torch.argmax(logits, dim=-1).tolist()
 
 
285
  batch_ret = []
286
+
287
+ special_toks = tokenizer.all_special_tokens
288
  for batch_idx in range(len(sentences)):
289
+
290
  ret = []
291
  batch_ret.append(ret)
 
 
 
 
292
 
293
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[batch_idx])
294
+ for tok_idx in range(len(tokens)):
295
+ token = tokens[tok_idx]
296
+ if token in special_toks: continue
297
 
298
  # wordpieces should just be appended to the previous word
299
  # we modify the last token in ret
 
301
  if token.startswith('##'):
302
  continue
303
  # for each token, we append a tuple containing: token, label, start position, end position
304
+ ret.append((token, id2label[predictions[batch_idx][tok_idx]]))
305
 
306
  return batch_ret
307
 
308
+ def lex_parse_logits(input_ids: List[List[int]], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor):
 
309
 
310
+ predictions = torch.argsort(logits, dim=-1, descending=True)[..., :3].tolist()
311
  batch_ret = []
312
+
313
+ special_toks = tokenizer.all_special_tokens
314
  for batch_idx in range(len(sentences)):
315
  intermediate_ret = []
316
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[batch_idx])
317
+ for tok_idx in range(len(tokens)):
318
+ token = tokens[tok_idx]
319
+ if token in special_toks: continue
320
 
 
321
  # wordpieces should just be appended to the previous word
322
  if token.startswith('##'):
323
  intermediate_ret[-1] = (intermediate_ret[-1][0] + token[2:], intermediate_ret[-1][1])
324
  continue
325
+ intermediate_ret.append((token, tokenizer.convert_ids_to_tokens(predictions[batch_idx][tok_idx])))
326
+
327
  # build the final output taking into account valid letters
328
  ret = []
329
  batch_ret.append(ret)
 
378
  # store a mapping between each word index and the actual line it appears in
379
  idx_to_key = {-1: 0}
380
  for word_idx,word in enumerate(sentence['tokens']):
381
+ try:
382
+ # handle blank lexemes
383
+ if word['lex'] == '[BLANK]':
384
+ word['lex'] = word['seg'][-1]
385
+ except KeyError:
386
+ import json
387
+ print(json.dumps(sentence, ensure_ascii=False, indent=2))
388
+ exit(0)
389
+
390
  start = len(intermediate_output)
391
  # Add in all the prefixes
392
  if len(word['seg']) > 1:
BertForMorphTagging.py CHANGED
@@ -159,42 +159,42 @@ class BertForMorphTagging(BertPreTrainedModel):
159
  inputs = {k:v.to(self.device) for k,v in inputs.items()}
160
  # calculate the logits
161
  logits = self.forward(**inputs, return_dict=True).logits
162
- return parse_logits(inputs, sentences, tokenizer, logits)
163
 
164
- def parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: MorphLogitsOutput):
165
  prefix_logits, pos_logits, feats_logits, suffix_logits, suffix_feats_logits = \
166
  logits.prefix_logits, logits.pos_logits, logits.features_logits, logits.suffix_logits, logits.suffix_features_logits
167
 
168
- prefix_predictions = (prefix_logits > 0.5).int() # Threshold at 0.5 for multi-label classification
169
- pos_predictions = pos_logits.argmax(axis=-1)
170
- suffix_predictions = suffix_logits.argmax(axis=-1)
171
- feats_predictions = [logits.argmax(axis=-1) for logits in feats_logits]
172
- suffix_feats_predictions = [logits.argmax(axis=-1) for logits in suffix_feats_logits]
173
 
174
  # create the return dictionary
175
  # for each sentence, return a dict object with the following files { text, tokens }
176
  # Where tokens is a list of dicts, where each dict is:
177
  # { pos: str, feats: dict, prefixes: List[str], suffix: str | bool, suffix_feats: dict | None}
178
- special_tokens = set([tokenizer.pad_token, tokenizer.cls_token, tokenizer.sep_token])
179
  ret = []
180
  for sent_idx,sentence in enumerate(sentences):
181
- input_id_strs = tokenizer.convert_ids_to_tokens(inputs['input_ids'][sent_idx])
182
  # iterate through each token in the sentence, ignoring special tokens
183
  tokens = []
184
  for token_idx,token_str in enumerate(input_id_strs):
185
- if not token_str in special_tokens:
186
- if token_str.startswith('##'):
187
- tokens[-1]['token'] += token_str[2:]
188
- continue
189
- tokens.append(dict(
190
- token=token_str,
191
- pos=ALL_POS[pos_predictions[sent_idx, token_idx]],
192
- feats=get_features_dict_from_predictions(feats_predictions, (sent_idx, token_idx)),
193
- prefixes=[ALL_PREFIX_POS[idx] for idx,i in enumerate(prefix_predictions[sent_idx, token_idx]) if i > 0],
194
- suffix=get_suffix_or_false(ALL_SUFFIX_POS[suffix_predictions[sent_idx, token_idx]]),
195
- ))
196
- if tokens[-1]['suffix']:
197
- tokens[-1]['suffix_feats'] = get_features_dict_from_predictions(suffix_feats_predictions, (sent_idx, token_idx))
198
  ret.append(dict(text=sentence, tokens=tokens))
199
  return ret
200
 
@@ -204,7 +204,7 @@ def get_suffix_or_false(suffix):
204
  def get_features_dict_from_predictions(predictions, idx):
205
  ret = {}
206
  for (feat_idx, (feat_name, feat_values)) in enumerate(ALL_FEATURES):
207
- val = feat_values[predictions[feat_idx][idx]]
208
  if val != 'none':
209
  ret[feat_name] = val
210
  return ret
 
159
  inputs = {k:v.to(self.device) for k,v in inputs.items()}
160
  # calculate the logits
161
  logits = self.forward(**inputs, return_dict=True).logits
162
+ return parse_logits(inputs['input_ids'].tolist(), sentences, tokenizer, logits)
163
 
164
+ def parse_logits(input_ids: List[List[int]], sentences: List[str], tokenizer: BertTokenizerFast, logits: MorphLogitsOutput):
165
  prefix_logits, pos_logits, feats_logits, suffix_logits, suffix_feats_logits = \
166
  logits.prefix_logits, logits.pos_logits, logits.features_logits, logits.suffix_logits, logits.suffix_features_logits
167
 
168
+ prefix_predictions = (prefix_logits > 0.5).int().tolist() # Threshold at 0.5 for multi-label classification
169
+ pos_predictions = pos_logits.argmax(axis=-1).tolist()
170
+ suffix_predictions = suffix_logits.argmax(axis=-1).tolist()
171
+ feats_predictions = [logits.argmax(axis=-1).tolist() for logits in feats_logits]
172
+ suffix_feats_predictions = [logits.argmax(axis=-1).tolist() for logits in suffix_feats_logits]
173
 
174
  # create the return dictionary
175
  # for each sentence, return a dict object with the following files { text, tokens }
176
  # Where tokens is a list of dicts, where each dict is:
177
  # { pos: str, feats: dict, prefixes: List[str], suffix: str | bool, suffix_feats: dict | None}
178
+ special_toks = tokenizer.all_special_tokens
179
  ret = []
180
  for sent_idx,sentence in enumerate(sentences):
181
+ input_id_strs = tokenizer.convert_ids_to_tokens(input_ids[sent_idx])
182
  # iterate through each token in the sentence, ignoring special tokens
183
  tokens = []
184
  for token_idx,token_str in enumerate(input_id_strs):
185
+ if token_str in special_toks: continue
186
+ if token_str.startswith('##'):
187
+ tokens[-1]['token'] += token_str[2:]
188
+ continue
189
+ tokens.append(dict(
190
+ token=token_str,
191
+ pos=ALL_POS[pos_predictions[sent_idx][token_idx]],
192
+ feats=get_features_dict_from_predictions(feats_predictions, (sent_idx, token_idx)),
193
+ prefixes=[ALL_PREFIX_POS[idx] for idx,i in enumerate(prefix_predictions[sent_idx][token_idx]) if i > 0],
194
+ suffix=get_suffix_or_false(ALL_SUFFIX_POS[suffix_predictions[sent_idx][token_idx]]),
195
+ ))
196
+ if tokens[-1]['suffix']:
197
+ tokens[-1]['suffix_feats'] = get_features_dict_from_predictions(suffix_feats_predictions, (sent_idx, token_idx))
198
  ret.append(dict(text=sentence, tokens=tokens))
199
  return ret
200
 
 
204
  def get_features_dict_from_predictions(predictions, idx):
205
  ret = {}
206
  for (feat_idx, (feat_name, feat_values)) in enumerate(ALL_FEATURES):
207
+ val = feat_values[predictions[feat_idx][idx[0]][idx[1]]]
208
  if val != 'none':
209
  ret[feat_name] = val
210
  return ret
BertForPrefixMarking.py CHANGED
@@ -154,15 +154,15 @@ class BertForPrefixMarking(BertPreTrainedModel):
154
 
155
  # run through bert
156
  logits = self.forward(**inputs, return_dict=True).logits
157
- return parse_logits(inputs, sentences, tokenizer, logits)
158
 
159
- def parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.FloatTensor):
160
  # extract the predictions by argmaxing the final dimension (batch x sequence x prefixes x prediction)
161
- logit_preds = torch.argmax(logits, axis=3)
162
 
163
  ret = []
164
 
165
- for sent_idx,sent_ids in enumerate(inputs['input_ids']):
166
  tokens = tokenizer.convert_ids_to_tokens(sent_ids)
167
  ret.append([])
168
  for tok_idx,token in enumerate(tokens):
@@ -176,7 +176,7 @@ def parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenize
176
  token += tokens[next_tok_idx][2:]
177
  next_tok_idx += 1
178
 
179
- prefix_len = get_predicted_prefix_len_from_logits(token, logit_preds[sent_idx, tok_idx])
180
 
181
  if not prefix_len:
182
  ret[-1].append([token])
@@ -232,7 +232,7 @@ def get_predicted_prefix_len_from_logits(token, token_logits):
232
  seen_prefixes.add(prefix)
233
 
234
  # check if we predicted this prefix
235
- if token_logits[PREFIXES_TO_CLASS[prefix]].item():
236
  cur_len += len(prefix)
237
  if last_check: break
238
  skip_next = len(prefix) > 1
 
154
 
155
  # run through bert
156
  logits = self.forward(**inputs, return_dict=True).logits
157
+ return parse_logits(inputs['input_ids'].tolist(), sentences, tokenizer, logits)
158
 
159
+ def parse_logits(input_ids: List[List[int]], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.FloatTensor):
160
  # extract the predictions by argmaxing the final dimension (batch x sequence x prefixes x prediction)
161
+ logit_preds = torch.argmax(logits, axis=3).tolist()
162
 
163
  ret = []
164
 
165
+ for sent_idx,sent_ids in enumerate(input_ids):
166
  tokens = tokenizer.convert_ids_to_tokens(sent_ids)
167
  ret.append([])
168
  for tok_idx,token in enumerate(tokens):
 
176
  token += tokens[next_tok_idx][2:]
177
  next_tok_idx += 1
178
 
179
+ prefix_len = get_predicted_prefix_len_from_logits(token, logit_preds[sent_idx][tok_idx])
180
 
181
  if not prefix_len:
182
  ret[-1].append([token])
 
232
  seen_prefixes.add(prefix)
233
 
234
  # check if we predicted this prefix
235
+ if token_logits[PREFIXES_TO_CLASS[prefix]]:
236
  cur_len += len(prefix)
237
  if last_check: break
238
  skip_next = len(prefix) > 1
BertForSyntaxParsing.py CHANGED
@@ -73,7 +73,7 @@ class BertSyntaxParsingHead(nn.Module):
73
  dep_indices = labels.dependency_labels.clamp_min(0)
74
  # Otherwise - check if he wants the MST or just the argmax
75
  elif compute_mst:
76
- dep_indices = compute_mst_tree(attention_scores)
77
  else:
78
  dep_indices = torch.argmax(attention_scores, dim=-1)
79
 
@@ -160,14 +160,16 @@ class BertForSyntaxParsing(BertPreTrainedModel):
160
  inputs = tokenizer(sentences, padding='longest', truncation=True, return_tensors='pt')
161
  inputs = {k:v.to(self.device) for k,v in inputs.items()}
162
  logits = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_mst).logits
163
- return parse_logits(inputs, sentences, tokenizer, logits)
164
 
165
- def parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: SyntaxLogitsOutput):
166
  outputs = []
 
 
167
  for i in range(len(sentences)):
168
  deps = logits.dependency_head_indices[i].tolist()
169
  funcs = logits.function_logits.argmax(-1)[i].tolist()
170
- toks = tokenizer.convert_ids_to_tokens(inputs['input_ids'][i])[1:-1] # ignore cls and sep
171
 
172
  # first, go through the tokens and create a mapping between each dependency index and the index without wordpieces
173
  # wordpieces. At the same time, append the wordpieces in
@@ -187,6 +189,8 @@ def parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenize
187
  continue
188
 
189
  dep_idx = deps[i + 1] - 1 # increase 1 for cls, decrease 1 for cls
 
 
190
  dep_head = 'root' if dep_idx == -1 else toks[dep_idx]
191
  dep_func = ALL_FUNCTION_LABELS[funcs[i + 1]]
192
 
@@ -200,7 +204,7 @@ def parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenize
200
  return outputs
201
 
202
 
203
- def compute_mst_tree(attention_scores: torch.Tensor):
204
  # attention scores should be 3 dimensions - batch x seq x seq (if it is 2 - just unsqueeze)
205
  if attention_scores.ndim == 2: attention_scores = attention_scores.unsqueeze(0)
206
  if attention_scores.ndim != 3 or attention_scores.shape[1] != attention_scores.shape[2]:
@@ -209,40 +213,58 @@ def compute_mst_tree(attention_scores: torch.Tensor):
209
  batch_size, seq_len, _ = attention_scores.shape
210
  # start by softmaxing so the scores are comparable
211
  attention_scores = attention_scores.softmax(dim=-1)
 
 
 
 
 
 
 
 
 
 
212
 
213
  # set the values for the CLS and sep to all by very low, so they never get chosen as a replacement arc
214
- attention_scores[:, 0, :] = -10000
215
- attention_scores[:, -1, :] = -10000
216
- attention_scores[:, :, -1] = -10000 # can never predict sep
 
 
217
 
218
  # find the root, and make him super high so we never have a conflict
219
  root_cands = torch.argsort(attention_scores[:, :, 0], dim=-1)
220
- batch_indices = torch.arange(batch_size, device=root_cands.device)
221
- attention_scores[batch_indices.unsqueeze(1), root_cands, 0] = -10000
222
- attention_scores[batch_indices, root_cands[:, -1], 0] = 10000
223
-
224
  # we start by getting the argmax for each score, and then computing the cycles and contracting them
225
  sorted_indices = torch.argsort(attention_scores, dim=-1, descending=True)
226
  indices = sorted_indices[:, :, 0].clone() # take the argmax
227
 
 
 
 
 
 
228
  # go through each batch item and make sure our tree works
229
  for batch_idx in range(batch_size):
230
  # We have one root - detect the cycles and contract them. A cycle can never contain the root so really
231
  # for every cycle, we look at all the nodes, and find the highest arc out of the cycle for any values. Replace that and tada
232
- has_cycle, cycle_nodes = detect_cycle(indices[batch_idx])
 
233
  while has_cycle:
234
- base_idx, head_idx = choose_contracting_arc(indices[batch_idx], sorted_indices[batch_idx], cycle_nodes, attention_scores[batch_idx])
235
  indices[batch_idx, base_idx] = head_idx
 
236
  # find the next cycle
237
- has_cycle, cycle_nodes = detect_cycle(indices[batch_idx])
238
-
239
  return indices
240
 
241
- def detect_cycle(indices: torch.LongTensor):
242
  # Simple cycle detection algorithm
243
  # Returns a boolean indicating if a cycle is detected and the nodes involved in the cycle
244
  visited = set()
245
- for node in range(1, len(indices) - 1): # ignore the CLS/SEP tokens
246
  if node in visited:
247
  continue
248
  current_path = set()
@@ -255,31 +277,36 @@ def detect_cycle(indices: torch.LongTensor):
255
  return True, current_path # Cycle detected
256
  return False, None
257
 
258
- def choose_contracting_arc(indices: torch.LongTensor, sorted_indices: torch.LongTensor, cycle_nodes: set, scores: torch.FloatTensor):
259
  # Chooses the highest-scoring, non-cycling arc from a graph. Iterates through 'cycle_nodes' to find
260
  # the best arc based on 'scores', avoiding cycles and zero node connections.
261
  # For each node, we only look at the next highest scoring non-cycling arc
262
  best_base_idx, best_head_idx = -1, -1
263
- score = float('-inf')
264
 
265
  # convert the indices to a list once, to avoid multiple conversions (saves a few seconds)
266
  currents = indices.tolist()
267
  for base_node in cycle_nodes:
 
268
  # we don't want to take anything that has a higher score than the current value - we can end up in an endless loop
269
  # Since the indices are sorted, as soon as we find our current item, we can move on to the next.
270
  current = currents[base_node]
271
  found_current = False
272
 
273
- for head_node in sorted_indices[base_node].tolist():
274
  if head_node == current:
275
  found_current = True
276
  continue
 
277
  if not found_current or head_node in cycle_nodes or head_node == 0:
278
  continue
279
 
280
- current_score = scores[base_node, head_node].item()
281
  if current_score > score:
282
  best_base_idx, best_head_idx, score = base_node, head_node, current_score
283
  break
284
 
 
 
 
285
  return best_base_idx, best_head_idx
 
73
  dep_indices = labels.dependency_labels.clamp_min(0)
74
  # Otherwise - check if he wants the MST or just the argmax
75
  elif compute_mst:
76
+ dep_indices = compute_mst_tree(attention_scores, extended_attention_mask)
77
  else:
78
  dep_indices = torch.argmax(attention_scores, dim=-1)
79
 
 
160
  inputs = tokenizer(sentences, padding='longest', truncation=True, return_tensors='pt')
161
  inputs = {k:v.to(self.device) for k,v in inputs.items()}
162
  logits = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_mst).logits
163
+ return parse_logits(inputs['input_ids'].tolist(), sentences, tokenizer, logits)
164
 
165
+ def parse_logits(input_ids: List[List[int]], sentences: List[str], tokenizer: BertTokenizerFast, logits: SyntaxLogitsOutput):
166
  outputs = []
167
+
168
+ special_toks = tokenizer.all_special_tokens
169
  for i in range(len(sentences)):
170
  deps = logits.dependency_head_indices[i].tolist()
171
  funcs = logits.function_logits.argmax(-1)[i].tolist()
172
+ toks = [tok for tok in tokenizer.convert_ids_to_tokens(input_ids[i]) if tok not in special_toks]
173
 
174
  # first, go through the tokens and create a mapping between each dependency index and the index without wordpieces
175
  # wordpieces. At the same time, append the wordpieces in
 
189
  continue
190
 
191
  dep_idx = deps[i + 1] - 1 # increase 1 for cls, decrease 1 for cls
192
+ if dep_idx == len(toks): dep_idx = i - 1 # if he predicts sep, then just point to the previous word
193
+
194
  dep_head = 'root' if dep_idx == -1 else toks[dep_idx]
195
  dep_func = ALL_FUNCTION_LABELS[funcs[i + 1]]
196
 
 
204
  return outputs
205
 
206
 
207
+ def compute_mst_tree(attention_scores: torch.Tensor, extended_attention_mask: torch.LongTensor):
208
  # attention scores should be 3 dimensions - batch x seq x seq (if it is 2 - just unsqueeze)
209
  if attention_scores.ndim == 2: attention_scores = attention_scores.unsqueeze(0)
210
  if attention_scores.ndim != 3 or attention_scores.shape[1] != attention_scores.shape[2]:
 
213
  batch_size, seq_len, _ = attention_scores.shape
214
  # start by softmaxing so the scores are comparable
215
  attention_scores = attention_scores.softmax(dim=-1)
216
+
217
+ batch_indices = torch.arange(batch_size, device=attention_scores.device)
218
+ seq_indices = torch.arange(seq_len, device=attention_scores.device)
219
+
220
+ seq_lens = torch.full((batch_size,), seq_len)
221
+
222
+ if extended_attention_mask is not None:
223
+ seq_lens = torch.argmax((extended_attention_mask != 0).int(), dim=2).squeeze(1)
224
+ # zero out any padding
225
+ attention_scores[extended_attention_mask.squeeze(1) != 0] = 0
226
 
227
  # set the values for the CLS and sep to all by very low, so they never get chosen as a replacement arc
228
+ attention_scores[:, 0, :] = 0
229
+ attention_scores[batch_indices, seq_lens - 1, :] = 0
230
+ attention_scores[batch_indices, :, seq_lens - 1] = 0 # can never predict sep
231
+ # set the values for each token pointing to itself be 0
232
+ attention_scores[:, seq_indices, seq_indices] = 0
233
 
234
  # find the root, and make him super high so we never have a conflict
235
  root_cands = torch.argsort(attention_scores[:, :, 0], dim=-1)
236
+ attention_scores[batch_indices.unsqueeze(1), root_cands, 0] = 0
237
+ attention_scores[batch_indices, root_cands[:, -1], 0] = 1.0
238
+
 
239
  # we start by getting the argmax for each score, and then computing the cycles and contracting them
240
  sorted_indices = torch.argsort(attention_scores, dim=-1, descending=True)
241
  indices = sorted_indices[:, :, 0].clone() # take the argmax
242
 
243
+ attention_scores = attention_scores.tolist()
244
+ seq_lens = seq_lens.tolist()
245
+ sorted_indices = [[sub_l[:slen] for sub_l in l[:slen]] for l,slen in zip(sorted_indices.tolist(), seq_lens)]
246
+
247
+
248
  # go through each batch item and make sure our tree works
249
  for batch_idx in range(batch_size):
250
  # We have one root - detect the cycles and contract them. A cycle can never contain the root so really
251
  # for every cycle, we look at all the nodes, and find the highest arc out of the cycle for any values. Replace that and tada
252
+ has_cycle, cycle_nodes = detect_cycle(indices[batch_idx], seq_lens[batch_idx])
253
+ contracted_arcs = set()
254
  while has_cycle:
255
+ base_idx, head_idx = choose_contracting_arc(indices[batch_idx], sorted_indices[batch_idx], cycle_nodes, contracted_arcs, seq_lens[batch_idx], attention_scores[batch_idx])
256
  indices[batch_idx, base_idx] = head_idx
257
+ contracted_arcs.add(base_idx)
258
  # find the next cycle
259
+ has_cycle, cycle_nodes = detect_cycle(indices[batch_idx], seq_lens[batch_idx])
260
+
261
  return indices
262
 
263
+ def detect_cycle(indices: torch.LongTensor, seq_len: int):
264
  # Simple cycle detection algorithm
265
  # Returns a boolean indicating if a cycle is detected and the nodes involved in the cycle
266
  visited = set()
267
+ for node in range(1, seq_len - 1): # ignore the CLS/SEP tokens
268
  if node in visited:
269
  continue
270
  current_path = set()
 
277
  return True, current_path # Cycle detected
278
  return False, None
279
 
280
+ def choose_contracting_arc(indices: torch.LongTensor, sorted_indices: List[List[int]], cycle_nodes: set, contracted_arcs: set, seq_len: int, scores: List[List[float]]):
281
  # Chooses the highest-scoring, non-cycling arc from a graph. Iterates through 'cycle_nodes' to find
282
  # the best arc based on 'scores', avoiding cycles and zero node connections.
283
  # For each node, we only look at the next highest scoring non-cycling arc
284
  best_base_idx, best_head_idx = -1, -1
285
+ score = 0
286
 
287
  # convert the indices to a list once, to avoid multiple conversions (saves a few seconds)
288
  currents = indices.tolist()
289
  for base_node in cycle_nodes:
290
+ if base_node in contracted_arcs: continue
291
  # we don't want to take anything that has a higher score than the current value - we can end up in an endless loop
292
  # Since the indices are sorted, as soon as we find our current item, we can move on to the next.
293
  current = currents[base_node]
294
  found_current = False
295
 
296
+ for head_node in sorted_indices[base_node]:
297
  if head_node == current:
298
  found_current = True
299
  continue
300
+ if head_node in contracted_arcs: continue
301
  if not found_current or head_node in cycle_nodes or head_node == 0:
302
  continue
303
 
304
+ current_score = scores[base_node][head_node]
305
  if current_score > score:
306
  best_base_idx, best_head_idx, score = base_node, head_node, current_score
307
  break
308
 
309
+ if best_base_idx == -1:
310
+ raise ValueError('Stuck in endless loop trying to compute syntax mst. Please try again setting compute_syntax_mst=False')
311
+
312
  return best_base_idx, best_head_idx