I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,190 @@
# mypy: allow-untyped-defs
"""
This is a simple interpreter for Sympy expressions that dispatches to
classes following the torch._inductor.virtualized calling convention.
For directness, the interpreter takes the handler directly rather than
consulting the TLS. It does not use most of the methods on the full
handler; only those with corresponding Sympy expressions. To see an example
of a full handler, see torch.utils._sympy.value_ranges.ValueRangeAnalysis.
"""
import functools
import logging
from typing import Any, Dict, Union
import sympy
from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom
import torch
from .functions import (
CeilToInt,
CleanDiv,
FloatPow,
FloatTrueDiv,
FloorDiv,
FloorToInt,
Identity,
IntTrueDiv,
IsNonOverlappingAndDenseIndicator,
Max,
Min,
Mod,
ModularIndexing,
PowByNatural,
PythonMod,
RoundDecimal,
RoundToInt,
ToFloat,
TruncToFloat,
TruncToInt,
Where,
)
log = logging.getLogger(__name__)
# TODO: Dedupe this with SYMPY_INTERP
@functools.lru_cache(None)
def handlers():
# TODO add CeilDiv (it doesn't appear in the index_expr)
# TODO default to some decompositions if the interpreter doesn't have them
# like decomposing ModularIndexing or implementing Le(a,b) as Ge(b, a)
HANDLERS = {
sympy.Or: "or_",
sympy.And: "and_",
sympy.Eq: "eq",
sympy.Ne: "ne",
sympy.Lt: "lt",
sympy.Gt: "gt",
sympy.Le: "le",
sympy.Ge: "ge",
sympy.Not: "not_",
IntTrueDiv: "int_truediv",
FloatTrueDiv: "truediv",
FloorDiv: "floordiv",
CleanDiv: "floordiv", # TODO: hmm?
TruncToFloat: "trunc",
Where: "where",
sympy.Add: "add",
sympy.Mul: "mul",
FloatPow: "pow",
PowByNatural: "pow_by_natural",
# sympy simplifies x * x into Pow(x, 2), so we need to handle this.
# Do NOT use builtin Pow for floats
# TODO: There is a hazard here, if we have float * float it will
# also get turned into Pow(float, 2) but we don't want this because
# pow_by_natural is assumed to only be integers. Probably the fix is
# to add a FloatMul to impede this optimization
sympy.Pow: "pow_by_natural",
Mod: "mod",
PythonMod: "mod", # TODO: this is wrong
# TODO: Inductor can generate these, but it's ill-specified which
# semantics were intended here. Needs to be cleaned up along with
# FloorDiv in a bigger cleanup
sympy.Mod: "mod",
sympy.Abs: "abs",
sympy.log: "log",
sympy.exp: "exp",
sympy.Min: "minimum",
sympy.Max: "maximum",
Min: "minimum",
Max: "maximum",
ModularIndexing: "modular_indexing",
sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair",
sympy.Piecewise: "piecewise",
Identity: "identity",
IsNonOverlappingAndDenseIndicator: "is_non_overlapping_and_dense_indicator",
RoundDecimal: "round_decimal",
}
for name in ["cos", "sin", "tan", "sinh", "cosh", "tanh", "asin", "acos", "atan"]:
HANDLERS[getattr(sympy, name)] = name
return HANDLERS
ASSOCIATIVE_OPS = {"minimum", "maximum", "mul", "add", "and_", "or_"}
def _run_sympy_handler(analysis, args, expr, index_dtype=torch.int64):
# Special cases
if isinstance(expr, sympy.Pow) and isinstance(
expr.args[1], sympy.core.numbers.Half
):
return analysis.sqrt(args[0])
if isinstance(expr, ToFloat):
return analysis.to_dtype(args[0], torch.float64)
# These handlers are special because they take an extra dtype argument
# specifying what they should convert to, and we need to appropriately set
# this up when we convert from Sympy. A reasonable default when you
# are translating is to conservatively do int64, and then narrow these
# arguments later when you discover you can narrow the index range. But
# if you already know that 32-bit indexing is OK, you can directly do the
# sympy translation with index_dtype=torch.int32
INDEX_DTYPE_HANDLERS = {
TruncToInt: "trunc_to_int",
sympy.floor: "floor_to_int",
sympy.ceiling: "ceil_to_int",
FloorToInt: "floor_to_int",
CeilToInt: "ceil_to_int",
RoundToInt: "round_to_int",
}
if (handler_name := INDEX_DTYPE_HANDLERS.get(expr.func)) is not None:
return getattr(analysis, handler_name)(*args, index_dtype)
if hasattr(expr.func, "_torch_handler_name"):
handler_name = expr.func._torch_handler_name
else:
handler_name = handlers()[expr.func]
handler = getattr(analysis, handler_name)
try:
if handler_name in ASSOCIATIVE_OPS:
assert len(args) > 1
acc = handler(args[0], args[1])
for i in range(2, len(args)):
acc = handler(acc, args[i])
log.debug("%s(%s) -> %s", handler_name, args, acc)
return acc
else:
r = handler(*args)
log.debug("%s(%s) -> %s", handler_name, args, r)
return r
except Exception:
log.warning("failed while executing %s(%s)", handler_name, args)
raise
def sympy_interp(
analysis,
env: Dict[sympy.Symbol, Any],
expr: Union[sympy.Expr, SympyBoolean],
*,
index_dtype=torch.int64,
):
# Handle base cases
dtype = None
if isinstance(expr, BooleanAtom):
dtype = torch.bool
elif isinstance(expr, sympy.Integer):
dtype = torch.int64
elif isinstance(expr, sympy.Number):
dtype = torch.double
if dtype is not None:
return analysis.constant(expr, dtype)
elif isinstance(expr, sympy.Symbol):
return env[expr]
# Recursive case
return _run_sympy_handler(
analysis,
[sympy_interp(analysis, env, arg) for arg in expr.args], # type: ignore[arg-type]
expr,
index_dtype=index_dtype,
) # type: ignore[arg-type]

View File

@ -0,0 +1,397 @@
# mypy: allow-untyped-defs
import mpmath.libmp as mlib # type: ignore[import-untyped]
import sympy
from sympy import Expr
from sympy.core.decorators import _sympifyit
from sympy.core.expr import AtomicExpr
from sympy.core.numbers import Number
from sympy.core.parameters import global_parameters
from sympy.core.singleton import S, Singleton
class IntInfinity(Number, metaclass=Singleton):
r"""Positive integer infinite quantity.
Integer infinity is a value in an extended integers which
is greater than all other integers. We distinguish it from
sympy's existing notion of infinity in that it reports that
it is_integer.
Infinity is a singleton, and can be accessed by ``S.IntInfinity``,
or can be imported as ``int_oo``.
"""
# NB: We can't actually mark this as infinite, as integer and infinite are
# inconsistent assumptions in sympy. We also report that we are complex,
# different from sympy.oo
is_integer = True
is_commutative = True
is_number = True
is_extended_real = True
is_comparable = True
is_extended_positive = True
is_prime = False
# Ensure we get dispatched to before plain numbers
_op_priority = 100.0
__slots__ = ()
def __new__(cls):
return AtomicExpr.__new__(cls)
def _sympystr(self, printer):
return "int_oo"
def _eval_subs(self, old, new):
if self == old:
return new
# We could do these, not sure about it
"""
def _eval_evalf(self, prec=None):
return Float('inf')
def evalf(self, prec=None, **options):
return self._eval_evalf(prec)
"""
@_sympifyit("other", NotImplemented)
def __add__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other in (S.Infinity, S.NegativeInfinity):
return other
if other in (S.NegativeIntInfinity, S.NaN):
return S.NaN
return self
return Number.__add__(self, other)
__radd__ = __add__
@_sympifyit("other", NotImplemented)
def __sub__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other is S.Infinity:
return S.NegativeInfinity
if other is S.NegativeInfinity:
return S.Infinity
if other in (S.IntInfinity, S.NaN):
return S.NaN
return self
return Number.__sub__(self, other)
@_sympifyit("other", NotImplemented)
def __rsub__(self, other):
return (-self).__add__(other)
@_sympifyit("other", NotImplemented)
def __mul__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other.is_zero or other is S.NaN:
return S.NaN
if other.is_extended_positive:
return self
return S.NegativeIntInfinity
return Number.__mul__(self, other)
__rmul__ = __mul__
@_sympifyit("other", NotImplemented)
def __truediv__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other in (
S.Infinity,
S.IntInfinity,
S.NegativeInfinity,
S.NegativeIntInfinity,
S.NaN,
):
return S.NaN
if other.is_extended_nonnegative:
return S.Infinity # truediv produces float
return S.NegativeInfinity # truediv produces float
return Number.__truediv__(self, other)
def __abs__(self):
return S.IntInfinity
def __neg__(self):
return S.NegativeIntInfinity
def _eval_power(self, expt):
if expt.is_extended_positive:
return S.IntInfinity
if expt.is_extended_negative:
return S.Zero
if expt is S.NaN:
return S.NaN
if expt is S.ComplexInfinity:
return S.NaN
if expt.is_extended_real is False and expt.is_number:
from sympy.functions.elementary.complexes import re
expt_real = re(expt)
if expt_real.is_positive:
return S.ComplexInfinity
if expt_real.is_negative:
return S.Zero
if expt_real.is_zero:
return S.NaN
return self ** expt.evalf()
def _as_mpf_val(self, prec):
return mlib.finf
def __hash__(self):
return super().__hash__()
def __eq__(self, other):
return other is S.IntInfinity
def __ne__(self, other):
return other is not S.IntInfinity
def __gt__(self, other):
if other is S.Infinity:
return sympy.false # sympy.oo > int_oo
elif other is S.IntInfinity:
return sympy.false # consistency with sympy.oo
else:
return sympy.true
def __ge__(self, other):
if other is S.Infinity:
return sympy.false # sympy.oo > int_oo
elif other is S.IntInfinity:
return sympy.true # consistency with sympy.oo
else:
return sympy.true
def __lt__(self, other):
if other is S.Infinity:
return sympy.true # sympy.oo > int_oo
elif other is S.IntInfinity:
return sympy.false # consistency with sympy.oo
else:
return sympy.false
def __le__(self, other):
if other is S.Infinity:
return sympy.true # sympy.oo > int_oo
elif other is S.IntInfinity:
return sympy.true # consistency with sympy.oo
else:
return sympy.false
@_sympifyit("other", NotImplemented)
def __mod__(self, other):
if not isinstance(other, Expr):
return NotImplemented
return S.NaN
__rmod__ = __mod__
def floor(self):
return self
def ceiling(self):
return self
int_oo = S.IntInfinity
class NegativeIntInfinity(Number, metaclass=Singleton):
"""Negative integer infinite quantity.
NegativeInfinity is a singleton, and can be accessed
by ``S.NegativeInfinity``.
See Also
========
IntInfinity
"""
# Ensure we get dispatched to before plain numbers
_op_priority = 100.0
is_integer = True
is_extended_real = True
is_commutative = True
is_comparable = True
is_extended_negative = True
is_number = True
is_prime = False
__slots__ = ()
def __new__(cls):
return AtomicExpr.__new__(cls)
def _eval_subs(self, old, new):
if self == old:
return new
def _sympystr(self, printer):
return "-int_oo"
"""
def _eval_evalf(self, prec=None):
return Float('-inf')
def evalf(self, prec=None, **options):
return self._eval_evalf(prec)
"""
@_sympifyit("other", NotImplemented)
def __add__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other is S.Infinity:
return S.Infinity
if other in (S.IntInfinity, S.NaN):
return S.NaN
return self
return Number.__add__(self, other)
__radd__ = __add__
@_sympifyit("other", NotImplemented)
def __sub__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other is S.NegativeInfinity:
return S.Infinity
if other in (S.NegativeIntInfinity, S.NaN):
return S.NaN
return self
return Number.__sub__(self, other)
@_sympifyit("other", NotImplemented)
def __rsub__(self, other):
return (-self).__add__(other)
@_sympifyit("other", NotImplemented)
def __mul__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other.is_zero or other is S.NaN:
return S.NaN
if other.is_extended_positive:
return self
return S.IntInfinity
return Number.__mul__(self, other)
__rmul__ = __mul__
@_sympifyit("other", NotImplemented)
def __truediv__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other in (
S.Infinity,
S.IntInfinity,
S.NegativeInfinity,
S.NegativeIntInfinity,
S.NaN,
):
return S.NaN
if other.is_extended_nonnegative:
return self
return S.Infinity # truediv returns float
return Number.__truediv__(self, other)
def __abs__(self):
return S.IntInfinity
def __neg__(self):
return S.IntInfinity
def _eval_power(self, expt):
if expt.is_number:
if expt in (
S.NaN,
S.Infinity,
S.NegativeInfinity,
S.IntInfinity,
S.NegativeIntInfinity,
):
return S.NaN
if isinstance(expt, sympy.Integer) and expt.is_extended_positive:
if expt.is_odd:
return S.NegativeIntInfinity
else:
return S.IntInfinity
inf_part = S.IntInfinity**expt
s_part = S.NegativeOne**expt
if inf_part == 0 and s_part.is_finite:
return inf_part
if (
inf_part is S.ComplexInfinity
and s_part.is_finite
and not s_part.is_zero
):
return S.ComplexInfinity
return s_part * inf_part
def _as_mpf_val(self, prec):
return mlib.fninf
def __hash__(self):
return super().__hash__()
def __eq__(self, other):
return other is S.NegativeIntInfinity
def __ne__(self, other):
return other is not S.NegativeIntInfinity
def __gt__(self, other):
if other is S.NegativeInfinity:
return sympy.true # -sympy.oo < -int_oo
elif other is S.NegativeIntInfinity:
return sympy.false # consistency with sympy.oo
else:
return sympy.false
def __ge__(self, other):
if other is S.NegativeInfinity:
return sympy.true # -sympy.oo < -int_oo
elif other is S.NegativeIntInfinity:
return sympy.true # consistency with sympy.oo
else:
return sympy.false
def __lt__(self, other):
if other is S.NegativeInfinity:
return sympy.false # -sympy.oo < -int_oo
elif other is S.NegativeIntInfinity:
return sympy.false # consistency with sympy.oo
else:
return sympy.true
def __le__(self, other):
if other is S.NegativeInfinity:
return sympy.false # -sympy.oo < -int_oo
elif other is S.NegativeIntInfinity:
return sympy.true # consistency with sympy.oo
else:
return sympy.true
@_sympifyit("other", NotImplemented)
def __mod__(self, other):
if not isinstance(other, Expr):
return NotImplemented
return S.NaN
__rmod__ = __mod__
def floor(self):
return self
def ceiling(self):
return self
def as_powers_dict(self):
return {S.NegativeOne: 1, S.IntInfinity: 1}

View File

@ -0,0 +1,283 @@
# mypy: allow-untyped-defs
import math
import operator
import sympy
import torch
from torch.utils._sympy.functions import (
_keep_float,
FloatPow,
FloatTrueDiv,
FloorDiv,
IntTrueDiv,
Max,
Min,
Mod,
OpaqueUnaryFn_exp,
OpaqueUnaryFn_log,
OpaqueUnaryFn_sqrt,
PowByNatural,
RoundDecimal,
RoundToInt,
ToFloat,
TruncToInt,
)
# The sympy interpretation of operators. It will also sometimes work with
# plain int/float, but if you do certain operations you will get out a
# sympy.Basic in the end. If you want the Python/FX traceable interpretation,
# check PythonReferenceAnalysis.
# NB: For magic methods this needs to use normal magic methods
# so that test_magic_methods works
class ReferenceAnalysis:
@staticmethod
def constant(c, dtype):
return sympy.sympify(c)
@staticmethod
def or_(a, b):
return a | b
@staticmethod
def and_(a, b):
return a & b
@staticmethod
def eq(a, b):
if isinstance(a, sympy.Expr) or isinstance(b, sympy.Expr):
return sympy.Eq(a, b)
return a == b
@classmethod
def ne(cls, a, b):
return cls.not_(cls.eq(a, b))
@staticmethod
def lt(a, b):
return a < b
@staticmethod
def gt(a, b):
return a > b
@staticmethod
def le(a, b):
return a <= b
@staticmethod
def ge(a, b):
return a >= b
@staticmethod
def not_(a):
assert not isinstance(a, bool)
return ~a
@staticmethod
def reciprocal(x):
return FloatTrueDiv(1.0, x)
@staticmethod
def square(x):
return PowByNatural(x, 2)
@staticmethod
def trunc_to_int(x, dtype):
return TruncToInt(x)
@staticmethod
def ceil_to_int(x, dtype):
return sympy.ceiling(x)
@staticmethod
def floor_to_int(x, dtype):
return sympy.floor(x)
@staticmethod
def floor(x):
return _keep_float(sympy.floor)(x)
@staticmethod
def ceil(x):
return _keep_float(sympy.ceiling)(x)
@staticmethod
def to_dtype(x, dtype):
if dtype == torch.float64:
return ToFloat(x)
raise NotImplementedError(f"to_dtype {dtype} NYI")
@staticmethod
def mod(x, y):
return Mod(x, y)
@staticmethod
def abs(x):
return abs(x)
@staticmethod
def neg(x):
return -x
@staticmethod
def truediv(a, b):
return FloatTrueDiv(a, b)
@staticmethod
def int_truediv(a, b):
return IntTrueDiv(a, b)
@staticmethod
def floordiv(a, b):
return FloorDiv(a, b)
@staticmethod
def truncdiv(a, b):
raise NotImplementedError("TODO: truncdiv")
@staticmethod
def add(a, b):
return _keep_float(operator.add)(a, b)
@staticmethod
def mul(a, b):
return _keep_float(operator.mul)(a, b)
@staticmethod
def sub(a, b):
return _keep_float(operator.sub)(a, b)
@staticmethod
def exp(x):
return OpaqueUnaryFn_exp(x)
@staticmethod
def log(x):
return OpaqueUnaryFn_log(x)
@staticmethod
def sqrt(x):
return OpaqueUnaryFn_sqrt(x)
@staticmethod
def pow(a, b):
return _keep_float(FloatPow)(a, b)
@staticmethod
def pow_by_natural(a, b):
return PowByNatural(a, b)
@staticmethod
def minimum(a, b):
return Min(a, b)
@staticmethod
def maximum(a, b):
return Max(a, b)
@staticmethod
def round_to_int(a, dtype):
return RoundToInt(a)
@staticmethod
def round_decimal(a, b):
return RoundDecimal(a, b)
# Unlike ReferenceAnalysis, does NOT sympyify, instead, works with plain
# Python types and is FX traceable. Inheritance here is purely for code
# sharing (TODO: considering splitting out a BaseReferenceAnalysis).
class PythonReferenceAnalysis(ReferenceAnalysis):
@staticmethod
def constant(c, dtype):
if dtype is torch.int64:
return int(c)
elif dtype is torch.double:
return float(c)
elif dtype is torch.bool:
return bool(c)
else:
raise AssertionError(f"unrecognized dtype {dtype}")
@staticmethod
def not_(a):
return torch.sym_not(a)
@staticmethod
def floordiv(a, b):
return a // b
@staticmethod
def mod(x, y):
return x % y
@staticmethod
def truncdiv(a, b):
return a / b
@staticmethod
def to_dtype(x, dtype):
if dtype == torch.float64:
return torch.sym_float(x)
raise NotImplementedError(f"to_dtype {dtype} NYI")
@staticmethod
def exp(x):
raise AssertionError("exp is not valid shape sympy expr")
@staticmethod
def log(x):
raise AssertionError("log is not valid shape sympy expr")
@staticmethod
def sqrt(x):
return torch._sym_sqrt(x) # type: ignore[attr-defined]
@staticmethod
def minimum(a, b):
return torch.sym_min(a, b)
@staticmethod
def maximum(a, b):
return torch.sym_max(a, b)
@staticmethod
def floor_to_int(x, dtype):
return math.floor(x)
@staticmethod
def ceil_to_int(x, dtype):
return math.ceil(x)
@staticmethod
def floor(x):
return float(math.floor(x))
@staticmethod
def ceil(x):
return float(math.ceil(x))
@staticmethod
def truediv(a, b):
return a / b
@staticmethod
def pow(a, b):
return a**b
@staticmethod
def pow_by_natural(a, b):
# Pray that safe_pow is not needed here lol. In particular, this
# never participates in VR low/high ranges, so overflow should be
# unlikely
return a**b
@staticmethod
def round_to_int(a, dtype):
return round(a)
@staticmethod
def round_decimal(a, b):
return round(a, ndigits=b)

View File

@ -0,0 +1,96 @@
# mypy: allow-untyped-defs
import sympy
from sympy.multipledispatch import dispatch
__all__ = ["SingletonInt"]
class SingletonInt(sympy.AtomicExpr):
# This is probably not super important unless we are in multiple dispatch
# situations with other more exotic Expr types.
_op_priority = 99999
def __new__(cls, *args, coeff=None, **kwargs):
instance = super().__new__(cls, *args, **kwargs)
return instance
# The semantics of this class should match that of NestedIntSymNodeImpl in
# c10/core/NestedIntSymNodeImpl.h
def __init__(self, val, *, coeff=1):
self._val = val
self._coeff = coeff
super().__init__()
# See NOTE [ Inequalities with nested int ]
def _eval_Eq(self, other):
if (
isinstance(other, SingletonInt)
and other._val == self._val
and self._coeff == other._coeff
):
return sympy.true
else:
return sympy.false
# This is necessary so that calling expr.free_symbols on exprs that contain
# this Singleton does not error
@property
def free_symbols(self):
return set()
def __mul__(self, other):
if isinstance(other, SingletonInt):
raise ValueError(
"SingletonInt cannot be multiplied by another SingletonInt"
)
return SingletonInt(self._val, coeff=self._coeff * other)
def __rmul__(self, other):
if isinstance(other, SingletonInt):
raise ValueError(
"SingletonInt cannot be multiplied by another SingletonInt"
)
return SingletonInt(self._val, coeff=self._coeff * other)
# Make sure we promptly raise an error instead of falling back to building
# an expression tree. There are probably more ops, how can we be exhaustive?
def __add__(self, other):
raise NotImplementedError("NYI")
def __sub__(self, other):
raise NotImplementedError("NYI")
def __truediv__(self, other):
raise NotImplementedError("NYI")
def __floordiv__(self, other):
raise NotImplementedError("NYI")
def __mod__(self, other):
raise NotImplementedError("NYI")
# See NOTE [ Inequalities with nested int ]
@dispatch(sympy.Integer, SingletonInt)
def _eval_is_ge(a, b):
if a < 2:
return sympy.false
raise ValueError("Symbolic SingletonInt: Relation is indeterminate")
@dispatch(SingletonInt, sympy.Integer) # type: ignore[no-redef]
def _eval_is_ge(a, b): # noqa: F811
if b <= 2:
return sympy.true
raise ValueError("Symbolic SingletonInt: Relation is indeterminate")
@dispatch(SingletonInt, SingletonInt) # type: ignore[no-redef]
def _eval_is_ge(a, b): # noqa: F811
if a._val == b._val:
if a._coeff >= b._coeff:
return sympy.true
else:
return sympy.false
raise ValueError("Symbolic SingletonInt: Relation is indeterminate")

View File

@ -0,0 +1,175 @@
import logging
from typing import Dict, Optional, Tuple, Type
import sympy
from torch.utils._sympy.functions import FloorDiv
log = logging.getLogger(__name__)
_MIRROR_REL_OP: Dict[Type[sympy.Basic], Type[sympy.Rel]] = {
sympy.Eq: sympy.Eq,
sympy.Ne: sympy.Ne,
sympy.Ge: sympy.Le,
sympy.Gt: sympy.Lt,
sympy.Le: sympy.Ge,
sympy.Lt: sympy.Gt,
}
INEQUALITY_TYPES = (sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le)
def mirror_rel_op(type: Type) -> Optional[Type[sympy.Rel]]:
return _MIRROR_REL_OP.get(type, None)
# Tries to simplify 'expr', so as to leave only 'thing' in the left-hand side.
#
# Returns a tuple of:
# 1. The simplified expression
# 2. The expression on the right-hand side
#
# Returns 'None' if it can't reach a state where the only thing in the left
# hand side is 'thing'.
#
# 'trials': number of times 'try_solve' will try to isolate 'thing' to the
# left-hand side.
#
# 'floordiv_inequality': flag to enable conversion of 'FloorDiv' into
# inequalities.
def try_solve(
expr: sympy.Basic,
thing: sympy.Basic,
trials: int = 5,
floordiv_inequality: bool = True,
) -> Optional[Tuple[sympy.Rel, sympy.Basic]]:
mirror = mirror_rel_op(type(expr))
# Ignore unsupported expressions:
# - Those that are not relational operations
# - Those that don't have a mirror (just avoiding unexpected classes)
if not isinstance(expr, sympy.Rel) or mirror is None:
log.debug("expression with unsupported type: %s", type(expr))
return None
lhs_has_thing = expr.lhs.has(thing)
rhs_has_thing = expr.rhs.has(thing)
# Give up when 'thing' appears on both sides of the relational expression.
# That is because, as is, we assume the thing we are trying to isolate is
# only on the right-hand side.
if lhs_has_thing and rhs_has_thing:
log.debug("thing (%s) found in both sides of expression: %s", thing, expr)
return None
# Try considering both LHS and RHS by mirroring the original expression:
# a < b ==> b > a
expressions = []
# Add each version of 'expr' if 'thing' is in its left-hand side.
if lhs_has_thing:
expressions.append(expr)
if rhs_has_thing:
expressions.append(mirror(expr.rhs, expr.lhs))
for e in expressions:
if e is None:
continue
assert isinstance(e, sympy.Rel)
for _ in range(trials):
trial = _try_isolate_lhs(e, thing, floordiv_inequality=floordiv_inequality)
# Stop if there was no change in this trial.
if trial == e:
break
e = trial # type: ignore[assignment]
# Return if we were able to isolate 'thing' on the left-hand side.
if isinstance(e, sympy.Rel) and e.lhs == thing:
log.debug("solved: %s ---> %s", expr, e)
return e, e.rhs
return None
def _try_isolate_lhs(
e: sympy.Basic, thing: sympy.Basic, floordiv_inequality: bool
) -> sympy.Basic:
op = type(e)
if isinstance(e, sympy.Rel):
# Move any constants in the left-hand side to the right-hand side.
lhs_not_thing = (
sum(a for a in e.lhs.args if not a.has(thing))
if isinstance(e.lhs, sympy.Add)
else 0
)
e = op(e.lhs - lhs_not_thing, e.rhs - lhs_not_thing) # type: ignore[attr-defined]
# Divide both sides by the factors that don't contain thing.
if isinstance(e, sympy.Rel) and isinstance(e.lhs, sympy.Mul):
lhs, rhs = e.args
other = sympy.Mul(*[a for a in lhs.args if not a.has(thing)])
# If we can't tell whether 'other' is negative or positive, we do nothing.
# That is because we don't know whether we have mirror the operation or not.
if not (isinstance(e, INEQUALITY_TYPES) and other.is_negative is None):
# Divide both sides by 'other'.
lhs = lhs / other
rhs = rhs / other
# If 'e' is an inequality and 'other' is negative, we have to
# mirror the expression.
if isinstance(e, INEQUALITY_TYPES) and other.is_negative:
op = mirror_rel_op(op) # type: ignore[assignment]
assert op is not None
e = op(lhs, rhs)
################################################################################
# left-hand side is FloorDiv
################################################################################
#
# Given the expression: a // b op c
# where 'op' is a relational operation, these rules only work if:
# - b > 0
# - c is an integer
if (
floordiv_inequality
and isinstance(e, sympy.Rel)
and isinstance(e.lhs, FloorDiv)
and e.lhs.divisor.is_positive
and e.rhs.is_integer
):
# a // b == expr
# => a >= (b * expr) and a < (b * (expr + 1))
if isinstance(e, sympy.Eq):
numerator, denominator = e.lhs.args
return sympy.And(
sympy.Ge(numerator, (e.rhs * denominator)), # type: ignore[arg-type]
sympy.Lt(numerator, ((e.rhs + 1) * denominator)), # type: ignore[arg-type]
)
# a // b != expr
# => a < (b * expr) or a >= (b * (expr + 1))
if isinstance(e, sympy.Ne):
numerator, denominator = e.lhs.args
return sympy.Or(
sympy.Lt(numerator, (e.rhs * denominator)), # type: ignore[arg-type]
sympy.Ge(numerator, ((e.rhs + 1) * denominator)), # type: ignore[arg-type]
)
# The transformations below only work if b is positive.
# Note: we only have this information for constants.
# a // b > expr => a >= b * (expr + 1)
# a // b >= expr => a >= b * expr
if isinstance(e, (sympy.Gt, sympy.Ge)):
quotient = e.rhs if isinstance(e, sympy.Ge) else (e.rhs + 1) # type: ignore[arg-type]
return sympy.Ge(e.lhs.args[0], (quotient * e.lhs.args[1])) # type: ignore[arg-type]
# a // b < expr => a < b * expr
# a // b <= expr => a < b * (expr + 1)
if isinstance(e, (sympy.Lt, sympy.Le)):
quotient = e.rhs if isinstance(e, sympy.Lt) else (e.rhs + 1) # type: ignore[arg-type]
return sympy.Lt(e.lhs.args[0], (quotient * e.lhs.args[1])) # type: ignore[arg-type]
return e

View File

@ -0,0 +1,96 @@
# mypy: allow-untyped-defs
"""
This file contains canonical definitions for our symbol naming conventions,
across torch.fx.experimental.symbolic_shapes and torch._inductor. The
intention is:
1. To make it easily greppable where all the sites we use a prefix are
2. Make it possible to easily tell if we can introduce a new prefix without
introducing a conflict
You can occasionally test if prefixes have been hardcoded by renaming prefixes
in this file and seeing what breaks.
"""
from enum import auto, Enum
from typing import Sequence, Union
import sympy
class SymT(Enum):
SIZE = auto()
FLOAT = auto()
UNBACKED_INT = auto()
UNBACKED_FLOAT = auto()
# Inductor: The intermediates in inner_fn tmp0, one generated per ops call.
# If one of these shows up in an indexing expression, that means an
# indirect load is happening.
TMP = auto()
# Inductor: Placeholder variable that is later replaced with TMP
INDIRECT = auto()
# Inductor: Some size expressions are replaced with a precomputed size ps0
# which is computed host side, and then directly reused in the kernel, so
# we don't repeatedly recompute it on device.
PRECOMPUTED_SIZE = auto()
# Inductor: An indexing variable i0 in loops IR which ranges over non-reduced
# dim in the loop
INDEX = auto()
# Inductor: A reduction indexing r0 variable in loops IR which ranges over
# reduced dim in the loop
RINDEX = auto()
# Inductor: In templated kernels torch._inductor.kernel, we have a hook to
# store the final output and append epilogue fusions. To do this, we must
# know what the indexes the outputs range over. NB: These will also
# advertise as INDEX, this is... probably OK?
TEMPLATE_INDEX = auto()
# Inductor: iteration domain for blockIdx.x/blockIdx.y
XBLOCK = auto()
YBLOCK = auto()
# Inductor: this is used solely for dynamic_reshape_indexer
VIEW = auto()
# Alternate (non-modular) indexing used in halide kernels
HALIDE = auto()
# Invariant: there must not be a prefix which is a prefix of another string,
# as this introduces ambiguity
prefix_str = {
SymT.SIZE: "s", # integer
SymT.UNBACKED_INT: "u", # integer
# Prefix z here is chosen to avoid false aliasing in symbol_is_type test
# DO NOT add a "z" type. You also need to avoid conflicts on these
# prefixes but this is somewhat easier to manage
SymT.FLOAT: "zf",
SymT.UNBACKED_FLOAT: "zuf",
SymT.TMP: "tmp",
SymT.PRECOMPUTED_SIZE: "ps",
SymT.INDEX: "i",
SymT.RINDEX: "r",
SymT.TEMPLATE_INDEX: "idx",
SymT.XBLOCK: "x",
SymT.YBLOCK: "y",
SymT.INDIRECT: "indirect", # false aliasing?
SymT.VIEW: "view",
SymT.HALIDE: "h",
}
def make_symbol(prefix: SymT, idx: int, **kwargs) -> sympy.Symbol:
# TODO: maybe put the assumptions here directly
return sympy.Symbol(f"{prefix_str[prefix]}{idx}", **kwargs)
# This type is a little wider than it should be, because free_symbols says
# that it contains Basic, rather than Symbol
def symbol_is_type(sym: sympy.Basic, prefix: Union[SymT, Sequence[SymT]]) -> bool:
assert isinstance(sym, sympy.Symbol)
name_str = sym.name.lower() # Match capitalized names like XBLOCK, RBLOCK
if isinstance(prefix, SymT):
return name_str.startswith(prefix_str[prefix])
else:
return name_str.startswith(tuple(prefix_str[p] for p in prefix))
def free_symbol_is_type(e: sympy.Expr, prefix: SymT) -> bool:
return any(symbol_is_type(v, prefix) for v in e.free_symbols)

File diff suppressed because it is too large Load Diff