I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,50 @@
""" Rewrite Rules
DISCLAIMER: This module is experimental. The interface is subject to change.
A rule is a function that transforms one expression into another
Rule :: Expr -> Expr
A strategy is a function that says how a rule should be applied to a syntax
tree. In general strategies take rules and produce a new rule
Strategy :: [Rules], Other-stuff -> Rule
This allows developers to separate a mathematical transformation from the
algorithmic details of applying that transformation. The goal is to separate
the work of mathematical programming from algorithmic programming.
Submodules
strategies.rl - some fundamental rules
strategies.core - generic non-SymPy specific strategies
strategies.traverse - strategies that traverse a SymPy tree
strategies.tools - some conglomerate strategies that do depend on SymPy
"""
from . import rl
from . import traverse
from .rl import rm_id, unpack, flatten, sort, glom, distribute, rebuild
from .util import new
from .core import (
condition, debug, chain, null_safe, do_one, exhaust, minimize, tryit)
from .tools import canon, typed
from . import branch
__all__ = [
'rl',
'traverse',
'rm_id', 'unpack', 'flatten', 'sort', 'glom', 'distribute', 'rebuild',
'new',
'condition', 'debug', 'chain', 'null_safe', 'do_one', 'exhaust',
'minimize', 'tryit',
'canon', 'typed',
'branch',
]

View File

@ -0,0 +1,14 @@
from . import traverse
from .core import (
condition, debug, multiplex, exhaust, notempty,
chain, onaction, sfilter, yieldify, do_one, identity)
from .tools import canon
__all__ = [
'traverse',
'condition', 'debug', 'multiplex', 'exhaust', 'notempty', 'chain',
'onaction', 'sfilter', 'yieldify', 'do_one', 'identity',
'canon',
]

View File

@ -0,0 +1,116 @@
""" Generic SymPy-Independent Strategies """
def identity(x):
yield x
def exhaust(brule):
""" Apply a branching rule repeatedly until it has no effect """
def exhaust_brl(expr):
seen = {expr}
for nexpr in brule(expr):
if nexpr not in seen:
seen.add(nexpr)
yield from exhaust_brl(nexpr)
if seen == {expr}:
yield expr
return exhaust_brl
def onaction(brule, fn):
def onaction_brl(expr):
for result in brule(expr):
if result != expr:
fn(brule, expr, result)
yield result
return onaction_brl
def debug(brule, file=None):
""" Print the input and output expressions at each rule application """
if not file:
from sys import stdout
file = stdout
def write(brl, expr, result):
file.write("Rule: %s\n" % brl.__name__)
file.write("In: %s\nOut: %s\n\n" % (expr, result))
return onaction(brule, write)
def multiplex(*brules):
""" Multiplex many branching rules into one """
def multiplex_brl(expr):
seen = set()
for brl in brules:
for nexpr in brl(expr):
if nexpr not in seen:
seen.add(nexpr)
yield nexpr
return multiplex_brl
def condition(cond, brule):
""" Only apply branching rule if condition is true """
def conditioned_brl(expr):
if cond(expr):
yield from brule(expr)
else:
pass
return conditioned_brl
def sfilter(pred, brule):
""" Yield only those results which satisfy the predicate """
def filtered_brl(expr):
yield from filter(pred, brule(expr))
return filtered_brl
def notempty(brule):
def notempty_brl(expr):
yielded = False
for nexpr in brule(expr):
yielded = True
yield nexpr
if not yielded:
yield expr
return notempty_brl
def do_one(*brules):
""" Execute one of the branching rules """
def do_one_brl(expr):
yielded = False
for brl in brules:
for nexpr in brl(expr):
yielded = True
yield nexpr
if yielded:
return
return do_one_brl
def chain(*brules):
"""
Compose a sequence of brules so that they apply to the expr sequentially
"""
def chain_brl(expr):
if not brules:
yield expr
return
head, tail = brules[0], brules[1:]
for nexpr in head(expr):
yield from chain(*tail)(nexpr)
return chain_brl
def yieldify(rl):
""" Turn a rule into a branching rule """
def brl(expr):
yield rl(expr)
return brl

View File

@ -0,0 +1,117 @@
from sympy.strategies.branch.core import (
exhaust, debug, multiplex, condition, notempty, chain, onaction, sfilter,
yieldify, do_one, identity)
def posdec(x):
if x > 0:
yield x - 1
else:
yield x
def branch5(x):
if 0 < x < 5:
yield x - 1
elif 5 < x < 10:
yield x + 1
elif x == 5:
yield x + 1
yield x - 1
else:
yield x
def even(x):
return x % 2 == 0
def inc(x):
yield x + 1
def one_to_n(n):
yield from range(n)
def test_exhaust():
brl = exhaust(branch5)
assert set(brl(3)) == {0}
assert set(brl(7)) == {10}
assert set(brl(5)) == {0, 10}
def test_debug():
from io import StringIO
file = StringIO()
rl = debug(posdec, file)
list(rl(5))
log = file.getvalue()
file.close()
assert posdec.__name__ in log
assert '5' in log
assert '4' in log
def test_multiplex():
brl = multiplex(posdec, branch5)
assert set(brl(3)) == {2}
assert set(brl(7)) == {6, 8}
assert set(brl(5)) == {4, 6}
def test_condition():
brl = condition(even, branch5)
assert set(brl(4)) == set(branch5(4))
assert set(brl(5)) == set()
def test_sfilter():
brl = sfilter(even, one_to_n)
assert set(brl(10)) == {0, 2, 4, 6, 8}
def test_notempty():
def ident_if_even(x):
if even(x):
yield x
brl = notempty(ident_if_even)
assert set(brl(4)) == {4}
assert set(brl(5)) == {5}
def test_chain():
assert list(chain()(2)) == [2] # identity
assert list(chain(inc, inc)(2)) == [4]
assert list(chain(branch5, inc)(4)) == [4]
assert set(chain(branch5, inc)(5)) == {5, 7}
assert list(chain(inc, branch5)(5)) == [7]
def test_onaction():
L = []
def record(fn, input, output):
L.append((input, output))
list(onaction(inc, record)(2))
assert L == [(2, 3)]
list(onaction(identity, record)(2))
assert L == [(2, 3)]
def test_yieldify():
yinc = yieldify(lambda x: x + 1)
assert list(yinc(3)) == [4]
def test_do_one():
def bad(expr):
raise ValueError
assert list(do_one(inc)(3)) == [4]
assert list(do_one(inc, bad)(3)) == [4]
assert list(do_one(inc, posdec)(3)) == [4]

View File

@ -0,0 +1,42 @@
from sympy.strategies.branch.tools import canon
from sympy.core.basic import Basic
from sympy.core.numbers import Integer
from sympy.core.singleton import S
def posdec(x):
if isinstance(x, Integer) and x > 0:
yield x - 1
else:
yield x
def branch5(x):
if isinstance(x, Integer):
if 0 < x < 5:
yield x - 1
elif 5 < x < 10:
yield x + 1
elif x == 5:
yield x + 1
yield x - 1
else:
yield x
def test_zero_ints():
expr = Basic(S(2), Basic(S(5), S(3)), S(8))
expected = {Basic(S(0), Basic(S(0), S(0)), S(0))}
brl = canon(posdec)
assert set(brl(expr)) == expected
def test_split5():
expr = Basic(S(2), Basic(S(5), S(3)), S(8))
expected = {
Basic(S(0), Basic(S(0), S(0)), S(10)),
Basic(S(0), Basic(S(10), S(0)), S(10))}
brl = canon(branch5)
assert set(brl(expr)) == expected

View File

@ -0,0 +1,53 @@
from sympy.core.basic import Basic
from sympy.core.numbers import Integer
from sympy.core.singleton import S
from sympy.strategies.branch.traverse import top_down, sall
from sympy.strategies.branch.core import do_one, identity
def inc(x):
if isinstance(x, Integer):
yield x + 1
def test_top_down_easy():
expr = Basic(S(1), S(2))
expected = Basic(S(2), S(3))
brl = top_down(inc)
assert set(brl(expr)) == {expected}
def test_top_down_big_tree():
expr = Basic(S(1), Basic(S(2)), Basic(S(3), Basic(S(4)), S(5)))
expected = Basic(S(2), Basic(S(3)), Basic(S(4), Basic(S(5)), S(6)))
brl = top_down(inc)
assert set(brl(expr)) == {expected}
def test_top_down_harder_function():
def split5(x):
if x == 5:
yield x - 1
yield x + 1
expr = Basic(Basic(S(5), S(6)), S(1))
expected = {Basic(Basic(S(4), S(6)), S(1)), Basic(Basic(S(6), S(6)), S(1))}
brl = top_down(split5)
assert set(brl(expr)) == expected
def test_sall():
expr = Basic(S(1), S(2))
expected = Basic(S(2), S(3))
brl = sall(inc)
assert list(brl(expr)) == [expected]
expr = Basic(S(1), S(2), Basic(S(3), S(4)))
expected = Basic(S(2), S(3), Basic(S(3), S(4)))
brl = sall(do_one(inc, identity))
assert list(brl(expr)) == [expected]

View File

@ -0,0 +1,12 @@
from .core import exhaust, multiplex
from .traverse import top_down
def canon(*rules):
""" Strategy for canonicalization
Apply each branching rule in a top-down fashion through the tree.
Multiplex through all branching rule traversals
Keep doing this until there is no change.
"""
return exhaust(multiplex(*map(top_down, rules)))

View File

@ -0,0 +1,25 @@
""" Branching Strategies to Traverse a Tree """
from itertools import product
from sympy.strategies.util import basic_fns
from .core import chain, identity, do_one
def top_down(brule, fns=basic_fns):
""" Apply a rule down a tree running it on the top nodes first """
return chain(do_one(brule, identity),
lambda expr: sall(top_down(brule, fns), fns)(expr))
def sall(brule, fns=basic_fns):
""" Strategic all - apply rule to args """
op, new, children, leaf = map(fns.get, ('op', 'new', 'children', 'leaf'))
def all_rl(expr):
if leaf(expr):
yield expr
else:
myop = op(expr)
argss = product(*map(brule, children(expr)))
for args in argss:
yield new(myop, *args)
return all_rl

View File

@ -0,0 +1,151 @@
""" Generic SymPy-Independent Strategies """
from __future__ import annotations
from collections.abc import Callable, Mapping
from typing import TypeVar
from sys import stdout
_S = TypeVar('_S')
_T = TypeVar('_T')
def identity(x: _T) -> _T:
return x
def exhaust(rule: Callable[[_T], _T]) -> Callable[[_T], _T]:
""" Apply a rule repeatedly until it has no effect """
def exhaustive_rl(expr: _T) -> _T:
new, old = rule(expr), expr
while new != old:
new, old = rule(new), new
return new
return exhaustive_rl
def memoize(rule: Callable[[_S], _T]) -> Callable[[_S], _T]:
"""Memoized version of a rule
Notes
=====
This cache can grow infinitely, so it is not recommended to use this
than ``functools.lru_cache`` unless you need very heavy computation.
"""
cache: dict[_S, _T] = {}
def memoized_rl(expr: _S) -> _T:
if expr in cache:
return cache[expr]
else:
result = rule(expr)
cache[expr] = result
return result
return memoized_rl
def condition(
cond: Callable[[_T], bool], rule: Callable[[_T], _T]
) -> Callable[[_T], _T]:
""" Only apply rule if condition is true """
def conditioned_rl(expr: _T) -> _T:
if cond(expr):
return rule(expr)
return expr
return conditioned_rl
def chain(*rules: Callable[[_T], _T]) -> Callable[[_T], _T]:
"""
Compose a sequence of rules so that they apply to the expr sequentially
"""
def chain_rl(expr: _T) -> _T:
for rule in rules:
expr = rule(expr)
return expr
return chain_rl
def debug(rule, file=None):
""" Print out before and after expressions each time rule is used """
if file is None:
file = stdout
def debug_rl(*args, **kwargs):
expr = args[0]
result = rule(*args, **kwargs)
if result != expr:
file.write("Rule: %s\n" % rule.__name__)
file.write("In: %s\nOut: %s\n\n" % (expr, result))
return result
return debug_rl
def null_safe(rule: Callable[[_T], _T | None]) -> Callable[[_T], _T]:
""" Return original expr if rule returns None """
def null_safe_rl(expr: _T) -> _T:
result = rule(expr)
if result is None:
return expr
return result
return null_safe_rl
def tryit(rule: Callable[[_T], _T], exception) -> Callable[[_T], _T]:
""" Return original expr if rule raises exception """
def try_rl(expr: _T) -> _T:
try:
return rule(expr)
except exception:
return expr
return try_rl
def do_one(*rules: Callable[[_T], _T]) -> Callable[[_T], _T]:
""" Try each of the rules until one works. Then stop. """
def do_one_rl(expr: _T) -> _T:
for rl in rules:
result = rl(expr)
if result != expr:
return result
return expr
return do_one_rl
def switch(
key: Callable[[_S], _T],
ruledict: Mapping[_T, Callable[[_S], _S]]
) -> Callable[[_S], _S]:
""" Select a rule based on the result of key called on the function """
def switch_rl(expr: _S) -> _S:
rl = ruledict.get(key(expr), identity)
return rl(expr)
return switch_rl
# XXX Untyped default argument for minimize function
# where python requires SupportsRichComparison type
def _identity(x):
return x
def minimize(
*rules: Callable[[_S], _T],
objective=_identity
) -> Callable[[_S], _T]:
""" Select result of rules that minimizes objective
>>> from sympy.strategies import minimize
>>> inc = lambda x: x + 1
>>> dec = lambda x: x - 1
>>> rl = minimize(inc, dec)
>>> rl(4)
3
>>> rl = minimize(inc, dec, objective=lambda x: -x) # maximize
>>> rl(4)
5
"""
def minrule(expr: _S) -> _T:
return min([rule(expr) for rule in rules], key=objective)
return minrule

View File

@ -0,0 +1,176 @@
""" Generic Rules for SymPy
This file assumes knowledge of Basic and little else.
"""
from sympy.utilities.iterables import sift
from .util import new
# Functions that create rules
def rm_id(isid, new=new):
""" Create a rule to remove identities.
isid - fn :: x -> Bool --- whether or not this element is an identity.
Examples
========
>>> from sympy.strategies import rm_id
>>> from sympy import Basic, S
>>> remove_zeros = rm_id(lambda x: x==0)
>>> remove_zeros(Basic(S(1), S(0), S(2)))
Basic(1, 2)
>>> remove_zeros(Basic(S(0), S(0))) # If only identites then we keep one
Basic(0)
See Also:
unpack
"""
def ident_remove(expr):
""" Remove identities """
ids = list(map(isid, expr.args))
if sum(ids) == 0: # No identities. Common case
return expr
elif sum(ids) != len(ids): # there is at least one non-identity
return new(expr.__class__,
*[arg for arg, x in zip(expr.args, ids) if not x])
else:
return new(expr.__class__, expr.args[0])
return ident_remove
def glom(key, count, combine):
""" Create a rule to conglomerate identical args.
Examples
========
>>> from sympy.strategies import glom
>>> from sympy import Add
>>> from sympy.abc import x
>>> key = lambda x: x.as_coeff_Mul()[1]
>>> count = lambda x: x.as_coeff_Mul()[0]
>>> combine = lambda cnt, arg: cnt * arg
>>> rl = glom(key, count, combine)
>>> rl(Add(x, -x, 3*x, 2, 3, evaluate=False))
3*x + 5
Wait, how are key, count and combine supposed to work?
>>> key(2*x)
x
>>> count(2*x)
2
>>> combine(2, x)
2*x
"""
def conglomerate(expr):
""" Conglomerate together identical args x + x -> 2x """
groups = sift(expr.args, key)
counts = {k: sum(map(count, args)) for k, args in groups.items()}
newargs = [combine(cnt, mat) for mat, cnt in counts.items()]
if set(newargs) != set(expr.args):
return new(type(expr), *newargs)
else:
return expr
return conglomerate
def sort(key, new=new):
""" Create a rule to sort by a key function.
Examples
========
>>> from sympy.strategies import sort
>>> from sympy import Basic, S
>>> sort_rl = sort(str)
>>> sort_rl(Basic(S(3), S(1), S(2)))
Basic(1, 2, 3)
"""
def sort_rl(expr):
return new(expr.__class__, *sorted(expr.args, key=key))
return sort_rl
def distribute(A, B):
""" Turns an A containing Bs into a B of As
where A, B are container types
>>> from sympy.strategies import distribute
>>> from sympy import Add, Mul, symbols
>>> x, y = symbols('x,y')
>>> dist = distribute(Mul, Add)
>>> expr = Mul(2, x+y, evaluate=False)
>>> expr
2*(x + y)
>>> dist(expr)
2*x + 2*y
"""
def distribute_rl(expr):
for i, arg in enumerate(expr.args):
if isinstance(arg, B):
first, b, tail = expr.args[:i], expr.args[i], expr.args[i + 1:]
return B(*[A(*(first + (arg,) + tail)) for arg in b.args])
return expr
return distribute_rl
def subs(a, b):
""" Replace expressions exactly """
def subs_rl(expr):
if expr == a:
return b
else:
return expr
return subs_rl
# Functions that are rules
def unpack(expr):
""" Rule to unpack singleton args
>>> from sympy.strategies import unpack
>>> from sympy import Basic, S
>>> unpack(Basic(S(2)))
2
"""
if len(expr.args) == 1:
return expr.args[0]
else:
return expr
def flatten(expr, new=new):
""" Flatten T(a, b, T(c, d), T2(e)) to T(a, b, c, d, T2(e)) """
cls = expr.__class__
args = []
for arg in expr.args:
if arg.__class__ == cls:
args.extend(arg.args)
else:
args.append(arg)
return new(expr.__class__, *args)
def rebuild(expr):
""" Rebuild a SymPy tree.
Explanation
===========
This function recursively calls constructors in the expression tree.
This forces canonicalization and removes ugliness introduced by the use of
Basic.__new__
"""
if expr.is_Atom:
return expr
else:
return expr.func(*list(map(rebuild, expr.args)))

View File

@ -0,0 +1,118 @@
from __future__ import annotations
from sympy.core.singleton import S
from sympy.core.basic import Basic
from sympy.strategies.core import (
null_safe, exhaust, memoize, condition,
chain, tryit, do_one, debug, switch, minimize)
from io import StringIO
def posdec(x: int) -> int:
if x > 0:
return x - 1
return x
def inc(x: int) -> int:
return x + 1
def dec(x: int) -> int:
return x - 1
def test_null_safe():
def rl(expr: int) -> int | None:
if expr == 1:
return 2
return None
safe_rl = null_safe(rl)
assert rl(1) == safe_rl(1)
assert rl(3) is None
assert safe_rl(3) == 3
def test_exhaust():
sink = exhaust(posdec)
assert sink(5) == 0
assert sink(10) == 0
def test_memoize():
rl = memoize(posdec)
assert rl(5) == posdec(5)
assert rl(5) == posdec(5)
assert rl(-2) == posdec(-2)
def test_condition():
rl = condition(lambda x: x % 2 == 0, posdec)
assert rl(5) == 5
assert rl(4) == 3
def test_chain():
rl = chain(posdec, posdec)
assert rl(5) == 3
assert rl(1) == 0
def test_tryit():
def rl(expr: Basic) -> Basic:
assert False
safe_rl = tryit(rl, AssertionError)
assert safe_rl(S(1)) == S(1)
def test_do_one():
rl = do_one(posdec, posdec)
assert rl(5) == 4
def rl1(x: int) -> int:
if x == 1:
return 2
return x
def rl2(x: int) -> int:
if x == 2:
return 3
return x
rule = do_one(rl1, rl2)
assert rule(1) == 2
assert rule(rule(1)) == 3
def test_debug():
file = StringIO()
rl = debug(posdec, file)
rl(5)
log = file.getvalue()
file.close()
assert posdec.__name__ in log
assert '5' in log
assert '4' in log
def test_switch():
def key(x: int) -> int:
return x % 3
rl = switch(key, {0: inc, 1: dec})
assert rl(3) == 4
assert rl(4) == 3
assert rl(5) == 5
def test_minimize():
def key(x: int) -> int:
return -x
rl = minimize(inc, dec)
assert rl(4) == 3
rl = minimize(inc, dec, objective=key)
assert rl(4) == 5

View File

@ -0,0 +1,78 @@
from sympy.core.singleton import S
from sympy.strategies.rl import (
rm_id, glom, flatten, unpack, sort, distribute, subs, rebuild)
from sympy.core.basic import Basic
from sympy.core.add import Add
from sympy.core.mul import Mul
from sympy.core.symbol import symbols
from sympy.abc import x
def test_rm_id():
rmzeros = rm_id(lambda x: x == 0)
assert rmzeros(Basic(S(0), S(1))) == Basic(S(1))
assert rmzeros(Basic(S(0), S(0))) == Basic(S(0))
assert rmzeros(Basic(S(2), S(1))) == Basic(S(2), S(1))
def test_glom():
def key(x):
return x.as_coeff_Mul()[1]
def count(x):
return x.as_coeff_Mul()[0]
def newargs(cnt, arg):
return cnt * arg
rl = glom(key, count, newargs)
result = rl(Add(x, -x, 3 * x, 2, 3, evaluate=False))
expected = Add(3 * x, 5)
assert set(result.args) == set(expected.args)
def test_flatten():
assert flatten(Basic(S(1), S(2), Basic(S(3), S(4)))) == \
Basic(S(1), S(2), S(3), S(4))
def test_unpack():
assert unpack(Basic(S(2))) == 2
assert unpack(Basic(S(2), S(3))) == Basic(S(2), S(3))
def test_sort():
assert sort(str)(Basic(S(3), S(1), S(2))) == Basic(S(1), S(2), S(3))
def test_distribute():
class T1(Basic):
pass
class T2(Basic):
pass
distribute_t12 = distribute(T1, T2)
assert distribute_t12(T1(S(1), S(2), T2(S(3), S(4)), S(5))) == \
T2(T1(S(1), S(2), S(3), S(5)), T1(S(1), S(2), S(4), S(5)))
assert distribute_t12(T1(S(1), S(2), S(3))) == T1(S(1), S(2), S(3))
def test_distribute_add_mul():
x, y = symbols('x, y')
expr = Mul(2, Add(x, y), evaluate=False)
expected = Add(Mul(2, x), Mul(2, y))
distribute_mul = distribute(Mul, Add)
assert distribute_mul(expr) == expected
def test_subs():
rl = subs(1, 2)
assert rl(1) == 2
assert rl(3) == 3
def test_rebuild():
expr = Basic.__new__(Add, S(1), S(2))
assert rebuild(expr) == 3

View File

@ -0,0 +1,32 @@
from sympy.strategies.tools import subs, typed
from sympy.strategies.rl import rm_id
from sympy.core.basic import Basic
from sympy.core.singleton import S
def test_subs():
from sympy.core.symbol import symbols
a, b, c, d, e, f = symbols('a,b,c,d,e,f')
mapping = {a: d, d: a, Basic(e): Basic(f)}
expr = Basic(a, Basic(b, c), Basic(d, Basic(e)))
result = Basic(d, Basic(b, c), Basic(a, Basic(f)))
assert subs(mapping)(expr) == result
def test_subs_empty():
assert subs({})(Basic(S(1), S(2))) == Basic(S(1), S(2))
def test_typed():
class A(Basic):
pass
class B(Basic):
pass
rmzeros = rm_id(lambda x: x == S(0))
rmones = rm_id(lambda x: x == S(1))
remove_something = typed({A: rmzeros, B: rmones})
assert remove_something(A(S(0), S(1))) == A(S(1))
assert remove_something(B(S(0), S(1))) == B(S(0))

View File

@ -0,0 +1,84 @@
from sympy.strategies.traverse import (
top_down, bottom_up, sall, top_down_once, bottom_up_once, basic_fns)
from sympy.strategies.rl import rebuild
from sympy.strategies.util import expr_fns
from sympy.core.add import Add
from sympy.core.basic import Basic
from sympy.core.numbers import Integer
from sympy.core.singleton import S
from sympy.core.symbol import Str, Symbol
from sympy.abc import x, y, z
def zero_symbols(expression):
return S.Zero if isinstance(expression, Symbol) else expression
def test_sall():
zero_onelevel = sall(zero_symbols)
assert zero_onelevel(Basic(x, y, Basic(x, z))) == \
Basic(S(0), S(0), Basic(x, z))
def test_bottom_up():
_test_global_traversal(bottom_up)
_test_stop_on_non_basics(bottom_up)
def test_top_down():
_test_global_traversal(top_down)
_test_stop_on_non_basics(top_down)
def _test_global_traversal(trav):
zero_all_symbols = trav(zero_symbols)
assert zero_all_symbols(Basic(x, y, Basic(x, z))) == \
Basic(S(0), S(0), Basic(S(0), S(0)))
def _test_stop_on_non_basics(trav):
def add_one_if_can(expr):
try:
return expr + 1
except TypeError:
return expr
expr = Basic(S(1), Str('a'), Basic(S(2), Str('b')))
expected = Basic(S(2), Str('a'), Basic(S(3), Str('b')))
rl = trav(add_one_if_can)
assert rl(expr) == expected
class Basic2(Basic):
pass
def rl(x):
if x.args and not isinstance(x.args[0], Integer):
return Basic2(*x.args)
return x
def test_top_down_once():
top_rl = top_down_once(rl)
assert top_rl(Basic(S(1.0), S(2.0), Basic(S(3), S(4)))) == \
Basic2(S(1.0), S(2.0), Basic(S(3), S(4)))
def test_bottom_up_once():
bottom_rl = bottom_up_once(rl)
assert bottom_rl(Basic(S(1), S(2), Basic(S(3.0), S(4.0)))) == \
Basic(S(1), S(2), Basic2(S(3.0), S(4.0)))
def test_expr_fns():
expr = x + y**3
e = bottom_up(lambda v: v + 1, expr_fns)(expr)
b = bottom_up(lambda v: Basic.__new__(Add, v, S(1)), basic_fns)(expr)
assert rebuild(b) == e

View File

@ -0,0 +1,92 @@
from sympy.strategies.tree import treeapply, greedy, allresults, brute
from functools import partial, reduce
def inc(x):
return x + 1
def dec(x):
return x - 1
def double(x):
return 2 * x
def square(x):
return x**2
def add(*args):
return sum(args)
def mul(*args):
return reduce(lambda a, b: a * b, args, 1)
def test_treeapply():
tree = ([3, 3], [4, 1], 2)
assert treeapply(tree, {list: min, tuple: max}) == 3
assert treeapply(tree, {list: add, tuple: mul}) == 60
def test_treeapply_leaf():
assert treeapply(3, {}, leaf=lambda x: x**2) == 9
tree = ([3, 3], [4, 1], 2)
treep1 = ([4, 4], [5, 2], 3)
assert treeapply(tree, {list: min, tuple: max}, leaf=lambda x: x + 1) == \
treeapply(treep1, {list: min, tuple: max})
def test_treeapply_strategies():
from sympy.strategies import chain, minimize
join = {list: chain, tuple: minimize}
assert treeapply(inc, join) == inc
assert treeapply((inc, dec), join)(5) == minimize(inc, dec)(5)
assert treeapply([inc, dec], join)(5) == chain(inc, dec)(5)
tree = (inc, [dec, double]) # either inc or dec-then-double
assert treeapply(tree, join)(5) == 6
assert treeapply(tree, join)(1) == 0
maximize = partial(minimize, objective=lambda x: -x)
join = {list: chain, tuple: maximize}
fn = treeapply(tree, join)
assert fn(4) == 6 # highest value comes from the dec then double
assert fn(1) == 2 # highest value comes from the inc
def test_greedy():
tree = [inc, (dec, double)] # either inc or dec-then-double
fn = greedy(tree, objective=lambda x: -x)
assert fn(4) == 6 # highest value comes from the dec then double
assert fn(1) == 2 # highest value comes from the inc
tree = [inc, dec, [inc, dec, [(inc, inc), (dec, dec)]]]
lowest = greedy(tree)
assert lowest(10) == 8
highest = greedy(tree, objective=lambda x: -x)
assert highest(10) == 12
def test_allresults():
# square = lambda x: x**2
assert set(allresults(inc)(3)) == {inc(3)}
assert set(allresults([inc, dec])(3)) == {2, 4}
assert set(allresults((inc, dec))(3)) == {3}
assert set(allresults([inc, (dec, double)])(4)) == {5, 6}
def test_brute():
tree = ([inc, dec], square)
fn = brute(tree, lambda x: -x)
assert fn(2) == (2 + 1)**2
assert fn(-2) == (-2 - 1)**2
assert brute(inc)(1) == 2

View File

@ -0,0 +1,53 @@
from . import rl
from .core import do_one, exhaust, switch
from .traverse import top_down
def subs(d, **kwargs):
""" Full simultaneous exact substitution.
Examples
========
>>> from sympy.strategies.tools import subs
>>> from sympy import Basic, S
>>> mapping = {S(1): S(4), S(4): S(1), Basic(S(5)): Basic(S(6), S(7))}
>>> expr = Basic(S(1), Basic(S(2), S(3)), Basic(S(4), Basic(S(5))))
>>> subs(mapping)(expr)
Basic(4, Basic(2, 3), Basic(1, Basic(6, 7)))
"""
if d:
return top_down(do_one(*map(rl.subs, *zip(*d.items()))), **kwargs)
else:
return lambda x: x
def canon(*rules, **kwargs):
""" Strategy for canonicalization.
Explanation
===========
Apply each rule in a bottom_up fashion through the tree.
Do each one in turn.
Keep doing this until there is no change.
"""
return exhaust(top_down(exhaust(do_one(*rules)), **kwargs))
def typed(ruletypes):
""" Apply rules based on the expression type
inputs:
ruletypes -- a dict mapping {Type: rule}
Examples
========
>>> from sympy.strategies import rm_id, typed
>>> from sympy import Add, Mul
>>> rm_zeros = rm_id(lambda x: x==0)
>>> rm_ones = rm_id(lambda x: x==1)
>>> remove_idents = typed({Add: rm_zeros, Mul: rm_ones})
"""
return switch(type, ruletypes)

View File

@ -0,0 +1,37 @@
"""Strategies to Traverse a Tree."""
from sympy.strategies.util import basic_fns
from sympy.strategies.core import chain, do_one
def top_down(rule, fns=basic_fns):
"""Apply a rule down a tree running it on the top nodes first."""
return chain(rule, lambda expr: sall(top_down(rule, fns), fns)(expr))
def bottom_up(rule, fns=basic_fns):
"""Apply a rule down a tree running it on the bottom nodes first."""
return chain(lambda expr: sall(bottom_up(rule, fns), fns)(expr), rule)
def top_down_once(rule, fns=basic_fns):
"""Apply a rule down a tree - stop on success."""
return do_one(rule, lambda expr: sall(top_down(rule, fns), fns)(expr))
def bottom_up_once(rule, fns=basic_fns):
"""Apply a rule up a tree - stop on success."""
return do_one(lambda expr: sall(bottom_up(rule, fns), fns)(expr), rule)
def sall(rule, fns=basic_fns):
"""Strategic all - apply rule to args."""
op, new, children, leaf = map(fns.get, ('op', 'new', 'children', 'leaf'))
def all_rl(expr):
if leaf(expr):
return expr
else:
args = map(rule, children(expr))
return new(op(expr), *args)
return all_rl

View File

@ -0,0 +1,139 @@
from functools import partial
from sympy.strategies import chain, minimize
from sympy.strategies.core import identity
import sympy.strategies.branch as branch
from sympy.strategies.branch import yieldify
def treeapply(tree, join, leaf=identity):
""" Apply functions onto recursive containers (tree).
Explanation
===========
join - a dictionary mapping container types to functions
e.g. ``{list: minimize, tuple: chain}``
Keys are containers/iterables. Values are functions [a] -> a.
Examples
========
>>> from sympy.strategies.tree import treeapply
>>> tree = [(3, 2), (4, 1)]
>>> treeapply(tree, {list: max, tuple: min})
2
>>> add = lambda *args: sum(args)
>>> def mul(*args):
... total = 1
... for arg in args:
... total *= arg
... return total
>>> treeapply(tree, {list: mul, tuple: add})
25
"""
for typ in join:
if isinstance(tree, typ):
return join[typ](*map(partial(treeapply, join=join, leaf=leaf),
tree))
return leaf(tree)
def greedy(tree, objective=identity, **kwargs):
""" Execute a strategic tree. Select alternatives greedily
Trees
-----
Nodes in a tree can be either
function - a leaf
list - a selection among operations
tuple - a sequence of chained operations
Textual examples
----------------
Text: Run f, then run g, e.g. ``lambda x: g(f(x))``
Code: ``(f, g)``
Text: Run either f or g, whichever minimizes the objective
Code: ``[f, g]``
Textx: Run either f or g, whichever is better, then run h
Code: ``([f, g], h)``
Text: Either expand then simplify or try factor then foosimp. Finally print
Code: ``([(expand, simplify), (factor, foosimp)], print)``
Objective
---------
"Better" is determined by the objective keyword. This function makes
choices to minimize the objective. It defaults to the identity.
Examples
========
>>> from sympy.strategies.tree import greedy
>>> inc = lambda x: x + 1
>>> dec = lambda x: x - 1
>>> double = lambda x: 2*x
>>> tree = [inc, (dec, double)] # either inc or dec-then-double
>>> fn = greedy(tree)
>>> fn(4) # lowest value comes from the inc
5
>>> fn(1) # lowest value comes from dec then double
0
This function selects between options in a tuple. The result is chosen
that minimizes the objective function.
>>> fn = greedy(tree, objective=lambda x: -x) # maximize
>>> fn(4) # highest value comes from the dec then double
6
>>> fn(1) # highest value comes from the inc
2
Greediness
----------
This is a greedy algorithm. In the example:
([a, b], c) # do either a or b, then do c
the choice between running ``a`` or ``b`` is made without foresight to c
"""
optimize = partial(minimize, objective=objective)
return treeapply(tree, {list: optimize, tuple: chain}, **kwargs)
def allresults(tree, leaf=yieldify):
""" Execute a strategic tree. Return all possibilities.
Returns a lazy iterator of all possible results
Exhaustiveness
--------------
This is an exhaustive algorithm. In the example
([a, b], [c, d])
All of the results from
(a, c), (b, c), (a, d), (b, d)
are returned. This can lead to combinatorial blowup.
See sympy.strategies.greedy for details on input
"""
return treeapply(tree, {list: branch.multiplex, tuple: branch.chain},
leaf=leaf)
def brute(tree, objective=identity, **kwargs):
return lambda expr: min(tuple(allresults(tree, **kwargs)(expr)),
key=objective)

View File

@ -0,0 +1,17 @@
from sympy.core.basic import Basic
new = Basic.__new__
def assoc(d, k, v):
d = d.copy()
d[k] = v
return d
basic_fns = {'op': type,
'new': Basic.__new__,
'leaf': lambda x: not isinstance(x, Basic) or x.is_Atom,
'children': lambda x: x.args}
expr_fns = assoc(basic_fns, 'new', lambda op, *args: op(*args))