I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,62 @@
from sympy.multipledispatch.conflict import (supercedes, ordering, ambiguities,
ambiguous, super_signature, consistent)
class A: pass
class B(A): pass
class C: pass
def test_supercedes():
assert supercedes([B], [A])
assert supercedes([B, A], [A, A])
assert not supercedes([B, A], [A, B])
assert not supercedes([A], [B])
def test_consistent():
assert consistent([A], [A])
assert consistent([B], [B])
assert not consistent([A], [C])
assert consistent([A, B], [A, B])
assert consistent([B, A], [A, B])
assert not consistent([B, A], [B])
assert not consistent([B, A], [B, C])
def test_super_signature():
assert super_signature([[A]]) == [A]
assert super_signature([[A], [B]]) == [B]
assert super_signature([[A, B], [B, A]]) == [B, B]
assert super_signature([[A, A, B], [A, B, A], [B, A, A]]) == [B, B, B]
def test_ambiguous():
assert not ambiguous([A], [A])
assert not ambiguous([A], [B])
assert not ambiguous([B], [B])
assert not ambiguous([A, B], [B, B])
assert ambiguous([A, B], [B, A])
def test_ambiguities():
signatures = [[A], [B], [A, B], [B, A], [A, C]]
expected = {((A, B), (B, A))}
result = ambiguities(signatures)
assert set(map(frozenset, expected)) == set(map(frozenset, result))
signatures = [[A], [B], [A, B], [B, A], [A, C], [B, B]]
expected = set()
result = ambiguities(signatures)
assert set(map(frozenset, expected)) == set(map(frozenset, result))
def test_ordering():
signatures = [[A, A], [A, B], [B, A], [B, B], [A, C]]
ord = ordering(signatures)
assert ord[0] == (B, B) or ord[0] == (A, C)
assert ord[-1] == (A, A) or ord[-1] == (A, C)
def test_type_mro():
assert super_signature([[object], [type]]) == [type]

View File

@ -0,0 +1,213 @@
from __future__ import annotations
from typing import Any
from sympy.multipledispatch import dispatch
from sympy.multipledispatch.conflict import AmbiguityWarning
from sympy.testing.pytest import raises, warns
from functools import partial
test_namespace: dict[str, Any] = {}
orig_dispatch = dispatch
dispatch = partial(dispatch, namespace=test_namespace)
def test_singledispatch():
@dispatch(int)
def f(x): # noqa:F811
return x + 1
@dispatch(int)
def g(x): # noqa:F811
return x + 2
@dispatch(float) # noqa:F811
def f(x): # noqa:F811
return x - 1
assert f(1) == 2
assert g(1) == 3
assert f(1.0) == 0
assert raises(NotImplementedError, lambda: f('hello'))
def test_multipledispatch():
@dispatch(int, int)
def f(x, y): # noqa:F811
return x + y
@dispatch(float, float) # noqa:F811
def f(x, y): # noqa:F811
return x - y
assert f(1, 2) == 3
assert f(1.0, 2.0) == -1.0
class A: pass
class B: pass
class C(A): pass
class D(C): pass
class E(C): pass
def test_inheritance():
@dispatch(A)
def f(x): # noqa:F811
return 'a'
@dispatch(B) # noqa:F811
def f(x): # noqa:F811
return 'b'
assert f(A()) == 'a'
assert f(B()) == 'b'
assert f(C()) == 'a'
def test_inheritance_and_multiple_dispatch():
@dispatch(A, A)
def f(x, y): # noqa:F811
return type(x), type(y)
@dispatch(A, B) # noqa:F811
def f(x, y): # noqa:F811
return 0
assert f(A(), A()) == (A, A)
assert f(A(), C()) == (A, C)
assert f(A(), B()) == 0
assert f(C(), B()) == 0
assert raises(NotImplementedError, lambda: f(B(), B()))
def test_competing_solutions():
@dispatch(A)
def h(x): # noqa:F811
return 1
@dispatch(C) # noqa:F811
def h(x): # noqa:F811
return 2
assert h(D()) == 2
def test_competing_multiple():
@dispatch(A, B)
def h(x, y): # noqa:F811
return 1
@dispatch(C, B) # noqa:F811
def h(x, y): # noqa:F811
return 2
assert h(D(), B()) == 2
def test_competing_ambiguous():
test_namespace = {}
dispatch = partial(orig_dispatch, namespace=test_namespace)
@dispatch(A, C)
def f(x, y): # noqa:F811
return 2
with warns(AmbiguityWarning, test_stacklevel=False):
@dispatch(C, A) # noqa:F811
def f(x, y): # noqa:F811
return 2
assert f(A(), C()) == f(C(), A()) == 2
# assert raises(Warning, lambda : f(C(), C()))
def test_caching_correct_behavior():
@dispatch(A)
def f(x): # noqa:F811
return 1
assert f(C()) == 1
@dispatch(C)
def f(x): # noqa:F811
return 2
assert f(C()) == 2
def test_union_types():
@dispatch((A, C))
def f(x): # noqa:F811
return 1
assert f(A()) == 1
assert f(C()) == 1
def test_namespaces():
ns1 = {}
ns2 = {}
def foo(x):
return 1
foo1 = orig_dispatch(int, namespace=ns1)(foo)
def foo(x):
return 2
foo2 = orig_dispatch(int, namespace=ns2)(foo)
assert foo1(0) == 1
assert foo2(0) == 2
"""
Fails
def test_dispatch_on_dispatch():
@dispatch(A)
@dispatch(C)
def q(x): # noqa:F811
return 1
assert q(A()) == 1
assert q(C()) == 1
"""
def test_methods():
class Foo:
@dispatch(float)
def f(self, x): # noqa:F811
return x - 1
@dispatch(int) # noqa:F811
def f(self, x): # noqa:F811
return x + 1
@dispatch(int)
def g(self, x): # noqa:F811
return x + 3
foo = Foo()
assert foo.f(1) == 2
assert foo.f(1.0) == 0.0
assert foo.g(1) == 4
def test_methods_multiple_dispatch():
class Foo:
@dispatch(A, A)
def f(x, y): # noqa:F811
return 1
@dispatch(A, C) # noqa:F811
def f(x, y): # noqa:F811
return 2
foo = Foo()
assert foo.f(A(), A()) == 1
assert foo.f(A(), C()) == 2
assert foo.f(C(), C()) == 2

View File

@ -0,0 +1,284 @@
from sympy.multipledispatch.dispatcher import (Dispatcher, MDNotImplementedError,
MethodDispatcher, halt_ordering,
restart_ordering,
ambiguity_register_error_ignore_dup)
from sympy.testing.pytest import raises, warns
def identity(x):
return x
def inc(x):
return x + 1
def dec(x):
return x - 1
def test_dispatcher():
f = Dispatcher('f')
f.add((int,), inc)
f.add((float,), dec)
with warns(DeprecationWarning, test_stacklevel=False):
assert f.resolve((int,)) == inc
assert f.dispatch(int) is inc
assert f(1) == 2
assert f(1.0) == 0.0
def test_union_types():
f = Dispatcher('f')
f.register((int, float))(inc)
assert f(1) == 2
assert f(1.0) == 2.0
def test_dispatcher_as_decorator():
f = Dispatcher('f')
@f.register(int)
def inc(x): # noqa:F811
return x + 1
@f.register(float) # noqa:F811
def inc(x): # noqa:F811
return x - 1
assert f(1) == 2
assert f(1.0) == 0.0
def test_register_instance_method():
class Test:
__init__ = MethodDispatcher('f')
@__init__.register(list)
def _init_list(self, data):
self.data = data
@__init__.register(object)
def _init_obj(self, datum):
self.data = [datum]
a = Test(3)
b = Test([3])
assert a.data == b.data
def test_on_ambiguity():
f = Dispatcher('f')
def identity(x): return x
ambiguities = [False]
def on_ambiguity(dispatcher, amb):
ambiguities[0] = True
f.add((object, object), identity, on_ambiguity=on_ambiguity)
assert not ambiguities[0]
f.add((object, float), identity, on_ambiguity=on_ambiguity)
assert not ambiguities[0]
f.add((float, object), identity, on_ambiguity=on_ambiguity)
assert ambiguities[0]
def test_raise_error_on_non_class():
f = Dispatcher('f')
assert raises(TypeError, lambda: f.add((1,), inc))
def test_docstring():
def one(x, y):
""" Docstring number one """
return x + y
def two(x, y):
""" Docstring number two """
return x + y
def three(x, y):
return x + y
master_doc = 'Doc of the multimethod itself'
f = Dispatcher('f', doc=master_doc)
f.add((object, object), one)
f.add((int, int), two)
f.add((float, float), three)
assert one.__doc__.strip() in f.__doc__
assert two.__doc__.strip() in f.__doc__
assert f.__doc__.find(one.__doc__.strip()) < \
f.__doc__.find(two.__doc__.strip())
assert 'object, object' in f.__doc__
assert master_doc in f.__doc__
def test_help():
def one(x, y):
""" Docstring number one """
return x + y
def two(x, y):
""" Docstring number two """
return x + y
def three(x, y):
""" Docstring number three """
return x + y
master_doc = 'Doc of the multimethod itself'
f = Dispatcher('f', doc=master_doc)
f.add((object, object), one)
f.add((int, int), two)
f.add((float, float), three)
assert f._help(1, 1) == two.__doc__
assert f._help(1.0, 2.0) == three.__doc__
def test_source():
def one(x, y):
""" Docstring number one """
return x + y
def two(x, y):
""" Docstring number two """
return x - y
master_doc = 'Doc of the multimethod itself'
f = Dispatcher('f', doc=master_doc)
f.add((int, int), one)
f.add((float, float), two)
assert 'x + y' in f._source(1, 1)
assert 'x - y' in f._source(1.0, 1.0)
def test_source_raises_on_missing_function():
f = Dispatcher('f')
assert raises(TypeError, lambda: f.source(1))
def test_halt_method_resolution():
g = [0]
def on_ambiguity(a, b):
g[0] += 1
f = Dispatcher('f')
halt_ordering()
def func(*args):
pass
f.add((int, object), func)
f.add((object, int), func)
assert g == [0]
restart_ordering(on_ambiguity=on_ambiguity)
assert g == [1]
assert set(f.ordering) == {(int, object), (object, int)}
def test_no_implementations():
f = Dispatcher('f')
assert raises(NotImplementedError, lambda: f('hello'))
def test_register_stacking():
f = Dispatcher('f')
@f.register(list)
@f.register(tuple)
def rev(x):
return x[::-1]
assert f((1, 2, 3)) == (3, 2, 1)
assert f([1, 2, 3]) == [3, 2, 1]
assert raises(NotImplementedError, lambda: f('hello'))
assert rev('hello') == 'olleh'
def test_dispatch_method():
f = Dispatcher('f')
@f.register(list)
def rev(x):
return x[::-1]
@f.register(int, int)
def add(x, y):
return x + y
class MyList(list):
pass
assert f.dispatch(list) is rev
assert f.dispatch(MyList) is rev
assert f.dispatch(int, int) is add
def test_not_implemented():
f = Dispatcher('f')
@f.register(object)
def _(x):
return 'default'
@f.register(int)
def _(x):
if x % 2 == 0:
return 'even'
else:
raise MDNotImplementedError()
assert f('hello') == 'default' # default behavior
assert f(2) == 'even' # specialized behavior
assert f(3) == 'default' # fall bac to default behavior
assert raises(NotImplementedError, lambda: f(1, 2))
def test_not_implemented_error():
f = Dispatcher('f')
@f.register(float)
def _(a):
raise MDNotImplementedError()
assert raises(NotImplementedError, lambda: f(1.0))
def test_ambiguity_register_error_ignore_dup():
f = Dispatcher('f')
class A:
pass
class B(A):
pass
class C(A):
pass
# suppress warning for registering ambiguous signal
f.add((A, B), lambda x,y: None, ambiguity_register_error_ignore_dup)
f.add((B, A), lambda x,y: None, ambiguity_register_error_ignore_dup)
f.add((A, C), lambda x,y: None, ambiguity_register_error_ignore_dup)
f.add((C, A), lambda x,y: None, ambiguity_register_error_ignore_dup)
# raises error if ambiguous signal is passed
assert raises(NotImplementedError, lambda: f(B(), C()))