I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,20 @@
"""
Unified place for determining if external dependencies are installed or not.
You should import all external modules using the import_module() function.
For example
>>> from sympy.external import import_module
>>> numpy = import_module('numpy')
If the resulting library is not installed, or if the installed version
is less than a given minimum version, the function will return None.
Otherwise, it will return the library. See the docstring of
import_module() for more information.
"""
from sympy.external.importtools import import_module
__all__ = ['import_module']

View File

@ -0,0 +1,341 @@
import os
from ctypes import c_long, sizeof
from functools import reduce
from typing import Tuple as tTuple, Type
from warnings import warn
from sympy.external import import_module
from .pythonmpq import PythonMPQ
from .ntheory import (
bit_scan1 as python_bit_scan1,
bit_scan0 as python_bit_scan0,
remove as python_remove,
factorial as python_factorial,
sqrt as python_sqrt,
sqrtrem as python_sqrtrem,
gcd as python_gcd,
lcm as python_lcm,
gcdext as python_gcdext,
is_square as python_is_square,
invert as python_invert,
legendre as python_legendre,
jacobi as python_jacobi,
kronecker as python_kronecker,
iroot as python_iroot,
is_fermat_prp as python_is_fermat_prp,
is_euler_prp as python_is_euler_prp,
is_strong_prp as python_is_strong_prp,
is_fibonacci_prp as python_is_fibonacci_prp,
is_lucas_prp as python_is_lucas_prp,
is_selfridge_prp as python_is_selfridge_prp,
is_strong_lucas_prp as python_is_strong_lucas_prp,
is_strong_selfridge_prp as python_is_strong_selfridge_prp,
is_bpsw_prp as python_is_bpsw_prp,
is_strong_bpsw_prp as python_is_strong_bpsw_prp,
)
__all__ = [
# GROUND_TYPES is either 'gmpy' or 'python' depending on which is used. If
# gmpy is installed then it will be used unless the environment variable
# SYMPY_GROUND_TYPES is set to something other than 'auto', 'gmpy', or
# 'gmpy2'.
'GROUND_TYPES',
# If HAS_GMPY is 0, no supported version of gmpy is available. Otherwise,
# HAS_GMPY will be 2 for gmpy2 if GROUND_TYPES is 'gmpy'. It used to be
# possible for HAS_GMPY to be 1 for gmpy but gmpy is no longer supported.
'HAS_GMPY',
# SYMPY_INTS is a tuple containing the base types for valid integer types.
# This is either (int,) or (int, type(mpz(0))) depending on GROUND_TYPES.
'SYMPY_INTS',
# MPQ is either gmpy.mpq or the Python equivalent from
# sympy.external.pythonmpq
'MPQ',
# MPZ is either gmpy.mpz or int.
'MPZ',
'bit_scan1',
'bit_scan0',
'remove',
'factorial',
'sqrt',
'is_square',
'sqrtrem',
'gcd',
'lcm',
'gcdext',
'invert',
'legendre',
'jacobi',
'kronecker',
'iroot',
'is_fermat_prp',
'is_euler_prp',
'is_strong_prp',
'is_fibonacci_prp',
'is_lucas_prp',
'is_selfridge_prp',
'is_strong_lucas_prp',
'is_strong_selfridge_prp',
'is_bpsw_prp',
'is_strong_bpsw_prp',
]
#
# Tested python-flint version. Future versions might work but we will only use
# them if explicitly requested by SYMPY_GROUND_TYPES=flint.
#
_PYTHON_FLINT_VERSION_NEEDED = ["0.6", "0.7", "0.8", "0.9"]
def _flint_version_okay(flint_version):
major, minor = flint_version.split('.')[:2]
flint_ver = f'{major}.{minor}'
return flint_ver in _PYTHON_FLINT_VERSION_NEEDED
#
# We will only use gmpy2 >= 2.0.0
#
_GMPY2_MIN_VERSION = '2.0.0'
def _get_flint(sympy_ground_types):
if sympy_ground_types not in ('auto', 'flint'):
return None
try:
import flint
# Earlier versions of python-flint may not have __version__.
from flint import __version__ as _flint_version
except ImportError:
if sympy_ground_types == 'flint':
warn("SYMPY_GROUND_TYPES was set to flint but python-flint is not "
"installed. Falling back to other ground types.")
return None
if _flint_version_okay(_flint_version):
return flint
elif sympy_ground_types == 'auto':
return None
else:
warn(f"Using python-flint {_flint_version} because SYMPY_GROUND_TYPES "
f"is set to flint but this version of SymPy is only tested "
f"with python-flint versions {_PYTHON_FLINT_VERSION_NEEDED}.")
return flint
def _get_gmpy2(sympy_ground_types):
if sympy_ground_types not in ('auto', 'gmpy', 'gmpy2'):
return None
gmpy = import_module('gmpy2', min_module_version=_GMPY2_MIN_VERSION,
module_version_attr='version', module_version_attr_call_args=())
if sympy_ground_types != 'auto' and gmpy is None:
warn("gmpy2 library is not installed, switching to 'python' ground types")
return gmpy
#
# SYMPY_GROUND_TYPES can be flint, gmpy, gmpy2, python or auto (default)
#
_SYMPY_GROUND_TYPES = os.environ.get('SYMPY_GROUND_TYPES', 'auto').lower()
_flint = None
_gmpy = None
#
# First handle auto-detection of flint/gmpy2. We will prefer flint if available
# or otherwise gmpy2 if available and then lastly the python types.
#
if _SYMPY_GROUND_TYPES in ('auto', 'flint'):
_flint = _get_flint(_SYMPY_GROUND_TYPES)
if _flint is not None:
_SYMPY_GROUND_TYPES = 'flint'
else:
_SYMPY_GROUND_TYPES = 'auto'
if _SYMPY_GROUND_TYPES in ('auto', 'gmpy', 'gmpy2'):
_gmpy = _get_gmpy2(_SYMPY_GROUND_TYPES)
if _gmpy is not None:
_SYMPY_GROUND_TYPES = 'gmpy'
else:
_SYMPY_GROUND_TYPES = 'python'
if _SYMPY_GROUND_TYPES not in ('flint', 'gmpy', 'python'):
warn("SYMPY_GROUND_TYPES environment variable unrecognised. "
"Should be 'auto', 'flint', 'gmpy', 'gmpy2' or 'python'.")
_SYMPY_GROUND_TYPES = 'python'
#
# At this point _SYMPY_GROUND_TYPES is either flint, gmpy or python. The blocks
# below define the values exported by this module in each case.
#
#
# In gmpy2 and flint, there are functions that take a long (or unsigned long)
# argument. That is, it is not possible to input a value larger than that.
#
LONG_MAX = (1 << (8*sizeof(c_long) - 1)) - 1
#
# Type checkers are confused by what SYMPY_INTS is. There may be a better type
# hint for this like Type[Integral] or something.
#
SYMPY_INTS: tTuple[Type, ...]
if _SYMPY_GROUND_TYPES == 'gmpy':
assert _gmpy is not None
flint = None
gmpy = _gmpy
HAS_GMPY = 2
GROUND_TYPES = 'gmpy'
SYMPY_INTS = (int, type(gmpy.mpz(0)))
MPZ = gmpy.mpz
MPQ = gmpy.mpq
bit_scan1 = gmpy.bit_scan1
bit_scan0 = gmpy.bit_scan0
remove = gmpy.remove
factorial = gmpy.fac
sqrt = gmpy.isqrt
is_square = gmpy.is_square
sqrtrem = gmpy.isqrt_rem
gcd = gmpy.gcd
lcm = gmpy.lcm
gcdext = gmpy.gcdext
invert = gmpy.invert
legendre = gmpy.legendre
jacobi = gmpy.jacobi
kronecker = gmpy.kronecker
def iroot(x, n):
# In the latest gmpy2, the threshold for n is ULONG_MAX,
# but adjust to the older one.
if n <= LONG_MAX:
return gmpy.iroot(x, n)
return python_iroot(x, n)
is_fermat_prp = gmpy.is_fermat_prp
is_euler_prp = gmpy.is_euler_prp
is_strong_prp = gmpy.is_strong_prp
is_fibonacci_prp = gmpy.is_fibonacci_prp
is_lucas_prp = gmpy.is_lucas_prp
is_selfridge_prp = gmpy.is_selfridge_prp
is_strong_lucas_prp = gmpy.is_strong_lucas_prp
is_strong_selfridge_prp = gmpy.is_strong_selfridge_prp
is_bpsw_prp = gmpy.is_bpsw_prp
is_strong_bpsw_prp = gmpy.is_strong_bpsw_prp
elif _SYMPY_GROUND_TYPES == 'flint':
assert _flint is not None
flint = _flint
gmpy = None
HAS_GMPY = 0
GROUND_TYPES = 'flint'
SYMPY_INTS = (int, flint.fmpz) # type: ignore
MPZ = flint.fmpz # type: ignore
MPQ = flint.fmpq # type: ignore
bit_scan1 = python_bit_scan1
bit_scan0 = python_bit_scan0
remove = python_remove
factorial = python_factorial
def sqrt(x):
return flint.fmpz(x).isqrt()
def is_square(x):
if x < 0:
return False
return flint.fmpz(x).sqrtrem()[1] == 0
def sqrtrem(x):
return flint.fmpz(x).sqrtrem()
def gcd(*args):
return reduce(flint.fmpz.gcd, args, flint.fmpz(0))
def lcm(*args):
return reduce(flint.fmpz.lcm, args, flint.fmpz(1))
gcdext = python_gcdext
invert = python_invert
legendre = python_legendre
def jacobi(x, y):
if y <= 0 or not y % 2:
raise ValueError("y should be an odd positive integer")
return flint.fmpz(x).jacobi(y)
kronecker = python_kronecker
def iroot(x, n):
if n <= LONG_MAX:
y = flint.fmpz(x).root(n)
return y, y**n == x
return python_iroot(x, n)
is_fermat_prp = python_is_fermat_prp
is_euler_prp = python_is_euler_prp
is_strong_prp = python_is_strong_prp
is_fibonacci_prp = python_is_fibonacci_prp
is_lucas_prp = python_is_lucas_prp
is_selfridge_prp = python_is_selfridge_prp
is_strong_lucas_prp = python_is_strong_lucas_prp
is_strong_selfridge_prp = python_is_strong_selfridge_prp
is_bpsw_prp = python_is_bpsw_prp
is_strong_bpsw_prp = python_is_strong_bpsw_prp
elif _SYMPY_GROUND_TYPES == 'python':
flint = None
gmpy = None
HAS_GMPY = 0
GROUND_TYPES = 'python'
SYMPY_INTS = (int,)
MPZ = int
MPQ = PythonMPQ
bit_scan1 = python_bit_scan1
bit_scan0 = python_bit_scan0
remove = python_remove
factorial = python_factorial
sqrt = python_sqrt
is_square = python_is_square
sqrtrem = python_sqrtrem
gcd = python_gcd
lcm = python_lcm
gcdext = python_gcdext
invert = python_invert
legendre = python_legendre
jacobi = python_jacobi
kronecker = python_kronecker
iroot = python_iroot
is_fermat_prp = python_is_fermat_prp
is_euler_prp = python_is_euler_prp
is_strong_prp = python_is_strong_prp
is_fibonacci_prp = python_is_fibonacci_prp
is_lucas_prp = python_is_lucas_prp
is_selfridge_prp = python_is_selfridge_prp
is_strong_lucas_prp = python_is_strong_lucas_prp
is_strong_selfridge_prp = python_is_strong_selfridge_prp
is_bpsw_prp = python_is_bpsw_prp
is_strong_bpsw_prp = python_is_strong_bpsw_prp
else:
assert False

View File

@ -0,0 +1,187 @@
"""Tools to assist importing optional external modules."""
import sys
import re
# Override these in the module to change the default warning behavior.
# For example, you might set both to False before running the tests so that
# warnings are not printed to the console, or set both to True for debugging.
WARN_NOT_INSTALLED = None # Default is False
WARN_OLD_VERSION = None # Default is True
def __sympy_debug():
# helper function from sympy/__init__.py
# We don't just import SYMPY_DEBUG from that file because we don't want to
# import all of SymPy just to use this module.
import os
debug_str = os.getenv('SYMPY_DEBUG', 'False')
if debug_str in ('True', 'False'):
return eval(debug_str)
else:
raise RuntimeError("unrecognized value for SYMPY_DEBUG: %s" %
debug_str)
if __sympy_debug():
WARN_OLD_VERSION = True
WARN_NOT_INSTALLED = True
_component_re = re.compile(r'(\d+ | [a-z]+ | \.)', re.VERBOSE)
def version_tuple(vstring):
# Parse a version string to a tuple e.g. '1.2' -> (1, 2)
# Simplified from distutils.version.LooseVersion which was deprecated in
# Python 3.10.
components = []
for x in _component_re.split(vstring):
if x and x != '.':
try:
x = int(x)
except ValueError:
pass
components.append(x)
return tuple(components)
def import_module(module, min_module_version=None, min_python_version=None,
warn_not_installed=None, warn_old_version=None,
module_version_attr='__version__', module_version_attr_call_args=None,
import_kwargs={}, catch=()):
"""
Import and return a module if it is installed.
If the module is not installed, it returns None.
A minimum version for the module can be given as the keyword argument
min_module_version. This should be comparable against the module version.
By default, module.__version__ is used to get the module version. To
override this, set the module_version_attr keyword argument. If the
attribute of the module to get the version should be called (e.g.,
module.version()), then set module_version_attr_call_args to the args such
that module.module_version_attr(*module_version_attr_call_args) returns the
module's version.
If the module version is less than min_module_version using the Python <
comparison, None will be returned, even if the module is installed. You can
use this to keep from importing an incompatible older version of a module.
You can also specify a minimum Python version by using the
min_python_version keyword argument. This should be comparable against
sys.version_info.
If the keyword argument warn_not_installed is set to True, the function will
emit a UserWarning when the module is not installed.
If the keyword argument warn_old_version is set to True, the function will
emit a UserWarning when the library is installed, but cannot be imported
because of the min_module_version or min_python_version options.
Note that because of the way warnings are handled, a warning will be
emitted for each module only once. You can change the default warning
behavior by overriding the values of WARN_NOT_INSTALLED and WARN_OLD_VERSION
in sympy.external.importtools. By default, WARN_NOT_INSTALLED is False and
WARN_OLD_VERSION is True.
This function uses __import__() to import the module. To pass additional
options to __import__(), use the import_kwargs keyword argument. For
example, to import a submodule A.B, you must pass a nonempty fromlist option
to __import__. See the docstring of __import__().
This catches ImportError to determine if the module is not installed. To
catch additional errors, pass them as a tuple to the catch keyword
argument.
Examples
========
>>> from sympy.external import import_module
>>> numpy = import_module('numpy')
>>> numpy = import_module('numpy', min_python_version=(2, 7),
... warn_old_version=False)
>>> numpy = import_module('numpy', min_module_version='1.5',
... warn_old_version=False) # numpy.__version__ is a string
>>> # gmpy does not have __version__, but it does have gmpy.version()
>>> gmpy = import_module('gmpy', min_module_version='1.14',
... module_version_attr='version', module_version_attr_call_args=(),
... warn_old_version=False)
>>> # To import a submodule, you must pass a nonempty fromlist to
>>> # __import__(). The values do not matter.
>>> p3 = import_module('mpl_toolkits.mplot3d',
... import_kwargs={'fromlist':['something']})
>>> # matplotlib.pyplot can raise RuntimeError when the display cannot be opened
>>> matplotlib = import_module('matplotlib',
... import_kwargs={'fromlist':['pyplot']}, catch=(RuntimeError,))
"""
# keyword argument overrides default, and global variable overrides
# keyword argument.
warn_old_version = (WARN_OLD_VERSION if WARN_OLD_VERSION is not None
else warn_old_version or True)
warn_not_installed = (WARN_NOT_INSTALLED if WARN_NOT_INSTALLED is not None
else warn_not_installed or False)
import warnings
# Check Python first so we don't waste time importing a module we can't use
if min_python_version:
if sys.version_info < min_python_version:
if warn_old_version:
warnings.warn("Python version is too old to use %s "
"(%s or newer required)" % (
module, '.'.join(map(str, min_python_version))),
UserWarning, stacklevel=2)
return
try:
mod = __import__(module, **import_kwargs)
## there's something funny about imports with matplotlib and py3k. doing
## from matplotlib import collections
## gives python's stdlib collections module. explicitly re-importing
## the module fixes this.
from_list = import_kwargs.get('fromlist', ())
for submod in from_list:
if submod == 'collections' and mod.__name__ == 'matplotlib':
__import__(module + '.' + submod)
except ImportError:
if warn_not_installed:
warnings.warn("%s module is not installed" % module, UserWarning,
stacklevel=2)
return
except catch as e:
if warn_not_installed:
warnings.warn(
"%s module could not be used (%s)" % (module, repr(e)),
stacklevel=2)
return
if min_module_version:
modversion = getattr(mod, module_version_attr)
if module_version_attr_call_args is not None:
modversion = modversion(*module_version_attr_call_args)
if version_tuple(modversion) < version_tuple(min_module_version):
if warn_old_version:
# Attempt to create a pretty string version of the version
if isinstance(min_module_version, str):
verstr = min_module_version
elif isinstance(min_module_version, (tuple, list)):
verstr = '.'.join(map(str, min_module_version))
else:
# Either don't know what this is. Hopefully
# it's something that has a nice str version, like an int.
verstr = str(min_module_version)
warnings.warn("%s version is too old to use "
"(%s or newer required)" % (module, verstr),
UserWarning, stacklevel=2)
return
return mod

View File

@ -0,0 +1,637 @@
# sympy.external.ntheory
#
# This module provides pure Python implementations of some number theory
# functions that are alternately used from gmpy2 if it is installed.
import sys
import math
import mpmath.libmp as mlib
_small_trailing = [0] * 256
for j in range(1, 8):
_small_trailing[1 << j :: 1 << (j + 1)] = [j] * (1 << (7 - j))
def bit_scan1(x, n=0):
if not x:
return
x = abs(x >> n)
low_byte = x & 0xFF
if low_byte:
return _small_trailing[low_byte] + n
t = 8 + n
x >>= 8
# 2**m is quick for z up through 2**30
z = x.bit_length() - 1
if x == 1 << z:
return z + t
if z < 300:
# fixed 8-byte reduction
while not x & 0xFF:
x >>= 8
t += 8
else:
# binary reduction important when there might be a large
# number of trailing 0s
p = z >> 1
while not x & 0xFF:
while x & ((1 << p) - 1):
p >>= 1
x >>= p
t += p
return t + _small_trailing[x & 0xFF]
def bit_scan0(x, n=0):
return bit_scan1(x + (1 << n), n)
def remove(x, f):
if f < 2:
raise ValueError("factor must be > 1")
if x == 0:
return 0, 0
if f == 2:
b = bit_scan1(x)
return x >> b, b
m = 0
y, rem = divmod(x, f)
while not rem:
x = y
m += 1
if m > 5:
pow_list = [f**2]
while pow_list:
_f = pow_list[-1]
y, rem = divmod(x, _f)
if not rem:
m += 1 << len(pow_list)
x = y
pow_list.append(_f**2)
else:
pow_list.pop()
y, rem = divmod(x, f)
return x, m
def factorial(x):
"""Return x!."""
return int(mlib.ifac(int(x)))
def sqrt(x):
"""Integer square root of x."""
return int(mlib.isqrt(int(x)))
def sqrtrem(x):
"""Integer square root of x and remainder."""
s, r = mlib.sqrtrem(int(x))
return (int(s), int(r))
if sys.version_info[:2] >= (3, 9):
# As of Python 3.9 these can take multiple arguments
gcd = math.gcd
lcm = math.lcm
else:
# Until python 3.8 is no longer supported
from functools import reduce
def gcd(*args):
"""gcd of multiple integers."""
return reduce(math.gcd, args, 0)
def lcm(*args):
"""lcm of multiple integers."""
if 0 in args:
return 0
return reduce(lambda x, y: x*y//math.gcd(x, y), args, 1)
def _sign(n):
if n < 0:
return -1, -n
return 1, n
def gcdext(a, b):
if not a or not b:
g = abs(a) or abs(b)
if not g:
return (0, 0, 0)
return (g, a // g, b // g)
x_sign, a = _sign(a)
y_sign, b = _sign(b)
x, r = 1, 0
y, s = 0, 1
while b:
q, c = divmod(a, b)
a, b = b, c
x, r = r, x - q*r
y, s = s, y - q*s
return (a, x * x_sign, y * y_sign)
def is_square(x):
"""Return True if x is a square number."""
if x < 0:
return False
# Note that the possible values of y**2 % n for a given n are limited.
# For example, when n=4, y**2 % n can only take 0 or 1.
# In other words, if x % 4 is 2 or 3, then x is not a square number.
# Mathematically, it determines if it belongs to the set {y**2 % n},
# but implementationally, it can be realized as a logical conjunction
# with an n-bit integer.
# see https://mersenneforum.org/showpost.php?p=110896
# def magic(n):
# s = {y**2 % n for y in range(n)}
# s = set(range(n)) - s
# return sum(1 << bit for bit in s)
# >>> print(hex(magic(128)))
# 0xfdfdfdedfdfdfdecfdfdfdedfdfcfdec
# >>> print(hex(magic(99)))
# 0x5f6f9ffb6fb7ddfcb75befdec
# >>> print(hex(magic(91)))
# 0x6fd1bfcfed5f3679d3ebdec
# >>> print(hex(magic(85)))
# 0xdef9ae771ffe3b9d67dec
if 0xfdfdfdedfdfdfdecfdfdfdedfdfcfdec & (1 << (x & 127)):
return False # e.g. 2, 3
m = x % 765765 # 765765 = 99 * 91 * 85
if 0x5f6f9ffb6fb7ddfcb75befdec & (1 << (m % 99)):
return False # e.g. 17, 68
if 0x6fd1bfcfed5f3679d3ebdec & (1 << (m % 91)):
return False # e.g. 97, 388
if 0xdef9ae771ffe3b9d67dec & (1 << (m % 85)):
return False # e.g. 793, 1408
return mlib.sqrtrem(int(x))[1] == 0
def invert(x, m):
"""Modular inverse of x modulo m.
Returns y such that x*y == 1 mod m.
Uses ``math.pow`` but reproduces the behaviour of ``gmpy2.invert``
which raises ZeroDivisionError if no inverse exists.
"""
try:
return pow(x, -1, m)
except ValueError:
raise ZeroDivisionError("invert() no inverse exists")
def legendre(x, y):
"""Legendre symbol (x / y).
Following the implementation of gmpy2,
the error is raised only when y is an even number.
"""
if y <= 0 or not y % 2:
raise ValueError("y should be an odd prime")
x %= y
if not x:
return 0
if pow(x, (y - 1) // 2, y) == 1:
return 1
return -1
def jacobi(x, y):
"""Jacobi symbol (x / y)."""
if y <= 0 or not y % 2:
raise ValueError("y should be an odd positive integer")
x %= y
if not x:
return int(y == 1)
if y == 1 or x == 1:
return 1
if gcd(x, y) != 1:
return 0
j = 1
while x != 0:
while x % 2 == 0 and x > 0:
x >>= 1
if y % 8 in [3, 5]:
j = -j
x, y = y, x
if x % 4 == y % 4 == 3:
j = -j
x %= y
return j
def kronecker(x, y):
"""Kronecker symbol (x / y)."""
if gcd(x, y) != 1:
return 0
if y == 0:
return 1
sign = -1 if y < 0 and x < 0 else 1
y = abs(y)
s = bit_scan1(y)
y >>= s
if s % 2 and x % 8 in [3, 5]:
sign = -sign
return sign * jacobi(x, y)
def iroot(y, n):
if y < 0:
raise ValueError("y must be nonnegative")
if n < 1:
raise ValueError("n must be positive")
if y in (0, 1):
return y, True
if n == 1:
return y, True
if n == 2:
x, rem = mlib.sqrtrem(y)
return int(x), not rem
if n >= y.bit_length():
return 1, False
# Get initial estimate for Newton's method. Care must be taken to
# avoid overflow
try:
guess = int(y**(1./n) + 0.5)
except OverflowError:
exp = math.log2(y)/n
if exp > 53:
shift = int(exp - 53)
guess = int(2.0**(exp - shift) + 1) << shift
else:
guess = int(2.0**exp)
if guess > 2**50:
# Newton iteration
xprev, x = -1, guess
while 1:
t = x**(n - 1)
xprev, x = x, ((n - 1)*x + y//t)//n
if abs(x - xprev) < 2:
break
else:
x = guess
# Compensate
t = x**n
while t < y:
x += 1
t = x**n
while t > y:
x -= 1
t = x**n
return x, t == y
def is_fermat_prp(n, a):
if a < 2:
raise ValueError("is_fermat_prp() requires 'a' greater than or equal to 2")
if n < 1:
raise ValueError("is_fermat_prp() requires 'n' be greater than 0")
if n == 1:
return False
if n % 2 == 0:
return n == 2
a %= n
if gcd(n, a) != 1:
raise ValueError("is_fermat_prp() requires gcd(n,a) == 1")
return pow(a, n - 1, n) == 1
def is_euler_prp(n, a):
if a < 2:
raise ValueError("is_euler_prp() requires 'a' greater than or equal to 2")
if n < 1:
raise ValueError("is_euler_prp() requires 'n' be greater than 0")
if n == 1:
return False
if n % 2 == 0:
return n == 2
a %= n
if gcd(n, a) != 1:
raise ValueError("is_euler_prp() requires gcd(n,a) == 1")
return pow(a, n >> 1, n) == jacobi(a, n) % n
def _is_strong_prp(n, a):
s = bit_scan1(n - 1)
a = pow(a, n >> s, n)
if a == 1 or a == n - 1:
return True
for _ in range(s - 1):
a = pow(a, 2, n)
if a == n - 1:
return True
if a == 1:
return False
return False
def is_strong_prp(n, a):
if a < 2:
raise ValueError("is_strong_prp() requires 'a' greater than or equal to 2")
if n < 1:
raise ValueError("is_strong_prp() requires 'n' be greater than 0")
if n == 1:
return False
if n % 2 == 0:
return n == 2
a %= n
if gcd(n, a) != 1:
raise ValueError("is_strong_prp() requires gcd(n,a) == 1")
return _is_strong_prp(n, a)
def _lucas_sequence(n, P, Q, k):
r"""Return the modular Lucas sequence (U_k, V_k, Q_k).
Explanation
===========
Given a Lucas sequence defined by P, Q, returns the kth values for
U and V, along with Q^k, all modulo n. This is intended for use with
possibly very large values of n and k, where the combinatorial functions
would be completely unusable.
.. math ::
U_k = \begin{cases}
0 & \text{if } k = 0\\
1 & \text{if } k = 1\\
PU_{k-1} - QU_{k-2} & \text{if } k > 1
\end{cases}\\
V_k = \begin{cases}
2 & \text{if } k = 0\\
P & \text{if } k = 1\\
PV_{k-1} - QV_{k-2} & \text{if } k > 1
\end{cases}
The modular Lucas sequences are used in numerous places in number theory,
especially in the Lucas compositeness tests and the various n + 1 proofs.
Parameters
==========
n : int
n is an odd number greater than or equal to 3
P : int
Q : int
D determined by D = P**2 - 4*Q is non-zero
k : int
k is a nonnegative integer
Returns
=======
U, V, Qk : (int, int, int)
`(U_k \bmod{n}, V_k \bmod{n}, Q^k \bmod{n})`
Examples
========
>>> from sympy.external.ntheory import _lucas_sequence
>>> N = 10**2000 + 4561
>>> sol = U, V, Qk = _lucas_sequence(N, 3, 1, N//2); sol
(0, 2, 1)
References
==========
.. [1] https://en.wikipedia.org/wiki/Lucas_sequence
"""
if k == 0:
return (0, 2, 1)
D = P**2 - 4*Q
U = 1
V = P
Qk = Q % n
if Q == 1:
# Optimization for extra strong tests.
for b in bin(k)[3:]:
U = (U*V) % n
V = (V*V - 2) % n
if b == "1":
U, V = U*P + V, V*P + U*D
if U & 1:
U += n
if V & 1:
V += n
U, V = U >> 1, V >> 1
elif P == 1 and Q == -1:
# Small optimization for 50% of Selfridge parameters.
for b in bin(k)[3:]:
U = (U*V) % n
if Qk == 1:
V = (V*V - 2) % n
else:
V = (V*V + 2) % n
Qk = 1
if b == "1":
# new_U = (U + V) // 2
# new_V = (5*U + V) // 2 = 2*U + new_U
U, V = U + V, U << 1
if U & 1:
U += n
U >>= 1
V += U
Qk = -1
Qk %= n
elif P == 1:
for b in bin(k)[3:]:
U = (U*V) % n
V = (V*V - 2*Qk) % n
Qk *= Qk
if b == "1":
# new_U = (U + V) // 2
# new_V = new_U - 2*Q*U
U, V = U + V, (Q*U) << 1
if U & 1:
U += n
U >>= 1
V = U - V
Qk *= Q
Qk %= n
else:
# The general case with any P and Q.
for b in bin(k)[3:]:
U = (U*V) % n
V = (V*V - 2*Qk) % n
Qk *= Qk
if b == "1":
U, V = U*P + V, V*P + U*D
if U & 1:
U += n
if V & 1:
V += n
U, V = U >> 1, V >> 1
Qk *= Q
Qk %= n
return (U % n, V % n, Qk)
def is_fibonacci_prp(n, p, q):
d = p**2 - 4*q
if d == 0 or p <= 0 or q not in [1, -1]:
raise ValueError("invalid values for p,q in is_fibonacci_prp()")
if n < 1:
raise ValueError("is_fibonacci_prp() requires 'n' be greater than 0")
if n == 1:
return False
if n % 2 == 0:
return n == 2
return _lucas_sequence(n, p, q, n)[1] == p % n
def is_lucas_prp(n, p, q):
d = p**2 - 4*q
if d == 0:
raise ValueError("invalid values for p,q in is_lucas_prp()")
if n < 1:
raise ValueError("is_lucas_prp() requires 'n' be greater than 0")
if n == 1:
return False
if n % 2 == 0:
return n == 2
if gcd(n, q*d) not in [1, n]:
raise ValueError("is_lucas_prp() requires gcd(n,2*q*D) == 1")
return _lucas_sequence(n, p, q, n - jacobi(d, n))[0] == 0
def _is_selfridge_prp(n):
"""Lucas compositeness test with the Selfridge parameters for n.
Explanation
===========
The Lucas compositeness test checks whether n is a prime number.
The test can be run with arbitrary parameters ``P`` and ``Q``, which also change the performance of the test.
So, which parameters are most effective for running the Lucas compositeness test?
As an algorithm for determining ``P`` and ``Q``, Selfridge proposed method A [1]_ page 1401
(Since two methods were proposed, referred to simply as A and B in the paper,
we will refer to one of them as "method A").
method A fixes ``P = 1``. Then, ``D`` defined by ``D = P**2 - 4Q`` is varied from 5, -7, 9, -11, 13, and so on,
with the first ``D`` being ``jacobi(D, n) == -1``. Once ``D`` is determined,
``Q`` is determined to be ``(P**2 - D)//4``.
References
==========
.. [1] Robert Baillie, Samuel S. Wagstaff, Lucas Pseudoprimes,
Math. Comp. Vol 35, Number 152 (1980), pp. 1391-1417,
https://doi.org/10.1090%2FS0025-5718-1980-0583518-6
http://mpqs.free.fr/LucasPseudoprimes.pdf
"""
for D in range(5, 1_000_000, 2):
if D & 2: # if D % 4 == 3
D = -D
j = jacobi(D, n)
if j == -1:
return _lucas_sequence(n, 1, (1-D) // 4, n + 1)[0] == 0
if j == 0 and D % n:
return False
# When j == -1 is hard to find, suspect a square number
if D == 13 and is_square(n):
return False
raise ValueError("appropriate value for D cannot be found in is_selfridge_prp()")
def is_selfridge_prp(n):
if n < 1:
raise ValueError("is_selfridge_prp() requires 'n' be greater than 0")
if n == 1:
return False
if n % 2 == 0:
return n == 2
return _is_selfridge_prp(n)
def is_strong_lucas_prp(n, p, q):
D = p**2 - 4*q
if D == 0:
raise ValueError("invalid values for p,q in is_strong_lucas_prp()")
if n < 1:
raise ValueError("is_selfridge_prp() requires 'n' be greater than 0")
if n == 1:
return False
if n % 2 == 0:
return n == 2
if gcd(n, q*D) not in [1, n]:
raise ValueError("is_strong_lucas_prp() requires gcd(n,2*q*D) == 1")
j = jacobi(D, n)
s = bit_scan1(n - j)
U, V, Qk = _lucas_sequence(n, p, q, (n - j) >> s)
if U == 0 or V == 0:
return True
for _ in range(s - 1):
V = (V*V - 2*Qk) % n
if V == 0:
return True
Qk = pow(Qk, 2, n)
return False
def _is_strong_selfridge_prp(n):
for D in range(5, 1_000_000, 2):
if D & 2: # if D % 4 == 3
D = -D
j = jacobi(D, n)
if j == -1:
s = bit_scan1(n + 1)
U, V, Qk = _lucas_sequence(n, 1, (1-D) // 4, (n + 1) >> s)
if U == 0 or V == 0:
return True
for _ in range(s - 1):
V = (V*V - 2*Qk) % n
if V == 0:
return True
Qk = pow(Qk, 2, n)
return False
if j == 0 and D % n:
return False
# When j == -1 is hard to find, suspect a square number
if D == 13 and is_square(n):
return False
raise ValueError("appropriate value for D cannot be found in is_strong_selfridge_prp()")
def is_strong_selfridge_prp(n):
if n < 1:
raise ValueError("is_strong_selfridge_prp() requires 'n' be greater than 0")
if n == 1:
return False
if n % 2 == 0:
return n == 2
return _is_strong_selfridge_prp(n)
def is_bpsw_prp(n):
if n < 1:
raise ValueError("is_bpsw_prp() requires 'n' be greater than 0")
if n == 1:
return False
if n % 2 == 0:
return n == 2
return _is_strong_prp(n, 2) and _is_selfridge_prp(n)
def is_strong_bpsw_prp(n):
if n < 1:
raise ValueError("is_strong_bpsw_prp() requires 'n' be greater than 0")
if n == 1:
return False
if n % 2 == 0:
return n == 2
return _is_strong_prp(n, 2) and _is_strong_selfridge_prp(n)

View File

@ -0,0 +1,341 @@
"""
PythonMPQ: Rational number type based on Python integers.
This class is intended as a pure Python fallback for when gmpy2 is not
installed. If gmpy2 is installed then its mpq type will be used instead. The
mpq type is around 20x faster. We could just use the stdlib Fraction class
here but that is slower:
from fractions import Fraction
from sympy.external.pythonmpq import PythonMPQ
nums = range(1000)
dens = range(5, 1005)
rats = [Fraction(n, d) for n, d in zip(nums, dens)]
sum(rats) # <--- 24 milliseconds
rats = [PythonMPQ(n, d) for n, d in zip(nums, dens)]
sum(rats) # <--- 7 milliseconds
Both mpq and Fraction have some awkward features like the behaviour of
division with // and %:
>>> from fractions import Fraction
>>> Fraction(2, 3) % Fraction(1, 4)
1/6
For the QQ domain we do not want this behaviour because there should be no
remainder when dividing rational numbers. SymPy does not make use of this
aspect of mpq when gmpy2 is installed. Since this class is a fallback for that
case we do not bother implementing e.g. __mod__ so that we can be sure we
are not using it when gmpy2 is installed either.
"""
import operator
from math import gcd
from decimal import Decimal
from fractions import Fraction
import sys
from typing import Tuple as tTuple, Type
# Used for __hash__
_PyHASH_MODULUS = sys.hash_info.modulus
_PyHASH_INF = sys.hash_info.inf
class PythonMPQ:
"""Rational number implementation that is intended to be compatible with
gmpy2's mpq.
Also slightly faster than fractions.Fraction.
PythonMPQ should be treated as immutable although no effort is made to
prevent mutation (since that might slow down calculations).
"""
__slots__ = ('numerator', 'denominator')
def __new__(cls, numerator, denominator=None):
"""Construct PythonMPQ with gcd computation and checks"""
if denominator is not None:
#
# PythonMPQ(n, d): require n and d to be int and d != 0
#
if isinstance(numerator, int) and isinstance(denominator, int):
# This is the slow part:
divisor = gcd(numerator, denominator)
numerator //= divisor
denominator //= divisor
return cls._new_check(numerator, denominator)
else:
#
# PythonMPQ(q)
#
# Here q can be PythonMPQ, int, Decimal, float, Fraction or str
#
if isinstance(numerator, int):
return cls._new(numerator, 1)
elif isinstance(numerator, PythonMPQ):
return cls._new(numerator.numerator, numerator.denominator)
# Let Fraction handle Decimal/float conversion and str parsing
if isinstance(numerator, (Decimal, float, str)):
numerator = Fraction(numerator)
if isinstance(numerator, Fraction):
return cls._new(numerator.numerator, numerator.denominator)
#
# Reject everything else. This is more strict than mpq which allows
# things like mpq(Fraction, Fraction) or mpq(Decimal, any). The mpq
# behaviour is somewhat inconsistent so we choose to accept only a
# more strict subset of what mpq allows.
#
raise TypeError("PythonMPQ() requires numeric or string argument")
@classmethod
def _new_check(cls, numerator, denominator):
"""Construct PythonMPQ, check divide by zero and canonicalize signs"""
if not denominator:
raise ZeroDivisionError(f'Zero divisor {numerator}/{denominator}')
elif denominator < 0:
numerator = -numerator
denominator = -denominator
return cls._new(numerator, denominator)
@classmethod
def _new(cls, numerator, denominator):
"""Construct PythonMPQ efficiently (no checks)"""
obj = super().__new__(cls)
obj.numerator = numerator
obj.denominator = denominator
return obj
def __int__(self):
"""Convert to int (truncates towards zero)"""
p, q = self.numerator, self.denominator
if p < 0:
return -(-p//q)
return p//q
def __float__(self):
"""Convert to float (approximately)"""
return self.numerator / self.denominator
def __bool__(self):
"""True/False if nonzero/zero"""
return bool(self.numerator)
def __eq__(self, other):
"""Compare equal with PythonMPQ, int, float, Decimal or Fraction"""
if isinstance(other, PythonMPQ):
return (self.numerator == other.numerator
and self.denominator == other.denominator)
elif isinstance(other, self._compatible_types):
return self.__eq__(PythonMPQ(other))
else:
return NotImplemented
def __hash__(self):
"""hash - same as mpq/Fraction"""
try:
dinv = pow(self.denominator, -1, _PyHASH_MODULUS)
except ValueError:
hash_ = _PyHASH_INF
else:
hash_ = hash(hash(abs(self.numerator)) * dinv)
result = hash_ if self.numerator >= 0 else -hash_
return -2 if result == -1 else result
def __reduce__(self):
"""Deconstruct for pickling"""
return type(self), (self.numerator, self.denominator)
def __str__(self):
"""Convert to string"""
if self.denominator != 1:
return f"{self.numerator}/{self.denominator}"
else:
return f"{self.numerator}"
def __repr__(self):
"""Convert to string"""
return f"MPQ({self.numerator},{self.denominator})"
def _cmp(self, other, op):
"""Helper for lt/le/gt/ge"""
if not isinstance(other, self._compatible_types):
return NotImplemented
lhs = self.numerator * other.denominator
rhs = other.numerator * self.denominator
return op(lhs, rhs)
def __lt__(self, other):
"""self < other"""
return self._cmp(other, operator.lt)
def __le__(self, other):
"""self <= other"""
return self._cmp(other, operator.le)
def __gt__(self, other):
"""self > other"""
return self._cmp(other, operator.gt)
def __ge__(self, other):
"""self >= other"""
return self._cmp(other, operator.ge)
def __abs__(self):
"""abs(q)"""
return self._new(abs(self.numerator), self.denominator)
def __pos__(self):
"""+q"""
return self
def __neg__(self):
"""-q"""
return self._new(-self.numerator, self.denominator)
def __add__(self, other):
"""q1 + q2"""
if isinstance(other, PythonMPQ):
#
# This is much faster than the naive method used in the stdlib
# fractions module. Not sure where this method comes from
# though...
#
# Compare timings for something like:
# nums = range(1000)
# rats = [PythonMPQ(n, d) for n, d in zip(nums[:-5], nums[5:])]
# sum(rats) # <-- time this
#
ap, aq = self.numerator, self.denominator
bp, bq = other.numerator, other.denominator
g = gcd(aq, bq)
if g == 1:
p = ap*bq + aq*bp
q = bq*aq
else:
q1, q2 = aq//g, bq//g
p, q = ap*q2 + bp*q1, q1*q2
g2 = gcd(p, g)
p, q = (p // g2), q * (g // g2)
elif isinstance(other, int):
p = self.numerator + self.denominator * other
q = self.denominator
else:
return NotImplemented
return self._new(p, q)
def __radd__(self, other):
"""z1 + q2"""
if isinstance(other, int):
p = self.numerator + self.denominator * other
q = self.denominator
return self._new(p, q)
else:
return NotImplemented
def __sub__(self ,other):
"""q1 - q2"""
if isinstance(other, PythonMPQ):
ap, aq = self.numerator, self.denominator
bp, bq = other.numerator, other.denominator
g = gcd(aq, bq)
if g == 1:
p = ap*bq - aq*bp
q = bq*aq
else:
q1, q2 = aq//g, bq//g
p, q = ap*q2 - bp*q1, q1*q2
g2 = gcd(p, g)
p, q = (p // g2), q * (g // g2)
elif isinstance(other, int):
p = self.numerator - self.denominator*other
q = self.denominator
else:
return NotImplemented
return self._new(p, q)
def __rsub__(self, other):
"""z1 - q2"""
if isinstance(other, int):
p = self.denominator * other - self.numerator
q = self.denominator
return self._new(p, q)
else:
return NotImplemented
def __mul__(self, other):
"""q1 * q2"""
if isinstance(other, PythonMPQ):
ap, aq = self.numerator, self.denominator
bp, bq = other.numerator, other.denominator
x1 = gcd(ap, bq)
x2 = gcd(bp, aq)
p, q = ((ap//x1)*(bp//x2), (aq//x2)*(bq//x1))
elif isinstance(other, int):
x = gcd(other, self.denominator)
p = self.numerator*(other//x)
q = self.denominator//x
else:
return NotImplemented
return self._new(p, q)
def __rmul__(self, other):
"""z1 * q2"""
if isinstance(other, int):
x = gcd(self.denominator, other)
p = self.numerator*(other//x)
q = self.denominator//x
return self._new(p, q)
else:
return NotImplemented
def __pow__(self, exp):
"""q ** z"""
p, q = self.numerator, self.denominator
if exp < 0:
p, q, exp = q, p, -exp
return self._new_check(p**exp, q**exp)
def __truediv__(self, other):
"""q1 / q2"""
if isinstance(other, PythonMPQ):
ap, aq = self.numerator, self.denominator
bp, bq = other.numerator, other.denominator
x1 = gcd(ap, bp)
x2 = gcd(bq, aq)
p, q = ((ap//x1)*(bq//x2), (aq//x2)*(bp//x1))
elif isinstance(other, int):
x = gcd(other, self.numerator)
p = self.numerator//x
q = self.denominator*(other//x)
else:
return NotImplemented
return self._new_check(p, q)
def __rtruediv__(self, other):
"""z / q"""
if isinstance(other, int):
x = gcd(self.numerator, other)
p = self.denominator*(other//x)
q = self.numerator//x
return self._new_check(p, q)
else:
return NotImplemented
_compatible_types: tTuple[Type, ...] = ()
#
# These are the types that PythonMPQ will interoperate with for operations
# and comparisons such as ==, + etc. We define this down here so that we can
# include PythonMPQ in the list as well.
#
PythonMPQ._compatible_types = (PythonMPQ, int, Decimal, Fraction)

View File

View File

@ -0,0 +1,313 @@
import sympy
import tempfile
import os
from sympy.core.mod import Mod
from sympy.core.relational import Eq
from sympy.core.symbol import symbols
from sympy.external import import_module
from sympy.tensor import IndexedBase, Idx
from sympy.utilities.autowrap import autowrap, ufuncify, CodeWrapError
from sympy.testing.pytest import skip
numpy = import_module('numpy', min_module_version='1.6.1')
Cython = import_module('Cython', min_module_version='0.15.1')
f2py = import_module('numpy.f2py', import_kwargs={'fromlist': ['f2py']})
f2pyworks = False
if f2py:
try:
autowrap(symbols('x'), 'f95', 'f2py')
except (CodeWrapError, ImportError, OSError):
f2pyworks = False
else:
f2pyworks = True
a, b, c = symbols('a b c')
n, m, d = symbols('n m d', integer=True)
A, B, C = symbols('A B C', cls=IndexedBase)
i = Idx('i', m)
j = Idx('j', n)
k = Idx('k', d)
def has_module(module):
"""
Return True if module exists, otherwise run skip().
module should be a string.
"""
# To give a string of the module name to skip(), this function takes a
# string. So we don't waste time running import_module() more than once,
# just map the three modules tested here in this dict.
modnames = {'numpy': numpy, 'Cython': Cython, 'f2py': f2py}
if modnames[module]:
if module == 'f2py' and not f2pyworks:
skip("Couldn't run f2py.")
return True
skip("Couldn't import %s." % module)
#
# test runners used by several language-backend combinations
#
def runtest_autowrap_twice(language, backend):
f = autowrap((((a + b)/c)**5).expand(), language, backend)
g = autowrap((((a + b)/c)**4).expand(), language, backend)
# check that autowrap updates the module name. Else, g gives the same as f
assert f(1, -2, 1) == -1.0
assert g(1, -2, 1) == 1.0
def runtest_autowrap_trace(language, backend):
has_module('numpy')
trace = autowrap(A[i, i], language, backend)
assert trace(numpy.eye(100)) == 100
def runtest_autowrap_matrix_vector(language, backend):
has_module('numpy')
x, y = symbols('x y', cls=IndexedBase)
expr = Eq(y[i], A[i, j]*x[j])
mv = autowrap(expr, language, backend)
# compare with numpy's dot product
M = numpy.random.rand(10, 20)
x = numpy.random.rand(20)
y = numpy.dot(M, x)
assert numpy.sum(numpy.abs(y - mv(M, x))) < 1e-13
def runtest_autowrap_matrix_matrix(language, backend):
has_module('numpy')
expr = Eq(C[i, j], A[i, k]*B[k, j])
matmat = autowrap(expr, language, backend)
# compare with numpy's dot product
M1 = numpy.random.rand(10, 20)
M2 = numpy.random.rand(20, 15)
M3 = numpy.dot(M1, M2)
assert numpy.sum(numpy.abs(M3 - matmat(M1, M2))) < 1e-13
def runtest_ufuncify(language, backend):
has_module('numpy')
a, b, c = symbols('a b c')
fabc = ufuncify([a, b, c], a*b + c, backend=backend)
facb = ufuncify([a, c, b], a*b + c, backend=backend)
grid = numpy.linspace(-2, 2, 50)
b = numpy.linspace(-5, 4, 50)
c = numpy.linspace(-1, 1, 50)
expected = grid*b + c
numpy.testing.assert_allclose(fabc(grid, b, c), expected)
numpy.testing.assert_allclose(facb(grid, c, b), expected)
def runtest_issue_10274(language, backend):
expr = (a - b + c)**(13)
tmp = tempfile.mkdtemp()
f = autowrap(expr, language, backend, tempdir=tmp,
helpers=('helper', a - b + c, (a, b, c)))
assert f(1, 1, 1) == 1
for file in os.listdir(tmp):
if not (file.startswith("wrapped_code_") and file.endswith(".c")):
continue
with open(tmp + '/' + file) as fil:
lines = fil.readlines()
assert lines[0] == "/******************************************************************************\n"
assert "Code generated with SymPy " + sympy.__version__ in lines[1]
assert lines[2:] == [
" * *\n",
" * See http://www.sympy.org/ for more information. *\n",
" * *\n",
" * This file is part of 'autowrap' *\n",
" ******************************************************************************/\n",
"#include " + '"' + file[:-1]+ 'h"' + "\n",
"#include <math.h>\n",
"\n",
"double helper(double a, double b, double c) {\n",
"\n",
" double helper_result;\n",
" helper_result = a - b + c;\n",
" return helper_result;\n",
"\n",
"}\n",
"\n",
"double autofunc(double a, double b, double c) {\n",
"\n",
" double autofunc_result;\n",
" autofunc_result = pow(helper(a, b, c), 13);\n",
" return autofunc_result;\n",
"\n",
"}\n",
]
def runtest_issue_15337(language, backend):
has_module('numpy')
# NOTE : autowrap was originally designed to only accept an iterable for
# the kwarg "helpers", but in issue 10274 the user mistakenly thought that
# if there was only a single helper it did not need to be passed via an
# iterable that wrapped the helper tuple. There were no tests for this
# behavior so when the code was changed to accept a single tuple it broke
# the original behavior. These tests below ensure that both now work.
a, b, c, d, e = symbols('a, b, c, d, e')
expr = (a - b + c - d + e)**13
exp_res = (1. - 2. + 3. - 4. + 5.)**13
f = autowrap(expr, language, backend, args=(a, b, c, d, e),
helpers=('f1', a - b + c, (a, b, c)))
numpy.testing.assert_allclose(f(1, 2, 3, 4, 5), exp_res)
f = autowrap(expr, language, backend, args=(a, b, c, d, e),
helpers=(('f1', a - b, (a, b)), ('f2', c - d, (c, d))))
numpy.testing.assert_allclose(f(1, 2, 3, 4, 5), exp_res)
def test_issue_15230():
has_module('f2py')
x, y = symbols('x, y')
expr = Mod(x, 3.0) - Mod(y, -2.0)
f = autowrap(expr, args=[x, y], language='F95')
exp_res = float(expr.xreplace({x: 3.5, y: 2.7}).evalf())
assert abs(f(3.5, 2.7) - exp_res) < 1e-14
x, y = symbols('x, y', integer=True)
expr = Mod(x, 3) - Mod(y, -2)
f = autowrap(expr, args=[x, y], language='F95')
assert f(3, 2) == expr.xreplace({x: 3, y: 2})
#
# tests of language-backend combinations
#
# f2py
def test_wrap_twice_f95_f2py():
has_module('f2py')
runtest_autowrap_twice('f95', 'f2py')
def test_autowrap_trace_f95_f2py():
has_module('f2py')
runtest_autowrap_trace('f95', 'f2py')
def test_autowrap_matrix_vector_f95_f2py():
has_module('f2py')
runtest_autowrap_matrix_vector('f95', 'f2py')
def test_autowrap_matrix_matrix_f95_f2py():
has_module('f2py')
runtest_autowrap_matrix_matrix('f95', 'f2py')
def test_ufuncify_f95_f2py():
has_module('f2py')
runtest_ufuncify('f95', 'f2py')
def test_issue_15337_f95_f2py():
has_module('f2py')
runtest_issue_15337('f95', 'f2py')
# Cython
def test_wrap_twice_c_cython():
has_module('Cython')
runtest_autowrap_twice('C', 'cython')
def test_autowrap_trace_C_Cython():
has_module('Cython')
runtest_autowrap_trace('C99', 'cython')
def test_autowrap_matrix_vector_C_cython():
has_module('Cython')
runtest_autowrap_matrix_vector('C99', 'cython')
def test_autowrap_matrix_matrix_C_cython():
has_module('Cython')
runtest_autowrap_matrix_matrix('C99', 'cython')
def test_ufuncify_C_Cython():
has_module('Cython')
runtest_ufuncify('C99', 'cython')
def test_issue_10274_C_cython():
has_module('Cython')
runtest_issue_10274('C89', 'cython')
def test_issue_15337_C_cython():
has_module('Cython')
runtest_issue_15337('C89', 'cython')
def test_autowrap_custom_printer():
has_module('Cython')
from sympy.core.numbers import pi
from sympy.utilities.codegen import C99CodeGen
from sympy.printing.c import C99CodePrinter
class PiPrinter(C99CodePrinter):
def _print_Pi(self, expr):
return "S_PI"
printer = PiPrinter()
gen = C99CodeGen(printer=printer)
gen.preprocessor_statements.append('#include "shortpi.h"')
expr = pi * a
expected = (
'#include "%s"\n'
'#include <math.h>\n'
'#include "shortpi.h"\n'
'\n'
'double autofunc(double a) {\n'
'\n'
' double autofunc_result;\n'
' autofunc_result = S_PI*a;\n'
' return autofunc_result;\n'
'\n'
'}\n'
)
tmpdir = tempfile.mkdtemp()
# write a trivial header file to use in the generated code
with open(os.path.join(tmpdir, 'shortpi.h'), 'w') as f:
f.write('#define S_PI 3.14')
func = autowrap(expr, backend='cython', tempdir=tmpdir, code_gen=gen)
assert func(4.2) == 3.14 * 4.2
# check that the generated code is correct
for filename in os.listdir(tmpdir):
if filename.startswith('wrapped_code') and filename.endswith('.c'):
with open(os.path.join(tmpdir, filename)) as f:
lines = f.readlines()
expected = expected % filename.replace('.c', '.h')
assert ''.join(lines[7:]) == expected
# Numpy
def test_ufuncify_numpy():
# This test doesn't use Cython, but if Cython works, then there is a valid
# C compiler, which is needed.
has_module('Cython')
runtest_ufuncify('C99', 'numpy')

View File

@ -0,0 +1,379 @@
# This tests the compilation and execution of the source code generated with
# utilities.codegen. The compilation takes place in a temporary directory that
# is removed after the test. By default the test directory is always removed,
# but this behavior can be changed by setting the environment variable
# SYMPY_TEST_CLEAN_TEMP to:
# export SYMPY_TEST_CLEAN_TEMP=always : the default behavior.
# export SYMPY_TEST_CLEAN_TEMP=success : only remove the directories of working tests.
# export SYMPY_TEST_CLEAN_TEMP=never : never remove the directories with the test code.
# When a directory is not removed, the necessary information is printed on
# screen to find the files that belong to the (failed) tests. If a test does
# not fail, py.test captures all the output and you will not see the directories
# corresponding to the successful tests. Use the --nocapture option to see all
# the output.
# All tests below have a counterpart in utilities/test/test_codegen.py. In the
# latter file, the resulting code is compared with predefined strings, without
# compilation or execution.
# All the generated Fortran code should conform with the Fortran 95 standard,
# and all the generated C code should be ANSI C, which facilitates the
# incorporation in various projects. The tests below assume that the binary cc
# is somewhere in the path and that it can compile ANSI C code.
from sympy.abc import x, y, z
from sympy.external import import_module
from sympy.testing.pytest import skip
from sympy.utilities.codegen import codegen, make_routine, get_code_generator
import sys
import os
import tempfile
import subprocess
pyodide_js = import_module('pyodide_js')
# templates for the main program that will test the generated code.
main_template = {}
main_template['F95'] = """
program main
include "codegen.h"
integer :: result;
result = 0
%(statements)s
call exit(result)
end program
"""
main_template['C89'] = """
#include "codegen.h"
#include <stdio.h>
#include <math.h>
int main() {
int result = 0;
%(statements)s
return result;
}
"""
main_template['C99'] = main_template['C89']
# templates for the numerical tests
numerical_test_template = {}
numerical_test_template['C89'] = """
if (fabs(%(call)s)>%(threshold)s) {
printf("Numerical validation failed: %(call)s=%%e threshold=%(threshold)s\\n", %(call)s);
result = -1;
}
"""
numerical_test_template['C99'] = numerical_test_template['C89']
numerical_test_template['F95'] = """
if (abs(%(call)s)>%(threshold)s) then
write(6,"('Numerical validation failed:')")
write(6,"('%(call)s=',e15.5,'threshold=',e15.5)") %(call)s, %(threshold)s
result = -1;
end if
"""
# command sequences for supported compilers
compile_commands = {}
compile_commands['cc'] = [
"cc -c codegen.c -o codegen.o",
"cc -c main.c -o main.o",
"cc main.o codegen.o -lm -o test.exe"
]
compile_commands['gfortran'] = [
"gfortran -c codegen.f90 -o codegen.o",
"gfortran -ffree-line-length-none -c main.f90 -o main.o",
"gfortran main.o codegen.o -o test.exe"
]
compile_commands['g95'] = [
"g95 -c codegen.f90 -o codegen.o",
"g95 -ffree-line-length-huge -c main.f90 -o main.o",
"g95 main.o codegen.o -o test.exe"
]
compile_commands['ifort'] = [
"ifort -c codegen.f90 -o codegen.o",
"ifort -c main.f90 -o main.o",
"ifort main.o codegen.o -o test.exe"
]
combinations_lang_compiler = [
('C89', 'cc'),
('C99', 'cc'),
('F95', 'ifort'),
('F95', 'gfortran'),
('F95', 'g95')
]
def try_run(commands):
"""Run a series of commands and only return True if all ran fine."""
if pyodide_js:
return False
with open(os.devnull, 'w') as null:
for command in commands:
retcode = subprocess.call(command, stdout=null, shell=True,
stderr=subprocess.STDOUT)
if retcode != 0:
return False
return True
def run_test(label, routines, numerical_tests, language, commands, friendly=True):
"""A driver for the codegen tests.
This driver assumes that a compiler ifort is present in the PATH and that
ifort is (at least) a Fortran 90 compiler. The generated code is written in
a temporary directory, together with a main program that validates the
generated code. The test passes when the compilation and the validation
run correctly.
"""
# Check input arguments before touching the file system
language = language.upper()
assert language in main_template
assert language in numerical_test_template
# Check that environment variable makes sense
clean = os.getenv('SYMPY_TEST_CLEAN_TEMP', 'always').lower()
if clean not in ('always', 'success', 'never'):
raise ValueError("SYMPY_TEST_CLEAN_TEMP must be one of the following: 'always', 'success' or 'never'.")
# Do all the magic to compile, run and validate the test code
# 1) prepare the temporary working directory, switch to that dir
work = tempfile.mkdtemp("_sympy_%s_test" % language, "%s_" % label)
oldwork = os.getcwd()
os.chdir(work)
# 2) write the generated code
if friendly:
# interpret the routines as a name_expr list and call the friendly
# function codegen
codegen(routines, language, "codegen", to_files=True)
else:
code_gen = get_code_generator(language, "codegen")
code_gen.write(routines, "codegen", to_files=True)
# 3) write a simple main program that links to the generated code, and that
# includes the numerical tests
test_strings = []
for fn_name, args, expected, threshold in numerical_tests:
call_string = "%s(%s)-(%s)" % (
fn_name, ",".join(str(arg) for arg in args), expected)
if language == "F95":
call_string = fortranize_double_constants(call_string)
threshold = fortranize_double_constants(str(threshold))
test_strings.append(numerical_test_template[language] % {
"call": call_string,
"threshold": threshold,
})
if language == "F95":
f_name = "main.f90"
elif language.startswith("C"):
f_name = "main.c"
else:
raise NotImplementedError(
"FIXME: filename extension unknown for language: %s" % language)
with open(f_name, "w") as f:
f.write(
main_template[language] % {'statements': "".join(test_strings)})
# 4) Compile and link
compiled = try_run(commands)
# 5) Run if compiled
if compiled:
executed = try_run(["./test.exe"])
else:
executed = False
# 6) Clean up stuff
if clean == 'always' or (clean == 'success' and compiled and executed):
def safe_remove(filename):
if os.path.isfile(filename):
os.remove(filename)
safe_remove("codegen.f90")
safe_remove("codegen.c")
safe_remove("codegen.h")
safe_remove("codegen.o")
safe_remove("main.f90")
safe_remove("main.c")
safe_remove("main.o")
safe_remove("test.exe")
os.chdir(oldwork)
os.rmdir(work)
else:
print("TEST NOT REMOVED: %s" % work, file=sys.stderr)
os.chdir(oldwork)
# 7) Do the assertions in the end
assert compiled, "failed to compile %s code with:\n%s" % (
language, "\n".join(commands))
assert executed, "failed to execute %s code from:\n%s" % (
language, "\n".join(commands))
def fortranize_double_constants(code_string):
"""
Replaces every literal float with literal doubles
"""
import re
pattern_exp = re.compile(r'\d+(\.)?\d*[eE]-?\d+')
pattern_float = re.compile(r'\d+\.\d*(?!\d*d)')
def subs_exp(matchobj):
return re.sub('[eE]', 'd', matchobj.group(0))
def subs_float(matchobj):
return "%sd0" % matchobj.group(0)
code_string = pattern_exp.sub(subs_exp, code_string)
code_string = pattern_float.sub(subs_float, code_string)
return code_string
def is_feasible(language, commands):
# This test should always work, otherwise the compiler is not present.
routine = make_routine("test", x)
numerical_tests = [
("test", ( 1.0,), 1.0, 1e-15),
("test", (-1.0,), -1.0, 1e-15),
]
try:
run_test("is_feasible", [routine], numerical_tests, language, commands,
friendly=False)
return True
except AssertionError:
return False
valid_lang_commands = []
invalid_lang_compilers = []
for lang, compiler in combinations_lang_compiler:
commands = compile_commands[compiler]
if is_feasible(lang, commands):
valid_lang_commands.append((lang, commands))
else:
invalid_lang_compilers.append((lang, compiler))
# We test all language-compiler combinations, just to report what is skipped
def test_C89_cc():
if ("C89", 'cc') in invalid_lang_compilers:
skip("`cc' command didn't work as expected (C89)")
def test_C99_cc():
if ("C99", 'cc') in invalid_lang_compilers:
skip("`cc' command didn't work as expected (C99)")
def test_F95_ifort():
if ("F95", 'ifort') in invalid_lang_compilers:
skip("`ifort' command didn't work as expected")
def test_F95_gfortran():
if ("F95", 'gfortran') in invalid_lang_compilers:
skip("`gfortran' command didn't work as expected")
def test_F95_g95():
if ("F95", 'g95') in invalid_lang_compilers:
skip("`g95' command didn't work as expected")
# Here comes the actual tests
def test_basic_codegen():
numerical_tests = [
("test", (1.0, 6.0, 3.0), 21.0, 1e-15),
("test", (-1.0, 2.0, -2.5), -2.5, 1e-15),
]
name_expr = [("test", (x + y)*z)]
for lang, commands in valid_lang_commands:
run_test("basic_codegen", name_expr, numerical_tests, lang, commands)
def test_intrinsic_math1_codegen():
# not included: log10
from sympy.core.evalf import N
from sympy.functions import ln
from sympy.functions.elementary.exponential import log
from sympy.functions.elementary.hyperbolic import (cosh, sinh, tanh)
from sympy.functions.elementary.integers import (ceiling, floor)
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.trigonometric import (acos, asin, atan, cos, sin, tan)
name_expr = [
("test_fabs", abs(x)),
("test_acos", acos(x)),
("test_asin", asin(x)),
("test_atan", atan(x)),
("test_cos", cos(x)),
("test_cosh", cosh(x)),
("test_log", log(x)),
("test_ln", ln(x)),
("test_sin", sin(x)),
("test_sinh", sinh(x)),
("test_sqrt", sqrt(x)),
("test_tan", tan(x)),
("test_tanh", tanh(x)),
]
numerical_tests = []
for name, expr in name_expr:
for xval in 0.2, 0.5, 0.8:
expected = N(expr.subs(x, xval))
numerical_tests.append((name, (xval,), expected, 1e-14))
for lang, commands in valid_lang_commands:
if lang.startswith("C"):
name_expr_C = [("test_floor", floor(x)), ("test_ceil", ceiling(x))]
else:
name_expr_C = []
run_test("intrinsic_math1", name_expr + name_expr_C,
numerical_tests, lang, commands)
def test_instrinsic_math2_codegen():
# not included: frexp, ldexp, modf, fmod
from sympy.core.evalf import N
from sympy.functions.elementary.trigonometric import atan2
name_expr = [
("test_atan2", atan2(x, y)),
("test_pow", x**y),
]
numerical_tests = []
for name, expr in name_expr:
for xval, yval in (0.2, 1.3), (0.5, -0.2), (0.8, 0.8):
expected = N(expr.subs(x, xval).subs(y, yval))
numerical_tests.append((name, (xval, yval), expected, 1e-14))
for lang, commands in valid_lang_commands:
run_test("intrinsic_math2", name_expr, numerical_tests, lang, commands)
def test_complicated_codegen():
from sympy.core.evalf import N
from sympy.functions.elementary.trigonometric import (cos, sin, tan)
name_expr = [
("test1", ((sin(x) + cos(y) + tan(z))**7).expand()),
("test2", cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))),
]
numerical_tests = []
for name, expr in name_expr:
for xval, yval, zval in (0.2, 1.3, -0.3), (0.5, -0.2, 0.0), (0.8, 2.1, 0.8):
expected = N(expr.subs(x, xval).subs(y, yval).subs(z, zval))
numerical_tests.append((name, (xval, yval, zval), expected, 1e-12))
for lang, commands in valid_lang_commands:
run_test(
"complicated_codegen", name_expr, numerical_tests, lang, commands)

View File

@ -0,0 +1,12 @@
from sympy.external.gmpy import LONG_MAX, iroot
from sympy.testing.pytest import raises
def test_iroot():
assert iroot(2, LONG_MAX) == (1, False)
assert iroot(2, LONG_MAX + 1) == (1, False)
for x in range(3):
assert iroot(x, 1) == (x, True)
raises(ValueError, lambda: iroot(-1, 1))
raises(ValueError, lambda: iroot(0, 0))
raises(ValueError, lambda: iroot(0, -1))

View File

@ -0,0 +1,40 @@
from sympy.external import import_module
from sympy.testing.pytest import warns
# fixes issue that arose in addressing issue 6533
def test_no_stdlib_collections():
'''
make sure we get the right collections when it is not part of a
larger list
'''
import collections
matplotlib = import_module('matplotlib',
import_kwargs={'fromlist': ['cm', 'collections']},
min_module_version='1.1.0', catch=(RuntimeError,))
if matplotlib:
assert collections != matplotlib.collections
def test_no_stdlib_collections2():
'''
make sure we get the right collections when it is not part of a
larger list
'''
import collections
matplotlib = import_module('matplotlib',
import_kwargs={'fromlist': ['collections']},
min_module_version='1.1.0', catch=(RuntimeError,))
if matplotlib:
assert collections != matplotlib.collections
def test_no_stdlib_collections3():
'''make sure we get the right collections with no catch'''
import collections
matplotlib = import_module('matplotlib',
import_kwargs={'fromlist': ['cm', 'collections']},
min_module_version='1.1.0')
if matplotlib:
assert collections != matplotlib.collections
def test_min_module_version_python3_basestring_error():
with warns(UserWarning):
import_module('mpmath', min_module_version='1000.0.1')

View File

@ -0,0 +1,307 @@
from itertools import permutations
from sympy.external.ntheory import (bit_scan1, remove, bit_scan0, is_fermat_prp,
is_euler_prp, is_strong_prp, gcdext, _lucas_sequence,
is_fibonacci_prp, is_lucas_prp, is_selfridge_prp,
is_strong_lucas_prp, is_strong_selfridge_prp,
is_bpsw_prp, is_strong_bpsw_prp)
from sympy.testing.pytest import raises
def test_bit_scan1():
assert bit_scan1(0) is None
assert bit_scan1(1) == 0
assert bit_scan1(-1) == 0
assert bit_scan1(2) == 1
assert bit_scan1(7) == 0
assert bit_scan1(-7) == 0
for i in range(100):
assert bit_scan1(1 << i) == i
assert bit_scan1((1 << i) * 31337) == i
for i in range(500):
n = (1 << 500) + (1 << i)
assert bit_scan1(n) == i
assert bit_scan1(1 << 1000001) == 1000001
assert bit_scan1((1 << 273956)*7**37) == 273956
# issue 12709
for i in range(1, 10):
big = 1 << i
assert bit_scan1(-big) == bit_scan1(big)
def test_bit_scan0():
assert bit_scan0(-1) is None
assert bit_scan0(0) == 0
assert bit_scan0(1) == 1
assert bit_scan0(-2) == 0
def test_remove():
raises(ValueError, lambda: remove(1, 1))
assert remove(0, 3) == (0, 0)
for f in range(2, 10):
for y in range(2, 1000):
for z in [1, 17, 101, 1009]:
assert remove(z*f**y, f) == (z, y)
def test_gcdext():
assert gcdext(0, 0) == (0, 0, 0)
assert gcdext(3, 0) == (3, 1, 0)
assert gcdext(0, 4) == (4, 0, 1)
for n in range(1, 10):
assert gcdext(n, 1) == gcdext(-n, 1) == (1, 0, 1)
assert gcdext(n, -1) == gcdext(-n, -1) == (1, 0, -1)
assert gcdext(n, n) == gcdext(-n, n) == (n, 0, 1)
assert gcdext(n, -n) == gcdext(-n, -n) == (n, 0, -1)
for n in range(2, 10):
assert gcdext(1, n) == gcdext(1, -n) == (1, 1, 0)
assert gcdext(-1, n) == gcdext(-1, -n) == (1, -1, 0)
for a, b in permutations([2**5, 3, 5, 7**2, 11], 2):
g, x, y = gcdext(a, b)
assert g == a*x + b*y == 1
def test_is_fermat_prp():
# invalid input
raises(ValueError, lambda: is_fermat_prp(0, 10))
raises(ValueError, lambda: is_fermat_prp(5, 1))
# n = 1
assert not is_fermat_prp(1, 3)
# n is prime
assert is_fermat_prp(2, 4)
assert is_fermat_prp(3, 2)
assert is_fermat_prp(11, 3)
assert is_fermat_prp(2**31-1, 5)
# A001567
pseudorpime = [341, 561, 645, 1105, 1387, 1729, 1905, 2047,
2465, 2701, 2821, 3277, 4033, 4369, 4371, 4681]
for n in pseudorpime:
assert is_fermat_prp(n, 2)
# A020136
pseudorpime = [15, 85, 91, 341, 435, 451, 561, 645, 703, 1105,
1247, 1271, 1387, 1581, 1695, 1729, 1891, 1905]
for n in pseudorpime:
assert is_fermat_prp(n, 4)
def test_is_euler_prp():
# invalid input
raises(ValueError, lambda: is_euler_prp(0, 10))
raises(ValueError, lambda: is_euler_prp(5, 1))
# n = 1
assert not is_euler_prp(1, 3)
# n is prime
assert is_euler_prp(2, 4)
assert is_euler_prp(3, 2)
assert is_euler_prp(11, 3)
assert is_euler_prp(2**31-1, 5)
# A047713
pseudorpime = [561, 1105, 1729, 1905, 2047, 2465, 3277, 4033,
4681, 6601, 8321, 8481, 10585, 12801, 15841]
for n in pseudorpime:
assert is_euler_prp(n, 2)
# A048950
pseudorpime = [121, 703, 1729, 1891, 2821, 3281, 7381, 8401,
8911, 10585, 12403, 15457, 15841, 16531, 18721]
for n in pseudorpime:
assert is_euler_prp(n, 3)
def test_is_strong_prp():
# invalid input
raises(ValueError, lambda: is_strong_prp(0, 10))
raises(ValueError, lambda: is_strong_prp(5, 1))
# n = 1
assert not is_strong_prp(1, 3)
# n is prime
assert is_strong_prp(2, 4)
assert is_strong_prp(3, 2)
assert is_strong_prp(11, 3)
assert is_strong_prp(2**31-1, 5)
# A001262
pseudorpime = [2047, 3277, 4033, 4681, 8321, 15841, 29341,
42799, 49141, 52633, 65281, 74665, 80581]
for n in pseudorpime:
assert is_strong_prp(n, 2)
# A020229
pseudorpime = [121, 703, 1891, 3281, 8401, 8911, 10585, 12403,
16531, 18721, 19345, 23521, 31621, 44287, 47197]
for n in pseudorpime:
assert is_strong_prp(n, 3)
def test_lucas_sequence():
def lucas_u(P, Q, length):
array = [0] * length
array[1] = 1
for k in range(2, length):
array[k] = P * array[k - 1] - Q * array[k - 2]
return array
def lucas_v(P, Q, length):
array = [0] * length
array[0] = 2
array[1] = P
for k in range(2, length):
array[k] = P * array[k - 1] - Q * array[k - 2]
return array
length = 20
for P in range(-10, 10):
for Q in range(-10, 10):
D = P**2 - 4*Q
if D == 0:
continue
us = lucas_u(P, Q, length)
vs = lucas_v(P, Q, length)
for n in range(3, 100, 2):
for k in range(length):
U, V, Qk = _lucas_sequence(n, P, Q, k)
assert U == us[k] % n
assert V == vs[k] % n
assert pow(Q, k, n) == Qk
def test_is_fibonacci_prp():
# invalid input
raises(ValueError, lambda: is_fibonacci_prp(3, 2, 1))
raises(ValueError, lambda: is_fibonacci_prp(3, -5, 1))
raises(ValueError, lambda: is_fibonacci_prp(3, 5, 2))
raises(ValueError, lambda: is_fibonacci_prp(0, 5, -1))
# n = 1
assert not is_fibonacci_prp(1, 3, 1)
# n is prime
assert is_fibonacci_prp(2, 5, 1)
assert is_fibonacci_prp(3, 6, -1)
assert is_fibonacci_prp(11, 7, 1)
assert is_fibonacci_prp(2**31-1, 8, -1)
# A005845
pseudorpime = [705, 2465, 2737, 3745, 4181, 5777, 6721,
10877, 13201, 15251, 24465, 29281, 34561]
for n in pseudorpime:
assert is_fibonacci_prp(n, 1, -1)
def test_is_lucas_prp():
# invalid input
raises(ValueError, lambda: is_lucas_prp(3, 2, 1))
raises(ValueError, lambda: is_lucas_prp(0, 5, -1))
raises(ValueError, lambda: is_lucas_prp(15, 3, 1))
# n = 1
assert not is_lucas_prp(1, 3, 1)
# n is prime
assert is_lucas_prp(2, 5, 2)
assert is_lucas_prp(3, 6, -1)
assert is_lucas_prp(11, 7, 5)
assert is_lucas_prp(2**31-1, 8, -3)
# A081264
pseudorpime = [323, 377, 1891, 3827, 4181, 5777, 6601, 6721,
8149, 10877, 11663, 13201, 13981, 15251, 17119]
for n in pseudorpime:
assert is_lucas_prp(n, 1, -1)
def test_is_selfridge_prp():
# invalid input
raises(ValueError, lambda: is_selfridge_prp(0))
# n = 1
assert not is_selfridge_prp(1)
# n is prime
assert is_selfridge_prp(2)
assert is_selfridge_prp(3)
assert is_selfridge_prp(11)
assert is_selfridge_prp(2**31-1)
# A217120
pseudorpime = [323, 377, 1159, 1829, 3827, 5459, 5777, 9071,
9179, 10877, 11419, 11663, 13919, 14839, 16109]
for n in pseudorpime:
assert is_selfridge_prp(n)
def test_is_strong_lucas_prp():
# invalid input
raises(ValueError, lambda: is_strong_lucas_prp(3, 2, 1))
raises(ValueError, lambda: is_strong_lucas_prp(0, 5, -1))
raises(ValueError, lambda: is_strong_lucas_prp(15, 3, 1))
# n = 1
assert not is_strong_lucas_prp(1, 3, 1)
# n is prime
assert is_strong_lucas_prp(2, 5, 2)
assert is_strong_lucas_prp(3, 6, -1)
assert is_strong_lucas_prp(11, 7, 5)
assert is_strong_lucas_prp(2**31-1, 8, -3)
def test_is_strong_selfridge_prp():
# invalid input
raises(ValueError, lambda: is_strong_selfridge_prp(0))
# n = 1
assert not is_strong_selfridge_prp(1)
# n is prime
assert is_strong_selfridge_prp(2)
assert is_strong_selfridge_prp(3)
assert is_strong_selfridge_prp(11)
assert is_strong_selfridge_prp(2**31-1)
# A217255
pseudorpime = [5459, 5777, 10877, 16109, 18971, 22499, 24569,
25199, 40309, 58519, 75077, 97439, 100127, 113573]
for n in pseudorpime:
assert is_strong_selfridge_prp(n)
def test_is_bpsw_prp():
# invalid input
raises(ValueError, lambda: is_bpsw_prp(0))
# n = 1
assert not is_bpsw_prp(1)
# n is prime
assert is_bpsw_prp(2)
assert is_bpsw_prp(3)
assert is_bpsw_prp(11)
assert is_bpsw_prp(2**31-1)
def test_is_strong_bpsw_prp():
# invalid input
raises(ValueError, lambda: is_strong_bpsw_prp(0))
# n = 1
assert not is_strong_bpsw_prp(1)
# n is prime
assert is_strong_bpsw_prp(2)
assert is_strong_bpsw_prp(3)
assert is_strong_bpsw_prp(11)
assert is_strong_bpsw_prp(2**31-1)

View File

@ -0,0 +1,335 @@
# This testfile tests SymPy <-> NumPy compatibility
# Don't test any SymPy features here. Just pure interaction with NumPy.
# Always write regular SymPy tests for anything, that can be tested in pure
# Python (without numpy). Here we test everything, that a user may need when
# using SymPy with NumPy
from sympy.external.importtools import version_tuple
from sympy.external import import_module
numpy = import_module('numpy')
if numpy:
array, matrix, ndarray = numpy.array, numpy.matrix, numpy.ndarray
else:
#bin/test will not execute any tests now
disabled = True
from sympy.core.numbers import (Float, Integer, Rational)
from sympy.core.symbol import (Symbol, symbols)
from sympy.functions.elementary.trigonometric import sin
from sympy.matrices.dense import (Matrix, list2numpy, matrix2numpy, symarray)
from sympy.utilities.lambdify import lambdify
import sympy
import mpmath
from sympy.abc import x, y, z
from sympy.utilities.decorator import conserve_mpmath_dps
from sympy.utilities.exceptions import ignore_warnings
from sympy.testing.pytest import raises
# first, systematically check, that all operations are implemented and don't
# raise an exception
def test_systematic_basic():
def s(sympy_object, numpy_array):
_ = [sympy_object + numpy_array,
numpy_array + sympy_object,
sympy_object - numpy_array,
numpy_array - sympy_object,
sympy_object * numpy_array,
numpy_array * sympy_object,
sympy_object / numpy_array,
numpy_array / sympy_object,
sympy_object ** numpy_array,
numpy_array ** sympy_object]
x = Symbol("x")
y = Symbol("y")
sympy_objs = [
Rational(2, 3),
Float("1.3"),
x,
y,
pow(x, y)*y,
Integer(5),
Float(5.5),
]
numpy_objs = [
array([1]),
array([3, 8, -1]),
array([x, x**2, Rational(5)]),
array([x/y*sin(y), 5, Rational(5)]),
]
for x in sympy_objs:
for y in numpy_objs:
s(x, y)
# now some random tests, that test particular problems and that also
# check that the results of the operations are correct
def test_basics():
one = Rational(1)
zero = Rational(0)
assert array(1) == array(one)
assert array([one]) == array([one])
assert array([x]) == array([x])
assert array(x) == array(Symbol("x"))
assert array(one + x) == array(1 + x)
X = array([one, zero, zero])
assert (X == array([one, zero, zero])).all()
assert (X == array([one, 0, 0])).all()
def test_arrays():
one = Rational(1)
zero = Rational(0)
X = array([one, zero, zero])
Y = one*X
X = array([Symbol("a") + Rational(1, 2)])
Y = X + X
assert Y == array([1 + 2*Symbol("a")])
Y = Y + 1
assert Y == array([2 + 2*Symbol("a")])
Y = X - X
assert Y == array([0])
def test_conversion1():
a = list2numpy([x**2, x])
#looks like an array?
assert isinstance(a, ndarray)
assert a[0] == x**2
assert a[1] == x
assert len(a) == 2
#yes, it's the array
def test_conversion2():
a = 2*list2numpy([x**2, x])
b = list2numpy([2*x**2, 2*x])
assert (a == b).all()
one = Rational(1)
zero = Rational(0)
X = list2numpy([one, zero, zero])
Y = one*X
X = list2numpy([Symbol("a") + Rational(1, 2)])
Y = X + X
assert Y == array([1 + 2*Symbol("a")])
Y = Y + 1
assert Y == array([2 + 2*Symbol("a")])
Y = X - X
assert Y == array([0])
def test_list2numpy():
assert (array([x**2, x]) == list2numpy([x**2, x])).all()
def test_Matrix1():
m = Matrix([[x, x**2], [5, 2/x]])
assert (array(m.subs(x, 2)) == array([[2, 4], [5, 1]])).all()
m = Matrix([[sin(x), x**2], [5, 2/x]])
assert (array(m.subs(x, 2)) == array([[sin(2), 4], [5, 1]])).all()
def test_Matrix2():
m = Matrix([[x, x**2], [5, 2/x]])
with ignore_warnings(PendingDeprecationWarning):
assert (matrix(m.subs(x, 2)) == matrix([[2, 4], [5, 1]])).all()
m = Matrix([[sin(x), x**2], [5, 2/x]])
with ignore_warnings(PendingDeprecationWarning):
assert (matrix(m.subs(x, 2)) == matrix([[sin(2), 4], [5, 1]])).all()
def test_Matrix3():
a = array([[2, 4], [5, 1]])
assert Matrix(a) == Matrix([[2, 4], [5, 1]])
assert Matrix(a) != Matrix([[2, 4], [5, 2]])
a = array([[sin(2), 4], [5, 1]])
assert Matrix(a) == Matrix([[sin(2), 4], [5, 1]])
assert Matrix(a) != Matrix([[sin(0), 4], [5, 1]])
def test_Matrix4():
with ignore_warnings(PendingDeprecationWarning):
a = matrix([[2, 4], [5, 1]])
assert Matrix(a) == Matrix([[2, 4], [5, 1]])
assert Matrix(a) != Matrix([[2, 4], [5, 2]])
with ignore_warnings(PendingDeprecationWarning):
a = matrix([[sin(2), 4], [5, 1]])
assert Matrix(a) == Matrix([[sin(2), 4], [5, 1]])
assert Matrix(a) != Matrix([[sin(0), 4], [5, 1]])
def test_Matrix_sum():
M = Matrix([[1, 2, 3], [x, y, x], [2*y, -50, z*x]])
with ignore_warnings(PendingDeprecationWarning):
m = matrix([[2, 3, 4], [x, 5, 6], [x, y, z**2]])
assert M + m == Matrix([[3, 5, 7], [2*x, y + 5, x + 6], [2*y + x, y - 50, z*x + z**2]])
assert m + M == Matrix([[3, 5, 7], [2*x, y + 5, x + 6], [2*y + x, y - 50, z*x + z**2]])
assert M + m == M.add(m)
def test_Matrix_mul():
M = Matrix([[1, 2, 3], [x, y, x]])
with ignore_warnings(PendingDeprecationWarning):
m = matrix([[2, 4], [x, 6], [x, z**2]])
assert M*m == Matrix([
[ 2 + 5*x, 16 + 3*z**2],
[2*x + x*y + x**2, 4*x + 6*y + x*z**2],
])
assert m*M == Matrix([
[ 2 + 4*x, 4 + 4*y, 6 + 4*x],
[ 7*x, 2*x + 6*y, 9*x],
[x + x*z**2, 2*x + y*z**2, 3*x + x*z**2],
])
a = array([2])
assert a[0] * M == 2 * M
assert M * a[0] == 2 * M
def test_Matrix_array():
class matarray:
def __array__(self, dtype=object, copy=None):
if copy is not None and not copy:
raise TypeError("Cannot implement copy=False when converting Matrix to ndarray")
from numpy import array
return array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
matarr = matarray()
assert Matrix(matarr) == Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
def test_matrix2numpy():
a = matrix2numpy(Matrix([[1, x**2], [3*sin(x), 0]]))
assert isinstance(a, ndarray)
assert a.shape == (2, 2)
assert a[0, 0] == 1
assert a[0, 1] == x**2
assert a[1, 0] == 3*sin(x)
assert a[1, 1] == 0
def test_matrix2numpy_conversion():
a = Matrix([[1, 2, sin(x)], [x**2, x, Rational(1, 2)]])
b = array([[1, 2, sin(x)], [x**2, x, Rational(1, 2)]])
assert (matrix2numpy(a) == b).all()
assert matrix2numpy(a).dtype == numpy.dtype('object')
c = matrix2numpy(Matrix([[1, 2], [10, 20]]), dtype='int8')
d = matrix2numpy(Matrix([[1, 2], [10, 20]]), dtype='float64')
assert c.dtype == numpy.dtype('int8')
assert d.dtype == numpy.dtype('float64')
def test_issue_3728():
assert (Rational(1, 2)*array([2*x, 0]) == array([x, 0])).all()
assert (Rational(1, 2) + array(
[2*x, 0]) == array([2*x + Rational(1, 2), Rational(1, 2)])).all()
assert (Float("0.5")*array([2*x, 0]) == array([Float("1.0")*x, 0])).all()
assert (Float("0.5") + array(
[2*x, 0]) == array([2*x + Float("0.5"), Float("0.5")])).all()
@conserve_mpmath_dps
def test_lambdify():
mpmath.mp.dps = 16
sin02 = mpmath.mpf("0.198669330795061215459412627")
f = lambdify(x, sin(x), "numpy")
prec = 1e-15
assert -prec < f(0.2) - sin02 < prec
# if this succeeds, it can't be a numpy function
if version_tuple(numpy.__version__) >= version_tuple('1.17'):
with raises(TypeError):
f(x)
else:
with raises(AttributeError):
f(x)
def test_lambdify_matrix():
f = lambdify(x, Matrix([[x, 2*x], [1, 2]]), [{'ImmutableMatrix': numpy.array}, "numpy"])
assert (f(1) == array([[1, 2], [1, 2]])).all()
def test_lambdify_matrix_multi_input():
M = sympy.Matrix([[x**2, x*y, x*z],
[y*x, y**2, y*z],
[z*x, z*y, z**2]])
f = lambdify((x, y, z), M, [{'ImmutableMatrix': numpy.array}, "numpy"])
xh, yh, zh = 1.0, 2.0, 3.0
expected = array([[xh**2, xh*yh, xh*zh],
[yh*xh, yh**2, yh*zh],
[zh*xh, zh*yh, zh**2]])
actual = f(xh, yh, zh)
assert numpy.allclose(actual, expected)
def test_lambdify_matrix_vec_input():
X = sympy.DeferredVector('X')
M = Matrix([
[X[0]**2, X[0]*X[1], X[0]*X[2]],
[X[1]*X[0], X[1]**2, X[1]*X[2]],
[X[2]*X[0], X[2]*X[1], X[2]**2]])
f = lambdify(X, M, [{'ImmutableMatrix': numpy.array}, "numpy"])
Xh = array([1.0, 2.0, 3.0])
expected = array([[Xh[0]**2, Xh[0]*Xh[1], Xh[0]*Xh[2]],
[Xh[1]*Xh[0], Xh[1]**2, Xh[1]*Xh[2]],
[Xh[2]*Xh[0], Xh[2]*Xh[1], Xh[2]**2]])
actual = f(Xh)
assert numpy.allclose(actual, expected)
def test_lambdify_transl():
from sympy.utilities.lambdify import NUMPY_TRANSLATIONS
for sym, mat in NUMPY_TRANSLATIONS.items():
assert sym in sympy.__dict__
assert mat in numpy.__dict__
def test_symarray():
"""Test creation of numpy arrays of SymPy symbols."""
import numpy as np
import numpy.testing as npt
syms = symbols('_0,_1,_2')
s1 = symarray("", 3)
s2 = symarray("", 3)
npt.assert_array_equal(s1, np.array(syms, dtype=object))
assert s1[0] == s2[0]
a = symarray('a', 3)
b = symarray('b', 3)
assert not(a[0] == b[0])
asyms = symbols('a_0,a_1,a_2')
npt.assert_array_equal(a, np.array(asyms, dtype=object))
# Multidimensional checks
a2d = symarray('a', (2, 3))
assert a2d.shape == (2, 3)
a00, a12 = symbols('a_0_0,a_1_2')
assert a2d[0, 0] == a00
assert a2d[1, 2] == a12
a3d = symarray('a', (2, 3, 2))
assert a3d.shape == (2, 3, 2)
a000, a120, a121 = symbols('a_0_0_0,a_1_2_0,a_1_2_1')
assert a3d[0, 0, 0] == a000
assert a3d[1, 2, 0] == a120
assert a3d[1, 2, 1] == a121
def test_vectorize():
assert (numpy.vectorize(
sin)([1, 2, 3]) == numpy.array([sin(1), sin(2), sin(3)])).all()

View File

@ -0,0 +1,176 @@
"""
test_pythonmpq.py
Test the PythonMPQ class for consistency with gmpy2's mpq type. If gmpy2 is
installed run the same tests for both.
"""
from fractions import Fraction
from decimal import Decimal
import pickle
from typing import Callable, List, Tuple, Type
from sympy.testing.pytest import raises
from sympy.external.pythonmpq import PythonMPQ
#
# If gmpy2 is installed then run the tests for both mpq and PythonMPQ.
# That should ensure consistency between the implementation here and mpq.
#
rational_types: List[Tuple[Callable, Type, Callable, Type]]
rational_types = [(PythonMPQ, PythonMPQ, int, int)]
try:
from gmpy2 import mpq, mpz
rational_types.append((mpq, type(mpq(1)), mpz, type(mpz(1))))
except ImportError:
pass
def test_PythonMPQ():
#
# Test PythonMPQ and also mpq if gmpy/gmpy2 is installed.
#
for Q, TQ, Z, TZ in rational_types:
def check_Q(q):
assert isinstance(q, TQ)
assert isinstance(q.numerator, TZ)
assert isinstance(q.denominator, TZ)
return q.numerator, q.denominator
# Check construction from different types
assert check_Q(Q(3)) == (3, 1)
assert check_Q(Q(3, 5)) == (3, 5)
assert check_Q(Q(Q(3, 5))) == (3, 5)
assert check_Q(Q(0.5)) == (1, 2)
assert check_Q(Q('0.5')) == (1, 2)
assert check_Q(Q(Fraction(3, 5))) == (3, 5)
# https://github.com/aleaxit/gmpy/issues/327
if Q is PythonMPQ:
assert check_Q(Q(Decimal('0.6'))) == (3, 5)
# Invalid types
raises(TypeError, lambda: Q([]))
raises(TypeError, lambda: Q([], []))
# Check normalisation of signs
assert check_Q(Q(2, 3)) == (2, 3)
assert check_Q(Q(-2, 3)) == (-2, 3)
assert check_Q(Q(2, -3)) == (-2, 3)
assert check_Q(Q(-2, -3)) == (2, 3)
# Check gcd calculation
assert check_Q(Q(12, 8)) == (3, 2)
# __int__/__float__
assert int(Q(5, 3)) == 1
assert int(Q(-5, 3)) == -1
assert float(Q(5, 2)) == 2.5
assert float(Q(-5, 2)) == -2.5
# __str__/__repr__
assert str(Q(2, 1)) == "2"
assert str(Q(1, 2)) == "1/2"
if Q is PythonMPQ:
assert repr(Q(2, 1)) == "MPQ(2,1)"
assert repr(Q(1, 2)) == "MPQ(1,2)"
else:
assert repr(Q(2, 1)) == "mpq(2,1)"
assert repr(Q(1, 2)) == "mpq(1,2)"
# __bool__
assert bool(Q(1, 2)) is True
assert bool(Q(0)) is False
# __eq__/__ne__
assert (Q(2, 3) == Q(2, 3)) is True
assert (Q(2, 3) == Q(2, 5)) is False
assert (Q(2, 3) != Q(2, 3)) is False
assert (Q(2, 3) != Q(2, 5)) is True
# __hash__
assert hash(Q(3, 5)) == hash(Fraction(3, 5))
# __reduce__
q = Q(2, 3)
assert pickle.loads(pickle.dumps(q)) == q
# __ge__/__gt__/__le__/__lt__
assert (Q(1, 3) < Q(2, 3)) is True
assert (Q(2, 3) < Q(2, 3)) is False
assert (Q(2, 3) < Q(1, 3)) is False
assert (Q(-2, 3) < Q(1, 3)) is True
assert (Q(1, 3) < Q(-2, 3)) is False
assert (Q(1, 3) <= Q(2, 3)) is True
assert (Q(2, 3) <= Q(2, 3)) is True
assert (Q(2, 3) <= Q(1, 3)) is False
assert (Q(-2, 3) <= Q(1, 3)) is True
assert (Q(1, 3) <= Q(-2, 3)) is False
assert (Q(1, 3) > Q(2, 3)) is False
assert (Q(2, 3) > Q(2, 3)) is False
assert (Q(2, 3) > Q(1, 3)) is True
assert (Q(-2, 3) > Q(1, 3)) is False
assert (Q(1, 3) > Q(-2, 3)) is True
assert (Q(1, 3) >= Q(2, 3)) is False
assert (Q(2, 3) >= Q(2, 3)) is True
assert (Q(2, 3) >= Q(1, 3)) is True
assert (Q(-2, 3) >= Q(1, 3)) is False
assert (Q(1, 3) >= Q(-2, 3)) is True
# __abs__/__pos__/__neg__
assert abs(Q(2, 3)) == abs(Q(-2, 3)) == Q(2, 3)
assert +Q(2, 3) == Q(2, 3)
assert -Q(2, 3) == Q(-2, 3)
# __add__/__radd__
assert Q(2, 3) + Q(5, 7) == Q(29, 21)
assert Q(2, 3) + 1 == Q(5, 3)
assert 1 + Q(2, 3) == Q(5, 3)
raises(TypeError, lambda: [] + Q(1))
raises(TypeError, lambda: Q(1) + [])
# __sub__/__rsub__
assert Q(2, 3) - Q(5, 7) == Q(-1, 21)
assert Q(2, 3) - 1 == Q(-1, 3)
assert 1 - Q(2, 3) == Q(1, 3)
raises(TypeError, lambda: [] - Q(1))
raises(TypeError, lambda: Q(1) - [])
# __mul__/__rmul__
assert Q(2, 3) * Q(5, 7) == Q(10, 21)
assert Q(2, 3) * 1 == Q(2, 3)
assert 1 * Q(2, 3) == Q(2, 3)
raises(TypeError, lambda: [] * Q(1))
raises(TypeError, lambda: Q(1) * [])
# __pow__/__rpow__
assert Q(2, 3) ** 2 == Q(4, 9)
assert Q(2, 3) ** 1 == Q(2, 3)
assert Q(-2, 3) ** 2 == Q(4, 9)
assert Q(-2, 3) ** -1 == Q(-3, 2)
if Q is PythonMPQ:
raises(TypeError, lambda: 1 ** Q(2, 3))
raises(TypeError, lambda: Q(1, 4) ** Q(1, 2))
raises(TypeError, lambda: [] ** Q(1))
raises(TypeError, lambda: Q(1) ** [])
# __div__/__rdiv__
assert Q(2, 3) / Q(5, 7) == Q(14, 15)
assert Q(2, 3) / 1 == Q(2, 3)
assert 1 / Q(2, 3) == Q(3, 2)
raises(TypeError, lambda: [] / Q(1))
raises(TypeError, lambda: Q(1) / [])
raises(ZeroDivisionError, lambda: Q(1, 2) / Q(0))
# __divmod__
if Q is PythonMPQ:
raises(TypeError, lambda: Q(2, 3) // Q(1, 3))
raises(TypeError, lambda: Q(2, 3) % Q(1, 3))
raises(TypeError, lambda: 1 // Q(1, 3))
raises(TypeError, lambda: 1 % Q(1, 3))
raises(TypeError, lambda: Q(2, 3) // 1)
raises(TypeError, lambda: Q(2, 3) % 1)

View File

@ -0,0 +1,35 @@
# This testfile tests SymPy <-> SciPy compatibility
# Don't test any SymPy features here. Just pure interaction with SciPy.
# Always write regular SymPy tests for anything, that can be tested in pure
# Python (without scipy). Here we test everything, that a user may need when
# using SymPy with SciPy
from sympy.external import import_module
scipy = import_module('scipy')
if not scipy:
#bin/test will not execute any tests now
disabled = True
from sympy.functions.special.bessel import jn_zeros
def eq(a, b, tol=1e-6):
for x, y in zip(a, b):
if not (abs(x - y) < tol):
return False
return True
def test_jn_zeros():
assert eq(jn_zeros(0, 4, method="scipy"),
[3.141592, 6.283185, 9.424777, 12.566370])
assert eq(jn_zeros(1, 4, method="scipy"),
[4.493409, 7.725251, 10.904121, 14.066193])
assert eq(jn_zeros(2, 4, method="scipy"),
[5.763459, 9.095011, 12.322940, 15.514603])
assert eq(jn_zeros(3, 4, method="scipy"),
[6.987932, 10.417118, 13.698023, 16.923621])
assert eq(jn_zeros(4, 4, method="scipy"),
[8.182561, 11.704907, 15.039664, 18.301255])