|
from transformers import AutoModel, AutoTokenizer, AutoConfig, PreTrainedModel |
|
import torch |
|
import torch.nn.functional as F |
|
import torch.nn as nn |
|
from torch.nn import CrossEntropyLoss |
|
from dataclasses import dataclass |
|
from .configuration import GECToRConfig |
|
from typing import List, Union, Optional, Tuple |
|
import os |
|
import json |
|
from huggingface_hub import snapshot_download, ModelCard |
|
|
|
@dataclass |
|
class GECToROutput: |
|
loss: torch.Tensor = None |
|
loss_d: torch.Tensor = None |
|
loss_labels: torch.Tensor = None |
|
logits_d: torch.Tensor = None |
|
logits_labels: torch.Tensor = None |
|
accuracy: torch.Tensor = None |
|
accuracy_d: torch.Tensor = None |
|
|
|
@dataclass |
|
class GECToRPredictionOutput: |
|
probability_labels: torch.Tensor = None |
|
probability_d: torch.Tensor = None |
|
pred_labels: List[List[str]] = None |
|
pred_label_ids: torch.Tensor = None |
|
max_error_probability: torch.Tensor = None |
|
|
|
class GECToR(PreTrainedModel): |
|
config_class = GECToRConfig |
|
def __init__( |
|
self, |
|
config: GECToRConfig |
|
): |
|
super().__init__(config) |
|
self.config = config |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
self.config.model_id |
|
) |
|
if self.config.has_add_pooling_layer: |
|
self.bert = AutoModel.from_pretrained( |
|
self.config.model_id, |
|
add_pooling_layer=False |
|
) |
|
else: |
|
self.bert = AutoModel.from_pretrained( |
|
self.config.model_id |
|
) |
|
|
|
self.bert.resize_token_embeddings(self.bert.config.vocab_size + 1) |
|
self.label_proj_layer = nn.Linear( |
|
self.bert.config.hidden_size, |
|
self.config.num_labels - 1 |
|
) |
|
self.d_proj_layer = nn.Linear( |
|
self.bert.config.hidden_size, |
|
self.config.d_num_labels - 1 |
|
) |
|
self.dropout = nn.Dropout(self.config.p_dropout) |
|
self.loss_fn = CrossEntropyLoss( |
|
label_smoothing=self.config.label_smoothing |
|
) |
|
|
|
self.post_init() |
|
self.tune_bert(False) |
|
|
|
def init_weight(self) -> None: |
|
self._init_weights(self.label_proj_layer) |
|
self._init_weights(self.d_proj_layer) |
|
|
|
def _init_weights(self, module) -> None: |
|
"""Initialize the weights""" |
|
if isinstance(module, nn.Linear): |
|
|
|
|
|
module.weight.data.normal_( |
|
mean=0.0, |
|
std=self.config.initializer_range |
|
) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
return |
|
|
|
def tune_bert(self, tune=True): |
|
|
|
for param in self.bert.parameters(): |
|
param.requires_grad = tune |
|
return |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
d_labels: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
word_masks: Optional[torch.Tensor] = None, |
|
) -> GECToROutput: |
|
bert_logits = self.bert( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
).last_hidden_state |
|
logits_d = self.d_proj_layer(bert_logits) |
|
logits_labels = self.label_proj_layer(self.dropout(bert_logits)) |
|
loss_d, loss_labels, loss = None, None, None |
|
accuracy, accuracy_d = None, None |
|
if d_labels is not None and labels is not None: |
|
pad_id = self.config.label2id[self.config.label_pad_token] |
|
|
|
labels[labels == pad_id] = -100 |
|
d_labels[labels == -100] = -100 |
|
loss_d = self.loss_fn( |
|
logits_d.view(-1, self.config.d_num_labels - 1), |
|
d_labels.view(-1) |
|
) |
|
loss_labels = self.loss_fn( |
|
logits_labels.view(-1, self.config.num_labels - 1), |
|
labels.view(-1) |
|
) |
|
loss = loss_d + loss_labels |
|
|
|
pred_labels = torch.argmax(logits_labels, dim=-1) |
|
accuracy = torch.sum( |
|
(labels == pred_labels) * word_masks |
|
) / torch.sum(word_masks) |
|
pred_d = torch.argmax(logits_d, dim=-1) |
|
accuracy_d = torch.sum( |
|
(d_labels == pred_d) * word_masks |
|
) / torch.sum(word_masks) |
|
|
|
return GECToROutput( |
|
loss=loss, |
|
loss_d=loss_d, |
|
loss_labels=loss_labels, |
|
logits_d=logits_d, |
|
logits_labels=logits_labels, |
|
accuracy=accuracy, |
|
accuracy_d=accuracy_d |
|
) |
|
|
|
def predict( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
word_masks: torch.Tensor, |
|
keep_confidence: float=0, |
|
min_error_prob: float=0 |
|
): |
|
with torch.no_grad(): |
|
outputs = self.forward( |
|
input_ids, |
|
attention_mask |
|
) |
|
probability_labels = F.softmax(outputs.logits_labels, dim=-1) |
|
probability_d = F.softmax(outputs.logits_d, dim=-1) |
|
|
|
|
|
keep_index = self.config.label2id[self.config.keep_label] |
|
probability_labels[:, :, keep_index] += keep_confidence |
|
incor_idx = self.config.d_label2id[self.config.incorrect_label] |
|
probability_d = probability_d[:, :, incor_idx] |
|
max_error_probability = torch.max(probability_d * word_masks, dim=-1)[0] |
|
probability_labels[max_error_probability < min_error_prob, :, keep_index] \ |
|
= float('inf') |
|
pred_label_ids = torch.argmax(probability_labels, dim=-1) |
|
|
|
def convert_ids_to_labels(ids, id2label): |
|
labels = [] |
|
for id in ids.tolist(): |
|
labels.append(id2label[id]) |
|
return labels |
|
|
|
pred_labels = [] |
|
for ids in pred_label_ids: |
|
labels = convert_ids_to_labels( |
|
ids, |
|
self.config.id2label |
|
) |
|
pred_labels.append(labels) |
|
|
|
return GECToRPredictionOutput( |
|
probability_labels=probability_labels, |
|
probability_d=probability_d, |
|
pred_labels=pred_labels, |
|
pred_label_ids=pred_label_ids, |
|
max_error_probability=max_error_probability |
|
) |
|
|