|
import math |
|
from transformers.utils import ModelOutput |
|
import torch |
|
from torch import nn |
|
from typing import Dict, List, Tuple, Optional, Union |
|
from dataclasses import dataclass |
|
from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast |
|
|
|
ALL_FUNCTION_LABELS = ["nsubj", "nsubj:cop", "punct", "mark", "mark:q", "case", "case:gen", "case:acc", "fixed", "obl", "det", "amod", "acl:relcl", "nmod", "cc", "conj", "root", "compound:smixut", "cop", "compound:affix", "advmod", "nummod", "appos", "nsubj:pass", "nmod:poss", "xcomp", "obj", "aux", "parataxis", "advcl", "ccomp", "csubj", "acl", "obl:tmod", "csubj:pass", "dep", "dislocated", "nmod:tmod", "nmod:npmod", "flat", "obl:npmod", "goeswith", "reparandum", "orphan", "list", "discourse", "iobj", "vocative", "expl", "flat:name"] |
|
|
|
@dataclass |
|
class SyntaxLogitsOutput(ModelOutput): |
|
dependency_logits: torch.FloatTensor = None |
|
function_logits: torch.FloatTensor = None |
|
dependency_head_indices: torch.LongTensor = None |
|
|
|
def detach(self): |
|
return SyntaxTaggingOutput(self.dependency_logits.detach(), self.function_logits.detach(), self.dependency_head_indices.detach()) |
|
|
|
@dataclass |
|
class SyntaxTaggingOutput(ModelOutput): |
|
loss: Optional[torch.FloatTensor] = None |
|
logits: Optional[SyntaxLogitsOutput] = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
@dataclass |
|
class SyntaxLabels(ModelOutput): |
|
dependency_labels: Optional[torch.LongTensor] = None |
|
function_labels: Optional[torch.LongTensor] = None |
|
|
|
def detach(self): |
|
return SyntaxLabels(self.dependency_labels.detach(), self.function_labels.detach()) |
|
|
|
def to(self, device): |
|
return SyntaxLabels(self.dependency_labels.to(device), self.function_labels.to(device)) |
|
|
|
class BertSyntaxParsingHead(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
|
|
|
|
self.head_size = config.syntax_head_size |
|
self.query = nn.Linear(config.hidden_size, self.head_size) |
|
self.key = nn.Linear(config.hidden_size, self.head_size) |
|
|
|
self.num_function_classes = len(ALL_FUNCTION_LABELS) |
|
self.cls = nn.Linear(config.hidden_size * 2, self.num_function_classes) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
extended_attention_mask: Optional[torch.Tensor], |
|
labels: Optional[SyntaxLabels] = None, |
|
compute_mst: bool = False) -> Tuple[torch.Tensor, SyntaxLogitsOutput]: |
|
|
|
|
|
query_layer = self.query(hidden_states) |
|
key_layer = self.key(hidden_states) |
|
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / math.sqrt(self.head_size) |
|
|
|
|
|
if extended_attention_mask is not None: |
|
if extended_attention_mask.ndim == 4: |
|
extended_attention_mask = extended_attention_mask.squeeze(1) |
|
attention_scores += extended_attention_mask |
|
|
|
|
|
|
|
if self.training and labels is not None: |
|
|
|
dep_indices = labels.dependency_labels.clamp_min(0) |
|
|
|
elif compute_mst: |
|
dep_indices = compute_mst_tree(attention_scores) |
|
else: |
|
dep_indices = torch.argmax(attention_scores, dim=-1) |
|
|
|
|
|
batch_indices = torch.arange(dep_indices.size(0)).view(-1, 1).expand(-1, dep_indices.size(1)).to(dep_indices.device) |
|
dep_vectors = hidden_states[batch_indices, dep_indices, :] |
|
|
|
|
|
cls_inputs = torch.cat((hidden_states, dep_vectors), dim=-1) |
|
function_logits = self.cls(cls_inputs) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = nn.CrossEntropyLoss() |
|
|
|
loss = loss_fct(attention_scores.view(-1, hidden_states.size(-2)), labels.dependency_labels.view(-1)) |
|
|
|
loss += loss_fct(function_logits.view(-1, self.num_function_classes), labels.function_labels.view(-1)) |
|
|
|
return (loss, SyntaxLogitsOutput(attention_scores, function_logits, dep_indices)) |
|
|
|
|
|
class BertForSyntaxParsing(BertPreTrainedModel): |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.bert = BertModel(config, add_pooling_layer=False) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
self.syntax = BertSyntaxParsingHead(config) |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
labels: Optional[SyntaxLabels] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
compute_syntax_mst: Optional[bool] = None, |
|
): |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
bert_outputs = self.bert( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
extended_attention_mask = None |
|
if attention_mask is not None: |
|
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.size()) |
|
|
|
loss, logits = self.syntax(self.dropout(bert_outputs[0]), extended_attention_mask, labels, compute_syntax_mst) |
|
|
|
if not return_dict: |
|
return (loss,(logits.dependency_logits, logits.function_logits)) + bert_outputs[2:] |
|
|
|
return SyntaxTaggingOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=bert_outputs.hidden_states, |
|
attentions=bert_outputs.attentions, |
|
) |
|
|
|
def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, compute_mst=True): |
|
if isinstance(sentences, str): |
|
sentences = [sentences] |
|
|
|
|
|
inputs = tokenizer(sentences, padding='longest', truncation=True, return_tensors='pt') |
|
inputs = {k:v.to(self.device) for k,v in inputs.items()} |
|
logits = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_mst).logits |
|
return parse_logits(inputs, sentences, tokenizer, logits) |
|
|
|
def parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: SyntaxLogitsOutput): |
|
outputs = [] |
|
for i in range(len(sentences)): |
|
deps = logits.dependency_head_indices[i].tolist() |
|
funcs = logits.function_logits.argmax(-1)[i].tolist() |
|
toks = tokenizer.convert_ids_to_tokens(inputs['input_ids'][i])[1:-1] |
|
|
|
|
|
|
|
idx_mapping = {-1:-1} |
|
real_idx = -1 |
|
for i in range(len(toks)): |
|
if not toks[i].startswith('##'): |
|
real_idx += 1 |
|
idx_mapping[i] = real_idx |
|
|
|
|
|
tree = [] |
|
root_idx = 0 |
|
for i in range(len(toks)): |
|
if toks[i].startswith('##'): |
|
tree[-1]['word'] += toks[i][2:] |
|
continue |
|
|
|
dep_idx = deps[i + 1] - 1 |
|
dep_head = 'root' if dep_idx == -1 else toks[dep_idx] |
|
dep_func = ALL_FUNCTION_LABELS[funcs[i + 1]] |
|
|
|
if dep_head == 'root': root_idx = len(tree) |
|
tree.append(dict(word=toks[i], dep_head_idx=idx_mapping[dep_idx], dep_func=dep_func)) |
|
|
|
for d in tree: |
|
d['dep_head'] = tree[d['dep_head_idx']]['word'] |
|
|
|
outputs.append(dict(tree=tree, root_idx=root_idx)) |
|
return outputs |
|
|
|
|
|
def compute_mst_tree(attention_scores: torch.Tensor): |
|
|
|
if attention_scores.ndim == 2: attention_scores = attention_scores.unsqueeze(0) |
|
if attention_scores.ndim != 3 or attention_scores.shape[1] != attention_scores.shape[2]: |
|
raise ValueError(f'Expected attention scores to be of shape batch x seq x seq, instead got {attention_scores.shape}') |
|
|
|
batch_size, seq_len, _ = attention_scores.shape |
|
|
|
attention_scores = attention_scores.softmax(dim=-1) |
|
|
|
|
|
attention_scores[:, 0, :] = -10000 |
|
attention_scores[:, -1, :] = -10000 |
|
attention_scores[:, :, -1] = -10000 |
|
|
|
|
|
root_cands = torch.argsort(attention_scores[:, :, 0], dim=-1) |
|
batch_indices = torch.arange(batch_size, device=root_cands.device) |
|
attention_scores[batch_indices.unsqueeze(1), root_cands, 0] = -10000 |
|
attention_scores[batch_indices, root_cands[:, -1], 0] = 10000 |
|
|
|
|
|
sorted_indices = torch.argsort(attention_scores, dim=-1, descending=True) |
|
indices = sorted_indices[:, :, 0].clone() |
|
|
|
|
|
for batch_idx in range(batch_size): |
|
|
|
|
|
has_cycle, cycle_nodes = detect_cycle(indices[batch_idx]) |
|
while has_cycle: |
|
base_idx, head_idx = choose_contracting_arc(indices[batch_idx], sorted_indices[batch_idx], cycle_nodes, attention_scores[batch_idx]) |
|
indices[batch_idx, base_idx] = head_idx |
|
|
|
has_cycle, cycle_nodes = detect_cycle(indices[batch_idx]) |
|
|
|
return indices |
|
|
|
def detect_cycle(indices: torch.LongTensor): |
|
|
|
|
|
visited = set() |
|
for node in range(1, len(indices) - 1): |
|
if node in visited: |
|
continue |
|
current_path = set() |
|
while node not in visited: |
|
visited.add(node) |
|
current_path.add(node) |
|
node = indices[node].item() |
|
if node == 0: break |
|
if node in current_path: |
|
return True, current_path |
|
return False, None |
|
|
|
def choose_contracting_arc(indices: torch.LongTensor, sorted_indices: torch.LongTensor, cycle_nodes: set, scores: torch.FloatTensor): |
|
|
|
|
|
|
|
best_base_idx, best_head_idx = -1, -1 |
|
score = float('-inf') |
|
|
|
|
|
currents = indices.tolist() |
|
for base_node in cycle_nodes: |
|
|
|
|
|
current = currents[base_node] |
|
found_current = False |
|
|
|
for head_node in sorted_indices[base_node].tolist(): |
|
if head_node == current: |
|
found_current = True |
|
continue |
|
if not found_current or head_node in cycle_nodes or head_node == 0: |
|
continue |
|
|
|
current_score = scores[base_node, head_node].item() |
|
if current_score > score: |
|
best_base_idx, best_head_idx, score = base_node, head_node, current_score |
|
break |
|
|
|
return best_base_idx, best_head_idx |