Kano001's picture
Upload 2707 files
dc2106c verified
raw
history blame
8.89 kB
#!/usr/bin/env python
# SPDX-License-Identifier: Apache-2.0
import argparse
import glob
import os
import re
import subprocess
from textwrap import dedent
from typing import Iterable, Optional
autogen_header = """\
//
// WARNING: This file is automatically generated! Please edit onnx.in.proto.
//
"""
LITE_OPTION = """
// For using protobuf-lite
option optimize_for = LITE_RUNTIME;
"""
DEFAULT_PACKAGE_NAME = "onnx"
IF_ONNX_ML_REGEX = re.compile(r"\s*//\s*#if\s+ONNX-ML\s*$")
ENDIF_ONNX_ML_REGEX = re.compile(r"\s*//\s*#endif\s*$")
ELSE_ONNX_ML_REGEX = re.compile(r"\s*//\s*#else\s*$")
def process_ifs(lines: Iterable[str], onnx_ml: bool) -> Iterable[str]:
in_if = 0
for line in lines:
if IF_ONNX_ML_REGEX.match(line):
assert in_if == 0
in_if = 1
elif ELSE_ONNX_ML_REGEX.match(line):
assert in_if == 1
in_if = 2
elif ENDIF_ONNX_ML_REGEX.match(line):
assert in_if == 1 or in_if == 2 # noqa: PLR1714, PLR2004
in_if = 0
else: # noqa: PLR5501
if in_if == 0:
yield line
elif in_if == 1 and onnx_ml:
yield line
elif in_if == 2 and not onnx_ml: # noqa: PLR2004
yield line
IMPORT_REGEX = re.compile(r'(\s*)import\s*"([^"]*)\.proto";\s*$')
PACKAGE_NAME_REGEX = re.compile(r"\{PACKAGE_NAME\}")
ML_REGEX = re.compile(r"(.*)\-ml")
def process_package_name(lines: Iterable[str], package_name: str) -> Iterable[str]:
need_rename = package_name != DEFAULT_PACKAGE_NAME
for line in lines:
m = IMPORT_REGEX.match(line) if need_rename else None
if m:
include_name = m.group(2)
ml = ML_REGEX.match(include_name)
if ml:
include_name = f"{ml.group(1)}_{package_name}-ml"
else:
include_name = f"{include_name}_{package_name}"
yield m.group(1) + f'import "{include_name}.proto";'
else:
yield PACKAGE_NAME_REGEX.sub(package_name, line)
PROTO_SYNTAX_REGEX = re.compile(r'(\s*)syntax\s*=\s*"proto2"\s*;\s*$')
OPTIONAL_REGEX = re.compile(r"(\s*)optional\s(.*)$")
def convert_to_proto3(lines: Iterable[str]) -> Iterable[str]:
for line in lines:
# Set the syntax specifier
m = PROTO_SYNTAX_REGEX.match(line)
if m:
yield m.group(1) + 'syntax = "proto3";'
continue
# Remove optional keywords
m = OPTIONAL_REGEX.match(line)
if m:
yield m.group(1) + m.group(2)
continue
# Rewrite import
m = IMPORT_REGEX.match(line)
if m:
yield m.group(1) + f'import "{m.group(2)}.proto3";'
continue
yield line
def gen_proto3_code(
protoc_path: str, proto3_path: str, include_path: str, cpp_out: str, python_out: str
) -> None:
print(f"Generate pb3 code using {protoc_path}")
build_args = [protoc_path, proto3_path, "-I", include_path]
build_args.extend(["--cpp_out", cpp_out, "--python_out", python_out])
subprocess.check_call(build_args)
def translate(source: str, proto: int, onnx_ml: bool, package_name: str) -> str:
lines: Iterable[str] = source.splitlines()
lines = process_ifs(lines, onnx_ml=onnx_ml)
lines = process_package_name(lines, package_name=package_name)
if proto == 3: # noqa: PLR2004
lines = convert_to_proto3(lines)
else:
assert proto == 2 # noqa: PLR2004
return "\n".join(lines) # TODO: not Windows friendly
def qualify(f: str, pardir: Optional[str] = None) -> str:
if pardir is None:
pardir = os.path.realpath(os.path.dirname(__file__))
return os.path.join(pardir, f)
def convert(
stem: str,
package_name: str,
output: str,
do_onnx_ml: bool = False,
lite: bool = False,
protoc_path: str = "",
) -> None:
proto_in = qualify(f"{stem}.in.proto")
need_rename = package_name != DEFAULT_PACKAGE_NAME
# Having a separate variable for import_ml ensures that the import statements for the generated
# proto files can be set separately from the ONNX_ML environment variable setting.
import_ml = do_onnx_ml
# We do not want to generate the onnx-data-ml.proto files for onnx-data.in.proto,
# as there is no change between onnx-data.proto and the ML version.
if "onnx-data" in proto_in:
do_onnx_ml = False
if do_onnx_ml:
proto_base = f"{stem}_{package_name}-ml" if need_rename else f"{stem}-ml"
else:
proto_base = f"{stem}_{package_name}" if need_rename else f"{stem}"
proto = qualify(f"{proto_base}.proto", pardir=output)
proto3 = qualify(f"{proto_base}.proto3", pardir=output)
print(f"Processing {proto_in}")
with open(proto_in, encoding="utf-8") as fin:
source = fin.read()
print(f"Writing {proto}")
with open(proto, "w", newline="", encoding="utf-8") as fout:
fout.write(autogen_header)
fout.write(
translate(source, proto=2, onnx_ml=import_ml, package_name=package_name)
)
if lite:
fout.write(LITE_OPTION)
print(f"Writing {proto3}")
with open(proto3, "w", newline="", encoding="utf-8") as fout:
fout.write(autogen_header)
fout.write(
translate(source, proto=3, onnx_ml=import_ml, package_name=package_name)
)
if lite:
fout.write(LITE_OPTION)
if protoc_path:
porto3_dir = os.path.dirname(proto3)
base_dir = os.path.dirname(porto3_dir)
gen_proto3_code(protoc_path, proto3, base_dir, base_dir, base_dir)
pb3_files = glob.glob(os.path.join(porto3_dir, f"{proto_base}.proto3.*"))
for pb3_file in pb3_files:
print(f"Removing {pb3_file}")
os.remove(pb3_file)
if need_rename:
if do_onnx_ml:
proto_header = qualify(f"{stem}-ml.pb.h", pardir=output)
else:
proto_header = qualify(f"{stem}.pb.h", pardir=output)
print(f"Writing {proto_header}")
with open(proto_header, "w", newline="", encoding="utf-8") as fout:
fout.write("#pragma once\n")
fout.write(f'#include "{proto_base}.pb.h"\n')
# Generate py mapping
# "-" is invalid in python module name, replaces '-' with '_'
pb_py = qualify(f"{stem.replace('-', '_')}_pb.py", pardir=output)
if need_rename:
pb2_py = qualify(f"{proto_base.replace('-', '_')}_pb2.py", pardir=output)
else: # noqa: PLR5501
if do_onnx_ml:
pb2_py = qualify(f"{stem.replace('-', '_')}_ml_pb2.py", pardir=output)
else:
pb2_py = qualify(f"{stem.replace('-', '_')}_pb2.py", pardir=output)
print(f"generating {pb_py}")
with open(pb_py, "w", encoding="utf-8") as f:
f.write(
dedent(
f"""\
# This file is generated by setup.py. DO NOT EDIT!
from .{os.path.splitext(os.path.basename(pb2_py))[0]} import * # noqa
"""
)
)
def main() -> None:
parser = argparse.ArgumentParser(
description="Generates .proto file variations from .in.proto"
)
parser.add_argument(
"-p",
"--package",
default="onnx",
help="package name in the generated proto files (default: %(default)s)",
)
parser.add_argument("-m", "--ml", action="store_true", help="ML mode")
parser.add_argument(
"-l",
"--lite",
action="store_true",
help="generate lite proto to use with protobuf-lite",
)
parser.add_argument(
"-o",
"--output",
default=os.path.realpath(os.path.dirname(__file__)),
help="output directory (default: %(default)s)",
)
parser.add_argument(
"--protoc_path", default="", help="path to protoc for proto3 file validation"
)
parser.add_argument(
"stems",
nargs="*",
default=["onnx", "onnx-operators", "onnx-data"],
help="list of .in.proto file stems (default: %(default)s)",
)
args = parser.parse_args()
if not os.path.exists(args.output):
os.makedirs(args.output)
for stem in args.stems:
convert(
stem,
package_name=args.package,
output=args.output,
do_onnx_ml=args.ml,
lite=args.lite,
protoc_path=args.protoc_path,
)
if __name__ == "__main__":
main()