I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,11 @@
from .core import dispatch
from .dispatcher import (Dispatcher, halt_ordering, restart_ordering,
MDNotImplementedError)
__version__ = '0.4.9'
__all__ = [
'dispatch',
'Dispatcher', 'halt_ordering', 'restart_ordering', 'MDNotImplementedError',
]

View File

@ -0,0 +1,68 @@
from .utils import _toposort, groupby
class AmbiguityWarning(Warning):
pass
def supercedes(a, b):
""" A is consistent and strictly more specific than B """
return len(a) == len(b) and all(map(issubclass, a, b))
def consistent(a, b):
""" It is possible for an argument list to satisfy both A and B """
return (len(a) == len(b) and
all(issubclass(aa, bb) or issubclass(bb, aa)
for aa, bb in zip(a, b)))
def ambiguous(a, b):
""" A is consistent with B but neither is strictly more specific """
return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a))
def ambiguities(signatures):
""" All signature pairs such that A is ambiguous with B """
signatures = list(map(tuple, signatures))
return {(a, b) for a in signatures for b in signatures
if hash(a) < hash(b)
and ambiguous(a, b)
and not any(supercedes(c, a) and supercedes(c, b)
for c in signatures)}
def super_signature(signatures):
""" A signature that would break ambiguities """
n = len(signatures[0])
assert all(len(s) == n for s in signatures)
return [max([type.mro(sig[i]) for sig in signatures], key=len)[0]
for i in range(n)]
def edge(a, b, tie_breaker=hash):
""" A should be checked before B
Tie broken by tie_breaker, defaults to ``hash``
"""
if supercedes(a, b):
if supercedes(b, a):
return tie_breaker(a) > tie_breaker(b)
else:
return True
return False
def ordering(signatures):
""" A sane ordering of signatures to check, first to last
Topoological sort of edges as given by ``edge`` and ``supercedes``
"""
signatures = list(map(tuple, signatures))
edges = [(a, b) for a in signatures for b in signatures if edge(a, b)]
edges = groupby(lambda x: x[0], edges)
for s in signatures:
if s not in edges:
edges[s] = []
edges = {k: [b for a, b in v] for k, v in edges.items()}
return _toposort(edges)

View File

@ -0,0 +1,83 @@
from __future__ import annotations
from typing import Any
import inspect
from .dispatcher import Dispatcher, MethodDispatcher, ambiguity_warn
# XXX: This parameter to dispatch isn't documented and isn't used anywhere in
# sympy. Maybe it should just be removed.
global_namespace: dict[str, Any] = {}
def dispatch(*types, namespace=global_namespace, on_ambiguity=ambiguity_warn):
""" Dispatch function on the types of the inputs
Supports dispatch on all non-keyword arguments.
Collects implementations based on the function name. Ignores namespaces.
If ambiguous type signatures occur a warning is raised when the function is
defined suggesting the additional method to break the ambiguity.
Examples
--------
>>> from sympy.multipledispatch import dispatch
>>> @dispatch(int)
... def f(x):
... return x + 1
>>> @dispatch(float)
... def f(x): # noqa: F811
... return x - 1
>>> f(3)
4
>>> f(3.0)
2.0
Specify an isolated namespace with the namespace keyword argument
>>> my_namespace = dict()
>>> @dispatch(int, namespace=my_namespace)
... def foo(x):
... return x + 1
Dispatch on instance methods within classes
>>> class MyClass(object):
... @dispatch(list)
... def __init__(self, data):
... self.data = data
... @dispatch(int)
... def __init__(self, datum): # noqa: F811
... self.data = [datum]
"""
types = tuple(types)
def _(func):
name = func.__name__
if ismethod(func):
dispatcher = inspect.currentframe().f_back.f_locals.get(
name,
MethodDispatcher(name))
else:
if name not in namespace:
namespace[name] = Dispatcher(name)
dispatcher = namespace[name]
dispatcher.add(types, func, on_ambiguity=on_ambiguity)
return dispatcher
return _
def ismethod(func):
""" Is func a method?
Note that this has to work as the method is defined but before the class is
defined. At this stage methods look like functions.
"""
signature = inspect.signature(func)
return signature.parameters.get('self', None) is not None

View File

@ -0,0 +1,413 @@
from __future__ import annotations
from warnings import warn
import inspect
from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning
from .utils import expand_tuples
import itertools as itl
class MDNotImplementedError(NotImplementedError):
""" A NotImplementedError for multiple dispatch """
### Functions for on_ambiguity
def ambiguity_warn(dispatcher, ambiguities):
""" Raise warning when ambiguity is detected
Parameters
----------
dispatcher : Dispatcher
The dispatcher on which the ambiguity was detected
ambiguities : set
Set of type signature pairs that are ambiguous within this dispatcher
See Also:
Dispatcher.add
warning_text
"""
warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning)
class RaiseNotImplementedError:
"""Raise ``NotImplementedError`` when called."""
def __init__(self, dispatcher):
self.dispatcher = dispatcher
def __call__(self, *args, **kwargs):
types = tuple(type(a) for a in args)
raise NotImplementedError(
"Ambiguous signature for %s: <%s>" % (
self.dispatcher.name, str_signature(types)
))
def ambiguity_register_error_ignore_dup(dispatcher, ambiguities):
"""
If super signature for ambiguous types is duplicate types, ignore it.
Else, register instance of ``RaiseNotImplementedError`` for ambiguous types.
Parameters
----------
dispatcher : Dispatcher
The dispatcher on which the ambiguity was detected
ambiguities : set
Set of type signature pairs that are ambiguous within this dispatcher
See Also:
Dispatcher.add
ambiguity_warn
"""
for amb in ambiguities:
signature = tuple(super_signature(amb))
if len(set(signature)) == 1:
continue
dispatcher.add(
signature, RaiseNotImplementedError(dispatcher),
on_ambiguity=ambiguity_register_error_ignore_dup
)
###
_unresolved_dispatchers: set[Dispatcher] = set()
_resolve = [True]
def halt_ordering():
_resolve[0] = False
def restart_ordering(on_ambiguity=ambiguity_warn):
_resolve[0] = True
while _unresolved_dispatchers:
dispatcher = _unresolved_dispatchers.pop()
dispatcher.reorder(on_ambiguity=on_ambiguity)
class Dispatcher:
""" Dispatch methods based on type signature
Use ``dispatch`` to add implementations
Examples
--------
>>> from sympy.multipledispatch import dispatch
>>> @dispatch(int)
... def f(x):
... return x + 1
>>> @dispatch(float)
... def f(x): # noqa: F811
... return x - 1
>>> f(3)
4
>>> f(3.0)
2.0
"""
__slots__ = '__name__', 'name', 'funcs', 'ordering', '_cache', 'doc'
def __init__(self, name, doc=None):
self.name = self.__name__ = name
self.funcs = {}
self._cache = {}
self.ordering = []
self.doc = doc
def register(self, *types, **kwargs):
""" Register dispatcher with new implementation
>>> from sympy.multipledispatch.dispatcher import Dispatcher
>>> f = Dispatcher('f')
>>> @f.register(int)
... def inc(x):
... return x + 1
>>> @f.register(float)
... def dec(x):
... return x - 1
>>> @f.register(list)
... @f.register(tuple)
... def reverse(x):
... return x[::-1]
>>> f(1)
2
>>> f(1.0)
0.0
>>> f([1, 2, 3])
[3, 2, 1]
"""
def _(func):
self.add(types, func, **kwargs)
return func
return _
@classmethod
def get_func_params(cls, func):
if hasattr(inspect, "signature"):
sig = inspect.signature(func)
return sig.parameters.values()
@classmethod
def get_func_annotations(cls, func):
""" Get annotations of function positional parameters
"""
params = cls.get_func_params(func)
if params:
Parameter = inspect.Parameter
params = (param for param in params
if param.kind in
(Parameter.POSITIONAL_ONLY,
Parameter.POSITIONAL_OR_KEYWORD))
annotations = tuple(
param.annotation
for param in params)
if not any(ann is Parameter.empty for ann in annotations):
return annotations
def add(self, signature, func, on_ambiguity=ambiguity_warn):
""" Add new types/method pair to dispatcher
>>> from sympy.multipledispatch import Dispatcher
>>> D = Dispatcher('add')
>>> D.add((int, int), lambda x, y: x + y)
>>> D.add((float, float), lambda x, y: x + y)
>>> D(1, 2)
3
>>> D(1, 2.0)
Traceback (most recent call last):
...
NotImplementedError: Could not find signature for add: <int, float>
When ``add`` detects a warning it calls the ``on_ambiguity`` callback
with a dispatcher/itself, and a set of ambiguous type signature pairs
as inputs. See ``ambiguity_warn`` for an example.
"""
# Handle annotations
if not signature:
annotations = self.get_func_annotations(func)
if annotations:
signature = annotations
# Handle union types
if any(isinstance(typ, tuple) for typ in signature):
for typs in expand_tuples(signature):
self.add(typs, func, on_ambiguity)
return
for typ in signature:
if not isinstance(typ, type):
str_sig = ', '.join(c.__name__ if isinstance(c, type)
else str(c) for c in signature)
raise TypeError("Tried to dispatch on non-type: %s\n"
"In signature: <%s>\n"
"In function: %s" %
(typ, str_sig, self.name))
self.funcs[signature] = func
self.reorder(on_ambiguity=on_ambiguity)
self._cache.clear()
def reorder(self, on_ambiguity=ambiguity_warn):
if _resolve[0]:
self.ordering = ordering(self.funcs)
amb = ambiguities(self.funcs)
if amb:
on_ambiguity(self, amb)
else:
_unresolved_dispatchers.add(self)
def __call__(self, *args, **kwargs):
types = tuple([type(arg) for arg in args])
try:
func = self._cache[types]
except KeyError:
func = self.dispatch(*types)
if not func:
raise NotImplementedError(
'Could not find signature for %s: <%s>' %
(self.name, str_signature(types)))
self._cache[types] = func
try:
return func(*args, **kwargs)
except MDNotImplementedError:
funcs = self.dispatch_iter(*types)
next(funcs) # burn first
for func in funcs:
try:
return func(*args, **kwargs)
except MDNotImplementedError:
pass
raise NotImplementedError("Matching functions for "
"%s: <%s> found, but none completed successfully"
% (self.name, str_signature(types)))
def __str__(self):
return "<dispatched %s>" % self.name
__repr__ = __str__
def dispatch(self, *types):
""" Deterimine appropriate implementation for this type signature
This method is internal. Users should call this object as a function.
Implementation resolution occurs within the ``__call__`` method.
>>> from sympy.multipledispatch import dispatch
>>> @dispatch(int)
... def inc(x):
... return x + 1
>>> implementation = inc.dispatch(int)
>>> implementation(3)
4
>>> print(inc.dispatch(float))
None
See Also:
``sympy.multipledispatch.conflict`` - module to determine resolution order
"""
if types in self.funcs:
return self.funcs[types]
try:
return next(self.dispatch_iter(*types))
except StopIteration:
return None
def dispatch_iter(self, *types):
n = len(types)
for signature in self.ordering:
if len(signature) == n and all(map(issubclass, types, signature)):
result = self.funcs[signature]
yield result
def resolve(self, types):
""" Deterimine appropriate implementation for this type signature
.. deprecated:: 0.4.4
Use ``dispatch(*types)`` instead
"""
warn("resolve() is deprecated, use dispatch(*types)",
DeprecationWarning)
return self.dispatch(*types)
def __getstate__(self):
return {'name': self.name,
'funcs': self.funcs}
def __setstate__(self, d):
self.name = d['name']
self.funcs = d['funcs']
self.ordering = ordering(self.funcs)
self._cache = {}
@property
def __doc__(self):
docs = ["Multiply dispatched method: %s" % self.name]
if self.doc:
docs.append(self.doc)
other = []
for sig in self.ordering[::-1]:
func = self.funcs[sig]
if func.__doc__:
s = 'Inputs: <%s>\n' % str_signature(sig)
s += '-' * len(s) + '\n'
s += func.__doc__.strip()
docs.append(s)
else:
other.append(str_signature(sig))
if other:
docs.append('Other signatures:\n ' + '\n '.join(other))
return '\n\n'.join(docs)
def _help(self, *args):
return self.dispatch(*map(type, args)).__doc__
def help(self, *args, **kwargs):
""" Print docstring for the function corresponding to inputs """
print(self._help(*args))
def _source(self, *args):
func = self.dispatch(*map(type, args))
if not func:
raise TypeError("No function found")
return source(func)
def source(self, *args, **kwargs):
""" Print source code for the function corresponding to inputs """
print(self._source(*args))
def source(func):
s = 'File: %s\n\n' % inspect.getsourcefile(func)
s = s + inspect.getsource(func)
return s
class MethodDispatcher(Dispatcher):
""" Dispatch methods based on type signature
See Also:
Dispatcher
"""
@classmethod
def get_func_params(cls, func):
if hasattr(inspect, "signature"):
sig = inspect.signature(func)
return itl.islice(sig.parameters.values(), 1, None)
def __get__(self, instance, owner):
self.obj = instance
self.cls = owner
return self
def __call__(self, *args, **kwargs):
types = tuple([type(arg) for arg in args])
func = self.dispatch(*types)
if not func:
raise NotImplementedError('Could not find signature for %s: <%s>' %
(self.name, str_signature(types)))
return func(self.obj, *args, **kwargs)
def str_signature(sig):
""" String representation of type signature
>>> from sympy.multipledispatch.dispatcher import str_signature
>>> str_signature((int, float))
'int, float'
"""
return ', '.join(cls.__name__ for cls in sig)
def warning_text(name, amb):
""" The text for ambiguity warnings """
text = "\nAmbiguities exist in dispatched function %s\n\n" % (name)
text += "The following signatures may result in ambiguous behavior:\n"
for pair in amb:
text += "\t" + \
', '.join('[' + str_signature(s) + ']' for s in pair) + "\n"
text += "\n\nConsider making the following additions:\n\n"
text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s))
+ ')\ndef %s(...)' % name for s in amb])
return text

View File

@ -0,0 +1,62 @@
from sympy.multipledispatch.conflict import (supercedes, ordering, ambiguities,
ambiguous, super_signature, consistent)
class A: pass
class B(A): pass
class C: pass
def test_supercedes():
assert supercedes([B], [A])
assert supercedes([B, A], [A, A])
assert not supercedes([B, A], [A, B])
assert not supercedes([A], [B])
def test_consistent():
assert consistent([A], [A])
assert consistent([B], [B])
assert not consistent([A], [C])
assert consistent([A, B], [A, B])
assert consistent([B, A], [A, B])
assert not consistent([B, A], [B])
assert not consistent([B, A], [B, C])
def test_super_signature():
assert super_signature([[A]]) == [A]
assert super_signature([[A], [B]]) == [B]
assert super_signature([[A, B], [B, A]]) == [B, B]
assert super_signature([[A, A, B], [A, B, A], [B, A, A]]) == [B, B, B]
def test_ambiguous():
assert not ambiguous([A], [A])
assert not ambiguous([A], [B])
assert not ambiguous([B], [B])
assert not ambiguous([A, B], [B, B])
assert ambiguous([A, B], [B, A])
def test_ambiguities():
signatures = [[A], [B], [A, B], [B, A], [A, C]]
expected = {((A, B), (B, A))}
result = ambiguities(signatures)
assert set(map(frozenset, expected)) == set(map(frozenset, result))
signatures = [[A], [B], [A, B], [B, A], [A, C], [B, B]]
expected = set()
result = ambiguities(signatures)
assert set(map(frozenset, expected)) == set(map(frozenset, result))
def test_ordering():
signatures = [[A, A], [A, B], [B, A], [B, B], [A, C]]
ord = ordering(signatures)
assert ord[0] == (B, B) or ord[0] == (A, C)
assert ord[-1] == (A, A) or ord[-1] == (A, C)
def test_type_mro():
assert super_signature([[object], [type]]) == [type]

View File

@ -0,0 +1,213 @@
from __future__ import annotations
from typing import Any
from sympy.multipledispatch import dispatch
from sympy.multipledispatch.conflict import AmbiguityWarning
from sympy.testing.pytest import raises, warns
from functools import partial
test_namespace: dict[str, Any] = {}
orig_dispatch = dispatch
dispatch = partial(dispatch, namespace=test_namespace)
def test_singledispatch():
@dispatch(int)
def f(x): # noqa:F811
return x + 1
@dispatch(int)
def g(x): # noqa:F811
return x + 2
@dispatch(float) # noqa:F811
def f(x): # noqa:F811
return x - 1
assert f(1) == 2
assert g(1) == 3
assert f(1.0) == 0
assert raises(NotImplementedError, lambda: f('hello'))
def test_multipledispatch():
@dispatch(int, int)
def f(x, y): # noqa:F811
return x + y
@dispatch(float, float) # noqa:F811
def f(x, y): # noqa:F811
return x - y
assert f(1, 2) == 3
assert f(1.0, 2.0) == -1.0
class A: pass
class B: pass
class C(A): pass
class D(C): pass
class E(C): pass
def test_inheritance():
@dispatch(A)
def f(x): # noqa:F811
return 'a'
@dispatch(B) # noqa:F811
def f(x): # noqa:F811
return 'b'
assert f(A()) == 'a'
assert f(B()) == 'b'
assert f(C()) == 'a'
def test_inheritance_and_multiple_dispatch():
@dispatch(A, A)
def f(x, y): # noqa:F811
return type(x), type(y)
@dispatch(A, B) # noqa:F811
def f(x, y): # noqa:F811
return 0
assert f(A(), A()) == (A, A)
assert f(A(), C()) == (A, C)
assert f(A(), B()) == 0
assert f(C(), B()) == 0
assert raises(NotImplementedError, lambda: f(B(), B()))
def test_competing_solutions():
@dispatch(A)
def h(x): # noqa:F811
return 1
@dispatch(C) # noqa:F811
def h(x): # noqa:F811
return 2
assert h(D()) == 2
def test_competing_multiple():
@dispatch(A, B)
def h(x, y): # noqa:F811
return 1
@dispatch(C, B) # noqa:F811
def h(x, y): # noqa:F811
return 2
assert h(D(), B()) == 2
def test_competing_ambiguous():
test_namespace = {}
dispatch = partial(orig_dispatch, namespace=test_namespace)
@dispatch(A, C)
def f(x, y): # noqa:F811
return 2
with warns(AmbiguityWarning, test_stacklevel=False):
@dispatch(C, A) # noqa:F811
def f(x, y): # noqa:F811
return 2
assert f(A(), C()) == f(C(), A()) == 2
# assert raises(Warning, lambda : f(C(), C()))
def test_caching_correct_behavior():
@dispatch(A)
def f(x): # noqa:F811
return 1
assert f(C()) == 1
@dispatch(C)
def f(x): # noqa:F811
return 2
assert f(C()) == 2
def test_union_types():
@dispatch((A, C))
def f(x): # noqa:F811
return 1
assert f(A()) == 1
assert f(C()) == 1
def test_namespaces():
ns1 = {}
ns2 = {}
def foo(x):
return 1
foo1 = orig_dispatch(int, namespace=ns1)(foo)
def foo(x):
return 2
foo2 = orig_dispatch(int, namespace=ns2)(foo)
assert foo1(0) == 1
assert foo2(0) == 2
"""
Fails
def test_dispatch_on_dispatch():
@dispatch(A)
@dispatch(C)
def q(x): # noqa:F811
return 1
assert q(A()) == 1
assert q(C()) == 1
"""
def test_methods():
class Foo:
@dispatch(float)
def f(self, x): # noqa:F811
return x - 1
@dispatch(int) # noqa:F811
def f(self, x): # noqa:F811
return x + 1
@dispatch(int)
def g(self, x): # noqa:F811
return x + 3
foo = Foo()
assert foo.f(1) == 2
assert foo.f(1.0) == 0.0
assert foo.g(1) == 4
def test_methods_multiple_dispatch():
class Foo:
@dispatch(A, A)
def f(x, y): # noqa:F811
return 1
@dispatch(A, C) # noqa:F811
def f(x, y): # noqa:F811
return 2
foo = Foo()
assert foo.f(A(), A()) == 1
assert foo.f(A(), C()) == 2
assert foo.f(C(), C()) == 2

View File

@ -0,0 +1,284 @@
from sympy.multipledispatch.dispatcher import (Dispatcher, MDNotImplementedError,
MethodDispatcher, halt_ordering,
restart_ordering,
ambiguity_register_error_ignore_dup)
from sympy.testing.pytest import raises, warns
def identity(x):
return x
def inc(x):
return x + 1
def dec(x):
return x - 1
def test_dispatcher():
f = Dispatcher('f')
f.add((int,), inc)
f.add((float,), dec)
with warns(DeprecationWarning, test_stacklevel=False):
assert f.resolve((int,)) == inc
assert f.dispatch(int) is inc
assert f(1) == 2
assert f(1.0) == 0.0
def test_union_types():
f = Dispatcher('f')
f.register((int, float))(inc)
assert f(1) == 2
assert f(1.0) == 2.0
def test_dispatcher_as_decorator():
f = Dispatcher('f')
@f.register(int)
def inc(x): # noqa:F811
return x + 1
@f.register(float) # noqa:F811
def inc(x): # noqa:F811
return x - 1
assert f(1) == 2
assert f(1.0) == 0.0
def test_register_instance_method():
class Test:
__init__ = MethodDispatcher('f')
@__init__.register(list)
def _init_list(self, data):
self.data = data
@__init__.register(object)
def _init_obj(self, datum):
self.data = [datum]
a = Test(3)
b = Test([3])
assert a.data == b.data
def test_on_ambiguity():
f = Dispatcher('f')
def identity(x): return x
ambiguities = [False]
def on_ambiguity(dispatcher, amb):
ambiguities[0] = True
f.add((object, object), identity, on_ambiguity=on_ambiguity)
assert not ambiguities[0]
f.add((object, float), identity, on_ambiguity=on_ambiguity)
assert not ambiguities[0]
f.add((float, object), identity, on_ambiguity=on_ambiguity)
assert ambiguities[0]
def test_raise_error_on_non_class():
f = Dispatcher('f')
assert raises(TypeError, lambda: f.add((1,), inc))
def test_docstring():
def one(x, y):
""" Docstring number one """
return x + y
def two(x, y):
""" Docstring number two """
return x + y
def three(x, y):
return x + y
master_doc = 'Doc of the multimethod itself'
f = Dispatcher('f', doc=master_doc)
f.add((object, object), one)
f.add((int, int), two)
f.add((float, float), three)
assert one.__doc__.strip() in f.__doc__
assert two.__doc__.strip() in f.__doc__
assert f.__doc__.find(one.__doc__.strip()) < \
f.__doc__.find(two.__doc__.strip())
assert 'object, object' in f.__doc__
assert master_doc in f.__doc__
def test_help():
def one(x, y):
""" Docstring number one """
return x + y
def two(x, y):
""" Docstring number two """
return x + y
def three(x, y):
""" Docstring number three """
return x + y
master_doc = 'Doc of the multimethod itself'
f = Dispatcher('f', doc=master_doc)
f.add((object, object), one)
f.add((int, int), two)
f.add((float, float), three)
assert f._help(1, 1) == two.__doc__
assert f._help(1.0, 2.0) == three.__doc__
def test_source():
def one(x, y):
""" Docstring number one """
return x + y
def two(x, y):
""" Docstring number two """
return x - y
master_doc = 'Doc of the multimethod itself'
f = Dispatcher('f', doc=master_doc)
f.add((int, int), one)
f.add((float, float), two)
assert 'x + y' in f._source(1, 1)
assert 'x - y' in f._source(1.0, 1.0)
def test_source_raises_on_missing_function():
f = Dispatcher('f')
assert raises(TypeError, lambda: f.source(1))
def test_halt_method_resolution():
g = [0]
def on_ambiguity(a, b):
g[0] += 1
f = Dispatcher('f')
halt_ordering()
def func(*args):
pass
f.add((int, object), func)
f.add((object, int), func)
assert g == [0]
restart_ordering(on_ambiguity=on_ambiguity)
assert g == [1]
assert set(f.ordering) == {(int, object), (object, int)}
def test_no_implementations():
f = Dispatcher('f')
assert raises(NotImplementedError, lambda: f('hello'))
def test_register_stacking():
f = Dispatcher('f')
@f.register(list)
@f.register(tuple)
def rev(x):
return x[::-1]
assert f((1, 2, 3)) == (3, 2, 1)
assert f([1, 2, 3]) == [3, 2, 1]
assert raises(NotImplementedError, lambda: f('hello'))
assert rev('hello') == 'olleh'
def test_dispatch_method():
f = Dispatcher('f')
@f.register(list)
def rev(x):
return x[::-1]
@f.register(int, int)
def add(x, y):
return x + y
class MyList(list):
pass
assert f.dispatch(list) is rev
assert f.dispatch(MyList) is rev
assert f.dispatch(int, int) is add
def test_not_implemented():
f = Dispatcher('f')
@f.register(object)
def _(x):
return 'default'
@f.register(int)
def _(x):
if x % 2 == 0:
return 'even'
else:
raise MDNotImplementedError()
assert f('hello') == 'default' # default behavior
assert f(2) == 'even' # specialized behavior
assert f(3) == 'default' # fall bac to default behavior
assert raises(NotImplementedError, lambda: f(1, 2))
def test_not_implemented_error():
f = Dispatcher('f')
@f.register(float)
def _(a):
raise MDNotImplementedError()
assert raises(NotImplementedError, lambda: f(1.0))
def test_ambiguity_register_error_ignore_dup():
f = Dispatcher('f')
class A:
pass
class B(A):
pass
class C(A):
pass
# suppress warning for registering ambiguous signal
f.add((A, B), lambda x,y: None, ambiguity_register_error_ignore_dup)
f.add((B, A), lambda x,y: None, ambiguity_register_error_ignore_dup)
f.add((A, C), lambda x,y: None, ambiguity_register_error_ignore_dup)
f.add((C, A), lambda x,y: None, ambiguity_register_error_ignore_dup)
# raises error if ambiguous signal is passed
assert raises(NotImplementedError, lambda: f(B(), C()))

View File

@ -0,0 +1,105 @@
from collections import OrderedDict
def expand_tuples(L):
"""
>>> from sympy.multipledispatch.utils import expand_tuples
>>> expand_tuples([1, (2, 3)])
[(1, 2), (1, 3)]
>>> expand_tuples([1, 2])
[(1, 2)]
"""
if not L:
return [()]
elif not isinstance(L[0], tuple):
rest = expand_tuples(L[1:])
return [(L[0],) + t for t in rest]
else:
rest = expand_tuples(L[1:])
return [(item,) + t for t in rest for item in L[0]]
# Taken from theano/theano/gof/sched.py
# Avoids licensing issues because this was written by Matthew Rocklin
def _toposort(edges):
""" Topological sort algorithm by Kahn [1] - O(nodes + vertices)
inputs:
edges - a dict of the form {a: {b, c}} where b and c depend on a
outputs:
L - an ordered list of nodes that satisfy the dependencies of edges
>>> from sympy.multipledispatch.utils import _toposort
>>> _toposort({1: (2, 3), 2: (3, )})
[1, 2, 3]
Closely follows the wikipedia page [2]
[1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
Communications of the ACM
[2] https://en.wikipedia.org/wiki/Toposort#Algorithms
"""
incoming_edges = reverse_dict(edges)
incoming_edges = {k: set(val) for k, val in incoming_edges.items()}
S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges)
L = []
while S:
n, _ = S.popitem()
L.append(n)
for m in edges.get(n, ()):
assert n in incoming_edges[m]
incoming_edges[m].remove(n)
if not incoming_edges[m]:
S[m] = None
if any(incoming_edges.get(v, None) for v in edges):
raise ValueError("Input has cycles")
return L
def reverse_dict(d):
"""Reverses direction of dependence dict
>>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
>>> reverse_dict(d) # doctest: +SKIP
{1: ('a',), 2: ('a', 'b'), 3: ('b',)}
:note: dict order are not deterministic. As we iterate on the
input dict, it make the output of this function depend on the
dict order. So this function output order should be considered
as undeterministic.
"""
result = {}
for key in d:
for val in d[key]:
result[val] = result.get(val, ()) + (key, )
return result
# Taken from toolz
# Avoids licensing issues because this version was authored by Matthew Rocklin
def groupby(func, seq):
""" Group a collection by a key function
>>> from sympy.multipledispatch.utils import groupby
>>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
>>> groupby(len, names) # doctest: +SKIP
{3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
>>> iseven = lambda x: x % 2 == 0
>>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP
{False: [1, 3, 5, 7], True: [2, 4, 6, 8]}
See Also:
``countby``
"""
d = {}
for item in seq:
key = func(item)
if key not in d:
d[key] = []
d[key].append(item)
return d