|
''' |
|
Adapted from https://github.com/huggingface/transformers |
|
''' |
|
|
|
from transformers import T5Config, T5ForConditionalGeneration |
|
from transformers.models.t5.modeling_t5 import T5Stack, __HEAD_MASK_WARNING_MSG, T5EncoderModel |
|
import copy |
|
import math |
|
import os |
|
import warnings |
|
from typing import Optional, Tuple, Union |
|
import torch |
|
from torch import nn |
|
from torch.nn import CrossEntropyLoss |
|
from transformers.modeling_outputs import ( |
|
BaseModelOutput, |
|
Seq2SeqLMOutput, |
|
) |
|
|
|
class T5ForMultimodalGeneration(T5ForConditionalGeneration): |
|
_keys_to_ignore_on_load_missing = [ |
|
r"encoder.embed_tokens.weight", |
|
r"decoder.embed_tokens.weight", |
|
r"lm_head.weight", |
|
] |
|
_keys_to_ignore_on_load_unexpected = [ |
|
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", |
|
] |
|
|
|
def __init__(self, config: T5Config, patch_size, padding_idx, save_dir): |
|
super().__init__(config) |
|
self.model_dim = config.d_model |
|
|
|
self.padding_idx = padding_idx |
|
self.out = open(os.path.join(save_dir, 'gate.txt'), 'w') |
|
|
|
self.shared = nn.Embedding(config.vocab_size, config.d_model) |
|
self.patch_num, self.patch_dim = patch_size |
|
|
|
self.image_dense = nn.Linear(self.patch_dim, config.d_model) |
|
self.mha_layer = torch.nn.MultiheadAttention(embed_dim=config.hidden_size, kdim=config.hidden_size, vdim=config.hidden_size, num_heads=1, batch_first=True) |
|
self.gate_dense = nn.Linear(2*config.hidden_size, config.hidden_size) |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
encoder_config = copy.deepcopy(config) |
|
encoder_config.is_decoder = False |
|
encoder_config.use_cache = False |
|
encoder_config.is_encoder_decoder = False |
|
self.encoder = T5Stack(encoder_config, self.shared) |
|
|
|
decoder_config = copy.deepcopy(config) |
|
decoder_config.is_decoder = True |
|
decoder_config.is_encoder_decoder = False |
|
decoder_config.num_layers = config.num_decoder_layers |
|
self.decoder = T5Stack(decoder_config, self.shared) |
|
|
|
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
self.model_parallel = False |
|
self.device_map = None |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
image_ids=None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
decoder_input_ids: Optional[torch.LongTensor] = None, |
|
decoder_attention_mask: Optional[torch.BoolTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
decoder_head_mask: Optional[torch.FloatTensor] = None, |
|
cross_attn_head_mask: Optional[torch.Tensor] = None, |
|
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if head_mask is not None and decoder_head_mask is None: |
|
if self.config.num_layers == self.config.num_decoder_layers: |
|
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) |
|
decoder_head_mask = head_mask |
|
|
|
|
|
if encoder_outputs is None: |
|
|
|
encoder_outputs = self.encoder( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
inputs_embeds=inputs_embeds, |
|
head_mask=head_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): |
|
encoder_outputs = BaseModelOutput( |
|
last_hidden_state=encoder_outputs[0], |
|
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, |
|
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, |
|
) |
|
|
|
|
|
hidden_states = encoder_outputs[0] |
|
|
|
image_embedding = self.image_dense(image_ids) |
|
image_att, _ = self.mha_layer(hidden_states, image_embedding, image_embedding) |
|
|
|
merge = torch.cat([hidden_states, image_att], dim=-1) |
|
gate = self.sigmoid(self.gate_dense(merge)) |
|
hidden_states = (1 - gate) * hidden_states + gate * image_att |
|
|
|
if self.model_parallel: |
|
torch.cuda.set_device(self.decoder.first_device) |
|
|
|
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: |
|
|
|
decoder_input_ids = self._shift_right(labels) |
|
|
|
|
|
if self.model_parallel: |
|
torch.cuda.set_device(self.decoder.first_device) |
|
hidden_states = hidden_states.to(self.decoder.first_device) |
|
if decoder_input_ids is not None: |
|
decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) |
|
if attention_mask is not None: |
|
attention_mask = attention_mask.to(self.decoder.first_device) |
|
if decoder_attention_mask is not None: |
|
decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) |
|
|
|
|
|
decoder_outputs = self.decoder( |
|
input_ids=decoder_input_ids, |
|
attention_mask=decoder_attention_mask, |
|
inputs_embeds=decoder_inputs_embeds, |
|
past_key_values=past_key_values, |
|
encoder_hidden_states=hidden_states, |
|
encoder_attention_mask=attention_mask, |
|
head_mask=decoder_head_mask, |
|
cross_attn_head_mask=cross_attn_head_mask, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
sequence_output = decoder_outputs[0] |
|
|
|
|
|
if self.model_parallel: |
|
torch.cuda.set_device(self.encoder.first_device) |
|
self.lm_head = self.lm_head.to(self.encoder.first_device) |
|
sequence_output = sequence_output.to(self.lm_head.weight.device) |
|
|
|
if self.config.tie_word_embeddings: |
|
|
|
|
|
sequence_output = sequence_output * (self.model_dim**-0.5) |
|
|
|
lm_logits = self.lm_head(sequence_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = CrossEntropyLoss(ignore_index=-100) |
|
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) |
|
|
|
|
|
if not return_dict: |
|
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return Seq2SeqLMOutput( |
|
loss=loss, |
|
logits=lm_logits, |
|
past_key_values=decoder_outputs.past_key_values, |
|
decoder_hidden_states=decoder_outputs.hidden_states, |
|
decoder_attentions=decoder_outputs.attentions, |
|
cross_attentions=decoder_outputs.cross_attentions, |
|
encoder_last_hidden_state=encoder_outputs.last_hidden_state, |
|
encoder_hidden_states=encoder_outputs.hidden_states, |
|
encoder_attentions=encoder_outputs.attentions, |
|
) |