mart9992's picture
m
9231ab9
import os
from functools import partial, reduce
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type, Union
import transformers
from .. import PretrainedConfig, is_tf_available, is_torch_available
from ..utils import TF2_WEIGHTS_NAME, WEIGHTS_NAME, logging
from .config import OnnxConfig
if TYPE_CHECKING:
from transformers import PreTrainedModel, TFPreTrainedModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_torch_available():
from transformers.models.auto import (
AutoModel,
AutoModelForCausalLM,
AutoModelForImageClassification,
AutoModelForImageSegmentation,
AutoModelForMaskedImageModeling,
AutoModelForMaskedLM,
AutoModelForMultipleChoice,
AutoModelForObjectDetection,
AutoModelForQuestionAnswering,
AutoModelForSemanticSegmentation,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForSpeechSeq2Seq,
AutoModelForTokenClassification,
AutoModelForVision2Seq,
)
if is_tf_available():
from transformers.models.auto import (
TFAutoModel,
TFAutoModelForCausalLM,
TFAutoModelForMaskedLM,
TFAutoModelForMultipleChoice,
TFAutoModelForQuestionAnswering,
TFAutoModelForSemanticSegmentation,
TFAutoModelForSeq2SeqLM,
TFAutoModelForSequenceClassification,
TFAutoModelForTokenClassification,
)
if not is_torch_available() and not is_tf_available():
logger.warning(
"The ONNX export features are only supported for PyTorch or TensorFlow. You will not be able to export models"
" without one of these libraries installed."
)
def supported_features_mapping(
*supported_features: str, onnx_config_cls: str = None
) -> Dict[str, Callable[[PretrainedConfig], OnnxConfig]]:
"""
Generate the mapping between supported the features and their corresponding OnnxConfig for a given model.
Args:
*supported_features: The names of the supported features.
onnx_config_cls: The OnnxConfig full name corresponding to the model.
Returns:
The dictionary mapping a feature to an OnnxConfig constructor.
"""
if onnx_config_cls is None:
raise ValueError("A OnnxConfig class must be provided")
config_cls = transformers
for attr_name in onnx_config_cls.split("."):
config_cls = getattr(config_cls, attr_name)
mapping = {}
for feature in supported_features:
if "-with-past" in feature:
task = feature.replace("-with-past", "")
mapping[feature] = partial(config_cls.with_past, task=task)
else:
mapping[feature] = partial(config_cls.from_model_config, task=feature)
return mapping
class FeaturesManager:
_TASKS_TO_AUTOMODELS = {}
_TASKS_TO_TF_AUTOMODELS = {}
if is_torch_available():
_TASKS_TO_AUTOMODELS = {
"default": AutoModel,
"masked-lm": AutoModelForMaskedLM,
"causal-lm": AutoModelForCausalLM,
"seq2seq-lm": AutoModelForSeq2SeqLM,
"sequence-classification": AutoModelForSequenceClassification,
"token-classification": AutoModelForTokenClassification,
"multiple-choice": AutoModelForMultipleChoice,
"object-detection": AutoModelForObjectDetection,
"question-answering": AutoModelForQuestionAnswering,
"image-classification": AutoModelForImageClassification,
"image-segmentation": AutoModelForImageSegmentation,
"masked-im": AutoModelForMaskedImageModeling,
"semantic-segmentation": AutoModelForSemanticSegmentation,
"vision2seq-lm": AutoModelForVision2Seq,
"speech2seq-lm": AutoModelForSpeechSeq2Seq,
}
if is_tf_available():
_TASKS_TO_TF_AUTOMODELS = {
"default": TFAutoModel,
"masked-lm": TFAutoModelForMaskedLM,
"causal-lm": TFAutoModelForCausalLM,
"seq2seq-lm": TFAutoModelForSeq2SeqLM,
"sequence-classification": TFAutoModelForSequenceClassification,
"token-classification": TFAutoModelForTokenClassification,
"multiple-choice": TFAutoModelForMultipleChoice,
"question-answering": TFAutoModelForQuestionAnswering,
"semantic-segmentation": TFAutoModelForSemanticSegmentation,
}
# Set of model topologies we support associated to the features supported by each topology and the factory
_SUPPORTED_MODEL_TYPE = {
"albert": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls="models.albert.AlbertOnnxConfig",
),
"bart": supported_features_mapping(
"default",
"default-with-past",
"causal-lm",
"causal-lm-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
"sequence-classification",
"question-answering",
onnx_config_cls="models.bart.BartOnnxConfig",
),
# BEiT cannot be used with the masked image modeling autoclass, so this feature is excluded here
"beit": supported_features_mapping(
"default", "image-classification", onnx_config_cls="models.beit.BeitOnnxConfig"
),
"bert": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls="models.bert.BertOnnxConfig",
),
"big-bird": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls="models.big_bird.BigBirdOnnxConfig",
),
"bigbird-pegasus": supported_features_mapping(
"default",
"default-with-past",
"causal-lm",
"causal-lm-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
"sequence-classification",
"question-answering",
onnx_config_cls="models.bigbird_pegasus.BigBirdPegasusOnnxConfig",
),
"blenderbot": supported_features_mapping(
"default",
"default-with-past",
"causal-lm",
"causal-lm-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
onnx_config_cls="models.blenderbot.BlenderbotOnnxConfig",
),
"blenderbot-small": supported_features_mapping(
"default",
"default-with-past",
"causal-lm",
"causal-lm-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
onnx_config_cls="models.blenderbot_small.BlenderbotSmallOnnxConfig",
),
"bloom": supported_features_mapping(
"default",
"default-with-past",
"causal-lm",
"causal-lm-with-past",
"sequence-classification",
"token-classification",
onnx_config_cls="models.bloom.BloomOnnxConfig",
),
"camembert": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls="models.camembert.CamembertOnnxConfig",
),
"clip": supported_features_mapping(
"default",
onnx_config_cls="models.clip.CLIPOnnxConfig",
),
"codegen": supported_features_mapping(
"default",
"causal-lm",
onnx_config_cls="models.codegen.CodeGenOnnxConfig",
),
"convbert": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls="models.convbert.ConvBertOnnxConfig",
),
"convnext": supported_features_mapping(
"default",
"image-classification",
onnx_config_cls="models.convnext.ConvNextOnnxConfig",
),
"data2vec-text": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls="models.data2vec.Data2VecTextOnnxConfig",
),
"data2vec-vision": supported_features_mapping(
"default",
"image-classification",
# ONNX doesn't support `adaptive_avg_pool2d` yet
# "semantic-segmentation",
onnx_config_cls="models.data2vec.Data2VecVisionOnnxConfig",
),
"deberta": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
"token-classification",
"question-answering",
onnx_config_cls="models.deberta.DebertaOnnxConfig",
),
"deberta-v2": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls="models.deberta_v2.DebertaV2OnnxConfig",
),
"deit": supported_features_mapping(
"default", "image-classification", onnx_config_cls="models.deit.DeiTOnnxConfig"
),
"detr": supported_features_mapping(
"default",
"object-detection",
"image-segmentation",
onnx_config_cls="models.detr.DetrOnnxConfig",
),
"distilbert": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls="models.distilbert.DistilBertOnnxConfig",
),
"electra": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls="models.electra.ElectraOnnxConfig",
),
"flaubert": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls="models.flaubert.FlaubertOnnxConfig",
),
"gpt2": supported_features_mapping(
"default",
"default-with-past",
"causal-lm",
"causal-lm-with-past",
"sequence-classification",
"token-classification",
onnx_config_cls="models.gpt2.GPT2OnnxConfig",
),
"gptj": supported_features_mapping(
"default",
"default-with-past",
"causal-lm",
"causal-lm-with-past",
"question-answering",
"sequence-classification",
onnx_config_cls="models.gptj.GPTJOnnxConfig",
),
"gpt-neo": supported_features_mapping(
"default",
"default-with-past",
"causal-lm",
"causal-lm-with-past",
"sequence-classification",
onnx_config_cls="models.gpt_neo.GPTNeoOnnxConfig",
),
"groupvit": supported_features_mapping(
"default",
onnx_config_cls="models.groupvit.GroupViTOnnxConfig",
),
"ibert": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls="models.ibert.IBertOnnxConfig",
),
"imagegpt": supported_features_mapping(
"default", "image-classification", onnx_config_cls="models.imagegpt.ImageGPTOnnxConfig"
),
"layoutlm": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
"token-classification",
onnx_config_cls="models.layoutlm.LayoutLMOnnxConfig",
),
"layoutlmv3": supported_features_mapping(
"default",
"question-answering",
"sequence-classification",
"token-classification",
onnx_config_cls="models.layoutlmv3.LayoutLMv3OnnxConfig",
),
"levit": supported_features_mapping(
"default", "image-classification", onnx_config_cls="models.levit.LevitOnnxConfig"
),
"longt5": supported_features_mapping(
"default",
"default-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
onnx_config_cls="models.longt5.LongT5OnnxConfig",
),
"longformer": supported_features_mapping(
"default",
"masked-lm",
"multiple-choice",
"question-answering",
"sequence-classification",
"token-classification",
onnx_config_cls="models.longformer.LongformerOnnxConfig",
),
"marian": supported_features_mapping(
"default",
"default-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
"causal-lm",
"causal-lm-with-past",
onnx_config_cls="models.marian.MarianOnnxConfig",
),
"mbart": supported_features_mapping(
"default",
"default-with-past",
"causal-lm",
"causal-lm-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
"sequence-classification",
"question-answering",
onnx_config_cls="models.mbart.MBartOnnxConfig",
),
"mobilebert": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls="models.mobilebert.MobileBertOnnxConfig",
),
"mobilenet-v1": supported_features_mapping(
"default",
"image-classification",
onnx_config_cls="models.mobilenet_v1.MobileNetV1OnnxConfig",
),
"mobilenet-v2": supported_features_mapping(
"default",
"image-classification",
onnx_config_cls="models.mobilenet_v2.MobileNetV2OnnxConfig",
),
"mobilevit": supported_features_mapping(
"default",
"image-classification",
onnx_config_cls="models.mobilevit.MobileViTOnnxConfig",
),
"mt5": supported_features_mapping(
"default",
"default-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
onnx_config_cls="models.mt5.MT5OnnxConfig",
),
"m2m-100": supported_features_mapping(
"default",
"default-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
onnx_config_cls="models.m2m_100.M2M100OnnxConfig",
),
"owlvit": supported_features_mapping(
"default",
onnx_config_cls="models.owlvit.OwlViTOnnxConfig",
),
"perceiver": supported_features_mapping(
"image-classification",
"masked-lm",
"sequence-classification",
onnx_config_cls="models.perceiver.PerceiverOnnxConfig",
),
"poolformer": supported_features_mapping(
"default", "image-classification", onnx_config_cls="models.poolformer.PoolFormerOnnxConfig"
),
"rembert": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls="models.rembert.RemBertOnnxConfig",
),
"resnet": supported_features_mapping(
"default",
"image-classification",
onnx_config_cls="models.resnet.ResNetOnnxConfig",
),
"roberta": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls="models.roberta.RobertaOnnxConfig",
),
"roformer": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"token-classification",
"multiple-choice",
"question-answering",
"token-classification",
onnx_config_cls="models.roformer.RoFormerOnnxConfig",
),
"segformer": supported_features_mapping(
"default",
"image-classification",
"semantic-segmentation",
onnx_config_cls="models.segformer.SegformerOnnxConfig",
),
"squeezebert": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls="models.squeezebert.SqueezeBertOnnxConfig",
),
"swin": supported_features_mapping(
"default", "image-classification", onnx_config_cls="models.swin.SwinOnnxConfig"
),
"t5": supported_features_mapping(
"default",
"default-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
onnx_config_cls="models.t5.T5OnnxConfig",
),
"vision-encoder-decoder": supported_features_mapping(
"vision2seq-lm", onnx_config_cls="models.vision_encoder_decoder.VisionEncoderDecoderOnnxConfig"
),
"vit": supported_features_mapping(
"default", "image-classification", onnx_config_cls="models.vit.ViTOnnxConfig"
),
"whisper": supported_features_mapping(
"default",
"default-with-past",
"speech2seq-lm",
"speech2seq-lm-with-past",
onnx_config_cls="models.whisper.WhisperOnnxConfig",
),
"xlm": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls="models.xlm.XLMOnnxConfig",
),
"xlm-roberta": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls="models.xlm_roberta.XLMRobertaOnnxConfig",
),
"yolos": supported_features_mapping(
"default",
"object-detection",
onnx_config_cls="models.yolos.YolosOnnxConfig",
),
}
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values())))
@staticmethod
def get_supported_features_for_model_type(
model_type: str, model_name: Optional[str] = None
) -> Dict[str, Callable[[PretrainedConfig], OnnxConfig]]:
"""
Tries to retrieve the feature -> OnnxConfig constructor map from the model type.
Args:
model_type (`str`):
The model type to retrieve the supported features for.
model_name (`str`, *optional*):
The name attribute of the model object, only used for the exception message.
Returns:
The dictionary mapping each feature to a corresponding OnnxConfig constructor.
"""
model_type = model_type.lower()
if model_type not in FeaturesManager._SUPPORTED_MODEL_TYPE:
model_type_and_model_name = f"{model_type} ({model_name})" if model_name else model_type
raise KeyError(
f"{model_type_and_model_name} is not supported yet. "
f"Only {list(FeaturesManager._SUPPORTED_MODEL_TYPE.keys())} are supported. "
f"If you want to support {model_type} please propose a PR or open up an issue."
)
return FeaturesManager._SUPPORTED_MODEL_TYPE[model_type]
@staticmethod
def feature_to_task(feature: str) -> str:
return feature.replace("-with-past", "")
@staticmethod
def _validate_framework_choice(framework: str):
"""
Validates if the framework requested for the export is both correct and available, otherwise throws an
exception.
"""
if framework not in ["pt", "tf"]:
raise ValueError(
f"Only two frameworks are supported for ONNX export: pt or tf, but {framework} was provided."
)
elif framework == "pt" and not is_torch_available():
raise RuntimeError("Cannot export model to ONNX using PyTorch because no PyTorch package was found.")
elif framework == "tf" and not is_tf_available():
raise RuntimeError("Cannot export model to ONNX using TensorFlow because no TensorFlow package was found.")
@staticmethod
def get_model_class_for_feature(feature: str, framework: str = "pt") -> Type:
"""
Attempts to retrieve an AutoModel class from a feature name.
Args:
feature (`str`):
The feature required.
framework (`str`, *optional*, defaults to `"pt"`):
The framework to use for the export.
Returns:
The AutoModel class corresponding to the feature.
"""
task = FeaturesManager.feature_to_task(feature)
FeaturesManager._validate_framework_choice(framework)
if framework == "pt":
task_to_automodel = FeaturesManager._TASKS_TO_AUTOMODELS
else:
task_to_automodel = FeaturesManager._TASKS_TO_TF_AUTOMODELS
if task not in task_to_automodel:
raise KeyError(
f"Unknown task: {feature}. Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}"
)
return task_to_automodel[task]
@staticmethod
def determine_framework(model: str, framework: str = None) -> str:
"""
Determines the framework to use for the export.
The priority is in the following order:
1. User input via `framework`.
2. If local checkpoint is provided, use the same framework as the checkpoint.
3. Available framework in environment, with priority given to PyTorch
Args:
model (`str`):
The name of the model to export.
framework (`str`, *optional*, defaults to `None`):
The framework to use for the export. See above for priority if none provided.
Returns:
The framework to use for the export.
"""
if framework is not None:
return framework
framework_map = {"pt": "PyTorch", "tf": "TensorFlow"}
exporter_map = {"pt": "torch", "tf": "tf2onnx"}
if os.path.isdir(model):
if os.path.isfile(os.path.join(model, WEIGHTS_NAME)):
framework = "pt"
elif os.path.isfile(os.path.join(model, TF2_WEIGHTS_NAME)):
framework = "tf"
else:
raise FileNotFoundError(
"Cannot determine framework from given checkpoint location."
f" There should be a {WEIGHTS_NAME} for PyTorch"
f" or {TF2_WEIGHTS_NAME} for TensorFlow."
)
logger.info(f"Local {framework_map[framework]} model found.")
else:
if is_torch_available():
framework = "pt"
elif is_tf_available():
framework = "tf"
else:
raise EnvironmentError("Neither PyTorch nor TensorFlow found in environment. Cannot export to ONNX.")
logger.info(f"Framework not requested. Using {exporter_map[framework]} to export to ONNX.")
return framework
@staticmethod
def get_model_from_feature(
feature: str, model: str, framework: str = None, cache_dir: str = None
) -> Union["PreTrainedModel", "TFPreTrainedModel"]:
"""
Attempts to retrieve a model from a model's name and the feature to be enabled.
Args:
feature (`str`):
The feature required.
model (`str`):
The name of the model to export.
framework (`str`, *optional*, defaults to `None`):
The framework to use for the export. See `FeaturesManager.determine_framework` for the priority should
none be provided.
Returns:
The instance of the model.
"""
framework = FeaturesManager.determine_framework(model, framework)
model_class = FeaturesManager.get_model_class_for_feature(feature, framework)
try:
model = model_class.from_pretrained(model, cache_dir=cache_dir)
except OSError:
if framework == "pt":
logger.info("Loading TensorFlow model in PyTorch before exporting to ONNX.")
model = model_class.from_pretrained(model, from_tf=True, cache_dir=cache_dir)
else:
logger.info("Loading PyTorch model in TensorFlow before exporting to ONNX.")
model = model_class.from_pretrained(model, from_pt=True, cache_dir=cache_dir)
return model
@staticmethod
def check_supported_model_or_raise(
model: Union["PreTrainedModel", "TFPreTrainedModel"], feature: str = "default"
) -> Tuple[str, Callable]:
"""
Check whether or not the model has the requested features.
Args:
model: The model to export.
feature: The name of the feature to check if it is available.
Returns:
(str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties.
"""
model_type = model.config.model_type.replace("_", "-")
model_name = getattr(model, "name", "")
model_features = FeaturesManager.get_supported_features_for_model_type(model_type, model_name=model_name)
if feature not in model_features:
raise ValueError(
f"{model.config.model_type} doesn't support feature {feature}. Supported values are: {model_features}"
)
return model.config.model_type, FeaturesManager._SUPPORTED_MODEL_TYPE[model_type][feature]
def get_config(model_type: str, feature: str) -> OnnxConfig:
"""
Gets the OnnxConfig for a model_type and feature combination.
Args:
model_type (`str`):
The model type to retrieve the config for.
feature (`str`):
The feature to retrieve the config for.
Returns:
`OnnxConfig`: config for the combination
"""
return FeaturesManager._SUPPORTED_MODEL_TYPE[model_type][feature]