File size: 8,887 Bytes
dc2106c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
#!/usr/bin/env python

# SPDX-License-Identifier: Apache-2.0

import argparse
import glob
import os
import re
import subprocess
from textwrap import dedent
from typing import Iterable, Optional

autogen_header = """\

//

// WARNING: This file is automatically generated!  Please edit onnx.in.proto.

//





"""

LITE_OPTION = """



// For using protobuf-lite

option optimize_for = LITE_RUNTIME;



"""

DEFAULT_PACKAGE_NAME = "onnx"

IF_ONNX_ML_REGEX = re.compile(r"\s*//\s*#if\s+ONNX-ML\s*$")
ENDIF_ONNX_ML_REGEX = re.compile(r"\s*//\s*#endif\s*$")
ELSE_ONNX_ML_REGEX = re.compile(r"\s*//\s*#else\s*$")


def process_ifs(lines: Iterable[str], onnx_ml: bool) -> Iterable[str]:
    in_if = 0
    for line in lines:
        if IF_ONNX_ML_REGEX.match(line):
            assert in_if == 0
            in_if = 1
        elif ELSE_ONNX_ML_REGEX.match(line):
            assert in_if == 1
            in_if = 2
        elif ENDIF_ONNX_ML_REGEX.match(line):
            assert in_if == 1 or in_if == 2  # noqa: PLR1714, PLR2004
            in_if = 0
        else:  # noqa: PLR5501
            if in_if == 0:
                yield line
            elif in_if == 1 and onnx_ml:
                yield line
            elif in_if == 2 and not onnx_ml:  # noqa: PLR2004
                yield line


IMPORT_REGEX = re.compile(r'(\s*)import\s*"([^"]*)\.proto";\s*$')
PACKAGE_NAME_REGEX = re.compile(r"\{PACKAGE_NAME\}")
ML_REGEX = re.compile(r"(.*)\-ml")


def process_package_name(lines: Iterable[str], package_name: str) -> Iterable[str]:
    need_rename = package_name != DEFAULT_PACKAGE_NAME
    for line in lines:
        m = IMPORT_REGEX.match(line) if need_rename else None
        if m:
            include_name = m.group(2)
            ml = ML_REGEX.match(include_name)
            if ml:
                include_name = f"{ml.group(1)}_{package_name}-ml"
            else:
                include_name = f"{include_name}_{package_name}"
            yield m.group(1) + f'import "{include_name}.proto";'
        else:
            yield PACKAGE_NAME_REGEX.sub(package_name, line)


PROTO_SYNTAX_REGEX = re.compile(r'(\s*)syntax\s*=\s*"proto2"\s*;\s*$')
OPTIONAL_REGEX = re.compile(r"(\s*)optional\s(.*)$")


def convert_to_proto3(lines: Iterable[str]) -> Iterable[str]:
    for line in lines:
        # Set the syntax specifier
        m = PROTO_SYNTAX_REGEX.match(line)
        if m:
            yield m.group(1) + 'syntax = "proto3";'
            continue

        # Remove optional keywords
        m = OPTIONAL_REGEX.match(line)
        if m:
            yield m.group(1) + m.group(2)
            continue

        # Rewrite import
        m = IMPORT_REGEX.match(line)
        if m:
            yield m.group(1) + f'import "{m.group(2)}.proto3";'
            continue

        yield line


def gen_proto3_code(

    protoc_path: str, proto3_path: str, include_path: str, cpp_out: str, python_out: str

) -> None:
    print(f"Generate pb3 code using {protoc_path}")
    build_args = [protoc_path, proto3_path, "-I", include_path]
    build_args.extend(["--cpp_out", cpp_out, "--python_out", python_out])
    subprocess.check_call(build_args)


def translate(source: str, proto: int, onnx_ml: bool, package_name: str) -> str:
    lines: Iterable[str] = source.splitlines()
    lines = process_ifs(lines, onnx_ml=onnx_ml)
    lines = process_package_name(lines, package_name=package_name)
    if proto == 3:  # noqa: PLR2004
        lines = convert_to_proto3(lines)
    else:
        assert proto == 2  # noqa: PLR2004
    return "\n".join(lines)  # TODO: not Windows friendly


def qualify(f: str, pardir: Optional[str] = None) -> str:
    if pardir is None:
        pardir = os.path.realpath(os.path.dirname(__file__))
    return os.path.join(pardir, f)


def convert(

    stem: str,

    package_name: str,

    output: str,

    do_onnx_ml: bool = False,

    lite: bool = False,

    protoc_path: str = "",

) -> None:
    proto_in = qualify(f"{stem}.in.proto")
    need_rename = package_name != DEFAULT_PACKAGE_NAME
    # Having a separate variable for import_ml ensures that the import statements for the generated
    # proto files can be set separately from the ONNX_ML environment variable setting.
    import_ml = do_onnx_ml
    # We do not want to generate the onnx-data-ml.proto files for onnx-data.in.proto,
    # as there is no change between onnx-data.proto and the ML version.
    if "onnx-data" in proto_in:
        do_onnx_ml = False
    if do_onnx_ml:
        proto_base = f"{stem}_{package_name}-ml" if need_rename else f"{stem}-ml"
    else:
        proto_base = f"{stem}_{package_name}" if need_rename else f"{stem}"
    proto = qualify(f"{proto_base}.proto", pardir=output)
    proto3 = qualify(f"{proto_base}.proto3", pardir=output)

    print(f"Processing {proto_in}")
    with open(proto_in, encoding="utf-8") as fin:
        source = fin.read()
        print(f"Writing {proto}")
        with open(proto, "w", newline="", encoding="utf-8") as fout:
            fout.write(autogen_header)
            fout.write(
                translate(source, proto=2, onnx_ml=import_ml, package_name=package_name)
            )
            if lite:
                fout.write(LITE_OPTION)
        print(f"Writing {proto3}")
        with open(proto3, "w", newline="", encoding="utf-8") as fout:
            fout.write(autogen_header)
            fout.write(
                translate(source, proto=3, onnx_ml=import_ml, package_name=package_name)
            )
            if lite:
                fout.write(LITE_OPTION)
        if protoc_path:
            porto3_dir = os.path.dirname(proto3)
            base_dir = os.path.dirname(porto3_dir)
            gen_proto3_code(protoc_path, proto3, base_dir, base_dir, base_dir)
            pb3_files = glob.glob(os.path.join(porto3_dir, f"{proto_base}.proto3.*"))
            for pb3_file in pb3_files:
                print(f"Removing {pb3_file}")
                os.remove(pb3_file)

        if need_rename:
            if do_onnx_ml:
                proto_header = qualify(f"{stem}-ml.pb.h", pardir=output)
            else:
                proto_header = qualify(f"{stem}.pb.h", pardir=output)
            print(f"Writing {proto_header}")
            with open(proto_header, "w", newline="", encoding="utf-8") as fout:
                fout.write("#pragma once\n")
                fout.write(f'#include "{proto_base}.pb.h"\n')

    # Generate py mapping
    # "-" is invalid in python module name, replaces '-' with '_'
    pb_py = qualify(f"{stem.replace('-', '_')}_pb.py", pardir=output)
    if need_rename:
        pb2_py = qualify(f"{proto_base.replace('-', '_')}_pb2.py", pardir=output)
    else:  # noqa: PLR5501
        if do_onnx_ml:
            pb2_py = qualify(f"{stem.replace('-', '_')}_ml_pb2.py", pardir=output)
        else:
            pb2_py = qualify(f"{stem.replace('-', '_')}_pb2.py", pardir=output)

    print(f"generating {pb_py}")
    with open(pb_py, "w", encoding="utf-8") as f:
        f.write(
            dedent(
                f"""\

                # This file is generated by setup.py. DO NOT EDIT!





                from .{os.path.splitext(os.path.basename(pb2_py))[0]} import *  # noqa

                """
            )
        )


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Generates .proto file variations from .in.proto"
    )
    parser.add_argument(
        "-p",
        "--package",
        default="onnx",
        help="package name in the generated proto files (default: %(default)s)",
    )
    parser.add_argument("-m", "--ml", action="store_true", help="ML mode")
    parser.add_argument(
        "-l",
        "--lite",
        action="store_true",
        help="generate lite proto to use with protobuf-lite",
    )
    parser.add_argument(
        "-o",
        "--output",
        default=os.path.realpath(os.path.dirname(__file__)),
        help="output directory (default: %(default)s)",
    )
    parser.add_argument(
        "--protoc_path", default="", help="path to protoc for proto3 file validation"
    )
    parser.add_argument(
        "stems",
        nargs="*",
        default=["onnx", "onnx-operators", "onnx-data"],
        help="list of .in.proto file stems (default: %(default)s)",
    )
    args = parser.parse_args()

    if not os.path.exists(args.output):
        os.makedirs(args.output)

    for stem in args.stems:
        convert(
            stem,
            package_name=args.package,
            output=args.output,
            do_onnx_ml=args.ml,
            lite=args.lite,
            protoc_path=args.protoc_path,
        )


if __name__ == "__main__":
    main()