268 lines
8.4 KiB
Python
268 lines
8.4 KiB
Python
#!/usr/bin/env python
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import glob
|
|
import os
|
|
import re
|
|
import subprocess
|
|
from textwrap import dedent
|
|
from typing import Iterable
|
|
|
|
autogen_header = """\
|
|
//
|
|
// WARNING: This file is automatically generated! Please edit onnx.in.proto.
|
|
//
|
|
|
|
|
|
"""
|
|
|
|
LITE_OPTION = """
|
|
|
|
// For using protobuf-lite
|
|
option optimize_for = LITE_RUNTIME;
|
|
|
|
"""
|
|
|
|
DEFAULT_PACKAGE_NAME = "onnx"
|
|
|
|
IF_ONNX_ML_REGEX = re.compile(r"\s*//\s*#if\s+ONNX-ML\s*$")
|
|
ENDIF_ONNX_ML_REGEX = re.compile(r"\s*//\s*#endif\s*$")
|
|
ELSE_ONNX_ML_REGEX = re.compile(r"\s*//\s*#else\s*$")
|
|
|
|
|
|
def process_ifs(lines: Iterable[str], onnx_ml: bool) -> Iterable[str]:
|
|
in_if = 0
|
|
for line in lines:
|
|
if IF_ONNX_ML_REGEX.match(line):
|
|
assert in_if == 0
|
|
in_if = 1
|
|
elif ELSE_ONNX_ML_REGEX.match(line):
|
|
assert in_if == 1
|
|
in_if = 2
|
|
elif ENDIF_ONNX_ML_REGEX.match(line):
|
|
assert in_if == 1 or in_if == 2 # noqa: PLR1714, PLR2004
|
|
in_if = 0
|
|
else: # noqa: PLR5501
|
|
if in_if == 0:
|
|
yield line
|
|
elif in_if == 1 and onnx_ml:
|
|
yield line
|
|
elif in_if == 2 and not onnx_ml: # noqa: PLR2004
|
|
yield line
|
|
|
|
|
|
IMPORT_REGEX = re.compile(r'(\s*)import\s*"([^"]*)\.proto";\s*$')
|
|
PACKAGE_NAME_REGEX = re.compile(r"\{PACKAGE_NAME\}")
|
|
ML_REGEX = re.compile(r"(.*)\-ml")
|
|
|
|
|
|
def process_package_name(lines: Iterable[str], package_name: str) -> Iterable[str]:
|
|
need_rename = package_name != DEFAULT_PACKAGE_NAME
|
|
for line in lines:
|
|
m = IMPORT_REGEX.match(line) if need_rename else None
|
|
if m:
|
|
include_name = m.group(2)
|
|
ml = ML_REGEX.match(include_name)
|
|
if ml:
|
|
include_name = f"{ml.group(1)}_{package_name}-ml"
|
|
else:
|
|
include_name = f"{include_name}_{package_name}"
|
|
yield m.group(1) + f'import "{include_name}.proto";'
|
|
else:
|
|
yield PACKAGE_NAME_REGEX.sub(package_name, line)
|
|
|
|
|
|
PROTO_SYNTAX_REGEX = re.compile(r'(\s*)syntax\s*=\s*"proto2"\s*;\s*$')
|
|
OPTIONAL_REGEX = re.compile(r"(\s*)optional\s(.*)$")
|
|
|
|
|
|
def convert_to_proto3(lines: Iterable[str]) -> Iterable[str]:
|
|
for line in lines:
|
|
# Set the syntax specifier
|
|
m = PROTO_SYNTAX_REGEX.match(line)
|
|
if m:
|
|
yield m.group(1) + 'syntax = "proto3";'
|
|
continue
|
|
|
|
# Remove optional keywords
|
|
m = OPTIONAL_REGEX.match(line)
|
|
if m:
|
|
yield m.group(1) + m.group(2)
|
|
continue
|
|
|
|
# Rewrite import
|
|
m = IMPORT_REGEX.match(line)
|
|
if m:
|
|
yield m.group(1) + f'import "{m.group(2)}.proto3";'
|
|
continue
|
|
|
|
yield line
|
|
|
|
|
|
def gen_proto3_code(
|
|
protoc_path: str, proto3_path: str, include_path: str, cpp_out: str, python_out: str
|
|
) -> None:
|
|
print(f"Generate pb3 code using {protoc_path}")
|
|
build_args = [protoc_path, proto3_path, "-I", include_path]
|
|
build_args.extend(["--cpp_out", cpp_out, "--python_out", python_out])
|
|
subprocess.check_call(build_args)
|
|
|
|
|
|
def translate(source: str, proto: int, onnx_ml: bool, package_name: str) -> str:
|
|
lines: Iterable[str] = source.splitlines()
|
|
lines = process_ifs(lines, onnx_ml=onnx_ml)
|
|
lines = process_package_name(lines, package_name=package_name)
|
|
if proto == 3: # noqa: PLR2004
|
|
lines = convert_to_proto3(lines)
|
|
else:
|
|
assert proto == 2 # noqa: PLR2004
|
|
return "\n".join(lines) # TODO: not Windows friendly
|
|
|
|
|
|
def qualify(f: str, pardir: str | None = None) -> str:
|
|
if pardir is None:
|
|
pardir = os.path.realpath(os.path.dirname(__file__))
|
|
return os.path.join(pardir, f)
|
|
|
|
|
|
def convert(
|
|
stem: str,
|
|
package_name: str,
|
|
output: str,
|
|
do_onnx_ml: bool = False,
|
|
lite: bool = False,
|
|
protoc_path: str = "",
|
|
) -> None:
|
|
proto_in = qualify(f"{stem}.in.proto")
|
|
need_rename = package_name != DEFAULT_PACKAGE_NAME
|
|
# Having a separate variable for import_ml ensures that the import statements for the generated
|
|
# proto files can be set separately from the ONNX_ML environment variable setting.
|
|
import_ml = do_onnx_ml
|
|
# We do not want to generate the onnx-data-ml.proto files for onnx-data.in.proto,
|
|
# as there is no change between onnx-data.proto and the ML version.
|
|
if "onnx-data" in proto_in:
|
|
do_onnx_ml = False
|
|
if do_onnx_ml:
|
|
proto_base = f"{stem}_{package_name}-ml" if need_rename else f"{stem}-ml"
|
|
else:
|
|
proto_base = f"{stem}_{package_name}" if need_rename else f"{stem}"
|
|
proto = qualify(f"{proto_base}.proto", pardir=output)
|
|
proto3 = qualify(f"{proto_base}.proto3", pardir=output)
|
|
|
|
print(f"Processing {proto_in}")
|
|
with open(proto_in, encoding="utf-8") as fin:
|
|
source = fin.read()
|
|
print(f"Writing {proto}")
|
|
with open(proto, "w", newline="", encoding="utf-8") as fout:
|
|
fout.write(autogen_header)
|
|
fout.write(
|
|
translate(source, proto=2, onnx_ml=import_ml, package_name=package_name)
|
|
)
|
|
if lite:
|
|
fout.write(LITE_OPTION)
|
|
print(f"Writing {proto3}")
|
|
with open(proto3, "w", newline="", encoding="utf-8") as fout:
|
|
fout.write(autogen_header)
|
|
fout.write(
|
|
translate(source, proto=3, onnx_ml=import_ml, package_name=package_name)
|
|
)
|
|
if lite:
|
|
fout.write(LITE_OPTION)
|
|
if protoc_path:
|
|
porto3_dir = os.path.dirname(proto3)
|
|
base_dir = os.path.dirname(porto3_dir)
|
|
gen_proto3_code(protoc_path, proto3, base_dir, base_dir, base_dir)
|
|
pb3_files = glob.glob(os.path.join(porto3_dir, f"{proto_base}.proto3.*"))
|
|
for pb3_file in pb3_files:
|
|
print(f"Removing {pb3_file}")
|
|
os.remove(pb3_file)
|
|
|
|
if need_rename:
|
|
if do_onnx_ml:
|
|
proto_header = qualify(f"{stem}-ml.pb.h", pardir=output)
|
|
else:
|
|
proto_header = qualify(f"{stem}.pb.h", pardir=output)
|
|
print(f"Writing {proto_header}")
|
|
with open(proto_header, "w", newline="", encoding="utf-8") as fout:
|
|
fout.write("#pragma once\n")
|
|
fout.write(f'#include "{proto_base}.pb.h"\n')
|
|
|
|
# Generate py mapping
|
|
# "-" is invalid in python module name, replaces '-' with '_'
|
|
pb_py = qualify(f"{stem.replace('-', '_')}_pb.py", pardir=output)
|
|
if need_rename:
|
|
pb2_py = qualify(f"{proto_base.replace('-', '_')}_pb2.py", pardir=output)
|
|
else: # noqa: PLR5501
|
|
if do_onnx_ml:
|
|
pb2_py = qualify(f"{stem.replace('-', '_')}_ml_pb2.py", pardir=output)
|
|
else:
|
|
pb2_py = qualify(f"{stem.replace('-', '_')}_pb2.py", pardir=output)
|
|
|
|
print(f"generating {pb_py}")
|
|
with open(pb_py, "w", encoding="utf-8") as f:
|
|
f.write(
|
|
dedent(
|
|
f"""\
|
|
# This file is generated by setup.py. DO NOT EDIT!
|
|
|
|
|
|
from .{os.path.splitext(os.path.basename(pb2_py))[0]} import * # noqa
|
|
"""
|
|
)
|
|
)
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(
|
|
description="Generates .proto file variations from .in.proto"
|
|
)
|
|
parser.add_argument(
|
|
"-p",
|
|
"--package",
|
|
default="onnx",
|
|
help="package name in the generated proto files (default: %(default)s)",
|
|
)
|
|
parser.add_argument("-m", "--ml", action="store_true", help="ML mode")
|
|
parser.add_argument(
|
|
"-l",
|
|
"--lite",
|
|
action="store_true",
|
|
help="generate lite proto to use with protobuf-lite",
|
|
)
|
|
parser.add_argument(
|
|
"-o",
|
|
"--output",
|
|
default=os.path.realpath(os.path.dirname(__file__)),
|
|
help="output directory (default: %(default)s)",
|
|
)
|
|
parser.add_argument(
|
|
"--protoc_path", default="", help="path to protoc for proto3 file validation"
|
|
)
|
|
parser.add_argument(
|
|
"stems",
|
|
nargs="*",
|
|
default=["onnx", "onnx-operators", "onnx-data"],
|
|
help="list of .in.proto file stems (default: %(default)s)",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
if not os.path.exists(args.output):
|
|
os.makedirs(args.output)
|
|
|
|
for stem in args.stems:
|
|
convert(
|
|
stem,
|
|
package_name=args.package,
|
|
output=args.output,
|
|
do_onnx_ml=args.ml,
|
|
lite=args.lite,
|
|
protoc_path=args.protoc_path,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|