|
from typing import Optional, Tuple, Union |
|
import torch |
|
from torch import nn |
|
from transformers import AutoModel, AutoModelForTokenClassification, DebertaV2PreTrainedModel, DebertaV2Model, DebertaV2Config |
|
from transformers.modeling_outputs import TokenClassifierOutput |
|
|
|
class DebertaForUTCConfig(DebertaV2Config): |
|
model_type = "deberta-utc" |
|
|
|
def create_projection_layer(hidden_size: int, dropout: float = 0.25, out_dim: int = None) -> nn.Sequential: |
|
""" |
|
Creates a projection layer with specified configurations. |
|
""" |
|
if out_dim is None: |
|
out_dim = hidden_size |
|
|
|
return nn.Sequential( |
|
nn.Linear(hidden_size, out_dim * 4), |
|
nn.ReLU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(out_dim * 4, out_dim) |
|
) |
|
|
|
class DebertaForUTCPreTrainedModel(DebertaV2PreTrainedModel): |
|
""" |
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
|
models. |
|
""" |
|
config_class = DebertaForUTCConfig |
|
|
|
class DebertaV2ForUTC(DebertaForUTCPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
|
|
self.deberta = DebertaV2Model(config) |
|
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
self.token_rep_layer = create_projection_layer(config.hidden_size, config.hidden_dropout_prob) |
|
|
|
self.prompt_rep_layer = create_projection_layer(config.hidden_size, config.hidden_dropout_prob, config.hidden_size*3) |
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, TokenClassifierOutput]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.deberta( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
sequence_output = outputs[0] |
|
batch_size, seq_len, hidden_size = sequence_output.shape |
|
|
|
token_rep = self.token_rep_layer(sequence_output) |
|
|
|
prompt_rep = self.prompt_rep_layer(sequence_output[:,0,:]).view(batch_size, 3, hidden_size) |
|
|
|
logits = torch.einsum('BLD,BCD->BLC', token_rep, prompt_rep) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return TokenClassifierOutput( |
|
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions |
|
) |
|
|
|
AutoModel.register(DebertaForUTCConfig, DebertaForUTCPreTrainedModel) |
|
AutoModelForTokenClassification.register(DebertaForUTCConfig, DebertaV2ForUTC) |