#!/usr/bin/env python # SPDX-License-Identifier: Apache-2.0 import argparse import glob import os import re import subprocess from textwrap import dedent from typing import Iterable, Optional autogen_header = """\ // // WARNING: This file is automatically generated! Please edit onnx.in.proto. // """ LITE_OPTION = """ // For using protobuf-lite option optimize_for = LITE_RUNTIME; """ DEFAULT_PACKAGE_NAME = "onnx" IF_ONNX_ML_REGEX = re.compile(r"\s*//\s*#if\s+ONNX-ML\s*$") ENDIF_ONNX_ML_REGEX = re.compile(r"\s*//\s*#endif\s*$") ELSE_ONNX_ML_REGEX = re.compile(r"\s*//\s*#else\s*$") def process_ifs(lines: Iterable[str], onnx_ml: bool) -> Iterable[str]: in_if = 0 for line in lines: if IF_ONNX_ML_REGEX.match(line): assert in_if == 0 in_if = 1 elif ELSE_ONNX_ML_REGEX.match(line): assert in_if == 1 in_if = 2 elif ENDIF_ONNX_ML_REGEX.match(line): assert in_if == 1 or in_if == 2 # noqa: PLR1714, PLR2004 in_if = 0 else: # noqa: PLR5501 if in_if == 0: yield line elif in_if == 1 and onnx_ml: yield line elif in_if == 2 and not onnx_ml: # noqa: PLR2004 yield line IMPORT_REGEX = re.compile(r'(\s*)import\s*"([^"]*)\.proto";\s*$') PACKAGE_NAME_REGEX = re.compile(r"\{PACKAGE_NAME\}") ML_REGEX = re.compile(r"(.*)\-ml") def process_package_name(lines: Iterable[str], package_name: str) -> Iterable[str]: need_rename = package_name != DEFAULT_PACKAGE_NAME for line in lines: m = IMPORT_REGEX.match(line) if need_rename else None if m: include_name = m.group(2) ml = ML_REGEX.match(include_name) if ml: include_name = f"{ml.group(1)}_{package_name}-ml" else: include_name = f"{include_name}_{package_name}" yield m.group(1) + f'import "{include_name}.proto";' else: yield PACKAGE_NAME_REGEX.sub(package_name, line) PROTO_SYNTAX_REGEX = re.compile(r'(\s*)syntax\s*=\s*"proto2"\s*;\s*$') OPTIONAL_REGEX = re.compile(r"(\s*)optional\s(.*)$") def convert_to_proto3(lines: Iterable[str]) -> Iterable[str]: for line in lines: # Set the syntax specifier m = PROTO_SYNTAX_REGEX.match(line) if m: yield m.group(1) + 'syntax = "proto3";' continue # Remove optional keywords m = OPTIONAL_REGEX.match(line) if m: yield m.group(1) + m.group(2) continue # Rewrite import m = IMPORT_REGEX.match(line) if m: yield m.group(1) + f'import "{m.group(2)}.proto3";' continue yield line def gen_proto3_code( protoc_path: str, proto3_path: str, include_path: str, cpp_out: str, python_out: str ) -> None: print(f"Generate pb3 code using {protoc_path}") build_args = [protoc_path, proto3_path, "-I", include_path] build_args.extend(["--cpp_out", cpp_out, "--python_out", python_out]) subprocess.check_call(build_args) def translate(source: str, proto: int, onnx_ml: bool, package_name: str) -> str: lines: Iterable[str] = source.splitlines() lines = process_ifs(lines, onnx_ml=onnx_ml) lines = process_package_name(lines, package_name=package_name) if proto == 3: # noqa: PLR2004 lines = convert_to_proto3(lines) else: assert proto == 2 # noqa: PLR2004 return "\n".join(lines) # TODO: not Windows friendly def qualify(f: str, pardir: Optional[str] = None) -> str: if pardir is None: pardir = os.path.realpath(os.path.dirname(__file__)) return os.path.join(pardir, f) def convert( stem: str, package_name: str, output: str, do_onnx_ml: bool = False, lite: bool = False, protoc_path: str = "", ) -> None: proto_in = qualify(f"{stem}.in.proto") need_rename = package_name != DEFAULT_PACKAGE_NAME # Having a separate variable for import_ml ensures that the import statements for the generated # proto files can be set separately from the ONNX_ML environment variable setting. import_ml = do_onnx_ml # We do not want to generate the onnx-data-ml.proto files for onnx-data.in.proto, # as there is no change between onnx-data.proto and the ML version. if "onnx-data" in proto_in: do_onnx_ml = False if do_onnx_ml: proto_base = f"{stem}_{package_name}-ml" if need_rename else f"{stem}-ml" else: proto_base = f"{stem}_{package_name}" if need_rename else f"{stem}" proto = qualify(f"{proto_base}.proto", pardir=output) proto3 = qualify(f"{proto_base}.proto3", pardir=output) print(f"Processing {proto_in}") with open(proto_in, encoding="utf-8") as fin: source = fin.read() print(f"Writing {proto}") with open(proto, "w", newline="", encoding="utf-8") as fout: fout.write(autogen_header) fout.write( translate(source, proto=2, onnx_ml=import_ml, package_name=package_name) ) if lite: fout.write(LITE_OPTION) print(f"Writing {proto3}") with open(proto3, "w", newline="", encoding="utf-8") as fout: fout.write(autogen_header) fout.write( translate(source, proto=3, onnx_ml=import_ml, package_name=package_name) ) if lite: fout.write(LITE_OPTION) if protoc_path: porto3_dir = os.path.dirname(proto3) base_dir = os.path.dirname(porto3_dir) gen_proto3_code(protoc_path, proto3, base_dir, base_dir, base_dir) pb3_files = glob.glob(os.path.join(porto3_dir, f"{proto_base}.proto3.*")) for pb3_file in pb3_files: print(f"Removing {pb3_file}") os.remove(pb3_file) if need_rename: if do_onnx_ml: proto_header = qualify(f"{stem}-ml.pb.h", pardir=output) else: proto_header = qualify(f"{stem}.pb.h", pardir=output) print(f"Writing {proto_header}") with open(proto_header, "w", newline="", encoding="utf-8") as fout: fout.write("#pragma once\n") fout.write(f'#include "{proto_base}.pb.h"\n') # Generate py mapping # "-" is invalid in python module name, replaces '-' with '_' pb_py = qualify(f"{stem.replace('-', '_')}_pb.py", pardir=output) if need_rename: pb2_py = qualify(f"{proto_base.replace('-', '_')}_pb2.py", pardir=output) else: # noqa: PLR5501 if do_onnx_ml: pb2_py = qualify(f"{stem.replace('-', '_')}_ml_pb2.py", pardir=output) else: pb2_py = qualify(f"{stem.replace('-', '_')}_pb2.py", pardir=output) print(f"generating {pb_py}") with open(pb_py, "w", encoding="utf-8") as f: f.write( dedent( f"""\ # This file is generated by setup.py. DO NOT EDIT! from .{os.path.splitext(os.path.basename(pb2_py))[0]} import * # noqa """ ) ) def main() -> None: parser = argparse.ArgumentParser( description="Generates .proto file variations from .in.proto" ) parser.add_argument( "-p", "--package", default="onnx", help="package name in the generated proto files (default: %(default)s)", ) parser.add_argument("-m", "--ml", action="store_true", help="ML mode") parser.add_argument( "-l", "--lite", action="store_true", help="generate lite proto to use with protobuf-lite", ) parser.add_argument( "-o", "--output", default=os.path.realpath(os.path.dirname(__file__)), help="output directory (default: %(default)s)", ) parser.add_argument( "--protoc_path", default="", help="path to protoc for proto3 file validation" ) parser.add_argument( "stems", nargs="*", default=["onnx", "onnx-operators", "onnx-data"], help="list of .in.proto file stems (default: %(default)s)", ) args = parser.parse_args() if not os.path.exists(args.output): os.makedirs(args.output) for stem in args.stems: convert( stem, package_name=args.package, output=args.output, do_onnx_ml=args.ml, lite=args.lite, protoc_path=args.protoc_path, ) if __name__ == "__main__": main()