Kano001's picture
Upload 2707 files
dc2106c verified
raw
history blame
8.46 kB
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0
import warnings
from typing import Any, Dict, NamedTuple, Union, cast
import numpy as np
from onnx import OptionalProto, SequenceProto, TensorProto
class TensorDtypeMap(NamedTuple):
np_dtype: np.dtype
storage_dtype: int
name: str
# tensor_dtype: (numpy type, storage type, string name)
TENSOR_TYPE_MAP = {
int(TensorProto.FLOAT): TensorDtypeMap(
np.dtype("float32"), int(TensorProto.FLOAT), "TensorProto.FLOAT"
),
int(TensorProto.UINT8): TensorDtypeMap(
np.dtype("uint8"), int(TensorProto.INT32), "TensorProto.UINT8"
),
int(TensorProto.INT8): TensorDtypeMap(
np.dtype("int8"), int(TensorProto.INT32), "TensorProto.INT8"
),
int(TensorProto.UINT16): TensorDtypeMap(
np.dtype("uint16"), int(TensorProto.INT32), "TensorProto.UINT16"
),
int(TensorProto.INT16): TensorDtypeMap(
np.dtype("int16"), int(TensorProto.INT32), "TensorProto.INT16"
),
int(TensorProto.INT32): TensorDtypeMap(
np.dtype("int32"), int(TensorProto.INT32), "TensorProto.INT32"
),
int(TensorProto.INT64): TensorDtypeMap(
np.dtype("int64"), int(TensorProto.INT64), "TensorProto.INT64"
),
int(TensorProto.BOOL): TensorDtypeMap(
np.dtype("bool"), int(TensorProto.INT32), "TensorProto.BOOL"
),
int(TensorProto.FLOAT16): TensorDtypeMap(
np.dtype("float16"), int(TensorProto.UINT16), "TensorProto.FLOAT16"
),
# Native numpy does not support bfloat16 so now use float32.
int(TensorProto.BFLOAT16): TensorDtypeMap(
np.dtype("float32"), int(TensorProto.UINT16), "TensorProto.BFLOAT16"
),
int(TensorProto.DOUBLE): TensorDtypeMap(
np.dtype("float64"), int(TensorProto.DOUBLE), "TensorProto.DOUBLE"
),
int(TensorProto.COMPLEX64): TensorDtypeMap(
np.dtype("complex64"), int(TensorProto.FLOAT), "TensorProto.COMPLEX64"
),
int(TensorProto.COMPLEX128): TensorDtypeMap(
np.dtype("complex128"), int(TensorProto.DOUBLE), "TensorProto.COMPLEX128"
),
int(TensorProto.UINT32): TensorDtypeMap(
np.dtype("uint32"), int(TensorProto.UINT32), "TensorProto.UINT32"
),
int(TensorProto.UINT64): TensorDtypeMap(
np.dtype("uint64"), int(TensorProto.UINT64), "TensorProto.UINT64"
),
int(TensorProto.STRING): TensorDtypeMap(
np.dtype("object"), int(TensorProto.STRING), "TensorProto.STRING"
),
# Native numpy does not support float8 types, so now use float32 for these types.
int(TensorProto.FLOAT8E4M3FN): TensorDtypeMap(
np.dtype("float32"), int(TensorProto.UINT8), "TensorProto.FLOAT8E4M3FN"
),
int(TensorProto.FLOAT8E4M3FNUZ): TensorDtypeMap(
np.dtype("float32"), int(TensorProto.UINT8), "TensorProto.FLOAT8E4M3FNUZ"
),
int(TensorProto.FLOAT8E5M2): TensorDtypeMap(
np.dtype("float32"), int(TensorProto.UINT8), "TensorProto.FLOAT8E5M2"
),
int(TensorProto.FLOAT8E5M2FNUZ): TensorDtypeMap(
np.dtype("float32"), int(TensorProto.UINT8), "TensorProto.FLOAT8E5M2FNUZ"
),
# Native numpy does not support uint4/int4 so now use uint8/int8 for these types.
int(TensorProto.UINT4): TensorDtypeMap(
np.dtype("uint8"), int(TensorProto.INT32), "TensorProto.UINT4"
),
int(TensorProto.INT4): TensorDtypeMap(
np.dtype("int8"), int(TensorProto.INT32), "TensorProto.INT4"
),
}
class DeprecatedWarningDict(dict): # type: ignore
def __init__(
self,
dictionary: Dict[int, Union[int, str, np.dtype]],
original_function: str,
future_function: str = "",
) -> None:
super().__init__(dictionary)
self._origin_function = original_function
self._future_function = future_function
def __eq__(self, other: object) -> bool:
if not isinstance(other, DeprecatedWarningDict):
return False
return (
self._origin_function == other._origin_function
and self._future_function == other._future_function
)
def __getitem__(self, key: Union[int, str, np.dtype]) -> Any:
if not self._future_function:
warnings.warn(
str(
f"`mapping.{self._origin_function}` is now deprecated and will be removed in a future release."
"To silence this warning, please simply use if-else statement to get the corresponding value."
),
DeprecationWarning,
stacklevel=2,
)
else:
warnings.warn(
str(
f"`mapping.{self._origin_function}` is now deprecated and will be removed in a future release."
f"To silence this warning, please use `helper.{self._future_function}` instead."
),
DeprecationWarning,
stacklevel=2,
)
return super().__getitem__(key)
# This map is used for converting TensorProto values into numpy arrays
TENSOR_TYPE_TO_NP_TYPE = DeprecatedWarningDict(
{tensor_dtype: value.np_dtype for tensor_dtype, value in TENSOR_TYPE_MAP.items()},
"TENSOR_TYPE_TO_NP_TYPE",
"tensor_dtype_to_np_dtype",
)
# This is only used to get keys into STORAGE_TENSOR_TYPE_TO_FIELD.
# TODO(https://github.com/onnx/onnx/issues/4554): Move these variables into _mapping.py
TENSOR_TYPE_TO_STORAGE_TENSOR_TYPE = DeprecatedWarningDict(
{
tensor_dtype: value.storage_dtype
for tensor_dtype, value in TENSOR_TYPE_MAP.items()
},
"TENSOR_TYPE_TO_STORAGE_TENSOR_TYPE",
"tensor_dtype_to_storage_tensor_dtype",
)
# NP_TYPE_TO_TENSOR_TYPE will be eventually removed in the future
# and _NP_TYPE_TO_TENSOR_TYPE will only be used internally
_NP_TYPE_TO_TENSOR_TYPE = {
v: k
for k, v in TENSOR_TYPE_TO_NP_TYPE.items()
if k
not in (
TensorProto.BFLOAT16,
TensorProto.FLOAT8E4M3FN,
TensorProto.FLOAT8E4M3FNUZ,
TensorProto.FLOAT8E5M2,
TensorProto.FLOAT8E5M2FNUZ,
TensorProto.UINT4,
TensorProto.INT4,
)
}
# Currently native numpy does not support bfloat16 so TensorProto.BFLOAT16 is ignored for now
# Numpy float32 array is only reversed to TensorProto.FLOAT
NP_TYPE_TO_TENSOR_TYPE = DeprecatedWarningDict(
cast(Dict[int, Union[int, str, Any]], _NP_TYPE_TO_TENSOR_TYPE),
"NP_TYPE_TO_TENSOR_TYPE",
"np_dtype_to_tensor_dtype",
)
# STORAGE_TENSOR_TYPE_TO_FIELD will be eventually removed in the future
# and _STORAGE_TENSOR_TYPE_TO_FIELD will only be used internally
_STORAGE_TENSOR_TYPE_TO_FIELD = {
int(TensorProto.FLOAT): "float_data",
int(TensorProto.INT32): "int32_data",
int(TensorProto.INT64): "int64_data",
int(TensorProto.UINT8): "int32_data",
int(TensorProto.UINT16): "int32_data",
int(TensorProto.DOUBLE): "double_data",
int(TensorProto.COMPLEX64): "float_data",
int(TensorProto.COMPLEX128): "double_data",
int(TensorProto.UINT32): "uint64_data",
int(TensorProto.UINT64): "uint64_data",
int(TensorProto.STRING): "string_data",
int(TensorProto.BOOL): "int32_data",
}
STORAGE_TENSOR_TYPE_TO_FIELD = DeprecatedWarningDict(
cast(Dict[int, Union[int, str, Any]], _STORAGE_TENSOR_TYPE_TO_FIELD),
"STORAGE_TENSOR_TYPE_TO_FIELD",
)
# This map will be removed and there is no replacement for it
STORAGE_ELEMENT_TYPE_TO_FIELD = DeprecatedWarningDict(
{
int(SequenceProto.TENSOR): "tensor_values",
int(SequenceProto.SPARSE_TENSOR): "sparse_tensor_values",
int(SequenceProto.SEQUENCE): "sequence_values",
int(SequenceProto.MAP): "map_values",
int(OptionalProto.OPTIONAL): "optional_value",
},
"STORAGE_ELEMENT_TYPE_TO_FIELD",
)
# This map will be removed and there is no replacement for it
OPTIONAL_ELEMENT_TYPE_TO_FIELD = DeprecatedWarningDict(
{
int(OptionalProto.TENSOR): "tensor_value",
int(OptionalProto.SPARSE_TENSOR): "sparse_tensor_value",
int(OptionalProto.SEQUENCE): "sequence_value",
int(OptionalProto.MAP): "map_value",
int(OptionalProto.OPTIONAL): "optional_value",
},
"OPTIONAL_ELEMENT_TYPE_TO_FIELD",
)