I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,166 @@
"""
Python polyfills for common builtins.
"""
# NOTE: 1. Please do not import any submodule in the directory here to avoid circular imports.
# 2. While adding a new polyfill module, also add it to POLYFILLED_MODULE_NAMES in loader.py.
# Add it in the TYPE_CHECKING block below as well.
# mypy: allow-untyped-defs
from typing import Any, Callable, Sequence, TYPE_CHECKING
import torch
if TYPE_CHECKING:
# Load by torch._dynamo.polyfills.loader
# See also the POLYFILLED_MODULE_NAMES in torch/_dynamo/polyfills/loader.py
# Put the submodules here to avoid circular imports
from . import (
builtins as builtins,
functools as functools,
itertools as itertools,
os as os,
sys as sys,
)
def index(iterator, item, start=0, end=None):
from itertools import islice
for i, elem in islice(enumerate(iterator), start, end):
if item == elem:
return i
# This will not run in dynamo
raise ValueError(f"{item} is not in {type(iterator)}")
def repeat(item, count):
for i in range(count):
yield item
def radians(x):
import math
return math.pi / 180.0 * x
def accumulate_grad(x, new_grad):
new_grad = torch.clone(new_grad)
if x.grad is None:
x.grad = new_grad
else:
x.grad.add_(new_grad)
def list_cmp(op: Callable[[Any, Any], bool], left: Sequence[Any], right: Sequence[Any]):
"""emulate `(1,2,3) > (1,2)` etc"""
for a, b in zip(left, right):
if a != b:
return op(a, b)
return op(len(left), len(right))
def set_isdisjoint(set1, set2):
for x in set1:
if x in set2:
return False
return True
def set_intersection(set1, set2):
intersection_set = set()
for x in set1:
if x in set2:
intersection_set.add(x)
return intersection_set
def set_union(set1, set2):
union_set = set1.copy()
for x in set2:
if x not in union_set:
union_set.add(x)
return union_set
def set_difference(set1, set2):
difference_set = set()
for x in set1:
if x not in set2:
difference_set.add(x)
return difference_set
def dropwhile(predicate, iterable):
# dropwhile(lambda x: x<5, [1,4,6,4,1]) -> 6 4 1
iterable = iter(iterable)
for x in iterable:
if not predicate(x):
yield x
break
yield from iterable
def zip_longest(*iterables, fillvalue=None):
# Create a list of iterators from the input iterables
iterators = [iter(it) for it in iterables]
result = []
while True:
row = []
active = False
for it in iterators:
try:
# Try to get the next item from the iterator
value = next(it)
row.append(value)
active = True
except StopIteration:
# If the iterator is exhausted, use the fillvalue
row.append(fillvalue)
if not active:
break
result.append(tuple(row))
return result
def getattr_and_trace(*args, **kwargs):
wrapper_obj = args[0]
attr_name = args[1]
fn = getattr(wrapper_obj, attr_name)
return fn(*args[2:], **kwargs)
def mapping_get(obj, key, value=None):
try:
return obj.__getitem__(key)
except KeyError:
return value
def instantiate_user_defined_class_object(cls, /, *args, **kwargs):
obj = cls.__new__(cls, *args, **kwargs)
# Only call __init__ if the object is an instance of the class
# Reference: https://github.com/python/cpython/blob/3.12/Objects/typeobject.c#L1670-L1673
if isinstance(obj, cls):
obj.__init__(*args, **kwargs)
return obj
def foreach_lerp_inplace(self, end, weight):
# decompose foreach lerp into constituent ops, prevents a graph break due to
# converting a value to a scalar when arg[2] is a single tensor
result = torch._foreach_sub(end, self)
result = torch._foreach_mul(result, weight)
return torch._foreach_add_(self, result)
def foreach_pow_scalar(scalar, exps):
return torch._foreach_pow([scalar for _ in exps], exps)
def addcmul_inplace(self, tensor1, tensor2, value):
return self.add_(tensor1 * tensor2 * value)

View File

@ -0,0 +1,48 @@
"""
Python polyfills for builtins
"""
from __future__ import annotations
import builtins
from typing import Iterable, TypeVar
from ..decorators import substitute_in_graph
__all__ = [
"all",
"any",
"enumerate",
]
_T = TypeVar("_T")
@substitute_in_graph(builtins.all, can_constant_fold_through=True)
def all(iterable: Iterable[object], /) -> bool:
for elem in iterable:
if not elem:
return False
return True
@substitute_in_graph(builtins.any, can_constant_fold_through=True)
def any(iterable: Iterable[object], /) -> bool:
for elem in iterable:
if elem:
return True
return False
@substitute_in_graph(builtins.enumerate, is_embedded_type=True) # type: ignore[arg-type]
def enumerate(iterable: Iterable[_T], start: int = 0) -> Iterable[tuple[int, _T]]:
if not isinstance(start, int):
raise TypeError(
f"{type(start).__name__!r} object cannot be interpreted as an integer"
)
for x in iterable:
yield start, x
start += 1

View File

@ -0,0 +1,6 @@
"""
Python polyfills for functools
"""
__all__ = [] # type: ignore[var-annotated]

View File

@ -0,0 +1,85 @@
"""
Python polyfills for itertools
"""
from __future__ import annotations
import itertools
from typing import Iterable, Iterator, TypeVar
from ..decorators import substitute_in_graph
__all__ = [
"chain",
"chain_from_iterable",
"islice",
"tee",
]
_T = TypeVar("_T")
# Reference: https://docs.python.org/3/library/itertools.html#itertools.chain
@substitute_in_graph(itertools.chain, is_embedded_type=True) # type: ignore[arg-type]
def chain(*iterables: Iterable[_T]) -> Iterator[_T]:
for iterable in iterables:
yield from iterable
@substitute_in_graph(itertools.chain.from_iterable) # type: ignore[arg-type]
def chain_from_iterable(iterable: Iterable[Iterable[_T]], /) -> Iterator[_T]:
return itertools.chain(*iterable)
chain.from_iterable = chain_from_iterable # type: ignore[method-assign]
# Reference: https://docs.python.org/3/library/itertools.html#itertools.islice
@substitute_in_graph(itertools.islice, is_embedded_type=True) # type: ignore[arg-type]
def islice(iterable: Iterable[_T], /, *args: int | None) -> Iterator[_T]:
s = slice(*args)
start = 0 if s.start is None else s.start
stop = s.stop
step = 1 if s.step is None else s.step
if start < 0 or (stop is not None and stop < 0) or step <= 0:
raise ValueError(
"Indices for islice() must be None or an integer: 0 <= x <= sys.maxsize.",
)
if stop is None:
# TODO: use indices = itertools.count() and merge implementation with the else branch
# when we support infinite iterators
next_i = start
for i, element in enumerate(iterable):
if i == next_i:
yield element
next_i += step
else:
indices = range(max(start, stop))
next_i = start
for i, element in zip(indices, iterable):
if i == next_i:
yield element
next_i += step
# Reference: https://docs.python.org/3/library/itertools.html#itertools.tee
@substitute_in_graph(itertools.tee)
def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]:
iterator = iter(iterable)
shared_link = [None, None]
def _tee(link) -> Iterator[_T]: # type: ignore[no-untyped-def]
try:
while True:
if link[1] is None:
link[0] = next(iterator)
link[1] = [None, None]
value, link = link
yield value
except StopIteration:
return
return tuple(_tee(shared_link) for _ in range(n))

View File

@ -0,0 +1,35 @@
# Used to load and initialize polyfill handlers when importing torch._dynamo
# Please add a new import when adding a new polyfill module.
import importlib
from typing import Tuple, TYPE_CHECKING
from .. import polyfills, trace_rules
if TYPE_CHECKING:
from types import ModuleType
# See also the TYPE_CHECKING block in torch/_dynamo/polyfills/__init__.py
POLYFILLED_MODULE_NAMES: Tuple[str, ...] = (
"builtins",
"functools",
"itertools",
"os",
"sys",
)
POLYFILLED_MODULES: Tuple["ModuleType", ...] = tuple(
importlib.import_module(f".{submodule}", package=polyfills.__name__)
for submodule in POLYFILLED_MODULE_NAMES
)
# Unregister the builtin functions from _builtin_function_ids to let them to be
# dispatched with the appropriate VariableTracker type. Otherwise, they will be
# dispatched with BuiltinVariable if present in _builtin_function_ids.
for polyfill_module in POLYFILLED_MODULES:
for polyfill_name in polyfill_module.__all__:
polyfill_handler = getattr(polyfill_module, polyfill_name)
original_fn = polyfill_handler.__torch_dynamo_original__
trace_rules._builtin_function_ids.remove(id(original_fn))

View File

@ -0,0 +1,36 @@
"""
Python polyfills for os
"""
from __future__ import annotations
import os
from typing import AnyStr
from ..decorators import substitute_in_graph
__all__ = ["fspath"]
# Copied from os.py in the standard library
@substitute_in_graph(os.fspath, can_constant_fold_through=True)
def fspath(path: AnyStr | os.PathLike[AnyStr]) -> AnyStr:
if isinstance(path, (str, bytes)):
return path
path_type = type(path)
try:
path_repr = path_type.__fspath__(path) # type: ignore[arg-type]
except AttributeError:
if hasattr(path_type, "__fspath__"):
raise
raise TypeError(
f"expected str, bytes or os.PathLike object, not {path_type.__name__}",
) from None
if isinstance(path_repr, (str, bytes)):
return path_repr # type: ignore[return-value]
raise TypeError(
f"expected {path_type.__name__}.__fspath__() to return str or bytes, "
f"not {type(path_repr).__name__}",
)

View File

@ -0,0 +1,6 @@
"""
Python polyfills for sys
"""
__all__ = [] # type: ignore[var-annotated]