I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,181 @@
import dis
import inspect
from typing import Sequence, Union
import functorch._C
import torch
from functorch._C import dim as _C
from .tree_map import tree_flatten, tree_map
from .wrap_type import wrap_type
_C._patch_tensor_class()
dims, DimList, dimlists = _C.dims, _C.DimList, _C.dimlists
class DimensionMismatchError(Exception):
pass
class DimensionBindError(Exception):
pass
from . import op_properties
# use dict to avoid writing C++ bindings for set
pointwise = dict.fromkeys(op_properties.pointwise, True)
use_c = True
if not use_c:
from . import reference
class _Tensor:
# fast path around slow wrapping/unwrapping logic for simply queries used
# by the implementation...
@property
def dims(self):
return tuple(d for d in self._levels if isinstance(d, Dim))
def dim(self):
return self.ndim
if use_c:
__torch_function__ = classmethod(_C.__torch_function__)
expand = _C._instancemethod(_C.expand)
else:
__torch_function__ = reference.__torch_function__
expand = reference.expand
index = _C._instancemethod(_C.index)
def __repr__(self):
tensor, levels, ndim = self._tensor, self._levels, self.ndim
return f"{tensor}\nwith dims={tuple(l + ndim if isinstance(l, int) else l for l in levels)} sizes={tuple(tensor.size())}"
TensorLike = (_Tensor, torch.Tensor)
class Dim(_C.Dim, _Tensor):
# note that _C.Dim comes before tensor because we want the Dim API for things like size to take precendence.
# Tensor defines format, but we want to print Dims with special formatting
__format__ = object.__format__
class Tensor(_Tensor, _C.Tensor):
if not use_c:
from_batched = staticmethod(_C.Tensor_from_batched)
from_positional = staticmethod(_C.Tensor_from_positional)
sum = _C._instancemethod(_C.Tensor_sum)
def cat(tensors, dim, new_dim):
n = dims()
return stack(tensors, n, dim).index([n, dim], new_dim)
if use_c:
_wrap = _C._wrap
def _def(name, *args, **kwargs):
orig = getattr(torch.Tensor, name)
setattr(_Tensor, name, _C._instancemethod(_wrap(orig, *args, **kwargs)))
t__getitem__ = _C._instancemethod(_C.__getitem__)
stack = _C.stack
split = _C._instancemethod(_C.split)
else:
_wrap, _def = reference._wrap, reference._def
t__getitem__ = reference.t__getitem__
stack = reference.stack
split = reference.split
# note: there is no python reference
t__setitem__ = _C._instancemethod(_C.__setitem__)
# this is patched in the C API because otherwise torch.Tensor will
# no longer be considered a sequence and things will break
# torch.Tensor.__getitem__ = t__getitem__
_Tensor.__getitem__ = t__getitem__
# torch.Tensor.__setitem__ = t__setitem__
_Tensor.__setitem__ = t__setitem__
torch.Tensor.split = split
_Tensor.split = split
torch.Tensor.expand = _C._instancemethod(_C.expand)
torch.Tensor.index = _C._instancemethod(_C.index)
wrap_type(use_c, _Tensor, torch.Tensor, _Tensor.__torch_function__)
del _Tensor.ndim
if use_c:
_Tensor.order = _C._instancemethod(_C.order)
else:
_Tensor.order = reference.positional
_def("mean")
_def("sum")
_def("all")
_def("amax")
_def("amin")
_def("aminmax")
_def("any")
_def("count_nonzero")
_def("logsumexp")
_def("nanmean")
_def("nansum")
_def("prod")
_def("std", keepdim_offset=2)
_def("var", keepdim_offset=2)
_def("max", single_dim=True)
_def("min", single_dim=True)
_def("argmax", single_dim=True)
_def("argmin", single_dim=True)
_def("kthvalue", single_dim=True)
_def("median", single_dim=True)
_def("nanmedian", single_dim=True)
_def("mode", single_dim=True)
_def("sort", reduce=False)
_def("argsort", reduce=False)
_def("unbind", single_dim=True)
_def("chunk", dim_offset=1, reduce=False)
_def("cummax", single_dim=True, reduce=False)
_def("cummin", single_dim=True, reduce=False)
_def("cumprod", single_dim=True, reduce=False)
_def("cumprod_", single_dim=True, reduce=False)
_def("cumsum", single_dim=True, reduce=False)
_def("cumsum_", single_dim=True, reduce=False)
_def("logcumsumexp", single_dim=True, reduce=False)
_def("renorm", dim_offset=1, single_dim=True, reduce=False)
_def("softmax", single_dim=True, reduce=False)
softmax = _wrap(torch.nn.functional.softmax, single_dim=True, reduce=False)
# stuff to handle in the future, because they require special
# binding logic for dims
# cross
# diag_embed
# diagonal
# diagonal_scatter
# diff
# nanquantile
# quantile
# roll
# rot90
# topk (new dimes on output)
# should these all be subsumed by inplace indexing?
# index_add_
# index_add
# index_copy
# index_copy_
# index_fill
# index_fill_
# index_select
# scatter
# scatter_
# scatter_add
# scatter_add_
# scatter_reduce

View File

@ -0,0 +1,26 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from contextlib import contextmanager
from torch._C._functorch import _vmap_add_layers, _vmap_remove_layers
_enabled = False
@contextmanager
def _enable_layers(dims):
global _enabled
assert not _enabled
input = sorted((d._level, d.size) for d in dims if not isinstance(d, int))
n = len(input)
try:
_vmap_add_layers(input)
_enabled = True
yield
finally:
_enabled = False
_vmap_remove_layers(n)

View File

@ -0,0 +1,77 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from . import _Tensor, Tensor
from .reference import _dims, _enable_layers, llist, ltuple
class DelayedMulTensor(_Tensor):
def __init__(self, lhs, rhs):
self._lhs, self._rhs = lhs, rhs
self._data = None
self._levels_data = None
self._has_device = lhs._has_device or rhs._has_device
self._batchtensor_data = None
self._tensor_data = None
@property
def _levels(self):
if self._levels_data is None:
levels = llist(self._lhs._levels)
for l in self._rhs._levels:
if l not in levels:
levels.append(l)
self._levels_data = ltuple(levels)
return self._levels_data
@property
def _batchtensor(self):
if self._batchtensor_data is None:
with _enable_layers(self._levels):
print("bt multiply fallback")
self._batchtensor_data = self._lhs._batchtensor * self._rhs._batchtensor
return self._batchtensor_data
@property
def _tensor(self):
if self._tensor_data is None:
self._tensor_data = Tensor.from_batched(
self._batchtensor, self._has_device
)._tensor
return self._tensor_data
@property
def ndim(self):
return self._batchtensor.ndim
@property
def dims(self):
return ltuple(super().dims)
def sum(self, dim):
dims = _dims(dim, 0, False, False)
n = ord("a")
all_levels = self._levels
def to_char(d):
return chr(n + all_levels.index(d))
plhs, levelslhs = self._lhs._tensor, self._lhs._levels
prhs, levelsrhs = self._rhs._tensor, self._rhs._levels
new_dims = tuple(d for d in self.dims if d not in dims)
new_levels = [l for l in self._levels if l not in dims]
fmt = "".join(
[
*(to_char(d) for d in levelslhs),
",",
*(to_char(d) for d in levelsrhs),
"->",
*(to_char(d) for d in new_levels),
]
)
result_data = torch.einsum(fmt, (plhs, prhs))
return Tensor.from_positional(result_data, new_levels, True)

View File

@ -0,0 +1,121 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import dis
import inspect
from dataclasses import dataclass
from typing import Union
from . import DimList
_vmap_levels = []
@dataclass
class LevelInfo:
level: int
alive: bool = True
class Dim:
def __init__(self, name: str, size: Union[None, int] = None):
self.name = name
self._size = None
self._vmap_level = None
if size is not None:
self.size = size
def __del__(self):
if self._vmap_level is not None:
_vmap_active_levels[self._vmap_stack].alive = False # noqa: F821
while (
not _vmap_levels[-1].alive
and current_level() == _vmap_levels[-1].level # noqa: F821
):
_vmap_decrement_nesting() # noqa: F821
_vmap_levels.pop()
@property
def size(self):
assert self.is_bound
return self._size
@size.setter
def size(self, size: int):
from . import DimensionBindError
if self._size is None:
self._size = size
self._vmap_level = _vmap_increment_nesting(size, "same") # noqa: F821
self._vmap_stack = len(_vmap_levels)
_vmap_levels.append(LevelInfo(self._vmap_level))
elif self._size != size:
raise DimensionBindError(
f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}"
)
@property
def is_bound(self):
return self._size is not None
def __repr__(self):
return self.name
def extract_name(inst):
assert inst.opname == "STORE_FAST" or inst.opname == "STORE_NAME"
return inst.argval
_cache = {}
def dims(lists=0):
frame = inspect.currentframe()
assert frame is not None
calling_frame = frame.f_back
assert calling_frame is not None
code, lasti = calling_frame.f_code, calling_frame.f_lasti
key = (code, lasti)
if key not in _cache:
first = lasti // 2 + 1
instructions = list(dis.get_instructions(calling_frame.f_code))
unpack = instructions[first]
if unpack.opname == "STORE_FAST" or unpack.opname == "STORE_NAME":
# just a single dim, not a list
name = unpack.argval
ctor = Dim if lists == 0 else DimList
_cache[key] = lambda: ctor(name=name)
else:
assert unpack.opname == "UNPACK_SEQUENCE"
ndims = unpack.argval
names = tuple(
extract_name(instructions[first + 1 + i]) for i in range(ndims)
)
first_list = len(names) - lists
_cache[key] = lambda: tuple(
Dim(n) if i < first_list else DimList(name=n)
for i, n in enumerate(names)
)
return _cache[key]()
def _dim_set(positional, arg):
def convert(a):
if isinstance(a, Dim):
return a
else:
assert isinstance(a, int)
return positional[a]
if arg is None:
return positional
elif not isinstance(arg, (Dim, int)):
return tuple(convert(a) for a in arg)
else:
return (convert(arg),)

View File

@ -0,0 +1,42 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import signal
import subprocess
from contextlib import contextmanager
@contextmanager
def magic_trace(output="trace.fxt", magic_trace_cache="/tmp/magic-trace"):
pid = os.getpid()
if not os.path.exists(magic_trace_cache):
print(f"Downloading magic_trace to: {magic_trace_cache}")
subprocess.run(
[
"wget",
"-O",
magic_trace_cache,
"-q",
"https://github.com/janestreet/magic-trace/releases/download/v1.0.2/magic-trace",
]
)
subprocess.run(["chmod", "+x", magic_trace_cache])
args = [magic_trace_cache, "attach", "-pid", str(pid), "-o", output]
p = subprocess.Popen(args, stderr=subprocess.PIPE, encoding="utf-8")
while True:
x = p.stderr.readline()
print(x)
if "Attached" in x:
break
try:
yield
finally:
p.send_signal(signal.SIGINT)
r = p.wait()
print(p.stderr.read())
p.stderr.close()
if r != 0:
raise ValueError(f"magic_trace exited abnormally: {r}")

View File

@ -0,0 +1,312 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
# pointwise operators can go through a faster pathway
tensor_magic_methods = ["add", ""]
pointwise_magic_methods_with_reverse = (
"add",
"sub",
"mul",
"floordiv",
"div",
"truediv",
"mod",
"pow",
"lshift",
"rshift",
"and",
"or",
"xor",
)
pointwise_magic_methods = (
*(x for m in pointwise_magic_methods_with_reverse for x in (m, "r" + m)),
"eq",
"gt",
"le",
"lt",
"ge",
"gt",
"ne",
"neg",
"pos",
"abs",
"invert",
"iadd",
"isub",
"imul",
"ifloordiv",
"idiv",
"itruediv",
"imod",
"ipow",
"ilshift",
"irshift",
"iand",
"ior",
"ixor",
"int",
"long",
"float",
"complex",
)
pointwise_methods = (*(f"__{m}__" for m in pointwise_magic_methods),)
pointwise = (
*(getattr(torch.Tensor, m) for m in pointwise_methods),
torch.nn.functional.dropout,
torch.where,
torch.Tensor.abs,
torch.abs,
torch.Tensor.acos,
torch.acos,
torch.Tensor.acosh,
torch.acosh,
torch.Tensor.add,
torch.add,
torch.Tensor.addcdiv,
torch.addcdiv,
torch.Tensor.addcmul,
torch.addcmul,
torch.Tensor.addr,
torch.addr,
torch.Tensor.angle,
torch.angle,
torch.Tensor.asin,
torch.asin,
torch.Tensor.asinh,
torch.asinh,
torch.Tensor.atan,
torch.atan,
torch.Tensor.atan2,
torch.atan2,
torch.Tensor.atanh,
torch.atanh,
torch.Tensor.bitwise_and,
torch.bitwise_and,
torch.Tensor.bitwise_left_shift,
torch.bitwise_left_shift,
torch.Tensor.bitwise_not,
torch.bitwise_not,
torch.Tensor.bitwise_or,
torch.bitwise_or,
torch.Tensor.bitwise_right_shift,
torch.bitwise_right_shift,
torch.Tensor.bitwise_xor,
torch.bitwise_xor,
torch.Tensor.ceil,
torch.ceil,
torch.celu,
torch.nn.functional.celu,
torch.Tensor.clamp,
torch.clamp,
torch.Tensor.clamp_max,
torch.clamp_max,
torch.Tensor.clamp_min,
torch.clamp_min,
torch.Tensor.copysign,
torch.copysign,
torch.Tensor.cos,
torch.cos,
torch.Tensor.cosh,
torch.cosh,
torch.Tensor.deg2rad,
torch.deg2rad,
torch.Tensor.digamma,
torch.digamma,
torch.Tensor.div,
torch.div,
torch.dropout,
torch.nn.functional.dropout,
torch.nn.functional.elu,
torch.Tensor.eq,
torch.eq,
torch.Tensor.erf,
torch.erf,
torch.Tensor.erfc,
torch.erfc,
torch.Tensor.erfinv,
torch.erfinv,
torch.Tensor.exp,
torch.exp,
torch.Tensor.exp2,
torch.exp2,
torch.Tensor.expm1,
torch.expm1,
torch.feature_dropout,
torch.Tensor.float_power,
torch.float_power,
torch.Tensor.floor,
torch.floor,
torch.Tensor.floor_divide,
torch.floor_divide,
torch.Tensor.fmod,
torch.fmod,
torch.Tensor.frac,
torch.frac,
torch.Tensor.frexp,
torch.frexp,
torch.Tensor.gcd,
torch.gcd,
torch.Tensor.ge,
torch.ge,
torch.nn.functional.gelu,
torch.nn.functional.glu,
torch.Tensor.gt,
torch.gt,
torch.Tensor.hardshrink,
torch.hardshrink,
torch.nn.functional.hardshrink,
torch.nn.functional.hardsigmoid,
torch.nn.functional.hardswish,
torch.nn.functional.hardtanh,
torch.Tensor.heaviside,
torch.heaviside,
torch.Tensor.hypot,
torch.hypot,
torch.Tensor.i0,
torch.i0,
torch.Tensor.igamma,
torch.igamma,
torch.Tensor.igammac,
torch.igammac,
torch.Tensor.isclose,
torch.isclose,
torch.Tensor.isfinite,
torch.isfinite,
torch.Tensor.isinf,
torch.isinf,
torch.Tensor.isnan,
torch.isnan,
torch.Tensor.isneginf,
torch.isneginf,
torch.Tensor.isposinf,
torch.isposinf,
torch.Tensor.isreal,
torch.isreal,
torch.Tensor.kron,
torch.kron,
torch.Tensor.lcm,
torch.lcm,
torch.Tensor.ldexp,
torch.ldexp,
torch.Tensor.le,
torch.le,
torch.nn.functional.leaky_relu,
torch.Tensor.lerp,
torch.lerp,
torch.Tensor.lgamma,
torch.lgamma,
torch.Tensor.log,
torch.log,
torch.Tensor.log10,
torch.log10,
torch.Tensor.log1p,
torch.log1p,
torch.Tensor.log2,
torch.log2,
torch.nn.functional.logsigmoid,
torch.Tensor.logical_and,
torch.logical_and,
torch.Tensor.logical_not,
torch.logical_not,
torch.Tensor.logical_or,
torch.logical_or,
torch.Tensor.logical_xor,
torch.logical_xor,
torch.Tensor.logit,
torch.logit,
torch.Tensor.lt,
torch.lt,
torch.Tensor.maximum,
torch.maximum,
torch.Tensor.minimum,
torch.minimum,
torch.nn.functional.mish,
torch.Tensor.mvlgamma,
torch.mvlgamma,
torch.Tensor.nan_to_num,
torch.nan_to_num,
torch.Tensor.ne,
torch.ne,
torch.Tensor.neg,
torch.neg,
torch.Tensor.nextafter,
torch.nextafter,
torch.Tensor.outer,
torch.outer,
torch.polar,
torch.Tensor.polygamma,
torch.polygamma,
torch.Tensor.positive,
torch.positive,
torch.Tensor.pow,
torch.pow,
torch.Tensor.prelu,
torch.prelu,
torch.nn.functional.prelu,
torch.Tensor.rad2deg,
torch.rad2deg,
torch.Tensor.reciprocal,
torch.reciprocal,
torch.Tensor.relu,
torch.relu,
torch.nn.functional.relu,
torch.nn.functional.relu6,
torch.Tensor.remainder,
torch.remainder,
torch.Tensor.round,
torch.round,
torch.rrelu,
torch.nn.functional.rrelu,
torch.Tensor.rsqrt,
torch.rsqrt,
torch.rsub,
torch.selu,
torch.nn.functional.selu,
torch.Tensor.sgn,
torch.sgn,
torch.Tensor.sigmoid,
torch.sigmoid,
torch.nn.functional.sigmoid,
torch.Tensor.sign,
torch.sign,
torch.Tensor.signbit,
torch.signbit,
torch.nn.functional.silu,
torch.Tensor.sin,
torch.sin,
torch.Tensor.sinc,
torch.sinc,
torch.Tensor.sinh,
torch.sinh,
torch.nn.functional.softplus,
torch.nn.functional.softshrink,
torch.Tensor.sqrt,
torch.sqrt,
torch.Tensor.square,
torch.square,
torch.Tensor.sub,
torch.sub,
torch.Tensor.tan,
torch.tan,
torch.Tensor.tanh,
torch.tanh,
torch.nn.functional.tanh,
torch.threshold,
torch.nn.functional.threshold,
torch.trapz,
torch.Tensor.true_divide,
torch.true_divide,
torch.Tensor.trunc,
torch.trunc,
torch.Tensor.xlogy,
torch.xlogy,
torch.rand_like,
)

View File

@ -0,0 +1,645 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# reference python implementations for C ops
import torch
from functorch._C import dim as _C
from . import op_properties
from .batch_tensor import _enable_layers
from .tree_map import tree_flatten, tree_map
DimList = _C.DimList
import operator
from functools import reduce
# use dict to avoid writing C++ bindings for set
pointwise = set(op_properties.pointwise)
def prod(x):
return reduce(operator.mul, x, 1)
def _wrap_dim(d, N, keepdim):
from . import Dim
if isinstance(d, Dim):
assert not keepdim, "cannot preserve first-class dimensions with keepdim=True"
return d
elif d >= 0:
return d - N
else:
return d
def _dims(d, N, keepdim, single_dim):
from . import Dim
if isinstance(d, (Dim, int)):
return ltuple((_wrap_dim(d, N, keepdim),))
assert not single_dim, f"expected a single dimension or int but found: {d}"
return ltuple(_wrap_dim(x, N, keepdim) for x in d)
def _bind_dims_to_size(lhs_size, rhs, lhs_debug):
from . import DimensionMismatchError
not_bound = tuple((i, r) for i, r in enumerate(rhs) if not r.is_bound)
if len(not_bound) == 1:
idx, d = not_bound[0]
rhs_so_far = prod(r.size for r in rhs if r.is_bound)
if lhs_size % rhs_so_far != 0:
rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs)
raise DimensionMismatchError(
f"inferred dimension does not evenly fit into larger dimension: {lhs_size} vs {rhs_s}"
)
new_size = lhs_size // rhs_so_far
d.size = new_size
elif len(not_bound) > 1:
rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs)
raise DimensionMismatchError(
f"cannot infer the size of two dimensions at once: {rhs} with sizes {rhs_s}"
)
else:
rhs_size = prod(r.size for r in rhs)
if lhs_size != rhs_size:
raise DimensionMismatchError(
f"Dimension sizes to do not match ({lhs_size} != {rhs_size}) when matching {lhs_debug} to {rhs}"
)
def _tensor_levels(inp):
from . import _Tensor
if isinstance(inp, _Tensor):
return inp._tensor, llist(inp._levels), inp._has_device
else:
return inp, llist(range(-inp.ndim, 0)), True
def _match_levels(v, from_levels, to_levels):
view = []
permute = []
requires_view = False
size = v.size()
for t in to_levels:
try:
idx = from_levels.index(t)
permute.append(idx)
view.append(size[idx])
except ValueError:
view.append(1)
requires_view = True
if permute != list(range(len(permute))):
v = v.permute(*permute)
if requires_view:
v = v.view(*view)
return v
# make a single dimension positional but do not permute it,
# used to do multi-tensor operators where the dim being acted on
# should not physically move if possible
def _positional_no_permute(self, dim, expand_dim=False):
from . import Tensor
ptensor, levels = self._tensor, llist(self._levels)
try:
idx = levels.index(dim)
except ValueError:
if not expand_dim:
raise
idx = 0
ptensor = ptensor.expand(dim.size, *ptensor.size())
levels.insert(0, 0)
idx_batched = 0
for i in range(idx):
if isinstance(levels[i], int):
levels[i] -= 1
idx_batched += 1
levels[idx] = -idx_batched - 1
return Tensor.from_positional(ptensor, levels, self._has_device), idx_batched
def seq(a, b):
from . import Dim
if isinstance(a, Dim) != isinstance(b, Dim):
return False
if isinstance(a, Dim):
return a is b
else:
return a == b
class isin:
def __contains__(self, item):
for x in self:
if seq(item, x):
return True
return False
def index(self, item):
for i, x in enumerate(self):
if seq(item, x):
return i
raise ValueError
class llist(isin, list):
pass
class ltuple(isin, tuple):
pass
empty_dict = {}
@classmethod
def __torch_function__(self, orig, cls, args, kwargs=empty_dict):
from . import _Tensor, Tensor, TensorLike
from .delayed_mul_tensor import DelayedMulTensor
if orig is torch.Tensor.__mul__:
lhs, rhs = args
if (
isinstance(lhs, _Tensor)
and isinstance(rhs, _Tensor)
and lhs.ndim == 0
and rhs.ndim == 0
):
return DelayedMulTensor(lhs, rhs)
all_dims = llist()
flat_args, unflatten = tree_flatten((args, kwargs))
device_holding_tensor = None
for f in flat_args:
if isinstance(f, _Tensor):
if f._has_device:
device_holding_tensor = f._batchtensor
for d in f.dims:
if d not in all_dims:
all_dims.append(d)
def unwrap(t):
if isinstance(t, _Tensor):
r = t._batchtensor
if device_holding_tensor is not None and not t._has_device:
r = r.to(device=device_holding_tensor.device)
return r
return t
if orig in pointwise:
result_levels = llist()
arg_levels = llist()
to_expand = []
for i, f in enumerate(flat_args):
if isinstance(f, TensorLike):
ptensor, levels, _ = _tensor_levels(f)
if (
isinstance(f, _Tensor)
and not f._has_device
and device_holding_tensor is not None
):
ptensor = ptensor.to(device=device_holding_tensor.device)
flat_args[i] = ptensor
for l in levels:
if l not in result_levels:
result_levels.append(l)
to_expand.append((i, levels))
for i, levels in to_expand:
flat_args[i] = _match_levels(flat_args[i], levels, result_levels)
args, kwargs = unflatten(flat_args)
result = orig(*args, **kwargs)
def wrap(t):
if isinstance(t, TensorLike):
return Tensor.from_positional(
t, result_levels, device_holding_tensor is not None
)
return t
return tree_map(wrap, result)
else:
def wrap(t):
if isinstance(t, TensorLike):
return Tensor.from_batched(t, device_holding_tensor is not None)
return t
with _enable_layers(all_dims):
print(f"batch_tensor for {orig}")
args, kwargs = unflatten(unwrap(f) for f in flat_args)
result = orig(*args, **kwargs)
# print("END", orig)
return tree_map(wrap, result)
def positional(self, *dims):
from . import Dim, DimensionBindError, Tensor
ptensor, levels = self._tensor, llist(self._levels)
flat_dims = llist()
view = []
needs_view = False
ndim = self.ndim
for d in dims:
if isinstance(d, DimList):
flat_dims.extend(d)
view.extend(e.size for e in d)
elif isinstance(d, Dim):
flat_dims.append(d)
view.append(d.size)
elif isinstance(d, int):
d = _wrap_dim(d, ndim, False)
flat_dims.append(d)
view.append(ptensor.size(d))
else:
flat_dims.extend(d)
view.append(prod(e.size for e in d))
needs_view = True
permute = list(range(len(levels)))
nflat = len(flat_dims)
for i, d in enumerate(flat_dims):
try:
idx = levels.index(d)
except ValueError as e:
raise DimensionBindError(
f"tensor of dimensions {self.dims} does not contain dim {d}"
) from e
p = permute[idx]
del levels[idx]
del permute[idx]
levels.insert(i, 0)
permute.insert(i, p)
ptensor = ptensor.permute(*permute)
seen = 0
for i in range(len(levels) - 1, -1, -1):
if isinstance(levels[i], int):
seen += 1
levels[i] = -seen
result = Tensor.from_positional(ptensor, levels, self._has_device)
if needs_view:
result = result.reshape(*view, *result.size()[len(flat_dims) :])
return result
def _contains_dim(input):
from . import Dim
for i in input:
if isinstance(i, Dim):
return True
def expand(self, *sizes):
if not _contains_dim(sizes):
return self.__torch_function__(torch.Tensor.expand, None, (self, *sizes))
dims = sizes
sizes = [d.size for d in dims] + [-1] * self.ndim
self = self.expand(*sizes)
return self[dims]
_not_present = object()
def _getarg(name, offset, args, kwargs, default):
if len(args) > offset:
return args[offset]
return kwargs.get(name, default)
def _patcharg(name, offset, args, kwargs, value):
if len(args) > offset:
args[offset] = value
else:
kwargs[name] = value
def _wrap(
orig, dim_offset=0, keepdim_offset=1, dim_name="dim", single_dim=False, reduce=True
):
from . import Dim, Tensor, TensorLike
def fn(self, *args, **kwargs):
dim = _getarg(dim_name, dim_offset, args, kwargs, _not_present)
if dim is _not_present or (single_dim and not isinstance(dim, Dim)):
with _enable_layers(self.dims):
print(f"dim fallback batch_tensor for {orig}")
return Tensor.from_batched(
orig(self._batchtensor, *args, **kwargs), self._has_device
)
keepdim = (
_getarg("keepdim", keepdim_offset, args, kwargs, False) if reduce else False
)
t, levels = self._tensor, llist(self._levels)
dims = _dims(dim, self._batchtensor.ndim, keepdim, single_dim)
dim_indices = tuple(levels.index(d) for d in dims)
if reduce and not keepdim:
new_levels = [l for i, l in enumerate(levels) if i not in dim_indices]
else:
new_levels = levels
if len(dim_indices) == 1:
dim_indices = dim_indices[
0
] # so that dims that really only take a single argument work...
args = list(args)
_patcharg(dim_name, dim_offset, args, kwargs, dim_indices)
def wrap(t):
if isinstance(t, TensorLike):
return Tensor.from_positional(t, new_levels, self._has_device)
return t
with _enable_layers(new_levels):
print(f"dim used batch_tensor for {orig}")
r = orig(t, *args, **kwargs)
return tree_map(wrap, r)
return fn
def _def(name, *args, **kwargs):
from . import _Tensor
orig = getattr(torch.Tensor, name)
setattr(_Tensor, name, _wrap(orig, *args, **kwargs))
no_slice = slice(None)
_orig_getitem = torch.Tensor.__getitem__
class dim_tracker:
def __init__(self) -> None:
self.dims = llist()
self.count = []
def record(self, d):
if d not in self.dims:
self.dims.append(d)
self.count.append(1)
def __getitem__(self, d):
return self.count[self.dims.index(d)]
def t__getitem__(self, input):
from . import _Tensor, Dim, DimensionBindError, DimList, Tensor, TensorLike
# * bail to original example if we have a single non-Dim tensor, or a non-tensor
# * locate ... or an unbound tensor list, and determine its size, bind dim list
# (remember that None does not count to the total dim count)
# * bind simple dims and dim-packs to their sizes, count the number of uses of each dim,
# produce the re-view if needed
# * for each single-use dim index, replace with no_slice and mark that it will be added
# (keep track of whether we have to call super)
# * call super if needed
# * if we have dims to bind, bind them (it will help if we eliminated ... and None before)
# this handles bool indexing handling, as well as some other simple cases.
is_simple = (
not isinstance(input, Dim)
and not isinstance(input, (tuple, list))
and
# WAR for functorch bug where zero time tensors in getitem are not handled correctly.
not (isinstance(input, TensorLike) and input.ndim == 0)
)
if is_simple:
if isinstance(self, _Tensor):
return _Tensor.__torch_function__(_orig_getitem, None, (self, input))
else:
return _orig_getitem(self, input)
# can further optimize this case
if not isinstance(input, tuple):
input = [input]
else:
input = list(input)
dims_indexed = 0
expanding_object = None
dimlists = []
for i, s in enumerate(input):
if s is ... or isinstance(s, DimList) and not s.is_bound:
if expanding_object is not None:
msg = (
"at most one ... or unbound dimension list can exist in indexing list but"
f" found 2 at offsets {i} and {expanding_object}"
)
raise DimensionBindError(msg)
expanding_object = i
if isinstance(s, DimList):
dims_indexed += len(s) if s.is_bound else 0
dimlists.append(i)
elif s is not None and s is not ...:
dims_indexed += 1
ndim = self.ndim
if dims_indexed > ndim:
raise IndexError(
f"at least {dims_indexed} indices were supplied but the tensor only has {ndim} dimensions."
)
if expanding_object is not None:
expanding_ndims = ndim - dims_indexed
obj = input[expanding_object]
if obj is ...:
input[expanding_object : expanding_object + 1] = [
no_slice
] * expanding_ndims
else:
obj.bind_len(expanding_ndims)
# flatten the dimslists into the indexing
for i in reversed(dimlists):
input[i : i + 1] = input[i]
dims_indexed = 0
requires_view = False
size = self.size()
view_sizes = []
dims_seen = dim_tracker()
def add_dims(t):
if not isinstance(t, _Tensor):
return
for d in t.dims:
dims_seen.record(d)
add_dims(self)
dim_packs = []
for i, idx in enumerate(input):
if idx is None:
input[i] = no_slice
view_sizes.append(1)
requires_view = True
else:
sz = size[dims_indexed]
if isinstance(idx, Dim):
idx.size = sz
dims_seen.record(idx)
view_sizes.append(sz)
elif isinstance(idx, (tuple, list)) and idx and isinstance(idx[0], Dim):
for d in idx:
dims_seen.record(idx)
_bind_dims_to_size(sz, idx, f"offset {i}")
view_sizes.extend(d.size for d in idx)
requires_view = True
dim_packs.append(i)
else:
add_dims(idx)
view_sizes.append(sz)
dims_indexed += 1
if requires_view:
self = self.view(*view_sizes)
for i in reversed(dim_packs):
input[i : i + 1] = input[i]
# currenty:
# input is flat, containing either Dim, or Tensor, or something valid for standard indexing
# self may have first-class dims as well.
# to index:
# drop the first class dims from self, they just become direct indices of their positions
# figure out the dimensions of the indexing tensors: union of all the dims in the tensors in the index.
# these dimensions will appear and need to be bound at the first place tensor occures
if isinstance(self, _Tensor):
ptensor_self, levels = self._tensor, list(self._levels)
# indices to ptensor rather than self which has first-class dimensions
input_it = iter(input)
flat_inputs = [next(input_it) if isinstance(l, int) else l for l in levels]
has_device = self._has_device
to_pad = 0
else:
ptensor_self, flat_inputs = self, input
to_pad = ptensor_self.ndim - len(flat_inputs)
has_device = True
result_levels = []
index_levels = []
tensor_insert_point = None
to_expand = {}
requires_getindex = False
for i, inp in enumerate(flat_inputs):
if isinstance(inp, Dim) and dims_seen[inp] == 1:
flat_inputs[i] = no_slice
result_levels.append(inp)
elif isinstance(inp, TensorLike):
requires_getindex = True
if tensor_insert_point is None:
tensor_insert_point = len(result_levels)
ptensor, levels, _ = _tensor_levels(inp)
to_expand[i] = levels
flat_inputs[i] = ptensor
for l in levels:
if l not in index_levels:
index_levels.append(l)
else:
requires_getindex = True
result_levels.append(0)
if tensor_insert_point is not None:
result_levels[tensor_insert_point:tensor_insert_point] = index_levels
for i, levels in to_expand.items():
flat_inputs[i] = _match_levels(flat_inputs[i], levels, index_levels)
if requires_getindex:
result = _orig_getitem(ptensor_self, flat_inputs)
else:
result = ptensor_self
next_positional = -1
if to_pad > 0:
result_levels.extend([0] * to_pad)
for i, r in enumerate(reversed(result_levels)):
if isinstance(r, int):
result_levels[-1 - i] = next_positional
next_positional -= 1
return Tensor.from_positional(result, result_levels, has_device)
# XXX - dim is optional and can be the outer-most dimension...
def stack(tensors, new_dim, dim=0, out=None):
if isinstance(dim, int):
return torch.stack(tensors, dim, out).index(dim, new_dim)
index = None
if out is not None:
out, index = _positional_no_permute(out, dim, expand_dim=True)
ptensors = []
for t in tensors:
pt, pi = _positional_no_permute(t, dim, expand_dim=True)
if index is not None and pi != index:
pt = pt.move_dim(pi, index)
else:
index = pi
ptensors.append(pt)
pr = torch.stack(ptensors, index, out=out)
return pr.index((index, index + 1), (new_dim, dim))
_orig_split = torch.Tensor.split
def split(self, split_size_or_sections, dim=0):
from . import _Tensor, Dim
if isinstance(split_size_or_sections, int) or any(
isinstance(t, int) for t in split_size_or_sections
):
if isinstance(dim, Dim):
raise ValueError(
"when dim is specified as a Dim object, split sizes must also be dimensions."
)
return _orig_split(self, split_size_or_sections, dim=dim)
if isinstance(dim, Dim):
assert isinstance(self, _Tensor), f"Tensor does not have dimension {dim}"
self, dim = _positional_no_permute(self, dim)
size = self.size(dim)
total_bound_size = 0
unbound = []
sizes = []
for i, d in enumerate(split_size_or_sections):
if d.is_bound:
sizes.append(d.size)
total_bound_size += d.size
else:
sizes.append(0)
unbound.append(i)
if unbound:
assert (
total_bound_size <= size
), f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})"
remaining_size = size - total_bound_size
chunk_size = -(-remaining_size // len(unbound))
for u in unbound:
sz = min(chunk_size, remaining_size)
split_size_or_sections[u].size = sz
sizes[u] = sz
remaining_size -= sz
else:
assert (
total_bound_size == size
), f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})"
return tuple(
t.index(dim, d)
for d, t in zip(split_size_or_sections, _orig_split(self, sizes, dim=dim))
)

View File

@ -0,0 +1,15 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from functorch._C import dim
tree_flatten = dim.tree_flatten
def tree_map(fn, tree):
vs, unflatten = tree_flatten(tree)
return unflatten(fn(v) for v in vs)

View File

@ -0,0 +1,72 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from types import (
BuiltinMethodType,
FunctionType,
GetSetDescriptorType,
MethodDescriptorType,
WrapperDescriptorType,
)
from functorch._C import dim as _C
_wrap_method = _C._wrap_method
FUNC_TYPES = (
FunctionType,
MethodDescriptorType,
BuiltinMethodType,
WrapperDescriptorType,
)
PROPERTY_TYPES = (GetSetDescriptorType, property)
def _py_wrap_method(orig, __torch_function__):
def impl(*args, **kwargs):
return __torch_function__(orig, None, args, kwargs)
return impl
def wrap_type(use_c, to_patch, pattern, __torch_function__):
if use_c:
wrap_method = _wrap_method
else:
wrap_method = _py_wrap_method
all = {}
for t in reversed(pattern.mro()[:-1]): # skip object
all.update(t.__dict__)
def wrap_attr(orig):
return property(wrap_method(orig.__get__, __torch_function__))
for name, obj in all.items():
if name in (
"__dict__",
"__new__",
"__init__",
"__repr__",
"__weakref__",
"__doc__",
"__module__",
"__dir__",
):
continue
# skip things that have been overloaded
# things that come from object like `__eq__` still need to be patched, however.
if hasattr(to_patch, name) and getattr(to_patch, name) is not getattr(
object, name, None
):
continue
if isinstance(obj, FUNC_TYPES):
setattr(to_patch, name, wrap_method(obj, __torch_function__))
elif isinstance(obj, PROPERTY_TYPES):
setattr(to_patch, name, wrap_attr(obj))