from dataclasses import dataclass from transformers.models.t5.modeling_t5 import ( T5Stack, T5Block, T5LayerNorm, T5LayerSelfAttention, T5LayerFF, T5LayerCrossAttention, T5PreTrainedModel, T5ForConditionalGeneration ) import torch import torch.nn as nn from torch.nn import CrossEntropyLoss from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple import copy from transformers.modeling_outputs import ModelOutput, BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput from transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from transformers.utils import logging from transformers import BeamScorer, BeamSearchScorer logger = logging.get_logger(__name__) # The encoder for input token sequence class JointEncoder(T5Stack): def __init__(self, config, embed_tokens=None): super(T5Stack, self).__init__(config) self.config = config self.embed_tokens = embed_tokens self.is_decoder = self.config.is_decoder assert self.config.is_decoder is False self.block = nn.ModuleList( [T5Block(config, has_relative_attention_bias=(i == 0)) for i in range(config.num_layers)] ) self.final_layer_norm = T5LayerNorm( config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) ## Set maximum 512 whole words in a source text self.whole_word_embeddings = nn.Embedding( 512, config.d_model ## config.d_model is 768 for base ) self.init_weights() self.model_parallel = False self.device_map = None def set_input_embeddings(self, new_embeddings): self.embed_tokens = new_embeddings def forward( self, input_ids=None, whole_word_ids=None, attention_mask=None, inputs_embeds=None, head_mask=None, past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): if inputs_embeds is None: assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" inputs_embeds = self.embed_tokens(input_ids) ### embedding step - add HERE ### if whole_word_ids is not None: whole_word_embeds = self.whole_word_embeddings(whole_word_ids) assert whole_word_embeds.shape[-1] == inputs_embeds.shape[-1] inputs_embeds = inputs_embeds + whole_word_embeds B, L = inputs_embeds.size()[:-1] if attention_mask is None: attention_mask = input_ids.ne(self.config.pad_token_id).to(dtype=inputs_embeds.dtype, device=inputs_embeds.device) # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask = self.get_extended_attention_mask( attention_mask, (B, L), inputs_embeds.device) # initialize past_key_values with `None` if past does not exist if past_key_values is None: past_key_values = [None] * len(self.block) # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None hidden_states = self.dropout(inputs_embeds) if self.config.num_layers > 0: assert self.block[0].layer[0].SelfAttention.has_relative_attention_bias seq_length = L q_len = seq_length k_len = seq_length # [1, n_heads, Q_len, K_len] text_position_bias = self.block[0].layer[0].SelfAttention.compute_bias( L, L) num_heads = text_position_bias.size(1) position_bias = text_position_bias.new_zeros( 1, num_heads, seq_length, seq_length) position_bias[:, :, :L, :L] = text_position_bias position_bias = position_bias + extended_attention_mask for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): layer_head_mask = head_mask[i] layer_outputs = layer_module( hidden_states, attention_mask=extended_attention_mask, position_bias=position_bias, encoder_hidden_states=None, encoder_attention_mask=None, encoder_decoder_position_bias=None, # head_mask=head_mask[i], layer_head_mask=layer_head_mask, past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, ) # layer_outputs is a tuple with: # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) hidden_states, present_key_value_state = layer_outputs[:2] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention weights), # (self-attention position bias), (cross-attention weights), (cross-attention position bias) # position_bias = layer_outputs[2] # append next layer key value states if use_cache: present_key_value_states = present_key_value_states + \ (present_key_value_state,) hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) # Add last layer if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple( v for v in [ hidden_states, present_key_value_states, all_hidden_states, all_attentions, all_cross_attentions, ] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=present_key_value_states, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, ) class P5(T5ForConditionalGeneration): _keys_to_ignore_on_load_missing = [ r"encoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight", r"lm_head\.weight", ] _keys_to_ignore_on_load_unexpected = [ r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", ] def __init__(self, config): super(T5ForConditionalGeneration, self).__init__(config) self.config = config self.model_dim = config.d_model self.shared = nn.Embedding(config.vocab_size, config.d_model) encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False encoder_config.use_cache = False encoder_config.is_encoder_decoder = False self.encoder = JointEncoder(encoder_config, self.shared) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True decoder_config.is_encoder_decoder = False self.decoder = T5Stack(decoder_config, self.shared) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) self.init_weights() self.model_parallel = False self.device_map = None def set_input_embeddings(self, new_embeddings): self.shared = new_embeddings self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) def extend_vocab(self, vocab_size): new_shared = nn.Embedding(vocab_size, self.config.d_model) old_weight = self.shared.weight.data.detach().clone() old_vocab_size = old_weight.size(0) new_shared.weight.data[:old_vocab_size, :] = old_weight self.shared = new_shared new_lm_head = nn.Linear(self.config.d_model, vocab_size, bias=False) old_weight = self.lm_head.weight.data.detach().clone() old_vocab_size = old_weight.size(0) new_lm_head.weight.data[:old_vocab_size, :] = old_weight self.lm_head = new_lm_head self.encoder.embed_tokens = self.shared self.decoder.embed_tokens = self.shared self.lm_head.weight = self.shared.weight self.config.vocab_size = vocab_size self.encoder.config.vocab_size = vocab_size self.decoder.config.vocab_size = vocab_size def forward( self, input_ids=None, whole_word_ids=None, attention_mask=None, encoder_outputs=None, decoder_input_ids=None, decoder_attention_mask=None, past_key_values=None, use_cache=None, labels=None, inputs_embeds=None, decoder_inputs_embeds=None, head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, reduce_loss=False, return_hidden_state=False, **kwargs, ): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if encoder_outputs is None: encoder_outputs = self.encoder( input_ids=input_ids, whole_word_ids=whole_word_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len( encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len( encoder_outputs) > 2 else None, ) hidden_states = encoder_outputs[0] if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: # get decoder inputs from shifting lm labels to the right decoder_input_ids = self._shift_right(labels) # If decoding with past key value states, only the last tokens # should be given as an input if past_key_values is not None: assert labels is None, "Decoder should not use cached key value states when training." if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids[:, -1:] if decoder_inputs_embeds is not None: decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] if attention_mask is None: attention_mask = input_ids.ne(self.config.pad_token_id).to(dtype=hidden_states.dtype, device=hidden_states.device) encoder_attention_mask = attention_mask # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, inputs_embeds=decoder_inputs_embeds, past_key_values=past_key_values, encoder_hidden_states=hidden_states, encoder_attention_mask=encoder_attention_mask, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = decoder_outputs[0] assert self.config.tie_word_embeddings is True if self.config.tie_word_embeddings: sequence_output = sequence_output * (self.model_dim ** -0.5) if return_hidden_state: return sequence_output lm_logits = self.lm_head(sequence_output) loss = None if labels is not None: if reduce_loss: loss_fct = CrossEntropyLoss(ignore_index=-100) else: loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none') loss = loss_fct( lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) return P5Seq2SeqLMOutput( loss=loss, logits=lm_logits, past_key_values=decoder_outputs.past_key_values, decoder_last_hidden_state=decoder_outputs.last_hidden_state, decoder_hidden_states=decoder_outputs.hidden_states, ) def prepare_inputs_for_generation( self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs): if past is not None: input_ids = input_ids[:, -1:] output = { "decoder_input_ids": input_ids, "past_key_values": past, "encoder_outputs": encoder_outputs, "attention_mask": attention_mask, "use_cache": use_cache, } return output @staticmethod def _expand_inputs_for_generation( input_ids: torch.LongTensor, expand_size: int = 1, is_encoder_decoder: bool = False, attention_mask: torch.LongTensor = None, encoder_outputs: ModelOutput = None, **model_kwargs ) -> Tuple[torch.LongTensor, Dict[str, Any]]: expanded_return_idx = ( torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device) ) input_ids = input_ids.index_select(0, expanded_return_idx) if "token_type_ids" in model_kwargs: token_type_ids = model_kwargs["token_type_ids"] model_kwargs["token_type_ids"] = token_type_ids.index_select( 0, expanded_return_idx) if attention_mask is not None: model_kwargs["attention_mask"] = attention_mask.index_select( 0, expanded_return_idx) if is_encoder_decoder: assert encoder_outputs is not None encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( 0, expanded_return_idx ) model_kwargs["encoder_outputs"] = encoder_outputs return input_ids, model_kwargs @dataclass class P5Seq2SeqLMOutput(ModelOutput): """ Base class for sequence-to-sequence language models outputs. Args: loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): Languaged modeling loss. logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`). Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be used (see ``past_key_values`` input) to speed up sequential decoding. decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. decoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the self-attention heads. encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder of the model. encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the self-attention heads. """ loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None past_key_values: Optional[List[torch.FloatTensor]] = None decoder_last_hidden_state: Optional[Tuple[torch.FloatTensor]] = None decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_last_hidden_state: Optional[torch.FloatTensor] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None