Spaces:
Running
Running
File size: 4,284 Bytes
dc2106c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0
__all__ = [
"C",
"ONNX_DOMAIN",
"ONNX_ML_DOMAIN",
"AI_ONNX_PREVIEW_TRAINING_DOMAIN",
"has",
"register_schema",
"deregister_schema",
"get_schema",
"get_all_schemas",
"get_all_schemas_with_history",
"onnx_opset_version",
"get_function_ops",
"OpSchema",
"SchemaError",
]
from typing import List
import onnx.onnx_cpp2py_export.defs as C # noqa: N812
from onnx import AttributeProto, FunctionProto
ONNX_DOMAIN = ""
ONNX_ML_DOMAIN = "ai.onnx.ml"
AI_ONNX_PREVIEW_TRAINING_DOMAIN = "ai.onnx.preview.training"
has = C.has_schema
get_schema = C.get_schema
get_all_schemas = C.get_all_schemas
get_all_schemas_with_history = C.get_all_schemas_with_history
deregister_schema = C.deregister_schema
def onnx_opset_version() -> int:
"""Return current opset for domain `ai.onnx`."""
return C.schema_version_map()[ONNX_DOMAIN][1]
def onnx_ml_opset_version() -> int:
"""Return current opset for domain `ai.onnx.ml`."""
return C.schema_version_map()[ONNX_ML_DOMAIN][1]
@property # type: ignore
def _function_proto(self): # type: ignore
func_proto = FunctionProto()
func_proto.ParseFromString(self._function_body)
return func_proto
OpSchema = C.OpSchema # type: ignore
OpSchema.function_body = _function_proto # type: ignore
@property # type: ignore
def _attribute_default_value(self): # type: ignore
attr = AttributeProto()
attr.ParseFromString(self._default_value)
return attr
OpSchema.Attribute.default_value = _attribute_default_value # type: ignore
def _op_schema_repr(self) -> str:
return f"""\
OpSchema(
name={self.name!r},
domain={self.domain!r},
since_version={self.since_version!r},
doc={self.doc!r},
type_constraints={self.type_constraints!r},
inputs={self.inputs!r},
outputs={self.outputs!r},
attributes={self.attributes!r}
)"""
OpSchema.__repr__ = _op_schema_repr # type: ignore
def _op_schema_formal_parameter_repr(self) -> str:
return (
f"OpSchema.FormalParameter(name={self.name!r}, type_str={self.type_str!r}, "
f"description={self.description!r}, param_option={self.option!r}, "
f"is_homogeneous={self.is_homogeneous!r}, min_arity={self.min_arity!r}, "
f"differentiation_category={self.differentiation_category!r})"
)
OpSchema.FormalParameter.__repr__ = _op_schema_formal_parameter_repr # type: ignore
def _op_schema_type_constraint_param_repr(self) -> str:
return (
f"OpSchema.TypeConstraintParam(type_param_str={self.type_param_str!r}, "
f"allowed_type_strs={self.allowed_type_strs!r}, description={self.description!r})"
)
OpSchema.TypeConstraintParam.__repr__ = _op_schema_type_constraint_param_repr # type: ignore
def _op_schema_attribute_repr(self) -> str:
return (
f"OpSchema.Attribute(name={self.name!r}, type={self.type!r}, description={self.description!r}, "
f"default_value={self.default_value!r}, required={self.required!r})"
)
OpSchema.Attribute.__repr__ = _op_schema_attribute_repr # type: ignore
def get_function_ops() -> List[OpSchema]:
"""Return operators defined as functions."""
schemas = C.get_all_schemas()
return [schema for schema in schemas if schema.has_function or schema.has_context_dependent_function] # type: ignore
SchemaError = C.SchemaError
def register_schema(schema: OpSchema) -> None:
"""Register a user provided OpSchema.
The function extends available operator set versions for the provided domain if necessary.
Args:
schema: The OpSchema to register.
"""
version_map = C.schema_version_map()
domain = schema.domain
version = schema.since_version
min_version, max_version = version_map.get(domain, (version, version))
if domain not in version_map or not (min_version <= version <= max_version):
min_version = min(min_version, version)
max_version = max(max_version, version)
C.set_domain_to_version(schema.domain, min_version, max_version)
C.register_schema(schema)
|