File size: 4,284 Bytes
dc2106c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0

__all__ = [
    "C",
    "ONNX_DOMAIN",
    "ONNX_ML_DOMAIN",
    "AI_ONNX_PREVIEW_TRAINING_DOMAIN",
    "has",
    "register_schema",
    "deregister_schema",
    "get_schema",
    "get_all_schemas",
    "get_all_schemas_with_history",
    "onnx_opset_version",
    "get_function_ops",
    "OpSchema",
    "SchemaError",
]

from typing import List

import onnx.onnx_cpp2py_export.defs as C  # noqa: N812
from onnx import AttributeProto, FunctionProto

ONNX_DOMAIN = ""
ONNX_ML_DOMAIN = "ai.onnx.ml"
AI_ONNX_PREVIEW_TRAINING_DOMAIN = "ai.onnx.preview.training"


has = C.has_schema
get_schema = C.get_schema
get_all_schemas = C.get_all_schemas
get_all_schemas_with_history = C.get_all_schemas_with_history
deregister_schema = C.deregister_schema


def onnx_opset_version() -> int:
    """Return current opset for domain `ai.onnx`."""
    return C.schema_version_map()[ONNX_DOMAIN][1]


def onnx_ml_opset_version() -> int:
    """Return current opset for domain `ai.onnx.ml`."""
    return C.schema_version_map()[ONNX_ML_DOMAIN][1]


@property  # type: ignore
def _function_proto(self):  # type: ignore
    func_proto = FunctionProto()
    func_proto.ParseFromString(self._function_body)
    return func_proto


OpSchema = C.OpSchema  # type: ignore
OpSchema.function_body = _function_proto  # type: ignore


@property  # type: ignore
def _attribute_default_value(self):  # type: ignore
    attr = AttributeProto()
    attr.ParseFromString(self._default_value)
    return attr


OpSchema.Attribute.default_value = _attribute_default_value  # type: ignore


def _op_schema_repr(self) -> str:
    return f"""\

OpSchema(

    name={self.name!r},

    domain={self.domain!r},

    since_version={self.since_version!r},

    doc={self.doc!r},

    type_constraints={self.type_constraints!r},

    inputs={self.inputs!r},

    outputs={self.outputs!r},

    attributes={self.attributes!r}

)"""


OpSchema.__repr__ = _op_schema_repr  # type: ignore


def _op_schema_formal_parameter_repr(self) -> str:
    return (
        f"OpSchema.FormalParameter(name={self.name!r}, type_str={self.type_str!r}, "
        f"description={self.description!r}, param_option={self.option!r}, "
        f"is_homogeneous={self.is_homogeneous!r}, min_arity={self.min_arity!r}, "
        f"differentiation_category={self.differentiation_category!r})"
    )


OpSchema.FormalParameter.__repr__ = _op_schema_formal_parameter_repr  # type: ignore


def _op_schema_type_constraint_param_repr(self) -> str:
    return (
        f"OpSchema.TypeConstraintParam(type_param_str={self.type_param_str!r}, "
        f"allowed_type_strs={self.allowed_type_strs!r}, description={self.description!r})"
    )


OpSchema.TypeConstraintParam.__repr__ = _op_schema_type_constraint_param_repr  # type: ignore


def _op_schema_attribute_repr(self) -> str:
    return (
        f"OpSchema.Attribute(name={self.name!r}, type={self.type!r}, description={self.description!r}, "
        f"default_value={self.default_value!r}, required={self.required!r})"
    )


OpSchema.Attribute.__repr__ = _op_schema_attribute_repr  # type: ignore


def get_function_ops() -> List[OpSchema]:
    """Return operators defined as functions."""
    schemas = C.get_all_schemas()
    return [schema for schema in schemas if schema.has_function or schema.has_context_dependent_function]  # type: ignore


SchemaError = C.SchemaError


def register_schema(schema: OpSchema) -> None:
    """Register a user provided OpSchema.



    The function extends available operator set versions for the provided domain if necessary.



    Args:

        schema: The OpSchema to register.

    """
    version_map = C.schema_version_map()
    domain = schema.domain
    version = schema.since_version
    min_version, max_version = version_map.get(domain, (version, version))
    if domain not in version_map or not (min_version <= version <= max_version):
        min_version = min(min_version, version)
        max_version = max(max_version, version)
        C.set_domain_to_version(schema.domain, min_version, max_version)
    C.register_schema(schema)