from transformers.modeling_outputs import TokenClassifierOutput import torch import torch.nn as nn from transformers import PreTrainedModel, AutoModel, AutoConfig from torch.nn import CrossEntropyLoss from typing import Optional, Tuple, Union from configuration import ImpressoConfig import logging logger = logging.getLogger(__name__) class ExtendedMultitaskModelForTokenClassification(PreTrainedModel): config_class = ImpressoConfig _keys_to_ignore_on_load_missing = [r"position_ids"] def __init__(self, config, num_token_labels_dict): super().__init__(config) self.num_token_labels_dict = num_token_labels_dict self.config = config # self.bert = AutoModel.from_config(config) self.bert = AutoModel.from_pretrained( config.name_or_path, config=config.pretrained_config ) if "classifier_dropout" not in config.__dict__: classifier_dropout = 0.1 else: classifier_dropout = ( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) self.dropout = nn.Dropout(classifier_dropout) # Additional transformer layers self.transformer_encoder = nn.TransformerEncoder( nn.TransformerEncoderLayer( d_model=config.hidden_size, nhead=config.num_attention_heads ), num_layers=2, ) # For token classification, create a classifier for each task self.token_classifiers = nn.ModuleDict( { task: nn.Linear(config.hidden_size, num_labels) for task, num_labels in num_token_labels_dict.items() } ) # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, token_labels: Optional[dict] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: r""" token_labels (`dict` of `torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*): Labels for computing the token classification loss. Keys should match the tasks. """ return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) bert_kwargs = { "input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, "position_ids": position_ids, "head_mask": head_mask, "inputs_embeds": inputs_embeds, "output_attentions": output_attentions, "output_hidden_states": output_hidden_states, "return_dict": return_dict, } if any( keyword in self.config.name_or_path.lower() for keyword in ["llama", "deberta"] ): bert_kwargs.pop("token_type_ids") bert_kwargs.pop("head_mask") outputs = self.bert(**bert_kwargs) # For token classification token_output = outputs[0] token_output = self.dropout(token_output) # Pass through additional transformer layers token_output = self.transformer_encoder(token_output.transpose(0, 1)).transpose( 0, 1 ) # Collect the logits and compute the loss for each task task_logits = {} total_loss = 0 for task, classifier in self.token_classifiers.items(): logits = classifier(token_output) task_logits[task] = logits if token_labels and task in token_labels: loss_fct = CrossEntropyLoss() loss = loss_fct( logits.view(-1, self.num_token_labels_dict[task]), token_labels[task].view(-1), ) total_loss += loss if not return_dict: output = (task_logits,) + outputs[2:] return ((total_loss,) + output) if total_loss != 0 else output return TokenClassifierOutput( loss=total_loss, logits=task_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )