Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2021 The HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" Classes to support Vision-Encoder-Text-Decoder architectures""" | |
import timeit | |
from typing import Optional | |
import torch | |
from torch import nn | |
from torch.nn import CrossEntropyLoss | |
from transformers.configuration_utils import PretrainedConfig | |
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput | |
from transformers.modeling_utils import PreTrainedModel | |
#from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings | |
from transformers.utils import logging | |
from transformers.models.auto.configuration_auto import AutoConfig | |
from transformers.models.auto.modeling_auto import AutoModel, AutoModelForCausalLM | |
from transformers.models.vision_encoder_decoder.configuration_vision_encoder_decoder import VisionEncoderDecoderConfig | |
import inspect | |
from .gpt2 import ThisGPT2LMHeadModel | |
from .gpt2 import ThisGPT2Config | |
from .xglm import ThisXGLMForCausalLM | |
from .xglm import ThisXGLMConfig | |
from .opt import ThisOPTForCausalLM | |
from .opt import ThisOPTConfig | |
# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right | |
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): | |
""" | |
Shift input ids one token to the right. | |
""" | |
shifted_input_ids = input_ids.new_zeros(input_ids.shape) | |
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() | |
if decoder_start_token_id is None: | |
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.") | |
shifted_input_ids[:, 0] = decoder_start_token_id | |
if pad_token_id is None: | |
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.") | |
# replace possible -100 values in labels by `pad_token_id` | |
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) | |
return shifted_input_ids | |
logger = logging.get_logger(__name__) | |
_CONFIG_FOR_DOC = "SmallCapConfig" | |
VISION_ENCODER_DECODER_START_DOCSTRING = r""" | |
This class can be used to initialize an image-to-text-sequence model with any pretrained vision autoencoding model | |
as the encoder and any pretrained text autoregressive model as the decoder. The encoder is loaded via | |
[`~AutoModel.from_pretrained`] function and the decoder is loaded via [`~AutoModelForCausalLM.from_pretrained`] | |
function. Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream | |
generative task, like image captioning. | |
The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation | |
tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation | |
Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi | |
Zhou, Wei Li, Peter J. Liu. | |
Additionally, in [TrOCR: Transformer-based Optical Character Recognition with Pre-trained | |
Models](https://arxiv.org/abs/2109.10282) it is shown how leveraging large pretrained vision models for optical | |
character recognition (OCR) yields a significant performance improvement. | |
After such a Vision-Encoder-Text-Decoder model has been trained/fine-tuned, it can be saved/loaded just like any | |
other models (see the examples for more information). | |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the | |
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads | |
etc.) | |
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. | |
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage | |
and behavior. | |
Parameters: | |
config ([`VisionEncoderDecoderConfig`]): Model configuration class with all the parameters of the model. | |
Initializing with a config file does not load the weights associated with the model, only the | |
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. | |
""" | |
VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r""" | |
Args: | |
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): | |
Pixel values. Pixel values can be obtained using a feature extractor (e.g. if you use ViT as the encoder, | |
you should use [`ViTFeatureExtractor`]). See [`ViTFeatureExtractor.__call__`] for details. | |
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): | |
Indices of decoder input sequence tokens in the vocabulary. | |
Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and | |
[`PreTrainedTokenizer.__call__`] for details. | |
[What are input IDs?](../glossary#input-ids) | |
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see | |
`past_key_values`). | |
For training, `decoder_input_ids` are automatically created by the model by shifting the `labels` to the | |
right, replacing -100 by the `pad_token_id` and prepending them with the `decoder_start_token_id`. | |
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): | |
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also | |
be used by default. | |
encoder_outputs (`tuple(torch.FloatTensor)`, *optional*): | |
This tuple must consist of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) | |
`last_hidden_state` (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`) is a tensor | |
of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the | |
decoder. | |
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): | |
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. | |
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that | |
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all | |
`decoder_input_ids` of shape `(batch_size, sequence_length)`. | |
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): | |
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded | |
representation. This is useful if you want more control over how to convert `decoder_input_ids` indices | |
into associated vectors than the model's internal embedding lookup matrix. | |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
Labels for computing the masked language modeling loss for the decoder. Indices should be in `[-100, 0, | |
..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored | |
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` | |
use_cache (`bool`, *optional*): | |
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see | |
`past_key_values`). | |
output_attentions (`bool`, *optional*): | |
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned | |
tensors for more detail. | |
output_hidden_states (`bool`, *optional*): | |
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for | |
more detail. | |
return_dict (`bool`, *optional*): | |
If set to `True`, the model will return a [`~utils.Seq2SeqLMOutput`] instead of a plain tuple. | |
kwargs: (*optional*) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors: | |
- Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function. | |
- With a *decoder_* prefix which will be input as `**decoder_kwargs` for the decoder forward function. | |
""" | |
class SmallCapConfig(VisionEncoderDecoderConfig): | |
model_type = "smallcap" | |
def __init__( | |
self, | |
**kwargs, | |
): | |
super().__init__(**kwargs) | |
class SmallCap(PreTrainedModel): | |
r""" | |
[`VisionEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with | |
one of the base vision model classes of the library as encoder and another one as decoder when created with the | |
:meth*~transformers.AutoModel.from_pretrained* class method for the encoder and | |
:meth*~transformers.AutoModelForCausalLM.from_pretrained* class method for the decoder. | |
""" | |
config_class = SmallCapConfig | |
base_model_prefix = "smallcap" | |
main_input_name = "pixel_values" | |
def __init__( | |
self, | |
config: Optional[PretrainedConfig] = None, | |
encoder: Optional[PreTrainedModel] = None, | |
decoder: Optional[PreTrainedModel] = None, | |
): | |
if config is None and (encoder is None or decoder is None): | |
raise ValueError("Either a configuration or an encoder and a decoder has to be provided.") | |
if config is None: | |
config = SmallCapConfig.from_encoder_decoder_configs(encoder.config, decoder.config) | |
else: | |
if not isinstance(config, self.config_class): | |
raise ValueError(f"Config: {config} has to be of type {self.config_class}") | |
if config.decoder.cross_attention_hidden_size is not None: | |
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: | |
raise ValueError( | |
"If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal#" | |
f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for" | |
f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for" | |
" `config.encoder.hidden_size`." | |
) | |
# initialize with config | |
# make sure input & output embeddings is not tied | |
config.tie_word_embeddings = False | |
super().__init__(config) | |
if encoder is None: | |
encoder = AutoModel.from_config(config.encoder) | |
if decoder is None: | |
decoder = AutoModelForCausalLM.from_config(config.decoder) | |
self.encoder = encoder.vision_model | |
self.encoder.main_input_name = 'pixel_values' | |
self.decoder = decoder | |
# make sure that the individual model's config refers to the shared config | |
# so that the updates to the config will be synced | |
self.encoder.config = self.config.encoder | |
self.decoder.config = self.config.decoder | |
def get_encoder(self): | |
return self.encoder | |
def get_decoder(self): | |
return self.decoder | |
def get_output_embeddings(self): | |
return self.decoder.get_output_embeddings() | |
def set_output_embeddings(self, new_embeddings): | |
return self.decoder.set_output_embeddings(new_embeddings) | |
def from_pretrained(cls, *args, **kwargs): | |
# At the moment fast initialization is not supported for composite models | |
if kwargs.get("_fast_init", False): | |
logger.warning( | |
"Fast initialization is currently not supported for VisionEncoderDecoderModel. " | |
"Falling back to slow initialization..." | |
) | |
kwargs["_fast_init"] = False | |
return super().from_pretrained(*args, **kwargs) | |
def from_encoder_decoder_pretrained( | |
cls, | |
encoder_pretrained_model_name_or_path: str = None, | |
decoder_pretrained_model_name_or_path: str = None, | |
cross_attention_reduce_factor: int = None, | |
*model_args, | |
**kwargs | |
) -> PreTrainedModel: | |
r""" | |
Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model | |
checkpoints. | |
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train | |
the model, you need to first set it back in training mode with `model.train()`. | |
Params: | |
encoder_pretrained_model_name_or_path (`str`, *optional*): | |
Information necessary to initiate the image encoder. Can be either: | |
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. An | |
example is `google/vit-base-patch16-224-in21k`. | |
- A path to a *directory* containing model weights saved using | |
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. | |
- A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In | |
this case, `from_tf` should be set to `True` and a configuration object should be provided as | |
`config` argument. This loading path is slower than converting the TensorFlow checkpoint in a | |
PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. | |
decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`): | |
Information necessary to initiate the text decoder. Can be either: | |
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. | |
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a | |
user or organization name, like `dbmdz/bert-base-german-cased`. | |
- A path to a *directory* containing model weights saved using | |
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. | |
- A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In | |
this case, `from_tf` should be set to `True` and a configuration object should be provided as | |
`config` argument. This loading path is slower than converting the TensorFlow checkpoint in a | |
PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. | |
model_args (remaining positional arguments, *optional*): | |
All remaning positional arguments will be passed to the underlying model's `__init__` method. | |
kwargs (remaining dictionary of keyword arguments, *optional*): | |
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., | |
`output_attentions=True`). | |
- To update the encoder configuration, use the prefix *encoder_* for each configuration parameter. | |
- To update the decoder configuration, use the prefix *decoder_* for each configuration parameter. | |
- To update the parent model configuration, do not use a prefix for each configuration parameter. | |
Behaves differently depending on whether a `config` is provided or automatically loaded. | |
Example: | |
```python | |
>>> from transformers import VisionEncoderDecoderModel | |
>>> # initialize a vit-bert from a pretrained ViT and a pretrained BERT model. Note that the cross-attention layers will be randomly initialized | |
>>> model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained( | |
... "google/vit-base-patch16-224-in21k", "bert-base-uncased" | |
... ) | |
>>> # saving model after fine-tuning | |
>>> model.save_pretrained("./vit-bert") | |
>>> # load fine-tuned model | |
>>> model = VisionEncoderDecoderModel.from_pretrained("./vit-bert") | |
```""" | |
kwargs_encoder = { | |
argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_") | |
} | |
kwargs_decoder = { | |
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") | |
} | |
# remove encoder, decoder kwargs from kwargs | |
for key in kwargs_encoder.keys(): | |
del kwargs["encoder_" + key] | |
for key in kwargs_decoder.keys(): | |
del kwargs["decoder_" + key] | |
# Load and initialize the encoder and decoder | |
# The distinction between encoder and decoder at the model level is made | |
# by the value of the flag `is_decoder` that we need to set correctly. | |
encoder = kwargs_encoder.pop("model", None) | |
if encoder is None: | |
if encoder_pretrained_model_name_or_path is None: | |
raise ValueError( | |
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " | |
"to be defined." | |
) | |
if "config" not in kwargs_encoder: | |
encoder_config, kwargs_encoder = AutoConfig.from_pretrained( | |
encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True | |
) | |
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: | |
logger.info( | |
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " | |
"from a decoder model. Cross-attention and casual mask are disabled." | |
) | |
encoder_config.is_decoder = False | |
encoder_config.add_cross_attention = False | |
kwargs_encoder["config"] = encoder_config | |
encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder) | |
decoder = kwargs_decoder.pop("model", None) | |
if decoder is None: | |
if decoder_pretrained_model_name_or_path is None: | |
raise ValueError( | |
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " | |
"to be defined." | |
) | |
if "config" not in kwargs_decoder: | |
if "xglm" in decoder_pretrained_model_name_or_path: | |
decoder_config, kwargs_decoder = ThisXGLMConfig.from_pretrained( | |
decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True | |
) | |
elif "opt" in decoder_pretrained_model_name_or_path: | |
decoder_config, kwargs_decoder = ThisOPTConfig.from_pretrained( | |
decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True | |
) | |
else: | |
decoder_config, kwargs_decoder = ThisGPT2Config.from_pretrained( | |
decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True | |
) | |
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: | |
logger.info( | |
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention" | |
f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if" | |
f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." | |
) | |
decoder_config.is_decoder = True | |
decoder_config.add_cross_attention = True | |
decoder_config.encoder_hidden_size = encoder.config.vision_config.hidden_size | |
decoder_config.cross_attention_reduce_factor = cross_attention_reduce_factor | |
kwargs_decoder["config"] = decoder_config | |
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: | |
logger.warning( | |
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " | |
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " | |
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " | |
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a " | |
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`" | |
) | |
#decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) | |
if "xglm" in decoder_pretrained_model_name_or_path: | |
decoder = ThisXGLMForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) | |
elif "opt" in decoder_pretrained_model_name_or_path: | |
decoder = ThisOPTForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) | |
else: | |
decoder = ThisGPT2LMHeadModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) | |
# instantiate config with corresponding kwargs | |
config = SmallCapConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) | |
# make sure input & output embeddings is not tied | |
config.tie_word_embeddings = False | |
return cls(encoder=encoder, decoder=decoder, config=config) | |
def forward( | |
self, | |
pixel_values=None, | |
decoder_input_ids=None, | |
decoder_attention_mask=None, | |
encoder_outputs=None, | |
past_key_values=None, | |
decoder_inputs_embeds=None, | |
labels=None, | |
use_cache=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
**kwargs, | |
): | |
r""" | |
Returns: | |
Examples: | |
```python | |
>>> from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
>>> import requests | |
>>> from PIL import Image | |
>>> import torch | |
>>> processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") | |
>>> model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") | |
>>> # load image from the IAM dataset | |
>>> url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg" | |
>>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB") | |
>>> # training | |
>>> model.config.decoder_start_token_id = processor.tokenizer.cls_token_id | |
>>> model.config.pad_token_id = processor.tokenizer.pad_token_id | |
>>> model.config.vocab_size = model.config.decoder.vocab_size | |
>>> pixel_values = processor(image, return_tensors="pt").pixel_values | |
>>> text = "hello world" | |
>>> labels = processor.tokenizer(text, return_tensors="pt").input_ids | |
>>> outputs = model(pixel_values=pixel_values, labels=labels) | |
>>> loss = outputs.loss | |
>>> # inference (generation) | |
>>> generated_ids = model.generate(pixel_values) | |
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
```""" | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} | |
kwargs_decoder = { | |
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") | |
} | |
if encoder_outputs is None: | |
if pixel_values is None: | |
raise ValueError("You have to specify pixel_values") | |
encoder_outputs = self.encoder( | |
pixel_values=pixel_values, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
**kwargs_encoder, | |
) | |
elif isinstance(encoder_outputs, tuple): | |
encoder_outputs = BaseModelOutput(*encoder_outputs) | |
else: | |
encoder_outputs = BaseModelOutput(encoder_outputs, None) | |
encoder_hidden_states = encoder_outputs[0] | |
# else: | |
encoder_attention_mask = None | |
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): | |
decoder_input_ids = shift_tokens_right( | |
labels, self.config.pad_token_id, self.config.decoder_start_token_id | |
) | |
# Decode | |
decoder_outputs = self.decoder( | |
input_ids=decoder_input_ids, | |
attention_mask=decoder_attention_mask, | |
encoder_hidden_states=encoder_hidden_states, | |
encoder_attention_mask=encoder_attention_mask, | |
inputs_embeds=decoder_inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
use_cache=use_cache, | |
past_key_values=past_key_values, | |
return_dict=return_dict, | |
**kwargs_decoder, | |
) | |
# Compute loss independent from decoder (as some shift the logits inside them) | |
loss = None | |
if labels is not None: | |
logits = decoder_outputs.logits if return_dict else decoder_outputs[0] | |
loss_fct = CrossEntropyLoss() | |
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1)) | |
if not return_dict: | |
if loss is not None: | |
return (loss,) + decoder_outputs + encoder_outputs | |
else: | |
return decoder_outputs + encoder_outputs | |
return Seq2SeqLMOutput( | |
loss=loss, | |
logits=decoder_outputs.logits, | |
past_key_values=decoder_outputs.past_key_values, | |
decoder_hidden_states=decoder_outputs.hidden_states, | |
decoder_attentions=decoder_outputs.attentions, | |
cross_attentions=decoder_outputs.cross_attentions, | |
encoder_last_hidden_state=encoder_outputs.last_hidden_state, | |
encoder_hidden_states=encoder_outputs.hidden_states, | |
encoder_attentions=encoder_outputs.attentions, | |
) | |
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): | |
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) | |
def prepare_inputs_for_generation( | |
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs | |
): | |
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past) | |
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None | |
input_dict = { | |
"attention_mask": attention_mask, | |
"decoder_attention_mask": decoder_attention_mask, | |
"decoder_input_ids": decoder_inputs["input_ids"], | |
"encoder_outputs": encoder_outputs, | |
"past_key_values": decoder_inputs["past_key_values"], | |
"use_cache": use_cache, | |
} | |
return input_dict | |
def resize_token_embeddings(self, *args, **kwargs): | |
raise NotImplementedError( | |
"Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the" | |
" respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))" | |
) | |
def _reorder_cache(self, past, beam_idx): | |
# apply decoder cache reordering here | |
return self.decoder._reorder_cache(past, beam_idx) | |