I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

Binary file not shown.

View File

@ -0,0 +1,39 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch._functorch.deprecated import (
combine_state_for_ensemble,
functionalize,
grad,
grad_and_value,
hessian,
jacfwd,
jacrev,
jvp,
make_functional,
make_functional_with_buffers,
vjp,
vmap,
)
# utilities. Maybe these should go in their own namespace in the future?
from torch._functorch.make_functional import (
FunctionalModule,
FunctionalModuleWithBuffers,
)
# Was never documented
from torch._functorch.python_key import make_fx
# Top-level APIs. Please think carefully before adding something to the
# top-level namespace:
# - private helper functions should go into torch._functorch
# - very experimental things should go into functorch.experimental
# - compilation related things should go into functorch.compile
__version__ = torch.__version__

View File

@ -0,0 +1,8 @@
# This file has moved to under torch/_functorch. It is not public API.
# If you are not a PyTorch developer and you are relying on the following
# imports, please file an issue.
from torch._functorch.aot_autograd import (
aot_autograd_decompositions,
KNOWN_TYPES,
PytreeThunk,
)

View File

@ -0,0 +1,7 @@
# This file has moved to under torch/_functorch. It is not public API.
# If you are not a PyTorch developer and you are relying on the following
# imports, please file an issue.
from torch._functorch.eager_transforms import (
_assert_wrapped_functional,
_unwrap_functional_tensor,
)

View File

@ -0,0 +1,4 @@
# This file has moved to under torch/_functorch. It is not public API.
# If you are not a PyTorch developer and you are relying on the following
# imports, please file an issue.
from torch._functorch.make_functional import _swap_state

View File

@ -0,0 +1,16 @@
# This file has moved to under torch/_functorch. It is not public API.
# If you are not a PyTorch developer and you are relying on the following
# imports, please file an issue.
from torch._functorch.vmap import (
_add_batch_dim,
_broadcast_to_and_flatten,
_create_batched_inputs,
_get_name,
_process_batched_inputs,
_remove_batch_dim,
_unwrap_batched,
_validate_and_get_batch_size,
Tensor,
tree_flatten,
tree_unflatten,
)

View File

@ -0,0 +1,30 @@
from torch._functorch import config
from torch._functorch.aot_autograd import (
aot_function,
aot_module,
aot_module_simplified,
compiled_function,
compiled_module,
get_aot_compilation_context,
get_aot_graph_name,
get_graph_being_compiled,
make_boxed_compiler,
make_boxed_func,
)
from torch._functorch.compilers import (
debug_compile,
default_decompositions,
draw_graph_compile,
memory_efficient_fusion,
nnc_jit,
nop,
print_compile,
ts_compile,
)
from torch._functorch.fx_minifier import minifier
from torch._functorch.partitioners import (
default_partition,
draw_graph,
min_cut_rematerialization_partition,
)
from torch._functorch.python_key import pythonkey_decompose

View File

@ -0,0 +1,181 @@
import dis
import inspect
from typing import Sequence, Union
import functorch._C
import torch
from functorch._C import dim as _C
from .tree_map import tree_flatten, tree_map
from .wrap_type import wrap_type
_C._patch_tensor_class()
dims, DimList, dimlists = _C.dims, _C.DimList, _C.dimlists
class DimensionMismatchError(Exception):
pass
class DimensionBindError(Exception):
pass
from . import op_properties
# use dict to avoid writing C++ bindings for set
pointwise = dict.fromkeys(op_properties.pointwise, True)
use_c = True
if not use_c:
from . import reference
class _Tensor:
# fast path around slow wrapping/unwrapping logic for simply queries used
# by the implementation...
@property
def dims(self):
return tuple(d for d in self._levels if isinstance(d, Dim))
def dim(self):
return self.ndim
if use_c:
__torch_function__ = classmethod(_C.__torch_function__)
expand = _C._instancemethod(_C.expand)
else:
__torch_function__ = reference.__torch_function__
expand = reference.expand
index = _C._instancemethod(_C.index)
def __repr__(self):
tensor, levels, ndim = self._tensor, self._levels, self.ndim
return f"{tensor}\nwith dims={tuple(l + ndim if isinstance(l, int) else l for l in levels)} sizes={tuple(tensor.size())}"
TensorLike = (_Tensor, torch.Tensor)
class Dim(_C.Dim, _Tensor):
# note that _C.Dim comes before tensor because we want the Dim API for things like size to take precendence.
# Tensor defines format, but we want to print Dims with special formatting
__format__ = object.__format__
class Tensor(_Tensor, _C.Tensor):
if not use_c:
from_batched = staticmethod(_C.Tensor_from_batched)
from_positional = staticmethod(_C.Tensor_from_positional)
sum = _C._instancemethod(_C.Tensor_sum)
def cat(tensors, dim, new_dim):
n = dims()
return stack(tensors, n, dim).index([n, dim], new_dim)
if use_c:
_wrap = _C._wrap
def _def(name, *args, **kwargs):
orig = getattr(torch.Tensor, name)
setattr(_Tensor, name, _C._instancemethod(_wrap(orig, *args, **kwargs)))
t__getitem__ = _C._instancemethod(_C.__getitem__)
stack = _C.stack
split = _C._instancemethod(_C.split)
else:
_wrap, _def = reference._wrap, reference._def
t__getitem__ = reference.t__getitem__
stack = reference.stack
split = reference.split
# note: there is no python reference
t__setitem__ = _C._instancemethod(_C.__setitem__)
# this is patched in the C API because otherwise torch.Tensor will
# no longer be considered a sequence and things will break
# torch.Tensor.__getitem__ = t__getitem__
_Tensor.__getitem__ = t__getitem__
# torch.Tensor.__setitem__ = t__setitem__
_Tensor.__setitem__ = t__setitem__
torch.Tensor.split = split
_Tensor.split = split
torch.Tensor.expand = _C._instancemethod(_C.expand)
torch.Tensor.index = _C._instancemethod(_C.index)
wrap_type(use_c, _Tensor, torch.Tensor, _Tensor.__torch_function__)
del _Tensor.ndim
if use_c:
_Tensor.order = _C._instancemethod(_C.order)
else:
_Tensor.order = reference.positional
_def("mean")
_def("sum")
_def("all")
_def("amax")
_def("amin")
_def("aminmax")
_def("any")
_def("count_nonzero")
_def("logsumexp")
_def("nanmean")
_def("nansum")
_def("prod")
_def("std", keepdim_offset=2)
_def("var", keepdim_offset=2)
_def("max", single_dim=True)
_def("min", single_dim=True)
_def("argmax", single_dim=True)
_def("argmin", single_dim=True)
_def("kthvalue", single_dim=True)
_def("median", single_dim=True)
_def("nanmedian", single_dim=True)
_def("mode", single_dim=True)
_def("sort", reduce=False)
_def("argsort", reduce=False)
_def("unbind", single_dim=True)
_def("chunk", dim_offset=1, reduce=False)
_def("cummax", single_dim=True, reduce=False)
_def("cummin", single_dim=True, reduce=False)
_def("cumprod", single_dim=True, reduce=False)
_def("cumprod_", single_dim=True, reduce=False)
_def("cumsum", single_dim=True, reduce=False)
_def("cumsum_", single_dim=True, reduce=False)
_def("logcumsumexp", single_dim=True, reduce=False)
_def("renorm", dim_offset=1, single_dim=True, reduce=False)
_def("softmax", single_dim=True, reduce=False)
softmax = _wrap(torch.nn.functional.softmax, single_dim=True, reduce=False)
# stuff to handle in the future, because they require special
# binding logic for dims
# cross
# diag_embed
# diagonal
# diagonal_scatter
# diff
# nanquantile
# quantile
# roll
# rot90
# topk (new dimes on output)
# should these all be subsumed by inplace indexing?
# index_add_
# index_add
# index_copy
# index_copy_
# index_fill
# index_fill_
# index_select
# scatter
# scatter_
# scatter_add
# scatter_add_
# scatter_reduce

View File

@ -0,0 +1,26 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from contextlib import contextmanager
from torch._C._functorch import _vmap_add_layers, _vmap_remove_layers
_enabled = False
@contextmanager
def _enable_layers(dims):
global _enabled
assert not _enabled
input = sorted((d._level, d.size) for d in dims if not isinstance(d, int))
n = len(input)
try:
_vmap_add_layers(input)
_enabled = True
yield
finally:
_enabled = False
_vmap_remove_layers(n)

View File

@ -0,0 +1,77 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from . import _Tensor, Tensor
from .reference import _dims, _enable_layers, llist, ltuple
class DelayedMulTensor(_Tensor):
def __init__(self, lhs, rhs):
self._lhs, self._rhs = lhs, rhs
self._data = None
self._levels_data = None
self._has_device = lhs._has_device or rhs._has_device
self._batchtensor_data = None
self._tensor_data = None
@property
def _levels(self):
if self._levels_data is None:
levels = llist(self._lhs._levels)
for l in self._rhs._levels:
if l not in levels:
levels.append(l)
self._levels_data = ltuple(levels)
return self._levels_data
@property
def _batchtensor(self):
if self._batchtensor_data is None:
with _enable_layers(self._levels):
print("bt multiply fallback")
self._batchtensor_data = self._lhs._batchtensor * self._rhs._batchtensor
return self._batchtensor_data
@property
def _tensor(self):
if self._tensor_data is None:
self._tensor_data = Tensor.from_batched(
self._batchtensor, self._has_device
)._tensor
return self._tensor_data
@property
def ndim(self):
return self._batchtensor.ndim
@property
def dims(self):
return ltuple(super().dims)
def sum(self, dim):
dims = _dims(dim, 0, False, False)
n = ord("a")
all_levels = self._levels
def to_char(d):
return chr(n + all_levels.index(d))
plhs, levelslhs = self._lhs._tensor, self._lhs._levels
prhs, levelsrhs = self._rhs._tensor, self._rhs._levels
new_dims = tuple(d for d in self.dims if d not in dims)
new_levels = [l for l in self._levels if l not in dims]
fmt = "".join(
[
*(to_char(d) for d in levelslhs),
",",
*(to_char(d) for d in levelsrhs),
"->",
*(to_char(d) for d in new_levels),
]
)
result_data = torch.einsum(fmt, (plhs, prhs))
return Tensor.from_positional(result_data, new_levels, True)

View File

@ -0,0 +1,121 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import dis
import inspect
from dataclasses import dataclass
from typing import Union
from . import DimList
_vmap_levels = []
@dataclass
class LevelInfo:
level: int
alive: bool = True
class Dim:
def __init__(self, name: str, size: Union[None, int] = None):
self.name = name
self._size = None
self._vmap_level = None
if size is not None:
self.size = size
def __del__(self):
if self._vmap_level is not None:
_vmap_active_levels[self._vmap_stack].alive = False # noqa: F821
while (
not _vmap_levels[-1].alive
and current_level() == _vmap_levels[-1].level # noqa: F821
):
_vmap_decrement_nesting() # noqa: F821
_vmap_levels.pop()
@property
def size(self):
assert self.is_bound
return self._size
@size.setter
def size(self, size: int):
from . import DimensionBindError
if self._size is None:
self._size = size
self._vmap_level = _vmap_increment_nesting(size, "same") # noqa: F821
self._vmap_stack = len(_vmap_levels)
_vmap_levels.append(LevelInfo(self._vmap_level))
elif self._size != size:
raise DimensionBindError(
f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}"
)
@property
def is_bound(self):
return self._size is not None
def __repr__(self):
return self.name
def extract_name(inst):
assert inst.opname == "STORE_FAST" or inst.opname == "STORE_NAME"
return inst.argval
_cache = {}
def dims(lists=0):
frame = inspect.currentframe()
assert frame is not None
calling_frame = frame.f_back
assert calling_frame is not None
code, lasti = calling_frame.f_code, calling_frame.f_lasti
key = (code, lasti)
if key not in _cache:
first = lasti // 2 + 1
instructions = list(dis.get_instructions(calling_frame.f_code))
unpack = instructions[first]
if unpack.opname == "STORE_FAST" or unpack.opname == "STORE_NAME":
# just a single dim, not a list
name = unpack.argval
ctor = Dim if lists == 0 else DimList
_cache[key] = lambda: ctor(name=name)
else:
assert unpack.opname == "UNPACK_SEQUENCE"
ndims = unpack.argval
names = tuple(
extract_name(instructions[first + 1 + i]) for i in range(ndims)
)
first_list = len(names) - lists
_cache[key] = lambda: tuple(
Dim(n) if i < first_list else DimList(name=n)
for i, n in enumerate(names)
)
return _cache[key]()
def _dim_set(positional, arg):
def convert(a):
if isinstance(a, Dim):
return a
else:
assert isinstance(a, int)
return positional[a]
if arg is None:
return positional
elif not isinstance(arg, (Dim, int)):
return tuple(convert(a) for a in arg)
else:
return (convert(arg),)

View File

@ -0,0 +1,42 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import signal
import subprocess
from contextlib import contextmanager
@contextmanager
def magic_trace(output="trace.fxt", magic_trace_cache="/tmp/magic-trace"):
pid = os.getpid()
if not os.path.exists(magic_trace_cache):
print(f"Downloading magic_trace to: {magic_trace_cache}")
subprocess.run(
[
"wget",
"-O",
magic_trace_cache,
"-q",
"https://github.com/janestreet/magic-trace/releases/download/v1.0.2/magic-trace",
]
)
subprocess.run(["chmod", "+x", magic_trace_cache])
args = [magic_trace_cache, "attach", "-pid", str(pid), "-o", output]
p = subprocess.Popen(args, stderr=subprocess.PIPE, encoding="utf-8")
while True:
x = p.stderr.readline()
print(x)
if "Attached" in x:
break
try:
yield
finally:
p.send_signal(signal.SIGINT)
r = p.wait()
print(p.stderr.read())
p.stderr.close()
if r != 0:
raise ValueError(f"magic_trace exited abnormally: {r}")

View File

@ -0,0 +1,312 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
# pointwise operators can go through a faster pathway
tensor_magic_methods = ["add", ""]
pointwise_magic_methods_with_reverse = (
"add",
"sub",
"mul",
"floordiv",
"div",
"truediv",
"mod",
"pow",
"lshift",
"rshift",
"and",
"or",
"xor",
)
pointwise_magic_methods = (
*(x for m in pointwise_magic_methods_with_reverse for x in (m, "r" + m)),
"eq",
"gt",
"le",
"lt",
"ge",
"gt",
"ne",
"neg",
"pos",
"abs",
"invert",
"iadd",
"isub",
"imul",
"ifloordiv",
"idiv",
"itruediv",
"imod",
"ipow",
"ilshift",
"irshift",
"iand",
"ior",
"ixor",
"int",
"long",
"float",
"complex",
)
pointwise_methods = (*(f"__{m}__" for m in pointwise_magic_methods),)
pointwise = (
*(getattr(torch.Tensor, m) for m in pointwise_methods),
torch.nn.functional.dropout,
torch.where,
torch.Tensor.abs,
torch.abs,
torch.Tensor.acos,
torch.acos,
torch.Tensor.acosh,
torch.acosh,
torch.Tensor.add,
torch.add,
torch.Tensor.addcdiv,
torch.addcdiv,
torch.Tensor.addcmul,
torch.addcmul,
torch.Tensor.addr,
torch.addr,
torch.Tensor.angle,
torch.angle,
torch.Tensor.asin,
torch.asin,
torch.Tensor.asinh,
torch.asinh,
torch.Tensor.atan,
torch.atan,
torch.Tensor.atan2,
torch.atan2,
torch.Tensor.atanh,
torch.atanh,
torch.Tensor.bitwise_and,
torch.bitwise_and,
torch.Tensor.bitwise_left_shift,
torch.bitwise_left_shift,
torch.Tensor.bitwise_not,
torch.bitwise_not,
torch.Tensor.bitwise_or,
torch.bitwise_or,
torch.Tensor.bitwise_right_shift,
torch.bitwise_right_shift,
torch.Tensor.bitwise_xor,
torch.bitwise_xor,
torch.Tensor.ceil,
torch.ceil,
torch.celu,
torch.nn.functional.celu,
torch.Tensor.clamp,
torch.clamp,
torch.Tensor.clamp_max,
torch.clamp_max,
torch.Tensor.clamp_min,
torch.clamp_min,
torch.Tensor.copysign,
torch.copysign,
torch.Tensor.cos,
torch.cos,
torch.Tensor.cosh,
torch.cosh,
torch.Tensor.deg2rad,
torch.deg2rad,
torch.Tensor.digamma,
torch.digamma,
torch.Tensor.div,
torch.div,
torch.dropout,
torch.nn.functional.dropout,
torch.nn.functional.elu,
torch.Tensor.eq,
torch.eq,
torch.Tensor.erf,
torch.erf,
torch.Tensor.erfc,
torch.erfc,
torch.Tensor.erfinv,
torch.erfinv,
torch.Tensor.exp,
torch.exp,
torch.Tensor.exp2,
torch.exp2,
torch.Tensor.expm1,
torch.expm1,
torch.feature_dropout,
torch.Tensor.float_power,
torch.float_power,
torch.Tensor.floor,
torch.floor,
torch.Tensor.floor_divide,
torch.floor_divide,
torch.Tensor.fmod,
torch.fmod,
torch.Tensor.frac,
torch.frac,
torch.Tensor.frexp,
torch.frexp,
torch.Tensor.gcd,
torch.gcd,
torch.Tensor.ge,
torch.ge,
torch.nn.functional.gelu,
torch.nn.functional.glu,
torch.Tensor.gt,
torch.gt,
torch.Tensor.hardshrink,
torch.hardshrink,
torch.nn.functional.hardshrink,
torch.nn.functional.hardsigmoid,
torch.nn.functional.hardswish,
torch.nn.functional.hardtanh,
torch.Tensor.heaviside,
torch.heaviside,
torch.Tensor.hypot,
torch.hypot,
torch.Tensor.i0,
torch.i0,
torch.Tensor.igamma,
torch.igamma,
torch.Tensor.igammac,
torch.igammac,
torch.Tensor.isclose,
torch.isclose,
torch.Tensor.isfinite,
torch.isfinite,
torch.Tensor.isinf,
torch.isinf,
torch.Tensor.isnan,
torch.isnan,
torch.Tensor.isneginf,
torch.isneginf,
torch.Tensor.isposinf,
torch.isposinf,
torch.Tensor.isreal,
torch.isreal,
torch.Tensor.kron,
torch.kron,
torch.Tensor.lcm,
torch.lcm,
torch.Tensor.ldexp,
torch.ldexp,
torch.Tensor.le,
torch.le,
torch.nn.functional.leaky_relu,
torch.Tensor.lerp,
torch.lerp,
torch.Tensor.lgamma,
torch.lgamma,
torch.Tensor.log,
torch.log,
torch.Tensor.log10,
torch.log10,
torch.Tensor.log1p,
torch.log1p,
torch.Tensor.log2,
torch.log2,
torch.nn.functional.logsigmoid,
torch.Tensor.logical_and,
torch.logical_and,
torch.Tensor.logical_not,
torch.logical_not,
torch.Tensor.logical_or,
torch.logical_or,
torch.Tensor.logical_xor,
torch.logical_xor,
torch.Tensor.logit,
torch.logit,
torch.Tensor.lt,
torch.lt,
torch.Tensor.maximum,
torch.maximum,
torch.Tensor.minimum,
torch.minimum,
torch.nn.functional.mish,
torch.Tensor.mvlgamma,
torch.mvlgamma,
torch.Tensor.nan_to_num,
torch.nan_to_num,
torch.Tensor.ne,
torch.ne,
torch.Tensor.neg,
torch.neg,
torch.Tensor.nextafter,
torch.nextafter,
torch.Tensor.outer,
torch.outer,
torch.polar,
torch.Tensor.polygamma,
torch.polygamma,
torch.Tensor.positive,
torch.positive,
torch.Tensor.pow,
torch.pow,
torch.Tensor.prelu,
torch.prelu,
torch.nn.functional.prelu,
torch.Tensor.rad2deg,
torch.rad2deg,
torch.Tensor.reciprocal,
torch.reciprocal,
torch.Tensor.relu,
torch.relu,
torch.nn.functional.relu,
torch.nn.functional.relu6,
torch.Tensor.remainder,
torch.remainder,
torch.Tensor.round,
torch.round,
torch.rrelu,
torch.nn.functional.rrelu,
torch.Tensor.rsqrt,
torch.rsqrt,
torch.rsub,
torch.selu,
torch.nn.functional.selu,
torch.Tensor.sgn,
torch.sgn,
torch.Tensor.sigmoid,
torch.sigmoid,
torch.nn.functional.sigmoid,
torch.Tensor.sign,
torch.sign,
torch.Tensor.signbit,
torch.signbit,
torch.nn.functional.silu,
torch.Tensor.sin,
torch.sin,
torch.Tensor.sinc,
torch.sinc,
torch.Tensor.sinh,
torch.sinh,
torch.nn.functional.softplus,
torch.nn.functional.softshrink,
torch.Tensor.sqrt,
torch.sqrt,
torch.Tensor.square,
torch.square,
torch.Tensor.sub,
torch.sub,
torch.Tensor.tan,
torch.tan,
torch.Tensor.tanh,
torch.tanh,
torch.nn.functional.tanh,
torch.threshold,
torch.nn.functional.threshold,
torch.trapz,
torch.Tensor.true_divide,
torch.true_divide,
torch.Tensor.trunc,
torch.trunc,
torch.Tensor.xlogy,
torch.xlogy,
torch.rand_like,
)

View File

@ -0,0 +1,645 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# reference python implementations for C ops
import torch
from functorch._C import dim as _C
from . import op_properties
from .batch_tensor import _enable_layers
from .tree_map import tree_flatten, tree_map
DimList = _C.DimList
import operator
from functools import reduce
# use dict to avoid writing C++ bindings for set
pointwise = set(op_properties.pointwise)
def prod(x):
return reduce(operator.mul, x, 1)
def _wrap_dim(d, N, keepdim):
from . import Dim
if isinstance(d, Dim):
assert not keepdim, "cannot preserve first-class dimensions with keepdim=True"
return d
elif d >= 0:
return d - N
else:
return d
def _dims(d, N, keepdim, single_dim):
from . import Dim
if isinstance(d, (Dim, int)):
return ltuple((_wrap_dim(d, N, keepdim),))
assert not single_dim, f"expected a single dimension or int but found: {d}"
return ltuple(_wrap_dim(x, N, keepdim) for x in d)
def _bind_dims_to_size(lhs_size, rhs, lhs_debug):
from . import DimensionMismatchError
not_bound = tuple((i, r) for i, r in enumerate(rhs) if not r.is_bound)
if len(not_bound) == 1:
idx, d = not_bound[0]
rhs_so_far = prod(r.size for r in rhs if r.is_bound)
if lhs_size % rhs_so_far != 0:
rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs)
raise DimensionMismatchError(
f"inferred dimension does not evenly fit into larger dimension: {lhs_size} vs {rhs_s}"
)
new_size = lhs_size // rhs_so_far
d.size = new_size
elif len(not_bound) > 1:
rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs)
raise DimensionMismatchError(
f"cannot infer the size of two dimensions at once: {rhs} with sizes {rhs_s}"
)
else:
rhs_size = prod(r.size for r in rhs)
if lhs_size != rhs_size:
raise DimensionMismatchError(
f"Dimension sizes to do not match ({lhs_size} != {rhs_size}) when matching {lhs_debug} to {rhs}"
)
def _tensor_levels(inp):
from . import _Tensor
if isinstance(inp, _Tensor):
return inp._tensor, llist(inp._levels), inp._has_device
else:
return inp, llist(range(-inp.ndim, 0)), True
def _match_levels(v, from_levels, to_levels):
view = []
permute = []
requires_view = False
size = v.size()
for t in to_levels:
try:
idx = from_levels.index(t)
permute.append(idx)
view.append(size[idx])
except ValueError:
view.append(1)
requires_view = True
if permute != list(range(len(permute))):
v = v.permute(*permute)
if requires_view:
v = v.view(*view)
return v
# make a single dimension positional but do not permute it,
# used to do multi-tensor operators where the dim being acted on
# should not physically move if possible
def _positional_no_permute(self, dim, expand_dim=False):
from . import Tensor
ptensor, levels = self._tensor, llist(self._levels)
try:
idx = levels.index(dim)
except ValueError:
if not expand_dim:
raise
idx = 0
ptensor = ptensor.expand(dim.size, *ptensor.size())
levels.insert(0, 0)
idx_batched = 0
for i in range(idx):
if isinstance(levels[i], int):
levels[i] -= 1
idx_batched += 1
levels[idx] = -idx_batched - 1
return Tensor.from_positional(ptensor, levels, self._has_device), idx_batched
def seq(a, b):
from . import Dim
if isinstance(a, Dim) != isinstance(b, Dim):
return False
if isinstance(a, Dim):
return a is b
else:
return a == b
class isin:
def __contains__(self, item):
for x in self:
if seq(item, x):
return True
return False
def index(self, item):
for i, x in enumerate(self):
if seq(item, x):
return i
raise ValueError
class llist(isin, list):
pass
class ltuple(isin, tuple):
pass
empty_dict = {}
@classmethod
def __torch_function__(self, orig, cls, args, kwargs=empty_dict):
from . import _Tensor, Tensor, TensorLike
from .delayed_mul_tensor import DelayedMulTensor
if orig is torch.Tensor.__mul__:
lhs, rhs = args
if (
isinstance(lhs, _Tensor)
and isinstance(rhs, _Tensor)
and lhs.ndim == 0
and rhs.ndim == 0
):
return DelayedMulTensor(lhs, rhs)
all_dims = llist()
flat_args, unflatten = tree_flatten((args, kwargs))
device_holding_tensor = None
for f in flat_args:
if isinstance(f, _Tensor):
if f._has_device:
device_holding_tensor = f._batchtensor
for d in f.dims:
if d not in all_dims:
all_dims.append(d)
def unwrap(t):
if isinstance(t, _Tensor):
r = t._batchtensor
if device_holding_tensor is not None and not t._has_device:
r = r.to(device=device_holding_tensor.device)
return r
return t
if orig in pointwise:
result_levels = llist()
arg_levels = llist()
to_expand = []
for i, f in enumerate(flat_args):
if isinstance(f, TensorLike):
ptensor, levels, _ = _tensor_levels(f)
if (
isinstance(f, _Tensor)
and not f._has_device
and device_holding_tensor is not None
):
ptensor = ptensor.to(device=device_holding_tensor.device)
flat_args[i] = ptensor
for l in levels:
if l not in result_levels:
result_levels.append(l)
to_expand.append((i, levels))
for i, levels in to_expand:
flat_args[i] = _match_levels(flat_args[i], levels, result_levels)
args, kwargs = unflatten(flat_args)
result = orig(*args, **kwargs)
def wrap(t):
if isinstance(t, TensorLike):
return Tensor.from_positional(
t, result_levels, device_holding_tensor is not None
)
return t
return tree_map(wrap, result)
else:
def wrap(t):
if isinstance(t, TensorLike):
return Tensor.from_batched(t, device_holding_tensor is not None)
return t
with _enable_layers(all_dims):
print(f"batch_tensor for {orig}")
args, kwargs = unflatten(unwrap(f) for f in flat_args)
result = orig(*args, **kwargs)
# print("END", orig)
return tree_map(wrap, result)
def positional(self, *dims):
from . import Dim, DimensionBindError, Tensor
ptensor, levels = self._tensor, llist(self._levels)
flat_dims = llist()
view = []
needs_view = False
ndim = self.ndim
for d in dims:
if isinstance(d, DimList):
flat_dims.extend(d)
view.extend(e.size for e in d)
elif isinstance(d, Dim):
flat_dims.append(d)
view.append(d.size)
elif isinstance(d, int):
d = _wrap_dim(d, ndim, False)
flat_dims.append(d)
view.append(ptensor.size(d))
else:
flat_dims.extend(d)
view.append(prod(e.size for e in d))
needs_view = True
permute = list(range(len(levels)))
nflat = len(flat_dims)
for i, d in enumerate(flat_dims):
try:
idx = levels.index(d)
except ValueError as e:
raise DimensionBindError(
f"tensor of dimensions {self.dims} does not contain dim {d}"
) from e
p = permute[idx]
del levels[idx]
del permute[idx]
levels.insert(i, 0)
permute.insert(i, p)
ptensor = ptensor.permute(*permute)
seen = 0
for i in range(len(levels) - 1, -1, -1):
if isinstance(levels[i], int):
seen += 1
levels[i] = -seen
result = Tensor.from_positional(ptensor, levels, self._has_device)
if needs_view:
result = result.reshape(*view, *result.size()[len(flat_dims) :])
return result
def _contains_dim(input):
from . import Dim
for i in input:
if isinstance(i, Dim):
return True
def expand(self, *sizes):
if not _contains_dim(sizes):
return self.__torch_function__(torch.Tensor.expand, None, (self, *sizes))
dims = sizes
sizes = [d.size for d in dims] + [-1] * self.ndim
self = self.expand(*sizes)
return self[dims]
_not_present = object()
def _getarg(name, offset, args, kwargs, default):
if len(args) > offset:
return args[offset]
return kwargs.get(name, default)
def _patcharg(name, offset, args, kwargs, value):
if len(args) > offset:
args[offset] = value
else:
kwargs[name] = value
def _wrap(
orig, dim_offset=0, keepdim_offset=1, dim_name="dim", single_dim=False, reduce=True
):
from . import Dim, Tensor, TensorLike
def fn(self, *args, **kwargs):
dim = _getarg(dim_name, dim_offset, args, kwargs, _not_present)
if dim is _not_present or (single_dim and not isinstance(dim, Dim)):
with _enable_layers(self.dims):
print(f"dim fallback batch_tensor for {orig}")
return Tensor.from_batched(
orig(self._batchtensor, *args, **kwargs), self._has_device
)
keepdim = (
_getarg("keepdim", keepdim_offset, args, kwargs, False) if reduce else False
)
t, levels = self._tensor, llist(self._levels)
dims = _dims(dim, self._batchtensor.ndim, keepdim, single_dim)
dim_indices = tuple(levels.index(d) for d in dims)
if reduce and not keepdim:
new_levels = [l for i, l in enumerate(levels) if i not in dim_indices]
else:
new_levels = levels
if len(dim_indices) == 1:
dim_indices = dim_indices[
0
] # so that dims that really only take a single argument work...
args = list(args)
_patcharg(dim_name, dim_offset, args, kwargs, dim_indices)
def wrap(t):
if isinstance(t, TensorLike):
return Tensor.from_positional(t, new_levels, self._has_device)
return t
with _enable_layers(new_levels):
print(f"dim used batch_tensor for {orig}")
r = orig(t, *args, **kwargs)
return tree_map(wrap, r)
return fn
def _def(name, *args, **kwargs):
from . import _Tensor
orig = getattr(torch.Tensor, name)
setattr(_Tensor, name, _wrap(orig, *args, **kwargs))
no_slice = slice(None)
_orig_getitem = torch.Tensor.__getitem__
class dim_tracker:
def __init__(self) -> None:
self.dims = llist()
self.count = []
def record(self, d):
if d not in self.dims:
self.dims.append(d)
self.count.append(1)
def __getitem__(self, d):
return self.count[self.dims.index(d)]
def t__getitem__(self, input):
from . import _Tensor, Dim, DimensionBindError, DimList, Tensor, TensorLike
# * bail to original example if we have a single non-Dim tensor, or a non-tensor
# * locate ... or an unbound tensor list, and determine its size, bind dim list
# (remember that None does not count to the total dim count)
# * bind simple dims and dim-packs to their sizes, count the number of uses of each dim,
# produce the re-view if needed
# * for each single-use dim index, replace with no_slice and mark that it will be added
# (keep track of whether we have to call super)
# * call super if needed
# * if we have dims to bind, bind them (it will help if we eliminated ... and None before)
# this handles bool indexing handling, as well as some other simple cases.
is_simple = (
not isinstance(input, Dim)
and not isinstance(input, (tuple, list))
and
# WAR for functorch bug where zero time tensors in getitem are not handled correctly.
not (isinstance(input, TensorLike) and input.ndim == 0)
)
if is_simple:
if isinstance(self, _Tensor):
return _Tensor.__torch_function__(_orig_getitem, None, (self, input))
else:
return _orig_getitem(self, input)
# can further optimize this case
if not isinstance(input, tuple):
input = [input]
else:
input = list(input)
dims_indexed = 0
expanding_object = None
dimlists = []
for i, s in enumerate(input):
if s is ... or isinstance(s, DimList) and not s.is_bound:
if expanding_object is not None:
msg = (
"at most one ... or unbound dimension list can exist in indexing list but"
f" found 2 at offsets {i} and {expanding_object}"
)
raise DimensionBindError(msg)
expanding_object = i
if isinstance(s, DimList):
dims_indexed += len(s) if s.is_bound else 0
dimlists.append(i)
elif s is not None and s is not ...:
dims_indexed += 1
ndim = self.ndim
if dims_indexed > ndim:
raise IndexError(
f"at least {dims_indexed} indices were supplied but the tensor only has {ndim} dimensions."
)
if expanding_object is not None:
expanding_ndims = ndim - dims_indexed
obj = input[expanding_object]
if obj is ...:
input[expanding_object : expanding_object + 1] = [
no_slice
] * expanding_ndims
else:
obj.bind_len(expanding_ndims)
# flatten the dimslists into the indexing
for i in reversed(dimlists):
input[i : i + 1] = input[i]
dims_indexed = 0
requires_view = False
size = self.size()
view_sizes = []
dims_seen = dim_tracker()
def add_dims(t):
if not isinstance(t, _Tensor):
return
for d in t.dims:
dims_seen.record(d)
add_dims(self)
dim_packs = []
for i, idx in enumerate(input):
if idx is None:
input[i] = no_slice
view_sizes.append(1)
requires_view = True
else:
sz = size[dims_indexed]
if isinstance(idx, Dim):
idx.size = sz
dims_seen.record(idx)
view_sizes.append(sz)
elif isinstance(idx, (tuple, list)) and idx and isinstance(idx[0], Dim):
for d in idx:
dims_seen.record(idx)
_bind_dims_to_size(sz, idx, f"offset {i}")
view_sizes.extend(d.size for d in idx)
requires_view = True
dim_packs.append(i)
else:
add_dims(idx)
view_sizes.append(sz)
dims_indexed += 1
if requires_view:
self = self.view(*view_sizes)
for i in reversed(dim_packs):
input[i : i + 1] = input[i]
# currenty:
# input is flat, containing either Dim, or Tensor, or something valid for standard indexing
# self may have first-class dims as well.
# to index:
# drop the first class dims from self, they just become direct indices of their positions
# figure out the dimensions of the indexing tensors: union of all the dims in the tensors in the index.
# these dimensions will appear and need to be bound at the first place tensor occures
if isinstance(self, _Tensor):
ptensor_self, levels = self._tensor, list(self._levels)
# indices to ptensor rather than self which has first-class dimensions
input_it = iter(input)
flat_inputs = [next(input_it) if isinstance(l, int) else l for l in levels]
has_device = self._has_device
to_pad = 0
else:
ptensor_self, flat_inputs = self, input
to_pad = ptensor_self.ndim - len(flat_inputs)
has_device = True
result_levels = []
index_levels = []
tensor_insert_point = None
to_expand = {}
requires_getindex = False
for i, inp in enumerate(flat_inputs):
if isinstance(inp, Dim) and dims_seen[inp] == 1:
flat_inputs[i] = no_slice
result_levels.append(inp)
elif isinstance(inp, TensorLike):
requires_getindex = True
if tensor_insert_point is None:
tensor_insert_point = len(result_levels)
ptensor, levels, _ = _tensor_levels(inp)
to_expand[i] = levels
flat_inputs[i] = ptensor
for l in levels:
if l not in index_levels:
index_levels.append(l)
else:
requires_getindex = True
result_levels.append(0)
if tensor_insert_point is not None:
result_levels[tensor_insert_point:tensor_insert_point] = index_levels
for i, levels in to_expand.items():
flat_inputs[i] = _match_levels(flat_inputs[i], levels, index_levels)
if requires_getindex:
result = _orig_getitem(ptensor_self, flat_inputs)
else:
result = ptensor_self
next_positional = -1
if to_pad > 0:
result_levels.extend([0] * to_pad)
for i, r in enumerate(reversed(result_levels)):
if isinstance(r, int):
result_levels[-1 - i] = next_positional
next_positional -= 1
return Tensor.from_positional(result, result_levels, has_device)
# XXX - dim is optional and can be the outer-most dimension...
def stack(tensors, new_dim, dim=0, out=None):
if isinstance(dim, int):
return torch.stack(tensors, dim, out).index(dim, new_dim)
index = None
if out is not None:
out, index = _positional_no_permute(out, dim, expand_dim=True)
ptensors = []
for t in tensors:
pt, pi = _positional_no_permute(t, dim, expand_dim=True)
if index is not None and pi != index:
pt = pt.move_dim(pi, index)
else:
index = pi
ptensors.append(pt)
pr = torch.stack(ptensors, index, out=out)
return pr.index((index, index + 1), (new_dim, dim))
_orig_split = torch.Tensor.split
def split(self, split_size_or_sections, dim=0):
from . import _Tensor, Dim
if isinstance(split_size_or_sections, int) or any(
isinstance(t, int) for t in split_size_or_sections
):
if isinstance(dim, Dim):
raise ValueError(
"when dim is specified as a Dim object, split sizes must also be dimensions."
)
return _orig_split(self, split_size_or_sections, dim=dim)
if isinstance(dim, Dim):
assert isinstance(self, _Tensor), f"Tensor does not have dimension {dim}"
self, dim = _positional_no_permute(self, dim)
size = self.size(dim)
total_bound_size = 0
unbound = []
sizes = []
for i, d in enumerate(split_size_or_sections):
if d.is_bound:
sizes.append(d.size)
total_bound_size += d.size
else:
sizes.append(0)
unbound.append(i)
if unbound:
assert (
total_bound_size <= size
), f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})"
remaining_size = size - total_bound_size
chunk_size = -(-remaining_size // len(unbound))
for u in unbound:
sz = min(chunk_size, remaining_size)
split_size_or_sections[u].size = sz
sizes[u] = sz
remaining_size -= sz
else:
assert (
total_bound_size == size
), f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})"
return tuple(
t.index(dim, d)
for d, t in zip(split_size_or_sections, _orig_split(self, sizes, dim=dim))
)

View File

@ -0,0 +1,15 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from functorch._C import dim
tree_flatten = dim.tree_flatten
def tree_map(fn, tree):
vs, unflatten = tree_flatten(tree)
return unflatten(fn(v) for v in vs)

View File

@ -0,0 +1,72 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from types import (
BuiltinMethodType,
FunctionType,
GetSetDescriptorType,
MethodDescriptorType,
WrapperDescriptorType,
)
from functorch._C import dim as _C
_wrap_method = _C._wrap_method
FUNC_TYPES = (
FunctionType,
MethodDescriptorType,
BuiltinMethodType,
WrapperDescriptorType,
)
PROPERTY_TYPES = (GetSetDescriptorType, property)
def _py_wrap_method(orig, __torch_function__):
def impl(*args, **kwargs):
return __torch_function__(orig, None, args, kwargs)
return impl
def wrap_type(use_c, to_patch, pattern, __torch_function__):
if use_c:
wrap_method = _wrap_method
else:
wrap_method = _py_wrap_method
all = {}
for t in reversed(pattern.mro()[:-1]): # skip object
all.update(t.__dict__)
def wrap_attr(orig):
return property(wrap_method(orig.__get__, __torch_function__))
for name, obj in all.items():
if name in (
"__dict__",
"__new__",
"__init__",
"__repr__",
"__weakref__",
"__doc__",
"__module__",
"__dir__",
):
continue
# skip things that have been overloaded
# things that come from object like `__eq__` still need to be patched, however.
if hasattr(to_patch, name) and getattr(to_patch, name) is not getattr(
object, name, None
):
continue
if isinstance(obj, FUNC_TYPES):
setattr(to_patch, name, wrap_method(obj, __torch_function__))
elif isinstance(obj, PROPERTY_TYPES):
setattr(to_patch, name, wrap_attr(obj))

View File

@ -0,0 +1,4 @@
from .rearrange import rearrange
__all__ = ["rearrange"]

View File

@ -0,0 +1,303 @@
"""Adapted from https://github.com/arogozhnikov/einops/blob/36c7bb16e57d6e57f8f3050f9e07abdf3f00469f/einops/parsing.py.
MIT License
Copyright (c) 2018 Alex Rogozhnikov
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
from __future__ import annotations
import keyword
import warnings
from typing import Collection, List, Mapping, Optional, Set, Tuple, Union
_ellipsis: str = "\u2026" # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated
class AnonymousAxis:
"""Used by `ParsedExpression` to represent an axis with a size (> 1), but no associated identifier.
Note: Different instances of this class are not equal to each other, even if they have the same value.
"""
def __init__(self, value: str) -> None:
self.value = int(value)
if self.value < 1:
raise ValueError(
f"Anonymous axis should have positive length, not {self.value}"
)
def __repr__(self) -> str:
return f"{self.value}-axis"
class ParsedExpression:
"""Structure containing information about one side of an `einops`-style pattern (e.g. 'b c (h w)')."""
def __init__(
self,
expression: str,
*,
allow_underscore: bool = False,
allow_duplicates: bool = False,
) -> None:
"""Parse the expression and store relevant metadata.
Args:
expression (str): the `einops`-pattern to parse
allow_underscore (bool): whether to allow axis identifier names to begin with an underscore
allow_duplicates (bool): whether to allow an identifier to appear more than once in the expression
"""
self.has_ellipsis: bool = False
self.has_ellipsis_parenthesized: Optional[bool] = None
self.identifiers: Set[Union[str, AnonymousAxis]] = set()
# that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition
self.has_non_unitary_anonymous_axes: bool = False
# composition keeps structure of composite axes, see how different corner cases are handled in tests
self.composition: List[Union[List[Union[str, AnonymousAxis]], str]] = []
if "." in expression:
if "..." not in expression:
raise ValueError(
"Expression may contain dots only inside ellipsis (...)"
)
if str.count(expression, "...") != 1 or str.count(expression, ".") != 3:
raise ValueError(
"Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor "
)
expression = expression.replace("...", _ellipsis)
self.has_ellipsis = True
bracket_group: Optional[List[Union[str, AnonymousAxis]]] = None
def add_axis_name(x: str) -> None:
if x in self.identifiers:
if not (allow_underscore and x == "_") and not allow_duplicates:
raise ValueError(
f"Indexing expression contains duplicate dimension '{x}'"
)
if x == _ellipsis:
self.identifiers.add(_ellipsis)
if bracket_group is None:
self.composition.append(_ellipsis)
self.has_ellipsis_parenthesized = False
else:
bracket_group.append(_ellipsis)
self.has_ellipsis_parenthesized = True
else:
is_number = str.isdecimal(x)
if is_number and int(x) == 1:
# handling the case of anonymous axis of length 1
if bracket_group is None:
self.composition.append([])
else:
pass # no need to think about 1s inside parenthesis
return
is_axis_name, reason = self.check_axis_name_return_reason(
x, allow_underscore=allow_underscore
)
if not (is_number or is_axis_name):
raise ValueError(f"Invalid axis identifier: {x}\n{reason}")
axis_name: Union[str, AnonymousAxis] = (
AnonymousAxis(x) if is_number else x
)
self.identifiers.add(axis_name)
if is_number:
self.has_non_unitary_anonymous_axes = True
if bracket_group is None:
self.composition.append([axis_name])
else:
bracket_group.append(axis_name)
current_identifier = None
for char in expression:
if char in "() ":
if current_identifier is not None:
add_axis_name(current_identifier)
current_identifier = None
if char == "(":
if bracket_group is not None:
raise ValueError(
"Axis composition is one-level (brackets inside brackets not allowed)"
)
bracket_group = []
elif char == ")":
if bracket_group is None:
raise ValueError("Brackets are not balanced")
self.composition.append(bracket_group)
bracket_group = None
elif str.isalnum(char) or char in ["_", _ellipsis]:
if current_identifier is None:
current_identifier = char
else:
current_identifier += char
else:
raise ValueError(f"Unknown character '{char}'")
if bracket_group is not None:
raise ValueError(f"Imbalanced parentheses in expression: '{expression}'")
if current_identifier is not None:
add_axis_name(current_identifier)
@staticmethod
def check_axis_name_return_reason(
name: str, allow_underscore: bool = False
) -> Tuple[bool, str]:
"""Check if the given axis name is valid, and a message explaining why if not.
Valid axes names are python identifiers except keywords, and should not start or end with an underscore.
Args:
name (str): the axis name to check
allow_underscore (bool): whether axis names are allowed to start with an underscore
Returns:
Tuple[bool, str]: whether the axis name is valid, a message explaining why if not
"""
if not str.isidentifier(name):
return False, "not a valid python identifier"
elif name[0] == "_" or name[-1] == "_":
if name == "_" and allow_underscore:
return True, ""
return False, "axis name should should not start or end with underscore"
else:
if keyword.iskeyword(name):
warnings.warn(
f"It is discouraged to use axes names that are keywords: {name}",
RuntimeWarning,
)
if name in ["axis"]:
warnings.warn(
"It is discouraged to use 'axis' as an axis name and will raise an error in future",
FutureWarning,
)
return True, ""
@staticmethod
def check_axis_name(name: str) -> bool:
"""Check if the name is a valid axis name.
Args:
name (str): the axis name to check
Returns:
bool: whether the axis name is valid
"""
is_valid, _ = ParsedExpression.check_axis_name_return_reason(name)
return is_valid
def parse_pattern(
pattern: str, axes_lengths: Mapping[str, int]
) -> Tuple[ParsedExpression, ParsedExpression]:
"""Parse an `einops`-style pattern into a left-hand side and right-hand side `ParsedExpression` object.
Args:
pattern (str): the `einops`-style rearrangement pattern
axes_lengths (Mapping[str, int]): any additional length specifications for dimensions
Returns:
Tuple[ParsedExpression, ParsedExpression]: a tuple containing the left-hand side and right-hand side expressions
"""
# adapted from einops.einops._prepare_transformation_recipe
# https://github.com/arogozhnikov/einops/blob/230ac1526c1f42c9e1f7373912c7f8047496df11/einops/einops.py
try:
left_str, right_str = pattern.split("->")
except ValueError:
raise ValueError("Pattern must contain a single '->' separator") from None
if _ellipsis in axes_lengths:
raise ValueError(f"'{_ellipsis}' is not an allowed axis identifier")
left = ParsedExpression(left_str)
right = ParsedExpression(right_str)
if not left.has_ellipsis and right.has_ellipsis:
raise ValueError(
f"Ellipsis found in right side, but not left side of a pattern {pattern}"
)
if left.has_ellipsis and left.has_ellipsis_parenthesized:
raise ValueError(
f"Ellipsis is parenthesis in the left side is not allowed: {pattern}"
)
return left, right
def validate_rearrange_expressions(
left: ParsedExpression, right: ParsedExpression, axes_lengths: Mapping[str, int]
) -> None:
"""Perform expression validations that are specific to the `rearrange` operation.
Args:
left (ParsedExpression): left-hand side expression
right (ParsedExpression): right-hand side expression
axes_lengths (Mapping[str, int]): any additional length specifications for dimensions
"""
for length in axes_lengths.values():
if (length_type := type(length)) is not int:
raise TypeError(
f"rearrange axis lengths must be integers, got: {length_type}"
)
if left.has_non_unitary_anonymous_axes or right.has_non_unitary_anonymous_axes:
raise ValueError("rearrange only supports unnamed axes of size 1")
difference = set.symmetric_difference(left.identifiers, right.identifiers)
if len(difference) > 0:
raise ValueError(
f"Identifiers only on one side of rearrange expression (should be on both): {difference}"
)
unmatched_axes = axes_lengths.keys() - left.identifiers
if len(unmatched_axes) > 0:
raise ValueError(
f"Identifiers not found in rearrange expression: {unmatched_axes}"
)
def comma_separate(collection: Collection[Union[str, Collection[str]]]) -> str:
"""Convert a collection of strings representing first class dims into a comma-separated string.
Args:
collection (Collection[Union[str, Collection[str]]]): the collection of strings to convert
Returns:
str: the comma-separated string
Examples:
>>> comma_separate(('d0',))
'd0'
>>> comma_separate(('d0', 'd1', 'd2', 'd3'))
'd0, d1, d2, d3'
>>> comma_separate([('d1', 'd4')])
'(d1, d4)'
>>> comma_separate([('d0',), (), ('d1',), ('d2',), ('d3', 'd4')])
'(d0,), (), (d1,), (d2,), (d3, d4)'
"""
return ", ".join(
item
if isinstance(item, str)
else f"({comma_separate(item)}{',' if len(item) == 1 else ''})"
for item in collection
)

View File

@ -0,0 +1,208 @@
from __future__ import annotations
import functools
from typing import Callable, Dict, List, Sequence, Tuple, Union
import torch
from functorch._C import dim as _C
from ._parsing import (
_ellipsis,
AnonymousAxis,
comma_separate,
parse_pattern,
validate_rearrange_expressions,
)
__all__ = ["rearrange"]
dims = _C.dims
@functools.lru_cache(256)
def _create_rearrange_callable(
tensor_ndim: int, pattern: str, **axes_lengths: int
) -> Callable[[torch.Tensor], torch.Tensor]:
r"""Translate an `einops`-style pattern into a callable that performs the rearrange using first-class dimensions.
Since the an equivalent result is computed for tensors with the same number of dimensions, with the same pattern and
specified axes lengths, this function can be memoized.
Args:
tensor_ndim (int): the number of dimensions in the tensor to rearrange
pattern (str): the `einops`-style rearrangement pattern
axes_lengths (int): any additional length specifications for dimensions
Returns:
Callable[[torch.Tensor], torch.Tensor]: a callable that performs the rearrangement
"""
left, right = parse_pattern(pattern, axes_lengths)
validate_rearrange_expressions(left, right, axes_lengths)
n_anon_dims = sum(not dim for dim in left.composition)
if left.has_ellipsis:
n_ellipsis_dims = tensor_ndim - (len(left.composition) - 1)
n_named_dims = len(left.identifiers) - 1
if (pattern_ndim := n_anon_dims + n_named_dims) > tensor_ndim:
raise ValueError(
f"Number of dimensions in pattern ({pattern_ndim}) must be less than or equal to the number of "
f"dimensions in the tensor ({tensor_ndim})"
)
else:
n_ellipsis_dims = 0
n_named_dims = len(left.identifiers)
if (pattern_ndim := len(left.composition)) != tensor_ndim:
raise ValueError(
f"Number of dimensions in pattern ({pattern_ndim}) must be equal to the number of dimensions in "
f"the tensor ({tensor_ndim})"
)
n_dims = n_named_dims + n_ellipsis_dims + n_anon_dims
if n_dims == 0:
# an identity rearrangement on a 0-dimension tensor
return lambda tensor: tensor
first_class_dims: Tuple[str, ...] = tuple(f"d{i}" for i in range(n_dims))
identifier_dim_map: Dict[Union[str, AnonymousAxis], Tuple[str, ...]] = {}
anon_axes: List[AnonymousAxis] = []
# map the left-hand side identifiers to strings representing first class dims
dims_i = 0
for dimension in left.composition:
if isinstance(dimension, list):
for identifier in dimension:
# non-unitary anon axes are not allowed in rearrange & unitary anon axes are represented as empty lists
assert isinstance(identifier, str)
identifier_dim_map[identifier] = (first_class_dims[dims_i],)
dims_i += 1
if not dimension:
# unitary anonymous axis
anon_axis = AnonymousAxis("1")
identifier_dim_map[anon_axis] = (first_class_dims[dims_i],)
anon_axes.append(anon_axis)
dimension.append(anon_axis)
dims_i += 1
elif dimension == _ellipsis:
identifier = _ellipsis
identifier_dim_map[identifier] = tuple(
first_class_dims[dims_i + j] for j in range(n_ellipsis_dims)
)
dims_i += n_ellipsis_dims
else:
raise ValueError(f"Unexpected dimension: {dimension}")
def composition_to_dims(
composition: Sequence[Union[List[Union[str, AnonymousAxis]], str]]
) -> List[Union[str, Tuple[str, ...]]]:
"""Convert a `ParsedExpression.composition` into a `Tensor.__getitem__` index of strings representing first
class dims."""
dim_composition: List[Union[str, Tuple[str, ...]]] = []
for dimension in composition:
if isinstance(dimension, list):
dim_composition.append(
tuple(
dim
for identifier in dimension
for dim in identifier_dim_map[identifier]
)
)
elif dimension == _ellipsis:
dim_composition.extend(identifier_dim_map[_ellipsis])
else:
raise ValueError(f"Unexpected dimension: {dimension}")
return dim_composition
left_dims = composition_to_dims(left.composition)
right_dims = composition_to_dims(right.composition)
anon_dims = tuple(identifier_dim_map[axis][0] for axis in anon_axes)
specified_lengths = tuple(
(identifier_dim_map[axis][0], length) for axis, length in axes_lengths.items()
)
custom_rearrange_callable_name = "do_rearrange"
custom_rearrange_callable_code = (
(
f"def {custom_rearrange_callable_name}(tensor):\n"
f" {comma_separate(first_class_dims)} = dims({n_dims})\n"
)
+ (
"".join(
f" {dim}.size = {length}\n" for (dim, length) in specified_lengths
)
if specified_lengths
else ""
)
+ f" tensor = tensor[{comma_separate(left_dims)}].order({comma_separate(right_dims)})\n"
+ (
f" return tensor.sum({comma_separate([anon_dims])}, keepdim=False)\n"
if anon_dims
else " return tensor\n"
)
)
exec(custom_rearrange_callable_code)
return locals()[custom_rearrange_callable_name]
def rearrange(
tensor: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]],
pattern: str,
**axes_lengths: int,
) -> torch.Tensor:
r"""A native implementation of `einops.rearrange`, a reader-friendly smart element reordering for multidimensional
tensors. This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze,
stack, concatenate and other operations.
See: https://einops.rocks/api/rearrange/
Args:
tensor (Tensor or sequence of Tensor): the tensor(s) to rearrange
pattern (str): the rearrangement pattern
axes_lengths (int): any additional length specifications for dimensions
Returns:
Tensor: the rearranged tensor
Examples:
>>> # suppose we have a set of 32 images in "h w c" format (height-width-channel)
>>> images = torch.randn((32, 30, 40, 3))
>>> # stack along first (batch) axis, output is a single array
>>> rearrange(images, 'b h w c -> b h w c').shape
torch.Size([32, 30, 40, 3])
>>> # concatenate images along height (vertical axis), 960 = 32 * 30
>>> rearrange(images, 'b h w c -> (b h) w c').shape
torch.Size([960, 40, 3])
>>> # concatenated images along horizontal axis, 1280 = 32 * 40
>>> rearrange(images, 'b h w c -> h (b w) c').shape
torch.Size([30, 1280, 3])
>>> # reordered axes to "b c h w" format for deep learning
>>> rearrange(images, 'b h w c -> b c h w').shape
torch.Size([32, 3, 30, 40])
>>> # flattened each image into a vector, 3600 = 30 * 40 * 3
>>> rearrange(images, 'b h w c -> b (c h w)').shape
torch.Size([32, 3600])
>>> # split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2
>>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape
torch.Size([128, 15, 20, 3])
>>> # space-to-depth operation
>>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape
torch.Size([32, 15, 20, 12])
"""
if not isinstance(tensor, torch.Tensor):
tensor = torch.stack(tensor)
rearrange_callable = _create_rearrange_callable(
tensor.ndim, pattern, **axes_lengths
)
return rearrange_callable(tensor)

View File

@ -0,0 +1,5 @@
# PyTorch forward-mode is not mature yet
from functorch import functionalize
from torch._functorch.apis import chunk_vmap
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
from torch._functorch.eager_transforms import hessian, jacfwd, jvp

View File

@ -0,0 +1,7 @@
from torch import cond # noqa: F401
from torch._higher_order_ops.cond import UnsupportedAliasMutationException # noqa: F401
from torch._higher_order_ops.map import ( # noqa: F401
_stack_pytree,
_unstack_pytree,
map,
)

View File

@ -0,0 +1 @@
from torch._ops import HigherOrderOperator # noqa: F401