I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,4 @@
# mypy: disable-error-code=attr-defined
from .core import unify, reify # noqa: F403
from .more import unifiable # noqa: F403
from .variable import var, isvar, vars, variables, Var # noqa: F403

View File

@ -0,0 +1,119 @@
# mypy: allow-untyped-defs
from collections.abc import Iterator # type: ignore[import]
from functools import partial
from .unification_tools import assoc # type: ignore[import]
from .utils import transitive_get as walk
from .variable import isvar
from .dispatch import dispatch
__all__ = ["reify", "unify"]
###############
# Reification #
###############
@dispatch(Iterator, dict)
def _reify(t, s):
return map(partial(reify, s=s), t)
# return (reify(arg, s) for arg in t)
_reify
@dispatch(tuple, dict) # type: ignore[no-redef]
def _reify(t, s):
return tuple(reify(iter(t), s))
_reify
@dispatch(list, dict) # type: ignore[no-redef]
def _reify(t, s):
return list(reify(iter(t), s))
_reify
@dispatch(dict, dict) # type: ignore[no-redef]
def _reify(d, s):
return {k: reify(v, s) for k, v in d.items()}
_reify
@dispatch(object, dict) # type: ignore[no-redef]
def _reify(o, s):
return o # catch all, just return the object
def reify(e, s):
""" Replace variables of expression with substitution
>>> # xdoctest: +SKIP
>>> x, y = var(), var()
>>> e = (1, x, (3, y))
>>> s = {x: 2, y: 4}
>>> reify(e, s)
(1, 2, (3, 4))
>>> e = {1: x, 3: (y, 5)}
>>> reify(e, s)
{1: 2, 3: (4, 5)}
"""
if isvar(e):
return reify(s[e], s) if e in s else e
return _reify(e, s)
###############
# Unification #
###############
seq = tuple, list, Iterator
@dispatch(seq, seq, dict)
def _unify(u, v, s):
if len(u) != len(v):
return False
for uu, vv in zip(u, v): # avoiding recursion
s = unify(uu, vv, s)
if s is False:
return False
return s
#
# @dispatch((set, frozenset), (set, frozenset), dict)
# def _unify(u, v, s):
# i = u & v
# u = u - i
# v = v - i
# return _unify(sorted(u), sorted(v), s)
#
#
# @dispatch(dict, dict, dict)
# def _unify(u, v, s):
# if len(u) != len(v):
# return False
# for key, uval in iteritems(u):
# if key not in v:
# return False
# s = unify(uval, v[key], s)
# if s is False:
# return False
# return s
#
#
# @dispatch(object, object, dict)
# def _unify(u, v, s):
# return False # catch all
@dispatch(object, object, dict)
def unify(u, v, s): # no check at the moment
""" Find substitution so that u == v while satisfying s
>>> x = var('x')
>>> unify((1, x), (1, 2), {})
{~x: 2}
"""
u = walk(u, s)
v = walk(v, s)
if u == v:
return s
if isvar(u):
return assoc(s, u, v)
if isvar(v):
return assoc(s, v, u)
return _unify(u, v, s)
unify
@dispatch(object, object) # type: ignore[no-redef]
def unify(u, v):
return unify(u, v, {})

View File

@ -0,0 +1,6 @@
from functools import partial
from .multipledispatch import dispatch # type: ignore[import]
namespace = {} # type: ignore[var-annotated]
dispatch = partial(dispatch, namespace=namespace)

View File

@ -0,0 +1,122 @@
# mypy: allow-untyped-defs
from .core import unify, reify # type: ignore[attr-defined]
from .variable import isvar
from .utils import _toposort, freeze
from .unification_tools import groupby, first # type: ignore[import]
class Dispatcher:
def __init__(self, name):
self.name = name
self.funcs = {}
self.ordering = []
def add(self, signature, func):
self.funcs[freeze(signature)] = func
self.ordering = ordering(self.funcs)
def __call__(self, *args, **kwargs):
func, s = self.resolve(args)
return func(*args, **kwargs)
def resolve(self, args):
n = len(args)
for signature in self.ordering:
if len(signature) != n:
continue
s = unify(freeze(args), signature)
if s is not False:
result = self.funcs[signature]
return result, s
raise NotImplementedError("No match found. \nKnown matches: "
+ str(self.ordering) + "\nInput: " + str(args))
def register(self, *signature):
def _(func):
self.add(signature, func)
return self
return _
class VarDispatcher(Dispatcher):
""" A dispatcher that calls functions with variable names
>>> # xdoctest: +SKIP
>>> d = VarDispatcher('d')
>>> x = var('x')
>>> @d.register('inc', x)
... def f(x):
... return x + 1
>>> @d.register('double', x)
... def f(x):
... return x * 2
>>> d('inc', 10)
11
>>> d('double', 10)
20
"""
def __call__(self, *args, **kwargs):
func, s = self.resolve(args)
d = {k.token: v for k, v in s.items()}
return func(**d)
global_namespace = {} # type: ignore[var-annotated]
def match(*signature, **kwargs):
namespace = kwargs.get('namespace', global_namespace)
dispatcher = kwargs.get('Dispatcher', Dispatcher)
def _(func):
name = func.__name__
if name not in namespace:
namespace[name] = dispatcher(name)
d = namespace[name]
d.add(signature, func)
return d
return _
def supercedes(a, b):
""" ``a`` is a more specific match than ``b`` """
if isvar(b) and not isvar(a):
return True
s = unify(a, b)
if s is False:
return False
s = {k: v for k, v in s.items() if not isvar(k) or not isvar(v)}
if reify(a, s) == a:
return True
if reify(b, s) == b:
return False
# Taken from multipledispatch
def edge(a, b, tie_breaker=hash):
""" A should be checked before B
Tie broken by tie_breaker, defaults to ``hash``
"""
if supercedes(a, b):
if supercedes(b, a):
return tie_breaker(a) > tie_breaker(b)
else:
return True
return False
# Taken from multipledispatch
def ordering(signatures):
""" A sane ordering of signatures to check, first to last
Topological sort of edges as given by ``edge`` and ``supercedes``
"""
signatures = list(map(tuple, signatures))
edges = [(a, b) for a in signatures for b in signatures if edge(a, b)]
edges = groupby(first, edges)
for s in signatures:
if s not in edges:
edges[s] = []
edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[attr-defined, assignment]
return _toposort(edges)

View File

@ -0,0 +1,118 @@
# mypy: allow-untyped-defs
from .core import unify, reify # type: ignore[attr-defined]
from .dispatch import dispatch
def unifiable(cls):
""" Register standard unify and reify operations on class
This uses the type and __dict__ or __slots__ attributes to define the
nature of the term
See Also:
>>> # xdoctest: +SKIP
>>> class A(object):
... def __init__(self, a, b):
... self.a = a
... self.b = b
>>> unifiable(A)
<class 'unification.more.A'>
>>> x = var('x')
>>> a = A(1, 2)
>>> b = A(1, x)
>>> unify(a, b, {})
{~x: 2}
"""
_unify.add((cls, cls, dict), unify_object)
_reify.add((cls, dict), reify_object)
return cls
#########
# Reify #
#########
def reify_object(o, s):
""" Reify a Python object with a substitution
>>> # xdoctest: +SKIP
>>> class Foo(object):
... def __init__(self, a, b):
... self.a = a
... self.b = b
... def __str__(self):
... return "Foo(%s, %s)"%(str(self.a), str(self.b))
>>> x = var('x')
>>> f = Foo(1, x)
>>> print(f)
Foo(1, ~x)
>>> print(reify_object(f, {x: 2}))
Foo(1, 2)
"""
if hasattr(o, '__slots__'):
return _reify_object_slots(o, s)
else:
return _reify_object_dict(o, s)
def _reify_object_dict(o, s):
obj = object.__new__(type(o))
d = reify(o.__dict__, s)
if d == o.__dict__:
return o
obj.__dict__.update(d)
return obj
def _reify_object_slots(o, s):
attrs = [getattr(o, attr) for attr in o.__slots__]
new_attrs = reify(attrs, s)
if attrs == new_attrs:
return o
else:
newobj = object.__new__(type(o))
for slot, attr in zip(o.__slots__, new_attrs):
setattr(newobj, slot, attr)
return newobj
@dispatch(slice, dict)
def _reify(o, s):
""" Reify a Python ``slice`` object """
return slice(*reify((o.start, o.stop, o.step), s))
#########
# Unify #
#########
def unify_object(u, v, s):
""" Unify two Python objects
Unifies their type and ``__dict__`` attributes
>>> # xdoctest: +SKIP
>>> class Foo(object):
... def __init__(self, a, b):
... self.a = a
... self.b = b
... def __str__(self):
... return "Foo(%s, %s)"%(str(self.a), str(self.b))
>>> x = var('x')
>>> f = Foo(1, x)
>>> g = Foo(1, 2)
>>> unify_object(f, g, {})
{~x: 2}
"""
if type(u) != type(v):
return False
if hasattr(u, '__slots__'):
return unify([getattr(u, slot) for slot in u.__slots__],
[getattr(v, slot) for slot in v.__slots__],
s)
else:
return unify(u.__dict__, v.__dict__, s)
@dispatch(slice, slice, dict)
def _unify(u, v, s):
""" Unify a Python ``slice`` object """
return unify((u.start, u.stop, u.step), (v.start, v.stop, v.step), s)

View File

@ -0,0 +1,3 @@
from .core import dispatch
from .dispatcher import (Dispatcher, halt_ordering, restart_ordering,
MDNotImplementedError)

View File

@ -0,0 +1,121 @@
# mypy: allow-untyped-defs
from .utils import _toposort, groupby
from .variadic import isvariadic
import operator
__all__ = ["AmbiguityWarning", "supercedes", "consistent", "ambiguous", "ambiguities", "super_signature",
"edge", "ordering"]
class AmbiguityWarning(Warning):
pass
def supercedes(a, b):
""" A is consistent and strictly more specific than B """
if len(a) < len(b):
# only case is if a is empty and b is variadic
return not a and len(b) == 1 and isvariadic(b[-1])
elif len(a) == len(b):
return all(map(issubclass, a, b))
else:
# len(a) > len(b)
p1 = 0
p2 = 0
while p1 < len(a) and p2 < len(b):
cur_a = a[p1]
cur_b = b[p2]
if not (isvariadic(cur_a) or isvariadic(cur_b)):
if not issubclass(cur_a, cur_b):
return False
p1 += 1
p2 += 1
elif isvariadic(cur_a):
assert p1 == len(a) - 1
return p2 == len(b) - 1 and issubclass(cur_a, cur_b)
elif isvariadic(cur_b):
assert p2 == len(b) - 1
if not issubclass(cur_a, cur_b):
return False
p1 += 1
return p2 == len(b) - 1 and p1 == len(a)
def consistent(a, b):
""" It is possible for an argument list to satisfy both A and B """
# Need to check for empty args
if not a:
return not b or isvariadic(b[0])
if not b:
return not a or isvariadic(a[0])
# Non-empty args check for mutual subclasses
if len(a) == len(b):
return all(issubclass(aa, bb) or issubclass(bb, aa)
for aa, bb in zip(a, b))
else:
p1 = 0
p2 = 0
while p1 < len(a) and p2 < len(b):
cur_a = a[p1]
cur_b = b[p2]
if not issubclass(cur_b, cur_a) and not issubclass(cur_a, cur_b):
return False
if not (isvariadic(cur_a) or isvariadic(cur_b)):
p1 += 1
p2 += 1
elif isvariadic(cur_a):
p2 += 1
elif isvariadic(cur_b):
p1 += 1
# We only need to check for variadic ends
# Variadic types are guaranteed to be the last element
return (isvariadic(cur_a) and p2 == len(b) or # type: ignore[possibly-undefined]
isvariadic(cur_b) and p1 == len(a)) # type: ignore[possibly-undefined]
def ambiguous(a, b):
""" A is consistent with B but neither is strictly more specific """
return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a))
def ambiguities(signatures):
""" All signature pairs such that A is ambiguous with B """
signatures = list(map(tuple, signatures))
return {(a, b) for a in signatures for b in signatures
if hash(a) < hash(b)
and ambiguous(a, b)
and not any(supercedes(c, a) and supercedes(c, b)
for c in signatures)}
def super_signature(signatures):
""" A signature that would break ambiguities """
n = len(signatures[0])
assert all(len(s) == n for s in signatures)
return [max((type.mro(sig[i]) for sig in signatures), key=len)[0]
for i in range(n)]
def edge(a, b, tie_breaker=hash):
""" A should be checked before B
Tie broken by tie_breaker, defaults to ``hash``
"""
# A either supercedes B and B does not supercede A or if B does then call
# tie_breaker
return supercedes(a, b) and (not supercedes(b, a) or tie_breaker(a) > tie_breaker(b))
def ordering(signatures):
""" A sane ordering of signatures to check, first to last
Topological sort of edges as given by ``edge`` and ``supercedes``
"""
signatures = list(map(tuple, signatures))
edges = [(a, b) for a in signatures for b in signatures if edge(a, b)]
edges = groupby(operator.itemgetter(0), edges)
for s in signatures:
if s not in edges:
edges[s] = []
edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[assignment, attr-defined]
return _toposort(edges)

View File

@ -0,0 +1,84 @@
# mypy: allow-untyped-defs
import inspect
import sys
from .dispatcher import Dispatcher, MethodDispatcher
global_namespace = {} # type: ignore[var-annotated]
__all__ = ["dispatch", "ismethod"]
def dispatch(*types, **kwargs):
""" Dispatch function on the types of the inputs
Supports dispatch on all non-keyword arguments.
Collects implementations based on the function name. Ignores namespaces.
If ambiguous type signatures occur a warning is raised when the function is
defined suggesting the additional method to break the ambiguity.
Example:
>>> # xdoctest: +SKIP
>>> @dispatch(int)
... def f(x):
... return x + 1
>>> @dispatch(float)
... def f(x):
... return x - 1
>>> # xdoctest: +SKIP
>>> f(3)
4
>>> f(3.0)
2.0
>>> # Specify an isolated namespace with the namespace keyword argument
>>> my_namespace = {}
>>> @dispatch(int, namespace=my_namespace)
... def foo(x):
... return x + 1
>>> # Dispatch on instance methods within classes
>>> class MyClass(object):
... @dispatch(list)
... def __init__(self, data):
... self.data = data
... @dispatch(int)
... def __init__(self, datum):
... self.data = [datum]
>>> MyClass([1, 2, 3]).data
[1, 2, 3]
>>> MyClass(3).data
[3]
"""
namespace = kwargs.get('namespace', global_namespace)
types = tuple(types)
def _df(func):
name = func.__name__
if ismethod(func):
dispatcher = inspect.currentframe().f_back.f_locals.get( # type: ignore[union-attr]
name, # type: ignore[union-attr]
MethodDispatcher(name),
)
else:
if name not in namespace:
namespace[name] = Dispatcher(name)
dispatcher = namespace[name]
dispatcher.add(types, func)
return dispatcher
return _df
def ismethod(func):
""" Is func a method?
Note that this has to work as the method is defined but before the class is
defined. At this stage methods look like functions.
"""
if hasattr(inspect, "signature"):
signature = inspect.signature(func)
return signature.parameters.get('self', None) is not None
else:
if sys.version_info.major < 3:
spec = inspect.getargspec(func) # type: ignore[attr-defined]
else:
spec = inspect.getfullargspec(func) # type: ignore[union-attr, assignment]
return spec and spec.args and spec.args[0] == 'self'

View File

@ -0,0 +1,427 @@
# mypy: allow-untyped-defs
from warnings import warn
import inspect
from typing_extensions import deprecated
from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning
from .utils import expand_tuples
from .variadic import Variadic, isvariadic
import itertools as itl
__all__ = ["MDNotImplementedError", "ambiguity_warn", "halt_ordering", "restart_ordering", "variadic_signature_matches_iter",
"variadic_signature_matches", "Dispatcher", "source", "MethodDispatcher", "str_signature", "warning_text"]
class MDNotImplementedError(NotImplementedError):
""" A NotImplementedError for multiple dispatch """
def ambiguity_warn(dispatcher, ambiguities):
""" Raise warning when ambiguity is detected
Parameters
----------
dispatcher : Dispatcher
The dispatcher on which the ambiguity was detected
ambiguities : set
Set of type signature pairs that are ambiguous within this dispatcher
See Also:
Dispatcher.add
warning_text
"""
warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning)
@deprecated(
"`halt_ordering` is deprecated, you can safely remove this call.",
category=FutureWarning,
)
def halt_ordering():
"""Deprecated interface to temporarily disable ordering."""
@deprecated(
"`restart_ordering` is deprecated, if you would like to eagerly order the dispatchers, "
"you should call the `reorder()` method on each dispatcher.",
category=FutureWarning,
)
def restart_ordering(on_ambiguity=ambiguity_warn):
"""Deprecated interface to temporarily resume ordering."""
def variadic_signature_matches_iter(types, full_signature):
"""Check if a set of input types matches a variadic signature.
Notes
-----
The algorithm is as follows:
Initialize the current signature to the first in the sequence
For each type in `types`:
If the current signature is variadic
If the type matches the signature
yield True
Else
Try to get the next signature
If no signatures are left we can't possibly have a match
so yield False
Else
yield True if the type matches the current signature
Get the next signature
"""
sigiter = iter(full_signature)
sig = next(sigiter)
for typ in types:
matches = issubclass(typ, sig)
yield matches
if not isvariadic(sig):
# we're not matching a variadic argument, so move to the next
# element in the signature
sig = next(sigiter)
else:
try:
sig = next(sigiter)
except StopIteration:
assert isvariadic(sig)
yield True
else:
# We have signature items left over, so all of our arguments
# haven't matched
yield False
def variadic_signature_matches(types, full_signature):
# No arguments always matches a variadic signature
assert full_signature
return all(variadic_signature_matches_iter(types, full_signature))
class Dispatcher:
""" Dispatch methods based on type signature
Use ``dispatch`` to add implementations
Examples
--------
>>> # xdoctest: +SKIP("bad import name")
>>> from multipledispatch import dispatch
>>> @dispatch(int)
... def f(x):
... return x + 1
>>> @dispatch(float)
... def f(x):
... return x - 1
>>> f(3)
4
>>> f(3.0)
2.0
"""
__slots__ = '__name__', 'name', 'funcs', '_ordering', '_cache', 'doc'
def __init__(self, name, doc=None):
self.name = self.__name__ = name
self.funcs = {}
self.doc = doc
self._cache = {}
def register(self, *types, **kwargs):
""" register dispatcher with new implementation
>>> # xdoctest: +SKIP
>>> f = Dispatcher('f')
>>> @f.register(int)
... def inc(x):
... return x + 1
>>> @f.register(float)
... def dec(x):
... return x - 1
>>> @f.register(list)
... @f.register(tuple)
... def reverse(x):
... return x[::-1]
>>> f(1)
2
>>> f(1.0)
0.0
>>> f([1, 2, 3])
[3, 2, 1]
"""
def _df(func):
self.add(types, func, **kwargs) # type: ignore[call-arg]
return func
return _df
@classmethod
def get_func_params(cls, func):
if hasattr(inspect, "signature"):
sig = inspect.signature(func)
return sig.parameters.values()
@classmethod
def get_func_annotations(cls, func):
""" get annotations of function positional parameters
"""
params = cls.get_func_params(func)
if params:
Parameter = inspect.Parameter
params = (param for param in params
if param.kind in
(Parameter.POSITIONAL_ONLY,
Parameter.POSITIONAL_OR_KEYWORD))
annotations = tuple(
param.annotation
for param in params)
if all(ann is not Parameter.empty for ann in annotations):
return annotations
def add(self, signature, func):
""" Add new types/method pair to dispatcher
>>> # xdoctest: +SKIP
>>> D = Dispatcher('add')
>>> D.add((int, int), lambda x, y: x + y)
>>> D.add((float, float), lambda x, y: x + y)
>>> D(1, 2)
3
>>> D(1, 2.0)
Traceback (most recent call last):
...
NotImplementedError: Could not find signature for add: <int, float>
>>> # When ``add`` detects a warning it calls the ``on_ambiguity`` callback
>>> # with a dispatcher/itself, and a set of ambiguous type signature pairs
>>> # as inputs. See ``ambiguity_warn`` for an example.
"""
# Handle annotations
if not signature:
annotations = self.get_func_annotations(func)
if annotations:
signature = annotations
# Handle union types
if any(isinstance(typ, tuple) for typ in signature):
for typs in expand_tuples(signature):
self.add(typs, func)
return
new_signature = []
for index, typ in enumerate(signature, start=1):
if not isinstance(typ, (type, list)):
str_sig = ', '.join(c.__name__ if isinstance(c, type)
else str(c) for c in signature)
raise TypeError(f"Tried to dispatch on non-type: {typ}\n"
f"In signature: <{str_sig}>\n"
f"In function: {self.name}")
# handle variadic signatures
if isinstance(typ, list):
if index != len(signature):
raise TypeError(
'Variadic signature must be the last element'
)
if len(typ) != 1:
raise TypeError(
'Variadic signature must contain exactly one element. '
'To use a variadic union type place the desired types '
'inside of a tuple, e.g., [(int, str)]'
)
new_signature.append(Variadic[typ[0]])
else:
new_signature.append(typ)
self.funcs[tuple(new_signature)] = func
self._cache.clear()
try:
del self._ordering
except AttributeError:
pass
@property
def ordering(self):
try:
return self._ordering
except AttributeError:
return self.reorder()
def reorder(self, on_ambiguity=ambiguity_warn):
self._ordering = od = ordering(self.funcs)
amb = ambiguities(self.funcs)
if amb:
on_ambiguity(self, amb)
return od
def __call__(self, *args, **kwargs):
types = tuple([type(arg) for arg in args])
try:
func = self._cache[types]
except KeyError as e:
func = self.dispatch(*types)
if not func:
raise NotImplementedError(
f'Could not find signature for {self.name}: <{str_signature(types)}>') from e
self._cache[types] = func
try:
return func(*args, **kwargs)
except MDNotImplementedError as e:
funcs = self.dispatch_iter(*types)
next(funcs) # burn first
for func in funcs:
try:
return func(*args, **kwargs)
except MDNotImplementedError:
pass
raise NotImplementedError(
"Matching functions for "
f"{self.name}: <{str_signature(types)}> found, but none completed successfully",) from e
def __str__(self):
return f"<dispatched {self.name}>"
__repr__ = __str__
def dispatch(self, *types):
"""Determine appropriate implementation for this type signature
This method is internal. Users should call this object as a function.
Implementation resolution occurs within the ``__call__`` method.
>>> # xdoctest: +SKIP
>>> from multipledispatch import dispatch
>>> @dispatch(int)
... def inc(x):
... return x + 1
>>> implementation = inc.dispatch(int)
>>> implementation(3)
4
>>> print(inc.dispatch(float))
None
See Also:
``multipledispatch.conflict`` - module to determine resolution order
"""
if types in self.funcs:
return self.funcs[types]
try:
return next(self.dispatch_iter(*types))
except StopIteration:
return None
def dispatch_iter(self, *types):
n = len(types)
for signature in self.ordering:
if len(signature) == n and all(map(issubclass, types, signature)):
result = self.funcs[signature]
yield result
elif len(signature) and isvariadic(signature[-1]):
if variadic_signature_matches(types, signature):
result = self.funcs[signature]
yield result
@deprecated("`resolve()` is deprecated, use `dispatch(*types)`", category=FutureWarning)
def resolve(self, types):
""" Determine appropriate implementation for this type signature
.. deprecated:: 0.4.4
Use ``dispatch(*types)`` instead
"""
return self.dispatch(*types)
def __getstate__(self):
return {'name': self.name,
'funcs': self.funcs}
def __setstate__(self, d):
self.name = d['name']
self.funcs = d['funcs']
self._ordering = ordering(self.funcs)
self._cache = {}
@property
def __doc__(self):
docs = [f"Multiply dispatched method: {self.name}"]
if self.doc:
docs.append(self.doc)
other = []
for sig in self.ordering[::-1]:
func = self.funcs[sig]
if func.__doc__:
s = f'Inputs: <{str_signature(sig)}>\n'
s += '-' * len(s) + '\n'
s += func.__doc__.strip()
docs.append(s)
else:
other.append(str_signature(sig))
if other:
docs.append('Other signatures:\n ' + '\n '.join(other))
return '\n\n'.join(docs)
def _help(self, *args):
return self.dispatch(*map(type, args)).__doc__
def help(self, *args, **kwargs):
""" Print docstring for the function corresponding to inputs """
print(self._help(*args))
def _source(self, *args):
func = self.dispatch(*map(type, args))
if not func:
raise TypeError("No function found")
return source(func)
def source(self, *args, **kwargs):
""" Print source code for the function corresponding to inputs """
print(self._source(*args))
def source(func):
s = f'File: {inspect.getsourcefile(func)}\n\n'
s = s + inspect.getsource(func)
return s
class MethodDispatcher(Dispatcher):
""" Dispatch methods based on type signature
See Also:
Dispatcher
"""
__slots__ = ('obj', 'cls')
@classmethod
def get_func_params(cls, func):
if hasattr(inspect, "signature"):
sig = inspect.signature(func)
return itl.islice(sig.parameters.values(), 1, None)
def __get__(self, instance, owner):
self.obj = instance
self.cls = owner
return self
def __call__(self, *args, **kwargs):
types = tuple([type(arg) for arg in args])
func = self.dispatch(*types)
if not func:
raise NotImplementedError(f'Could not find signature for {self.name}: <{str_signature(types)}>')
return func(self.obj, *args, **kwargs)
def str_signature(sig):
""" String representation of type signature
>>> str_signature((int, float))
'int, float'
"""
return ', '.join(cls.__name__ for cls in sig)
def warning_text(name, amb):
""" The text for ambiguity warnings """
text = f"\nAmbiguities exist in dispatched function {name}\n\n"
text += "The following signatures may result in ambiguous behavior:\n"
for pair in amb:
text += "\t" + \
', '.join('[' + str_signature(s) + ']' for s in pair) + "\n"
text += "\n\nConsider making the following additions:\n\n"
text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s))
+ f')\ndef {name}(...)' for s in amb])
return text

View File

@ -0,0 +1,126 @@
# mypy: allow-untyped-defs
from collections import OrderedDict
__all__ = ["raises", "expand_tuples", "reverse_dict", "groupby", "typename"]
def raises(err, lamda):
try:
lamda()
return False
except err:
return True
def expand_tuples(L):
"""
>>> expand_tuples([1, (2, 3)])
[(1, 2), (1, 3)]
>>> expand_tuples([1, 2])
[(1, 2)]
"""
if not L:
return [()]
elif not isinstance(L[0], tuple):
rest = expand_tuples(L[1:])
return [(L[0],) + t for t in rest]
else:
rest = expand_tuples(L[1:])
return [(item,) + t for t in rest for item in L[0]]
# Taken from theano/theano/gof/sched.py
# Avoids licensing issues because this was written by Matthew Rocklin
def _toposort(edges):
""" Topological sort algorithm by Kahn [1] - O(nodes + vertices)
inputs:
edges - a dict of the form {a: {b, c}} where b and c depend on a
outputs:
L - an ordered list of nodes that satisfy the dependencies of edges
>>> _toposort({1: (2, 3), 2: (3, )})
[1, 2, 3]
>>> # Closely follows the wikipedia page [2]
>>> # [1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
>>> # Communications of the ACM
>>> # [2] http://en.wikipedia.org/wiki/Toposort#Algorithms
"""
incoming_edges = reverse_dict(edges)
incoming_edges = OrderedDict((k, set(val))
for k, val in incoming_edges.items())
S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges)
L = []
while S:
n, _ = S.popitem()
L.append(n)
for m in edges.get(n, ()):
assert n in incoming_edges[m]
incoming_edges[m].remove(n)
if not incoming_edges[m]:
S[m] = None
if any(incoming_edges.get(v, None) for v in edges):
raise ValueError("Input has cycles")
return L
def reverse_dict(d):
"""Reverses direction of dependence dict
>>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
>>> reverse_dict(d) # doctest: +SKIP
{1: ('a',), 2: ('a', 'b'), 3: ('b',)}
:note: dict order are not deterministic. As we iterate on the
input dict, it make the output of this function depend on the
dict order. So this function output order should be considered
as undeterministic.
"""
result = OrderedDict() # type: ignore[var-annotated]
for key in d:
for val in d[key]:
result[val] = result.get(val, ()) + (key,)
return result
# Taken from toolz
# Avoids licensing issues because this version was authored by Matthew Rocklin
def groupby(func, seq):
""" Group a collection by a key function
>>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
>>> groupby(len, names) # doctest: +SKIP
{3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
>>> iseven = lambda x: x % 2 == 0
>>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP
{False: [1, 3, 5, 7], True: [2, 4, 6, 8]}
See Also:
``countby``
"""
d = OrderedDict() # type: ignore[var-annotated]
for item in seq:
key = func(item)
if key not in d:
d[key] = []
d[key].append(item)
return d
def typename(type):
"""Get the name of `type`.
Parameters
----------
type : Union[Type, Tuple[Type]]
Returns
-------
str
The name of `type` or a tuple of the names of the types in `type`.
Examples
--------
>>> typename(int)
'int'
>>> typename((int, float))
'(int, float)'
"""
try:
return type.__name__
except AttributeError:
if len(type) == 1:
return typename(*type)
return f"({', '.join(map(typename, type))})"

View File

@ -0,0 +1,92 @@
# mypy: allow-untyped-defs
from .utils import typename
__all__ = ["VariadicSignatureType", "isvariadic", "VariadicSignatureMeta", "Variadic"]
class VariadicSignatureType(type):
# checking if subclass is a subclass of self
def __subclasscheck__(cls, subclass):
other_type = (subclass.variadic_type if isvariadic(subclass)
else (subclass,))
return subclass is cls or all(
issubclass(other, cls.variadic_type) for other in other_type # type: ignore[attr-defined]
)
def __eq__(cls, other):
"""
Return True if other has the same variadic type
Parameters
----------
other : object (type)
The object (type) to check
Returns
-------
bool
Whether or not `other` is equal to `self`
"""
return (isvariadic(other) and
set(cls.variadic_type) == set(other.variadic_type)) # type: ignore[attr-defined]
def __hash__(cls):
return hash((type(cls), frozenset(cls.variadic_type))) # type: ignore[attr-defined]
def isvariadic(obj):
"""Check whether the type `obj` is variadic.
Parameters
----------
obj : type
The type to check
Returns
-------
bool
Whether or not `obj` is variadic
Examples
--------
>>> # xdoctest: +SKIP
>>> isvariadic(int)
False
>>> isvariadic(Variadic[int])
True
"""
return isinstance(obj, VariadicSignatureType)
class VariadicSignatureMeta(type):
"""A metaclass that overrides ``__getitem__`` on the class. This is used to
generate a new type for Variadic signatures. See the Variadic class for
examples of how this behaves.
"""
def __getitem__(cls, variadic_type):
if not (isinstance(variadic_type, (type, tuple)) or type(variadic_type)):
raise ValueError("Variadic types must be type or tuple of types"
" (Variadic[int] or Variadic[(int, float)]")
if not isinstance(variadic_type, tuple):
variadic_type = variadic_type,
return VariadicSignatureType(
f'Variadic[{typename(variadic_type)}]',
(),
dict(variadic_type=variadic_type, __slots__=())
)
class Variadic(metaclass=VariadicSignatureMeta):
"""A class whose getitem method can be used to generate a new type
representing a specific variadic signature.
Examples
--------
>>> # xdoctest: +SKIP
>>> Variadic[int] # any number of int arguments
<class 'multipledispatch.variadic.Variadic[int]'>
>>> Variadic[(int, str)] # any number of one of int or str arguments
<class 'multipledispatch.variadic.Variadic[(int, str)]'>
>>> issubclass(int, Variadic[int])
True
>>> issubclass(int, Variadic[(int, str)])
True
>>> issubclass(str, Variadic[(int, str)])
True
>>> issubclass(float, Variadic[(int, str)])
False
"""

View File

@ -0,0 +1,396 @@
# mypy: allow-untyped-defs
import collections
import operator
from functools import reduce
from collections.abc import Mapping
__all__ = ['merge', 'merge_with', 'valmap', 'keymap', 'itemmap',
'valfilter', 'keyfilter', 'itemfilter',
'assoc', 'dissoc', 'assoc_in', 'update_in', 'get_in']
def _get_factory(f, kwargs):
factory = kwargs.pop('factory', dict)
if kwargs:
raise TypeError(f"{f.__name__}() got an unexpected keyword argument '{kwargs.popitem()[0]}'")
return factory
def merge(*dicts, **kwargs):
""" Merge a collection of dictionaries
>>> merge({1: 'one'}, {2: 'two'})
{1: 'one', 2: 'two'}
Later dictionaries have precedence
>>> merge({1: 2, 3: 4}, {3: 3, 4: 4})
{1: 2, 3: 3, 4: 4}
See Also:
merge_with
"""
if len(dicts) == 1 and not isinstance(dicts[0], Mapping):
dicts = dicts[0]
factory = _get_factory(merge, kwargs)
rv = factory()
for d in dicts:
rv.update(d)
return rv
def merge_with(func, *dicts, **kwargs):
""" Merge dictionaries and apply function to combined values
A key may occur in more than one dict, and all values mapped from the key
will be passed to the function as a list, such as func([val1, val2, ...]).
>>> merge_with(sum, {1: 1, 2: 2}, {1: 10, 2: 20})
{1: 11, 2: 22}
>>> merge_with(first, {1: 1, 2: 2}, {2: 20, 3: 30}) # doctest: +SKIP
{1: 1, 2: 2, 3: 30}
See Also:
merge
"""
if len(dicts) == 1 and not isinstance(dicts[0], Mapping):
dicts = dicts[0]
factory = _get_factory(merge_with, kwargs)
result = factory()
for d in dicts:
for k, v in d.items():
if k not in result:
result[k] = [v]
else:
result[k].append(v)
return valmap(func, result, factory)
def valmap(func, d, factory=dict):
""" Apply function to values of dictionary
>>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]}
>>> valmap(sum, bills) # doctest: +SKIP
{'Alice': 65, 'Bob': 45}
See Also:
keymap
itemmap
"""
rv = factory()
rv.update(zip(d.keys(), map(func, d.values())))
return rv
def keymap(func, d, factory=dict):
""" Apply function to keys of dictionary
>>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]}
>>> keymap(str.lower, bills) # doctest: +SKIP
{'alice': [20, 15, 30], 'bob': [10, 35]}
See Also:
valmap
itemmap
"""
rv = factory()
rv.update(zip(map(func, d.keys()), d.values()))
return rv
def itemmap(func, d, factory=dict):
""" Apply function to items of dictionary
>>> accountids = {"Alice": 10, "Bob": 20}
>>> itemmap(reversed, accountids) # doctest: +SKIP
{10: "Alice", 20: "Bob"}
See Also:
keymap
valmap
"""
rv = factory()
rv.update(map(func, d.items()))
return rv
def valfilter(predicate, d, factory=dict):
""" Filter items in dictionary by value
>>> iseven = lambda x: x % 2 == 0
>>> d = {1: 2, 2: 3, 3: 4, 4: 5}
>>> valfilter(iseven, d)
{1: 2, 3: 4}
See Also:
keyfilter
itemfilter
valmap
"""
rv = factory()
for k, v in d.items():
if predicate(v):
rv[k] = v
return rv
def keyfilter(predicate, d, factory=dict):
""" Filter items in dictionary by key
>>> iseven = lambda x: x % 2 == 0
>>> d = {1: 2, 2: 3, 3: 4, 4: 5}
>>> keyfilter(iseven, d)
{2: 3, 4: 5}
See Also:
valfilter
itemfilter
keymap
"""
rv = factory()
for k, v in d.items():
if predicate(k):
rv[k] = v
return rv
def itemfilter(predicate, d, factory=dict):
""" Filter items in dictionary by item
>>> def isvalid(item):
... k, v = item
... return k % 2 == 0 and v < 4
>>> d = {1: 2, 2: 3, 3: 4, 4: 5}
>>> itemfilter(isvalid, d)
{2: 3}
See Also:
keyfilter
valfilter
itemmap
"""
rv = factory()
for item in d.items():
if predicate(item):
k, v = item
rv[k] = v
return rv
def assoc(d, key, value, factory=dict):
""" Return a new dict with new key value pair
New dict has d[key] set to value. Does not modify the initial dictionary.
>>> assoc({'x': 1}, 'x', 2)
{'x': 2}
>>> assoc({'x': 1}, 'y', 3) # doctest: +SKIP
{'x': 1, 'y': 3}
"""
d2 = factory()
d2.update(d)
d2[key] = value
return d2
def dissoc(d, *keys, **kwargs):
""" Return a new dict with the given key(s) removed.
New dict has d[key] deleted for each supplied key.
Does not modify the initial dictionary.
>>> dissoc({'x': 1, 'y': 2}, 'y')
{'x': 1}
>>> dissoc({'x': 1, 'y': 2}, 'y', 'x')
{}
>>> dissoc({'x': 1}, 'y') # Ignores missing keys
{'x': 1}
"""
factory = _get_factory(dissoc, kwargs)
d2 = factory()
if len(keys) < len(d) * .6:
d2.update(d)
for key in keys:
if key in d2:
del d2[key]
else:
remaining = set(d)
remaining.difference_update(keys)
for k in remaining:
d2[k] = d[k]
return d2
def assoc_in(d, keys, value, factory=dict):
""" Return a new dict with new, potentially nested, key value pair
>>> purchase = {'name': 'Alice',
... 'order': {'items': ['Apple', 'Orange'],
... 'costs': [0.50, 1.25]},
... 'credit card': '5555-1234-1234-1234'}
>>> assoc_in(purchase, ['order', 'costs'], [0.25, 1.00]) # doctest: +SKIP
{'credit card': '5555-1234-1234-1234',
'name': 'Alice',
'order': {'costs': [0.25, 1.00], 'items': ['Apple', 'Orange']}}
"""
return update_in(d, keys, lambda x: value, value, factory)
def update_in(d, keys, func, default=None, factory=dict):
""" Update value in a (potentially) nested dictionary
inputs:
d - dictionary on which to operate
keys - list or tuple giving the location of the value to be changed in d
func - function to operate on that value
If keys == [k0,..,kX] and d[k0]..[kX] == v, update_in returns a copy of the
original dictionary with v replaced by func(v), but does not mutate the
original dictionary.
If k0 is not a key in d, update_in creates nested dictionaries to the depth
specified by the keys, with the innermost value set to func(default).
>>> inc = lambda x: x + 1
>>> update_in({'a': 0}, ['a'], inc)
{'a': 1}
>>> transaction = {'name': 'Alice',
... 'purchase': {'items': ['Apple', 'Orange'],
... 'costs': [0.50, 1.25]},
... 'credit card': '5555-1234-1234-1234'}
>>> update_in(transaction, ['purchase', 'costs'], sum) # doctest: +SKIP
{'credit card': '5555-1234-1234-1234',
'name': 'Alice',
'purchase': {'costs': 1.75, 'items': ['Apple', 'Orange']}}
>>> # updating a value when k0 is not in d
>>> update_in({}, [1, 2, 3], str, default="bar")
{1: {2: {3: 'bar'}}}
>>> update_in({1: 'foo'}, [2, 3, 4], inc, 0)
{1: 'foo', 2: {3: {4: 1}}}
"""
ks = iter(keys)
k = next(ks)
rv = inner = factory()
rv.update(d)
for key in ks:
if k in d:
d = d[k]
dtemp = factory()
dtemp.update(d)
else:
d = dtemp = factory()
inner[k] = inner = dtemp
k = key
if k in d:
inner[k] = func(d[k])
else:
inner[k] = func(default)
return rv
def get_in(keys, coll, default=None, no_default=False):
""" Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys.
If coll[i0][i1]...[iX] cannot be found, returns ``default``, unless
``no_default`` is specified, then it raises KeyError or IndexError.
``get_in`` is a generalization of ``operator.getitem`` for nested data
structures such as dictionaries and lists.
>>> transaction = {'name': 'Alice',
... 'purchase': {'items': ['Apple', 'Orange'],
... 'costs': [0.50, 1.25]},
... 'credit card': '5555-1234-1234-1234'}
>>> get_in(['purchase', 'items', 0], transaction)
'Apple'
>>> get_in(['name'], transaction)
'Alice'
>>> get_in(['purchase', 'total'], transaction)
>>> get_in(['purchase', 'items', 'apple'], transaction)
>>> get_in(['purchase', 'items', 10], transaction)
>>> get_in(['purchase', 'total'], transaction, 0)
0
>>> get_in(['y'], {}, no_default=True)
Traceback (most recent call last):
...
KeyError: 'y'
See Also:
itertoolz.get
operator.getitem
"""
try:
return reduce(operator.getitem, keys, coll)
except (KeyError, IndexError, TypeError):
if no_default:
raise
return default
def getter(index):
if isinstance(index, list):
if len(index) == 1:
index = index[0]
return lambda x: (x[index],)
elif index:
return operator.itemgetter(*index)
else:
return lambda x: ()
else:
return operator.itemgetter(index)
def groupby(key, seq):
""" Group a collection by a key function
>>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
>>> groupby(len, names) # doctest: +SKIP
{3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
>>> iseven = lambda x: x % 2 == 0
>>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP
{False: [1, 3, 5, 7], True: [2, 4, 6, 8]}
Non-callable keys imply grouping on a member.
>>> groupby('gender', [{'name': 'Alice', 'gender': 'F'},
... {'name': 'Bob', 'gender': 'M'},
... {'name': 'Charlie', 'gender': 'M'}]) # doctest:+SKIP
{'F': [{'gender': 'F', 'name': 'Alice'}],
'M': [{'gender': 'M', 'name': 'Bob'},
{'gender': 'M', 'name': 'Charlie'}]}
Not to be confused with ``itertools.groupby``
See Also:
countby
"""
if not callable(key):
key = getter(key)
d = collections.defaultdict(lambda: [].append) # type: ignore[var-annotated]
for item in seq:
d[key(item)](item)
rv = {}
for k, v in d.items():
rv[k] = v.__self__ # type: ignore[var-annotated, attr-defined]
return rv
def first(seq):
""" The first element in a sequence
>>> first('ABC')
'A'
"""
return next(iter(seq))

View File

@ -0,0 +1,106 @@
# mypy: allow-untyped-defs
__all__ = ["hashable", "transitive_get", "raises", "reverse_dict", "xfail", "freeze"]
def hashable(x):
try:
hash(x)
return True
except TypeError:
return False
def transitive_get(key, d):
""" Transitive dict.get
>>> d = {1: 2, 2: 3, 3: 4}
>>> d.get(1)
2
>>> transitive_get(1, d)
4
"""
while hashable(key) and key in d:
key = d[key]
return key
def raises(err, lamda):
try:
lamda()
return False
except err:
return True
# Taken from theano/theano/gof/sched.py
# Avoids licensing issues because this was written by Matthew Rocklin
def _toposort(edges):
""" Topological sort algorithm by Kahn [1] - O(nodes + vertices)
inputs:
edges - a dict of the form {a: {b, c}} where b and c depend on a
outputs:
L - an ordered list of nodes that satisfy the dependencies of edges
>>> # xdoctest: +SKIP
>>> _toposort({1: (2, 3), 2: (3, )})
[1, 2, 3]
Closely follows the wikipedia page [2]
[1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
Communications of the ACM
[2] http://en.wikipedia.org/wiki/Toposort#Algorithms
"""
incoming_edges = reverse_dict(edges)
incoming_edges = {k: set(val) for k, val in incoming_edges.items()}
S = ({v for v in edges if v not in incoming_edges})
L = []
while S:
n = S.pop()
L.append(n)
for m in edges.get(n, ()):
assert n in incoming_edges[m]
incoming_edges[m].remove(n)
if not incoming_edges[m]:
S.add(m)
if any(incoming_edges.get(v, None) for v in edges):
raise ValueError("Input has cycles")
return L
def reverse_dict(d):
"""Reverses direction of dependence dict
>>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
>>> reverse_dict(d) # doctest: +SKIP
{1: ('a',), 2: ('a', 'b'), 3: ('b',)}
:note: dict order are not deterministic. As we iterate on the
input dict, it make the output of this function depend on the
dict order. So this function output order should be considered
as undeterministic.
"""
result = {} # type: ignore[var-annotated]
for key in d:
for val in d[key]:
result[val] = result.get(val, ()) + (key,)
return result
def xfail(func):
try:
func()
raise Exception("XFailed test passed") # pragma:nocover # noqa: TRY002
except Exception:
pass
def freeze(d):
""" Freeze container to hashable form
>>> freeze(1)
1
>>> freeze([1, 2])
(1, 2)
>>> freeze({1: 2}) # doctest: +SKIP
frozenset([(1, 2)])
"""
if isinstance(d, dict):
return frozenset(map(freeze, d.items()))
if isinstance(d, set):
return frozenset(map(freeze, d))
if isinstance(d, (tuple, list)):
return tuple(map(freeze, d))
return d

View File

@ -0,0 +1,86 @@
# mypy: allow-untyped-defs
from contextlib import contextmanager
from .utils import hashable
from .dispatch import dispatch
_global_logic_variables = set() # type: ignore[var-annotated]
_glv = _global_logic_variables
class Var:
""" Logic Variable """
_id = 1
def __new__(cls, *token):
if len(token) == 0:
token = f"_{Var._id}" # type: ignore[assignment]
Var._id += 1
elif len(token) == 1:
token = token[0]
obj = object.__new__(cls)
obj.token = token # type: ignore[attr-defined]
return obj
def __str__(self):
return "~" + str(self.token) # type: ignore[attr-defined]
__repr__ = __str__
def __eq__(self, other):
return type(self) == type(other) and self.token == other.token # type: ignore[attr-defined]
def __hash__(self):
return hash((type(self), self.token)) # type: ignore[attr-defined]
def var():
return lambda *args: Var(*args)
def vars():
return lambda n: [var() for i in range(n)]
@dispatch(Var)
def isvar(v):
return True
isvar
@dispatch(object) # type: ignore[no-redef]
def isvar(o):
return not not _glv and hashable(o) and o in _glv
@contextmanager
def variables(*variables):
"""
Context manager for logic variables
Example:
>>> # xdoctest: +SKIP("undefined vars")
>>> from __future__ import with_statement
>>> with variables(1):
... print(isvar(1))
True
>>> print(isvar(1))
False
>>> # Normal approach
>>> from unification import unify
>>> x = var('x')
>>> unify(x, 1)
{~x: 1}
>>> # Context Manager approach
>>> with variables('x'):
... print(unify('x', 1))
{'x': 1}
"""
old_global_logic_variables = _global_logic_variables.copy()
_global_logic_variables.update(set(variables))
try:
yield
finally:
_global_logic_variables.clear()
_global_logic_variables.update(old_global_logic_variables)