Spaces:
Running
Running
#!/usr/bin/env python | |
# SPDX-License-Identifier: Apache-2.0 | |
import argparse | |
import glob | |
import os | |
import re | |
import subprocess | |
from textwrap import dedent | |
from typing import Iterable, Optional | |
autogen_header = """\ | |
// | |
// WARNING: This file is automatically generated! Please edit onnx.in.proto. | |
// | |
""" | |
LITE_OPTION = """ | |
// For using protobuf-lite | |
option optimize_for = LITE_RUNTIME; | |
""" | |
DEFAULT_PACKAGE_NAME = "onnx" | |
IF_ONNX_ML_REGEX = re.compile(r"\s*//\s*#if\s+ONNX-ML\s*$") | |
ENDIF_ONNX_ML_REGEX = re.compile(r"\s*//\s*#endif\s*$") | |
ELSE_ONNX_ML_REGEX = re.compile(r"\s*//\s*#else\s*$") | |
def process_ifs(lines: Iterable[str], onnx_ml: bool) -> Iterable[str]: | |
in_if = 0 | |
for line in lines: | |
if IF_ONNX_ML_REGEX.match(line): | |
assert in_if == 0 | |
in_if = 1 | |
elif ELSE_ONNX_ML_REGEX.match(line): | |
assert in_if == 1 | |
in_if = 2 | |
elif ENDIF_ONNX_ML_REGEX.match(line): | |
assert in_if == 1 or in_if == 2 # noqa: PLR1714, PLR2004 | |
in_if = 0 | |
else: # noqa: PLR5501 | |
if in_if == 0: | |
yield line | |
elif in_if == 1 and onnx_ml: | |
yield line | |
elif in_if == 2 and not onnx_ml: # noqa: PLR2004 | |
yield line | |
IMPORT_REGEX = re.compile(r'(\s*)import\s*"([^"]*)\.proto";\s*$') | |
PACKAGE_NAME_REGEX = re.compile(r"\{PACKAGE_NAME\}") | |
ML_REGEX = re.compile(r"(.*)\-ml") | |
def process_package_name(lines: Iterable[str], package_name: str) -> Iterable[str]: | |
need_rename = package_name != DEFAULT_PACKAGE_NAME | |
for line in lines: | |
m = IMPORT_REGEX.match(line) if need_rename else None | |
if m: | |
include_name = m.group(2) | |
ml = ML_REGEX.match(include_name) | |
if ml: | |
include_name = f"{ml.group(1)}_{package_name}-ml" | |
else: | |
include_name = f"{include_name}_{package_name}" | |
yield m.group(1) + f'import "{include_name}.proto";' | |
else: | |
yield PACKAGE_NAME_REGEX.sub(package_name, line) | |
PROTO_SYNTAX_REGEX = re.compile(r'(\s*)syntax\s*=\s*"proto2"\s*;\s*$') | |
OPTIONAL_REGEX = re.compile(r"(\s*)optional\s(.*)$") | |
def convert_to_proto3(lines: Iterable[str]) -> Iterable[str]: | |
for line in lines: | |
# Set the syntax specifier | |
m = PROTO_SYNTAX_REGEX.match(line) | |
if m: | |
yield m.group(1) + 'syntax = "proto3";' | |
continue | |
# Remove optional keywords | |
m = OPTIONAL_REGEX.match(line) | |
if m: | |
yield m.group(1) + m.group(2) | |
continue | |
# Rewrite import | |
m = IMPORT_REGEX.match(line) | |
if m: | |
yield m.group(1) + f'import "{m.group(2)}.proto3";' | |
continue | |
yield line | |
def gen_proto3_code( | |
protoc_path: str, proto3_path: str, include_path: str, cpp_out: str, python_out: str | |
) -> None: | |
print(f"Generate pb3 code using {protoc_path}") | |
build_args = [protoc_path, proto3_path, "-I", include_path] | |
build_args.extend(["--cpp_out", cpp_out, "--python_out", python_out]) | |
subprocess.check_call(build_args) | |
def translate(source: str, proto: int, onnx_ml: bool, package_name: str) -> str: | |
lines: Iterable[str] = source.splitlines() | |
lines = process_ifs(lines, onnx_ml=onnx_ml) | |
lines = process_package_name(lines, package_name=package_name) | |
if proto == 3: # noqa: PLR2004 | |
lines = convert_to_proto3(lines) | |
else: | |
assert proto == 2 # noqa: PLR2004 | |
return "\n".join(lines) # TODO: not Windows friendly | |
def qualify(f: str, pardir: Optional[str] = None) -> str: | |
if pardir is None: | |
pardir = os.path.realpath(os.path.dirname(__file__)) | |
return os.path.join(pardir, f) | |
def convert( | |
stem: str, | |
package_name: str, | |
output: str, | |
do_onnx_ml: bool = False, | |
lite: bool = False, | |
protoc_path: str = "", | |
) -> None: | |
proto_in = qualify(f"{stem}.in.proto") | |
need_rename = package_name != DEFAULT_PACKAGE_NAME | |
# Having a separate variable for import_ml ensures that the import statements for the generated | |
# proto files can be set separately from the ONNX_ML environment variable setting. | |
import_ml = do_onnx_ml | |
# We do not want to generate the onnx-data-ml.proto files for onnx-data.in.proto, | |
# as there is no change between onnx-data.proto and the ML version. | |
if "onnx-data" in proto_in: | |
do_onnx_ml = False | |
if do_onnx_ml: | |
proto_base = f"{stem}_{package_name}-ml" if need_rename else f"{stem}-ml" | |
else: | |
proto_base = f"{stem}_{package_name}" if need_rename else f"{stem}" | |
proto = qualify(f"{proto_base}.proto", pardir=output) | |
proto3 = qualify(f"{proto_base}.proto3", pardir=output) | |
print(f"Processing {proto_in}") | |
with open(proto_in, encoding="utf-8") as fin: | |
source = fin.read() | |
print(f"Writing {proto}") | |
with open(proto, "w", newline="", encoding="utf-8") as fout: | |
fout.write(autogen_header) | |
fout.write( | |
translate(source, proto=2, onnx_ml=import_ml, package_name=package_name) | |
) | |
if lite: | |
fout.write(LITE_OPTION) | |
print(f"Writing {proto3}") | |
with open(proto3, "w", newline="", encoding="utf-8") as fout: | |
fout.write(autogen_header) | |
fout.write( | |
translate(source, proto=3, onnx_ml=import_ml, package_name=package_name) | |
) | |
if lite: | |
fout.write(LITE_OPTION) | |
if protoc_path: | |
porto3_dir = os.path.dirname(proto3) | |
base_dir = os.path.dirname(porto3_dir) | |
gen_proto3_code(protoc_path, proto3, base_dir, base_dir, base_dir) | |
pb3_files = glob.glob(os.path.join(porto3_dir, f"{proto_base}.proto3.*")) | |
for pb3_file in pb3_files: | |
print(f"Removing {pb3_file}") | |
os.remove(pb3_file) | |
if need_rename: | |
if do_onnx_ml: | |
proto_header = qualify(f"{stem}-ml.pb.h", pardir=output) | |
else: | |
proto_header = qualify(f"{stem}.pb.h", pardir=output) | |
print(f"Writing {proto_header}") | |
with open(proto_header, "w", newline="", encoding="utf-8") as fout: | |
fout.write("#pragma once\n") | |
fout.write(f'#include "{proto_base}.pb.h"\n') | |
# Generate py mapping | |
# "-" is invalid in python module name, replaces '-' with '_' | |
pb_py = qualify(f"{stem.replace('-', '_')}_pb.py", pardir=output) | |
if need_rename: | |
pb2_py = qualify(f"{proto_base.replace('-', '_')}_pb2.py", pardir=output) | |
else: # noqa: PLR5501 | |
if do_onnx_ml: | |
pb2_py = qualify(f"{stem.replace('-', '_')}_ml_pb2.py", pardir=output) | |
else: | |
pb2_py = qualify(f"{stem.replace('-', '_')}_pb2.py", pardir=output) | |
print(f"generating {pb_py}") | |
with open(pb_py, "w", encoding="utf-8") as f: | |
f.write( | |
dedent( | |
f"""\ | |
# This file is generated by setup.py. DO NOT EDIT! | |
from .{os.path.splitext(os.path.basename(pb2_py))[0]} import * # noqa | |
""" | |
) | |
) | |
def main() -> None: | |
parser = argparse.ArgumentParser( | |
description="Generates .proto file variations from .in.proto" | |
) | |
parser.add_argument( | |
"-p", | |
"--package", | |
default="onnx", | |
help="package name in the generated proto files (default: %(default)s)", | |
) | |
parser.add_argument("-m", "--ml", action="store_true", help="ML mode") | |
parser.add_argument( | |
"-l", | |
"--lite", | |
action="store_true", | |
help="generate lite proto to use with protobuf-lite", | |
) | |
parser.add_argument( | |
"-o", | |
"--output", | |
default=os.path.realpath(os.path.dirname(__file__)), | |
help="output directory (default: %(default)s)", | |
) | |
parser.add_argument( | |
"--protoc_path", default="", help="path to protoc for proto3 file validation" | |
) | |
parser.add_argument( | |
"stems", | |
nargs="*", | |
default=["onnx", "onnx-operators", "onnx-data"], | |
help="list of .in.proto file stems (default: %(default)s)", | |
) | |
args = parser.parse_args() | |
if not os.path.exists(args.output): | |
os.makedirs(args.output) | |
for stem in args.stems: | |
convert( | |
stem, | |
package_name=args.package, | |
output=args.output, | |
do_onnx_ml=args.ml, | |
lite=args.lite, | |
protoc_path=args.protoc_path, | |
) | |
if __name__ == "__main__": | |
main() | |