|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
import dataclasses |
|
import warnings |
|
from abc import ABC, abstractmethod |
|
from collections import OrderedDict |
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union |
|
|
|
import numpy as np |
|
from packaging import version |
|
|
|
from ..utils import TensorType, is_torch_available, is_vision_available, logging |
|
from .utils import ParameterFormat, compute_effective_axis_dimension, compute_serialized_parameters_size |
|
|
|
|
|
if TYPE_CHECKING: |
|
from ..configuration_utils import PretrainedConfig |
|
from ..feature_extraction_utils import FeatureExtractionMixin |
|
from ..image_processing_utils import ImageProcessingMixin |
|
from ..tokenization_utils_base import PreTrainedTokenizerBase |
|
|
|
|
|
if is_vision_available(): |
|
from PIL import Image |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
DEFAULT_ONNX_OPSET = 11 |
|
|
|
|
|
EXTERNAL_DATA_FORMAT_SIZE_LIMIT = 2 * 1024 * 1024 * 1024 |
|
|
|
|
|
@dataclasses.dataclass |
|
class PatchingSpec: |
|
""" |
|
Data class that holds patching specifications. |
|
|
|
Args: |
|
o: Module / object where the op to patch is located |
|
name: Name of the op to monkey patch |
|
custom_op: Custom op that patches the original op |
|
orig_op: Original op that is being patched |
|
op_wrapper: Wrapper (optional) that wraps both the original and custom ops. |
|
It is useful for ops that are class or static methods for instance. |
|
""" |
|
|
|
o: Any |
|
name: str |
|
custom_op: Callable |
|
orig_op: Optional[Callable] = None |
|
op_wrapper: Optional[Callable] = None |
|
|
|
|
|
class OnnxConfig(ABC): |
|
""" |
|
Base class for ONNX exportable model describing metadata on how to export the model through the ONNX format. |
|
""" |
|
|
|
default_fixed_batch = 2 |
|
default_fixed_sequence = 8 |
|
default_fixed_num_choices = 4 |
|
torch_onnx_minimum_version = version.parse("1.8") |
|
_tasks_to_common_outputs = { |
|
"causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), |
|
"default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}), |
|
"image-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), |
|
"image-segmentation": OrderedDict( |
|
{ |
|
"logits": {0: "batch", 1: "sequence"}, |
|
"pred_boxes": {0: "batch", 1: "sequence"}, |
|
"pred_masks": {0: "batch", 1: "sequence"}, |
|
} |
|
), |
|
"masked-im": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), |
|
"masked-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), |
|
"multiple-choice": OrderedDict({"logits": {0: "batch"}}), |
|
"object-detection": OrderedDict( |
|
{ |
|
"logits": {0: "batch", 1: "sequence"}, |
|
"pred_boxes": {0: "batch", 1: "sequence"}, |
|
} |
|
), |
|
"question-answering": OrderedDict( |
|
{ |
|
"start_logits": {0: "batch", 1: "sequence"}, |
|
"end_logits": {0: "batch", 1: "sequence"}, |
|
} |
|
), |
|
"semantic-segmentation": OrderedDict({"logits": {0: "batch", 1: "num_labels", 2: "height", 3: "width"}}), |
|
"seq2seq-lm": OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}), |
|
"sequence-classification": OrderedDict({"logits": {0: "batch"}}), |
|
"token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), |
|
"vision2seq-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), |
|
"speech2seq-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), |
|
} |
|
|
|
def __init__(self, config: "PretrainedConfig", task: str = "default", patching_specs: List[PatchingSpec] = None): |
|
self._config = config |
|
|
|
if task not in self._tasks_to_common_outputs: |
|
raise ValueError( |
|
f"{task} is not a supported task, supported tasks: {self._tasks_to_common_outputs.keys()}" |
|
) |
|
self.task = task |
|
|
|
self._patching_specs = [] |
|
for spec in patching_specs if patching_specs is not None else []: |
|
final_spec = spec |
|
if spec.orig_op is None: |
|
final_spec = dataclasses.replace(spec, orig_op=getattr(spec.o, spec.name)) |
|
self._patching_specs.append(final_spec) |
|
|
|
@classmethod |
|
def from_model_config(cls, config: "PretrainedConfig", task: str = "default") -> "OnnxConfig": |
|
""" |
|
Instantiate a OnnxConfig for a specific model |
|
|
|
Args: |
|
config: The model's configuration to use when exporting to ONNX |
|
|
|
Returns: |
|
OnnxConfig for this model |
|
""" |
|
return cls(config, task=task) |
|
|
|
@property |
|
@abstractmethod |
|
def inputs(self) -> Mapping[str, Mapping[int, str]]: |
|
""" |
|
Mapping containing the axis definition of the input tensors to provide to the model |
|
|
|
Returns: |
|
For each input: its name associated to the axes symbolic name and the axis position within the tensor |
|
""" |
|
raise NotImplementedError() |
|
|
|
@property |
|
def outputs(self) -> Mapping[str, Mapping[int, str]]: |
|
""" |
|
Mapping containing the axis definition of the output tensors to provide to the model |
|
|
|
Returns: |
|
For each output: its name associated to the axes symbolic name and the axis position within the tensor |
|
""" |
|
common_outputs = self._tasks_to_common_outputs[self.task] |
|
return copy.deepcopy(common_outputs) |
|
|
|
@property |
|
def values_override(self) -> Optional[Mapping[str, Any]]: |
|
""" |
|
Dictionary of keys to override in the model's config before exporting |
|
|
|
Returns: |
|
Dictionary with the keys (and their corresponding values) to override |
|
""" |
|
if hasattr(self._config, "use_cache"): |
|
return {"use_cache": False} |
|
|
|
return None |
|
|
|
@property |
|
def default_batch_size(self) -> int: |
|
""" |
|
The default batch size to use if no other indication |
|
|
|
Returns: |
|
Integer > 0 |
|
""" |
|
|
|
return OnnxConfig.default_fixed_batch |
|
|
|
@property |
|
def default_sequence_length(self) -> int: |
|
""" |
|
The default sequence length to use if no other indication |
|
|
|
Returns: |
|
Integer > 0 |
|
""" |
|
return OnnxConfig.default_fixed_sequence |
|
|
|
@property |
|
def default_num_choices(self) -> int: |
|
""" |
|
The default number of choices to use if no other indication |
|
|
|
Returns: |
|
Integer > 0 |
|
""" |
|
return OnnxConfig.default_fixed_num_choices |
|
|
|
@property |
|
def default_onnx_opset(self) -> int: |
|
""" |
|
Which onnx opset to use when exporting the model |
|
|
|
Returns: |
|
Integer ONNX Opset version |
|
""" |
|
return DEFAULT_ONNX_OPSET |
|
|
|
@property |
|
def atol_for_validation(self) -> float: |
|
""" |
|
What absolute tolerance value to use during model conversion validation. |
|
|
|
Returns: |
|
Float absolute tolerance value. |
|
""" |
|
return 1e-5 |
|
|
|
@property |
|
def is_torch_support_available(self) -> bool: |
|
""" |
|
The minimum PyTorch version required to export the model. |
|
|
|
Returns: |
|
`bool`: Whether the installed version of PyTorch is compatible with the model. |
|
""" |
|
if is_torch_available(): |
|
from transformers.utils import get_torch_version |
|
|
|
return version.parse(get_torch_version()) >= self.torch_onnx_minimum_version |
|
else: |
|
return False |
|
|
|
@staticmethod |
|
def use_external_data_format(num_parameters: int) -> bool: |
|
""" |
|
Flag indicating if the model requires using external data format |
|
|
|
Args: |
|
num_parameters: Number of parameter on the model |
|
|
|
Returns: |
|
True if model.num_parameters() * size_of(float32) >= 2Gb False otherwise |
|
""" |
|
|
|
return ( |
|
compute_serialized_parameters_size(num_parameters, ParameterFormat.Float) |
|
>= EXTERNAL_DATA_FORMAT_SIZE_LIMIT |
|
) |
|
|
|
def _generate_dummy_images( |
|
self, batch_size: int = 2, num_channels: int = 3, image_height: int = 40, image_width: int = 40 |
|
): |
|
images = [] |
|
for _ in range(batch_size): |
|
data = np.random.rand(image_height, image_width, num_channels) * 255 |
|
images.append(Image.fromarray(data.astype("uint8")).convert("RGB")) |
|
return images |
|
|
|
def _generate_dummy_audio( |
|
self, batch_size: int = 2, sampling_rate: int = 22050, time_duration: float = 5.0, frequency: int = 220 |
|
): |
|
audio_data = [] |
|
for _ in range(batch_size): |
|
|
|
t = np.linspace(0, time_duration, int(time_duration * sampling_rate), endpoint=False) |
|
|
|
|
|
audio_data.append(0.5 * np.sin(2 * np.pi * frequency * t)) |
|
|
|
return audio_data |
|
|
|
def generate_dummy_inputs( |
|
self, |
|
preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin", "ImageProcessingMixin"], |
|
batch_size: int = -1, |
|
seq_length: int = -1, |
|
num_choices: int = -1, |
|
is_pair: bool = False, |
|
framework: Optional[TensorType] = None, |
|
num_channels: int = 3, |
|
image_width: int = 40, |
|
image_height: int = 40, |
|
sampling_rate: int = 22050, |
|
time_duration: float = 5.0, |
|
frequency: int = 220, |
|
tokenizer: "PreTrainedTokenizerBase" = None, |
|
) -> Mapping[str, Any]: |
|
""" |
|
Generate inputs to provide to the ONNX exporter for the specific framework |
|
|
|
Args: |
|
preprocessor: ([`PreTrainedTokenizerBase`], [`FeatureExtractionMixin`], or [`ImageProcessingMixin`]): |
|
The preprocessor associated with this model configuration. |
|
batch_size (`int`, *optional*, defaults to -1): |
|
The batch size to export the model for (-1 means dynamic axis). |
|
num_choices (`int`, *optional*, defaults to -1): |
|
The number of candidate answers provided for multiple choice task (-1 means dynamic axis). |
|
seq_length (`int`, *optional*, defaults to -1): |
|
The sequence length to export the model for (-1 means dynamic axis). |
|
is_pair (`bool`, *optional*, defaults to `False`): |
|
Indicate if the input is a pair (sentence 1, sentence 2) |
|
framework (`TensorType`, *optional*, defaults to `None`): |
|
The framework (PyTorch or TensorFlow) that the tokenizer will generate tensors for. |
|
num_channels (`int`, *optional*, defaults to 3): |
|
The number of channels of the generated images. |
|
image_width (`int`, *optional*, defaults to 40): |
|
The width of the generated images. |
|
image_height (`int`, *optional*, defaults to 40): |
|
The height of the generated images. |
|
sampling_rate (`int`, *optional* defaults to 22050) |
|
The sampling rate for audio data generation. |
|
time_duration (`float`, *optional* defaults to 5.0) |
|
Total seconds of sampling for audio data generation. |
|
frequency (`int`, *optional* defaults to 220) |
|
The desired natural frequency of generated audio. |
|
|
|
Returns: |
|
Mapping[str, Tensor] holding the kwargs to provide to the model's forward function |
|
""" |
|
from ..feature_extraction_utils import FeatureExtractionMixin |
|
from ..image_processing_utils import ImageProcessingMixin |
|
from ..tokenization_utils_base import PreTrainedTokenizerBase |
|
|
|
if isinstance(preprocessor, PreTrainedTokenizerBase) and tokenizer is not None: |
|
raise ValueError("You cannot provide both a tokenizer and a preprocessor to generate dummy inputs.") |
|
if tokenizer is not None: |
|
warnings.warn( |
|
"The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use" |
|
" `preprocessor` instead.", |
|
FutureWarning, |
|
) |
|
logger.warning("Overwriting the `preprocessor` argument with `tokenizer` to generate dummmy inputs.") |
|
preprocessor = tokenizer |
|
if isinstance(preprocessor, PreTrainedTokenizerBase): |
|
|
|
batch_size = compute_effective_axis_dimension( |
|
batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0 |
|
) |
|
|
|
token_to_add = preprocessor.num_special_tokens_to_add(is_pair) |
|
seq_length = compute_effective_axis_dimension( |
|
seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add |
|
) |
|
|
|
input_token = ( |
|
preprocessor.unk_token |
|
if (preprocessor.unk_token is not None and len(preprocessor.unk_token) > 0) |
|
else "0" |
|
) |
|
dummy_input = [" ".join([input_token]) * seq_length] * batch_size |
|
if self.task == "multiple-choice": |
|
|
|
|
|
num_choices = compute_effective_axis_dimension( |
|
num_choices, fixed_dimension=OnnxConfig.default_fixed_num_choices, num_token_to_add=0 |
|
) |
|
dummy_input = dummy_input * num_choices |
|
|
|
tokenized_input = preprocessor(dummy_input, text_pair=dummy_input) |
|
|
|
for k, v in tokenized_input.items(): |
|
tokenized_input[k] = [v[i : i + num_choices] for i in range(0, len(v), num_choices)] |
|
return dict(tokenized_input.convert_to_tensors(tensor_type=framework)) |
|
return dict(preprocessor(dummy_input, return_tensors=framework)) |
|
elif isinstance(preprocessor, ImageProcessingMixin): |
|
if preprocessor.model_input_names[0] != "pixel_values": |
|
raise ValueError( |
|
f"The `preprocessor` is an image processor ({preprocessor.__class__.__name__}) and expects" |
|
f' `model_input_names[0]` to be "pixel_values", but got {preprocessor.model_input_names[0]}' |
|
) |
|
|
|
batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch) |
|
dummy_input = self._generate_dummy_images(batch_size, num_channels, image_height, image_width) |
|
return dict(preprocessor(images=dummy_input, return_tensors=framework)) |
|
elif isinstance(preprocessor, FeatureExtractionMixin) and preprocessor.model_input_names[0] == "pixel_values": |
|
|
|
batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch) |
|
dummy_input = self._generate_dummy_images(batch_size, num_channels, image_height, image_width) |
|
return dict(preprocessor(images=dummy_input, return_tensors=framework)) |
|
elif ( |
|
isinstance(preprocessor, FeatureExtractionMixin) and preprocessor.model_input_names[0] == "input_features" |
|
): |
|
|
|
batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch) |
|
dummy_input = self._generate_dummy_audio(batch_size, sampling_rate, time_duration, frequency) |
|
return dict(preprocessor(dummy_input, return_tensors=framework)) |
|
else: |
|
raise ValueError( |
|
"Unable to generate dummy inputs for the model. Please provide a tokenizer or a preprocessor." |
|
) |
|
|
|
def generate_dummy_inputs_onnxruntime(self, reference_model_inputs: Mapping[str, Any]) -> Mapping[str, Any]: |
|
""" |
|
Generate inputs for ONNX Runtime using the reference model inputs. Override this to run inference with seq2seq |
|
models which have the encoder and decoder exported as separate ONNX files. |
|
|
|
Args: |
|
reference_model_inputs ([`Mapping[str, Tensor]`): |
|
Reference inputs for the model. |
|
|
|
Returns: |
|
`Mapping[str, Tensor]`: The mapping holding the kwargs to provide to the model's forward function |
|
""" |
|
return reference_model_inputs |
|
|
|
def patch_ops(self): |
|
for spec in self._patching_specs: |
|
custom_op = spec.custom_op if spec.op_wrapper is None else spec.op_wrapper(spec.custom_op) |
|
setattr(spec.o, spec.name, custom_op) |
|
|
|
def restore_ops(self): |
|
for spec in self._patching_specs: |
|
orig_op = spec.orig_op if spec.op_wrapper is None else spec.op_wrapper(spec.orig_op) |
|
setattr(spec.o, spec.name, orig_op) |
|
|
|
@classmethod |
|
def flatten_output_collection_property(cls, name: str, field: Iterable[Any]) -> Dict[str, Any]: |
|
""" |
|
Flatten any potential nested structure expanding the name of the field with the index of the element within the |
|
structure. |
|
|
|
Args: |
|
name: The name of the nested structure |
|
field: The structure to, potentially, be flattened |
|
|
|
Returns: |
|
(Dict[str, Any]): Outputs with flattened structure and key mapping this new structure. |
|
|
|
""" |
|
from itertools import chain |
|
|
|
return {f"{name}.{idx}": item for idx, item in enumerate(chain.from_iterable(field))} |
|
|
|
|
|
class OnnxConfigWithPast(OnnxConfig, ABC): |
|
def __init__( |
|
self, |
|
config: "PretrainedConfig", |
|
task: str = "default", |
|
patching_specs: List[PatchingSpec] = None, |
|
use_past: bool = False, |
|
): |
|
super().__init__(config, task=task, patching_specs=patching_specs) |
|
self.use_past = use_past |
|
|
|
@classmethod |
|
def with_past(cls, config: "PretrainedConfig", task: str = "default") -> "OnnxConfigWithPast": |
|
""" |
|
Instantiate a OnnxConfig with `use_past` attribute set to True |
|
|
|
Args: |
|
config: The underlying model's config to use when exporting to ONNX |
|
|
|
Returns: |
|
OnnxConfig with `.use_past = True` |
|
""" |
|
return cls(config, task=task, use_past=True) |
|
|
|
@property |
|
def outputs(self) -> Mapping[str, Mapping[int, str]]: |
|
common_outputs = super().outputs |
|
if self.use_past: |
|
self.fill_with_past_key_values_(common_outputs, direction="outputs") |
|
|
|
return common_outputs |
|
|
|
@property |
|
def values_override(self) -> Optional[Mapping[str, Any]]: |
|
if hasattr(self._config, "use_cache"): |
|
return {"use_cache": self.use_past} |
|
|
|
return None |
|
|
|
@property |
|
def num_layers(self) -> int: |
|
""" |
|
The number of layers attribute retrieved from the model config. Override this for model configs where the |
|
number of layers attribute is not called `num_layers`. |
|
""" |
|
if not hasattr(self._config, "num_layers"): |
|
raise AttributeError( |
|
"could not find the number of layers attribute in the model configuration, override the num_layers" |
|
" property of the model OnnxConfig to solve this" |
|
) |
|
return self._config.num_layers |
|
|
|
@property |
|
def num_attention_heads(self) -> int: |
|
""" |
|
The number of attention heads attribute retrieved from the model config. Override this for model configs where |
|
the number of attention heads attribute is not called `num_attention_heads`. |
|
""" |
|
if not hasattr(self._config, "num_attention_heads"): |
|
raise AttributeError( |
|
"could not find the number of attention heads attribute in the model configuration, override the" |
|
" num_attention_heads property of the model OnnxConfig to solve this" |
|
) |
|
return self._config.num_attention_heads |
|
|
|
def generate_dummy_inputs( |
|
self, |
|
tokenizer: "PreTrainedTokenizerBase", |
|
batch_size: int = -1, |
|
seq_length: int = -1, |
|
is_pair: bool = False, |
|
framework: Optional[TensorType] = None, |
|
) -> Mapping[str, Any]: |
|
|
|
common_inputs = super().generate_dummy_inputs( |
|
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework |
|
) |
|
|
|
if self.use_past: |
|
if not is_torch_available(): |
|
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") |
|
else: |
|
import torch |
|
|
|
batch, seqlen = common_inputs["input_ids"].shape |
|
|
|
past_key_values_length = seqlen + 2 |
|
shape = ( |
|
batch, |
|
self.num_attention_heads, |
|
past_key_values_length, |
|
self._config.hidden_size // self.num_attention_heads, |
|
) |
|
|
|
if "attention_mask" in common_inputs: |
|
mask_dtype = common_inputs["attention_mask"].dtype |
|
common_inputs["attention_mask"] = torch.cat( |
|
[common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], |
|
dim=1, |
|
) |
|
|
|
common_inputs["past_key_values"] = [] |
|
for _ in range(self.num_layers): |
|
common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape))) |
|
|
|
return common_inputs |
|
|
|
def fill_with_past_key_values_( |
|
self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str, inverted_values_shape: bool = False |
|
): |
|
""" |
|
Fill the input_or_outputs mapping with past_key_values dynamic axes considering. |
|
|
|
Args: |
|
inputs_or_outputs: The mapping to fill. |
|
direction: either "inputs" or "outputs", it specifies whether input_or_outputs is the input mapping or the |
|
output mapping, this is important for axes naming. |
|
inverted_values_shape: |
|
If `True`, store values on dynamic axis 1, else on axis 2. |
|
|
|
""" |
|
if direction not in ["inputs", "outputs"]: |
|
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') |
|
|
|
name = "past_key_values" if direction == "inputs" else "present" |
|
for i in range(self.num_layers): |
|
inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} |
|
if inverted_values_shape: |
|
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 1: "past_sequence + sequence"} |
|
else: |
|
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} |
|
|
|
def _flatten_past_key_values_(self, flattened_output, name, idx, t): |
|
flattened_output[f"{name}.{idx}.key"] = t[0] |
|
flattened_output[f"{name}.{idx}.value"] = t[1] |
|
|
|
def flatten_output_collection_property(self, name: str, field: Iterable[Any]) -> Dict[str, Any]: |
|
flattened_output = {} |
|
if name in ["present", "past_key_values"]: |
|
for idx, t in enumerate(field): |
|
self._flatten_past_key_values_(flattened_output, name, idx, t) |
|
else: |
|
flattened_output = super().flatten_output_collection_property(name, field) |
|
|
|
return flattened_output |
|
|
|
|
|
class OnnxSeq2SeqConfigWithPast(OnnxConfigWithPast): |
|
@property |
|
def outputs(self) -> Mapping[str, Mapping[int, str]]: |
|
common_outputs = super(OnnxConfigWithPast, self).outputs |
|
|
|
for name, axes_names in common_outputs.items(): |
|
sequence_name = "encoder_sequence" if "encoder" in name else "decoder_sequence" |
|
for axis_idx, name in axes_names.items(): |
|
if "sequence" in name: |
|
axes_names[axis_idx] = sequence_name |
|
|
|
else: |
|
axes_names[axis_idx] = name |
|
if self.use_past: |
|
self.fill_with_past_key_values_(common_outputs, direction="outputs") |
|
|
|
return common_outputs |
|
|
|
@property |
|
def num_layers(self) -> Tuple[int]: |
|
try: |
|
num_layers = super().num_layers |
|
num_layers = (num_layers, num_layers) |
|
except AttributeError: |
|
if hasattr(self._config, "encoder_layers") and hasattr(self._config, "decoder_layers"): |
|
num_layers = (self._config.encoder_layers, self._config.decoder_layers) |
|
else: |
|
raise AttributeError( |
|
"could not find the number of encoder and decoder layers attributes in the model configuration," |
|
" override the num_layers property of the model OnnxConfig to solve this" |
|
) |
|
|
|
return num_layers |
|
|
|
@property |
|
def num_attention_heads(self) -> Tuple[int]: |
|
try: |
|
num_attention_heads = super().num_attention_heads |
|
num_attention_heads = (num_attention_heads, num_attention_heads) |
|
except AttributeError: |
|
if hasattr(self._config, "encoder_attention_heads") and hasattr(self._config, "decoder_attention_heads"): |
|
num_attention_heads = (self._config.encoder_attention_heads, self._config.decoder_attention_heads) |
|
else: |
|
raise AttributeError( |
|
"could not find the number of attention heads for the encoder and the decoder attributes in the" |
|
" model configuration, override the num_attention_heads property of the model OnnxConfig to solve" |
|
" this" |
|
) |
|
return num_attention_heads |
|
|
|
def generate_dummy_inputs( |
|
self, |
|
tokenizer: "PreTrainedTokenizerBase", |
|
batch_size: int = -1, |
|
seq_length: int = -1, |
|
is_pair: bool = False, |
|
framework: Optional[TensorType] = None, |
|
) -> Mapping[str, Any]: |
|
encoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( |
|
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework |
|
) |
|
|
|
|
|
decoder_seq_length = seq_length if not self.use_past else 1 |
|
decoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( |
|
tokenizer, batch_size=batch_size, seq_length=decoder_seq_length, is_pair=is_pair, framework=framework |
|
) |
|
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()} |
|
common_inputs = dict(**encoder_inputs, **decoder_inputs) |
|
|
|
if self.use_past: |
|
if not is_torch_available(): |
|
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") |
|
else: |
|
import torch |
|
batch = common_inputs["input_ids"].shape[0] |
|
encoder_seq_length = common_inputs["input_ids"].shape[1] |
|
decoder_seq_length = common_inputs["decoder_input_ids"].shape[1] |
|
num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads |
|
encoder_shape = ( |
|
batch, |
|
num_encoder_attention_heads, |
|
encoder_seq_length, |
|
self._config.hidden_size // num_encoder_attention_heads, |
|
) |
|
decoder_shape = ( |
|
batch, |
|
num_decoder_attention_heads, |
|
|
|
decoder_seq_length + 3, |
|
self._config.hidden_size // num_decoder_attention_heads, |
|
) |
|
|
|
common_inputs["past_key_values"] = [] |
|
|
|
num_encoder_layers, num_decoder_layers = self.num_layers |
|
min_num_layers = min(num_encoder_layers, num_decoder_layers) |
|
max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers |
|
remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder" |
|
|
|
for _ in range(min_num_layers): |
|
|
|
|
|
common_inputs["past_key_values"].append( |
|
( |
|
torch.zeros(decoder_shape), |
|
torch.zeros(decoder_shape), |
|
torch.zeros(encoder_shape), |
|
torch.zeros(encoder_shape), |
|
) |
|
) |
|
|
|
|
|
shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape |
|
for _ in range(min_num_layers, max_num_layers): |
|
common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape))) |
|
|
|
return common_inputs |
|
|
|
def fill_with_past_key_values_(self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str): |
|
if direction not in ["inputs", "outputs"]: |
|
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') |
|
|
|
name = "past_key_values" if direction == "inputs" else "present" |
|
|
|
|
|
num_encoder_layers, num_decoder_layers = self.num_layers |
|
min_num_layers = min(num_encoder_layers, num_decoder_layers) |
|
max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers |
|
remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder" |
|
|
|
encoder_sequence = "past_encoder_sequence" |
|
decoder_sequence = "past_decoder_sequence" if direction == "inputs" else "past_decoder_sequence + sequence" |
|
|
|
for i in range(min_num_layers): |
|
inputs_or_outputs[f"{name}.{i}.decoder.key"] = {0: "batch", 2: decoder_sequence} |
|
inputs_or_outputs[f"{name}.{i}.decoder.value"] = {0: "batch", 2: decoder_sequence} |
|
inputs_or_outputs[f"{name}.{i}.encoder.key"] = {0: "batch", 2: encoder_sequence} |
|
inputs_or_outputs[f"{name}.{i}.encoder.value"] = {0: "batch", 2: encoder_sequence} |
|
|
|
for i in range(min_num_layers, max_num_layers): |
|
if remaining_side_name == "encoder": |
|
axes_info = {0: "batch", 2: encoder_sequence} |
|
else: |
|
axes_info = {0: "batch", 2: decoder_sequence} |
|
inputs_or_outputs[f"{name}.{i}.{remaining_side_name}.key"] = axes_info |
|
|
|
def _flatten_past_key_values_(self, flattened_output, name, idx, t): |
|
flattened_output[f"{name}.{idx}.decoder.key"] = t[0] |
|
flattened_output[f"{name}.{idx}.decoder.value"] = t[1] |
|
flattened_output[f"{name}.{idx}.encoder.key"] = t[2] |
|
flattened_output[f"{name}.{idx}.encoder.value"] = t[3] |
|
|