Spaces:
Running
Running
# Copyright (c) ONNX Project Contributors | |
# | |
# SPDX-License-Identifier: Apache-2.0 | |
import warnings | |
from typing import Any, Dict, NamedTuple, Union, cast | |
import numpy as np | |
from onnx import OptionalProto, SequenceProto, TensorProto | |
class TensorDtypeMap(NamedTuple): | |
np_dtype: np.dtype | |
storage_dtype: int | |
name: str | |
# tensor_dtype: (numpy type, storage type, string name) | |
TENSOR_TYPE_MAP = { | |
int(TensorProto.FLOAT): TensorDtypeMap( | |
np.dtype("float32"), int(TensorProto.FLOAT), "TensorProto.FLOAT" | |
), | |
int(TensorProto.UINT8): TensorDtypeMap( | |
np.dtype("uint8"), int(TensorProto.INT32), "TensorProto.UINT8" | |
), | |
int(TensorProto.INT8): TensorDtypeMap( | |
np.dtype("int8"), int(TensorProto.INT32), "TensorProto.INT8" | |
), | |
int(TensorProto.UINT16): TensorDtypeMap( | |
np.dtype("uint16"), int(TensorProto.INT32), "TensorProto.UINT16" | |
), | |
int(TensorProto.INT16): TensorDtypeMap( | |
np.dtype("int16"), int(TensorProto.INT32), "TensorProto.INT16" | |
), | |
int(TensorProto.INT32): TensorDtypeMap( | |
np.dtype("int32"), int(TensorProto.INT32), "TensorProto.INT32" | |
), | |
int(TensorProto.INT64): TensorDtypeMap( | |
np.dtype("int64"), int(TensorProto.INT64), "TensorProto.INT64" | |
), | |
int(TensorProto.BOOL): TensorDtypeMap( | |
np.dtype("bool"), int(TensorProto.INT32), "TensorProto.BOOL" | |
), | |
int(TensorProto.FLOAT16): TensorDtypeMap( | |
np.dtype("float16"), int(TensorProto.UINT16), "TensorProto.FLOAT16" | |
), | |
# Native numpy does not support bfloat16 so now use float32. | |
int(TensorProto.BFLOAT16): TensorDtypeMap( | |
np.dtype("float32"), int(TensorProto.UINT16), "TensorProto.BFLOAT16" | |
), | |
int(TensorProto.DOUBLE): TensorDtypeMap( | |
np.dtype("float64"), int(TensorProto.DOUBLE), "TensorProto.DOUBLE" | |
), | |
int(TensorProto.COMPLEX64): TensorDtypeMap( | |
np.dtype("complex64"), int(TensorProto.FLOAT), "TensorProto.COMPLEX64" | |
), | |
int(TensorProto.COMPLEX128): TensorDtypeMap( | |
np.dtype("complex128"), int(TensorProto.DOUBLE), "TensorProto.COMPLEX128" | |
), | |
int(TensorProto.UINT32): TensorDtypeMap( | |
np.dtype("uint32"), int(TensorProto.UINT32), "TensorProto.UINT32" | |
), | |
int(TensorProto.UINT64): TensorDtypeMap( | |
np.dtype("uint64"), int(TensorProto.UINT64), "TensorProto.UINT64" | |
), | |
int(TensorProto.STRING): TensorDtypeMap( | |
np.dtype("object"), int(TensorProto.STRING), "TensorProto.STRING" | |
), | |
# Native numpy does not support float8 types, so now use float32 for these types. | |
int(TensorProto.FLOAT8E4M3FN): TensorDtypeMap( | |
np.dtype("float32"), int(TensorProto.UINT8), "TensorProto.FLOAT8E4M3FN" | |
), | |
int(TensorProto.FLOAT8E4M3FNUZ): TensorDtypeMap( | |
np.dtype("float32"), int(TensorProto.UINT8), "TensorProto.FLOAT8E4M3FNUZ" | |
), | |
int(TensorProto.FLOAT8E5M2): TensorDtypeMap( | |
np.dtype("float32"), int(TensorProto.UINT8), "TensorProto.FLOAT8E5M2" | |
), | |
int(TensorProto.FLOAT8E5M2FNUZ): TensorDtypeMap( | |
np.dtype("float32"), int(TensorProto.UINT8), "TensorProto.FLOAT8E5M2FNUZ" | |
), | |
# Native numpy does not support uint4/int4 so now use uint8/int8 for these types. | |
int(TensorProto.UINT4): TensorDtypeMap( | |
np.dtype("uint8"), int(TensorProto.INT32), "TensorProto.UINT4" | |
), | |
int(TensorProto.INT4): TensorDtypeMap( | |
np.dtype("int8"), int(TensorProto.INT32), "TensorProto.INT4" | |
), | |
} | |
class DeprecatedWarningDict(dict): # type: ignore | |
def __init__( | |
self, | |
dictionary: Dict[int, Union[int, str, np.dtype]], | |
original_function: str, | |
future_function: str = "", | |
) -> None: | |
super().__init__(dictionary) | |
self._origin_function = original_function | |
self._future_function = future_function | |
def __eq__(self, other: object) -> bool: | |
if not isinstance(other, DeprecatedWarningDict): | |
return False | |
return ( | |
self._origin_function == other._origin_function | |
and self._future_function == other._future_function | |
) | |
def __getitem__(self, key: Union[int, str, np.dtype]) -> Any: | |
if not self._future_function: | |
warnings.warn( | |
str( | |
f"`mapping.{self._origin_function}` is now deprecated and will be removed in a future release." | |
"To silence this warning, please simply use if-else statement to get the corresponding value." | |
), | |
DeprecationWarning, | |
stacklevel=2, | |
) | |
else: | |
warnings.warn( | |
str( | |
f"`mapping.{self._origin_function}` is now deprecated and will be removed in a future release." | |
f"To silence this warning, please use `helper.{self._future_function}` instead." | |
), | |
DeprecationWarning, | |
stacklevel=2, | |
) | |
return super().__getitem__(key) | |
# This map is used for converting TensorProto values into numpy arrays | |
TENSOR_TYPE_TO_NP_TYPE = DeprecatedWarningDict( | |
{tensor_dtype: value.np_dtype for tensor_dtype, value in TENSOR_TYPE_MAP.items()}, | |
"TENSOR_TYPE_TO_NP_TYPE", | |
"tensor_dtype_to_np_dtype", | |
) | |
# This is only used to get keys into STORAGE_TENSOR_TYPE_TO_FIELD. | |
# TODO(https://github.com/onnx/onnx/issues/4554): Move these variables into _mapping.py | |
TENSOR_TYPE_TO_STORAGE_TENSOR_TYPE = DeprecatedWarningDict( | |
{ | |
tensor_dtype: value.storage_dtype | |
for tensor_dtype, value in TENSOR_TYPE_MAP.items() | |
}, | |
"TENSOR_TYPE_TO_STORAGE_TENSOR_TYPE", | |
"tensor_dtype_to_storage_tensor_dtype", | |
) | |
# NP_TYPE_TO_TENSOR_TYPE will be eventually removed in the future | |
# and _NP_TYPE_TO_TENSOR_TYPE will only be used internally | |
_NP_TYPE_TO_TENSOR_TYPE = { | |
v: k | |
for k, v in TENSOR_TYPE_TO_NP_TYPE.items() | |
if k | |
not in ( | |
TensorProto.BFLOAT16, | |
TensorProto.FLOAT8E4M3FN, | |
TensorProto.FLOAT8E4M3FNUZ, | |
TensorProto.FLOAT8E5M2, | |
TensorProto.FLOAT8E5M2FNUZ, | |
TensorProto.UINT4, | |
TensorProto.INT4, | |
) | |
} | |
# Currently native numpy does not support bfloat16 so TensorProto.BFLOAT16 is ignored for now | |
# Numpy float32 array is only reversed to TensorProto.FLOAT | |
NP_TYPE_TO_TENSOR_TYPE = DeprecatedWarningDict( | |
cast(Dict[int, Union[int, str, Any]], _NP_TYPE_TO_TENSOR_TYPE), | |
"NP_TYPE_TO_TENSOR_TYPE", | |
"np_dtype_to_tensor_dtype", | |
) | |
# STORAGE_TENSOR_TYPE_TO_FIELD will be eventually removed in the future | |
# and _STORAGE_TENSOR_TYPE_TO_FIELD will only be used internally | |
_STORAGE_TENSOR_TYPE_TO_FIELD = { | |
int(TensorProto.FLOAT): "float_data", | |
int(TensorProto.INT32): "int32_data", | |
int(TensorProto.INT64): "int64_data", | |
int(TensorProto.UINT8): "int32_data", | |
int(TensorProto.UINT16): "int32_data", | |
int(TensorProto.DOUBLE): "double_data", | |
int(TensorProto.COMPLEX64): "float_data", | |
int(TensorProto.COMPLEX128): "double_data", | |
int(TensorProto.UINT32): "uint64_data", | |
int(TensorProto.UINT64): "uint64_data", | |
int(TensorProto.STRING): "string_data", | |
int(TensorProto.BOOL): "int32_data", | |
} | |
STORAGE_TENSOR_TYPE_TO_FIELD = DeprecatedWarningDict( | |
cast(Dict[int, Union[int, str, Any]], _STORAGE_TENSOR_TYPE_TO_FIELD), | |
"STORAGE_TENSOR_TYPE_TO_FIELD", | |
) | |
# This map will be removed and there is no replacement for it | |
STORAGE_ELEMENT_TYPE_TO_FIELD = DeprecatedWarningDict( | |
{ | |
int(SequenceProto.TENSOR): "tensor_values", | |
int(SequenceProto.SPARSE_TENSOR): "sparse_tensor_values", | |
int(SequenceProto.SEQUENCE): "sequence_values", | |
int(SequenceProto.MAP): "map_values", | |
int(OptionalProto.OPTIONAL): "optional_value", | |
}, | |
"STORAGE_ELEMENT_TYPE_TO_FIELD", | |
) | |
# This map will be removed and there is no replacement for it | |
OPTIONAL_ELEMENT_TYPE_TO_FIELD = DeprecatedWarningDict( | |
{ | |
int(OptionalProto.TENSOR): "tensor_value", | |
int(OptionalProto.SPARSE_TENSOR): "sparse_tensor_value", | |
int(OptionalProto.SEQUENCE): "sequence_value", | |
int(OptionalProto.MAP): "map_value", | |
int(OptionalProto.OPTIONAL): "optional_value", | |
}, | |
"OPTIONAL_ELEMENT_TYPE_TO_FIELD", | |
) | |