Spaces:
Running
Running
# Copyright (c) ONNX Project Contributors | |
# | |
# SPDX-License-Identifier: Apache-2.0 | |
__all__ = [ | |
"C", | |
"ONNX_DOMAIN", | |
"ONNX_ML_DOMAIN", | |
"AI_ONNX_PREVIEW_TRAINING_DOMAIN", | |
"has", | |
"register_schema", | |
"deregister_schema", | |
"get_schema", | |
"get_all_schemas", | |
"get_all_schemas_with_history", | |
"onnx_opset_version", | |
"get_function_ops", | |
"OpSchema", | |
"SchemaError", | |
] | |
from typing import List | |
import onnx.onnx_cpp2py_export.defs as C # noqa: N812 | |
from onnx import AttributeProto, FunctionProto | |
ONNX_DOMAIN = "" | |
ONNX_ML_DOMAIN = "ai.onnx.ml" | |
AI_ONNX_PREVIEW_TRAINING_DOMAIN = "ai.onnx.preview.training" | |
has = C.has_schema | |
get_schema = C.get_schema | |
get_all_schemas = C.get_all_schemas | |
get_all_schemas_with_history = C.get_all_schemas_with_history | |
deregister_schema = C.deregister_schema | |
def onnx_opset_version() -> int: | |
"""Return current opset for domain `ai.onnx`.""" | |
return C.schema_version_map()[ONNX_DOMAIN][1] | |
def onnx_ml_opset_version() -> int: | |
"""Return current opset for domain `ai.onnx.ml`.""" | |
return C.schema_version_map()[ONNX_ML_DOMAIN][1] | |
# type: ignore | |
def _function_proto(self): # type: ignore | |
func_proto = FunctionProto() | |
func_proto.ParseFromString(self._function_body) | |
return func_proto | |
OpSchema = C.OpSchema # type: ignore | |
OpSchema.function_body = _function_proto # type: ignore | |
# type: ignore | |
def _attribute_default_value(self): # type: ignore | |
attr = AttributeProto() | |
attr.ParseFromString(self._default_value) | |
return attr | |
OpSchema.Attribute.default_value = _attribute_default_value # type: ignore | |
def _op_schema_repr(self) -> str: | |
return f"""\ | |
OpSchema( | |
name={self.name!r}, | |
domain={self.domain!r}, | |
since_version={self.since_version!r}, | |
doc={self.doc!r}, | |
type_constraints={self.type_constraints!r}, | |
inputs={self.inputs!r}, | |
outputs={self.outputs!r}, | |
attributes={self.attributes!r} | |
)""" | |
OpSchema.__repr__ = _op_schema_repr # type: ignore | |
def _op_schema_formal_parameter_repr(self) -> str: | |
return ( | |
f"OpSchema.FormalParameter(name={self.name!r}, type_str={self.type_str!r}, " | |
f"description={self.description!r}, param_option={self.option!r}, " | |
f"is_homogeneous={self.is_homogeneous!r}, min_arity={self.min_arity!r}, " | |
f"differentiation_category={self.differentiation_category!r})" | |
) | |
OpSchema.FormalParameter.__repr__ = _op_schema_formal_parameter_repr # type: ignore | |
def _op_schema_type_constraint_param_repr(self) -> str: | |
return ( | |
f"OpSchema.TypeConstraintParam(type_param_str={self.type_param_str!r}, " | |
f"allowed_type_strs={self.allowed_type_strs!r}, description={self.description!r})" | |
) | |
OpSchema.TypeConstraintParam.__repr__ = _op_schema_type_constraint_param_repr # type: ignore | |
def _op_schema_attribute_repr(self) -> str: | |
return ( | |
f"OpSchema.Attribute(name={self.name!r}, type={self.type!r}, description={self.description!r}, " | |
f"default_value={self.default_value!r}, required={self.required!r})" | |
) | |
OpSchema.Attribute.__repr__ = _op_schema_attribute_repr # type: ignore | |
def get_function_ops() -> List[OpSchema]: | |
"""Return operators defined as functions.""" | |
schemas = C.get_all_schemas() | |
return [schema for schema in schemas if schema.has_function or schema.has_context_dependent_function] # type: ignore | |
SchemaError = C.SchemaError | |
def register_schema(schema: OpSchema) -> None: | |
"""Register a user provided OpSchema. | |
The function extends available operator set versions for the provided domain if necessary. | |
Args: | |
schema: The OpSchema to register. | |
""" | |
version_map = C.schema_version_map() | |
domain = schema.domain | |
version = schema.since_version | |
min_version, max_version = version_map.get(domain, (version, version)) | |
if domain not in version_map or not (min_version <= version <= max_version): | |
min_version = min(min_version, version) | |
max_version = max(max_version, version) | |
C.set_domain_to_version(schema.domain, min_version, max_version) | |
C.register_schema(schema) | |