I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,5 @@
# This is __init__.pyi, not __init__.py
# This directory is only considered for typing information, not for actual
# module implementations.
ONNX_ML: bool = ...

View File

@ -0,0 +1,21 @@
class CheckerContext:
ir_version: int = ...
opset_imports: dict[str, int] = ...
class LexicalScopeContext:
ir_version: int = ...
opset_imports: dict[str, int] = ...
class ValidationError(Exception): ...
def check_value_info(bytes: bytes, checker_context: CheckerContext) -> None: ... # noqa: A002
def check_tensor(bytes: bytes, checker_context: CheckerContext) -> None: ... # noqa: A002
def check_sparse_tensor(bytes: bytes, checker_context: CheckerContext) -> None: ... # noqa: A002
def check_attribute(bytes: bytes, checker_context: CheckerContext, lexical_scope_context: LexicalScopeContext) -> None: ... # noqa: A002
def check_node(bytes: bytes, checker_context: CheckerContext, lexical_scope_context: LexicalScopeContext) -> None: ... # noqa: A002
def check_function(bytes: bytes, checker_context: CheckerContext, lexical_scope_context: LexicalScopeContext) -> None: ... # noqa: A002
def check_graph(bytes: bytes, checker_context: CheckerContext, lexical_scope_context: LexicalScopeContext) -> None: ... # noqa: A002
def check_model(bytes: bytes, full_check: bool, skip_opset_compatibility_check: bool, check_custom_domain: bool) -> None: ... # noqa: A002
def check_model_path(path: str, full_check: bool, skip_opset_compatibility_check: bool, check_custom_domain: bool) -> None: ...

View File

@ -0,0 +1,200 @@
"""Submodule containing all the ONNX schema definitions."""
from typing import Sequence, overload
from onnx import AttributeProto, FunctionProto
class SchemaError(Exception): ...
class OpSchema:
def __init__(
self,
name: str,
domain: str,
since_version: int,
doc: str = "",
*,
inputs: Sequence[OpSchema.FormalParameter] = (),
outputs: Sequence[OpSchema.FormalParameter] = (),
type_constraints: Sequence[tuple[str, Sequence[str], str]] = (),
attributes: Sequence[OpSchema.Attribute] = (),
) -> None: ...
@property
def file(self) -> str: ...
@property
def line(self) -> int: ...
@property
def support_level(self) -> SupportType: ...
@property
def doc(self) -> str | None: ...
@property
def since_version(self) -> int: ...
@property
def deprecated(self) -> bool: ...
@property
def domain(self) -> str: ...
@property
def name(self) -> str: ...
@property
def min_input(self) -> int: ...
@property
def max_input(self) -> int: ...
@property
def min_output(self) -> int: ...
@property
def max_output(self) -> int: ...
@property
def attributes(self) -> dict[str, Attribute]: ...
@property
def inputs(self) -> Sequence[FormalParameter]: ...
@property
def outputs(self) -> Sequence[FormalParameter]: ...
@property
def type_constraints(self) -> Sequence[TypeConstraintParam]: ...
@property
def has_type_and_shape_inference_function(self) -> bool: ...
@property
def has_data_propagation_function(self) -> bool: ...
@staticmethod
def is_infinite(v: int) -> bool: ...
def consumed(self, schema: OpSchema, i: int) -> tuple[UseType, int]: ...
def _infer_node_outputs(
self,
node_proto: bytes,
value_types: dict[str, bytes],
input_data: dict[str, bytes],
input_sparse_data: dict[str, bytes],
) -> dict[str, bytes]: ...
@property
def function_body(self) -> FunctionProto: ...
class TypeConstraintParam:
def __init__(
self,
type_param_str: str,
allowed_type_strs: Sequence[str],
description: str = "",
) -> None:
"""Type constraint parameter.
Args:
type_param_str: Type parameter string, for example, "T", "T1", etc.
allowed_type_strs: Allowed type strings for this type parameter. E.g. ["tensor(float)"].
description: Type parameter description.
"""
@property
def type_param_str(self) -> str: ...
@property
def description(self) -> str: ...
@property
def allowed_type_strs(self) -> Sequence[str]: ...
class FormalParameterOption:
Single: OpSchema.FormalParameterOption = ...
Optional: OpSchema.FormalParameterOption = ...
Variadic: OpSchema.FormalParameterOption = ...
class DifferentiationCategory:
Unknown: OpSchema.DifferentiationCategory = ...
Differentiable: OpSchema.DifferentiationCategory = ...
NonDifferentiable: OpSchema.DifferentiationCategory = ...
class FormalParameter:
def __init__(
self,
name: str,
type_str: str,
description: str = "",
*,
param_option: OpSchema.FormalParameterOption = OpSchema.FormalParameterOption.Single, # noqa: F821
is_homogeneous: bool = True,
min_arity: int = 1,
differentiation_category: OpSchema.DifferentiationCategory = OpSchema.DifferentiationCategory.Unknown, # noqa: F821
) -> None: ...
@property
def name(self) -> str: ...
@property
def types(self) -> set[str]: ...
@property
def type_str(self) -> str: ...
@property
def description(self) -> str: ...
@property
def option(self) -> OpSchema.FormalParameterOption: ...
@property
def is_homogeneous(self) -> bool: ...
@property
def min_arity(self) -> int: ...
@property
def differentiation_category(self) -> OpSchema.DifferentiationCategory: ...
class AttrType:
FLOAT: OpSchema.AttrType = ...
INT: OpSchema.AttrType = ...
STRING: OpSchema.AttrType = ...
TENSOR: OpSchema.AttrType = ...
GRAPH: OpSchema.AttrType = ...
SPARSE_TENSOR: OpSchema.AttrType = ...
TYPE_PROTO: OpSchema.AttrType = ...
FLOATS: OpSchema.AttrType = ...
INTS: OpSchema.AttrType = ...
STRINGS: OpSchema.AttrType = ...
TENSORS: OpSchema.AttrType = ...
GRAPHS: OpSchema.AttrType = ...
SPARSE_TENSORS: OpSchema.AttrType = ...
TYPE_PROTOS: OpSchema.AttrType = ...
class Attribute:
@overload
def __init__(
self,
name: str,
type: OpSchema.AttrType, # noqa: A002
description: str = "",
*,
required: bool = True,
) -> None: ...
@overload
def __init__(
self,
name: str,
default_value: AttributeProto,
description: str = "",
) -> None: ...
@property
def name(self) -> str: ...
@property
def description(self) -> str: ...
@property
def type(self) -> OpSchema.AttrType: ...
@property
def default_value(self) -> AttributeProto: ...
@property
def required(self) -> bool: ...
class SupportType(int):
COMMON: OpSchema.SupportType = ...
EXPERIMENTAL: OpSchema.SupportType = ...
class UseType:
DEFAULT: OpSchema.UseType = ...
CONSUME_ALLOWED: OpSchema.UseType = ...
CONSUME_ENFORCED: OpSchema.UseType = ...
@overload
def has_schema(op_type: str, domain: str = "") -> bool: ...
@overload
def has_schema(
op_type: str, max_inclusive_version: int, domain: str = ""
) -> bool: ...
def schema_version_map() -> dict[str, tuple[int, int]]: ...
@overload
def get_schema(
op_type: str, max_inclusive_version: int, domain: str = ""
) -> OpSchema: ...
@overload
def get_schema(op_type: str, domain: str = "") -> OpSchema: ...
def get_all_schemas() -> Sequence[OpSchema]: ...
def get_all_schemas_with_history() -> Sequence[OpSchema]: ...
def set_domain_to_version(domain: str, min_version: int, max_version: int, last_release_version: int = -1) -> None: ...
def register_schema(schema: OpSchema) -> None: ...
def deregister_schema(op_type: str, version: int, domain: str) -> None: ...

View File

@ -0,0 +1,15 @@
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0
def inline_local_functions(model: bytes, convert_version: bool) -> bytes:
"""Inlines calls to model-local function in input model and returns it.
Both input and output are serialized ModelProtos.
"""
def inline_selected_functions(model: bytes, function_ids: list[tuple[str,str]], exclude: bool) -> bytes:
"""Inlines calls to selected model-local functions in input model and returns it.
Inlines all functions specified in function_ids, unless exclude is true, in which
case it inlines all functions except those specified in function_ids.
Both input and output are serialized ModelProtos.
"""

View File

@ -0,0 +1,28 @@
def parse_model(model: str) -> tuple[bool, bytes, bytes]:
"""Returns (success-flag, error-message, serialized-proto).
If success-flag is true, then serialized-proto contains the parsed ModelProto.
Otherwise, error-message contains a string describing the parse error.
"""
def parse_graph(graph: str) -> tuple[bool, bytes, bytes]:
"""Returns (success-flag, error-message, serialized-proto).
If success-flag is true, then serialized-proto contains the parsed GraphProto.
Otherwise, error-message contains a string describing the parse error.
"""
def parse_function(function: str) -> tuple[bool, bytes, bytes]:
"""Returns (success-flag, error-message, serialized-proto).
If success-flag is true, then serialized-proto contains the parsed FunctionProto.
Otherwise, error-message contains a string describing the parse error.
"""
def parse_node(node: str) -> tuple[bool, bytes, bytes]:
"""Returns (success-flag, error-message, serialized-proto).
If success-flag is true, then serialized-proto contains the parsed NodeProto.
Otherwise, error-message contains a string describing the parse error.
"""

View File

@ -0,0 +1,3 @@
def function_to_text(serialized_function_proto: bytes) -> str: ...
def graph_to_text(serialized_graph_proto: bytes) -> str: ...
def model_to_text(serialized_model_proto: bytes) -> str: ...

View File

@ -0,0 +1,16 @@
class InferenceError(Exception): ...
def infer_shapes(
b: bytes, check_type: bool, strict_mode: bool, data_prop: bool
) -> bytes: ...
def infer_shapes_path(
model_path: str,
output_path: str,
check_type: bool,
strict_mode: bool,
data_prop: bool,
) -> None: ...
def infer_function_output_types(bytes: bytes, input_types: list[bytes], attributes: list[bytes]) -> list[bytes]: ... # noqa: A002

View File

@ -0,0 +1,4 @@
class ConvertError(Exception): ...
# Where the first bytes are a serialized ModelProto
def convert_version(bytes: bytes, target: int) -> bytes: ... # noqa: A002