I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,294 @@
# mypy: allow-untyped-defs
import warnings
from contextlib import contextmanager
from typing import Any, Iterator
import torch._C
# These are imported so users can access them from the `torch.jit` module
from torch._jit_internal import (
_Await,
_drop,
_IgnoreContextManager,
_isinstance,
_overload,
_overload_method,
export,
Final,
Future,
ignore,
is_scripting,
unused,
)
from torch.jit._async import fork, wait
from torch.jit._await import _awaitable, _awaitable_nowait, _awaitable_wait
from torch.jit._decomposition_utils import _register_decomposition
from torch.jit._freeze import freeze, optimize_for_inference, run_frozen_optimizations
from torch.jit._fuser import (
fuser,
last_executed_optimized_graph,
optimized_execution,
set_fusion_strategy,
)
from torch.jit._ir_utils import _InsertPoint
from torch.jit._script import (
_ScriptProfile,
_unwrap_optional,
Attribute,
CompilationUnit,
interface,
RecursiveScriptClass,
RecursiveScriptModule,
script,
script_method,
ScriptFunction,
ScriptModule,
ScriptWarning,
)
from torch.jit._serialization import (
jit_module_from_flatbuffer,
load,
save,
save_jit_module_to_flatbuffer,
)
from torch.jit._trace import (
_flatten,
_get_trace_graph,
_script_if_tracing,
_unique_state_dict,
is_tracing,
ONNXTracedModule,
TopLevelTracedModule,
trace,
trace_module,
TracedModule,
TracerWarning,
TracingCheckError,
)
from torch.utils import set_module
__all__ = [
"Attribute",
"CompilationUnit",
"Error",
"Future",
"ScriptFunction",
"ScriptModule",
"annotate",
"enable_onednn_fusion",
"export",
"export_opnames",
"fork",
"freeze",
"interface",
"ignore",
"isinstance",
"load",
"onednn_fusion_enabled",
"optimize_for_inference",
"save",
"script",
"script_if_tracing",
"set_fusion_strategy",
"strict_fusion",
"trace",
"trace_module",
"unused",
"wait",
]
# For backwards compatibility
_fork = fork
_wait = wait
_set_fusion_strategy = set_fusion_strategy
def export_opnames(m):
r"""
Generate new bytecode for a Script module.
Returns what the op list would be for a Script Module based off the current code base.
If you have a LiteScriptModule and want to get the currently present
list of ops call _export_operator_list instead.
"""
return torch._C._export_opnames(m._c)
# torch.jit.Error
Error = torch._C.JITException
set_module(Error, "torch.jit")
# This is not perfect but works in common cases
Error.__name__ = "Error"
Error.__qualname__ = "Error"
# for use in python if using annotate
def annotate(the_type, the_value):
"""Use to give type of `the_value` in TorchScript compiler.
This method is a pass-through function that returns `the_value`, used to hint TorchScript
compiler the type of `the_value`. It is a no-op when running outside of TorchScript.
Though TorchScript can infer correct type for most Python expressions, there are some cases where
type inference can be wrong, including:
- Empty containers like `[]` and `{}`, which TorchScript assumes to be container of `Tensor`
- Optional types like `Optional[T]` but assigned a valid value of type `T`, TorchScript would assume
it is type `T` rather than `Optional[T]`
Note that `annotate()` does not help in `__init__` method of `torch.nn.Module` subclasses because it
is executed in eager mode. To annotate types of `torch.nn.Module` attributes,
use :meth:`~torch.jit.Attribute` instead.
Example:
.. testcode::
import torch
from typing import Dict
@torch.jit.script
def fn():
# Telling TorchScript that this empty dictionary is a (str -> int) dictionary
# instead of default dictionary type of (str -> Tensor).
d = torch.jit.annotate(Dict[str, int], {})
# Without `torch.jit.annotate` above, following statement would fail because of
# type mismatch.
d["name"] = 20
.. testcleanup::
del fn
Args:
the_type: Python type that should be passed to TorchScript compiler as type hint for `the_value`
the_value: Value or expression to hint type for.
Returns:
`the_value` is passed back as return value.
"""
return the_value
def script_if_tracing(fn):
"""
Compiles ``fn`` when it is first called during tracing.
``torch.jit.script`` has a non-negligible start up time when it is first called due to
lazy-initializations of many compiler builtins. Therefore you should not use
it in library code. However, you may want to have parts of your library work
in tracing even if they use control flow. In these cases, you should use
``@torch.jit.script_if_tracing`` to substitute for
``torch.jit.script``.
Args:
fn: A function to compile.
Returns:
If called during tracing, a :class:`ScriptFunction` created by `torch.jit.script` is returned.
Otherwise, the original function `fn` is returned.
"""
return _script_if_tracing(fn)
# for torch.jit.isinstance
def isinstance(obj, target_type):
"""
Provide container type refinement in TorchScript.
It can refine parameterized containers of the List, Dict, Tuple, and Optional types. E.g. ``List[str]``,
``Dict[str, List[torch.Tensor]]``, ``Optional[Tuple[int,str,int]]``. It can also
refine basic types such as bools and ints that are available in TorchScript.
Args:
obj: object to refine the type of
target_type: type to try to refine obj to
Returns:
``bool``: True if obj was successfully refined to the type of target_type,
False otherwise with no new type refinement
Example (using ``torch.jit.isinstance`` for type refinement):
.. testcode::
import torch
from typing import Any, Dict, List
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, input: Any): # note the Any type
if torch.jit.isinstance(input, List[torch.Tensor]):
for t in input:
y = t.clamp(0, 0.5)
elif torch.jit.isinstance(input, Dict[str, str]):
for val in input.values():
print(val)
m = torch.jit.script(MyModule())
x = [torch.rand(3,3), torch.rand(4,3)]
m(x)
y = {"key1":"val1","key2":"val2"}
m(y)
"""
return _isinstance(obj, target_type)
class strict_fusion:
"""
Give errors if not all nodes have been fused in inference, or symbolically differentiated in training.
Example:
Forcing fusion of additions.
.. code-block:: python
@torch.jit.script
def foo(x):
with torch.jit.strict_fusion():
return x + x + x
"""
def __init__(self) -> None:
if not torch._jit_internal.is_scripting():
warnings.warn("Only works in script mode")
def __enter__(self):
pass
def __exit__(self, type: Any, value: Any, tb: Any) -> None:
pass
# Context manager for globally hiding source ranges when printing graphs.
# Note that these functions are exposed to Python as static members of the
# Graph class, so mypy checks need to be skipped.
@contextmanager
def _hide_source_ranges() -> Iterator[None]:
old_enable_source_ranges = torch._C.Graph.global_print_source_ranges # type: ignore[attr-defined]
try:
torch._C.Graph.set_global_print_source_ranges(False) # type: ignore[attr-defined]
yield
finally:
torch._C.Graph.set_global_print_source_ranges(old_enable_source_ranges) # type: ignore[attr-defined]
def enable_onednn_fusion(enabled: bool):
"""Enable or disables onednn JIT fusion based on the parameter `enabled`."""
torch._C._jit_set_llga_enabled(enabled)
def onednn_fusion_enabled():
"""Return whether onednn JIT fusion is enabled."""
return torch._C._jit_llga_enabled()
del Any
if not torch._C._jit_init():
raise RuntimeError("JIT initialization failed")

View File

@ -0,0 +1,102 @@
# mypy: allow-untyped-defs
"""Async API.
This module contains the API for parallelism in TorchScript, notably:
* torch.jit.fork
* torch.jit.wait
This is not intended to be imported directly; please use the exposed
functionalities in `torch.jit`.
"""
import torch
from torch._jit_internal import Future
from torch.jit._builtins import _register_builtin
from torch.utils import set_module
set_module(Future, "torch.jit")
def fork(func, *args, **kwargs):
r"""
Create an asynchronous task executing `func` and a reference to the value of the result of this execution.
`fork` will return immediately, so the return value of `func` may not have been computed yet. To force completion
of the task and access the return value invoke `torch.jit.wait` on the Future. `fork` invoked
with a `func` which returns `T` is typed as `torch.jit.Future[T]`. `fork` calls can be arbitrarily
nested, and may be invoked with positional and keyword arguments.
Asynchronous execution will only occur when run in TorchScript. If run in pure python,
`fork` will not execute in parallel. `fork` will also not execute in parallel when invoked
while tracing, however the `fork` and `wait` calls will be captured in the exported IR Graph.
.. warning::
`fork` tasks will execute non-deterministically. We recommend only spawning
parallel fork tasks for pure functions that do not modify their inputs,
module attributes, or global state.
Args:
func (callable or torch.nn.Module): A Python function or `torch.nn.Module`
that will be invoked. If executed in TorchScript, it will execute asynchronously,
otherwise it will not. Traced invocations of fork will be captured in the IR.
``*args``, ``**kwargs``: arguments to invoke `func` with.
Returns:
`torch.jit.Future[T]`: a reference to the execution of `func`. The value `T`
can only be accessed by forcing completion of `func` through `torch.jit.wait`.
Example (fork a free function):
.. code-block:: python
import torch
from torch import Tensor
def foo(a : Tensor, b : int) -> Tensor:
return a + b
def bar(a):
fut : torch.jit.Future[Tensor] = torch.jit.fork(foo, a, b=2)
return torch.jit.wait(fut)
script_bar = torch.jit.script(bar)
input = torch.tensor(2)
# only the scripted version executes asynchronously
assert script_bar(input) == bar(input)
# trace is not run asynchronously, but fork is captured in IR
graph = torch.jit.trace(bar, (input,)).graph
assert "fork" in str(graph)
Example (fork a module method):
.. code-block:: python
import torch
from torch import Tensor
class AddMod(torch.nn.Module):
def forward(self, a: Tensor, b : int):
return a + b
class Mod(torch.nn.Module):
def __init__(self) -> None:
super(self).__init__()
self.mod = AddMod()
def forward(self, input):
fut = torch.jit.fork(self.mod, a, b=2)
return torch.jit.wait(fut)
input = torch.tensor(2)
mod = Mod()
assert mod(input) == torch.jit.script(mod).forward(input)
"""
return torch._C.fork(func, *args, **kwargs)
def wait(future):
r"""
Force completion of a `torch.jit.Future[T]` asynchronous task, returning the result of the task.
See :func:`~fork` for docs and examples.
Args:
future (torch.jit.Future[T]): an asynchronous task reference, created through `torch.jit.fork`
Returns:
`T`: the return value of the completed task
"""
return torch._C.wait(future)
_register_builtin(wait, "aten::wait")

View File

@ -0,0 +1,27 @@
# mypy: allow-untyped-defs
import torch
from torch._jit_internal import _Await
from torch.jit._builtins import _register_builtin
from torch.utils import set_module
set_module(_Await, "torch.jit")
def _awaitable(func, *args, **kwargs):
r"""Create Await object that will call specified functioni with specified args, when it is requested for the result."""
return torch._C._awaitable(func, *args, **kwargs)
def _awaitable_wait(aw):
r"""Request await the result of execution, if Await is not completed yet, the func will be called immediately."""
return torch._C._awaitable_wait(aw)
def _awaitable_nowait(o):
r"""Create completed Await with specified result."""
return torch._C._awaitable_nowait(o)
_register_builtin(_awaitable_wait, "prim::awaitable_wait")
_register_builtin(_awaitable_nowait, "prim::awaitable_nowait")

View File

@ -0,0 +1,193 @@
# mypy: allow-untyped-defs
import cmath
import math
import warnings
from collections import OrderedDict
from typing import Dict, Optional
import torch
import torch.backends.cudnn as cudnn
from torch.nn.modules.utils import (
_list_with_default,
_pair,
_quadruple,
_single,
_triple,
)
_builtin_table: Optional[Dict[int, str]] = None
_modules_containing_builtins = (torch, torch._C._nn, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._sparse, torch._C._special) # type: ignore[attr-defined] # noqa: B950
_builtin_ops = [
# Pairs of (function, op_name)
(_pair, "aten::_pair"),
(_quadruple, "aten::_quadruple"),
(_single, "aten::_single"),
(_triple, "aten::_triple"),
(_list_with_default, "aten::list_with_default"),
(OrderedDict, "aten::dict"),
(dict, "aten::dict"),
(cudnn.is_acceptable, "aten::cudnn_is_acceptable"),
(math.ceil, "aten::ceil"),
(math.copysign, "aten::copysign"),
(math.erf, "aten::erf"),
(math.erfc, "aten::erfc"),
(math.exp, "aten::exp"),
(math.expm1, "aten::expm1"),
(math.fabs, "aten::fabs"),
(math.floor, "aten::floor"),
(math.gamma, "aten::gamma"),
(math.lgamma, "aten::lgamma"),
(math.log, "aten::log"),
(math.log10, "aten::log10"),
(math.log1p, "aten::log1p"),
(math.pow, "aten::pow"),
(math.sqrt, "aten::sqrt"),
(math.isnan, "aten::isnan"),
(math.asinh, "aten::asinh"),
(math.atanh, "aten::atanh"),
(math.cosh, "aten::cosh"),
(math.sinh, "aten::sinh"),
(math.tanh, "aten::tanh"),
(math.acos, "aten::acos"),
(math.asin, "aten::asin"),
(math.atan, "aten::atan"),
(math.atan2, "aten::atan2"),
(math.cos, "aten::cos"),
(math.sin, "aten::sin"),
(math.tan, "aten::tan"),
(math.asinh, "aten::asinh"),
(math.atanh, "aten::atanh"),
(math.acosh, "aten::acosh"),
(math.fmod, "aten::fmod"),
(math.modf, "aten::modf"),
(math.factorial, "aten::factorial"),
(math.frexp, "aten::frexp"),
(math.isinf, "aten::isinf"),
(math.degrees, "aten::degrees"),
(math.radians, "aten::radians"),
(cmath.isnan, "aten::isnan"),
(cmath.isfinite, "aten::isfinite"),
(cmath.isinf, "aten::isinf"),
(cmath.phase, "aten::angle"),
(cmath.rect, "aten::polar"),
(cmath.log, "aten::log"),
(cmath.log10, "aten::log10"),
(cmath.sqrt, "aten::sqrt"),
(cmath.exp, "aten::exp"),
(cmath.sin, "aten::sin"),
(cmath.tan, "aten::tan"),
(cmath.cos, "aten::cos"),
(cmath.asin, "aten::asin"),
(cmath.acos, "aten::acos"),
(cmath.atan, "aten::atan"),
(cmath.sinh, "aten::sinh"),
(cmath.cosh, "aten::cosh"),
(cmath.tanh, "aten::tanh"),
(cmath.asinh, "aten::asinh"),
(cmath.acosh, "aten::acosh"),
(cmath.atanh, "aten::atanh"),
(math.ldexp, "aten::ldexp"),
(torch._assert, "aten::_assert"),
(torch.autograd.grad, "aten::grad"),
(torch.autograd.backward, "aten::backward"),
(torch._C._infer_size, "aten::_infer_size"),
(torch.nn.functional._no_grad_embedding_renorm_, "aten::_no_grad_embedding_renorm_"), # type: ignore[attr-defined]
(torch.nn.functional.assert_int_or_pair, "aten::_assert_int_or_pair"),
(torch.nn.init._no_grad_fill_, "aten::_no_grad_fill_"),
(torch.nn.init._no_grad_normal_, "aten::_no_grad_normal_"),
(torch.nn.init._no_grad_uniform_, "aten::_no_grad_uniform_"),
(torch.nn.init._no_grad_zero_, "aten::_no_grad_zero_"),
(torch._C._get_tracing_state, "aten::_get_tracing_state"),
(torch._C._get_cpu_capability, "aten::_get_cpu_capability"),
(warnings.warn, "aten::warn"),
(torch._VF.stft, "aten::stft"), # type: ignore[attr-defined]
(torch._VF.istft, "aten::istft"), # type: ignore[attr-defined]
(torch._VF.cdist, "aten::cdist"), # type: ignore[attr-defined]
(torch._VF.norm, "aten::norm"), # type: ignore[attr-defined]
(torch._VF.unique_dim, "aten::unique_dim"),
(torch._VF.unique_consecutive, "aten::unique_consecutive"), # type: ignore[attr-defined]
(torch._VF.nuclear_norm, "aten::nuclear_norm"),
(torch._VF.frobenius_norm, "aten::frobenius_norm"),
(torch._VF.tensordot, "aten::tensordot"), # type: ignore[attr-defined]
]
# ops in torch.functional are bound to torch
# in these cases, we want to resolve the function to their python implementation
# instead looking up a builtin "aten::" schema
def _gen_torch_functional_registered_ops():
# eventually ops should encompass all of torch/functional.py, (torch.functional.__all__)
# but we are currently only able to compile some of the functions. additionally,
# some functions directly map to their aten:: implementations.
# TODO: add support for more ops
ops = [
"stft",
"istft",
"lu",
"cdist",
"norm",
"unique",
"unique_consecutive",
"tensordot",
]
return {getattr(torch.functional, name) for name in ops}
_functional_registered_ops = _gen_torch_functional_registered_ops()
def _is_special_functional_bound_op(fn):
return fn in _functional_registered_ops
# lazily built to ensure the correct initialization order
def _get_builtin_table():
global _builtin_table
if _builtin_table is not None:
return _builtin_table
_builtin_table = {}
def register_all(mod):
for name in dir(mod):
v = getattr(mod, name)
if (
callable(v)
and not _is_special_functional_bound_op(v)
and v is not torch.no_grad
and v is not torch.autocast
):
# Fixup inconsistency in segment_reduce
if name == "_segment_reduce":
name = name[1:]
_builtin_ops.append((v, "aten::" + name))
for mod in _modules_containing_builtins:
register_all(mod)
_builtin_ops.append((math.gcd, "aten::gcd"))
_builtin_ops.append((math.isfinite, "aten::isfinite"))
_builtin_ops.append((math.remainder, "aten::mathremainder")) # type: ignore[attr-defined]
import torch.distributed.autograd as dist_autograd
if dist_autograd.is_available():
_builtin_ops.append((dist_autograd.get_gradients, "aten::get_gradients"))
_builtin_ops.append((dist_autograd.backward, "aten::dist_backward"))
# populate the _builtin_table from _builtin_ops
for builtin, aten_op in _builtin_ops:
_builtin_table[id(builtin)] = aten_op
return _builtin_table
def _register_builtin(fn, op):
_get_builtin_table()[id(fn)] = op
def _find_builtin(fn):
return _get_builtin_table().get(id(fn))

View File

@ -0,0 +1,249 @@
# mypy: allow-untyped-defs
import ast
import inspect
import textwrap
import warnings
import torch
class AttributeTypeIsSupportedChecker(ast.NodeVisitor):
"""Check the ``__init__`` method of a given ``nn.Module``.
It ensures that all instance-level attributes can be properly initialized.
Specifically, we do type inference based on attribute values...even
if the attribute in question has already been typed using
Python3-style annotations or ``torch.jit.annotate``. This means that
setting an instance-level attribute to ``[]`` (for ``List``),
``{}`` for ``Dict``), or ``None`` (for ``Optional``) isn't enough
information for us to properly initialize that attribute.
An object of this class can walk a given ``nn.Module``'s AST and
determine if it meets our requirements or not.
Known limitations
1. We can only check the AST nodes for certain constructs; we can't
``eval`` arbitrary expressions. This means that function calls,
class instantiations, and complex expressions that resolve to one of
the "empty" values specified above will NOT be flagged as
problematic.
2. We match on string literals, so if the user decides to use a
non-standard import (e.g. `from typing import List as foo`), we
won't catch it.
Example:
.. code-block:: python
class M(torch.nn.Module):
def fn(self):
return []
def __init__(self) -> None:
super().__init__()
self.x: List[int] = []
def forward(self, x: List[int]):
self.x = x
return 1
The above code will pass the ``AttributeTypeIsSupportedChecker``
check since we have a function call in ``__init__``. However,
it will still fail later with the ``RuntimeError`` "Tried to set
nonexistent attribute: x. Did you forget to initialize it in
__init__()?".
Args:
nn_module - The instance of ``torch.nn.Module`` whose
``__init__`` method we wish to check
"""
def check(self, nn_module: torch.nn.Module) -> None:
source_lines = inspect.getsource(nn_module.__class__.__init__)
# Ignore comments no matter the indentation
def is_useless_comment(line):
line = line.strip()
return line.startswith("#") and not line.startswith("# type:")
source_lines = "\n".join(
[l for l in source_lines.split("\n") if not is_useless_comment(l)]
)
# This AST only contains the `__init__` method of the nn.Module
init_ast = ast.parse(textwrap.dedent(source_lines))
# Get items annotated in the class body
self.class_level_annotations = list(nn_module.__annotations__.keys())
# Flag for later
self.visiting_class_level_ann = False
self.visit(init_ast)
def _is_empty_container(self, node: ast.AST, ann_type: str) -> bool:
if ann_type == "List":
# Assigning `[]` to a `List` type gives you a Node where
# value=List(elts=[], ctx=Load())
if not isinstance(node, ast.List):
return False
if node.elts:
return False
elif ann_type == "Dict":
# Assigning `{}` to a `Dict` type gives you a Node where
# value=Dict(keys=[], values=[])
if not isinstance(node, ast.Dict):
return False
if node.keys:
return False
elif ann_type == "Optional":
# Assigning `None` to an `Optional` type gives you a
# Node where value=Constant(value=None, kind=None)
if not isinstance(node, ast.Constant):
return False
if node.value: # type: ignore[attr-defined]
return False
return True
def visit_Assign(self, node):
"""Store assignment state when assigning to a Call Node.
If we're visiting a Call Node (the right-hand side of an
assignment statement), we won't be able to check the variable
that we're assigning to (the left-hand side of an assignment).
Because of this, we need to store this state in visitAssign.
(Luckily, we only have to do this if we're assigning to a Call
Node, i.e. ``torch.jit.annotate``. If we're using normal Python
annotations, we'll be visiting an AnnAssign Node, which has its
target built in.)
"""
try:
if (
isinstance(node.value, ast.Call)
and node.targets[0].attr in self.class_level_annotations
):
self.visiting_class_level_ann = True
except AttributeError:
return
self.generic_visit(node)
self.visiting_class_level_ann = False
def visit_AnnAssign(self, node):
"""Visit an AnnAssign node in an ``nn.Module``'s ``__init__`` method.
It checks if it conforms to our attribute annotation rules."""
# If we have a local variable
try:
if node.target.value.id != "self":
return
except AttributeError:
return
# If we have an attribute that's already been annotated at the
# class level
if node.target.attr in self.class_level_annotations:
return
# TODO @ansley: add `Union` once landed
# NB: Even though `Tuple` is a "container", we don't want to
# check for it here. `Tuple` functions as an type with an
# "infinite" number of subtypes, in the sense that you can have
# `Tuple[())]`, `Tuple[T1]`, `Tuple[T2]`, `Tuple[T1, T2]`,
# `Tuple[T2, T1]` and so on, and none of these subtypes can be
# used in place of the other. Therefore, assigning an empty
# tuple in `__init__` CORRECTLY means that that variable
# cannot be reassigned later to a non-empty tuple. Same
# deal with `NamedTuple`
containers = {"List", "list", "Dict", "dict", "Optional"}
# If we're not evaluating one of the specified problem types
try:
if node.annotation.value.id not in containers:
return
except AttributeError:
# To evaluate a base type (`str`, `int`, etc.), we would
# have needed to get the name through `node.annotation.id`
# instead of `node.annotation.value.id`. Seems that we're
# not evaluating one of our "containers"
return
# Check if the assigned variable is empty
ann_type = node.annotation.value.id
if not self._is_empty_container(node.value, ann_type):
return
warnings.warn(
"The TorchScript type system doesn't support "
"instance-level annotations on empty non-base "
"types in `__init__`. Instead, either 1) use a "
"type annotation in the class body, or 2) wrap "
"the type in `torch.jit.Attribute`."
)
def visit_Call(self, node):
"""Determine if a Call node is 'torch.jit.annotate' in __init__.
Visit a Call node in an ``nn.Module``'s ``__init__``
method and determine if it's ``torch.jit.annotate``. If so,
see if it conforms to our attribute annotation rules.
"""
# If we have an attribute that's already been annotated at the
# class level
if self.visiting_class_level_ann:
return
# If this isn't a call to `torch.jit.annotate`
try:
if (
node.func.value.value.id != "torch"
or node.func.value.attr != "jit"
or node.func.attr != "annotate"
):
self.generic_visit(node)
elif (
node.func.value.value.id != "jit" or node.func.value.attr != "annotate"
):
self.generic_visit(node)
except AttributeError:
# Looks like we didn't even have the right node structure
# to check for `torch.jit.annotate` in the first place
self.generic_visit(node)
# Invariant: we have a `torch.jit.annotate` or a
# `torch.annotate` call
# A Call Node for `torch.jit.annotate` should have an `args`
# list of length 2 where args[0] represents the annotation and
# args[1] represents the actual value
if len(node.args) != 2:
return
if not isinstance(node.args[0], ast.Subscript):
return
# See notes in `visit_AnnAssign` r.e. containers
containers = {"List", "Dict", "Optional"}
try:
ann_type = node.args[0].value.id # type: ignore[attr-defined]
except AttributeError:
return
if ann_type not in containers:
return
# Check if the assigned variable is empty
if not self._is_empty_container(node.args[1], ann_type):
return
warnings.warn(
"The TorchScript type system doesn't support "
"instance-level annotations on empty non-base "
"types in `__init__`. Instead, either 1) use a "
"type annotation in the class body, or 2) wrap "
"the type in `torch.jit.Attribute`."
)

View File

@ -0,0 +1,190 @@
# mypy: allow-untyped-defs
# Functions for synthesizing magic methods for JIT-compiled dataclasses
import ast
import dataclasses
import inspect
import os
from functools import partial
from typing import Callable, Dict, List
from torch._jit_internal import FAKE_FILENAME_PREFIX, is_optional
from torch._sources import ParsedDef, SourceContext
def _get_fake_filename(cls, method_name):
return os.path.join(FAKE_FILENAME_PREFIX, cls.__name__, method_name)
def compose_fn(cls, name: str, body_lines: List[str], signature: str) -> ParsedDef:
body = "\n".join(f" {b}" for b in body_lines)
decl = f"def {name}{signature}:\n{body}"
# Parse the function declaration
try:
py_ast = ast.parse(decl)
except SyntaxError as e:
# This should only happen if there's some unforeseeable change
# in the dataclasses module that makes our synthesized code fail
raise RuntimeError(
f"TorchScript failed to synthesize dataclass method '{name}' for class '{cls.__name__}'. "
"Please file a bug report at <https://github.com/pytorch/pytorch/issues>"
) from e
fake_filename = _get_fake_filename(cls, name)
# Parse the function
return ParsedDef(
py_ast,
ctx=SourceContext(
source=decl, filename=fake_filename, file_lineno=0, leading_whitespace_len=0
),
source=decl,
filename=fake_filename,
file_lineno=0,
)
def synthesize__init__(cls) -> ParsedDef:
# Supporting default factories in the way that people expect would sort of require us to
# allow compiling lambda functions, which is not currently supported.
if any(
field.default_factory is not dataclasses.MISSING
for field in dataclasses.fields(cls)
):
raise NotImplementedError(
"Default factory initializers are not supported in TorchScript dataclasses"
)
# Simply read off the generated __init__ signature from CPython's implementation. It'll be
# almost correct except for InitVar annotations, which we need to handle specially.
signature = inspect.signature(cls.__init__)
# Handle InitVars if needed (only works on Python 3.8+, when a `type` attribute was added to InitVar);
# see CPython commit here https://github.com/python/cpython/commit/01ee12ba35a333e8a6a25c4153c4a21838e9585c
init_vars: List[str] = []
params = []
for name, param in signature.parameters.items():
ann = param.annotation
if isinstance(ann, dataclasses.InitVar):
# The TorchScript interpreter can't handle InitVar annotations, so we unwrap the underlying type here
init_vars.append(name)
params.append(param.replace(annotation=ann.type)) # type: ignore[attr-defined]
else:
params.append(param)
signature = signature.replace(parameters=params)
body = [
# Assign all attributes to self
f"self.{field.name} = {field.name}"
for field in dataclasses.fields(cls)
if field.init and field.name not in init_vars
]
# Call user's impl of __post_init__ if it exists
if hasattr(cls, "__post_init__"):
body.append("self.__post_init__(" + ", ".join(init_vars) + ")")
return compose_fn(cls, "__init__", body or ["pass"], signature=str(signature))
# This is a placeholder at the moment since the TorchScript interpreter doesn't call __repr__
def synthesize__repr__(cls) -> ParsedDef:
return compose_fn(
cls,
"__repr__",
[
f"return '{cls.__name__}("
+ ", ".join(
[
f"{field.name}=self.{field.name}"
for field in dataclasses.fields(cls)
if field.repr
]
)
+ ")'"
],
signature="(self) -> str",
)
def synthesize__hash__(cls) -> ParsedDef:
return compose_fn(
cls,
"__hash__",
[
# This is just a placeholder to prevent compilation from failing; this won't even get called at
# all right now because the TorchScript interpreter doesn't call custom __hash__ implementations
"raise NotImplementedError('__hash__ is not supported for dataclasses in TorchScript')"
],
signature="(self) -> int",
)
# Implementation for __eq__ and __ne__
def synthesize_equality(cls, name: str, converse: str) -> ParsedDef:
return synthesize_comparison(
cls,
name,
allow_eq=True,
raise_on_none=False,
inner=[f"if val1 {converse} val2: return False"],
)
def synthesize_inequality(cls, name: str, op: str, allow_eq: bool) -> ParsedDef:
return synthesize_comparison(
cls,
name,
allow_eq,
raise_on_none=True,
inner=[
f"if val1 {op} val2: return True",
f"elif val2 {op} val1: return False",
],
)
def synthesize_comparison(
cls, name: str, allow_eq: bool, raise_on_none: bool, inner: List[str]
) -> ParsedDef:
body = []
for field in dataclasses.fields(cls):
if not field.compare:
continue
body.extend(
[
f"val1 = self.{field.name}",
f"val2 = other.{field.name}",
]
)
body.extend(
inner
if not is_optional(field.type)
else [
# Type refinement for optional fields; we need this to avoid type errors from the interpreter
"if val1 is not None and val2 is not None:",
*[" " + line for line in inner],
"elif (val1 is None) != (val2 is None):",
f" raise TypeError('Cannot compare {cls.__name__} with None')"
if raise_on_none
else " return False",
]
)
body.append(f"return {allow_eq}")
return compose_fn(
cls, name, body, signature=f"(self, other: {cls.__name__}) -> bool"
)
DATACLASS_MAGIC_METHODS: Dict[str, Callable] = {
"__init__": synthesize__init__,
"__repr__": synthesize__repr__,
"__hash__": synthesize__hash__,
"__eq__": partial(synthesize_equality, name="__eq__", converse="!="),
"__ne__": partial(synthesize_equality, name="__ne__", converse="=="),
"__lt__": partial(synthesize_inequality, name="__lt__", op="<", allow_eq=False),
"__le__": partial(synthesize_inequality, name="__le__", op="<", allow_eq=True),
"__gt__": partial(synthesize_inequality, name="__gt__", op=">", allow_eq=False),
"__ge__": partial(synthesize_inequality, name="__ge__", op=">", allow_eq=True),
}

View File

@ -0,0 +1,12 @@
# mypy: allow-untyped-defs
import torch
from torch._ops import OpOverload, OpOverloadPacket
def _register_decomposition(op: OpOverload, graph: torch._C.Graph):
assert not isinstance(
op, OpOverloadPacket
), f"Must pass specific op overload, not overload packet, found {op}"
assert isinstance(op, OpOverload)
torch._C._jit_register_decomposition_for_schema(op._schema, graph)

View File

@ -0,0 +1,137 @@
# mypy: allow-untyped-defs
import torch
from torch import Tensor
aten = torch.ops.aten
import inspect
import warnings
from typing import Callable, Dict, List, Optional, Set, TypeVar
from typing_extensions import ParamSpec
from torch.types import Number
decomposition_table: Dict[str, torch.jit.ScriptFunction] = {}
function_name_set: Set[str] = set()
_T = TypeVar("_T")
_P = ParamSpec("_P")
def check_decomposition_has_type_annotations(f):
inspect_empty = inspect._empty # type: ignore[attr-defined]
sig = inspect.signature(f)
for param in sig.parameters.values():
assert (
param.annotation != inspect_empty
), f"No signature on param {param.name} for function {f.name}"
assert (
sig.return_annotation != inspect_empty
), f"No return annotation for function {f.name}"
def signatures_match(decomposition_sig, torch_op_sig):
decomp_params = decomposition_sig.parameters
op_params = torch_op_sig.parameters
if len(decomp_params) != len(op_params):
return False
for decomp_param, op_param in zip(decomp_params.values(), op_params.values()):
# can't check full equality yet because not all fields are correcly deduced
# in the torch_op_sig - like default value
# can't check 'kind' bc
# kwarg-only values with defaults not yet supported in TS
inspect_empty = inspect._empty # type: ignore[attr-defined]
for field in ["name", "annotation"]:
if field == "name" and decomp_param.name == "self":
warnings.warn("PyTorch uses 'input' instead of 'self' on public api")
if getattr(decomp_param, field) != getattr(op_param, field):
return False
decomp_default = decomp_param.default
op_default = op_param.default
# default value not always correctly inferred as being present on torch schema,
# but if specified on both they should be equal
if decomp_default != inspect_empty and op_default != inspect_empty:
if decomp_default != op_default:
return False
return decomposition_sig.return_annotation == torch_op_sig.return_annotation
def register_decomposition(
aten_op: torch._ops.OpOverload,
registry: Optional[Dict[str, torch.jit.ScriptFunction]] = None,
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
def decomposition_decorator(f: Callable[_P, _T]) -> Callable[_P, _T]:
nonlocal registry
if registry is None:
registry = decomposition_table
assert isinstance(aten_op, torch._ops.OpOverload)
# Need unique name for jit function serialization
assert (
f.__name__ not in function_name_set
), f"Duplicated function name {f.__name__}"
function_name_set.add(f.__name__)
scripted_func = torch.jit.script(f)
torch._C._jit_pass_inline(scripted_func.graph)
for _ in range(2):
torch._C._jit_pass_peephole(scripted_func.graph)
torch._C._jit_pass_constant_propagation(scripted_func.graph)
registry[str(aten_op._schema)] = scripted_func
return f
return decomposition_decorator
# TODO: replace torch.sigmoid -> aten.sigmoid
@register_decomposition(aten.var.correction)
def var_decomposition(
input: Tensor,
dim: Optional[List[int]] = None,
correction: Optional[Number] = None,
keepdim: bool = False,
) -> Tensor:
if dim is None:
dim_i: List[int] = []
dim = dim_i
if isinstance(dim, (tuple, list)) and len(dim) == 0:
n = input.numel()
else:
n = 1
for dim_i in dim: # type: ignore[assignment]
n *= input.shape[dim_i] # type: ignore[call-overload]
mean = aten.mean(input, dim, True)
sub = input - mean
sq = sub * sub
sum = aten.sum(sq, dim, keepdim)
if correction is None:
denom = float(n - 1)
else:
if isinstance(correction, int):
denom = float(n - correction)
elif isinstance(correction, float):
denom = float(n) - correction
else:
raise RuntimeError("correction must be int or float")
return sum / max(0, denom)
@register_decomposition(aten.var.default)
def var(input: Tensor, unbiased: bool = True) -> Tensor:
return var_decomposition(input, correction=(1 if unbiased else 0))

View File

@ -0,0 +1,228 @@
# mypy: allow-untyped-defs
"""Freezing.
This is not intended to be imported directly; please use the exposed
functionalities in `torch.jit`.
"""
from typing import List, Optional
import torch
from torch.jit._script import RecursiveScriptModule, ScriptModule
def freeze(
mod, preserved_attrs: Optional[List[str]] = None, optimize_numerics: bool = True
):
r"""Freeze ScriptModule, inline submodules, and attributes as constants.
Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned
module's submodules, parameters, and attributes as constants in the TorchScript IR Graph.
By default, `forward` will be preserved, as well as attributes & methods specified in
`preserved_attrs`. Additionally, any attribute that is modified within a preserved
method will be preserved.
Freezing currently only accepts ScriptModules that are in eval mode.
Freezing applies generic optimization that will speed up your model regardless of machine.
To further optimize using server-specific settings, run `optimize_for_inference` after
freezing.
Args:
mod (:class:`ScriptModule`): a module to be frozen
preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method.
Attributes modified in preserved methods will also be preserved.
optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly
preserve numerics. Full details of optimization can be found at `torch.jit.run_frozen_optimizations`.
Returns:
Frozen :class:`ScriptModule`.
Example (Freezing a simple module with a Parameter):
.. testcode::
import torch
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super().__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
self.linear = torch.nn.Linear(N, M)
def forward(self, input):
output = self.weight.mm(input)
output = self.linear(output)
return output
scripted_module = torch.jit.script(MyModule(2, 3).eval())
frozen_module = torch.jit.freeze(scripted_module)
# parameters have been removed and inlined into the Graph as constants
assert len(list(frozen_module.named_parameters())) == 0
# See the compiled graph as Python code
print(frozen_module.code)
Example (Freezing a module with preserved attributes)
.. testcode::
import torch
class MyModule2(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.modified_tensor = torch.tensor(10.)
self.version = 1
def forward(self, input):
self.modified_tensor += 1
return input + self.modified_tensor
scripted_module = torch.jit.script(MyModule2().eval())
frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"])
# we've manually preserved `version`, so it still exists on the frozen module and can be modified
assert frozen_module.version == 1
frozen_module.version = 2
# `modified_tensor` is detected as being mutated in the forward, so freezing preserves
# it to retain model semantics
assert frozen_module(torch.tensor(1)) == torch.tensor(12)
# now that we've run it once, the next result will be incremented by one
assert frozen_module(torch.tensor(1)) == torch.tensor(13)
Note:
Freezing submodule attributes is also supported:
frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["submodule.version"])
Note:
If you're not sure why an attribute is not being inlined as a constant, you can run
`dump_alias_db` on frozen_module.forward.graph to see if freezing has detected the
attribute is being modified.
Note:
Because freezing makes weights constants and removes module hierarchy, `to` and other
nn.Module methods to manipulate device or dtype no longer work. As a workaround,
You can remap devices by specifying `map_location` in `torch.jit.load`, however
device-specific logic may have been baked into the model.
"""
if not isinstance(mod, ScriptModule):
raise RuntimeError(
"Freezing expects a ScriptModule as input. "
"Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'."
)
if mod.training:
raise RuntimeError(
"Freezing is currently only implemented for modules in eval mode. "
"Please call .eval() on your module before freezing."
)
preserved_attrs = preserved_attrs if preserved_attrs is not None else []
out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs))
RecursiveScriptModule._finalize_scriptmodule(out)
preserved_methods = [x for x in preserved_attrs if mod._c._has_method(x)]
run_frozen_optimizations(out, optimize_numerics, preserved_methods)
return out
def run_frozen_optimizations(
mod, optimize_numerics: bool = True, preserved_methods: Optional[List[str]] = None
):
r"""
Run a series of optimizations looking for patterns that occur in frozen graphs.
The current set of optimizations includes:
- Dropout Removal
- Pretranspose Linear Layers
- Concat Linear Layers with same input Tensor
- Conv -> Batchnorm folding
- Conv -> Add/Sub folding
- Conv -> Mul/Div folding
Args:
mod (:class:`ScriptModule`): a frozen module to be optimized
optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly
preserve numerics. These optimizations preserve default rtol and atol of `torch.testing.assert_close`
when applied on a single transformation, however in a module where many transformations are applied
the rtol or atol may no longer fall within the default `assert_close` tolerance. Conv -> Batchnorm folding,
Conv-Add/Sub, and Conv -> Mul/Div folding all may alter numerics.
Returns:
None
Note:
In rare occassions, this can result in slower execution.
Example (Freezing a module with Conv->Batchnorm)
.. code-block:: python
import torch
in_channels, out_channels = 3, 32
conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
mod = torch.nn.Sequential(conv, bn)
# set optimize to False here, by default freezing runs run_frozen_optimizations
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False)
# inspect frozen mod
assert "batch_norm" in str(frozen_mod.graph)
torch.jit.run_frozen_optimizations(frozen_mod)
assert "batch_norm" not in str(frozen_mod.graph)
"""
if mod._c._has_method("forward"):
torch._C._jit_pass_optimize_frozen_graph(mod.graph, optimize_numerics)
if preserved_methods is None:
preserved_methods = []
for method in preserved_methods:
torch._C._jit_pass_optimize_frozen_graph(
mod.__getattr__(method).graph, optimize_numerics
)
def optimize_for_inference(
mod: ScriptModule, other_methods: Optional[List[str]] = None
) -> ScriptModule:
"""
Perform a set of optimization passes to optimize a model for the purposes of inference.
If the model is not already frozen, optimize_for_inference
will invoke `torch.jit.freeze` automatically.
In addition to generic optimizations that should speed up your model regardless
of environment, prepare for inference will also bake in build specific settings
such as the presence of CUDNN or MKLDNN, and may in the future make transformations
which speed things up on one machine but slow things down on another. Accordingly,
serialization is not implemented following invoking `optimize_for_inference` and
is not guaranteed.
This is still in prototype, and may have the potential to slow down your model.
Primary use cases that have been targeted so far have been vision models on cpu
and gpu to a lesser extent.
Example (optimizing a module with Conv->Batchnorm)::
import torch
in_channels, out_channels = 3, 32
conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
mod = torch.nn.Sequential(conv, bn)
frozen_mod = torch.jit.optimize_for_inference(torch.jit.script(mod.eval()))
assert "batch_norm" not in str(frozen_mod.graph)
# if built with MKLDNN, convolution will be run with MKLDNN weights
assert "MKLDNN" in frozen_mod.graph
"""
if not isinstance(mod, ScriptModule):
raise RuntimeError(
"optimize_for_inference expects a ScriptModule as input. "
"Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'."
)
if other_methods is None:
other_methods = []
if hasattr(mod, "training"):
mod = freeze(mod.eval(), preserved_attrs=other_methods)
torch._C._jit_pass_optimize_for_inference(mod._c, other_methods)
return mod

View File

@ -0,0 +1,161 @@
# mypy: allow-untyped-defs
import contextlib
from typing import List, Tuple
import torch
@contextlib.contextmanager
def optimized_execution(should_optimize):
"""Context manager that controls whether the JIT's executor will run optimizations before executing a function."""
stored_flag = torch._C._get_graph_executor_optimize()
torch._C._set_graph_executor_optimize(should_optimize)
try:
yield
finally:
torch._C._set_graph_executor_optimize(stored_flag)
@contextlib.contextmanager
def fuser(name):
"""Context manager that facilitates switching between backend fusers.
Valid names:
* ``fuser0`` - enables only legacy fuser
* ``fuser1`` - enables only NNC
* ``fuser2`` - enables only nvFuser
* ``fuser3`` - enables oneDNN Graph
"""
old_cpu_fuse = torch._C._jit_can_fuse_on_cpu()
old_gpu_fuse = torch._C._jit_can_fuse_on_gpu()
old_texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
old_nvfuser_state = torch._C._jit_nvfuser_enabled()
old_llga_state = torch._C._jit_llga_enabled()
if name == "fuser0": # legacy fuser
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)
torch._C._jit_set_llga_enabled(False)
elif name == "fuser1": # NNC
old_profiling_executor = torch._C._jit_set_profiling_executor(True)
old_profiling_mode = torch._C._get_graph_executor_optimize(True)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
torch._C._jit_set_texpr_fuser_enabled(True)
torch._C._jit_set_nvfuser_enabled(False)
torch._C._jit_set_llga_enabled(False)
elif name == "fuser2": # nvFuser
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(True)
torch._C._jit_set_llga_enabled(False)
elif name == "fuser3": # oneDNN Graph
old_profiling_executor = torch._C._jit_set_profiling_executor(True)
old_profiling_mode = torch._C._get_graph_executor_optimize(True)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(True)
torch._C._jit_set_nvfuser_enabled(False)
torch._C._jit_set_llga_enabled(True)
elif name == "none": # Turn Pytorch fuser off
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)
torch._C._jit_set_llga_enabled(False)
else:
raise Exception(f"unrecognized fuser option (name: {name})") # noqa: TRY002
try:
yield
finally:
if name in ["fuser1", "fuser3"]: # NNC or oneDNN Graph
torch._C._jit_set_profiling_executor(old_profiling_executor) # type: ignore[possibly-undefined]
torch._C._get_graph_executor_optimize(old_profiling_mode) # type: ignore[possibly-undefined]
# recover the previous values
torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse)
torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse)
torch._C._jit_set_texpr_fuser_enabled(old_texpr_fuser_state)
torch._C._jit_set_nvfuser_enabled(old_nvfuser_state)
torch._C._jit_set_llga_enabled(old_llga_state)
last_executed_optimized_graph = torch._C._last_executed_optimized_graph
def _get_differentiable_graph_node(node, diff_node):
if node.kind() == "prim::DifferentiableGraph":
diff_node.append(node)
else:
for block in node.blocks():
for n in block.nodes():
_get_differentiable_graph_node(n, diff_node)
def _graph_for(self, *args, **kwargs):
return _script_method_graph_for(self, self, *args, **kwargs)
def _script_method_graph_for(self, parent, *args, **kwargs):
try:
dbs = parent.get_debug_state()
eps = list(dbs.execution_plans.values())
assert len(eps) == 1
graph = eps[0].graph.copy()
# graph_executor_states for differentiable node
fw_states = eps[0].code.differentiable_op_executor_states()
diff_nodes: List[torch._C.Node] = []
for n in graph.nodes():
_get_differentiable_graph_node(n, diff_nodes)
assert len(fw_states) == len(diff_nodes)
# swap each differentiable graph with optimized graph in their execution plan
for n, state in zip(diff_nodes, fw_states):
fw_execution_plans = list(state.execution_plans.values())
# we can only update the subgraph when there's a unique execution
# plan. Avoid assert here so we would skip the ones that can't be
# updated while try the best effort to update other nodes.
if len(fw_execution_plans) == 1:
n.g_("Subgraph", fw_execution_plans[0].graph)
return graph
except Exception:
# fallback approach, we just ran the graph and return the recorded optimized
# graph
self(*args, **kwargs)
return last_executed_optimized_graph()
def set_fusion_strategy(strategy: List[Tuple[str, int]]):
"""Set the type and number of specializations that can occur during fusion.
Usage: provide a list of pairs (type, depth) where type is one of "STATIC" or "DYNAMIC"
and depth is an integer.
Behavior - static vs dynamic:
In STATIC fusion, fused ops are compiled to have fixed input shapes. The shape is determined
based on some initial profiling runs.
In DYNAMIC fusion, fused ops are compiled to have variable input shapes, so that multiple
shapes are possible.
In both cases, we also recompile on new striding behavior, device, or dtype.
Behavior - fallback functions & depth:
When an input doesn't match the format required by the specialized compiled op, it will run
a fallback function. Fallback functions are recursively be compiled and specialized based
on the observed tensor shapes. Since compilation can be slow, the "depth" parameter is provided to
limit the number of specializations that can be compiled, before giving up on recompiling and
falling back to a completely un-fused, un-specialized implementation.
The list of (type, depth) pairs controls the type of specializations and the number of
specializations. For example: [("STATIC", 2), ("DYNAMIC", 2)] indicates that the first
two specializations will use static fusions, the following two specializations will use
dynamic fusion, and any inputs that satisfy none of the 4 options will run an
unfused implementation.
NB: in the future, if more as more fusion backends are added there may be more granular
apis for specific fusers.
"""
return torch._C._jit_set_fusion_strategy(strategy)

View File

@ -0,0 +1,26 @@
# mypy: allow-untyped-defs
from typing import Union
import torch
class _InsertPoint:
def __init__(
self,
insert_point_graph: torch._C.Graph,
insert_point: Union[torch._C.Node, torch._C.Block],
):
self.insert_point = insert_point
self.g = insert_point_graph
self.guard = None
def __enter__(self):
self.prev_insert_point = self.g.insertPoint()
self.g.setInsertPoint(self.insert_point)
def __exit__(self, *args):
self.g.setInsertPoint(self.prev_insert_point)
def insert_point_guard(self, insert_point: Union[torch._C.Node, torch._C.Block]):
return _InsertPoint(self, insert_point)

View File

@ -0,0 +1,11 @@
import torch
add_stat_value = torch.ops.prim.AddStatValue
set_logger = torch._C._logging_set_logger
LockingLogger = torch._C.LockingLogger
AggregationType = torch._C.AggregationType
NoopLogger = torch._C.NoopLogger
time_point = torch.ops.prim.TimePoint

View File

@ -0,0 +1,194 @@
# mypy: allow-untyped-defs
import inspect
import sys
import typing
from collections import defaultdict
from pathlib import Path
from types import CodeType
from typing import Dict, Iterable, List, Optional
import torch
_IS_MONKEYTYPE_INSTALLED = True
try:
import monkeytype # type: ignore[import]
from monkeytype import trace as monkeytype_trace
from monkeytype.config import _startswith, LIB_PATHS # type: ignore[import]
from monkeytype.db.base import ( # type: ignore[import]
CallTraceStore,
CallTraceStoreLogger,
CallTraceThunk,
)
from monkeytype.tracing import CallTrace, CodeFilter # type: ignore[import]
except ImportError:
_IS_MONKEYTYPE_INSTALLED = False
# Checks whether a class is defind in `torch.*` modules
def is_torch_native_class(cls):
if not hasattr(cls, "__module__"):
return False
parent_modules = cls.__module__.split(".")
if not parent_modules:
return False
root_module = sys.modules.get(parent_modules[0])
return root_module is torch
def get_type(type):
"""Convert the given type to a torchScript acceptable format."""
if isinstance(type, str):
return type
elif inspect.getmodule(type) == typing:
# If the type is a type imported from typing
# like Tuple, List, Dict then replace `typing.`
# with a null string. This needs to be done since
# typing.List is not accepted by TorchScript.
type_to_string = str(type)
return type_to_string.replace(type.__module__ + ".", "")
elif is_torch_native_class(type):
# If the type is a subtype of torch module, then TorchScript expects a fully qualified name
# for the type which is obtained by combining the module name and type name.
return type.__module__ + "." + type.__name__
else:
# For all other types use the name for the type.
return type.__name__
def get_optional_of_element_type(types):
"""Extract element type, return as `Optional[element type]` from consolidated types.
Helper function to extracts the type of the element to be annotated to Optional
from the list of consolidated types and returns `Optional[element type]`.
TODO: To remove this check once Union support lands.
"""
elem_type = types[1] if type(None) == types[0] else types[0]
elem_type = get_type(elem_type)
# Optional type is internally converted to Union[type, NoneType], which
# is not supported yet in TorchScript. Hence, representing the optional type as string.
return "Optional[" + elem_type + "]"
def get_qualified_name(func):
return func.__qualname__
if _IS_MONKEYTYPE_INSTALLED:
class JitTypeTraceStoreLogger(CallTraceStoreLogger):
"""A JitTypeCallTraceLogger that stores logged traces in a CallTraceStore."""
def __init__(self, store: CallTraceStore):
super().__init__(store)
def log(self, trace: CallTrace) -> None:
self.traces.append(trace)
class JitTypeTraceStore(CallTraceStore):
def __init__(self) -> None:
super().__init__()
# A dictionary keeping all collected CallTrace
# key is fully qualified name of called function
# value is list of all CallTrace
self.trace_records: Dict[str, list] = defaultdict(list)
def add(self, traces: Iterable[CallTrace]):
for t in traces:
qualified_name = get_qualified_name(t.func)
self.trace_records[qualified_name].append(t)
def filter(
self,
qualified_name: str,
qualname_prefix: Optional[str] = None,
limit: int = 2000,
) -> List[CallTraceThunk]:
return self.trace_records[qualified_name]
def analyze(self, qualified_name: str) -> Dict:
# Analyze the types for the given module
# and create a dictionary of all the types
# for arguments.
records = self.trace_records[qualified_name]
all_args = defaultdict(set)
for record in records:
for arg, arg_type in record.arg_types.items():
all_args[arg].add(arg_type)
return all_args
def consolidate_types(self, qualified_name: str) -> Dict:
all_args = self.analyze(qualified_name)
# If there are more types for an argument,
# then consolidate the type to `Any` and replace the entry
# by type `Any`.
for arg, types in all_args.items():
types = list(types)
type_length = len(types)
if type_length == 2 and type(None) in types:
# TODO: To remove this check once Union suppport in TorchScript lands.
all_args[arg] = get_optional_of_element_type(types)
elif type_length > 1:
all_args[arg] = "Any"
elif type_length == 1:
all_args[arg] = get_type(types[0])
return all_args
def get_args_types(self, qualified_name: str) -> Dict:
return self.consolidate_types(qualified_name)
class JitTypeTraceConfig(monkeytype.config.Config):
def __init__(self, s: JitTypeTraceStore):
super().__init__()
self.s = s
def trace_logger(self) -> JitTypeTraceStoreLogger:
"""Return a JitCallTraceStoreLogger that logs to the configured trace store."""
return JitTypeTraceStoreLogger(self.trace_store())
def trace_store(self) -> CallTraceStore:
return self.s
def code_filter(self) -> Optional[CodeFilter]:
return jit_code_filter
else:
# When MonkeyType is not installed, we provide dummy class definitions
# for the below classes.
class JitTypeTraceStoreLogger: # type: ignore[no-redef]
def __init__(self) -> None:
pass
class JitTypeTraceStore: # type: ignore[no-redef]
def __init__(self) -> None:
self.trace_records = None
class JitTypeTraceConfig: # type: ignore[no-redef]
def __init__(self) -> None:
pass
monkeytype_trace = None # type: ignore[assignment] # noqa: F811
def jit_code_filter(code: CodeType) -> bool:
"""Codefilter for Torchscript to trace forward calls.
The custom CodeFilter is required while scripting a FX Traced forward calls.
FX Traced forward calls have `code.co_filename` start with '<' which is used
to exclude tracing of stdlib and site-packages in the default code filter.
Since we need all forward calls to be traced, this custom code filter
checks for code.co_name to be 'forward' and enables tracing for all such calls.
The code filter is similar to default code filter for monkeytype and
excludes tracing of stdlib and site-packages.
"""
# Filter code without a source file and exclude this check for 'forward' calls.
if code.co_name != "forward" and (
not code.co_filename or code.co_filename[0] == "<"
):
return False
filename = Path(code.co_filename).resolve()
return not any(_startswith(filename, lib_path) for lib_path in LIB_PATHS)

View File

@ -0,0 +1,47 @@
# mypy: allow-untyped-defs
"""
Tools to help with tensor property propagation.
This is not intended to be imported directly; please use the exposed
functionalities in `torch.jit`.
"""
from typing import Any, List
import torch
from torch import TensorType
from torch._C import Graph
def apply_input_props_using_example(graph: Graph, example_input: List[Any]):
"""
Applies properties for each tensor in the graph inputs
using the example supplied.
"""
graph_inputs = list(graph.inputs())
if len(graph_inputs) == 0:
return
# Strip self args off for methods
in_0 = graph_inputs[0]
if isinstance(in_0.type(), torch._C.ClassType) and in_0.debugName() == "self":
graph_inputs = graph_inputs[1:]
if not len(graph_inputs) == len(example_input):
raise RuntimeError(
"Number of inputs in graph does not match number of inputs in the example"
)
for i, (graph_i, example_i) in enumerate(zip(graph_inputs, example_input)):
if example_i is None:
continue # Skip the type check
if isinstance(example_i, torch.Tensor) != isinstance(
graph_i.type(), TensorType
):
raise RuntimeError(
f"Input {i} does not match type of example", graph_i, example_i
)
if isinstance(example_i, torch.Tensor):
graph_i.setType(TensorType.create_from_tensor(example_i)) # type: ignore[arg-type]

View File

@ -0,0 +1,38 @@
# mypy: allow-untyped-defs
# These functions are referenced from the pickle archives produced by
# ScriptModule.save()
# These (`build_*`) functions used to be used by `pickler.cpp` to specify
# the type of the list for certain special types, but now all lists get
# a type attached and restored via `restore_type_tag` below. The legacy
# functions should stick around for backwards-compatibility.
def build_intlist(data):
return data
def build_tensorlist(data):
return data
def build_doublelist(data):
return data
def build_boollist(data):
return data
def build_tensor_from_id(data):
if isinstance(data, int):
# just the id, can't really do anything
return data
def restore_type_tag(value, type_str):
# The type_ptr is used by the jit unpickler to restore the full static type
# to container types like list when they are re-loaded, but this doesn't
# matter for Python, so just return the plain value
return value

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,296 @@
# mypy: allow-untyped-defs
# mypy: disable-error-code="type-arg"
from typing import Any, Callable, NamedTuple, overload, TypeVar
from typing_extensions import Never, TypeAlias
from _typeshed import Incomplete
import torch
from torch._classes import classes as classes
from torch._jit_internal import _qualified_name as _qualified_name
from torch.jit._builtins import _register_builtin as _register_builtin
from torch.jit._fuser import (
_graph_for as _graph_for,
_script_method_graph_for as _script_method_graph_for,
)
from torch.jit._monkeytype_config import (
JitTypeTraceConfig as JitTypeTraceConfig,
JitTypeTraceStore as JitTypeTraceStore,
monkeytype_trace as monkeytype_trace,
)
from torch.jit._recursive import (
_compile_and_register_class as _compile_and_register_class,
infer_methods_to_compile as infer_methods_to_compile,
ScriptMethodStub as ScriptMethodStub,
wrap_cpp_module as wrap_cpp_module,
)
from torch.jit._serialization import validate_map_location as validate_map_location
from torch.jit._state import (
_enabled as _enabled,
_set_jit_function_cache as _set_jit_function_cache,
_set_jit_overload_cache as _set_jit_overload_cache,
_try_get_jit_cached_function as _try_get_jit_cached_function,
_try_get_jit_cached_overloads as _try_get_jit_cached_overloads,
)
from torch.jit.frontend import (
get_default_args as get_default_args,
get_jit_class_def as get_jit_class_def,
get_jit_def as get_jit_def,
)
from torch.nn import Module as Module
from torch.overrides import (
has_torch_function as has_torch_function,
has_torch_function_unary as has_torch_function_unary,
has_torch_function_variadic as has_torch_function_variadic,
)
from torch.package import (
PackageExporter as PackageExporter,
PackageImporter as PackageImporter,
)
from torch.utils import set_module as set_module
ScriptFunction = torch._C.ScriptFunction
type_trace_db: JitTypeTraceStore
# Defined in torch/csrc/jit/python/script_init.cpp
ResolutionCallback: TypeAlias = Callable[[str], Callable[..., Any]]
_ClassVar = TypeVar("_ClassVar", bound=type)
def _reduce(cls) -> None: ...
class Attribute(NamedTuple):
value: Incomplete
type: Incomplete
def _get_type_trace_db(): ...
def _get_function_from_type(cls, name): ...
def _is_new_style_class(cls): ...
class OrderedDictWrapper:
_c: Incomplete
def __init__(self, _c) -> None: ...
def keys(self): ...
def values(self): ...
def __len__(self) -> int: ...
def __delitem__(self, k) -> None: ...
def items(self): ...
def __setitem__(self, k, v) -> None: ...
def __contains__(self, k) -> bool: ...
def __getitem__(self, k): ...
class OrderedModuleDict(OrderedDictWrapper):
_python_modules: Incomplete
def __init__(self, module, python_dict) -> None: ...
def items(self): ...
def __contains__(self, k) -> bool: ...
def __setitem__(self, k, v) -> None: ...
def __getitem__(self, k): ...
class ScriptMeta(type):
def __init__(cls, name, bases, attrs) -> None: ...
class _CachedForward:
def __get__(self, obj, cls): ...
class ScriptWarning(Warning): ...
def script_method(fn): ...
class ConstMap:
const_mapping: Incomplete
def __init__(self, const_mapping) -> None: ...
def __getattr__(self, attr): ...
def unpackage_script_module(
importer: PackageImporter,
script_module_id: str,
) -> torch.nn.Module: ...
_magic_methods: Incomplete
class RecursiveScriptClass:
_c: Incomplete
_props: Incomplete
def __init__(self, cpp_class) -> None: ...
def __getattr__(self, attr): ...
def __setattr__(self, attr, value) -> None: ...
def forward_magic_method(self, method_name, *args, **kwargs): ...
def __getstate__(self) -> None: ...
def __iadd__(self, other): ...
def method_template(self, *args, **kwargs): ...
class ScriptModule(Module, metaclass=ScriptMeta):
__jit_unused_properties__: Incomplete
def __init__(self) -> None: ...
forward: Callable[..., Any]
def __getattr__(self, attr): ...
def __setattr__(self, attr, value) -> None: ...
def define(self, src): ...
def _replicate_for_data_parallel(self): ...
def __reduce_package__(self, exporter: PackageExporter): ...
# add __jit_unused_properties__
@property
def code(self) -> str: ...
@property
def code_with_constants(self) -> tuple[str, ConstMap]: ...
@property
def graph(self) -> torch.Graph: ...
@property
def inlined_graph(self) -> torch.Graph: ...
@property
def original_name(self) -> str: ...
class RecursiveScriptModule(ScriptModule):
_disable_script_meta: bool
_c: Incomplete
def __init__(self, cpp_module) -> None: ...
@staticmethod
def _construct(cpp_module, init_fn): ...
@staticmethod
def _finalize_scriptmodule(script_module) -> None: ...
_concrete_type: Incomplete
_modules: Incomplete
_parameters: Incomplete
_buffers: Incomplete
__dict__: Incomplete
def _reconstruct(self, cpp_module) -> None: ...
def save(self, f, **kwargs): ...
def _save_for_lite_interpreter(self, *args, **kwargs): ...
def _save_to_buffer_for_lite_interpreter(self, *args, **kwargs): ...
def save_to_buffer(self, *args, **kwargs): ...
def get_debug_state(self, *args, **kwargs): ...
def extra_repr(self): ...
def graph_for(self, *args, **kwargs): ...
def define(self, src) -> None: ...
def __getattr__(self, attr): ...
def __setattr__(self, attr, value) -> None: ...
def __copy__(self): ...
def __deepcopy__(self, memo): ...
def forward_magic_method(self, method_name, *args, **kwargs): ...
def __iter__(self): ...
def __getitem__(self, idx): ...
def __len__(self) -> int: ...
def __contains__(self, key) -> bool: ...
def __dir__(self): ...
def __bool__(self) -> bool: ...
def _replicate_for_data_parallel(self): ...
def _get_methods(cls): ...
_compiled_methods_allowlist: Incomplete
def _make_fail(name): ...
def call_prepare_scriptable_func_impl(obj, memo): ...
def call_prepare_scriptable_func(obj): ...
def create_script_dict(obj): ...
def create_script_list(obj, type_hint: Incomplete | None = ...): ...
@overload
def script(
obj: type[Module],
optimize: bool | None = None,
_frames_up: int = 0,
_rcb: ResolutionCallback | None = None,
example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
) -> Never: ...
@overload
def script( # type: ignore[misc]
obj: dict,
optimize: bool | None = None,
_frames_up: int = 0,
_rcb: ResolutionCallback | None = None,
example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
) -> torch.ScriptDict: ...
@overload
def script( # type: ignore[misc]
obj: list,
optimize: bool | None = None,
_frames_up: int = 0,
_rcb: ResolutionCallback | None = None,
example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
) -> torch.ScriptList: ...
@overload
def script( # type: ignore[misc]
obj: Module,
optimize: bool | None = None,
_frames_up: int = 0,
_rcb: ResolutionCallback | None = None,
example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
) -> RecursiveScriptModule: ...
@overload
def script( # type: ignore[misc]
obj: _ClassVar,
optimize: bool | None = None,
_frames_up: int = 0,
_rcb: ResolutionCallback | None = None,
example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
) -> _ClassVar: ...
@overload
def script( # type: ignore[misc]
obj: Callable,
optimize: bool | None = None,
_frames_up: int = 0,
_rcb: ResolutionCallback | None = None,
example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
) -> ScriptFunction: ...
@overload
def script(
obj: Any,
optimize: bool | None = None,
_frames_up: int = 0,
_rcb: ResolutionCallback | None = None,
example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
) -> RecursiveScriptClass: ...
@overload
def script(
obj,
optimize: Incomplete | None = ...,
_frames_up: int = ...,
_rcb: Incomplete | None = ...,
example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = ...,
): ...
def _check_overload_defaults(impl_defaults, overload_defaults, loc) -> None: ...
def _compile_function_with_overload(overload_fn, qual_name, impl_fn): ...
def _get_overloads(obj): ...
def _check_directly_compile_overloaded(obj) -> None: ...
def interface(obj): ...
def _recursive_compile_class(obj, loc): ...
CompilationUnit: Incomplete
def pad(s: str, padding: int, offset: int = ..., char: str = ...): ...
class _ScriptProfileColumn:
header: Incomplete
alignment: Incomplete
offset: Incomplete
rows: Incomplete
def __init__(
self,
header: str,
alignment: int = ...,
offset: int = ...,
) -> None: ...
def add_row(self, lineno: int, value: Any): ...
def materialize(self): ...
class _ScriptProfileTable:
cols: Incomplete
source_range: Incomplete
def __init__(
self,
cols: list[_ScriptProfileColumn],
source_range: list[int],
) -> None: ...
def dump_string(self): ...
class _ScriptProfile:
profile: Incomplete
def __init__(self) -> None: ...
def enable(self) -> None: ...
def disable(self) -> None: ...
def dump_string(self) -> str: ...
def dump(self) -> None: ...
def _unwrap_optional(x): ...

View File

@ -0,0 +1,273 @@
# mypy: allow-untyped-defs
"""Serialization.
This module contains functionality for serializing TorchScript modules, notably:
* torch.jit.save
* torch.jit.load
This is not intended to be imported directly; please use the exposed
functionalities in `torch.jit`.
"""
import os
import torch
from torch._jit_internal import _get_model_id
from torch._utils_internal import log_torchscript_usage
from torch.jit._recursive import wrap_cpp_module
from torch.serialization import validate_cuda_device
def save(m, f, _extra_files=None):
r"""
Save an offline version of this module for use in a separate process.
The saved module serializes all of the methods, submodules, parameters, and
attributes of this module. It can be loaded into the C++ API using
``torch::jit::load(filename)`` or into the Python API with
:func:`torch.jit.load <torch.jit.load>`.
To be able to save a module, it must not make any calls to native Python
functions. This means that all submodules must be subclasses of
:class:`ScriptModule` as well.
.. DANGER::
All modules, no matter their device, are always loaded onto the CPU
during loading. This is different from :func:`torch.load`'s semantics
and may change in the future.
Args:
m: A :class:`ScriptModule` to save.
f: A file-like object (has to implement write and flush) or a string
containing a file name.
_extra_files: Map from filename to contents which will be stored as part of `f`.
.. note::
torch.jit.save attempts to preserve the behavior of some operators
across versions. For example, dividing two integer tensors in
PyTorch 1.5 performed floor division, and if the module
containing that code is saved in PyTorch 1.5 and loaded in PyTorch 1.6
its division behavior will be preserved. The same module saved in
PyTorch 1.6 will fail to load in PyTorch 1.5, however, since the
behavior of division changed in 1.6, and 1.5 does not know how to
replicate the 1.6 behavior.
Example:
.. testcode::
import torch
import io
class MyModule(torch.nn.Module):
def forward(self, x):
return x + 10
m = torch.jit.script(MyModule())
# Save to file
torch.jit.save(m, 'scriptmodule.pt')
# This line is equivalent to the previous
m.save("scriptmodule.pt")
# Save to io.BytesIO buffer
buffer = io.BytesIO()
torch.jit.save(m, buffer)
# Save with extra files
extra_files = {'foo.txt': b'bar'}
torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)
"""
log_torchscript_usage("save", model_id=_get_model_id(m))
if _extra_files is None:
_extra_files = {}
if isinstance(f, (str, os.PathLike)):
m.save(f, _extra_files=_extra_files)
else:
ret = m.save_to_buffer(_extra_files=_extra_files)
f.write(ret)
def load(f, map_location=None, _extra_files=None, _restore_shapes=False):
r"""
Load a :class:`ScriptModule` or :class:`ScriptFunction` previously saved with :func:`torch.jit.save <torch.jit.save>`.
All previously saved modules, no matter their device, are first loaded onto CPU,
and then are moved to the devices they were saved from. If this fails (e.g.
because the run time system doesn't have certain devices), an exception is
raised.
Args:
f: a file-like object (has to implement read, readline, tell, and seek),
or a string containing a file name
map_location (string or torch.device): A simplified version of
``map_location`` in `torch.jit.save` used to dynamically remap
storages to an alternative set of devices.
_extra_files (dictionary of filename to content): The extra
filenames given in the map would be loaded and their content
would be stored in the provided map.
_restore_shapes (bool): Whether or not to retrace the module on load using stored inputs
Returns:
A :class:`ScriptModule` object.
Example:
.. testcode::
import torch
import io
torch.jit.load('scriptmodule.pt')
# Load ScriptModule from io.BytesIO object
with open('scriptmodule.pt', 'rb') as f:
buffer = io.BytesIO(f.read())
# Load all tensors to the original device
torch.jit.load(buffer)
# Load all tensors onto CPU, using a device
buffer.seek(0)
torch.jit.load(buffer, map_location=torch.device('cpu'))
# Load all tensors onto CPU, using a string
buffer.seek(0)
torch.jit.load(buffer, map_location='cpu')
# Load with extra files.
extra_files = {'foo.txt': ''} # values will be replaced with data
torch.jit.load('scriptmodule.pt', _extra_files=extra_files)
print(extra_files['foo.txt'])
.. testoutput::
:hide:
...
.. testcleanup::
import os
os.remove("scriptmodule.pt")
"""
if isinstance(f, (str, os.PathLike)):
if not os.path.exists(f): # type: ignore[type-var]
raise ValueError(f"The provided filename {f} does not exist") # type: ignore[str-bytes-safe]
if os.path.isdir(f):
raise ValueError(f"The provided filename {f} is a directory") # type: ignore[str-bytes-safe]
map_location = validate_map_location(map_location)
if _extra_files is None:
_extra_files = {}
cu = torch._C.CompilationUnit()
if isinstance(f, (str, os.PathLike)):
cpp_module = torch._C.import_ir_module(cu, os.fspath(f), map_location, _extra_files, _restore_shapes) # type: ignore[call-arg]
else:
cpp_module = torch._C.import_ir_module_from_buffer(
cu, f.read(), map_location, _extra_files, _restore_shapes
) # type: ignore[call-arg]
# TODO: Pretty sure this approach loses ConstSequential status and such
ret = wrap_cpp_module(cpp_module)
log_torchscript_usage("load", model_id=_get_model_id(ret))
return ret
def validate_map_location(map_location=None):
if isinstance(map_location, str):
map_location = torch.device(map_location)
elif not (map_location is None or isinstance(map_location, torch.device)):
raise ValueError(
"map_location should be either None, string or torch.device, "
"but got type: " + str(type(map_location))
)
if str(map_location).startswith("cuda"):
validate_cuda_device(map_location)
return map_location
def jit_module_from_flatbuffer(f):
if isinstance(f, (str, os.PathLike)):
f = os.fspath(f)
return wrap_cpp_module(torch._C._load_jit_module_from_file(f))
else:
return wrap_cpp_module(torch._C._load_jit_module_from_bytes(f.read()))
def save_jit_module_to_flatbuffer(m, f, _extra_files=None):
r"""
Save an offline version of this module for use in a separate process.
The saved module serializes all of the methods, submodules, parameters, and
attributes of this module. It can be loaded into the C++ API using
``torch::jit::load_jit_module_from_file(filename)`` or into the Python API with
:func:`torch.jit.jit_module_from_flatbuffer<torch.jit.jit_module_from_flatbuffer>`.
To be able to save a module, it must not make any calls to native Python
functions. This means that all submodules must be subclasses of
:class:`ScriptModule` as well.
.. DANGER::
All modules, no matter their device, are always loaded onto the CPU
during loading. This is different from :func:`torch.load`'s semantics
and may change in the future.
Args:
m: A :class:`ScriptModule` to save.
f: A string for file path
Example:
.. testcode::
import torch
import io
class MyModule(torch.nn.Module):
def forward(self, x):
return x + 10
m = torch.jit.script(MyModule())
# Save to file
torch.jit.save_jit_module_to_flatbuffer(m, 'scriptmodule.ff')
"""
extra_files = _extra_files
if extra_files is None:
extra_files = {}
if isinstance(f, (str, os.PathLike)):
f = os.fspath(f)
torch._C._save_jit_module(m._c, f, extra_files)
else:
s = torch._C._save_jit_module_to_bytes(m._c, extra_files)
f.write(s)
def get_flatbuffer_module_info(path_or_file):
r"""Get some information regarding a model file in flatbuffer format.
Args:
path_or_file: Either str, Path or file like object (BytesIO OK).
If it's str or Path, we will read the file referenced by that
path as Bytes.
Returns:
A dict with metadata on what that file contains, currently looks like
this:
{
'bytecode_version': 4, # int
'operator_version': 4, # int
'function_names': {
'__torch__.___torch_mangle_0.Foo.forward'}, # set
'type_names': set(), # set
'opname_to_num_args': {'aten::linear': 3} # Dict[str, int]
}
"""
if isinstance(path_or_file, (str, os.PathLike)):
with open(path_or_file, "rb") as f:
all_bytes = f.read()
else:
all_bytes = path_or_file.read()
return torch._C._get_module_info_from_flatbuffer(all_bytes)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,127 @@
# mypy: allow-untyped-defs
"""JIT-related state.
This module stores various pieces of Python-global state relating to the JIT.
This is not intended to be imported directly; please the exposed
functionalities in `torch.jit`.
"""
import os
import weakref
from typing import Any, Dict, Type
import torch
class EnabledProxy:
"""Stores whether the JIT is enabled or not.
This is just a wrapper for a bool, so that we get reference semantics
"""
def __init__(self) -> None:
self.enabled = self.parse_env(
"PYTORCH_JIT", True, "> Using PyTorch JIT", "> PyTorch JIT DISABLED"
)
def parse_env(self, name, default, true_message, false_message):
value = os.environ.get(name)
if value is None:
return default
if value.lower() in {"1", "true", "yes"}:
return True
elif value.lower() in {"0", "false", "no"}:
return False
if value == "1v":
print(true_message)
return True
elif value == "0v":
print(false_message)
return False
raise ValueError(f"Unknown setting of {name}. Try using 0 or 1.")
def __bool__(self):
return self.enabled
_enabled = EnabledProxy()
def disable():
_enabled.enabled = False
def enable():
_enabled.enabled = True
# The Python CompilationUnit. All functions and modules defined in Python will
# live in here. It's defined in Python because doing in cpp creates static
# destruction order issues.
_python_cu = torch._C.CompilationUnit()
# python class => ScriptClass mapping
_script_classes: Dict[Type[Any], Type[Any]] = {}
_name_to_pyclass: Dict[str, Type[Any]] = {}
def _add_script_class(python_class, script_class):
_script_classes[python_class] = script_class
_name_to_pyclass[script_class.qualified_name()] = python_class
def _get_script_class(python_class):
override = getattr(python_class, "_jit_override_qualname", None)
if override is not None:
python_class = _get_python_class(override)
return _script_classes.get(python_class, None)
def _get_python_class(qualified_name):
return _name_to_pyclass.get(qualified_name, None)
def _clear_class_state():
_script_classes.clear()
_name_to_pyclass.clear()
# Caching: we currently cache compilation of free functions and overloaded functions.
# To cache free functions we hold a weak ref to the function object and
# map to the compiled fn's qualified name.
# To cache overloaded functions we hold a weak ref to the function obj and
# map to all of its overloaded compiled fns.
# In the future we could consider caching more types of objects so that
# aliasing is preserved across separate compilations of the same object.
_jit_caching_layer: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
_jit_function_overload_caching: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
def _try_get_jit_cached_overloads(key):
qual_names = _jit_function_overload_caching.get(key, None)
if qual_names:
return [_python_cu.find_function(qual_name) for qual_name in qual_names]
else:
return None
def _set_jit_overload_cache(key, compiled_fns):
_jit_function_overload_caching[key] = [fn.qualified_name for fn in compiled_fns]
def _try_get_jit_cached_function(key):
if getattr(key, "__disable_jit_function_caching__", False) is True:
return None
qual_name = _jit_caching_layer.get(key, None)
if qual_name:
return _python_cu.find_function(qual_name)
else:
return None
def _set_jit_function_cache(key, value):
# only free functions currently supported
assert isinstance(value, torch.jit.ScriptFunction)
_jit_caching_layer[key] = value.qualified_name

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,551 @@
# mypy: allow-untyped-defs
import ast
import builtins
import dis
import enum
import inspect
import re
import typing
import warnings
from textwrap import dedent
from typing import Type
import torch
from torch._C import (
_GeneratorType,
AnyType,
AwaitType,
BoolType,
ComplexType,
DeviceObjType,
DictType,
EnumType,
FloatType,
FutureType,
InterfaceType,
IntType,
ListType,
NoneType,
NumberType,
OptionalType,
StreamObjType,
StringType,
TensorType,
TupleType,
UnionType,
)
from torch._jit_internal import ( # type: ignore[attr-defined]
_Await,
_qualified_name,
Any,
BroadcastingList1,
BroadcastingList2,
BroadcastingList3,
Dict,
Future,
is_await,
is_dict,
is_future,
is_ignored_fn,
is_list,
is_optional,
is_tuple,
is_union,
List,
Optional,
Tuple,
Union,
)
from torch._sources import get_source_lines_and_file
from ._state import _get_script_class
if torch.distributed.rpc.is_available():
from torch._C import RRefType
from torch._jit_internal import is_rref, RRef
from torch._ops import OpOverloadPacket
class Module:
def __init__(self, name, members):
self.name = name
self.members = members
def __getattr__(self, name):
try:
return self.members[name]
except KeyError:
raise RuntimeError(
f"Module {self.name} has no member called {name}"
) from None
class EvalEnv:
env = {
"torch": Module("torch", {"Tensor": torch.Tensor}),
"Tensor": torch.Tensor,
"typing": Module("typing", {"Tuple": Tuple}),
"Tuple": Tuple,
"List": List,
"Dict": Dict,
"Optional": Optional,
"Union": Union,
"Future": Future,
"Await": _Await,
}
def __init__(self, rcb):
self.rcb = rcb
if torch.distributed.rpc.is_available():
self.env["RRef"] = RRef
def __getitem__(self, name):
if name in self.env:
return self.env[name]
if self.rcb is not None:
return self.rcb(name)
return getattr(builtins, name, None)
def get_signature(fn, rcb, loc, is_method):
if isinstance(fn, OpOverloadPacket):
signature = try_real_annotations(fn.op, loc)
else:
signature = try_real_annotations(fn, loc)
if signature is not None and is_method:
# If this is a method, then the signature will include a type for
# `self`, but type comments do not contain a `self`. So strip it
# away here so everything is consistent (`inspect.ismethod` does
# not work here since `fn` is unbound at this point)
param_types, return_type = signature
param_types = param_types[1:]
signature = (param_types, return_type)
if signature is None:
type_line, source = None, None
try:
source = dedent("".join(get_source_lines_and_file(fn)[0]))
type_line = get_type_line(source)
except TypeError:
pass
# This might happen both because we failed to get the source of fn, or
# because it didn't have any annotations.
if type_line is not None:
signature = parse_type_line(type_line, rcb, loc)
return signature
def is_function_or_method(the_callable):
# A stricter version of `inspect.isroutine` that does not pass for built-in
# functions
return inspect.isfunction(the_callable) or inspect.ismethod(the_callable)
def is_vararg(the_callable):
if not is_function_or_method(the_callable) and callable(the_callable): # noqa: B004
# If `the_callable` is a class, de-sugar the call so we can still get
# the signature
the_callable = the_callable.__call__
if is_function_or_method(the_callable):
return inspect.getfullargspec(the_callable).varargs is not None
else:
return False
def get_param_names(fn, n_args):
if isinstance(fn, OpOverloadPacket):
fn = fn.op
if (
not is_function_or_method(fn)
and callable(fn)
and is_function_or_method(fn.__call__)
): # noqa: B004
# De-sugar calls to classes
fn = fn.__call__
if is_function_or_method(fn):
if is_ignored_fn(fn):
fn = inspect.unwrap(fn)
return inspect.getfullargspec(fn).args
else:
# The `fn` was not a method or function (maybe a class with a __call__
# method, so use a default param name list)
return [str(i) for i in range(n_args)]
def check_fn(fn, loc):
# Make sure the function definition is not a class instantiation
try:
source = dedent("".join(get_source_lines_and_file(fn)[0]))
except (OSError, TypeError):
return
if source is None:
return
py_ast = ast.parse(source)
if len(py_ast.body) == 1 and isinstance(py_ast.body[0], ast.ClassDef):
raise torch.jit.frontend.FrontendError(
loc,
f"Cannot instantiate class '{py_ast.body[0].name}' in a script function",
)
if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
raise torch.jit.frontend.FrontendError(
loc, "Expected a single top-level function"
)
def _eval_no_call(stmt, glob, loc):
"""Evaluate statement as long as it does not contain any method/function calls."""
bytecode = compile(stmt, "", mode="eval")
for insn in dis.get_instructions(bytecode):
if "CALL" in insn.opname:
raise RuntimeError(
f"Type annotation should not contain calls, but '{stmt}' does"
)
return eval(bytecode, glob, loc) # type: ignore[arg-type] # noqa: P204
def parse_type_line(type_line, rcb, loc):
"""Parse a type annotation specified as a comment.
Example inputs:
# type: (Tensor, torch.Tensor) -> Tuple[Tensor]
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tensor
"""
arg_ann_str, ret_ann_str = split_type_line(type_line)
try:
arg_ann = _eval_no_call(arg_ann_str, {}, EvalEnv(rcb))
except (NameError, SyntaxError) as e:
raise RuntimeError(
"Failed to parse the argument list of a type annotation"
) from e
if not isinstance(arg_ann, tuple):
arg_ann = (arg_ann,)
try:
ret_ann = _eval_no_call(ret_ann_str, {}, EvalEnv(rcb))
except (NameError, SyntaxError) as e:
raise RuntimeError(
"Failed to parse the return type of a type annotation"
) from e
arg_types = [ann_to_type(ann, loc) for ann in arg_ann]
return arg_types, ann_to_type(ret_ann, loc)
def get_type_line(source):
"""Try to find the line containing a comment with the type annotation."""
type_comment = "# type:"
lines = source.split("\n")
lines = list(enumerate(lines))
type_lines = list(filter(lambda line: type_comment in line[1], lines))
# `type: ignore` comments may be needed in JIT'ed functions for mypy, due
# to the hack in torch/_VF.py.
# An ignore type comment can be of following format:
# 1) type: ignore
# 2) type: ignore[rule-code]
# This ignore statement must be at the end of the line
# adding an extra backslash before the space, to avoid triggering
# one of the checks in .github/workflows/lint.yml
type_pattern = re.compile("# type:\\ ignore(\\[[a-zA-Z-]+\\])?$")
type_lines = list(filter(lambda line: not type_pattern.search(line[1]), type_lines))
if len(type_lines) == 0:
# Catch common typo patterns like extra spaces, typo in 'ignore', etc.
wrong_type_pattern = re.compile("#[\t ]*type[\t ]*(?!: ignore(\\[.*\\])?$):")
wrong_type_lines = list(
filter(lambda line: wrong_type_pattern.search(line[1]), lines)
)
if len(wrong_type_lines) > 0:
raise RuntimeError(
"The annotation prefix in line "
+ str(wrong_type_lines[0][0])
+ " is probably invalid.\nIt must be '# type:'"
+ "\nSee PEP 484 (https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)" # noqa: B950
+ "\nfor examples"
)
return None
elif len(type_lines) == 1:
# Only 1 type line, quit now
return type_lines[0][1].strip()
# Parse split up argument types according to PEP 484
# https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code
return_line = None
parameter_type_lines = []
for line_num, line in type_lines:
if "# type: (...) -> " in line:
return_line = (line_num, line)
break
elif type_comment in line:
parameter_type_lines.append(line)
if return_line is None:
raise RuntimeError(
"Return type line '# type: (...) -> ...' not found on multiline "
"type annotation\nfor type lines:\n"
+ "\n".join([line[1] for line in type_lines])
+ "\n(See PEP 484 https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)"
)
def get_parameter_type(line):
item_type = line[line.find(type_comment) + len(type_comment) :]
return item_type.strip()
types = map(get_parameter_type, parameter_type_lines)
parameter_types = ", ".join(types)
return return_line[1].replace("...", parameter_types)
def split_type_line(type_line):
"""Split the comment with the type annotation into parts for argument and return types.
For example, for an input of:
# type: (Tensor, torch.Tensor) -> Tuple[Tensor, Tensor]
This function will return:
("(Tensor, torch.Tensor)", "Tuple[Tensor, Tensor]")
"""
start_offset = len("# type:")
try:
arrow_pos = type_line.index("->")
except ValueError:
raise RuntimeError(
"Syntax error in type annotation (couldn't find `->`)"
) from None
return type_line[start_offset:arrow_pos].strip(), type_line[arrow_pos + 2 :].strip()
def try_real_annotations(fn, loc):
"""Try to use the Py3.5+ annotation syntax to get the type."""
try:
# Note: anything annotated as `Optional[T]` will automatically
# be returned as `Union[T, None]` per
# https://github.com/python/typing/blob/master/src/typing.py#L850
sig = inspect.signature(fn)
except ValueError:
return None
all_annots = [sig.return_annotation] + [
p.annotation for p in sig.parameters.values()
]
if all(ann is sig.empty for ann in all_annots):
return None
arg_types = [ann_to_type(p.annotation, loc) for p in sig.parameters.values()]
return_type = ann_to_type(sig.return_annotation, loc)
return arg_types, return_type
# Finds common type for enum values belonging to an Enum class. If not all
# values have the same type, AnyType is returned.
def get_enum_value_type(e: Type[enum.Enum], loc):
enum_values: List[enum.Enum] = list(e)
if not enum_values:
raise ValueError(f"No enum values defined for: '{e.__class__}'")
types = {type(v.value) for v in enum_values}
ir_types = [try_ann_to_type(t, loc) for t in types]
# If Enum values are of different types, an exception will be raised here.
# Even though Python supports this case, we chose to not implement it to
# avoid overcomplicate logic here for a rare use case. Please report a
# feature request if you find it necessary.
res = torch._C.unify_type_list(ir_types)
if not res:
return AnyType.get()
return res
def is_tensor(ann):
if issubclass(ann, torch.Tensor):
return True
if issubclass(
ann,
(
torch.LongTensor,
torch.DoubleTensor,
torch.FloatTensor,
torch.IntTensor,
torch.ShortTensor,
torch.HalfTensor,
torch.CharTensor,
torch.ByteTensor,
torch.BoolTensor,
),
):
warnings.warn(
"TorchScript will treat type annotations of Tensor "
"dtype-specific subtypes as if they are normal Tensors. "
"dtype constraints are not enforced in compilation either."
)
return True
return False
def _fake_rcb(inp):
return None
def try_ann_to_type(ann, loc, rcb=None):
ann_args = typing.get_args(ann) # always returns a tuple!
if ann is inspect.Signature.empty:
return TensorType.getInferred()
if ann is None:
return NoneType.get()
if inspect.isclass(ann) and is_tensor(ann):
return TensorType.get()
if is_tuple(ann):
# Special case for the empty Tuple type annotation `Tuple[()]`
if len(ann_args) == 1 and ann_args[0] == ():
return TupleType([])
return TupleType([try_ann_to_type(a, loc) for a in ann_args])
if is_list(ann):
elem_type = try_ann_to_type(ann_args[0], loc)
if elem_type:
return ListType(elem_type)
if is_dict(ann):
key = try_ann_to_type(ann_args[0], loc)
value = try_ann_to_type(ann_args[1], loc)
# Raise error if key or value is None
if key is None:
raise ValueError(
f"Unknown type annotation: '{ann_args[0]}' at {loc.highlight()}"
)
if value is None:
raise ValueError(
f"Unknown type annotation: '{ann_args[1]}' at {loc.highlight()}"
)
return DictType(key, value)
if is_optional(ann):
if issubclass(ann_args[1], type(None)):
contained = ann_args[0]
else:
contained = ann_args[1]
valid_type = try_ann_to_type(contained, loc)
msg = "Unsupported annotation {} could not be resolved because {} could not be resolved. At\n{}"
assert valid_type, msg.format(repr(ann), repr(contained), repr(loc))
return OptionalType(valid_type)
if is_union(ann):
# TODO: this is hack to recognize NumberType
if set(ann_args) == {int, float, complex}:
return NumberType.get()
inner: List = []
# We need these extra checks because both `None` and invalid
# values will return `None`
# TODO: Determine if the other cases need to be fixed as well
for a in typing.get_args(ann):
if a is None:
inner.append(NoneType.get())
maybe_type = try_ann_to_type(a, loc)
msg = "Unsupported annotation {} could not be resolved because {} could not be resolved. At\n{}"
assert maybe_type, msg.format(repr(ann), repr(maybe_type), repr(loc))
inner.append(maybe_type)
return UnionType(inner) # type: ignore[arg-type]
if torch.distributed.rpc.is_available() and is_rref(ann):
return RRefType(try_ann_to_type(ann_args[0], loc))
if is_future(ann):
return FutureType(try_ann_to_type(ann_args[0], loc))
if is_await(ann):
elementType = try_ann_to_type(ann_args[0], loc) if ann_args else AnyType.get()
return AwaitType(elementType)
if ann is float:
return FloatType.get()
if ann is complex:
return ComplexType.get()
if ann is int or ann is torch.SymInt:
return IntType.get()
if ann is str:
return StringType.get()
if ann is bool:
return BoolType.get()
if ann is Any:
return AnyType.get()
if ann is type(None):
return NoneType.get()
if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"):
return InterfaceType(ann.__torch_script_interface__)
if ann is torch.device:
return DeviceObjType.get()
if ann is torch.Generator:
return _GeneratorType.get()
if ann is torch.Stream:
return StreamObjType.get()
if ann is torch.dtype:
return IntType.get() # dtype not yet bound in as its own type
if inspect.isclass(ann) and issubclass(ann, enum.Enum):
if _get_script_class(ann) is None:
scripted_class = torch.jit._script._recursive_compile_class(ann, loc)
name = scripted_class.qualified_name()
else:
name = _qualified_name(ann)
return EnumType(name, get_enum_value_type(ann, loc), list(ann))
if inspect.isclass(ann):
maybe_script_class = _get_script_class(ann)
if maybe_script_class is not None:
return maybe_script_class
if torch._jit_internal.can_compile_class(ann):
return torch.jit._script._recursive_compile_class(ann, loc)
# Maybe resolve a NamedTuple to a Tuple Type
if rcb is None:
rcb = _fake_rcb
return torch._C._resolve_type_from_object(ann, loc, rcb)
def ann_to_type(ann, loc, rcb=None):
the_type = try_ann_to_type(ann, loc, rcb)
if the_type is not None:
return the_type
raise ValueError(f"Unknown type annotation: '{ann}' at {loc.highlight()}")
__all__ = [
"Any",
"List",
"BroadcastingList1",
"BroadcastingList2",
"BroadcastingList3",
"Tuple",
"is_tuple",
"is_list",
"Dict",
"is_dict",
"is_optional",
"is_union",
"TensorType",
"TupleType",
"FloatType",
"ComplexType",
"IntType",
"ListType",
"StringType",
"DictType",
"AnyType",
"Module",
# TODO: Consider not exporting these during wildcard import (reserve
# that for the types; for idiomatic typing code.)
"get_signature",
"check_fn",
"get_param_names",
"parse_type_line",
"get_type_line",
"split_type_line",
"try_real_annotations",
"try_ann_to_type",
"ann_to_type",
]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,34 @@
# mypy: allow-untyped-defs
from typing import List
from torch._C import _compile_graph_to_code_table, _generate_upgraders_graph
def format_bytecode(table):
# given a nested tuple, convert it to nested list
def listify(content):
if not isinstance(content, tuple):
return content
return [listify(i) for i in content]
formatted_table = {}
for entry in table:
identifier = entry[0]
content = entry[1]
content = listify(content)
formatted_table[identifier] = content
return formatted_table
def generate_upgraders_bytecode() -> List:
yaml_content = []
upgraders_graph_map = _generate_upgraders_graph()
for upgrader_name, upgrader_graph in upgraders_graph_map.items():
bytecode_table = _compile_graph_to_code_table(upgrader_name, upgrader_graph)
entry = {upgrader_name: format_bytecode(bytecode_table)}
yaml_content.append(entry)
return yaml_content
if __name__ == "__main__":
raise RuntimeError("This file is not meant to be run directly")

View File

@ -0,0 +1,232 @@
# mypy: allow-untyped-defs
import os
import torch
from torch.jit._serialization import validate_map_location
def _load_for_lite_interpreter(f, map_location=None):
r"""
Load a :class:`LiteScriptModule` saved with :func:`torch.jit._save_for_lite_interpreter`.
Args:
f: a file-like object (has to implement read, readline, tell, and seek),
or a string containing a file name
map_location: a string or torch.device used to dynamically remap
storages to an alternative set of devices.
Returns:
A :class:`LiteScriptModule` object.
Example:
.. testcode::
import torch
import io
# Load LiteScriptModule from saved file path
torch.jit._load_for_lite_interpreter('lite_script_module.pt')
# Load LiteScriptModule from io.BytesIO object
with open('lite_script_module.pt', 'rb') as f:
buffer = io.BytesIO(f.read())
# Load all tensors to the original device
torch.jit.mobile._load_for_lite_interpreter(buffer)
"""
if isinstance(f, (str, os.PathLike)):
if not os.path.exists(f):
raise ValueError(f"The provided filename {f} does not exist")
if os.path.isdir(f):
raise ValueError(f"The provided filename {f} is a directory")
map_location = validate_map_location(map_location)
if isinstance(f, (str, os.PathLike)):
cpp_module = torch._C._load_for_lite_interpreter(os.fspath(f), map_location)
else:
cpp_module = torch._C._load_for_lite_interpreter_from_buffer(
f.read(), map_location
)
return LiteScriptModule(cpp_module)
class LiteScriptModule:
def __init__(self, cpp_module):
self._c = cpp_module
super().__init__()
def __call__(self, *input):
return self._c.forward(input)
def find_method(self, method_name):
return self._c.find_method(method_name)
def forward(self, *input):
return self._c.forward(input)
def run_method(self, method_name, *input):
return self._c.run_method(method_name, input)
def _export_operator_list(module: LiteScriptModule):
r"""Return a set of root operator names (with overload name) that are used by any method in this mobile module."""
return torch._C._export_operator_list(module._c)
def _get_model_bytecode_version(f_input) -> int:
r"""Take a file-like object to return an integer.
Args:
f_input: a file-like object (has to implement read, readline, tell, and seek),
or a string containing a file name
Returns:
version: An integer. If the integer is -1, the version is invalid. A warning
will show in the log.
Example:
.. testcode::
from torch.jit.mobile import _get_model_bytecode_version
# Get bytecode version from a saved file path
version = _get_model_bytecode_version("path/to/model.ptl")
"""
if isinstance(f_input, (str, os.PathLike)):
if not os.path.exists(f_input):
raise ValueError(f"The provided filename {f_input} does not exist")
if os.path.isdir(f_input):
raise ValueError(f"The provided filename {f_input} is a directory")
if isinstance(f_input, (str, os.PathLike)):
return torch._C._get_model_bytecode_version(os.fspath(f_input))
else:
return torch._C._get_model_bytecode_version_from_buffer(f_input.read())
def _get_mobile_model_contained_types(f_input) -> int:
r"""Take a file-like object and return a set of string, like ("int", "Optional").
Args:
f_input: a file-like object (has to implement read, readline, tell, and seek),
or a string containing a file name
Returns:
type_list: A set of string, like ("int", "Optional"). These are types used in bytecode.
Example:
.. testcode::
from torch.jit.mobile import _get_mobile_model_contained_types
# Get type list from a saved file path
type_list = _get_mobile_model_contained_types("path/to/model.ptl")
"""
if isinstance(f_input, (str, os.PathLike)):
if not os.path.exists(f_input):
raise ValueError(f"The provided filename {f_input} does not exist")
if os.path.isdir(f_input):
raise ValueError(f"The provided filename {f_input} is a directory")
if isinstance(f_input, (str, os.PathLike)):
return torch._C._get_mobile_model_contained_types(os.fspath(f_input))
else:
return torch._C._get_mobile_model_contained_types_from_buffer(f_input.read())
def _backport_for_mobile(f_input, f_output, to_version):
r"""Take a input string containing a file name (file-like object) and a new destination to return a boolean.
Args:
f_input: a file-like object (has to implement read, readline, tell, and seek),
or a string containing a file name
f_output: path to new model destination
to_version: the expected output model bytecode version
Returns:
success: A boolean. If backport success, return true, otherwise false
"""
if isinstance(f_input, (str, os.PathLike)):
if not os.path.exists(f_input):
raise ValueError(f"The provided filename {f_input} does not exist")
if os.path.isdir(f_input):
raise ValueError(f"The provided filename {f_input} is a directory")
if (isinstance(f_input, (str, os.PathLike))) and (
isinstance(f_output, (str, os.PathLike))
):
return torch._C._backport_for_mobile(
os.fspath(f_input), os.fspath(f_output), to_version
)
else:
return torch._C._backport_for_mobile_from_buffer(
f_input.read(), str(f_output), to_version
)
def _backport_for_mobile_to_buffer(f_input, to_version):
r"""Take a string containing a file name (file-like object).
Args:
f_input: a file-like object (has to implement read, readline, tell, and seek),
or a string containing a file name
"""
if isinstance(f_input, (str, os.PathLike)):
if not os.path.exists(f_input):
raise ValueError(f"The provided filename {f_input} does not exist")
if os.path.isdir(f_input):
raise ValueError(f"The provided filename {f_input} is a directory")
if isinstance(f_input, (str, os.PathLike)):
return torch._C._backport_for_mobile_to_buffer(os.fspath(f_input), to_version)
else:
return torch._C._backport_for_mobile_from_buffer_to_buffer(
f_input.read(), to_version
)
def _get_model_ops_and_info(f_input):
r"""Retrieve the root (top level) operators of a model and their corresponding compatibility info.
These root operators can call other operators within them (traced ops), and
a root op can call many different traced ops depending on internal code paths in the root op.
These traced ops are not returned by this function. Those operators are abstracted into the
runtime as an implementation detail (and the traced ops themselves can also call other operators)
making retrieving them difficult and their value from this api negligible since they will differ
between which runtime version the model is run on. Because of this, there is a false positive this
api can't prevent in a compatibility usecase. All the root ops of a model are present in a
target runtime, but not all the traced ops are which prevents a model from being able to run.
Args:
f_input: a file-like object (has to implement read, readline, tell, and seek),
or a string containing a file name
Returns:
Operators and info: A Dictionary mapping strings (the qualified names of the root operators)
of the model to their OperatorInfo structs.
Example:
.. testcode::
from torch.jit.mobile import _get_model_ops_and_info
# Get bytecode version from a saved file path
ops_and_info = _get_model_ops_and_info("path/to/model.ptl")
"""
if isinstance(f_input, (str, os.PathLike)):
if not os.path.exists(f_input):
raise ValueError(f"The provided filename {f_input} does not exist")
if os.path.isdir(f_input):
raise ValueError(f"The provided filename {f_input} is a directory")
if isinstance(f_input, (str, os.PathLike)):
return torch._C._get_model_ops_and_info(os.fspath(f_input))
else:
return torch._C._get_model_ops_and_info(f_input.read())

View File

@ -0,0 +1,100 @@
# mypy: allow-untyped-defs
import torch
class QuantizedLinear(torch.jit.ScriptModule):
def __init__(self, other):
raise RuntimeError(
"torch.jit.QuantizedLinear is no longer supported. Please use "
"torch.ao.nn.quantized.dynamic.Linear instead."
)
# FP16 weights
class QuantizedLinearFP16(torch.jit.ScriptModule):
def __init__(self, other):
super().__init__()
raise RuntimeError(
"torch.jit.QuantizedLinearFP16 is no longer supported. "
"Please use the torch.ao.nn.quantized.dynamic.Linear instead."
)
# Quantized RNN cell implementations
class QuantizedRNNCellBase(torch.jit.ScriptModule):
def __init__(self, other):
raise RuntimeError(
"torch.jit.QuantizedRNNCellBase is no longer supported. "
"Please use the torch.ao.nn.quantized.dynamic.RNNCell instead."
)
class QuantizedRNNCell(QuantizedRNNCellBase):
def __init__(self, other):
raise RuntimeError(
"torch.jit.QuantizedRNNCell is no longer supported. "
"Please use the torch.ao.nn.quantized.dynamic.RNNCell instead."
)
class QuantizedLSTMCell(QuantizedRNNCellBase):
def __init__(self, other):
super().__init__(other)
raise RuntimeError(
"torch.jit.QuantizedLSTMCell is no longer supported. "
"Please use the torch.ao.nn.quantized.dynamic.LSTMCell instead."
)
class QuantizedGRUCell(QuantizedRNNCellBase):
def __init__(self, other):
super().__init__(other)
raise RuntimeError(
"torch.jit.QuantizedGRUCell is no longer supported. "
"Please use the torch.ao.nn.quantized.dynamic.GRUCell instead."
)
class QuantizedRNNBase(torch.jit.ScriptModule):
def __init__(self, other, dtype=torch.int8):
raise RuntimeError(
"torch.jit.QuantizedRNNBase is no longer supported. "
"Please use the torch.ao.nn.quantized.dynamic instead."
)
class QuantizedLSTM(QuantizedRNNBase):
def __init__(self, other, dtype):
raise RuntimeError(
"torch.jit.QuantizedLSTM is no longer supported. "
"Please use the torch.ao.nn.quantized.dynamic.LSTM instead."
)
class QuantizedGRU(QuantizedRNNBase):
def __init__(self, *args, **kwargs):
raise RuntimeError(
"torch.jit.QuantizedGRU is no longer supported. "
"Please use the torch.ao.nn.quantized.dynamic.GRU instead."
)
def quantize_rnn_cell_modules(module):
raise RuntimeError(
"quantize_rnn_cell_modules function is no longer supported. "
"Please use torch.ao.quantization.quantize_dynamic API instead."
)
def quantize_linear_modules(module, dtype=torch.int8):
raise RuntimeError(
"quantize_linear_modules function is no longer supported. "
"Please use torch.ao.quantization.quantize_dynamic API instead."
)
def quantize_rnn_modules(module, dtype=torch.int8):
raise RuntimeError(
"quantize_rnn_modules function is no longer supported. "
"Please use torch.ao.quantization.quantize_dynamic API instead."
)

View File

@ -0,0 +1,344 @@
# mypy: allow-untyped-defs
import inspect
import textwrap
import torch.jit
from torch.jit._builtins import _find_builtin
# this file is for generating documentation using sphinx autodoc
# > help(torch.jit.supported_ops) will also give a nice listed of the
# supported ops programmatically
def _hidden(name):
return name.startswith("_") and not name.startswith("__")
def _emit_type(type):
return str(type)
def _emit_arg(indent, i, arg):
v = f"{arg.name} : {_emit_type(arg.type)}"
default = arg.default_value
if default is not None:
v = f"{v}={str(default)}"
if i > 0:
v = f"\n{' ' * indent}{v}"
return v
def _emit_args(indent, arguments):
return ",".join(_emit_arg(indent, i, arg) for i, arg in enumerate(arguments))
def _emit_ret(ret):
return _emit_type(ret.type)
def _emit_rets(returns):
if len(returns) == 1:
return _emit_ret(returns[0])
return f"Tuple[{', '.join(_emit_ret(r) for r in returns)}]"
def _emit_schema(mod, name, schema, arg_start=0, padding=4):
if mod is None:
qualified_name = name
else:
qualified_name = f"{mod}.{name}"
schema_str = (
f"{qualified_name}"
f"({_emit_args(len(qualified_name) + 1 + padding, schema.arguments[arg_start:])}) "
f"-> {_emit_rets(schema.returns)}"
)
return schema_str
def _get_tensor_ops():
def is_tensor_method(schema):
if len(schema.arguments) == 0:
return False
self = schema.arguments[0]
if self.name != "self":
return False
if not self.type.isSubtypeOf(torch._C.TensorType.get()):
return False
return True
methods = []
# discover methods
for elem in dir(torch.Tensor):
if not _hidden(elem):
schemas = torch._C._jit_get_schemas_for_operator("aten::" + elem)
for schema in schemas:
if is_tensor_method(schema):
methods.append(_emit_schema("Tensor", elem, schema, arg_start=1))
return "Supported Tensor Methods", methods
def _get_nn_functional_ops():
functions = []
# Iterate over torch.nn.functional
mod = torch.nn.functional
name = mod.__name__
for elem in dir(torch.nn.functional):
attr = getattr(mod, elem)
if not inspect.isfunction(attr) or _hidden(elem[0]):
# Ignore non-functions and internal methods
continue
attr_module = inspect.getmodule(attr)
if not attr_module:
raise RuntimeError(f"Module for {attr} not found")
if "torch.nn.functional" not in attr_module.__name__:
# Ignore functions from outside torch.nn.functional
continue
try:
# compile fn, get schema
scripted = torch.jit.script(attr)
scripted_schema = scripted.schema
functions.append(_emit_schema(name, elem, scripted_schema))
except: # noqa: B001,E722
# Skip interpolate / boolean dispatched things
pass
# Iterate over modules that we know contain a lot of builtins
for mod in torch.jit._builtins._modules_containing_builtins:
name = mod.__name__
for elem in dir(mod):
builtin = _find_builtin(getattr(mod, elem))
if builtin is not None:
schemas = torch._C._jit_get_schemas_for_operator(builtin)
for schema in schemas:
# remove _tan but not __and__
if not _hidden(elem):
functions.append(_emit_schema(name, elem, schema))
return "Supported PyTorch Functions", functions
def _get_builtins_helper():
builtins = []
for fn, _builtin_name in torch.jit._builtins._builtin_ops:
mod = inspect.getmodule(fn)
if not hasattr(fn, "__name__"):
# typing classes
continue
if not mod:
continue
if _hidden(fn.__name__) or _hidden(fn.__qualname__) or _hidden(mod.__name__):
# skip internal-only methods
continue
if "torch._C" in mod.__name__:
continue
builtins.append((fn, _builtin_name))
return builtins
def _is_math_fn(fn):
mod = inspect.getmodule(fn)
if not mod:
raise RuntimeError(f"Module for {fn} not found")
return mod.__name__ == "math"
def _get_torchscript_builtins():
functions = []
builtins = filter(lambda fn: not _is_math_fn(fn[0]), _get_builtins_helper())
builtins_list = list(builtins)
# Iterate over the specially added builtins
for fn, _builtin_name in builtins_list:
mod = inspect.getmodule(fn)
if not mod:
raise RuntimeError(f"Module for {fn} not found")
builtin = _find_builtin(fn)
if builtin is not None:
schemas = torch._C._jit_get_schemas_for_operator(builtin)
for schema in schemas:
functions.append(_emit_schema(mod.__name__, fn.__name__, schema))
return "TorchScript Builtin Functions", functions
def _get_math_builtins():
functions = []
builtins = filter(lambda fn: _is_math_fn(fn[0]), _get_builtins_helper())
builtins_list = list(builtins)
# Iterate over the specially added builtins
for fn, _builtin_name in builtins_list:
mod = inspect.getmodule(fn)
if not mod:
raise RuntimeError(f"Module for {fn} not found")
builtin = _find_builtin(fn)
if builtin is not None:
schemas = torch._C._jit_get_schemas_for_operator(builtin)
for schema in schemas:
schema_str = _emit_schema(mod.__name__, fn.__name__, schema)
if "Tensor" in schema_str:
# Skip Tensor ops that have the same name as math functions
# (they will show up in the tensor methods section)
continue
functions.append(schema)
return "``math`` Module", functions
def _get_global_builtins():
# Taken from the 'globals' map in torch/csrc/jit/frontend/ir_emitter.cpp
supported_builtins = [
"print",
"tuple",
"float",
"complex",
"int",
"bool",
"str",
"getattr",
"hasattr",
"isinstance",
"len",
"hex",
"oct",
"round",
"hash",
"min",
"max",
"abs",
"all",
"divmod",
"list",
"ord",
"chr",
"bin",
"range",
"zip",
"enumerate",
"sorted",
]
op_renames = {
"bool": "aten::Bool",
"int": "aten::Int",
"float": "aten::Float",
"complex": "aten::Complex",
"abs": "prim::abs",
"max": "prim::max",
"min": "prim::min",
"range": "fake::does_not_exist",
}
schemaless_op_explanations = {
"print": "Print any value",
"tuple": "Lists cannot be converted to tuples with this method since their size is not statically known",
"getattr": "Attribute name must be a literal string",
"hasattr": "Attribute name must be a literal string",
"isinstance": "Result is static",
"zip": "Arguments must be iterable. See :ref:`Iterables <jit_iterables>` for details.",
"enumerate": "Arguments must be iterable. See :ref:`Iterables <jit_iterables>` for details.",
"range": "Can only be used as an iterator in a for loop",
}
magic_methods = [
("complex", "__complex__"),
("float", "__float__"),
("int", "__int__"),
("bool", "__bool__"),
("str", "__str__"),
("len", "__len__"),
("hex", "__hex__"),
("oct", "__oct__"),
]
magic_methods_rows = []
for fn, magic_method in magic_methods:
magic_methods_rows.append(f'"{fn}", "``{magic_method}``"')
schematized_ops = []
schemaless_ops = []
for fn in supported_builtins:
op_name = f"aten::{fn}"
if fn in op_renames:
op_name = op_renames[fn]
schemas = torch._C._jit_get_schemas_for_operator(op_name)
for s in schemas:
schematized_ops.append(_emit_schema(None, fn, s, padding=0))
if len(schemas) > 0:
schematized_ops.append("")
else:
table_row = (
f'":external+python:py:obj:`{fn}`", "{schemaless_op_explanations[fn]}"'
)
schemaless_ops.append(table_row)
schematized_ops_str = "\n".join(schematized_ops)
schemaless_ops_str = "\n".join(schemaless_ops)
magic_methods_rows_str = "\n".join(magic_methods_rows)
schematized_ops_str = textwrap.indent(schematized_ops_str, "\t")
schemaless_ops_str = textwrap.indent(schemaless_ops_str, "\t")
magic_methods_rows_str = textwrap.indent(magic_methods_rows_str, "\t")
section = f"""
The functions in the following table are supported but do not have a static schema
.. csv-table::
:header: "Function", "Note"
{schemaless_ops_str}
The following functions will use the corresponding magic method on :any:`TorchScript classes`
.. csv-table::
:header: "Function", "Magic Method"
{magic_methods_rows_str}
These built-in functions use the schema
.. rst-class:: codeblock-height-limiter
::
{schematized_ops_str}
"""
return "Python Built-in Functions", section
def _list_supported_ops():
def emit_block(decls):
return "\n.. rst-class:: codeblock-height-limiter\n\n::\n\n{}\n".format(
"".join(f" {d}\n\n" for d in decls)
)
body = ""
op_gathering_fns = (
_get_tensor_ops,
_get_nn_functional_ops,
_get_torchscript_builtins,
_get_global_builtins,
_get_math_builtins,
)
for fn in op_gathering_fns:
header, items = fn()
link_target = header.replace("`", "").replace("-", "").lower().replace(" ", "-")
if isinstance(items, str):
section = f"{header}\n{'~' * len(header)}\n{items}\n"
else:
section = f"{header}\n{'~' * len(header)}\n{emit_block(items)}"
section = f".. _{link_target}:" + "\n\n" + section
body += section
return body
__doc__ = _list_supported_ops()

View File

@ -0,0 +1,78 @@
# mypy: allow-untyped-defs
from textwrap import dedent
from typing import Any, Dict
import torch.jit
def execWrapper(code, glob, loc):
exec(code, glob, loc)
def _gen_unsupported_methods_properties():
tensor_attrs = set(filter(lambda x: x[0] != "_", dir(torch.Tensor)))
tensor = torch.tensor([2])
funcs_template = dedent(
"""
def func(x):
return x.{op}()
"""
)
deprecated_apis = {
"volatile",
"resize",
"reinforce",
"new",
"name",
"map2_",
"has_names",
"grad_fn",
"resize_as",
}
tensor_attrs = tensor_attrs - deprecated_apis
properties = []
methods = []
sorted_tensor_attrs = sorted(tensor_attrs, key=lambda x: x.lower())
for attr in sorted_tensor_attrs:
funcs_str = funcs_template.format(op=attr)
scope: Dict[str, Any] = {}
execWrapper(funcs_str, globals(), scope)
try:
cu = torch.jit.CompilationUnit(funcs_str)
except Exception as e:
if "nonexistent attribute" not in repr(e):
continue
attr_repr = repr(getattr(tensor, attr))
if "bound method" in attr_repr or "built-in method" in attr_repr:
methods.append(attr)
else:
properties.append(attr)
mapped_methods = ("\t* :meth:`~torch.Tensor." + x + r"`" for x in methods)
mapped_properties = ("\t* :attr:`~torch.Tensor." + x + r"`" for x in properties)
return "\n".join(mapped_methods), "\n".join(mapped_properties)
def _list_unsupported_tensor_ops():
header = """\n\n
Unsupported Tensor Methods
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
"""
methods, properties = _gen_unsupported_methods_properties()
return (
header
+ "\n"
+ methods
+ """
Unsupported Tensor Properties
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
"""
+ "\n"
+ properties
)
__doc__ = _list_unsupported_tensor_ops()