Hugging Face's logo Hugging Face Models Datasets Spaces Posts Docs Pricing CATIE-AQ / FAT5-small-UL2-fr private Feature Extraction Transformers Safetensors French flash_t5 flash-attention UL2 FAT5 custom_code Carbon Emissions 7 papers Model card Files and versions Community Settings FAT5-small-UL2-fr / custom_heads_flash_t5.py bourdoiscatie's picture bourdoiscatie Update custom_heads_flash_t5.py 3477ecb verified 8 days ago raw history blame edit delete No virus 16.9 kB import torch import torch.nn as nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss import copy from typing import Optional, Union, Tuple, List from transformers.modeling_outputs import ( Seq2SeqQuestionAnsweringModelOutput, QuestionAnsweringModelOutput, TokenClassifierOutput, BaseModelOutput, Seq2SeqSequenceClassifierOutput, SequenceClassifierOutput ) from .modeling_flash_t5 import FlashT5PreTrainedModel, FlashT5Stack, FlashT5Model, FlashT5EncoderModel from .configuration_flash_t5 import FlashT5Config ################## Encoder only head ################## class FlashT5ForTokenClassification(FlashT5PreTrainedModel): def __init__(self, config: FlashT5Config): super().__init__(config) self.num_labels = config.num_labels self.shared = nn.Embedding(config.vocab_size, config.d_model) self.encoder = FlashT5Stack(config, self.shared) self.dropout = nn.Dropout(config.classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() # Initialize classifier self.classifier.weight.data.normal_(mean=0.0, std=config.initializer_factor * 1.0) self.classifier.bias.data.zero_() self.model_parallel = False def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. Returns: """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] hidden_states = self.dropout(hidden_states) logits = self.classifier(hidden_states) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: output = (logits, outputs[2:-1]) return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class FlashT5ClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" def __init__(self, config: FlashT5Config): super().__init__() self.dense = nn.Linear(config.d_model, config.d_model) self.dropout = nn.Dropout(p=config.classifier_dropout) self.out_proj = nn.Linear(config.d_model, config.num_labels) # initialize weights factor = config.initializer_factor self.dense.weight.data.normal_(mean=0.0, std=factor * ((config.d_model) ** -0.5)) if hasattr(self.dense, "bias") and self.dense.bias is not None: self.dense.bias.data.zero_() self.out_proj.weight.data.normal_(mean=0.0, std=factor * ((config.d_model) ** -0.5)) if hasattr(self.out_proj, "bias") and self.out_proj.bias is not None: self.out_proj.bias.data.zero_() def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dropout(hidden_states) hidden_states = self.dense(hidden_states) hidden_states = torch.tanh(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.out_proj(hidden_states) return hidden_states class FlashT5ForSequenceClassification(FlashT5PreTrainedModel): _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"] def __init__(self, config: FlashT5Config): super().__init__(config) self.model_dim = config.d_model self.config.problem_type = None self.config.is_encoder_decoder = False self.shared = nn.Embedding(config.vocab_size, config.d_model) encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False encoder_config.is_encoder_decoder = False encoder_config.use_cache = False self.encoder = FlashT5Stack(encoder_config, self.shared) self.classification_head = FlashT5ClassificationHead(config) # Initialize weights and apply final processing self.post_init() self.model_parallel = False def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). Returns: """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: use_cache = False if input_ids is None and inputs_embeds is not None: raise NotImplementedError( f"Passing input embeddings is currently not supported for {self.__class__.__name__}" ) outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device) if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: raise ValueError("All examples must have the same number of tokens.") batch_size, _, hidden_size = sequence_output.shape sentence_representation = sequence_output[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :] logits = self.classification_head(sentence_representation) loss = None if labels is not None: labels = labels.to(logits.device) if self.config.problem_type is None: if self.config.num_labels == 1: self.config.problem_type = "regression" elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = nn.MSELoss() if self.config.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = nn.BCEWithLogitsLoss() loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions ) class FlashT5ForQuestionAnswering(FlashT5PreTrainedModel): _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"] def __init__(self, config: FlashT5Config): super().__init__(config) self.transformer = FlashT5EncoderModel(config) self.num_labels = config.num_labels self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() # Model parallel self.model_parallel = False def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, start_positions: Optional[torch.Tensor] = None, end_positions: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.FloatTensor], QuestionAnsweringModelOutput]: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence are not taken into account for computing the loss. end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the end of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence are not taken into account for computing the loss. Returns: """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict encoder_outputs = self.transformer( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = encoder_outputs[0] logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1).contiguous() end_logits = end_logits.squeeze(-1).contiguous() total_loss = None if start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1).to(start_logits.device) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1).to(end_logits.device) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions = start_positions.clamp(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 if not return_dict: output = (start_logits, end_logits) + encoder_outputs[1:] return ((total_loss,) + output) if total_loss is not None else output return QuestionAnsweringModelOutput( loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) class FlashT5ForQuestionAnswering(FlashT5PreTrainedModel): _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"] def __init__(self, config: FlashT5Config): super().__init__(config) self.shared = nn.Embedding(config.vocab_size, config.d_model) encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False encoder_config.is_encoder_decoder = False self.encoder = FlashT5Stack(encoder_config, self.shared) self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() self.qa_outputs.weight.data.normal_(mean=0.0, std=config.initializer_factor * 1.0) self.qa_outputs.bias.data.zero_() self.model_parallel = False def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, start_positions: Optional[torch.LongTensor] = None, end_positions: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, QuestionAnsweringModelOutput]: r""" Returns: Example: ```python >>> from transformers import AutoTokenizer, MTxEncoderForQuestionAnswering >>> tokenizer = AutoTokenizer.from_pretrained("MTx-small") >>> model = MTxEncoderForQuestionAnswering.from_pretrained("MTx-small") >>> input_ids = tokenizer( ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" ... ).input_ids # Batch size 1 >>> outputs = model(input_ids=input_ids) >>> start_logits = outputs.start_logits >>> end_logits = outputs.end_logits ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.encoder( input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, ) sequence_output = outputs[0] logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1).contiguous() end_logits = end_logits.squeeze(-1).contiguous() total_loss = None if start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1).to(start_logits.device) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1).to(end_logits.device) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions = start_positions.clamp(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 if not return_dict: output = (start_logits, end_logits) + outputs[1:] return ((total_loss,) + output) if total_loss is not None else output return QuestionAnsweringModelOutput( loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )