I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,468 @@
__version__ = '1.3.0'
from .usertools import monitor, timing
from .ctx_fp import FPContext
from .ctx_mp import MPContext
from .ctx_iv import MPIntervalContext
fp = FPContext()
mp = MPContext()
iv = MPIntervalContext()
fp._mp = mp
mp._mp = mp
iv._mp = mp
mp._fp = fp
fp._fp = fp
mp._iv = iv
fp._iv = iv
iv._iv = iv
# XXX: extremely bad pickle hack
from . import ctx_mp as _ctx_mp
_ctx_mp._mpf_module.mpf = mp.mpf
_ctx_mp._mpf_module.mpc = mp.mpc
make_mpf = mp.make_mpf
make_mpc = mp.make_mpc
extraprec = mp.extraprec
extradps = mp.extradps
workprec = mp.workprec
workdps = mp.workdps
autoprec = mp.autoprec
maxcalls = mp.maxcalls
memoize = mp.memoize
mag = mp.mag
bernfrac = mp.bernfrac
qfrom = mp.qfrom
mfrom = mp.mfrom
kfrom = mp.kfrom
taufrom = mp.taufrom
qbarfrom = mp.qbarfrom
ellipfun = mp.ellipfun
jtheta = mp.jtheta
kleinj = mp.kleinj
eta = mp.eta
qp = mp.qp
qhyper = mp.qhyper
qgamma = mp.qgamma
qfac = mp.qfac
nint_distance = mp.nint_distance
plot = mp.plot
cplot = mp.cplot
splot = mp.splot
odefun = mp.odefun
jacobian = mp.jacobian
findroot = mp.findroot
multiplicity = mp.multiplicity
isinf = mp.isinf
isnan = mp.isnan
isnormal = mp.isnormal
isint = mp.isint
isfinite = mp.isfinite
almosteq = mp.almosteq
nan = mp.nan
rand = mp.rand
absmin = mp.absmin
absmax = mp.absmax
fraction = mp.fraction
linspace = mp.linspace
arange = mp.arange
mpmathify = convert = mp.convert
mpc = mp.mpc
mpi = iv._mpi
nstr = mp.nstr
nprint = mp.nprint
chop = mp.chop
fneg = mp.fneg
fadd = mp.fadd
fsub = mp.fsub
fmul = mp.fmul
fdiv = mp.fdiv
fprod = mp.fprod
quad = mp.quad
quadgl = mp.quadgl
quadts = mp.quadts
quadosc = mp.quadosc
quadsubdiv = mp.quadsubdiv
invertlaplace = mp.invertlaplace
invlaptalbot = mp.invlaptalbot
invlapstehfest = mp.invlapstehfest
invlapdehoog = mp.invlapdehoog
pslq = mp.pslq
identify = mp.identify
findpoly = mp.findpoly
richardson = mp.richardson
shanks = mp.shanks
levin = mp.levin
cohen_alt = mp.cohen_alt
nsum = mp.nsum
nprod = mp.nprod
difference = mp.difference
diff = mp.diff
diffs = mp.diffs
diffs_prod = mp.diffs_prod
diffs_exp = mp.diffs_exp
diffun = mp.diffun
differint = mp.differint
taylor = mp.taylor
pade = mp.pade
polyval = mp.polyval
polyroots = mp.polyroots
fourier = mp.fourier
fourierval = mp.fourierval
sumem = mp.sumem
sumap = mp.sumap
chebyfit = mp.chebyfit
limit = mp.limit
matrix = mp.matrix
eye = mp.eye
diag = mp.diag
zeros = mp.zeros
ones = mp.ones
hilbert = mp.hilbert
randmatrix = mp.randmatrix
swap_row = mp.swap_row
extend = mp.extend
norm = mp.norm
mnorm = mp.mnorm
lu_solve = mp.lu_solve
lu = mp.lu
qr = mp.qr
unitvector = mp.unitvector
inverse = mp.inverse
residual = mp.residual
qr_solve = mp.qr_solve
cholesky = mp.cholesky
cholesky_solve = mp.cholesky_solve
det = mp.det
cond = mp.cond
hessenberg = mp.hessenberg
schur = mp.schur
eig = mp.eig
eig_sort = mp.eig_sort
eigsy = mp.eigsy
eighe = mp.eighe
eigh = mp.eigh
svd_r = mp.svd_r
svd_c = mp.svd_c
svd = mp.svd
gauss_quadrature = mp.gauss_quadrature
expm = mp.expm
sqrtm = mp.sqrtm
powm = mp.powm
logm = mp.logm
sinm = mp.sinm
cosm = mp.cosm
mpf = mp.mpf
j = mp.j
exp = mp.exp
expj = mp.expj
expjpi = mp.expjpi
ln = mp.ln
im = mp.im
re = mp.re
inf = mp.inf
ninf = mp.ninf
sign = mp.sign
eps = mp.eps
pi = mp.pi
ln2 = mp.ln2
ln10 = mp.ln10
phi = mp.phi
e = mp.e
euler = mp.euler
catalan = mp.catalan
khinchin = mp.khinchin
glaisher = mp.glaisher
apery = mp.apery
degree = mp.degree
twinprime = mp.twinprime
mertens = mp.mertens
ldexp = mp.ldexp
frexp = mp.frexp
fsum = mp.fsum
fdot = mp.fdot
sqrt = mp.sqrt
cbrt = mp.cbrt
exp = mp.exp
ln = mp.ln
log = mp.log
log10 = mp.log10
power = mp.power
cos = mp.cos
sin = mp.sin
tan = mp.tan
cosh = mp.cosh
sinh = mp.sinh
tanh = mp.tanh
acos = mp.acos
asin = mp.asin
atan = mp.atan
asinh = mp.asinh
acosh = mp.acosh
atanh = mp.atanh
sec = mp.sec
csc = mp.csc
cot = mp.cot
sech = mp.sech
csch = mp.csch
coth = mp.coth
asec = mp.asec
acsc = mp.acsc
acot = mp.acot
asech = mp.asech
acsch = mp.acsch
acoth = mp.acoth
cospi = mp.cospi
sinpi = mp.sinpi
sinc = mp.sinc
sincpi = mp.sincpi
cos_sin = mp.cos_sin
cospi_sinpi = mp.cospi_sinpi
fabs = mp.fabs
re = mp.re
im = mp.im
conj = mp.conj
floor = mp.floor
ceil = mp.ceil
nint = mp.nint
frac = mp.frac
root = mp.root
nthroot = mp.nthroot
hypot = mp.hypot
fmod = mp.fmod
ldexp = mp.ldexp
frexp = mp.frexp
sign = mp.sign
arg = mp.arg
phase = mp.phase
polar = mp.polar
rect = mp.rect
degrees = mp.degrees
radians = mp.radians
atan2 = mp.atan2
fib = mp.fib
fibonacci = mp.fibonacci
lambertw = mp.lambertw
zeta = mp.zeta
altzeta = mp.altzeta
gamma = mp.gamma
rgamma = mp.rgamma
factorial = mp.factorial
fac = mp.fac
fac2 = mp.fac2
beta = mp.beta
betainc = mp.betainc
psi = mp.psi
#psi0 = mp.psi0
#psi1 = mp.psi1
#psi2 = mp.psi2
#psi3 = mp.psi3
polygamma = mp.polygamma
digamma = mp.digamma
#trigamma = mp.trigamma
#tetragamma = mp.tetragamma
#pentagamma = mp.pentagamma
harmonic = mp.harmonic
bernoulli = mp.bernoulli
bernfrac = mp.bernfrac
stieltjes = mp.stieltjes
hurwitz = mp.hurwitz
dirichlet = mp.dirichlet
bernpoly = mp.bernpoly
eulerpoly = mp.eulerpoly
eulernum = mp.eulernum
polylog = mp.polylog
clsin = mp.clsin
clcos = mp.clcos
gammainc = mp.gammainc
gammaprod = mp.gammaprod
binomial = mp.binomial
rf = mp.rf
ff = mp.ff
hyper = mp.hyper
hyp0f1 = mp.hyp0f1
hyp1f1 = mp.hyp1f1
hyp1f2 = mp.hyp1f2
hyp2f1 = mp.hyp2f1
hyp2f2 = mp.hyp2f2
hyp2f0 = mp.hyp2f0
hyp2f3 = mp.hyp2f3
hyp3f2 = mp.hyp3f2
hyperu = mp.hyperu
hypercomb = mp.hypercomb
meijerg = mp.meijerg
appellf1 = mp.appellf1
appellf2 = mp.appellf2
appellf3 = mp.appellf3
appellf4 = mp.appellf4
hyper2d = mp.hyper2d
bihyper = mp.bihyper
erf = mp.erf
erfc = mp.erfc
erfi = mp.erfi
erfinv = mp.erfinv
npdf = mp.npdf
ncdf = mp.ncdf
expint = mp.expint
e1 = mp.e1
ei = mp.ei
li = mp.li
ci = mp.ci
si = mp.si
chi = mp.chi
shi = mp.shi
fresnels = mp.fresnels
fresnelc = mp.fresnelc
airyai = mp.airyai
airybi = mp.airybi
airyaizero = mp.airyaizero
airybizero = mp.airybizero
scorergi = mp.scorergi
scorerhi = mp.scorerhi
ellipk = mp.ellipk
ellipe = mp.ellipe
ellipf = mp.ellipf
ellippi = mp.ellippi
elliprc = mp.elliprc
elliprj = mp.elliprj
elliprf = mp.elliprf
elliprd = mp.elliprd
elliprg = mp.elliprg
agm = mp.agm
jacobi = mp.jacobi
chebyt = mp.chebyt
chebyu = mp.chebyu
legendre = mp.legendre
legenp = mp.legenp
legenq = mp.legenq
hermite = mp.hermite
pcfd = mp.pcfd
pcfu = mp.pcfu
pcfv = mp.pcfv
pcfw = mp.pcfw
gegenbauer = mp.gegenbauer
laguerre = mp.laguerre
spherharm = mp.spherharm
besselj = mp.besselj
j0 = mp.j0
j1 = mp.j1
besseli = mp.besseli
bessely = mp.bessely
besselk = mp.besselk
besseljzero = mp.besseljzero
besselyzero = mp.besselyzero
hankel1 = mp.hankel1
hankel2 = mp.hankel2
struveh = mp.struveh
struvel = mp.struvel
angerj = mp.angerj
webere = mp.webere
lommels1 = mp.lommels1
lommels2 = mp.lommels2
whitm = mp.whitm
whitw = mp.whitw
ber = mp.ber
bei = mp.bei
ker = mp.ker
kei = mp.kei
coulombc = mp.coulombc
coulombf = mp.coulombf
coulombg = mp.coulombg
barnesg = mp.barnesg
superfac = mp.superfac
hyperfac = mp.hyperfac
loggamma = mp.loggamma
siegeltheta = mp.siegeltheta
siegelz = mp.siegelz
grampoint = mp.grampoint
zetazero = mp.zetazero
riemannr = mp.riemannr
primepi = mp.primepi
primepi2 = mp.primepi2
primezeta = mp.primezeta
bell = mp.bell
polyexp = mp.polyexp
expm1 = mp.expm1
log1p = mp.log1p
powm1 = mp.powm1
unitroots = mp.unitroots
cyclotomic = mp.cyclotomic
mangoldt = mp.mangoldt
secondzeta = mp.secondzeta
nzeros = mp.nzeros
backlunds = mp.backlunds
lerchphi = mp.lerchphi
stirling1 = mp.stirling1
stirling2 = mp.stirling2
squarew = mp.squarew
trianglew = mp.trianglew
sawtoothw = mp.sawtoothw
unit_triangle = mp.unit_triangle
sigmoid = mp.sigmoid
# be careful when changing this name, don't use test*!
def runtests():
"""
Run all mpmath tests and print output.
"""
import os.path
from inspect import getsourcefile
from .tests import runtests as tests
testdir = os.path.dirname(os.path.abspath(getsourcefile(tests)))
importdir = os.path.abspath(testdir + '/../..')
tests.testit(importdir, testdir)
def doctests(filter=[]):
import sys
from timeit import default_timer as clock
for i, arg in enumerate(sys.argv):
if '__init__.py' in arg:
filter = [sn for sn in sys.argv[i+1:] if not sn.startswith("-")]
break
import doctest
globs = globals().copy()
for obj in globs: #sorted(globs.keys()):
if filter:
if not sum([pat in obj for pat in filter]):
continue
sys.stdout.write(str(obj) + " ")
sys.stdout.flush()
t1 = clock()
doctest.run_docstring_examples(globs[obj], {}, verbose=("-v" in sys.argv))
t2 = clock()
print(round(t2-t1, 3))
if __name__ == '__main__':
doctests()

View File

@ -0,0 +1,6 @@
from . import calculus
# XXX: hack to set methods
from . import approximation
from . import differentiation
from . import extrapolation
from . import polynomials

View File

@ -0,0 +1,246 @@
from ..libmp.backend import xrange
from .calculus import defun
#----------------------------------------------------------------------------#
# Approximation methods #
#----------------------------------------------------------------------------#
# The Chebyshev approximation formula is given at:
# http://mathworld.wolfram.com/ChebyshevApproximationFormula.html
# The only major changes in the following code is that we return the
# expanded polynomial coefficients instead of Chebyshev coefficients,
# and that we automatically transform [a,b] -> [-1,1] and back
# for convenience.
# Coefficient in Chebyshev approximation
def chebcoeff(ctx,f,a,b,j,N):
s = ctx.mpf(0)
h = ctx.mpf(0.5)
for k in range(1, N+1):
t = ctx.cospi((k-h)/N)
s += f(t*(b-a)*h + (b+a)*h) * ctx.cospi(j*(k-h)/N)
return 2*s/N
# Generate Chebyshev polynomials T_n(ax+b) in expanded form
def chebT(ctx, a=1, b=0):
Tb = [1]
yield Tb
Ta = [b, a]
while 1:
yield Ta
# Recurrence: T[n+1](ax+b) = 2*(ax+b)*T[n](ax+b) - T[n-1](ax+b)
Tmp = [0] + [2*a*t for t in Ta]
for i, c in enumerate(Ta): Tmp[i] += 2*b*c
for i, c in enumerate(Tb): Tmp[i] -= c
Ta, Tb = Tmp, Ta
@defun
def chebyfit(ctx, f, interval, N, error=False):
r"""
Computes a polynomial of degree `N-1` that approximates the
given function `f` on the interval `[a, b]`. With ``error=True``,
:func:`~mpmath.chebyfit` also returns an accurate estimate of the
maximum absolute error; that is, the maximum value of
`|f(x) - P(x)|` for `x \in [a, b]`.
:func:`~mpmath.chebyfit` uses the Chebyshev approximation formula,
which gives a nearly optimal solution: that is, the maximum
error of the approximating polynomial is very close to
the smallest possible for any polynomial of the same degree.
Chebyshev approximation is very useful if one needs repeated
evaluation of an expensive function, such as function defined
implicitly by an integral or a differential equation. (For
example, it could be used to turn a slow mpmath function
into a fast machine-precision version of the same.)
**Examples**
Here we use :func:`~mpmath.chebyfit` to generate a low-degree approximation
of `f(x) = \cos(x)`, valid on the interval `[1, 2]`::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> poly, err = chebyfit(cos, [1, 2], 5, error=True)
>>> nprint(poly)
[0.00291682, 0.146166, -0.732491, 0.174141, 0.949553]
>>> nprint(err, 12)
1.61351758081e-5
The polynomial can be evaluated using ``polyval``::
>>> nprint(polyval(poly, 1.6), 12)
-0.0291858904138
>>> nprint(cos(1.6), 12)
-0.0291995223013
Sampling the true error at 1000 points shows that the error
estimate generated by ``chebyfit`` is remarkably good::
>>> error = lambda x: abs(cos(x) - polyval(poly, x))
>>> nprint(max([error(1+n/1000.) for n in range(1000)]), 12)
1.61349954245e-5
**Choice of degree**
The degree `N` can be set arbitrarily high, to obtain an
arbitrarily good approximation. As a rule of thumb, an
`N`-term Chebyshev approximation is good to `N/(b-a)` decimal
places on a unit interval (although this depends on how
well-behaved `f` is). The cost grows accordingly: ``chebyfit``
evaluates the function `(N^2)/2` times to compute the
coefficients and an additional `N` times to estimate the error.
**Possible issues**
One should be careful to use a sufficiently high working
precision both when calling ``chebyfit`` and when evaluating
the resulting polynomial, as the polynomial is sometimes
ill-conditioned. It is for example difficult to reach
15-digit accuracy when evaluating the polynomial using
machine precision floats, no matter the theoretical
accuracy of the polynomial. (The option to return the
coefficients in Chebyshev form should be made available
in the future.)
It is important to note the Chebyshev approximation works
poorly if `f` is not smooth. A function containing singularities,
rapid oscillation, etc can be approximated more effectively by
multiplying it by a weight function that cancels out the
nonsmooth features, or by dividing the interval into several
segments.
"""
a, b = ctx._as_points(interval)
orig = ctx.prec
try:
ctx.prec = orig + int(N**0.5) + 20
c = [chebcoeff(ctx,f,a,b,k,N) for k in range(N)]
d = [ctx.zero] * N
d[0] = -c[0]/2
h = ctx.mpf(0.5)
T = chebT(ctx, ctx.mpf(2)/(b-a), ctx.mpf(-1)*(b+a)/(b-a))
for (k, Tk) in zip(range(N), T):
for i in range(len(Tk)):
d[i] += c[k]*Tk[i]
d = d[::-1]
# Estimate maximum error
err = ctx.zero
for k in range(N):
x = ctx.cos(ctx.pi*k/N) * (b-a)*h + (b+a)*h
err = max(err, abs(f(x) - ctx.polyval(d, x)))
finally:
ctx.prec = orig
if error:
return d, +err
else:
return d
@defun
def fourier(ctx, f, interval, N):
r"""
Computes the Fourier series of degree `N` of the given function
on the interval `[a, b]`. More precisely, :func:`~mpmath.fourier` returns
two lists `(c, s)` of coefficients (the cosine series and sine
series, respectively), such that
.. math ::
f(x) \sim \sum_{k=0}^N
c_k \cos(k m x) + s_k \sin(k m x)
where `m = 2 \pi / (b-a)`.
Note that many texts define the first coefficient as `2 c_0` instead
of `c_0`. The easiest way to evaluate the computed series correctly
is to pass it to :func:`~mpmath.fourierval`.
**Examples**
The function `f(x) = x` has a simple Fourier series on the standard
interval `[-\pi, \pi]`. The cosine coefficients are all zero (because
the function has odd symmetry), and the sine coefficients are
rational numbers::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> c, s = fourier(lambda x: x, [-pi, pi], 5)
>>> nprint(c)
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
>>> nprint(s)
[0.0, 2.0, -1.0, 0.666667, -0.5, 0.4]
This computes a Fourier series of a nonsymmetric function on
a nonstandard interval::
>>> I = [-1, 1.5]
>>> f = lambda x: x**2 - 4*x + 1
>>> cs = fourier(f, I, 4)
>>> nprint(cs[0])
[0.583333, 1.12479, -1.27552, 0.904708, -0.441296]
>>> nprint(cs[1])
[0.0, -2.6255, 0.580905, 0.219974, -0.540057]
It is instructive to plot a function along with its truncated
Fourier series::
>>> plot([f, lambda x: fourierval(cs, I, x)], I) #doctest: +SKIP
Fourier series generally converge slowly (and may not converge
pointwise). For example, if `f(x) = \cosh(x)`, a 10-term Fourier
series gives an `L^2` error corresponding to 2-digit accuracy::
>>> I = [-1, 1]
>>> cs = fourier(cosh, I, 9)
>>> g = lambda x: (cosh(x) - fourierval(cs, I, x))**2
>>> nprint(sqrt(quad(g, I)))
0.00467963
:func:`~mpmath.fourier` uses numerical quadrature. For nonsmooth functions,
the accuracy (and speed) can be improved by including all singular
points in the interval specification::
>>> nprint(fourier(abs, [-1, 1], 0), 10)
([0.5000441648], [0.0])
>>> nprint(fourier(abs, [-1, 0, 1], 0), 10)
([0.5], [0.0])
"""
interval = ctx._as_points(interval)
a = interval[0]
b = interval[-1]
L = b-a
cos_series = []
sin_series = []
cutoff = ctx.eps*10
for n in xrange(N+1):
m = 2*n*ctx.pi/L
an = 2*ctx.quadgl(lambda t: f(t)*ctx.cos(m*t), interval)/L
bn = 2*ctx.quadgl(lambda t: f(t)*ctx.sin(m*t), interval)/L
if n == 0:
an /= 2
if abs(an) < cutoff: an = ctx.zero
if abs(bn) < cutoff: bn = ctx.zero
cos_series.append(an)
sin_series.append(bn)
return cos_series, sin_series
@defun
def fourierval(ctx, series, interval, x):
"""
Evaluates a Fourier series (in the format computed by
by :func:`~mpmath.fourier` for the given interval) at the point `x`.
The series should be a pair `(c, s)` where `c` is the
cosine series and `s` is the sine series. The two lists
need not have the same length.
"""
cs, ss = series
ab = ctx._as_points(interval)
a = interval[0]
b = interval[-1]
m = 2*ctx.pi/(ab[-1]-ab[0])
s = ctx.zero
s += ctx.fsum(cs[n]*ctx.cos(m*n*x) for n in xrange(len(cs)) if cs[n])
s += ctx.fsum(ss[n]*ctx.sin(m*n*x) for n in xrange(len(ss)) if ss[n])
return s

View File

@ -0,0 +1,6 @@
class CalculusMethods(object):
pass
def defun(f):
setattr(CalculusMethods, f.__name__, f)
return f

View File

@ -0,0 +1,647 @@
from ..libmp.backend import xrange
from .calculus import defun
try:
iteritems = dict.iteritems
except AttributeError:
iteritems = dict.items
#----------------------------------------------------------------------------#
# Differentiation #
#----------------------------------------------------------------------------#
@defun
def difference(ctx, s, n):
r"""
Given a sequence `(s_k)` containing at least `n+1` items, returns the
`n`-th forward difference,
.. math ::
\Delta^n = \sum_{k=0}^{\infty} (-1)^{k+n} {n \choose k} s_k.
"""
n = int(n)
d = ctx.zero
b = (-1) ** (n & 1)
for k in xrange(n+1):
d += b * s[k]
b = (b * (k-n)) // (k+1)
return d
def hsteps(ctx, f, x, n, prec, **options):
singular = options.get('singular')
addprec = options.get('addprec', 10)
direction = options.get('direction', 0)
workprec = (prec+2*addprec) * (n+1)
orig = ctx.prec
try:
ctx.prec = workprec
h = options.get('h')
if h is None:
if options.get('relative'):
hextramag = int(ctx.mag(x))
else:
hextramag = 0
h = ctx.ldexp(1, -prec-addprec-hextramag)
else:
h = ctx.convert(h)
# Directed: steps x, x+h, ... x+n*h
direction = options.get('direction', 0)
if direction:
h *= ctx.sign(direction)
steps = xrange(n+1)
norm = h
# Central: steps x-n*h, x-(n-2)*h ..., x, ..., x+(n-2)*h, x+n*h
else:
steps = xrange(-n, n+1, 2)
norm = (2*h)
# Perturb
if singular:
x += 0.5*h
values = [f(x+k*h) for k in steps]
return values, norm, workprec
finally:
ctx.prec = orig
@defun
def diff(ctx, f, x, n=1, **options):
r"""
Numerically computes the derivative of `f`, `f'(x)`, or generally for
an integer `n \ge 0`, the `n`-th derivative `f^{(n)}(x)`.
A few basic examples are::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> diff(lambda x: x**2 + x, 1.0)
3.0
>>> diff(lambda x: x**2 + x, 1.0, 2)
2.0
>>> diff(lambda x: x**2 + x, 1.0, 3)
0.0
>>> nprint([diff(exp, 3, n) for n in range(5)]) # exp'(x) = exp(x)
[20.0855, 20.0855, 20.0855, 20.0855, 20.0855]
Even more generally, given a tuple of arguments `(x_1, \ldots, x_k)`
and order `(n_1, \ldots, n_k)`, the partial derivative
`f^{(n_1,\ldots,n_k)}(x_1,\ldots,x_k)` is evaluated. For example::
>>> diff(lambda x,y: 3*x*y + 2*y - x, (0.25, 0.5), (0,1))
2.75
>>> diff(lambda x,y: 3*x*y + 2*y - x, (0.25, 0.5), (1,1))
3.0
**Options**
The following optional keyword arguments are recognized:
``method``
Supported methods are ``'step'`` or ``'quad'``: derivatives may be
computed using either a finite difference with a small step
size `h` (default), or numerical quadrature.
``direction``
Direction of finite difference: can be -1 for a left
difference, 0 for a central difference (default), or +1
for a right difference; more generally can be any complex number.
``addprec``
Extra precision for `h` used to account for the function's
sensitivity to perturbations (default = 10).
``relative``
Choose `h` relative to the magnitude of `x`, rather than an
absolute value; useful for large or tiny `x` (default = False).
``h``
As an alternative to ``addprec`` and ``relative``, manually
select the step size `h`.
``singular``
If True, evaluation exactly at the point `x` is avoided; this is
useful for differentiating functions with removable singularities.
Default = False.
``radius``
Radius of integration contour (with ``method = 'quad'``).
Default = 0.25. A larger radius typically is faster and more
accurate, but it must be chosen so that `f` has no
singularities within the radius from the evaluation point.
A finite difference requires `n+1` function evaluations and must be
performed at `(n+1)` times the target precision. Accordingly, `f` must
support fast evaluation at high precision.
With integration, a larger number of function evaluations is
required, but not much extra precision is required. For high order
derivatives, this method may thus be faster if f is very expensive to
evaluate at high precision.
**Further examples**
The direction option is useful for computing left- or right-sided
derivatives of nonsmooth functions::
>>> diff(abs, 0, direction=0)
0.0
>>> diff(abs, 0, direction=1)
1.0
>>> diff(abs, 0, direction=-1)
-1.0
More generally, if the direction is nonzero, a right difference
is computed where the step size is multiplied by sign(direction).
For example, with direction=+j, the derivative from the positive
imaginary direction will be computed::
>>> diff(abs, 0, direction=j)
(0.0 - 1.0j)
With integration, the result may have a small imaginary part
even even if the result is purely real::
>>> diff(sqrt, 1, method='quad') # doctest:+ELLIPSIS
(0.5 - 4.59...e-26j)
>>> chop(_)
0.5
Adding precision to obtain an accurate value::
>>> diff(cos, 1e-30)
0.0
>>> diff(cos, 1e-30, h=0.0001)
-9.99999998328279e-31
>>> diff(cos, 1e-30, addprec=100)
-1.0e-30
"""
partial = False
try:
orders = list(n)
x = list(x)
partial = True
except TypeError:
pass
if partial:
x = [ctx.convert(_) for _ in x]
return _partial_diff(ctx, f, x, orders, options)
method = options.get('method', 'step')
if n == 0 and method != 'quad' and not options.get('singular'):
return f(ctx.convert(x))
prec = ctx.prec
try:
if method == 'step':
values, norm, workprec = hsteps(ctx, f, x, n, prec, **options)
ctx.prec = workprec
v = ctx.difference(values, n) / norm**n
elif method == 'quad':
ctx.prec += 10
radius = ctx.convert(options.get('radius', 0.25))
def g(t):
rei = radius*ctx.expj(t)
z = x + rei
return f(z) / rei**n
d = ctx.quadts(g, [0, 2*ctx.pi])
v = d * ctx.factorial(n) / (2*ctx.pi)
else:
raise ValueError("unknown method: %r" % method)
finally:
ctx.prec = prec
return +v
def _partial_diff(ctx, f, xs, orders, options):
if not orders:
return f()
if not sum(orders):
return f(*xs)
i = 0
for i in range(len(orders)):
if orders[i]:
break
order = orders[i]
def fdiff_inner(*f_args):
def inner(t):
return f(*(f_args[:i] + (t,) + f_args[i+1:]))
return ctx.diff(inner, f_args[i], order, **options)
orders[i] = 0
return _partial_diff(ctx, fdiff_inner, xs, orders, options)
@defun
def diffs(ctx, f, x, n=None, **options):
r"""
Returns a generator that yields the sequence of derivatives
.. math ::
f(x), f'(x), f''(x), \ldots, f^{(k)}(x), \ldots
With ``method='step'``, :func:`~mpmath.diffs` uses only `O(k)`
function evaluations to generate the first `k` derivatives,
rather than the roughly `O(k^2)` evaluations
required if one calls :func:`~mpmath.diff` `k` separate times.
With `n < \infty`, the generator stops as soon as the
`n`-th derivative has been generated. If the exact number of
needed derivatives is known in advance, this is further
slightly more efficient.
Options are the same as for :func:`~mpmath.diff`.
**Examples**
>>> from mpmath import *
>>> mp.dps = 15
>>> nprint(list(diffs(cos, 1, 5)))
[0.540302, -0.841471, -0.540302, 0.841471, 0.540302, -0.841471]
>>> for i, d in zip(range(6), diffs(cos, 1)):
... print("%s %s" % (i, d))
...
0 0.54030230586814
1 -0.841470984807897
2 -0.54030230586814
3 0.841470984807897
4 0.54030230586814
5 -0.841470984807897
"""
if n is None:
n = ctx.inf
else:
n = int(n)
if options.get('method', 'step') != 'step':
k = 0
while k < n + 1:
yield ctx.diff(f, x, k, **options)
k += 1
return
singular = options.get('singular')
if singular:
yield ctx.diff(f, x, 0, singular=True)
else:
yield f(ctx.convert(x))
if n < 1:
return
if n == ctx.inf:
A, B = 1, 2
else:
A, B = 1, n+1
while 1:
callprec = ctx.prec
y, norm, workprec = hsteps(ctx, f, x, B, callprec, **options)
for k in xrange(A, B):
try:
ctx.prec = workprec
d = ctx.difference(y, k) / norm**k
finally:
ctx.prec = callprec
yield +d
if k >= n:
return
A, B = B, int(A*1.4+1)
B = min(B, n)
def iterable_to_function(gen):
gen = iter(gen)
data = []
def f(k):
for i in xrange(len(data), k+1):
data.append(next(gen))
return data[k]
return f
@defun
def diffs_prod(ctx, factors):
r"""
Given a list of `N` iterables or generators yielding
`f_k(x), f'_k(x), f''_k(x), \ldots` for `k = 1, \ldots, N`,
generate `g(x), g'(x), g''(x), \ldots` where
`g(x) = f_1(x) f_2(x) \cdots f_N(x)`.
At high precision and for large orders, this is typically more efficient
than numerical differentiation if the derivatives of each `f_k(x)`
admit direct computation.
Note: This function does not increase the working precision internally,
so guard digits may have to be added externally for full accuracy.
**Examples**
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> f = lambda x: exp(x)*cos(x)*sin(x)
>>> u = diffs(f, 1)
>>> v = mp.diffs_prod([diffs(exp,1), diffs(cos,1), diffs(sin,1)])
>>> next(u); next(v)
1.23586333600241
1.23586333600241
>>> next(u); next(v)
0.104658952245596
0.104658952245596
>>> next(u); next(v)
-5.96999877552086
-5.96999877552086
>>> next(u); next(v)
-12.4632923122697
-12.4632923122697
"""
N = len(factors)
if N == 1:
for c in factors[0]:
yield c
else:
u = iterable_to_function(ctx.diffs_prod(factors[:N//2]))
v = iterable_to_function(ctx.diffs_prod(factors[N//2:]))
n = 0
while 1:
#yield sum(binomial(n,k)*u(n-k)*v(k) for k in xrange(n+1))
s = u(n) * v(0)
a = 1
for k in xrange(1,n+1):
a = a * (n-k+1) // k
s += a * u(n-k) * v(k)
yield s
n += 1
def dpoly(n, _cache={}):
"""
nth differentiation polynomial for exp (Faa di Bruno's formula).
TODO: most exponents are zero, so maybe a sparse representation
would be better.
"""
if n in _cache:
return _cache[n]
if not _cache:
_cache[0] = {(0,):1}
R = dpoly(n-1)
R = dict((c+(0,),v) for (c,v) in iteritems(R))
Ra = {}
for powers, count in iteritems(R):
powers1 = (powers[0]+1,) + powers[1:]
if powers1 in Ra:
Ra[powers1] += count
else:
Ra[powers1] = count
for powers, count in iteritems(R):
if not sum(powers):
continue
for k,p in enumerate(powers):
if p:
powers2 = powers[:k] + (p-1,powers[k+1]+1) + powers[k+2:]
if powers2 in Ra:
Ra[powers2] += p*count
else:
Ra[powers2] = p*count
_cache[n] = Ra
return _cache[n]
@defun
def diffs_exp(ctx, fdiffs):
r"""
Given an iterable or generator yielding `f(x), f'(x), f''(x), \ldots`
generate `g(x), g'(x), g''(x), \ldots` where `g(x) = \exp(f(x))`.
At high precision and for large orders, this is typically more efficient
than numerical differentiation if the derivatives of `f(x)`
admit direct computation.
Note: This function does not increase the working precision internally,
so guard digits may have to be added externally for full accuracy.
**Examples**
The derivatives of the gamma function can be computed using
logarithmic differentiation::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>>
>>> def diffs_loggamma(x):
... yield loggamma(x)
... i = 0
... while 1:
... yield psi(i,x)
... i += 1
...
>>> u = diffs_exp(diffs_loggamma(3))
>>> v = diffs(gamma, 3)
>>> next(u); next(v)
2.0
2.0
>>> next(u); next(v)
1.84556867019693
1.84556867019693
>>> next(u); next(v)
2.49292999190269
2.49292999190269
>>> next(u); next(v)
3.44996501352367
3.44996501352367
"""
fn = iterable_to_function(fdiffs)
f0 = ctx.exp(fn(0))
yield f0
i = 1
while 1:
s = ctx.mpf(0)
for powers, c in iteritems(dpoly(i)):
s += c*ctx.fprod(fn(k+1)**p for (k,p) in enumerate(powers) if p)
yield s * f0
i += 1
@defun
def differint(ctx, f, x, n=1, x0=0):
r"""
Calculates the Riemann-Liouville differintegral, or fractional
derivative, defined by
.. math ::
\,_{x_0}{\mathbb{D}}^n_xf(x) = \frac{1}{\Gamma(m-n)} \frac{d^m}{dx^m}
\int_{x_0}^{x}(x-t)^{m-n-1}f(t)dt
where `f` is a given (presumably well-behaved) function,
`x` is the evaluation point, `n` is the order, and `x_0` is
the reference point of integration (`m` is an arbitrary
parameter selected automatically).
With `n = 1`, this is just the standard derivative `f'(x)`; with `n = 2`,
the second derivative `f''(x)`, etc. With `n = -1`, it gives
`\int_{x_0}^x f(t) dt`, with `n = -2`
it gives `\int_{x_0}^x \left( \int_{x_0}^t f(u) du \right) dt`, etc.
As `n` is permitted to be any number, this operator generalizes
iterated differentiation and iterated integration to a single
operator with a continuous order parameter.
**Examples**
There is an exact formula for the fractional derivative of a
monomial `x^p`, which may be used as a reference. For example,
the following gives a half-derivative (order 0.5)::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> x = mpf(3); p = 2; n = 0.5
>>> differint(lambda t: t**p, x, n)
7.81764019044672
>>> gamma(p+1)/gamma(p-n+1) * x**(p-n)
7.81764019044672
Another useful test function is the exponential function, whose
integration / differentiation formula easy generalizes
to arbitrary order. Here we first compute a third derivative,
and then a triply nested integral. (The reference point `x_0`
is set to `-\infty` to avoid nonzero endpoint terms.)::
>>> differint(lambda x: exp(pi*x), -1.5, 3)
0.278538406900792
>>> exp(pi*-1.5) * pi**3
0.278538406900792
>>> differint(lambda x: exp(pi*x), 3.5, -3, -inf)
1922.50563031149
>>> exp(pi*3.5) / pi**3
1922.50563031149
However, for noninteger `n`, the differentiation formula for the
exponential function must be modified to give the same result as the
Riemann-Liouville differintegral::
>>> x = mpf(3.5)
>>> c = pi
>>> n = 1+2*j
>>> differint(lambda x: exp(c*x), x, n)
(-123295.005390743 + 140955.117867654j)
>>> x**(-n) * exp(c)**x * (x*c)**n * gammainc(-n, 0, x*c) / gamma(-n)
(-123295.005390743 + 140955.117867654j)
"""
m = max(int(ctx.ceil(ctx.re(n)))+1, 1)
r = m-n-1
g = lambda x: ctx.quad(lambda t: (x-t)**r * f(t), [x0, x])
return ctx.diff(g, x, m) / ctx.gamma(m-n)
@defun
def diffun(ctx, f, n=1, **options):
r"""
Given a function `f`, returns a function `g(x)` that evaluates the nth
derivative `f^{(n)}(x)`::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> cos2 = diffun(sin)
>>> sin2 = diffun(sin, 4)
>>> cos(1.3), cos2(1.3)
(0.267498828624587, 0.267498828624587)
>>> sin(1.3), sin2(1.3)
(0.963558185417193, 0.963558185417193)
The function `f` must support arbitrary precision evaluation.
See :func:`~mpmath.diff` for additional details and supported
keyword options.
"""
if n == 0:
return f
def g(x):
return ctx.diff(f, x, n, **options)
return g
@defun
def taylor(ctx, f, x, n, **options):
r"""
Produces a degree-`n` Taylor polynomial around the point `x` of the
given function `f`. The coefficients are returned as a list.
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> nprint(chop(taylor(sin, 0, 5)))
[0.0, 1.0, 0.0, -0.166667, 0.0, 0.00833333]
The coefficients are computed using high-order numerical
differentiation. The function must be possible to evaluate
to arbitrary precision. See :func:`~mpmath.diff` for additional details
and supported keyword options.
Note that to evaluate the Taylor polynomial as an approximation
of `f`, e.g. with :func:`~mpmath.polyval`, the coefficients must be reversed,
and the point of the Taylor expansion must be subtracted from
the argument:
>>> p = taylor(exp, 2.0, 10)
>>> polyval(p[::-1], 2.5 - 2.0)
12.1824939606092
>>> exp(2.5)
12.1824939607035
"""
gen = enumerate(ctx.diffs(f, x, n, **options))
if options.get("chop", True):
return [ctx.chop(d)/ctx.factorial(i) for i, d in gen]
else:
return [d/ctx.factorial(i) for i, d in gen]
@defun
def pade(ctx, a, L, M):
r"""
Computes a Pade approximation of degree `(L, M)` to a function.
Given at least `L+M+1` Taylor coefficients `a` approximating
a function `A(x)`, :func:`~mpmath.pade` returns coefficients of
polynomials `P, Q` satisfying
.. math ::
P = \sum_{k=0}^L p_k x^k
Q = \sum_{k=0}^M q_k x^k
Q_0 = 1
A(x) Q(x) = P(x) + O(x^{L+M+1})
`P(x)/Q(x)` can provide a good approximation to an analytic function
beyond the radius of convergence of its Taylor series (example
from G.A. Baker 'Essentials of Pade Approximants' Academic Press,
Ch.1A)::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> one = mpf(1)
>>> def f(x):
... return sqrt((one + 2*x)/(one + x))
...
>>> a = taylor(f, 0, 6)
>>> p, q = pade(a, 3, 3)
>>> x = 10
>>> polyval(p[::-1], x)/polyval(q[::-1], x)
1.38169105566806
>>> f(x)
1.38169855941551
"""
# To determine L+1 coefficients of P and M coefficients of Q
# L+M+1 coefficients of A must be provided
if len(a) < L+M+1:
raise ValueError("L+M+1 Coefficients should be provided")
if M == 0:
if L == 0:
return [ctx.one], [ctx.one]
else:
return a[:L+1], [ctx.one]
# Solve first
# a[L]*q[1] + ... + a[L-M+1]*q[M] = -a[L+1]
# ...
# a[L+M-1]*q[1] + ... + a[L]*q[M] = -a[L+M]
A = ctx.matrix(M)
for j in range(M):
for i in range(min(M, L+j+1)):
A[j, i] = a[L+j-i]
v = -ctx.matrix(a[(L+1):(L+M+1)])
x = ctx.lu_solve(A, v)
q = [ctx.one] + list(x)
# compute p
p = [0]*(L+1)
for i in range(L+1):
s = a[i]
for j in range(1, min(M,i) + 1):
s += q[j]*a[i-j]
p[i] = s
return p, q

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,973 @@
# contributed to mpmath by Kristopher L. Kuhlman, February 2017
# contributed to mpmath by Guillermo Navas-Palencia, February 2022
class InverseLaplaceTransform(object):
r"""
Inverse Laplace transform methods are implemented using this
class, in order to simplify the code and provide a common
infrastructure.
Implement a custom inverse Laplace transform algorithm by
subclassing :class:`InverseLaplaceTransform` and implementing the
appropriate methods. The subclass can then be used by
:func:`~mpmath.invertlaplace` by passing it as the *method*
argument.
"""
def __init__(self, ctx):
self.ctx = ctx
def calc_laplace_parameter(self, t, **kwargs):
r"""
Determine the vector of Laplace parameter values needed for an
algorithm, this will depend on the choice of algorithm (de
Hoog is default), the algorithm-specific parameters passed (or
default ones), and desired time.
"""
raise NotImplementedError
def calc_time_domain_solution(self, fp):
r"""
Compute the time domain solution, after computing the
Laplace-space function evaluations at the abscissa required
for the algorithm. Abscissa computed for one algorithm are
typically not useful for another algorithm.
"""
raise NotImplementedError
class FixedTalbot(InverseLaplaceTransform):
def calc_laplace_parameter(self, t, **kwargs):
r"""The "fixed" Talbot method deforms the Bromwich contour towards
`-\infty` in the shape of a parabola. Traditionally the Talbot
algorithm has adjustable parameters, but the "fixed" version
does not. The `r` parameter could be passed in as a parameter,
if you want to override the default given by (Abate & Valko,
2004).
The Laplace parameter is sampled along a parabola opening
along the negative imaginary axis, with the base of the
parabola along the real axis at
`p=\frac{r}{t_\mathrm{max}}`. As the number of terms used in
the approximation (degree) grows, the abscissa required for
function evaluation tend towards `-\infty`, requiring high
precision to prevent overflow. If any poles, branch cuts or
other singularities exist such that the deformed Bromwich
contour lies to the left of the singularity, the method will
fail.
**Optional arguments**
:class:`~mpmath.calculus.inverselaplace.FixedTalbot.calc_laplace_parameter`
recognizes the following keywords
*tmax*
maximum time associated with vector of times
(typically just the time requested)
*degree*
integer order of approximation (M = number of terms)
*r*
abscissa for `p_0` (otherwise computed using rule
of thumb `2M/5`)
The working precision will be increased according to a rule of
thumb. If 'degree' is not specified, the working precision and
degree are chosen to hopefully achieve the dps of the calling
context. If 'degree' is specified, the working precision is
chosen to achieve maximum resulting precision for the
specified degree.
.. math ::
p_0=\frac{r}{t}
.. math ::
p_i=\frac{i r \pi}{Mt_\mathrm{max}}\left[\cot\left(
\frac{i\pi}{M}\right) + j \right] \qquad 1\le i <M
where `j=\sqrt{-1}`, `r=2M/5`, and `t_\mathrm{max}` is the
maximum specified time.
"""
# required
# ------------------------------
# time of desired approximation
self.t = self.ctx.convert(t)
# optional
# ------------------------------
# maximum time desired (used for scaling) default is requested
# time.
self.tmax = self.ctx.convert(kwargs.get('tmax', self.t))
# empirical relationships used here based on a linear fit of
# requested and delivered dps for exponentially decaying time
# functions for requested dps up to 512.
if 'degree' in kwargs:
self.degree = kwargs['degree']
self.dps_goal = self.degree
else:
self.dps_goal = int(1.72*self.ctx.dps)
self.degree = max(12, int(1.38*self.dps_goal))
M = self.degree
# this is adjusting the dps of the calling context hopefully
# the caller doesn't monkey around with it between calling
# this routine and calc_time_domain_solution()
self.dps_orig = self.ctx.dps
self.ctx.dps = self.dps_goal
# Abate & Valko rule of thumb for r parameter
self.r = kwargs.get('r', self.ctx.fraction(2, 5)*M)
self.theta = self.ctx.linspace(0.0, self.ctx.pi, M+1)
self.cot_theta = self.ctx.matrix(M, 1)
self.cot_theta[0] = 0 # not used
# all but time-dependent part of p
self.delta = self.ctx.matrix(M, 1)
self.delta[0] = self.r
for i in range(1, M):
self.cot_theta[i] = self.ctx.cot(self.theta[i])
self.delta[i] = self.r*self.theta[i]*(self.cot_theta[i] + 1j)
self.p = self.ctx.matrix(M, 1)
self.p = self.delta/self.tmax
# NB: p is complex (mpc)
def calc_time_domain_solution(self, fp, t, manual_prec=False):
r"""The fixed Talbot time-domain solution is computed from the
Laplace-space function evaluations using
.. math ::
f(t,M)=\frac{2}{5t}\sum_{k=0}^{M-1}\Re \left[
\gamma_k \bar{f}(p_k)\right]
where
.. math ::
\gamma_0 = \frac{1}{2}e^{r}\bar{f}(p_0)
.. math ::
\gamma_k = e^{tp_k}\left\lbrace 1 + \frac{jk\pi}{M}\left[1 +
\cot \left( \frac{k \pi}{M} \right)^2 \right] - j\cot\left(
\frac{k \pi}{M}\right)\right \rbrace \qquad 1\le k<M.
Again, `j=\sqrt{-1}`.
Before calling this function, call
:class:`~mpmath.calculus.inverselaplace.FixedTalbot.calc_laplace_parameter`
to set the parameters and compute the required coefficients.
**References**
1. Abate, J., P. Valko (2004). Multi-precision Laplace
transform inversion. *International Journal for Numerical
Methods in Engineering* 60:979-993,
http://dx.doi.org/10.1002/nme.995
2. Talbot, A. (1979). The accurate numerical inversion of
Laplace transforms. *IMA Journal of Applied Mathematics*
23(1):97, http://dx.doi.org/10.1093/imamat/23.1.97
"""
# required
# ------------------------------
self.t = self.ctx.convert(t)
# assume fp was computed from p matrix returned from
# calc_laplace_parameter(), so is already a list or matrix of
# mpmath 'mpc' types
# these were computed in previous call to
# calc_laplace_parameter()
theta = self.theta
delta = self.delta
M = self.degree
p = self.p
r = self.r
ans = self.ctx.matrix(M, 1)
ans[0] = self.ctx.exp(delta[0])*fp[0]/2
for i in range(1, M):
ans[i] = self.ctx.exp(delta[i])*fp[i]*(
1 + 1j*theta[i]*(1 + self.cot_theta[i]**2) -
1j*self.cot_theta[i])
result = self.ctx.fraction(2, 5)*self.ctx.fsum(ans)/self.t
# setting dps back to value when calc_laplace_parameter was
# called, unless flag is set.
if not manual_prec:
self.ctx.dps = self.dps_orig
return result.real
# ****************************************
class Stehfest(InverseLaplaceTransform):
def calc_laplace_parameter(self, t, **kwargs):
r"""
The Gaver-Stehfest method is a discrete approximation of the
Widder-Post inversion algorithm, rather than a direct
approximation of the Bromwich contour integral.
The method abscissa along the real axis, and therefore has
issues inverting oscillatory functions (which have poles in
pairs away from the real axis).
The working precision will be increased according to a rule of
thumb. If 'degree' is not specified, the working precision and
degree are chosen to hopefully achieve the dps of the calling
context. If 'degree' is specified, the working precision is
chosen to achieve maximum resulting precision for the
specified degree.
.. math ::
p_k = \frac{k \log 2}{t} \qquad 1 \le k \le M
"""
# required
# ------------------------------
# time of desired approximation
self.t = self.ctx.convert(t)
# optional
# ------------------------------
# empirical relationships used here based on a linear fit of
# requested and delivered dps for exponentially decaying time
# functions for requested dps up to 512.
if 'degree' in kwargs:
self.degree = kwargs['degree']
self.dps_goal = int(1.38*self.degree)
else:
self.dps_goal = int(2.93*self.ctx.dps)
self.degree = max(16, self.dps_goal)
# _coeff routine requires even degree
if self.degree % 2 > 0:
self.degree += 1
M = self.degree
# this is adjusting the dps of the calling context
# hopefully the caller doesn't monkey around with it
# between calling this routine and calc_time_domain_solution()
self.dps_orig = self.ctx.dps
self.ctx.dps = self.dps_goal
self.V = self._coeff()
self.p = self.ctx.matrix(self.ctx.arange(1, M+1))*self.ctx.ln2/self.t
# NB: p is real (mpf)
def _coeff(self):
r"""Salzer summation weights (aka, "Stehfest coefficients")
only depend on the approximation order (M) and the precision"""
M = self.degree
M2 = int(M/2) # checked earlier that M is even
V = self.ctx.matrix(M, 1)
# Salzer summation weights
# get very large in magnitude and oscillate in sign,
# if the precision is not high enough, there will be
# catastrophic cancellation
for k in range(1, M+1):
z = self.ctx.matrix(min(k, M2)+1, 1)
for j in range(int((k+1)/2), min(k, M2)+1):
z[j] = (self.ctx.power(j, M2)*self.ctx.fac(2*j)/
(self.ctx.fac(M2-j)*self.ctx.fac(j)*
self.ctx.fac(j-1)*self.ctx.fac(k-j)*
self.ctx.fac(2*j-k)))
V[k-1] = self.ctx.power(-1, k+M2)*self.ctx.fsum(z)
return V
def calc_time_domain_solution(self, fp, t, manual_prec=False):
r"""Compute time-domain Stehfest algorithm solution.
.. math ::
f(t,M) = \frac{\log 2}{t} \sum_{k=1}^{M} V_k \bar{f}\left(
p_k \right)
where
.. math ::
V_k = (-1)^{k + N/2} \sum^{\min(k,N/2)}_{i=\lfloor(k+1)/2 \rfloor}
\frac{i^{\frac{N}{2}}(2i)!}{\left(\frac{N}{2}-i \right)! \, i! \,
\left(i-1 \right)! \, \left(k-i\right)! \, \left(2i-k \right)!}
As the degree increases, the abscissa (`p_k`) only increase
linearly towards `\infty`, but the Stehfest coefficients
(`V_k`) alternate in sign and increase rapidly in sign,
requiring high precision to prevent overflow or loss of
significance when evaluating the sum.
**References**
1. Widder, D. (1941). *The Laplace Transform*. Princeton.
2. Stehfest, H. (1970). Algorithm 368: numerical inversion of
Laplace transforms. *Communications of the ACM* 13(1):47-49,
http://dx.doi.org/10.1145/361953.361969
"""
# required
self.t = self.ctx.convert(t)
# assume fp was computed from p matrix returned from
# calc_laplace_parameter(), so is already
# a list or matrix of mpmath 'mpf' types
result = self.ctx.fdot(self.V, fp)*self.ctx.ln2/self.t
# setting dps back to value when calc_laplace_parameter was called
if not manual_prec:
self.ctx.dps = self.dps_orig
# ignore any small imaginary part
return result.real
# ****************************************
class deHoog(InverseLaplaceTransform):
def calc_laplace_parameter(self, t, **kwargs):
r"""the de Hoog, Knight & Stokes algorithm is an
accelerated form of the Fourier series numerical
inverse Laplace transform algorithms.
.. math ::
p_k = \gamma + \frac{jk}{T} \qquad 0 \le k < 2M+1
where
.. math ::
\gamma = \alpha - \frac{\log \mathrm{tol}}{2T},
`j=\sqrt{-1}`, `T = 2t_\mathrm{max}` is a scaled time,
`\alpha=10^{-\mathrm{dps\_goal}}` is the real part of the
rightmost pole or singularity, which is chosen based on the
desired accuracy (assuming the rightmost singularity is 0),
and `\mathrm{tol}=10\alpha` is the desired tolerance, which is
chosen in relation to `\alpha`.`
When increasing the degree, the abscissa increase towards
`j\infty`, but more slowly than the fixed Talbot
algorithm. The de Hoog et al. algorithm typically does better
with oscillatory functions of time, and less well-behaved
functions. The method tends to be slower than the Talbot and
Stehfest algorithsm, especially so at very high precision
(e.g., `>500` digits precision).
"""
# required
# ------------------------------
self.t = self.ctx.convert(t)
# optional
# ------------------------------
self.tmax = kwargs.get('tmax', self.t)
# empirical relationships used here based on a linear fit of
# requested and delivered dps for exponentially decaying time
# functions for requested dps up to 512.
if 'degree' in kwargs:
self.degree = kwargs['degree']
self.dps_goal = int(1.38*self.degree)
else:
self.dps_goal = int(self.ctx.dps*1.36)
self.degree = max(10, self.dps_goal)
# 2*M+1 terms in approximation
M = self.degree
# adjust alpha component of abscissa of convergence for higher
# precision
tmp = self.ctx.power(10.0, -self.dps_goal)
self.alpha = self.ctx.convert(kwargs.get('alpha', tmp))
# desired tolerance (here simply related to alpha)
self.tol = self.ctx.convert(kwargs.get('tol', self.alpha*10.0))
self.np = 2*self.degree+1 # number of terms in approximation
# this is adjusting the dps of the calling context
# hopefully the caller doesn't monkey around with it
# between calling this routine and calc_time_domain_solution()
self.dps_orig = self.ctx.dps
self.ctx.dps = self.dps_goal
# scaling factor (likely tun-able, but 2 is typical)
self.scale = kwargs.get('scale', 2)
self.T = self.ctx.convert(kwargs.get('T', self.scale*self.tmax))
self.p = self.ctx.matrix(2*M+1, 1)
self.gamma = self.alpha - self.ctx.log(self.tol)/(self.scale*self.T)
self.p = (self.gamma + self.ctx.pi*
self.ctx.matrix(self.ctx.arange(self.np))/self.T*1j)
# NB: p is complex (mpc)
def calc_time_domain_solution(self, fp, t, manual_prec=False):
r"""Calculate time-domain solution for
de Hoog, Knight & Stokes algorithm.
The un-accelerated Fourier series approach is:
.. math ::
f(t,2M+1) = \frac{e^{\gamma t}}{T} \sum_{k=0}^{2M}{}^{'}
\Re\left[\bar{f}\left( p_k \right)
e^{i\pi t/T} \right],
where the prime on the summation indicates the first term is halved.
This simplistic approach requires so many function evaluations
that it is not practical. Non-linear acceleration is
accomplished via Pade-approximation and an analytic expression
for the remainder of the continued fraction. See the original
paper (reference 2 below) a detailed description of the
numerical approach.
**References**
1. Davies, B. (2005). *Integral Transforms and their
Applications*, Third Edition. Springer.
2. de Hoog, F., J. Knight, A. Stokes (1982). An improved
method for numerical inversion of Laplace transforms. *SIAM
Journal of Scientific and Statistical Computing* 3:357-366,
http://dx.doi.org/10.1137/0903022
"""
M = self.degree
np = self.np
T = self.T
self.t = self.ctx.convert(t)
# would it be useful to try re-using
# space between e&q and A&B?
e = self.ctx.zeros(np, M+1)
q = self.ctx.matrix(2*M, M)
d = self.ctx.matrix(np, 1)
A = self.ctx.zeros(np+1, 1)
B = self.ctx.ones(np+1, 1)
# initialize Q-D table
e[:, 0] = 0.0 + 0j
q[0, 0] = fp[1]/(fp[0]/2)
for i in range(1, 2*M):
q[i, 0] = fp[i+1]/fp[i]
# rhombus rule for filling triangular Q-D table (e & q)
for r in range(1, M+1):
# start with e, column 1, 0:2*M-2
mr = 2*(M-r) + 1
e[0:mr, r] = q[1:mr+1, r-1] - q[0:mr, r-1] + e[1:mr+1, r-1]
if not r == M:
rq = r+1
mr = 2*(M-rq)+1 + 2
for i in range(mr):
q[i, rq-1] = q[i+1, rq-2]*e[i+1, rq-1]/e[i, rq-1]
# build up continued fraction coefficients (d)
d[0] = fp[0]/2
for r in range(1, M+1):
d[2*r-1] = -q[0, r-1] # even terms
d[2*r] = -e[0, r] # odd terms
# seed A and B for recurrence
A[0] = 0.0 + 0.0j
A[1] = d[0]
B[0:2] = 1.0 + 0.0j
# base of the power series
z = self.ctx.expjpi(self.t/T) # i*pi is already in fcn
# coefficients of Pade approximation (A & B)
# using recurrence for all but last term
for i in range(1, 2*M):
A[i+1] = A[i] + d[i]*A[i-1]*z
B[i+1] = B[i] + d[i]*B[i-1]*z
# "improved remainder" to continued fraction
brem = (1 + (d[2*M-1] - d[2*M])*z)/2
# powm1(x,y) computes x^y - 1 more accurately near zero
rem = brem*self.ctx.powm1(1 + d[2*M]*z/brem,
self.ctx.fraction(1, 2))
# last term of recurrence using new remainder
A[np] = A[2*M] + rem*A[2*M-1]
B[np] = B[2*M] + rem*B[2*M-1]
# diagonal Pade approximation
# F=A/B represents accelerated trapezoid rule
result = self.ctx.exp(self.gamma*self.t)/T*(A[np]/B[np]).real
# setting dps back to value when calc_laplace_parameter was called
if not manual_prec:
self.ctx.dps = self.dps_orig
return result
# ****************************************
class Cohen(InverseLaplaceTransform):
def calc_laplace_parameter(self, t, **kwargs):
r"""The Cohen algorithm accelerates the convergence of the nearly
alternating series resulting from the application of the trapezoidal
rule to the Bromwich contour inversion integral.
.. math ::
p_k = \frac{\gamma}{2 t} + \frac{\pi i k}{t} \qquad 0 \le k < M
where
.. math ::
\gamma = \frac{2}{3} (d + \log(10) + \log(2 t)),
`d = \mathrm{dps\_goal}`, which is chosen based on the desired
accuracy using the method developed in [1] to improve numerical
stability. The Cohen algorithm shows robustness similar to the de Hoog
et al. algorithm, but it is faster than the fixed Talbot algorithm.
**Optional arguments**
*degree*
integer order of the approximation (M = number of terms)
*alpha*
abscissa for `p_0` (controls the discretization error)
The working precision will be increased according to a rule of
thumb. If 'degree' is not specified, the working precision and
degree are chosen to hopefully achieve the dps of the calling
context. If 'degree' is specified, the working precision is
chosen to achieve maximum resulting precision for the
specified degree.
**References**
1. P. Glasserman, J. Ruiz-Mata (2006). Computing the credit loss
distribution in the Gaussian copula model: a comparison of methods.
*Journal of Credit Risk* 2(4):33-66, 10.21314/JCR.2006.057
"""
self.t = self.ctx.convert(t)
if 'degree' in kwargs:
self.degree = kwargs['degree']
self.dps_goal = int(1.5 * self.degree)
else:
self.dps_goal = int(self.ctx.dps * 1.74)
self.degree = max(22, int(1.31 * self.dps_goal))
M = self.degree + 1
# this is adjusting the dps of the calling context hopefully
# the caller doesn't monkey around with it between calling
# this routine and calc_time_domain_solution()
self.dps_orig = self.ctx.dps
self.ctx.dps = self.dps_goal
ttwo = 2 * self.t
tmp = self.ctx.dps * self.ctx.log(10) + self.ctx.log(ttwo)
tmp = self.ctx.fraction(2, 3) * tmp
self.alpha = self.ctx.convert(kwargs.get('alpha', tmp))
# all but time-dependent part of p
a_t = self.alpha / ttwo
p_t = self.ctx.pi * 1j / self.t
self.p = self.ctx.matrix(M, 1)
self.p[0] = a_t
for i in range(1, M):
self.p[i] = a_t + i * p_t
def calc_time_domain_solution(self, fp, t, manual_prec=False):
r"""Calculate time-domain solution for Cohen algorithm.
The accelerated nearly alternating series is:
.. math ::
f(t, M) = \frac{e^{\gamma / 2}}{t} \left[\frac{1}{2}
\Re\left(\bar{f}\left(\frac{\gamma}{2t}\right) \right) -
\sum_{k=0}^{M-1}\frac{c_{M,k}}{d_M}\Re\left(\bar{f}
\left(\frac{\gamma + 2(k+1) \pi i}{2t}\right)\right)\right],
where coefficients `\frac{c_{M, k}}{d_M}` are described in [1].
1. H. Cohen, F. Rodriguez Villegas, D. Zagier (2000). Convergence
acceleration of alternating series. *Experiment. Math* 9(1):3-12
"""
self.t = self.ctx.convert(t)
n = self.degree
M = n + 1
A = self.ctx.matrix(M, 1)
for i in range(M):
A[i] = fp[i].real
d = (3 + self.ctx.sqrt(8)) ** n
d = (d + 1 / d) / 2
b = -self.ctx.one
c = -d
s = 0
for k in range(n):
c = b - c
s = s + c * A[k + 1]
b = 2 * (k + n) * (k - n) * b / ((2 * k + 1) * (k + self.ctx.one))
result = self.ctx.exp(self.alpha / 2) / self.t * (A[0] / 2 - s / d)
# setting dps back to value when calc_laplace_parameter was
# called, unless flag is set.
if not manual_prec:
self.ctx.dps = self.dps_orig
return result
# ****************************************
class LaplaceTransformInversionMethods(object):
def __init__(ctx, *args, **kwargs):
ctx._fixed_talbot = FixedTalbot(ctx)
ctx._stehfest = Stehfest(ctx)
ctx._de_hoog = deHoog(ctx)
ctx._cohen = Cohen(ctx)
def invertlaplace(ctx, f, t, **kwargs):
r"""Computes the numerical inverse Laplace transform for a
Laplace-space function at a given time. The function being
evaluated is assumed to be a real-valued function of time.
The user must supply a Laplace-space function `\bar{f}(p)`,
and a desired time at which to estimate the time-domain
solution `f(t)`.
A few basic examples of Laplace-space functions with known
inverses (see references [1,2]) :
.. math ::
\mathcal{L}\left\lbrace f(t) \right\rbrace=\bar{f}(p)
.. math ::
\mathcal{L}^{-1}\left\lbrace \bar{f}(p) \right\rbrace = f(t)
.. math ::
\bar{f}(p) = \frac{1}{(p+1)^2}
.. math ::
f(t) = t e^{-t}
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> tt = [0.001, 0.01, 0.1, 1, 10]
>>> fp = lambda p: 1/(p+1)**2
>>> ft = lambda t: t*exp(-t)
>>> ft(tt[0]),ft(tt[0])-invertlaplace(fp,tt[0],method='talbot')
(0.000999000499833375, 8.57923043561212e-20)
>>> ft(tt[1]),ft(tt[1])-invertlaplace(fp,tt[1],method='talbot')
(0.00990049833749168, 3.27007646698047e-19)
>>> ft(tt[2]),ft(tt[2])-invertlaplace(fp,tt[2],method='talbot')
(0.090483741803596, -1.75215800052168e-18)
>>> ft(tt[3]),ft(tt[3])-invertlaplace(fp,tt[3],method='talbot')
(0.367879441171442, 1.2428864009344e-17)
>>> ft(tt[4]),ft(tt[4])-invertlaplace(fp,tt[4],method='talbot')
(0.000453999297624849, 4.04513489306658e-20)
The methods also work for higher precision:
>>> mp.dps = 100; mp.pretty = True
>>> nstr(ft(tt[0]),15),nstr(ft(tt[0])-invertlaplace(fp,tt[0],method='talbot'),15)
('0.000999000499833375', '-4.96868310693356e-105')
>>> nstr(ft(tt[1]),15),nstr(ft(tt[1])-invertlaplace(fp,tt[1],method='talbot'),15)
('0.00990049833749168', '1.23032291513122e-104')
.. math ::
\bar{f}(p) = \frac{1}{p^2+1}
.. math ::
f(t) = \mathrm{J}_0(t)
>>> mp.dps = 15; mp.pretty = True
>>> fp = lambda p: 1/sqrt(p*p + 1)
>>> ft = lambda t: besselj(0,t)
>>> ft(tt[0]),ft(tt[0])-invertlaplace(fp,tt[0],method='dehoog')
(0.999999750000016, -6.09717765032273e-18)
>>> ft(tt[1]),ft(tt[1])-invertlaplace(fp,tt[1],method='dehoog')
(0.99997500015625, -5.61756281076169e-17)
.. math ::
\bar{f}(p) = \frac{\log p}{p}
.. math ::
f(t) = -\gamma -\log t
>>> mp.dps = 15; mp.pretty = True
>>> fp = lambda p: log(p)/p
>>> ft = lambda t: -euler-log(t)
>>> ft(tt[0]),ft(tt[0])-invertlaplace(fp,tt[0],method='stehfest')
(6.3305396140806, -1.92126634837863e-16)
>>> ft(tt[1]),ft(tt[1])-invertlaplace(fp,tt[1],method='stehfest')
(4.02795452108656, -4.81486093200704e-16)
**Options**
:func:`~mpmath.invertlaplace` recognizes the following optional
keywords valid for all methods:
*method*
Chooses numerical inverse Laplace transform algorithm
(described below).
*degree*
Number of terms used in the approximation
**Algorithms**
Mpmath implements four numerical inverse Laplace transform
algorithms, attributed to: Talbot, Stehfest, and de Hoog,
Knight and Stokes. These can be selected by using
*method='talbot'*, *method='stehfest'*, *method='dehoog'* or
*method='cohen'* or by passing the classes *method=FixedTalbot*,
*method=Stehfest*, *method=deHoog*, or *method=Cohen*. The functions
:func:`~mpmath.invlaptalbot`, :func:`~mpmath.invlapstehfest`,
:func:`~mpmath.invlapdehoog`, and :func:`~mpmath.invlapcohen`
are also available as shortcuts.
All four algorithms implement a heuristic balance between the
requested precision and the precision used internally for the
calculations. This has been tuned for a typical exponentially
decaying function and precision up to few hundred decimal
digits.
The Laplace transform converts the variable time (i.e., along
a line) into a parameter given by the right half of the
complex `p`-plane. Singularities, poles, and branch cuts in
the complex `p`-plane contain all the information regarding
the time behavior of the corresponding function. Any numerical
method must therefore sample `p`-plane "close enough" to the
singularities to accurately characterize them, while not
getting too close to have catastrophic cancellation, overflow,
or underflow issues. Most significantly, if one or more of the
singularities in the `p`-plane is not on the left side of the
Bromwich contour, its effects will be left out of the computed
solution, and the answer will be completely wrong.
*Talbot*
The fixed Talbot method is high accuracy and fast, but the
method can catastrophically fail for certain classes of time-domain
behavior, including a Heaviside step function for positive
time (e.g., `H(t-2)`), or some oscillatory behaviors. The
Talbot method usually has adjustable parameters, but the
"fixed" variety implemented here does not. This method
deforms the Bromwich integral contour in the shape of a
parabola towards `-\infty`, which leads to problems
when the solution has a decaying exponential in it (e.g., a
Heaviside step function is equivalent to multiplying by a
decaying exponential in Laplace space).
*Stehfest*
The Stehfest algorithm only uses abscissa along the real axis
of the complex `p`-plane to estimate the time-domain
function. Oscillatory time-domain functions have poles away
from the real axis, so this method does not work well with
oscillatory functions, especially high-frequency ones. This
method also depends on summation of terms in a series that
grows very large, and will have catastrophic cancellation
during summation if the working precision is too low.
*de Hoog et al.*
The de Hoog, Knight, and Stokes method is essentially a
Fourier-series quadrature-type approximation to the Bromwich
contour integral, with non-linear series acceleration and an
analytical expression for the remainder term. This method is
typically one of the most robust. This method also involves the
greatest amount of overhead, so it is typically the slowest of the
four methods at high precision.
*Cohen*
The Cohen method is a trapezoidal rule approximation to the Bromwich
contour integral, with linear acceleration for alternating
series. This method is as robust as the de Hoog et al method and the
fastest of the four methods at high precision, and is therefore the
default method.
**Singularities**
All numerical inverse Laplace transform methods have problems
at large time when the Laplace-space function has poles,
singularities, or branch cuts to the right of the origin in
the complex plane. For simple poles in `\bar{f}(p)` at the
`p`-plane origin, the time function is constant in time (e.g.,
`\mathcal{L}\left\lbrace 1 \right\rbrace=1/p` has a pole at
`p=0`). A pole in `\bar{f}(p)` to the left of the origin is a
decreasing function of time (e.g., `\mathcal{L}\left\lbrace
e^{-t/2} \right\rbrace=1/(p+1/2)` has a pole at `p=-1/2`), and
a pole to the right of the origin leads to an increasing
function in time (e.g., `\mathcal{L}\left\lbrace t e^{t/4}
\right\rbrace = 1/(p-1/4)^2` has a pole at `p=1/4`). When
singularities occur off the real `p` axis, the time-domain
function is oscillatory. For example `\mathcal{L}\left\lbrace
\mathrm{J}_0(t) \right\rbrace=1/\sqrt{p^2+1}` has a branch cut
starting at `p=j=\sqrt{-1}` and is a decaying oscillatory
function, This range of behaviors is illustrated in Duffy [3]
Figure 4.10.4, p. 228.
In general as `p \rightarrow \infty` `t \rightarrow 0` and
vice-versa. All numerical inverse Laplace transform methods
require their abscissa to shift closer to the origin for
larger times. If the abscissa shift left of the rightmost
singularity in the Laplace domain, the answer will be
completely wrong (the effect of singularities to the right of
the Bromwich contour are not included in the results).
For example, the following exponentially growing function has
a pole at `p=3`:
.. math ::
\bar{f}(p)=\frac{1}{p^2-9}
.. math ::
f(t)=\frac{1}{3}\sinh 3t
>>> mp.dps = 15; mp.pretty = True
>>> fp = lambda p: 1/(p*p-9)
>>> ft = lambda t: sinh(3*t)/3
>>> tt = [0.01,0.1,1.0,10.0]
>>> ft(tt[0]),invertlaplace(fp,tt[0],method='talbot')
(0.0100015000675014, 0.0100015000675014)
>>> ft(tt[1]),invertlaplace(fp,tt[1],method='talbot')
(0.101506764482381, 0.101506764482381)
>>> ft(tt[2]),invertlaplace(fp,tt[2],method='talbot')
(3.33929164246997, 3.33929164246997)
>>> ft(tt[3]),invertlaplace(fp,tt[3],method='talbot')
(1781079096920.74, -1.61331069624091e-14)
**References**
1. [DLMF]_ section 1.14 (http://dlmf.nist.gov/1.14T4)
2. Cohen, A.M. (2007). Numerical Methods for Laplace Transform
Inversion, Springer.
3. Duffy, D.G. (1998). Advanced Engineering Mathematics, CRC Press.
**Numerical Inverse Laplace Transform Reviews**
1. Bellman, R., R.E. Kalaba, J.A. Lockett (1966). *Numerical
inversion of the Laplace transform: Applications to Biology,
Economics, Engineering, and Physics*. Elsevier.
2. Davies, B., B. Martin (1979). Numerical inversion of the
Laplace transform: a survey and comparison of methods. *Journal
of Computational Physics* 33:1-32,
http://dx.doi.org/10.1016/0021-9991(79)90025-1
3. Duffy, D.G. (1993). On the numerical inversion of Laplace
transforms: Comparison of three new methods on characteristic
problems from applications. *ACM Transactions on Mathematical
Software* 19(3):333-359, http://dx.doi.org/10.1145/155743.155788
4. Kuhlman, K.L., (2013). Review of Inverse Laplace Transform
Algorithms for Laplace-Space Numerical Approaches, *Numerical
Algorithms*, 63(2):339-355.
http://dx.doi.org/10.1007/s11075-012-9625-3
"""
rule = kwargs.get('method', 'cohen')
if type(rule) is str:
lrule = rule.lower()
if lrule == 'talbot':
rule = ctx._fixed_talbot
elif lrule == 'stehfest':
rule = ctx._stehfest
elif lrule == 'dehoog':
rule = ctx._de_hoog
elif rule == 'cohen':
rule = ctx._cohen
else:
raise ValueError("unknown invlap algorithm: %s" % rule)
else:
rule = rule(ctx)
# determine the vector of Laplace-space parameter
# needed for the requested method and desired time
rule.calc_laplace_parameter(t, **kwargs)
# compute the Laplace-space function evalutations
# at the required abscissa.
fp = [f(p) for p in rule.p]
# compute the time-domain solution from the
# Laplace-space function evaluations
return rule.calc_time_domain_solution(fp, t)
# shortcuts for the above function for specific methods
def invlaptalbot(ctx, *args, **kwargs):
kwargs['method'] = 'talbot'
return ctx.invertlaplace(*args, **kwargs)
def invlapstehfest(ctx, *args, **kwargs):
kwargs['method'] = 'stehfest'
return ctx.invertlaplace(*args, **kwargs)
def invlapdehoog(ctx, *args, **kwargs):
kwargs['method'] = 'dehoog'
return ctx.invertlaplace(*args, **kwargs)
def invlapcohen(ctx, *args, **kwargs):
kwargs['method'] = 'cohen'
return ctx.invertlaplace(*args, **kwargs)
# ****************************************
if __name__ == '__main__':
import doctest
doctest.testmod()

View File

@ -0,0 +1,288 @@
from bisect import bisect
from ..libmp.backend import xrange
class ODEMethods(object):
pass
def ode_taylor(ctx, derivs, x0, y0, tol_prec, n):
h = tol = ctx.ldexp(1, -tol_prec)
dim = len(y0)
xs = [x0]
ys = [y0]
x = x0
y = y0
orig = ctx.prec
try:
ctx.prec = orig*(1+n)
# Use n steps with Euler's method to get
# evaluation points for derivatives
for i in range(n):
fxy = derivs(x, y)
y = [y[i]+h*fxy[i] for i in xrange(len(y))]
x += h
xs.append(x)
ys.append(y)
# Compute derivatives
ser = [[] for d in range(dim)]
for j in range(n+1):
s = [0]*dim
b = (-1) ** (j & 1)
k = 1
for i in range(j+1):
for d in range(dim):
s[d] += b * ys[i][d]
b = (b * (j-k+1)) // (-k)
k += 1
scale = h**(-j) / ctx.fac(j)
for d in range(dim):
s[d] = s[d] * scale
ser[d].append(s[d])
finally:
ctx.prec = orig
# Estimate radius for which we can get full accuracy.
# XXX: do this right for zeros
radius = ctx.one
for ts in ser:
if ts[-1]:
radius = min(radius, ctx.nthroot(tol/abs(ts[-1]), n))
radius /= 2 # XXX
return ser, x0+radius
def odefun(ctx, F, x0, y0, tol=None, degree=None, method='taylor', verbose=False):
r"""
Returns a function `y(x) = [y_0(x), y_1(x), \ldots, y_n(x)]`
that is a numerical solution of the `n+1`-dimensional first-order
ordinary differential equation (ODE) system
.. math ::
y_0'(x) = F_0(x, [y_0(x), y_1(x), \ldots, y_n(x)])
y_1'(x) = F_1(x, [y_0(x), y_1(x), \ldots, y_n(x)])
\vdots
y_n'(x) = F_n(x, [y_0(x), y_1(x), \ldots, y_n(x)])
The derivatives are specified by the vector-valued function
*F* that evaluates
`[y_0', \ldots, y_n'] = F(x, [y_0, \ldots, y_n])`.
The initial point `x_0` is specified by the scalar argument *x0*,
and the initial value `y(x_0) = [y_0(x_0), \ldots, y_n(x_0)]` is
specified by the vector argument *y0*.
For convenience, if the system is one-dimensional, you may optionally
provide just a scalar value for *y0*. In this case, *F* should accept
a scalar *y* argument and return a scalar. The solution function
*y* will return scalar values instead of length-1 vectors.
Evaluation of the solution function `y(x)` is permitted
for any `x \ge x_0`.
A high-order ODE can be solved by transforming it into first-order
vector form. This transformation is described in standard texts
on ODEs. Examples will also be given below.
**Options, speed and accuracy**
By default, :func:`~mpmath.odefun` uses a high-order Taylor series
method. For reasonably well-behaved problems, the solution will
be fully accurate to within the working precision. Note that
*F* must be possible to evaluate to very high precision
for the generation of Taylor series to work.
To get a faster but less accurate solution, you can set a large
value for *tol* (which defaults roughly to *eps*). If you just
want to plot the solution or perform a basic simulation,
*tol = 0.01* is likely sufficient.
The *degree* argument controls the degree of the solver (with
*method='taylor'*, this is the degree of the Taylor series
expansion). A higher degree means that a longer step can be taken
before a new local solution must be generated from *F*,
meaning that fewer steps are required to get from `x_0` to a given
`x_1`. On the other hand, a higher degree also means that each
local solution becomes more expensive (i.e., more evaluations of
*F* are required per step, and at higher precision).
The optimal setting therefore involves a tradeoff. Generally,
decreasing the *degree* for Taylor series is likely to give faster
solution at low precision, while increasing is likely to be better
at higher precision.
The function
object returned by :func:`~mpmath.odefun` caches the solutions at all step
points and uses polynomial interpolation between step points.
Therefore, once `y(x_1)` has been evaluated for some `x_1`,
`y(x)` can be evaluated very quickly for any `x_0 \le x \le x_1`.
and continuing the evaluation up to `x_2 > x_1` is also fast.
**Examples of first-order ODEs**
We will solve the standard test problem `y'(x) = y(x), y(0) = 1`
which has explicit solution `y(x) = \exp(x)`::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> f = odefun(lambda x, y: y, 0, 1)
>>> for x in [0, 1, 2.5]:
... print((f(x), exp(x)))
...
(1.0, 1.0)
(2.71828182845905, 2.71828182845905)
(12.1824939607035, 12.1824939607035)
The solution with high precision::
>>> mp.dps = 50
>>> f = odefun(lambda x, y: y, 0, 1)
>>> f(1)
2.7182818284590452353602874713526624977572470937
>>> exp(1)
2.7182818284590452353602874713526624977572470937
Using the more general vectorized form, the test problem
can be input as (note that *f* returns a 1-element vector)::
>>> mp.dps = 15
>>> f = odefun(lambda x, y: [y[0]], 0, [1])
>>> f(1)
[2.71828182845905]
:func:`~mpmath.odefun` can solve nonlinear ODEs, which are generally
impossible (and at best difficult) to solve analytically. As
an example of a nonlinear ODE, we will solve `y'(x) = x \sin(y(x))`
for `y(0) = \pi/2`. An exact solution happens to be known
for this problem, and is given by
`y(x) = 2 \tan^{-1}\left(\exp\left(x^2/2\right)\right)`::
>>> f = odefun(lambda x, y: x*sin(y), 0, pi/2)
>>> for x in [2, 5, 10]:
... print((f(x), 2*atan(exp(mpf(x)**2/2))))
...
(2.87255666284091, 2.87255666284091)
(3.14158520028345, 3.14158520028345)
(3.14159265358979, 3.14159265358979)
If `F` is independent of `y`, an ODE can be solved using direct
integration. We can therefore obtain a reference solution with
:func:`~mpmath.quad`::
>>> f = lambda x: (1+x**2)/(1+x**3)
>>> g = odefun(lambda x, y: f(x), pi, 0)
>>> g(2*pi)
0.72128263801696
>>> quad(f, [pi, 2*pi])
0.72128263801696
**Examples of second-order ODEs**
We will solve the harmonic oscillator equation `y''(x) + y(x) = 0`.
To do this, we introduce the helper functions `y_0 = y, y_1 = y_0'`
whereby the original equation can be written as `y_1' + y_0' = 0`. Put
together, we get the first-order, two-dimensional vector ODE
.. math ::
\begin{cases}
y_0' = y_1 \\
y_1' = -y_0
\end{cases}
To get a well-defined IVP, we need two initial values. With
`y(0) = y_0(0) = 1` and `-y'(0) = y_1(0) = 0`, the problem will of
course be solved by `y(x) = y_0(x) = \cos(x)` and
`-y'(x) = y_1(x) = \sin(x)`. We check this::
>>> f = odefun(lambda x, y: [-y[1], y[0]], 0, [1, 0])
>>> for x in [0, 1, 2.5, 10]:
... nprint(f(x), 15)
... nprint([cos(x), sin(x)], 15)
... print("---")
...
[1.0, 0.0]
[1.0, 0.0]
---
[0.54030230586814, 0.841470984807897]
[0.54030230586814, 0.841470984807897]
---
[-0.801143615546934, 0.598472144103957]
[-0.801143615546934, 0.598472144103957]
---
[-0.839071529076452, -0.54402111088937]
[-0.839071529076452, -0.54402111088937]
---
Note that we get both the sine and the cosine solutions
simultaneously.
**TODO**
* Better automatic choice of degree and step size
* Make determination of Taylor series convergence radius
more robust
* Allow solution for `x < x_0`
* Allow solution for complex `x`
* Test for difficult (ill-conditioned) problems
* Implement Runge-Kutta and other algorithms
"""
if tol:
tol_prec = int(-ctx.log(tol, 2))+10
else:
tol_prec = ctx.prec+10
degree = degree or (3 + int(3*ctx.dps/2.))
workprec = ctx.prec + 40
try:
len(y0)
return_vector = True
except TypeError:
F_ = F
F = lambda x, y: [F_(x, y[0])]
y0 = [y0]
return_vector = False
ser, xb = ode_taylor(ctx, F, x0, y0, tol_prec, degree)
series_boundaries = [x0, xb]
series_data = [(ser, x0, xb)]
# We will be working with vectors of Taylor series
def mpolyval(ser, a):
return [ctx.polyval(s[::-1], a) for s in ser]
# Find nearest expansion point; compute if necessary
def get_series(x):
if x < x0:
raise ValueError
n = bisect(series_boundaries, x)
if n < len(series_boundaries):
return series_data[n-1]
while 1:
ser, xa, xb = series_data[-1]
if verbose:
print("Computing Taylor series for [%f, %f]" % (xa, xb))
y = mpolyval(ser, xb-xa)
xa = xb
ser, xb = ode_taylor(ctx, F, xb, y, tol_prec, degree)
series_boundaries.append(xb)
series_data.append((ser, xa, xb))
if x <= xb:
return series_data[-1]
# Evaluation function
def interpolant(x):
x = ctx.convert(x)
orig = ctx.prec
try:
ctx.prec = workprec
ser, xa, xb = get_series(x)
y = mpolyval(ser, x-xa)
finally:
ctx.prec = orig
if return_vector:
return [+yk for yk in y]
else:
return +y[0]
return interpolant
ODEMethods.odefun = odefun
if __name__ == "__main__":
import doctest
doctest.testmod()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,213 @@
from ..libmp.backend import xrange
from .calculus import defun
#----------------------------------------------------------------------------#
# Polynomials #
#----------------------------------------------------------------------------#
# XXX: extra precision
@defun
def polyval(ctx, coeffs, x, derivative=False):
r"""
Given coefficients `[c_n, \ldots, c_2, c_1, c_0]` and a number `x`,
:func:`~mpmath.polyval` evaluates the polynomial
.. math ::
P(x) = c_n x^n + \ldots + c_2 x^2 + c_1 x + c_0.
If *derivative=True* is set, :func:`~mpmath.polyval` simultaneously
evaluates `P(x)` with the derivative, `P'(x)`, and returns the
tuple `(P(x), P'(x))`.
>>> from mpmath import *
>>> mp.pretty = True
>>> polyval([3, 0, 2], 0.5)
2.75
>>> polyval([3, 0, 2], 0.5, derivative=True)
(2.75, 3.0)
The coefficients and the evaluation point may be any combination
of real or complex numbers.
"""
if not coeffs:
return ctx.zero
p = ctx.convert(coeffs[0])
q = ctx.zero
for c in coeffs[1:]:
if derivative:
q = p + x*q
p = c + x*p
if derivative:
return p, q
else:
return p
@defun
def polyroots(ctx, coeffs, maxsteps=50, cleanup=True, extraprec=10,
error=False, roots_init=None):
"""
Computes all roots (real or complex) of a given polynomial.
The roots are returned as a sorted list, where real roots appear first
followed by complex conjugate roots as adjacent elements. The polynomial
should be given as a list of coefficients, in the format used by
:func:`~mpmath.polyval`. The leading coefficient must be nonzero.
With *error=True*, :func:`~mpmath.polyroots` returns a tuple *(roots, err)*
where *err* is an estimate of the maximum error among the computed roots.
**Examples**
Finding the three real roots of `x^3 - x^2 - 14x + 24`::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> nprint(polyroots([1,-1,-14,24]), 4)
[-4.0, 2.0, 3.0]
Finding the two complex conjugate roots of `4x^2 + 3x + 2`, with an
error estimate::
>>> roots, err = polyroots([4,3,2], error=True)
>>> for r in roots:
... print(r)
...
(-0.375 + 0.59947894041409j)
(-0.375 - 0.59947894041409j)
>>>
>>> err
2.22044604925031e-16
>>>
>>> polyval([4,3,2], roots[0])
(2.22044604925031e-16 + 0.0j)
>>> polyval([4,3,2], roots[1])
(2.22044604925031e-16 + 0.0j)
The following example computes all the 5th roots of unity; that is,
the roots of `x^5 - 1`::
>>> mp.dps = 20
>>> for r in polyroots([1, 0, 0, 0, 0, -1]):
... print(r)
...
1.0
(-0.8090169943749474241 + 0.58778525229247312917j)
(-0.8090169943749474241 - 0.58778525229247312917j)
(0.3090169943749474241 + 0.95105651629515357212j)
(0.3090169943749474241 - 0.95105651629515357212j)
**Precision and conditioning**
The roots are computed to the current working precision accuracy. If this
accuracy cannot be achieved in ``maxsteps`` steps, then a
``NoConvergence`` exception is raised. The algorithm internally is using
the current working precision extended by ``extraprec``. If
``NoConvergence`` was raised, that is caused either by not having enough
extra precision to achieve convergence (in which case increasing
``extraprec`` should fix the problem) or too low ``maxsteps`` (in which
case increasing ``maxsteps`` should fix the problem), or a combination of
both.
The user should always do a convergence study with regards to
``extraprec`` to ensure accurate results. It is possible to get
convergence to a wrong answer with too low ``extraprec``.
Provided there are no repeated roots, :func:`~mpmath.polyroots` can
typically compute all roots of an arbitrary polynomial to high precision::
>>> mp.dps = 60
>>> for r in polyroots([1, 0, -10, 0, 1]):
... print(r)
...
-3.14626436994197234232913506571557044551247712918732870123249
-0.317837245195782244725757617296174288373133378433432554879127
0.317837245195782244725757617296174288373133378433432554879127
3.14626436994197234232913506571557044551247712918732870123249
>>>
>>> sqrt(3) + sqrt(2)
3.14626436994197234232913506571557044551247712918732870123249
>>> sqrt(3) - sqrt(2)
0.317837245195782244725757617296174288373133378433432554879127
**Algorithm**
:func:`~mpmath.polyroots` implements the Durand-Kerner method [1], which
uses complex arithmetic to locate all roots simultaneously.
The Durand-Kerner method can be viewed as approximately performing
simultaneous Newton iteration for all the roots. In particular,
the convergence to simple roots is quadratic, just like Newton's
method.
Although all roots are internally calculated using complex arithmetic, any
root found to have an imaginary part smaller than the estimated numerical
error is truncated to a real number (small real parts are also chopped).
Real roots are placed first in the returned list, sorted by value. The
remaining complex roots are sorted by their real parts so that conjugate
roots end up next to each other.
**References**
1. http://en.wikipedia.org/wiki/Durand-Kerner_method
"""
if len(coeffs) <= 1:
if not coeffs or not coeffs[0]:
raise ValueError("Input to polyroots must not be the zero polynomial")
# Constant polynomial with no roots
return []
orig = ctx.prec
tol = +ctx.eps
with ctx.extraprec(extraprec):
deg = len(coeffs) - 1
# Must be monic
lead = ctx.convert(coeffs[0])
if lead == 1:
coeffs = [ctx.convert(c) for c in coeffs]
else:
coeffs = [c/lead for c in coeffs]
f = lambda x: ctx.polyval(coeffs, x)
if roots_init is None:
roots = [ctx.mpc((0.4+0.9j)**n) for n in xrange(deg)]
else:
roots = [None]*deg;
deg_init = min(deg, len(roots_init))
roots[:deg_init] = list(roots_init[:deg_init])
roots[deg_init:] = [ctx.mpc((0.4+0.9j)**n) for n
in xrange(deg_init,deg)]
err = [ctx.one for n in xrange(deg)]
# Durand-Kerner iteration until convergence
for step in xrange(maxsteps):
if abs(max(err)) < tol:
break
for i in xrange(deg):
p = roots[i]
x = f(p)
for j in range(deg):
if i != j:
try:
x /= (p-roots[j])
except ZeroDivisionError:
continue
roots[i] = p - x
err[i] = abs(x)
if abs(max(err)) >= tol:
raise ctx.NoConvergence("Didn't converge in maxsteps=%d steps." \
% maxsteps)
# Remove small real or imaginary parts
if cleanup:
for i in xrange(deg):
if abs(roots[i]) < tol:
roots[i] = ctx.zero
elif abs(ctx._im(roots[i])) < tol:
roots[i] = roots[i].real
elif abs(ctx._re(roots[i])) < tol:
roots[i] = roots[i].imag * 1j
roots.sort(key=lambda x: (abs(ctx._im(x)), ctx._re(x)))
if error:
err = max(err)
err = max(err, ctx.ldexp(1, -orig+1))
return [+r for r in roots], +err
else:
return [+r for r in roots]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,494 @@
from operator import gt, lt
from .libmp.backend import xrange
from .functions.functions import SpecialFunctions
from .functions.rszeta import RSCache
from .calculus.quadrature import QuadratureMethods
from .calculus.inverselaplace import LaplaceTransformInversionMethods
from .calculus.calculus import CalculusMethods
from .calculus.optimization import OptimizationMethods
from .calculus.odes import ODEMethods
from .matrices.matrices import MatrixMethods
from .matrices.calculus import MatrixCalculusMethods
from .matrices.linalg import LinearAlgebraMethods
from .matrices.eigen import Eigen
from .identification import IdentificationMethods
from .visualization import VisualizationMethods
from . import libmp
class Context(object):
pass
class StandardBaseContext(Context,
SpecialFunctions,
RSCache,
QuadratureMethods,
LaplaceTransformInversionMethods,
CalculusMethods,
MatrixMethods,
MatrixCalculusMethods,
LinearAlgebraMethods,
Eigen,
IdentificationMethods,
OptimizationMethods,
ODEMethods,
VisualizationMethods):
NoConvergence = libmp.NoConvergence
ComplexResult = libmp.ComplexResult
def __init__(ctx):
ctx._aliases = {}
# Call those that need preinitialization (e.g. for wrappers)
SpecialFunctions.__init__(ctx)
RSCache.__init__(ctx)
QuadratureMethods.__init__(ctx)
LaplaceTransformInversionMethods.__init__(ctx)
CalculusMethods.__init__(ctx)
MatrixMethods.__init__(ctx)
def _init_aliases(ctx):
for alias, value in ctx._aliases.items():
try:
setattr(ctx, alias, getattr(ctx, value))
except AttributeError:
pass
_fixed_precision = False
# XXX
verbose = False
def warn(ctx, msg):
print("Warning:", msg)
def bad_domain(ctx, msg):
raise ValueError(msg)
def _re(ctx, x):
if hasattr(x, "real"):
return x.real
return x
def _im(ctx, x):
if hasattr(x, "imag"):
return x.imag
return ctx.zero
def _as_points(ctx, x):
return x
def fneg(ctx, x, **kwargs):
return -ctx.convert(x)
def fadd(ctx, x, y, **kwargs):
return ctx.convert(x)+ctx.convert(y)
def fsub(ctx, x, y, **kwargs):
return ctx.convert(x)-ctx.convert(y)
def fmul(ctx, x, y, **kwargs):
return ctx.convert(x)*ctx.convert(y)
def fdiv(ctx, x, y, **kwargs):
return ctx.convert(x)/ctx.convert(y)
def fsum(ctx, args, absolute=False, squared=False):
if absolute:
if squared:
return sum((abs(x)**2 for x in args), ctx.zero)
return sum((abs(x) for x in args), ctx.zero)
if squared:
return sum((x**2 for x in args), ctx.zero)
return sum(args, ctx.zero)
def fdot(ctx, xs, ys=None, conjugate=False):
if ys is not None:
xs = zip(xs, ys)
if conjugate:
cf = ctx.conj
return sum((x*cf(y) for (x,y) in xs), ctx.zero)
else:
return sum((x*y for (x,y) in xs), ctx.zero)
def fprod(ctx, args):
prod = ctx.one
for arg in args:
prod *= arg
return prod
def nprint(ctx, x, n=6, **kwargs):
"""
Equivalent to ``print(nstr(x, n))``.
"""
print(ctx.nstr(x, n, **kwargs))
def chop(ctx, x, tol=None):
"""
Chops off small real or imaginary parts, or converts
numbers close to zero to exact zeros. The input can be a
single number or an iterable::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = False
>>> chop(5+1e-10j, tol=1e-9)
mpf('5.0')
>>> nprint(chop([1.0, 1e-20, 3+1e-18j, -4, 2]))
[1.0, 0.0, 3.0, -4.0, 2.0]
The tolerance defaults to ``100*eps``.
"""
if tol is None:
tol = 100*ctx.eps
try:
x = ctx.convert(x)
absx = abs(x)
if abs(x) < tol:
return ctx.zero
if ctx._is_complex_type(x):
#part_tol = min(tol, absx*tol)
part_tol = max(tol, absx*tol)
if abs(x.imag) < part_tol:
return x.real
if abs(x.real) < part_tol:
return ctx.mpc(0, x.imag)
except TypeError:
if isinstance(x, ctx.matrix):
return x.apply(lambda a: ctx.chop(a, tol))
if hasattr(x, "__iter__"):
return [ctx.chop(a, tol) for a in x]
return x
def almosteq(ctx, s, t, rel_eps=None, abs_eps=None):
r"""
Determine whether the difference between `s` and `t` is smaller
than a given epsilon, either relatively or absolutely.
Both a maximum relative difference and a maximum difference
('epsilons') may be specified. The absolute difference is
defined as `|s-t|` and the relative difference is defined
as `|s-t|/\max(|s|, |t|)`.
If only one epsilon is given, both are set to the same value.
If none is given, both epsilons are set to `2^{-p+m}` where
`p` is the current working precision and `m` is a small
integer. The default setting typically allows :func:`~mpmath.almosteq`
to be used to check for mathematical equality
in the presence of small rounding errors.
**Examples**
>>> from mpmath import *
>>> mp.dps = 15
>>> almosteq(3.141592653589793, 3.141592653589790)
True
>>> almosteq(3.141592653589793, 3.141592653589700)
False
>>> almosteq(3.141592653589793, 3.141592653589700, 1e-10)
True
>>> almosteq(1e-20, 2e-20)
True
>>> almosteq(1e-20, 2e-20, rel_eps=0, abs_eps=0)
False
"""
t = ctx.convert(t)
if abs_eps is None and rel_eps is None:
rel_eps = abs_eps = ctx.ldexp(1, -ctx.prec+4)
if abs_eps is None:
abs_eps = rel_eps
elif rel_eps is None:
rel_eps = abs_eps
diff = abs(s-t)
if diff <= abs_eps:
return True
abss = abs(s)
abst = abs(t)
if abss < abst:
err = diff/abst
else:
err = diff/abss
return err <= rel_eps
def arange(ctx, *args):
r"""
This is a generalized version of Python's :func:`~mpmath.range` function
that accepts fractional endpoints and step sizes and
returns a list of ``mpf`` instances. Like :func:`~mpmath.range`,
:func:`~mpmath.arange` can be called with 1, 2 or 3 arguments:
``arange(b)``
`[0, 1, 2, \ldots, x]`
``arange(a, b)``
`[a, a+1, a+2, \ldots, x]`
``arange(a, b, h)``
`[a, a+h, a+h, \ldots, x]`
where `b-1 \le x < b` (in the third case, `b-h \le x < b`).
Like Python's :func:`~mpmath.range`, the endpoint is not included. To
produce ranges where the endpoint is included, :func:`~mpmath.linspace`
is more convenient.
**Examples**
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = False
>>> arange(4)
[mpf('0.0'), mpf('1.0'), mpf('2.0'), mpf('3.0')]
>>> arange(1, 2, 0.25)
[mpf('1.0'), mpf('1.25'), mpf('1.5'), mpf('1.75')]
>>> arange(1, -1, -0.75)
[mpf('1.0'), mpf('0.25'), mpf('-0.5')]
"""
if not len(args) <= 3:
raise TypeError('arange expected at most 3 arguments, got %i'
% len(args))
if not len(args) >= 1:
raise TypeError('arange expected at least 1 argument, got %i'
% len(args))
# set default
a = 0
dt = 1
# interpret arguments
if len(args) == 1:
b = args[0]
elif len(args) >= 2:
a = args[0]
b = args[1]
if len(args) == 3:
dt = args[2]
a, b, dt = ctx.mpf(a), ctx.mpf(b), ctx.mpf(dt)
assert a + dt != a, 'dt is too small and would cause an infinite loop'
# adapt code for sign of dt
if a > b:
if dt > 0:
return []
op = gt
else:
if dt < 0:
return []
op = lt
# create list
result = []
i = 0
t = a
while 1:
t = a + dt*i
i += 1
if op(t, b):
result.append(t)
else:
break
return result
def linspace(ctx, *args, **kwargs):
"""
``linspace(a, b, n)`` returns a list of `n` evenly spaced
samples from `a` to `b`. The syntax ``linspace(mpi(a,b), n)``
is also valid.
This function is often more convenient than :func:`~mpmath.arange`
for partitioning an interval into subintervals, since
the endpoint is included::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = False
>>> linspace(1, 4, 4)
[mpf('1.0'), mpf('2.0'), mpf('3.0'), mpf('4.0')]
You may also provide the keyword argument ``endpoint=False``::
>>> linspace(1, 4, 4, endpoint=False)
[mpf('1.0'), mpf('1.75'), mpf('2.5'), mpf('3.25')]
"""
if len(args) == 3:
a = ctx.mpf(args[0])
b = ctx.mpf(args[1])
n = int(args[2])
elif len(args) == 2:
assert hasattr(args[0], '_mpi_')
a = args[0].a
b = args[0].b
n = int(args[1])
else:
raise TypeError('linspace expected 2 or 3 arguments, got %i' \
% len(args))
if n < 1:
raise ValueError('n must be greater than 0')
if not 'endpoint' in kwargs or kwargs['endpoint']:
if n == 1:
return [ctx.mpf(a)]
step = (b - a) / ctx.mpf(n - 1)
y = [i*step + a for i in xrange(n)]
y[-1] = b
else:
step = (b - a) / ctx.mpf(n)
y = [i*step + a for i in xrange(n)]
return y
def cos_sin(ctx, z, **kwargs):
return ctx.cos(z, **kwargs), ctx.sin(z, **kwargs)
def cospi_sinpi(ctx, z, **kwargs):
return ctx.cospi(z, **kwargs), ctx.sinpi(z, **kwargs)
def _default_hyper_maxprec(ctx, p):
return int(1000 * p**0.25 + 4*p)
_gcd = staticmethod(libmp.gcd)
list_primes = staticmethod(libmp.list_primes)
isprime = staticmethod(libmp.isprime)
bernfrac = staticmethod(libmp.bernfrac)
moebius = staticmethod(libmp.moebius)
_ifac = staticmethod(libmp.ifac)
_eulernum = staticmethod(libmp.eulernum)
_stirling1 = staticmethod(libmp.stirling1)
_stirling2 = staticmethod(libmp.stirling2)
def sum_accurately(ctx, terms, check_step=1):
prec = ctx.prec
try:
extraprec = 10
while 1:
ctx.prec = prec + extraprec + 5
max_mag = ctx.ninf
s = ctx.zero
k = 0
for term in terms():
s += term
if (not k % check_step) and term:
term_mag = ctx.mag(term)
max_mag = max(max_mag, term_mag)
sum_mag = ctx.mag(s)
if sum_mag - term_mag > ctx.prec:
break
k += 1
cancellation = max_mag - sum_mag
if cancellation != cancellation:
break
if cancellation < extraprec or ctx._fixed_precision:
break
extraprec += min(ctx.prec, cancellation)
return s
finally:
ctx.prec = prec
def mul_accurately(ctx, factors, check_step=1):
prec = ctx.prec
try:
extraprec = 10
while 1:
ctx.prec = prec + extraprec + 5
max_mag = ctx.ninf
one = ctx.one
s = one
k = 0
for factor in factors():
s *= factor
term = factor - one
if (not k % check_step):
term_mag = ctx.mag(term)
max_mag = max(max_mag, term_mag)
sum_mag = ctx.mag(s-one)
#if sum_mag - term_mag > ctx.prec:
# break
if -term_mag > ctx.prec:
break
k += 1
cancellation = max_mag - sum_mag
if cancellation != cancellation:
break
if cancellation < extraprec or ctx._fixed_precision:
break
extraprec += min(ctx.prec, cancellation)
return s
finally:
ctx.prec = prec
def power(ctx, x, y):
r"""Converts `x` and `y` to mpmath numbers and evaluates
`x^y = \exp(y \log(x))`::
>>> from mpmath import *
>>> mp.dps = 30; mp.pretty = True
>>> power(2, 0.5)
1.41421356237309504880168872421
This shows the leading few digits of a large Mersenne prime
(performing the exact calculation ``2**43112609-1`` and
displaying the result in Python would be very slow)::
>>> power(2, 43112609)-1
3.16470269330255923143453723949e+12978188
"""
return ctx.convert(x) ** ctx.convert(y)
def _zeta_int(ctx, n):
return ctx.zeta(n)
def maxcalls(ctx, f, N):
"""
Return a wrapped copy of *f* that raises ``NoConvergence`` when *f*
has been called more than *N* times::
>>> from mpmath import *
>>> mp.dps = 15
>>> f = maxcalls(sin, 10)
>>> print(sum(f(n) for n in range(10)))
1.95520948210738
>>> f(10) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
NoConvergence: maxcalls: function evaluated 10 times
"""
counter = [0]
def f_maxcalls_wrapped(*args, **kwargs):
counter[0] += 1
if counter[0] > N:
raise ctx.NoConvergence("maxcalls: function evaluated %i times" % N)
return f(*args, **kwargs)
return f_maxcalls_wrapped
def memoize(ctx, f):
"""
Return a wrapped copy of *f* that caches computed values, i.e.
a memoized copy of *f*. Values are only reused if the cached precision
is equal to or higher than the working precision::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> f = memoize(maxcalls(sin, 1))
>>> f(2)
0.909297426825682
>>> f(2)
0.909297426825682
>>> mp.dps = 25
>>> f(2) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
NoConvergence: maxcalls: function evaluated 1 times
"""
f_cache = {}
def f_cached(*args, **kwargs):
if kwargs:
key = args, tuple(kwargs.items())
else:
key = args
prec = ctx.prec
if key in f_cache:
cprec, cvalue = f_cache[key]
if cprec >= prec:
return +cvalue
value = f(*args, **kwargs)
f_cache[key] = (prec, value)
return value
f_cached.__name__ = f.__name__
f_cached.__doc__ = f.__doc__
return f_cached

View File

@ -0,0 +1,253 @@
from .ctx_base import StandardBaseContext
import math
import cmath
from . import math2
from . import function_docs
from .libmp import mpf_bernoulli, to_float, int_types
from . import libmp
class FPContext(StandardBaseContext):
"""
Context for fast low-precision arithmetic (53-bit precision, giving at most
about 15-digit accuracy), using Python's builtin float and complex.
"""
def __init__(ctx):
StandardBaseContext.__init__(ctx)
# Override SpecialFunctions implementation
ctx.loggamma = math2.loggamma
ctx._bernoulli_cache = {}
ctx.pretty = False
ctx._init_aliases()
_mpq = lambda cls, x: float(x[0])/x[1]
NoConvergence = libmp.NoConvergence
def _get_prec(ctx): return 53
def _set_prec(ctx, p): return
def _get_dps(ctx): return 15
def _set_dps(ctx, p): return
_fixed_precision = True
prec = property(_get_prec, _set_prec)
dps = property(_get_dps, _set_dps)
zero = 0.0
one = 1.0
eps = math2.EPS
inf = math2.INF
ninf = math2.NINF
nan = math2.NAN
j = 1j
# Called by SpecialFunctions.__init__()
@classmethod
def _wrap_specfun(cls, name, f, wrap):
if wrap:
def f_wrapped(ctx, *args, **kwargs):
convert = ctx.convert
args = [convert(a) for a in args]
return f(ctx, *args, **kwargs)
else:
f_wrapped = f
f_wrapped.__doc__ = function_docs.__dict__.get(name, f.__doc__)
setattr(cls, name, f_wrapped)
def bernoulli(ctx, n):
cache = ctx._bernoulli_cache
if n in cache:
return cache[n]
cache[n] = to_float(mpf_bernoulli(n, 53, 'n'), strict=True)
return cache[n]
pi = math2.pi
e = math2.e
euler = math2.euler
sqrt2 = 1.4142135623730950488
sqrt5 = 2.2360679774997896964
phi = 1.6180339887498948482
ln2 = 0.69314718055994530942
ln10 = 2.302585092994045684
euler = 0.57721566490153286061
catalan = 0.91596559417721901505
khinchin = 2.6854520010653064453
apery = 1.2020569031595942854
glaisher = 1.2824271291006226369
absmin = absmax = abs
def is_special(ctx, x):
return x - x != 0.0
def isnan(ctx, x):
return x != x
def isinf(ctx, x):
return abs(x) == math2.INF
def isnormal(ctx, x):
if x:
return x - x == 0.0
return False
def isnpint(ctx, x):
if type(x) is complex:
if x.imag:
return False
x = x.real
return x <= 0.0 and round(x) == x
mpf = float
mpc = complex
def convert(ctx, x):
try:
return float(x)
except:
return complex(x)
power = staticmethod(math2.pow)
sqrt = staticmethod(math2.sqrt)
exp = staticmethod(math2.exp)
ln = log = staticmethod(math2.log)
cos = staticmethod(math2.cos)
sin = staticmethod(math2.sin)
tan = staticmethod(math2.tan)
cos_sin = staticmethod(math2.cos_sin)
acos = staticmethod(math2.acos)
asin = staticmethod(math2.asin)
atan = staticmethod(math2.atan)
cosh = staticmethod(math2.cosh)
sinh = staticmethod(math2.sinh)
tanh = staticmethod(math2.tanh)
gamma = staticmethod(math2.gamma)
rgamma = staticmethod(math2.rgamma)
fac = factorial = staticmethod(math2.factorial)
floor = staticmethod(math2.floor)
ceil = staticmethod(math2.ceil)
cospi = staticmethod(math2.cospi)
sinpi = staticmethod(math2.sinpi)
cbrt = staticmethod(math2.cbrt)
_nthroot = staticmethod(math2.nthroot)
_ei = staticmethod(math2.ei)
_e1 = staticmethod(math2.e1)
_zeta = _zeta_int = staticmethod(math2.zeta)
# XXX: math2
def arg(ctx, z):
z = complex(z)
return math.atan2(z.imag, z.real)
def expj(ctx, x):
return ctx.exp(ctx.j*x)
def expjpi(ctx, x):
return ctx.exp(ctx.j*ctx.pi*x)
ldexp = math.ldexp
frexp = math.frexp
def mag(ctx, z):
if z:
return ctx.frexp(abs(z))[1]
return ctx.ninf
def isint(ctx, z):
if hasattr(z, "imag"): # float/int don't have .real/.imag in py2.5
if z.imag:
return False
z = z.real
try:
return z == int(z)
except:
return False
def nint_distance(ctx, z):
if hasattr(z, "imag"): # float/int don't have .real/.imag in py2.5
n = round(z.real)
else:
n = round(z)
if n == z:
return n, ctx.ninf
return n, ctx.mag(abs(z-n))
def _convert_param(ctx, z):
if type(z) is tuple:
p, q = z
return ctx.mpf(p) / q, 'R'
if hasattr(z, "imag"): # float/int don't have .real/.imag in py2.5
intz = int(z.real)
else:
intz = int(z)
if z == intz:
return intz, 'Z'
return z, 'R'
def _is_real_type(ctx, z):
return isinstance(z, float) or isinstance(z, int_types)
def _is_complex_type(ctx, z):
return isinstance(z, complex)
def hypsum(ctx, p, q, types, coeffs, z, maxterms=6000, **kwargs):
coeffs = list(coeffs)
num = range(p)
den = range(p,p+q)
tol = ctx.eps
s = t = 1.0
k = 0
while 1:
for i in num: t *= (coeffs[i]+k)
for i in den: t /= (coeffs[i]+k)
k += 1; t /= k; t *= z; s += t
if abs(t) < tol:
return s
if k > maxterms:
raise ctx.NoConvergence
def atan2(ctx, x, y):
return math.atan2(x, y)
def psi(ctx, m, z):
m = int(m)
if m == 0:
return ctx.digamma(z)
return (-1)**(m+1) * ctx.fac(m) * ctx.zeta(m+1, z)
digamma = staticmethod(math2.digamma)
def harmonic(ctx, x):
x = ctx.convert(x)
if x == 0 or x == 1:
return x
return ctx.digamma(x+1) + ctx.euler
nstr = str
def to_fixed(ctx, x, prec):
return int(math.ldexp(x, prec))
def rand(ctx):
import random
return random.random()
_erf = staticmethod(math2.erf)
_erfc = staticmethod(math2.erfc)
def sum_accurately(ctx, terms, check_step=1):
s = ctx.zero
k = 0
for term in terms():
s += term
if (not k % check_step) and term:
if abs(term) <= 1e-18*abs(s):
break
k += 1
return s

View File

@ -0,0 +1,551 @@
import operator
from . import libmp
from .libmp.backend import basestring
from .libmp import (
int_types, MPZ_ONE,
prec_to_dps, dps_to_prec, repr_dps,
round_floor, round_ceiling,
fzero, finf, fninf, fnan,
mpf_le, mpf_neg,
from_int, from_float, from_str, from_rational,
mpi_mid, mpi_delta, mpi_str,
mpi_abs, mpi_pos, mpi_neg, mpi_add, mpi_sub,
mpi_mul, mpi_div, mpi_pow_int, mpi_pow,
mpi_from_str,
mpci_pos, mpci_neg, mpci_add, mpci_sub, mpci_mul, mpci_div, mpci_pow,
mpci_abs, mpci_pow, mpci_exp, mpci_log,
ComplexResult,
mpf_hash, mpc_hash)
from .matrices.matrices import _matrix
mpi_zero = (fzero, fzero)
from .ctx_base import StandardBaseContext
new = object.__new__
def convert_mpf_(x, prec, rounding):
if hasattr(x, "_mpf_"): return x._mpf_
if isinstance(x, int_types): return from_int(x, prec, rounding)
if isinstance(x, float): return from_float(x, prec, rounding)
if isinstance(x, basestring): return from_str(x, prec, rounding)
raise NotImplementedError
class ivmpf(object):
"""
Interval arithmetic class. Precision is controlled by iv.prec.
"""
def __new__(cls, x=0):
return cls.ctx.convert(x)
def cast(self, cls, f_convert):
a, b = self._mpi_
if a == b:
return cls(f_convert(a))
raise ValueError
def __int__(self):
return self.cast(int, libmp.to_int)
def __float__(self):
return self.cast(float, libmp.to_float)
def __complex__(self):
return self.cast(complex, libmp.to_float)
def __hash__(self):
a, b = self._mpi_
if a == b:
return mpf_hash(a)
else:
return hash(self._mpi_)
@property
def real(self): return self
@property
def imag(self): return self.ctx.zero
def conjugate(self): return self
@property
def a(self):
a, b = self._mpi_
return self.ctx.make_mpf((a, a))
@property
def b(self):
a, b = self._mpi_
return self.ctx.make_mpf((b, b))
@property
def mid(self):
ctx = self.ctx
v = mpi_mid(self._mpi_, ctx.prec)
return ctx.make_mpf((v, v))
@property
def delta(self):
ctx = self.ctx
v = mpi_delta(self._mpi_, ctx.prec)
return ctx.make_mpf((v,v))
@property
def _mpci_(self):
return self._mpi_, mpi_zero
def _compare(*args):
raise TypeError("no ordering relation is defined for intervals")
__gt__ = _compare
__le__ = _compare
__gt__ = _compare
__ge__ = _compare
def __contains__(self, t):
t = self.ctx.mpf(t)
return (self.a <= t.a) and (t.b <= self.b)
def __str__(self):
return mpi_str(self._mpi_, self.ctx.prec)
def __repr__(self):
if self.ctx.pretty:
return str(self)
a, b = self._mpi_
n = repr_dps(self.ctx.prec)
a = libmp.to_str(a, n)
b = libmp.to_str(b, n)
return "mpi(%r, %r)" % (a, b)
def _compare(s, t, cmpfun):
if not hasattr(t, "_mpi_"):
try:
t = s.ctx.convert(t)
except:
return NotImplemented
return cmpfun(s._mpi_, t._mpi_)
def __eq__(s, t): return s._compare(t, libmp.mpi_eq)
def __ne__(s, t): return s._compare(t, libmp.mpi_ne)
def __lt__(s, t): return s._compare(t, libmp.mpi_lt)
def __le__(s, t): return s._compare(t, libmp.mpi_le)
def __gt__(s, t): return s._compare(t, libmp.mpi_gt)
def __ge__(s, t): return s._compare(t, libmp.mpi_ge)
def __abs__(self):
return self.ctx.make_mpf(mpi_abs(self._mpi_, self.ctx.prec))
def __pos__(self):
return self.ctx.make_mpf(mpi_pos(self._mpi_, self.ctx.prec))
def __neg__(self):
return self.ctx.make_mpf(mpi_neg(self._mpi_, self.ctx.prec))
def ae(s, t, rel_eps=None, abs_eps=None):
return s.ctx.almosteq(s, t, rel_eps, abs_eps)
class ivmpc(object):
def __new__(cls, re=0, im=0):
re = cls.ctx.convert(re)
im = cls.ctx.convert(im)
y = new(cls)
y._mpci_ = re._mpi_, im._mpi_
return y
def __hash__(self):
(a, b), (c,d) = self._mpci_
if a == b and c == d:
return mpc_hash((a, c))
else:
return hash(self._mpci_)
def __repr__(s):
if s.ctx.pretty:
return str(s)
return "iv.mpc(%s, %s)" % (repr(s.real), repr(s.imag))
def __str__(s):
return "(%s + %s*j)" % (str(s.real), str(s.imag))
@property
def a(self):
(a, b), (c,d) = self._mpci_
return self.ctx.make_mpf((a, a))
@property
def b(self):
(a, b), (c,d) = self._mpci_
return self.ctx.make_mpf((b, b))
@property
def c(self):
(a, b), (c,d) = self._mpci_
return self.ctx.make_mpf((c, c))
@property
def d(self):
(a, b), (c,d) = self._mpci_
return self.ctx.make_mpf((d, d))
@property
def real(s):
return s.ctx.make_mpf(s._mpci_[0])
@property
def imag(s):
return s.ctx.make_mpf(s._mpci_[1])
def conjugate(s):
a, b = s._mpci_
return s.ctx.make_mpc((a, mpf_neg(b)))
def overlap(s, t):
t = s.ctx.convert(t)
real_overlap = (s.a <= t.a <= s.b) or (s.a <= t.b <= s.b) or (t.a <= s.a <= t.b) or (t.a <= s.b <= t.b)
imag_overlap = (s.c <= t.c <= s.d) or (s.c <= t.d <= s.d) or (t.c <= s.c <= t.d) or (t.c <= s.d <= t.d)
return real_overlap and imag_overlap
def __contains__(s, t):
t = s.ctx.convert(t)
return t.real in s.real and t.imag in s.imag
def _compare(s, t, ne=False):
if not isinstance(t, s.ctx._types):
try:
t = s.ctx.convert(t)
except:
return NotImplemented
if hasattr(t, '_mpi_'):
tval = t._mpi_, mpi_zero
elif hasattr(t, '_mpci_'):
tval = t._mpci_
if ne:
return s._mpci_ != tval
return s._mpci_ == tval
def __eq__(s, t): return s._compare(t)
def __ne__(s, t): return s._compare(t, True)
def __lt__(s, t): raise TypeError("complex intervals cannot be ordered")
__le__ = __gt__ = __ge__ = __lt__
def __neg__(s): return s.ctx.make_mpc(mpci_neg(s._mpci_, s.ctx.prec))
def __pos__(s): return s.ctx.make_mpc(mpci_pos(s._mpci_, s.ctx.prec))
def __abs__(s): return s.ctx.make_mpf(mpci_abs(s._mpci_, s.ctx.prec))
def ae(s, t, rel_eps=None, abs_eps=None):
return s.ctx.almosteq(s, t, rel_eps, abs_eps)
def _binary_op(f_real, f_complex):
def g_complex(ctx, sval, tval):
return ctx.make_mpc(f_complex(sval, tval, ctx.prec))
def g_real(ctx, sval, tval):
try:
return ctx.make_mpf(f_real(sval, tval, ctx.prec))
except ComplexResult:
sval = (sval, mpi_zero)
tval = (tval, mpi_zero)
return g_complex(ctx, sval, tval)
def lop_real(s, t):
if isinstance(t, _matrix): return NotImplemented
ctx = s.ctx
if not isinstance(t, ctx._types): t = ctx.convert(t)
if hasattr(t, "_mpi_"): return g_real(ctx, s._mpi_, t._mpi_)
if hasattr(t, "_mpci_"): return g_complex(ctx, (s._mpi_, mpi_zero), t._mpci_)
return NotImplemented
def rop_real(s, t):
ctx = s.ctx
if not isinstance(t, ctx._types): t = ctx.convert(t)
if hasattr(t, "_mpi_"): return g_real(ctx, t._mpi_, s._mpi_)
if hasattr(t, "_mpci_"): return g_complex(ctx, t._mpci_, (s._mpi_, mpi_zero))
return NotImplemented
def lop_complex(s, t):
if isinstance(t, _matrix): return NotImplemented
ctx = s.ctx
if not isinstance(t, s.ctx._types):
try:
t = s.ctx.convert(t)
except (ValueError, TypeError):
return NotImplemented
return g_complex(ctx, s._mpci_, t._mpci_)
def rop_complex(s, t):
ctx = s.ctx
if not isinstance(t, s.ctx._types):
t = s.ctx.convert(t)
return g_complex(ctx, t._mpci_, s._mpci_)
return lop_real, rop_real, lop_complex, rop_complex
ivmpf.__add__, ivmpf.__radd__, ivmpc.__add__, ivmpc.__radd__ = _binary_op(mpi_add, mpci_add)
ivmpf.__sub__, ivmpf.__rsub__, ivmpc.__sub__, ivmpc.__rsub__ = _binary_op(mpi_sub, mpci_sub)
ivmpf.__mul__, ivmpf.__rmul__, ivmpc.__mul__, ivmpc.__rmul__ = _binary_op(mpi_mul, mpci_mul)
ivmpf.__div__, ivmpf.__rdiv__, ivmpc.__div__, ivmpc.__rdiv__ = _binary_op(mpi_div, mpci_div)
ivmpf.__pow__, ivmpf.__rpow__, ivmpc.__pow__, ivmpc.__rpow__ = _binary_op(mpi_pow, mpci_pow)
ivmpf.__truediv__ = ivmpf.__div__; ivmpf.__rtruediv__ = ivmpf.__rdiv__
ivmpc.__truediv__ = ivmpc.__div__; ivmpc.__rtruediv__ = ivmpc.__rdiv__
class ivmpf_constant(ivmpf):
def __new__(cls, f):
self = new(cls)
self._f = f
return self
def _get_mpi_(self):
prec = self.ctx._prec[0]
a = self._f(prec, round_floor)
b = self._f(prec, round_ceiling)
return a, b
_mpi_ = property(_get_mpi_)
class MPIntervalContext(StandardBaseContext):
def __init__(ctx):
ctx.mpf = type('ivmpf', (ivmpf,), {})
ctx.mpc = type('ivmpc', (ivmpc,), {})
ctx._types = (ctx.mpf, ctx.mpc)
ctx._constant = type('ivmpf_constant', (ivmpf_constant,), {})
ctx._prec = [53]
ctx._set_prec(53)
ctx._constant._ctxdata = ctx.mpf._ctxdata = ctx.mpc._ctxdata = [ctx.mpf, new, ctx._prec]
ctx._constant.ctx = ctx.mpf.ctx = ctx.mpc.ctx = ctx
ctx.pretty = False
StandardBaseContext.__init__(ctx)
ctx._init_builtins()
def _mpi(ctx, a, b=None):
if b is None:
return ctx.mpf(a)
return ctx.mpf((a,b))
def _init_builtins(ctx):
ctx.one = ctx.mpf(1)
ctx.zero = ctx.mpf(0)
ctx.inf = ctx.mpf('inf')
ctx.ninf = -ctx.inf
ctx.nan = ctx.mpf('nan')
ctx.j = ctx.mpc(0,1)
ctx.exp = ctx._wrap_mpi_function(libmp.mpi_exp, libmp.mpci_exp)
ctx.sqrt = ctx._wrap_mpi_function(libmp.mpi_sqrt)
ctx.ln = ctx._wrap_mpi_function(libmp.mpi_log, libmp.mpci_log)
ctx.cos = ctx._wrap_mpi_function(libmp.mpi_cos, libmp.mpci_cos)
ctx.sin = ctx._wrap_mpi_function(libmp.mpi_sin, libmp.mpci_sin)
ctx.tan = ctx._wrap_mpi_function(libmp.mpi_tan)
ctx.gamma = ctx._wrap_mpi_function(libmp.mpi_gamma, libmp.mpci_gamma)
ctx.loggamma = ctx._wrap_mpi_function(libmp.mpi_loggamma, libmp.mpci_loggamma)
ctx.rgamma = ctx._wrap_mpi_function(libmp.mpi_rgamma, libmp.mpci_rgamma)
ctx.factorial = ctx._wrap_mpi_function(libmp.mpi_factorial, libmp.mpci_factorial)
ctx.fac = ctx.factorial
ctx.eps = ctx._constant(lambda prec, rnd: (0, MPZ_ONE, 1-prec, 1))
ctx.pi = ctx._constant(libmp.mpf_pi)
ctx.e = ctx._constant(libmp.mpf_e)
ctx.ln2 = ctx._constant(libmp.mpf_ln2)
ctx.ln10 = ctx._constant(libmp.mpf_ln10)
ctx.phi = ctx._constant(libmp.mpf_phi)
ctx.euler = ctx._constant(libmp.mpf_euler)
ctx.catalan = ctx._constant(libmp.mpf_catalan)
ctx.glaisher = ctx._constant(libmp.mpf_glaisher)
ctx.khinchin = ctx._constant(libmp.mpf_khinchin)
ctx.twinprime = ctx._constant(libmp.mpf_twinprime)
def _wrap_mpi_function(ctx, f_real, f_complex=None):
def g(x, **kwargs):
if kwargs:
prec = kwargs.get('prec', ctx._prec[0])
else:
prec = ctx._prec[0]
x = ctx.convert(x)
if hasattr(x, "_mpi_"):
return ctx.make_mpf(f_real(x._mpi_, prec))
if hasattr(x, "_mpci_"):
return ctx.make_mpc(f_complex(x._mpci_, prec))
raise ValueError
return g
@classmethod
def _wrap_specfun(cls, name, f, wrap):
if wrap:
def f_wrapped(ctx, *args, **kwargs):
convert = ctx.convert
args = [convert(a) for a in args]
prec = ctx.prec
try:
ctx.prec += 10
retval = f(ctx, *args, **kwargs)
finally:
ctx.prec = prec
return +retval
else:
f_wrapped = f
setattr(cls, name, f_wrapped)
def _set_prec(ctx, n):
ctx._prec[0] = max(1, int(n))
ctx._dps = prec_to_dps(n)
def _set_dps(ctx, n):
ctx._prec[0] = dps_to_prec(n)
ctx._dps = max(1, int(n))
prec = property(lambda ctx: ctx._prec[0], _set_prec)
dps = property(lambda ctx: ctx._dps, _set_dps)
def make_mpf(ctx, v):
a = new(ctx.mpf)
a._mpi_ = v
return a
def make_mpc(ctx, v):
a = new(ctx.mpc)
a._mpci_ = v
return a
def _mpq(ctx, pq):
p, q = pq
a = libmp.from_rational(p, q, ctx.prec, round_floor)
b = libmp.from_rational(p, q, ctx.prec, round_ceiling)
return ctx.make_mpf((a, b))
def convert(ctx, x):
if isinstance(x, (ctx.mpf, ctx.mpc)):
return x
if isinstance(x, ctx._constant):
return +x
if isinstance(x, complex) or hasattr(x, "_mpc_"):
re = ctx.convert(x.real)
im = ctx.convert(x.imag)
return ctx.mpc(re,im)
if isinstance(x, basestring):
v = mpi_from_str(x, ctx.prec)
return ctx.make_mpf(v)
if hasattr(x, "_mpi_"):
a, b = x._mpi_
else:
try:
a, b = x
except (TypeError, ValueError):
a = b = x
if hasattr(a, "_mpi_"):
a = a._mpi_[0]
else:
a = convert_mpf_(a, ctx.prec, round_floor)
if hasattr(b, "_mpi_"):
b = b._mpi_[1]
else:
b = convert_mpf_(b, ctx.prec, round_ceiling)
if a == fnan or b == fnan:
a = fninf
b = finf
assert mpf_le(a, b), "endpoints must be properly ordered"
return ctx.make_mpf((a, b))
def nstr(ctx, x, n=5, **kwargs):
x = ctx.convert(x)
if hasattr(x, "_mpi_"):
return libmp.mpi_to_str(x._mpi_, n, **kwargs)
if hasattr(x, "_mpci_"):
re = libmp.mpi_to_str(x._mpci_[0], n, **kwargs)
im = libmp.mpi_to_str(x._mpci_[1], n, **kwargs)
return "(%s + %s*j)" % (re, im)
def mag(ctx, x):
x = ctx.convert(x)
if isinstance(x, ctx.mpc):
return max(ctx.mag(x.real), ctx.mag(x.imag)) + 1
a, b = libmp.mpi_abs(x._mpi_)
sign, man, exp, bc = b
if man:
return exp+bc
if b == fzero:
return ctx.ninf
if b == fnan:
return ctx.nan
return ctx.inf
def isnan(ctx, x):
return False
def isinf(ctx, x):
return x == ctx.inf
def isint(ctx, x):
x = ctx.convert(x)
a, b = x._mpi_
if a == b:
sign, man, exp, bc = a
if man:
return exp >= 0
return a == fzero
return None
def ldexp(ctx, x, n):
a, b = ctx.convert(x)._mpi_
a = libmp.mpf_shift(a, n)
b = libmp.mpf_shift(b, n)
return ctx.make_mpf((a,b))
def absmin(ctx, x):
return abs(ctx.convert(x)).a
def absmax(ctx, x):
return abs(ctx.convert(x)).b
def atan2(ctx, y, x):
y = ctx.convert(y)._mpi_
x = ctx.convert(x)._mpi_
return ctx.make_mpf(libmp.mpi_atan2(y,x,ctx.prec))
def _convert_param(ctx, x):
if isinstance(x, libmp.int_types):
return x, 'Z'
if isinstance(x, tuple):
p, q = x
return (ctx.mpf(p) / ctx.mpf(q), 'R')
x = ctx.convert(x)
if isinstance(x, ctx.mpf):
return x, 'R'
if isinstance(x, ctx.mpc):
return x, 'C'
raise ValueError
def _is_real_type(ctx, z):
return isinstance(z, ctx.mpf) or isinstance(z, int_types)
def _is_complex_type(ctx, z):
return isinstance(z, ctx.mpc)
def hypsum(ctx, p, q, types, coeffs, z, maxterms=6000, **kwargs):
coeffs = list(coeffs)
num = range(p)
den = range(p,p+q)
#tol = ctx.eps
s = t = ctx.one
k = 0
while 1:
for i in num: t *= (coeffs[i]+k)
for i in den: t /= (coeffs[i]+k)
k += 1; t /= k; t *= z; s += t
if t == 0:
return s
#if abs(t) < tol:
# return s
if k > maxterms:
raise ctx.NoConvergence
# Register with "numbers" ABC
# We do not subclass, hence we do not use the @abstractmethod checks. While
# this is less invasive it may turn out that we do not actually support
# parts of the expected interfaces. See
# http://docs.python.org/2/library/numbers.html for list of abstract
# methods.
try:
import numbers
numbers.Complex.register(ivmpc)
numbers.Real.register(ivmpf)
except ImportError:
pass

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,14 @@
from . import functions
# Hack to update methods
from . import factorials
from . import hypergeometric
from . import expintegrals
from . import bessel
from . import orthogonal
from . import theta
from . import elliptic
from . import signals
from . import zeta
from . import rszeta
from . import zetazeros
from . import qfunctions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,425 @@
from .functions import defun, defun_wrapped
@defun_wrapped
def _erf_complex(ctx, z):
z2 = ctx.square_exp_arg(z, -1)
#z2 = -z**2
v = (2/ctx.sqrt(ctx.pi))*z * ctx.hyp1f1((1,2),(3,2), z2)
if not ctx._re(z):
v = ctx._im(v)*ctx.j
return v
@defun_wrapped
def _erfc_complex(ctx, z):
if ctx.re(z) > 2:
z2 = ctx.square_exp_arg(z)
nz2 = ctx.fneg(z2, exact=True)
v = ctx.exp(nz2)/ctx.sqrt(ctx.pi) * ctx.hyperu((1,2),(1,2), z2)
else:
v = 1 - ctx._erf_complex(z)
if not ctx._re(z):
v = 1+ctx._im(v)*ctx.j
return v
@defun
def erf(ctx, z):
z = ctx.convert(z)
if ctx._is_real_type(z):
try:
return ctx._erf(z)
except NotImplementedError:
pass
if ctx._is_complex_type(z) and not z.imag:
try:
return type(z)(ctx._erf(z.real))
except NotImplementedError:
pass
return ctx._erf_complex(z)
@defun
def erfc(ctx, z):
z = ctx.convert(z)
if ctx._is_real_type(z):
try:
return ctx._erfc(z)
except NotImplementedError:
pass
if ctx._is_complex_type(z) and not z.imag:
try:
return type(z)(ctx._erfc(z.real))
except NotImplementedError:
pass
return ctx._erfc_complex(z)
@defun
def square_exp_arg(ctx, z, mult=1, reciprocal=False):
prec = ctx.prec*4+20
if reciprocal:
z2 = ctx.fmul(z, z, prec=prec)
z2 = ctx.fdiv(ctx.one, z2, prec=prec)
else:
z2 = ctx.fmul(z, z, prec=prec)
if mult != 1:
z2 = ctx.fmul(z2, mult, exact=True)
return z2
@defun_wrapped
def erfi(ctx, z):
if not z:
return z
z2 = ctx.square_exp_arg(z)
v = (2/ctx.sqrt(ctx.pi)*z) * ctx.hyp1f1((1,2), (3,2), z2)
if not ctx._re(z):
v = ctx._im(v)*ctx.j
return v
@defun_wrapped
def erfinv(ctx, x):
xre = ctx._re(x)
if (xre != x) or (xre < -1) or (xre > 1):
return ctx.bad_domain("erfinv(x) is defined only for -1 <= x <= 1")
x = xre
#if ctx.isnan(x): return x
if not x: return x
if x == 1: return ctx.inf
if x == -1: return ctx.ninf
if abs(x) < 0.9:
a = 0.53728*x**3 + 0.813198*x
else:
# An asymptotic formula
u = ctx.ln(2/ctx.pi/(abs(x)-1)**2)
a = ctx.sign(x) * ctx.sqrt(u - ctx.ln(u))/ctx.sqrt(2)
ctx.prec += 10
return ctx.findroot(lambda t: ctx.erf(t)-x, a)
@defun_wrapped
def npdf(ctx, x, mu=0, sigma=1):
sigma = ctx.convert(sigma)
return ctx.exp(-(x-mu)**2/(2*sigma**2)) / (sigma*ctx.sqrt(2*ctx.pi))
@defun_wrapped
def ncdf(ctx, x, mu=0, sigma=1):
a = (x-mu)/(sigma*ctx.sqrt(2))
if a < 0:
return ctx.erfc(-a)/2
else:
return (1+ctx.erf(a))/2
@defun_wrapped
def betainc(ctx, a, b, x1=0, x2=1, regularized=False):
if x1 == x2:
v = 0
elif not x1:
if x1 == 0 and x2 == 1:
v = ctx.beta(a, b)
else:
v = x2**a * ctx.hyp2f1(a, 1-b, a+1, x2) / a
else:
m, d = ctx.nint_distance(a)
if m <= 0:
if d < -ctx.prec:
h = +ctx.eps
ctx.prec *= 2
a += h
elif d < -4:
ctx.prec -= d
s1 = x2**a * ctx.hyp2f1(a,1-b,a+1,x2)
s2 = x1**a * ctx.hyp2f1(a,1-b,a+1,x1)
v = (s1 - s2) / a
if regularized:
v /= ctx.beta(a,b)
return v
@defun
def gammainc(ctx, z, a=0, b=None, regularized=False):
regularized = bool(regularized)
z = ctx.convert(z)
if a is None:
a = ctx.zero
lower_modified = False
else:
a = ctx.convert(a)
lower_modified = a != ctx.zero
if b is None:
b = ctx.inf
upper_modified = False
else:
b = ctx.convert(b)
upper_modified = b != ctx.inf
# Complete gamma function
if not (upper_modified or lower_modified):
if regularized:
if ctx.re(z) < 0:
return ctx.inf
elif ctx.re(z) > 0:
return ctx.one
else:
return ctx.nan
return ctx.gamma(z)
if a == b:
return ctx.zero
# Standardize
if ctx.re(a) > ctx.re(b):
return -ctx.gammainc(z, b, a, regularized)
# Generalized gamma
if upper_modified and lower_modified:
return +ctx._gamma3(z, a, b, regularized)
# Upper gamma
elif lower_modified:
return ctx._upper_gamma(z, a, regularized)
# Lower gamma
elif upper_modified:
return ctx._lower_gamma(z, b, regularized)
@defun
def _lower_gamma(ctx, z, b, regularized=False):
# Pole
if ctx.isnpint(z):
return type(z)(ctx.inf)
G = [z] * regularized
negb = ctx.fneg(b, exact=True)
def h(z):
T1 = [ctx.exp(negb), b, z], [1, z, -1], [], G, [1], [1+z], b
return (T1,)
return ctx.hypercomb(h, [z])
@defun
def _upper_gamma(ctx, z, a, regularized=False):
# Fast integer case, when available
if ctx.isint(z):
try:
if regularized:
# Gamma pole
if ctx.isnpint(z):
return type(z)(ctx.zero)
orig = ctx.prec
try:
ctx.prec += 10
return ctx._gamma_upper_int(z, a) / ctx.gamma(z)
finally:
ctx.prec = orig
else:
return ctx._gamma_upper_int(z, a)
except NotImplementedError:
pass
# hypercomb is unable to detect the exact zeros, so handle them here
if z == 2 and a == -1:
return (z+a)*0
if z == 3 and (a == -1-1j or a == -1+1j):
return (z+a)*0
nega = ctx.fneg(a, exact=True)
G = [z] * regularized
# Use 2F0 series when possible; fall back to lower gamma representation
try:
def h(z):
r = z-1
return [([ctx.exp(nega), a], [1, r], [], G, [1, -r], [], 1/nega)]
return ctx.hypercomb(h, [z], force_series=True)
except ctx.NoConvergence:
def h(z):
T1 = [], [1, z-1], [z], G, [], [], 0
T2 = [-ctx.exp(nega), a, z], [1, z, -1], [], G, [1], [1+z], a
return T1, T2
return ctx.hypercomb(h, [z])
@defun
def _gamma3(ctx, z, a, b, regularized=False):
pole = ctx.isnpint(z)
if regularized and pole:
return ctx.zero
try:
ctx.prec += 15
# We don't know in advance whether it's better to write as a difference
# of lower or upper gamma functions, so try both
T1 = ctx.gammainc(z, a, regularized=regularized)
T2 = ctx.gammainc(z, b, regularized=regularized)
R = T1 - T2
if ctx.mag(R) - max(ctx.mag(T1), ctx.mag(T2)) > -10:
return R
if not pole:
T1 = ctx.gammainc(z, 0, b, regularized=regularized)
T2 = ctx.gammainc(z, 0, a, regularized=regularized)
R = T1 - T2
# May be ok, but should probably at least print a warning
# about possible cancellation
if 1: #ctx.mag(R) - max(ctx.mag(T1), ctx.mag(T2)) > -10:
return R
finally:
ctx.prec -= 15
raise NotImplementedError
@defun_wrapped
def expint(ctx, n, z):
if ctx.isint(n) and ctx._is_real_type(z):
try:
return ctx._expint_int(n, z)
except NotImplementedError:
pass
if ctx.isnan(n) or ctx.isnan(z):
return z*n
if z == ctx.inf:
return 1/z
if z == 0:
# integral from 1 to infinity of t^n
if ctx.re(n) <= 1:
# TODO: reasonable sign of infinity
return type(z)(ctx.inf)
else:
return ctx.one/(n-1)
if n == 0:
return ctx.exp(-z)/z
if n == -1:
return ctx.exp(-z)*(z+1)/z**2
return z**(n-1) * ctx.gammainc(1-n, z)
@defun_wrapped
def li(ctx, z, offset=False):
if offset:
if z == 2:
return ctx.zero
return ctx.ei(ctx.ln(z)) - ctx.ei(ctx.ln2)
if not z:
return z
if z == 1:
return ctx.ninf
return ctx.ei(ctx.ln(z))
@defun
def ei(ctx, z):
try:
return ctx._ei(z)
except NotImplementedError:
return ctx._ei_generic(z)
@defun_wrapped
def _ei_generic(ctx, z):
# Note: the following is currently untested because mp and fp
# both use special-case ei code
if z == ctx.inf:
return z
if z == ctx.ninf:
return ctx.zero
if ctx.mag(z) > 1:
try:
r = ctx.one/z
v = ctx.exp(z)*ctx.hyper([1,1],[],r,
maxterms=ctx.prec, force_series=True)/z
im = ctx._im(z)
if im > 0:
v += ctx.pi*ctx.j
if im < 0:
v -= ctx.pi*ctx.j
return v
except ctx.NoConvergence:
pass
v = z*ctx.hyp2f2(1,1,2,2,z) + ctx.euler
if ctx._im(z):
v += 0.5*(ctx.log(z) - ctx.log(ctx.one/z))
else:
v += ctx.log(abs(z))
return v
@defun
def e1(ctx, z):
try:
return ctx._e1(z)
except NotImplementedError:
return ctx.expint(1, z)
@defun
def ci(ctx, z):
try:
return ctx._ci(z)
except NotImplementedError:
return ctx._ci_generic(z)
@defun_wrapped
def _ci_generic(ctx, z):
if ctx.isinf(z):
if z == ctx.inf: return ctx.zero
if z == ctx.ninf: return ctx.pi*1j
jz = ctx.fmul(ctx.j,z,exact=True)
njz = ctx.fneg(jz,exact=True)
v = 0.5*(ctx.ei(jz) + ctx.ei(njz))
zreal = ctx._re(z)
zimag = ctx._im(z)
if zreal == 0:
if zimag > 0: v += ctx.pi*0.5j
if zimag < 0: v -= ctx.pi*0.5j
if zreal < 0:
if zimag >= 0: v += ctx.pi*1j
if zimag < 0: v -= ctx.pi*1j
if ctx._is_real_type(z) and zreal > 0:
v = ctx._re(v)
return v
@defun
def si(ctx, z):
try:
return ctx._si(z)
except NotImplementedError:
return ctx._si_generic(z)
@defun_wrapped
def _si_generic(ctx, z):
if ctx.isinf(z):
if z == ctx.inf: return 0.5*ctx.pi
if z == ctx.ninf: return -0.5*ctx.pi
# Suffers from cancellation near 0
if ctx.mag(z) >= -1:
jz = ctx.fmul(ctx.j,z,exact=True)
njz = ctx.fneg(jz,exact=True)
v = (-0.5j)*(ctx.ei(jz) - ctx.ei(njz))
zreal = ctx._re(z)
if zreal > 0:
v -= 0.5*ctx.pi
if zreal < 0:
v += 0.5*ctx.pi
if ctx._is_real_type(z):
v = ctx._re(v)
return v
else:
return z*ctx.hyp1f2((1,2),(3,2),(3,2),-0.25*z*z)
@defun_wrapped
def chi(ctx, z):
nz = ctx.fneg(z, exact=True)
v = 0.5*(ctx.ei(z) + ctx.ei(nz))
zreal = ctx._re(z)
zimag = ctx._im(z)
if zimag > 0:
v += ctx.pi*0.5j
elif zimag < 0:
v -= ctx.pi*0.5j
elif zreal < 0:
v += ctx.pi*1j
return v
@defun_wrapped
def shi(ctx, z):
# Suffers from cancellation near 0
if ctx.mag(z) >= -1:
nz = ctx.fneg(z, exact=True)
v = 0.5*(ctx.ei(z) - ctx.ei(nz))
zimag = ctx._im(z)
if zimag > 0: v -= 0.5j*ctx.pi
if zimag < 0: v += 0.5j*ctx.pi
return v
else:
return z * ctx.hyp1f2((1,2),(3,2),(3,2),0.25*z*z)
@defun_wrapped
def fresnels(ctx, z):
if z == ctx.inf:
return ctx.mpf(0.5)
if z == ctx.ninf:
return ctx.mpf(-0.5)
return ctx.pi*z**3/6*ctx.hyp1f2((3,4),(3,2),(7,4),-ctx.pi**2*z**4/16)
@defun_wrapped
def fresnelc(ctx, z):
if z == ctx.inf:
return ctx.mpf(0.5)
if z == ctx.ninf:
return ctx.mpf(-0.5)
return z*ctx.hyp1f2((1,4),(1,2),(5,4),-ctx.pi**2*z**4/16)

View File

@ -0,0 +1,187 @@
from ..libmp.backend import xrange
from .functions import defun, defun_wrapped
@defun
def gammaprod(ctx, a, b, _infsign=False):
a = [ctx.convert(x) for x in a]
b = [ctx.convert(x) for x in b]
poles_num = []
poles_den = []
regular_num = []
regular_den = []
for x in a: [regular_num, poles_num][ctx.isnpint(x)].append(x)
for x in b: [regular_den, poles_den][ctx.isnpint(x)].append(x)
# One more pole in numerator or denominator gives 0 or inf
if len(poles_num) < len(poles_den): return ctx.zero
if len(poles_num) > len(poles_den):
# Get correct sign of infinity for x+h, h -> 0 from above
# XXX: hack, this should be done properly
if _infsign:
a = [x and x*(1+ctx.eps) or x+ctx.eps for x in poles_num]
b = [x and x*(1+ctx.eps) or x+ctx.eps for x in poles_den]
return ctx.sign(ctx.gammaprod(a+regular_num,b+regular_den)) * ctx.inf
else:
return ctx.inf
# All poles cancel
# lim G(i)/G(j) = (-1)**(i+j) * gamma(1-j) / gamma(1-i)
p = ctx.one
orig = ctx.prec
try:
ctx.prec = orig + 15
while poles_num:
i = poles_num.pop()
j = poles_den.pop()
p *= (-1)**(i+j) * ctx.gamma(1-j) / ctx.gamma(1-i)
for x in regular_num: p *= ctx.gamma(x)
for x in regular_den: p /= ctx.gamma(x)
finally:
ctx.prec = orig
return +p
@defun
def beta(ctx, x, y):
x = ctx.convert(x)
y = ctx.convert(y)
if ctx.isinf(y):
x, y = y, x
if ctx.isinf(x):
if x == ctx.inf and not ctx._im(y):
if y == ctx.ninf:
return ctx.nan
if y > 0:
return ctx.zero
if ctx.isint(y):
return ctx.nan
if y < 0:
return ctx.sign(ctx.gamma(y)) * ctx.inf
return ctx.nan
xy = ctx.fadd(x, y, prec=2*ctx.prec)
return ctx.gammaprod([x, y], [xy])
@defun
def binomial(ctx, n, k):
n1 = ctx.fadd(n, 1, prec=2*ctx.prec)
k1 = ctx.fadd(k, 1, prec=2*ctx.prec)
nk1 = ctx.fsub(n1, k, prec=2*ctx.prec)
return ctx.gammaprod([n1], [k1, nk1])
@defun
def rf(ctx, x, n):
xn = ctx.fadd(x, n, prec=2*ctx.prec)
return ctx.gammaprod([xn], [x])
@defun
def ff(ctx, x, n):
x1 = ctx.fadd(x, 1, prec=2*ctx.prec)
xn1 = ctx.fadd(ctx.fsub(x, n, prec=2*ctx.prec), 1, prec=2*ctx.prec)
return ctx.gammaprod([x1], [xn1])
@defun_wrapped
def fac2(ctx, x):
if ctx.isinf(x):
if x == ctx.inf:
return x
return ctx.nan
return 2**(x/2)*(ctx.pi/2)**((ctx.cospi(x)-1)/4)*ctx.gamma(x/2+1)
@defun_wrapped
def barnesg(ctx, z):
if ctx.isinf(z):
if z == ctx.inf:
return z
return ctx.nan
if ctx.isnan(z):
return z
if (not ctx._im(z)) and ctx._re(z) <= 0 and ctx.isint(ctx._re(z)):
return z*0
# Account for size (would not be needed if computing log(G))
if abs(z) > 5:
ctx.dps += 2*ctx.log(abs(z),2)
# Reflection formula
if ctx.re(z) < -ctx.dps:
w = 1-z
pi2 = 2*ctx.pi
u = ctx.expjpi(2*w)
v = ctx.j*ctx.pi/12 - ctx.j*ctx.pi*w**2/2 + w*ctx.ln(1-u) - \
ctx.j*ctx.polylog(2, u)/pi2
v = ctx.barnesg(2-z)*ctx.exp(v)/pi2**w
if ctx._is_real_type(z):
v = ctx._re(v)
return v
# Estimate terms for asymptotic expansion
# TODO: fixme, obviously
N = ctx.dps // 2 + 5
G = 1
while abs(z) < N or ctx.re(z) < 1:
G /= ctx.gamma(z)
z += 1
z -= 1
s = ctx.mpf(1)/12
s -= ctx.log(ctx.glaisher)
s += z*ctx.log(2*ctx.pi)/2
s += (z**2/2-ctx.mpf(1)/12)*ctx.log(z)
s -= 3*z**2/4
z2k = z2 = z**2
for k in xrange(1, N+1):
t = ctx.bernoulli(2*k+2) / (4*k*(k+1)*z2k)
if abs(t) < ctx.eps:
#print k, N # check how many terms were needed
break
z2k *= z2
s += t
#if k == N:
# print "warning: series for barnesg failed to converge", ctx.dps
return G*ctx.exp(s)
@defun
def superfac(ctx, z):
return ctx.barnesg(z+2)
@defun_wrapped
def hyperfac(ctx, z):
# XXX: estimate needed extra bits accurately
if z == ctx.inf:
return z
if abs(z) > 5:
extra = 4*int(ctx.log(abs(z),2))
else:
extra = 0
ctx.prec += extra
if not ctx._im(z) and ctx._re(z) < 0 and ctx.isint(ctx._re(z)):
n = int(ctx.re(z))
h = ctx.hyperfac(-n-1)
if ((n+1)//2) & 1:
h = -h
if ctx._is_complex_type(z):
return h + 0j
return h
zp1 = z+1
# Wrong branch cut
#v = ctx.gamma(zp1)**z
#ctx.prec -= extra
#return v / ctx.barnesg(zp1)
v = ctx.exp(z*ctx.loggamma(zp1))
ctx.prec -= extra
return v / ctx.barnesg(zp1)
'''
@defun
def psi0(ctx, z):
"""Shortcut for psi(0,z) (the digamma function)"""
return ctx.psi(0, z)
@defun
def psi1(ctx, z):
"""Shortcut for psi(1,z) (the trigamma function)"""
return ctx.psi(1, z)
@defun
def psi2(ctx, z):
"""Shortcut for psi(2,z) (the tetragamma function)"""
return ctx.psi(2, z)
@defun
def psi3(ctx, z):
"""Shortcut for psi(3,z) (the pentagamma function)"""
return ctx.psi(3, z)
'''

View File

@ -0,0 +1,645 @@
from ..libmp.backend import xrange
class SpecialFunctions(object):
"""
This class implements special functions using high-level code.
Elementary and some other functions (e.g. gamma function, basecase
hypergeometric series) are assumed to be predefined by the context as
"builtins" or "low-level" functions.
"""
defined_functions = {}
# The series for the Jacobi theta functions converge for |q| < 1;
# in the current implementation they throw a ValueError for
# abs(q) > THETA_Q_LIM
THETA_Q_LIM = 1 - 10**-7
def __init__(self):
cls = self.__class__
for name in cls.defined_functions:
f, wrap = cls.defined_functions[name]
cls._wrap_specfun(name, f, wrap)
self.mpq_1 = self._mpq((1,1))
self.mpq_0 = self._mpq((0,1))
self.mpq_1_2 = self._mpq((1,2))
self.mpq_3_2 = self._mpq((3,2))
self.mpq_1_4 = self._mpq((1,4))
self.mpq_1_16 = self._mpq((1,16))
self.mpq_3_16 = self._mpq((3,16))
self.mpq_5_2 = self._mpq((5,2))
self.mpq_3_4 = self._mpq((3,4))
self.mpq_7_4 = self._mpq((7,4))
self.mpq_5_4 = self._mpq((5,4))
self.mpq_1_3 = self._mpq((1,3))
self.mpq_2_3 = self._mpq((2,3))
self.mpq_4_3 = self._mpq((4,3))
self.mpq_1_6 = self._mpq((1,6))
self.mpq_5_6 = self._mpq((5,6))
self.mpq_5_3 = self._mpq((5,3))
self._misc_const_cache = {}
self._aliases.update({
'phase' : 'arg',
'conjugate' : 'conj',
'nthroot' : 'root',
'polygamma' : 'psi',
'hurwitz' : 'zeta',
#'digamma' : 'psi0',
#'trigamma' : 'psi1',
#'tetragamma' : 'psi2',
#'pentagamma' : 'psi3',
'fibonacci' : 'fib',
'factorial' : 'fac',
})
self.zetazero_memoized = self.memoize(self.zetazero)
# Default -- do nothing
@classmethod
def _wrap_specfun(cls, name, f, wrap):
setattr(cls, name, f)
# Optional fast versions of common functions in common cases.
# If not overridden, default (generic hypergeometric series)
# implementations will be used
def _besselj(ctx, n, z): raise NotImplementedError
def _erf(ctx, z): raise NotImplementedError
def _erfc(ctx, z): raise NotImplementedError
def _gamma_upper_int(ctx, z, a): raise NotImplementedError
def _expint_int(ctx, n, z): raise NotImplementedError
def _zeta(ctx, s): raise NotImplementedError
def _zetasum_fast(ctx, s, a, n, derivatives, reflect): raise NotImplementedError
def _ei(ctx, z): raise NotImplementedError
def _e1(ctx, z): raise NotImplementedError
def _ci(ctx, z): raise NotImplementedError
def _si(ctx, z): raise NotImplementedError
def _altzeta(ctx, s): raise NotImplementedError
def defun_wrapped(f):
SpecialFunctions.defined_functions[f.__name__] = f, True
return f
def defun(f):
SpecialFunctions.defined_functions[f.__name__] = f, False
return f
def defun_static(f):
setattr(SpecialFunctions, f.__name__, f)
return f
@defun_wrapped
def cot(ctx, z): return ctx.one / ctx.tan(z)
@defun_wrapped
def sec(ctx, z): return ctx.one / ctx.cos(z)
@defun_wrapped
def csc(ctx, z): return ctx.one / ctx.sin(z)
@defun_wrapped
def coth(ctx, z): return ctx.one / ctx.tanh(z)
@defun_wrapped
def sech(ctx, z): return ctx.one / ctx.cosh(z)
@defun_wrapped
def csch(ctx, z): return ctx.one / ctx.sinh(z)
@defun_wrapped
def acot(ctx, z):
if not z:
return ctx.pi * 0.5
else:
return ctx.atan(ctx.one / z)
@defun_wrapped
def asec(ctx, z): return ctx.acos(ctx.one / z)
@defun_wrapped
def acsc(ctx, z): return ctx.asin(ctx.one / z)
@defun_wrapped
def acoth(ctx, z):
if not z:
return ctx.pi * 0.5j
else:
return ctx.atanh(ctx.one / z)
@defun_wrapped
def asech(ctx, z): return ctx.acosh(ctx.one / z)
@defun_wrapped
def acsch(ctx, z): return ctx.asinh(ctx.one / z)
@defun
def sign(ctx, x):
x = ctx.convert(x)
if not x or ctx.isnan(x):
return x
if ctx._is_real_type(x):
if x > 0:
return ctx.one
else:
return -ctx.one
return x / abs(x)
@defun
def agm(ctx, a, b=1):
if b == 1:
return ctx.agm1(a)
a = ctx.convert(a)
b = ctx.convert(b)
return ctx._agm(a, b)
@defun_wrapped
def sinc(ctx, x):
if ctx.isinf(x):
return 1/x
if not x:
return x+1
return ctx.sin(x)/x
@defun_wrapped
def sincpi(ctx, x):
if ctx.isinf(x):
return 1/x
if not x:
return x+1
return ctx.sinpi(x)/(ctx.pi*x)
# TODO: tests; improve implementation
@defun_wrapped
def expm1(ctx, x):
if not x:
return ctx.zero
# exp(x) - 1 ~ x
if ctx.mag(x) < -ctx.prec:
return x + 0.5*x**2
# TODO: accurately eval the smaller of the real/imag parts
return ctx.sum_accurately(lambda: iter([ctx.exp(x),-1]),1)
@defun_wrapped
def log1p(ctx, x):
if not x:
return ctx.zero
if ctx.mag(x) < -ctx.prec:
return x - 0.5*x**2
return ctx.log(ctx.fadd(1, x, prec=2*ctx.prec))
@defun_wrapped
def powm1(ctx, x, y):
mag = ctx.mag
one = ctx.one
w = x**y - one
M = mag(w)
# Only moderate cancellation
if M > -8:
return w
# Check for the only possible exact cases
if not w:
if (not y) or (x in (1, -1, 1j, -1j) and ctx.isint(y)):
return w
x1 = x - one
magy = mag(y)
lnx = ctx.ln(x)
# Small y: x^y - 1 ~ log(x)*y + O(log(x)^2 * y^2)
if magy + mag(lnx) < -ctx.prec:
return lnx*y + (lnx*y)**2/2
# TODO: accurately eval the smaller of the real/imag part
return ctx.sum_accurately(lambda: iter([x**y, -1]), 1)
@defun
def _rootof1(ctx, k, n):
k = int(k)
n = int(n)
k %= n
if not k:
return ctx.one
elif 2*k == n:
return -ctx.one
elif 4*k == n:
return ctx.j
elif 4*k == 3*n:
return -ctx.j
return ctx.expjpi(2*ctx.mpf(k)/n)
@defun
def root(ctx, x, n, k=0):
n = int(n)
x = ctx.convert(x)
if k:
# Special case: there is an exact real root
if (n & 1 and 2*k == n-1) and (not ctx.im(x)) and (ctx.re(x) < 0):
return -ctx.root(-x, n)
# Multiply by root of unity
prec = ctx.prec
try:
ctx.prec += 10
v = ctx.root(x, n, 0) * ctx._rootof1(k, n)
finally:
ctx.prec = prec
return +v
return ctx._nthroot(x, n)
@defun
def unitroots(ctx, n, primitive=False):
gcd = ctx._gcd
prec = ctx.prec
try:
ctx.prec += 10
if primitive:
v = [ctx._rootof1(k,n) for k in range(n) if gcd(k,n) == 1]
else:
# TODO: this can be done *much* faster
v = [ctx._rootof1(k,n) for k in range(n)]
finally:
ctx.prec = prec
return [+x for x in v]
@defun
def arg(ctx, x):
x = ctx.convert(x)
re = ctx._re(x)
im = ctx._im(x)
return ctx.atan2(im, re)
@defun
def fabs(ctx, x):
return abs(ctx.convert(x))
@defun
def re(ctx, x):
x = ctx.convert(x)
if hasattr(x, "real"): # py2.5 doesn't have .real/.imag for all numbers
return x.real
return x
@defun
def im(ctx, x):
x = ctx.convert(x)
if hasattr(x, "imag"): # py2.5 doesn't have .real/.imag for all numbers
return x.imag
return ctx.zero
@defun
def conj(ctx, x):
x = ctx.convert(x)
try:
return x.conjugate()
except AttributeError:
return x
@defun
def polar(ctx, z):
return (ctx.fabs(z), ctx.arg(z))
@defun_wrapped
def rect(ctx, r, phi):
return r * ctx.mpc(*ctx.cos_sin(phi))
@defun
def log(ctx, x, b=None):
if b is None:
return ctx.ln(x)
wp = ctx.prec + 20
return ctx.ln(x, prec=wp) / ctx.ln(b, prec=wp)
@defun
def log10(ctx, x):
return ctx.log(x, 10)
@defun
def fmod(ctx, x, y):
return ctx.convert(x) % ctx.convert(y)
@defun
def degrees(ctx, x):
return x / ctx.degree
@defun
def radians(ctx, x):
return x * ctx.degree
def _lambertw_special(ctx, z, k):
# W(0,0) = 0; all other branches are singular
if not z:
if not k:
return z
return ctx.ninf + z
if z == ctx.inf:
if k == 0:
return z
else:
return z + 2*k*ctx.pi*ctx.j
if z == ctx.ninf:
return (-z) + (2*k+1)*ctx.pi*ctx.j
# Some kind of nan or complex inf/nan?
return ctx.ln(z)
import math
import cmath
def _lambertw_approx_hybrid(z, k):
imag_sign = 0
if hasattr(z, "imag"):
x = float(z.real)
y = z.imag
if y:
imag_sign = (-1) ** (y < 0)
y = float(y)
else:
x = float(z)
y = 0.0
imag_sign = 0
# hack to work regardless of whether Python supports -0.0
if not y:
y = 0.0
z = complex(x,y)
if k == 0:
if -4.0 < y < 4.0 and -1.0 < x < 2.5:
if imag_sign:
# Taylor series in upper/lower half-plane
if y > 1.00: return (0.876+0.645j) + (0.118-0.174j)*(z-(0.75+2.5j))
if y > 0.25: return (0.505+0.204j) + (0.375-0.132j)*(z-(0.75+0.5j))
if y < -1.00: return (0.876-0.645j) + (0.118+0.174j)*(z-(0.75-2.5j))
if y < -0.25: return (0.505-0.204j) + (0.375+0.132j)*(z-(0.75-0.5j))
# Taylor series near -1
if x < -0.5:
if imag_sign >= 0:
return (-0.318+1.34j) + (-0.697-0.593j)*(z+1)
else:
return (-0.318-1.34j) + (-0.697+0.593j)*(z+1)
# return real type
r = -0.367879441171442
if (not imag_sign) and x > r:
z = x
# Singularity near -1/e
if x < -0.2:
return -1 + 2.33164398159712*(z-r)**0.5 - 1.81218788563936*(z-r)
# Taylor series near 0
if x < 0.5: return z
# Simple linear approximation
return 0.2 + 0.3*z
if (not imag_sign) and x > 0.0:
L1 = math.log(x); L2 = math.log(L1)
else:
L1 = cmath.log(z); L2 = cmath.log(L1)
elif k == -1:
# return real type
r = -0.367879441171442
if (not imag_sign) and r < x < 0.0:
z = x
if (imag_sign >= 0) and y < 0.1 and -0.6 < x < -0.2:
return -1 - 2.33164398159712*(z-r)**0.5 - 1.81218788563936*(z-r)
if (not imag_sign) and -0.2 <= x < 0.0:
L1 = math.log(-x)
return L1 - math.log(-L1)
else:
if imag_sign == -1 and (not y) and x < 0.0:
L1 = cmath.log(z) - 3.1415926535897932j
else:
L1 = cmath.log(z) - 6.2831853071795865j
L2 = cmath.log(L1)
return L1 - L2 + L2/L1 + L2*(L2-2)/(2*L1**2)
def _lambertw_series(ctx, z, k, tol):
"""
Return rough approximation for W_k(z) from an asymptotic series,
sufficiently accurate for the Halley iteration to converge to
the correct value.
"""
magz = ctx.mag(z)
if (-10 < magz < 900) and (-1000 < k < 1000):
# Near the branch point at -1/e
if magz < 1 and abs(z+0.36787944117144) < 0.05:
if k == 0 or (k == -1 and ctx._im(z) >= 0) or \
(k == 1 and ctx._im(z) < 0):
delta = ctx.sum_accurately(lambda: [z, ctx.exp(-1)])
cancellation = -ctx.mag(delta)
ctx.prec += cancellation
# Use series given in Corless et al.
p = ctx.sqrt(2*(ctx.e*z+1))
ctx.prec -= cancellation
u = {0:ctx.mpf(-1), 1:ctx.mpf(1)}
a = {0:ctx.mpf(2), 1:ctx.mpf(-1)}
if k != 0:
p = -p
s = ctx.zero
# The series converges, so we could use it directly, but unless
# *extremely* close, it is better to just use the first few
# terms to get a good approximation for the iteration
for l in xrange(max(2,cancellation)):
if l not in u:
a[l] = ctx.fsum(u[j]*u[l+1-j] for j in xrange(2,l))
u[l] = (l-1)*(u[l-2]/2+a[l-2]/4)/(l+1)-a[l]/2-u[l-1]/(l+1)
term = u[l] * p**l
s += term
if ctx.mag(term) < -tol:
return s, True
l += 1
ctx.prec += cancellation//2
return s, False
if k == 0 or k == -1:
return _lambertw_approx_hybrid(z, k), False
if k == 0:
if magz < -1:
return z*(1-z), False
L1 = ctx.ln(z)
L2 = ctx.ln(L1)
elif k == -1 and (not ctx._im(z)) and (-0.36787944117144 < ctx._re(z) < 0):
L1 = ctx.ln(-z)
return L1 - ctx.ln(-L1), False
else:
# This holds both as z -> 0 and z -> inf.
# Relative error is O(1/log(z)).
L1 = ctx.ln(z) + 2j*ctx.pi*k
L2 = ctx.ln(L1)
return L1 - L2 + L2/L1 + L2*(L2-2)/(2*L1**2), False
@defun
def lambertw(ctx, z, k=0):
z = ctx.convert(z)
k = int(k)
if not ctx.isnormal(z):
return _lambertw_special(ctx, z, k)
prec = ctx.prec
ctx.prec += 20 + ctx.mag(k or 1)
wp = ctx.prec
tol = wp - 5
w, done = _lambertw_series(ctx, z, k, tol)
if not done:
# Use Halley iteration to solve w*exp(w) = z
two = ctx.mpf(2)
for i in xrange(100):
ew = ctx.exp(w)
wew = w*ew
wewz = wew-z
wn = w - wewz/(wew+ew-(w+two)*wewz/(two*w+two))
if ctx.mag(wn-w) <= ctx.mag(wn) - tol:
w = wn
break
else:
w = wn
if i == 100:
ctx.warn("Lambert W iteration failed to converge for z = %s" % z)
ctx.prec = prec
return +w
@defun_wrapped
def bell(ctx, n, x=1):
x = ctx.convert(x)
if not n:
if ctx.isnan(x):
return x
return type(x)(1)
if ctx.isinf(x) or ctx.isinf(n) or ctx.isnan(x) or ctx.isnan(n):
return x**n
if n == 1: return x
if n == 2: return x*(x+1)
if x == 0: return ctx.sincpi(n)
return _polyexp(ctx, n, x, True) / ctx.exp(x)
def _polyexp(ctx, n, x, extra=False):
def _terms():
if extra:
yield ctx.sincpi(n)
t = x
k = 1
while 1:
yield k**n * t
k += 1
t = t*x/k
return ctx.sum_accurately(_terms, check_step=4)
@defun_wrapped
def polyexp(ctx, s, z):
if ctx.isinf(z) or ctx.isinf(s) or ctx.isnan(z) or ctx.isnan(s):
return z**s
if z == 0: return z*s
if s == 0: return ctx.expm1(z)
if s == 1: return ctx.exp(z)*z
if s == 2: return ctx.exp(z)*z*(z+1)
return _polyexp(ctx, s, z)
@defun_wrapped
def cyclotomic(ctx, n, z):
n = int(n)
if n < 0:
raise ValueError("n cannot be negative")
p = ctx.one
if n == 0:
return p
if n == 1:
return z - p
if n == 2:
return z + p
# Use divisor product representation. Unfortunately, this sometimes
# includes singularities for roots of unity, which we have to cancel out.
# Matching zeros/poles pairwise, we have (1-z^a)/(1-z^b) ~ a/b + O(z-1).
a_prod = 1
b_prod = 1
num_zeros = 0
num_poles = 0
for d in range(1,n+1):
if not n % d:
w = ctx.moebius(n//d)
# Use powm1 because it is important that we get 0 only
# if it really is exactly 0
b = -ctx.powm1(z, d)
if b:
p *= b**w
else:
if w == 1:
a_prod *= d
num_zeros += 1
elif w == -1:
b_prod *= d
num_poles += 1
#print n, num_zeros, num_poles
if num_zeros:
if num_zeros > num_poles:
p *= 0
else:
p *= a_prod
p /= b_prod
return p
@defun
def mangoldt(ctx, n):
r"""
Evaluates the von Mangoldt function `\Lambda(n) = \log p`
if `n = p^k` a power of a prime, and `\Lambda(n) = 0` otherwise.
**Examples**
>>> from mpmath import *
>>> mp.dps = 25; mp.pretty = True
>>> [mangoldt(n) for n in range(-2,3)]
[0.0, 0.0, 0.0, 0.0, 0.6931471805599453094172321]
>>> mangoldt(6)
0.0
>>> mangoldt(7)
1.945910149055313305105353
>>> mangoldt(8)
0.6931471805599453094172321
>>> fsum(mangoldt(n) for n in range(101))
94.04531122935739224600493
>>> fsum(mangoldt(n) for n in range(10001))
10013.39669326311478372032
"""
n = int(n)
if n < 2:
return ctx.zero
if n % 2 == 0:
# Must be a power of two
if n & (n-1) == 0:
return +ctx.ln2
else:
return ctx.zero
# TODO: the following could be generalized into a perfect
# power testing function
# ---
# Look for a small factor
for p in (3,5,7,11,13,17,19,23,29,31):
if not n % p:
q, r = n // p, 0
while q > 1:
q, r = divmod(q, p)
if r:
return ctx.zero
return ctx.ln(p)
if ctx.isprime(n):
return ctx.ln(n)
# Obviously, we could use arbitrary-precision arithmetic for this...
if n > 10**30:
raise NotImplementedError
k = 2
while 1:
p = int(n**(1./k) + 0.5)
if p < 2:
return ctx.zero
if p ** k == n:
if ctx.isprime(p):
return ctx.ln(p)
k += 1
@defun
def stirling1(ctx, n, k, exact=False):
v = ctx._stirling1(int(n), int(k))
if exact:
return int(v)
else:
return ctx.mpf(v)
@defun
def stirling2(ctx, n, k, exact=False):
v = ctx._stirling2(int(n), int(k))
if exact:
return int(v)
else:
return ctx.mpf(v)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,493 @@
from .functions import defun, defun_wrapped
def _hermite_param(ctx, n, z, parabolic_cylinder):
"""
Combined calculation of the Hermite polynomial H_n(z) (and its
generalization to complex n) and the parabolic cylinder
function D.
"""
n, ntyp = ctx._convert_param(n)
z = ctx.convert(z)
q = -ctx.mpq_1_2
# For re(z) > 0, 2F0 -- http://functions.wolfram.com/
# HypergeometricFunctions/HermiteHGeneral/06/02/0009/
# Otherwise, there is a reflection formula
# 2F0 + http://functions.wolfram.com/HypergeometricFunctions/
# HermiteHGeneral/16/01/01/0006/
#
# TODO:
# An alternative would be to use
# http://functions.wolfram.com/HypergeometricFunctions/
# HermiteHGeneral/06/02/0006/
#
# Also, the 1F1 expansion
# http://functions.wolfram.com/HypergeometricFunctions/
# HermiteHGeneral/26/01/02/0001/
# should probably be used for tiny z
if not z:
T1 = [2, ctx.pi], [n, 0.5], [], [q*(n-1)], [], [], 0
if parabolic_cylinder:
T1[1][0] += q*n
return T1,
can_use_2f0 = ctx.isnpint(-n) or ctx.re(z) > 0 or \
(ctx.re(z) == 0 and ctx.im(z) > 0)
expprec = ctx.prec*4 + 20
if parabolic_cylinder:
u = ctx.fmul(ctx.fmul(z,z,prec=expprec), -0.25, exact=True)
w = ctx.fmul(z, ctx.sqrt(0.5,prec=expprec), prec=expprec)
else:
w = z
w2 = ctx.fmul(w, w, prec=expprec)
rw2 = ctx.fdiv(1, w2, prec=expprec)
nrw2 = ctx.fneg(rw2, exact=True)
nw = ctx.fneg(w, exact=True)
if can_use_2f0:
T1 = [2, w], [n, n], [], [], [q*n, q*(n-1)], [], nrw2
terms = [T1]
else:
T1 = [2, nw], [n, n], [], [], [q*n, q*(n-1)], [], nrw2
T2 = [2, ctx.pi, nw], [n+2, 0.5, 1], [], [q*n], [q*(n-1)], [1-q], w2
terms = [T1,T2]
# Multiply by prefactor for D_n
if parabolic_cylinder:
expu = ctx.exp(u)
for i in range(len(terms)):
terms[i][1][0] += q*n
terms[i][0].append(expu)
terms[i][1].append(1)
return tuple(terms)
@defun
def hermite(ctx, n, z, **kwargs):
return ctx.hypercomb(lambda: _hermite_param(ctx, n, z, 0), [], **kwargs)
@defun
def pcfd(ctx, n, z, **kwargs):
r"""
Gives the parabolic cylinder function in Whittaker's notation
`D_n(z) = U(-n-1/2, z)` (see :func:`~mpmath.pcfu`).
It solves the differential equation
.. math ::
y'' + \left(n + \frac{1}{2} - \frac{1}{4} z^2\right) y = 0.
and can be represented in terms of Hermite polynomials
(see :func:`~mpmath.hermite`) as
.. math ::
D_n(z) = 2^{-n/2} e^{-z^2/4} H_n\left(\frac{z}{\sqrt{2}}\right).
**Plots**
.. literalinclude :: /plots/pcfd.py
.. image :: /plots/pcfd.png
**Examples**
>>> from mpmath import *
>>> mp.dps = 25; mp.pretty = True
>>> pcfd(0,0); pcfd(1,0); pcfd(2,0); pcfd(3,0)
1.0
0.0
-1.0
0.0
>>> pcfd(4,0); pcfd(-3,0)
3.0
0.6266570686577501256039413
>>> pcfd('1/2', 2+3j)
(-5.363331161232920734849056 - 3.858877821790010714163487j)
>>> pcfd(2, -10)
1.374906442631438038871515e-9
Verifying the differential equation::
>>> n = mpf(2.5)
>>> y = lambda z: pcfd(n,z)
>>> z = 1.75
>>> chop(diff(y,z,2) + (n+0.5-0.25*z**2)*y(z))
0.0
Rational Taylor series expansion when `n` is an integer::
>>> taylor(lambda z: pcfd(5,z), 0, 7)
[0.0, 15.0, 0.0, -13.75, 0.0, 3.96875, 0.0, -0.6015625]
"""
return ctx.hypercomb(lambda: _hermite_param(ctx, n, z, 1), [], **kwargs)
@defun
def pcfu(ctx, a, z, **kwargs):
r"""
Gives the parabolic cylinder function `U(a,z)`, which may be
defined for `\Re(z) > 0` in terms of the confluent
U-function (see :func:`~mpmath.hyperu`) by
.. math ::
U(a,z) = 2^{-\frac{1}{4}-\frac{a}{2}} e^{-\frac{1}{4} z^2}
U\left(\frac{a}{2}+\frac{1}{4},
\frac{1}{2}, \frac{1}{2}z^2\right)
or, for arbitrary `z`,
.. math ::
e^{-\frac{1}{4}z^2} U(a,z) =
U(a,0) \,_1F_1\left(-\tfrac{a}{2}+\tfrac{1}{4};
\tfrac{1}{2}; -\tfrac{1}{2}z^2\right) +
U'(a,0) z \,_1F_1\left(-\tfrac{a}{2}+\tfrac{3}{4};
\tfrac{3}{2}; -\tfrac{1}{2}z^2\right).
**Examples**
Connection to other functions::
>>> from mpmath import *
>>> mp.dps = 25; mp.pretty = True
>>> z = mpf(3)
>>> pcfu(0.5,z)
0.03210358129311151450551963
>>> sqrt(pi/2)*exp(z**2/4)*erfc(z/sqrt(2))
0.03210358129311151450551963
>>> pcfu(0.5,-z)
23.75012332835297233711255
>>> sqrt(pi/2)*exp(z**2/4)*erfc(-z/sqrt(2))
23.75012332835297233711255
>>> pcfu(0.5,-z)
23.75012332835297233711255
>>> sqrt(pi/2)*exp(z**2/4)*erfc(-z/sqrt(2))
23.75012332835297233711255
"""
n, _ = ctx._convert_param(a)
return ctx.pcfd(-n-ctx.mpq_1_2, z)
@defun
def pcfv(ctx, a, z, **kwargs):
r"""
Gives the parabolic cylinder function `V(a,z)`, which can be
represented in terms of :func:`~mpmath.pcfu` as
.. math ::
V(a,z) = \frac{\Gamma(a+\tfrac{1}{2}) (U(a,-z)-\sin(\pi a) U(a,z)}{\pi}.
**Examples**
Wronskian relation between `U` and `V`::
>>> from mpmath import *
>>> mp.dps = 25; mp.pretty = True
>>> a, z = 2, 3
>>> pcfu(a,z)*diff(pcfv,(a,z),(0,1))-diff(pcfu,(a,z),(0,1))*pcfv(a,z)
0.7978845608028653558798921
>>> sqrt(2/pi)
0.7978845608028653558798921
>>> a, z = 2.5, 3
>>> pcfu(a,z)*diff(pcfv,(a,z),(0,1))-diff(pcfu,(a,z),(0,1))*pcfv(a,z)
0.7978845608028653558798921
>>> a, z = 0.25, -1
>>> pcfu(a,z)*diff(pcfv,(a,z),(0,1))-diff(pcfu,(a,z),(0,1))*pcfv(a,z)
0.7978845608028653558798921
>>> a, z = 2+1j, 2+3j
>>> chop(pcfu(a,z)*diff(pcfv,(a,z),(0,1))-diff(pcfu,(a,z),(0,1))*pcfv(a,z))
0.7978845608028653558798921
"""
n, ntype = ctx._convert_param(a)
z = ctx.convert(z)
q = ctx.mpq_1_2
r = ctx.mpq_1_4
if ntype == 'Q' and ctx.isint(n*2):
# Faster for half-integers
def h():
jz = ctx.fmul(z, -1j, exact=True)
T1terms = _hermite_param(ctx, -n-q, z, 1)
T2terms = _hermite_param(ctx, n-q, jz, 1)
for T in T1terms:
T[0].append(1j)
T[1].append(1)
T[3].append(q-n)
u = ctx.expjpi((q*n-r)) * ctx.sqrt(2/ctx.pi)
for T in T2terms:
T[0].append(u)
T[1].append(1)
return T1terms + T2terms
v = ctx.hypercomb(h, [], **kwargs)
if ctx._is_real_type(n) and ctx._is_real_type(z):
v = ctx._re(v)
return v
else:
def h(n):
w = ctx.square_exp_arg(z, -0.25)
u = ctx.square_exp_arg(z, 0.5)
e = ctx.exp(w)
l = [ctx.pi, q, ctx.exp(w)]
Y1 = l, [-q, n*q+r, 1], [r-q*n], [], [q*n+r], [q], u
Y2 = l + [z], [-q, n*q-r, 1, 1], [1-r-q*n], [], [q*n+1-r], [1+q], u
c, s = ctx.cospi_sinpi(r+q*n)
Y1[0].append(s)
Y2[0].append(c)
for Y in (Y1, Y2):
Y[1].append(1)
Y[3].append(q-n)
return Y1, Y2
return ctx.hypercomb(h, [n], **kwargs)
@defun
def pcfw(ctx, a, z, **kwargs):
r"""
Gives the parabolic cylinder function `W(a,z)` defined in (DLMF 12.14).
**Examples**
Value at the origin::
>>> from mpmath import *
>>> mp.dps = 25; mp.pretty = True
>>> a = mpf(0.25)
>>> pcfw(a,0)
0.9722833245718180765617104
>>> power(2,-0.75)*sqrt(abs(gamma(0.25+0.5j*a)/gamma(0.75+0.5j*a)))
0.9722833245718180765617104
>>> diff(pcfw,(a,0),(0,1))
-0.5142533944210078966003624
>>> -power(2,-0.25)*sqrt(abs(gamma(0.75+0.5j*a)/gamma(0.25+0.5j*a)))
-0.5142533944210078966003624
"""
n, _ = ctx._convert_param(a)
z = ctx.convert(z)
def terms():
phi2 = ctx.arg(ctx.gamma(0.5 + ctx.j*n))
phi2 = (ctx.loggamma(0.5+ctx.j*n) - ctx.loggamma(0.5-ctx.j*n))/2j
rho = ctx.pi/8 + 0.5*phi2
# XXX: cancellation computing k
k = ctx.sqrt(1 + ctx.exp(2*ctx.pi*n)) - ctx.exp(ctx.pi*n)
C = ctx.sqrt(k/2) * ctx.exp(0.25*ctx.pi*n)
yield C * ctx.expj(rho) * ctx.pcfu(ctx.j*n, z*ctx.expjpi(-0.25))
yield C * ctx.expj(-rho) * ctx.pcfu(-ctx.j*n, z*ctx.expjpi(0.25))
v = ctx.sum_accurately(terms)
if ctx._is_real_type(n) and ctx._is_real_type(z):
v = ctx._re(v)
return v
"""
Even/odd PCFs. Useful?
@defun
def pcfy1(ctx, a, z, **kwargs):
a, _ = ctx._convert_param(n)
z = ctx.convert(z)
def h():
w = ctx.square_exp_arg(z)
w1 = ctx.fmul(w, -0.25, exact=True)
w2 = ctx.fmul(w, 0.5, exact=True)
e = ctx.exp(w1)
return [e], [1], [], [], [ctx.mpq_1_2*a+ctx.mpq_1_4], [ctx.mpq_1_2], w2
return ctx.hypercomb(h, [], **kwargs)
@defun
def pcfy2(ctx, a, z, **kwargs):
a, _ = ctx._convert_param(n)
z = ctx.convert(z)
def h():
w = ctx.square_exp_arg(z)
w1 = ctx.fmul(w, -0.25, exact=True)
w2 = ctx.fmul(w, 0.5, exact=True)
e = ctx.exp(w1)
return [e, z], [1, 1], [], [], [ctx.mpq_1_2*a+ctx.mpq_3_4], \
[ctx.mpq_3_2], w2
return ctx.hypercomb(h, [], **kwargs)
"""
@defun_wrapped
def gegenbauer(ctx, n, a, z, **kwargs):
# Special cases: a+0.5, a*2 poles
if ctx.isnpint(a):
return 0*(z+n)
if ctx.isnpint(a+0.5):
# TODO: something else is required here
# E.g.: gegenbauer(-2, -0.5, 3) == -12
if ctx.isnpint(n+1):
raise NotImplementedError("Gegenbauer function with two limits")
def h(a):
a2 = 2*a
T = [], [], [n+a2], [n+1, a2], [-n, n+a2], [a+0.5], 0.5*(1-z)
return [T]
return ctx.hypercomb(h, [a], **kwargs)
def h(n):
a2 = 2*a
T = [], [], [n+a2], [n+1, a2], [-n, n+a2], [a+0.5], 0.5*(1-z)
return [T]
return ctx.hypercomb(h, [n], **kwargs)
@defun_wrapped
def jacobi(ctx, n, a, b, x, **kwargs):
if not ctx.isnpint(a):
def h(n):
return (([], [], [a+n+1], [n+1, a+1], [-n, a+b+n+1], [a+1], (1-x)*0.5),)
return ctx.hypercomb(h, [n], **kwargs)
if not ctx.isint(b):
def h(n, a):
return (([], [], [-b], [n+1, -b-n], [-n, a+b+n+1], [b+1], (x+1)*0.5),)
return ctx.hypercomb(h, [n, a], **kwargs)
# XXX: determine appropriate limit
return ctx.binomial(n+a,n) * ctx.hyp2f1(-n,1+n+a+b,a+1,(1-x)/2, **kwargs)
@defun_wrapped
def laguerre(ctx, n, a, z, **kwargs):
# XXX: limits, poles
#if ctx.isnpint(n):
# return 0*(a+z)
def h(a):
return (([], [], [a+n+1], [a+1, n+1], [-n], [a+1], z),)
return ctx.hypercomb(h, [a], **kwargs)
@defun_wrapped
def legendre(ctx, n, x, **kwargs):
if ctx.isint(n):
n = int(n)
# Accuracy near zeros
if (n + (n < 0)) & 1:
if not x:
return x
mag = ctx.mag(x)
if mag < -2*ctx.prec-10:
return x
if mag < -5:
ctx.prec += -mag
return ctx.hyp2f1(-n,n+1,1,(1-x)/2, **kwargs)
@defun
def legenp(ctx, n, m, z, type=2, **kwargs):
# Legendre function, 1st kind
n = ctx.convert(n)
m = ctx.convert(m)
# Faster
if not m:
return ctx.legendre(n, z, **kwargs)
# TODO: correct evaluation at singularities
if type == 2:
def h(n,m):
g = m*0.5
T = [1+z, 1-z], [g, -g], [], [1-m], [-n, n+1], [1-m], 0.5*(1-z)
return (T,)
return ctx.hypercomb(h, [n,m], **kwargs)
if type == 3:
def h(n,m):
g = m*0.5
T = [z+1, z-1], [g, -g], [], [1-m], [-n, n+1], [1-m], 0.5*(1-z)
return (T,)
return ctx.hypercomb(h, [n,m], **kwargs)
raise ValueError("requires type=2 or type=3")
@defun
def legenq(ctx, n, m, z, type=2, **kwargs):
# Legendre function, 2nd kind
n = ctx.convert(n)
m = ctx.convert(m)
z = ctx.convert(z)
if z in (1, -1):
#if ctx.isint(m):
# return ctx.nan
#return ctx.inf # unsigned
return ctx.nan
if type == 2:
def h(n, m):
cos, sin = ctx.cospi_sinpi(m)
s = 2 * sin / ctx.pi
c = cos
a = 1+z
b = 1-z
u = m/2
w = (1-z)/2
T1 = [s, c, a, b], [-1, 1, u, -u], [], [1-m], \
[-n, n+1], [1-m], w
T2 = [-s, a, b], [-1, -u, u], [n+m+1], [n-m+1, m+1], \
[-n, n+1], [m+1], w
return T1, T2
return ctx.hypercomb(h, [n, m], **kwargs)
if type == 3:
# The following is faster when there only is a single series
# Note: not valid for -1 < z < 0 (?)
if abs(z) > 1:
def h(n, m):
T1 = [ctx.expjpi(m), 2, ctx.pi, z, z-1, z+1], \
[1, -n-1, 0.5, -n-m-1, 0.5*m, 0.5*m], \
[n+m+1], [n+1.5], \
[0.5*(2+n+m), 0.5*(1+n+m)], [n+1.5], z**(-2)
return [T1]
return ctx.hypercomb(h, [n, m], **kwargs)
else:
# not valid for 1 < z < inf ?
def h(n, m):
s = 2 * ctx.sinpi(m) / ctx.pi
c = ctx.expjpi(m)
a = 1+z
b = z-1
u = m/2
w = (1-z)/2
T1 = [s, c, a, b], [-1, 1, u, -u], [], [1-m], \
[-n, n+1], [1-m], w
T2 = [-s, c, a, b], [-1, 1, -u, u], [n+m+1], [n-m+1, m+1], \
[-n, n+1], [m+1], w
return T1, T2
return ctx.hypercomb(h, [n, m], **kwargs)
raise ValueError("requires type=2 or type=3")
@defun_wrapped
def chebyt(ctx, n, x, **kwargs):
if (not x) and ctx.isint(n) and int(ctx._re(n)) % 2 == 1:
return x * 0
return ctx.hyp2f1(-n,n,(1,2),(1-x)/2, **kwargs)
@defun_wrapped
def chebyu(ctx, n, x, **kwargs):
if (not x) and ctx.isint(n) and int(ctx._re(n)) % 2 == 1:
return x * 0
return (n+1) * ctx.hyp2f1(-n, n+2, (3,2), (1-x)/2, **kwargs)
@defun
def spherharm(ctx, l, m, theta, phi, **kwargs):
l = ctx.convert(l)
m = ctx.convert(m)
theta = ctx.convert(theta)
phi = ctx.convert(phi)
l_isint = ctx.isint(l)
l_natural = l_isint and l >= 0
m_isint = ctx.isint(m)
if l_isint and l < 0 and m_isint:
return ctx.spherharm(-(l+1), m, theta, phi, **kwargs)
if theta == 0 and m_isint and m < 0:
return ctx.zero * 1j
if l_natural and m_isint:
if abs(m) > l:
return ctx.zero * 1j
# http://functions.wolfram.com/Polynomials/
# SphericalHarmonicY/26/01/02/0004/
def h(l,m):
absm = abs(m)
C = [-1, ctx.expj(m*phi),
(2*l+1)*ctx.fac(l+absm)/ctx.pi/ctx.fac(l-absm),
ctx.sin(theta)**2,
ctx.fac(absm), 2]
P = [0.5*m*(ctx.sign(m)+1), 1, 0.5, 0.5*absm, -1, -absm-1]
return ((C, P, [], [], [absm-l, l+absm+1], [absm+1],
ctx.sin(0.5*theta)**2),)
else:
# http://functions.wolfram.com/HypergeometricFunctions/
# SphericalHarmonicYGeneral/26/01/02/0001/
def h(l,m):
if ctx.isnpint(l-m+1) or ctx.isnpint(l+m+1) or ctx.isnpint(1-m):
return (([0], [-1], [], [], [], [], 0),)
cos, sin = ctx.cos_sin(0.5*theta)
C = [0.5*ctx.expj(m*phi), (2*l+1)/ctx.pi,
ctx.gamma(l-m+1), ctx.gamma(l+m+1),
cos**2, sin**2]
P = [1, 0.5, 0.5, -0.5, 0.5*m, -0.5*m]
return ((C, P, [], [1-m], [-l,l+1], [1-m], sin**2),)
return ctx.hypercomb(h, [l,m], **kwargs)

View File

@ -0,0 +1,280 @@
from .functions import defun, defun_wrapped
@defun
def qp(ctx, a, q=None, n=None, **kwargs):
r"""
Evaluates the q-Pochhammer symbol (or q-rising factorial)
.. math ::
(a; q)_n = \prod_{k=0}^{n-1} (1-a q^k)
where `n = \infty` is permitted if `|q| < 1`. Called with two arguments,
``qp(a,q)`` computes `(a;q)_{\infty}`; with a single argument, ``qp(q)``
computes `(q;q)_{\infty}`. The special case
.. math ::
\phi(q) = (q; q)_{\infty} = \prod_{k=1}^{\infty} (1-q^k) =
\sum_{k=-\infty}^{\infty} (-1)^k q^{(3k^2-k)/2}
is also known as the Euler function, or (up to a factor `q^{-1/24}`)
the Dedekind eta function.
**Examples**
If `n` is a positive integer, the function amounts to a finite product::
>>> from mpmath import *
>>> mp.dps = 25; mp.pretty = True
>>> qp(2,3,5)
-725305.0
>>> fprod(1-2*3**k for k in range(5))
-725305.0
>>> qp(2,3,0)
1.0
Complex arguments are allowed::
>>> qp(2-1j, 0.75j)
(0.4628842231660149089976379 + 4.481821753552703090628793j)
The regular Pochhammer symbol `(a)_n` is obtained in the
following limit as `q \to 1`::
>>> a, n = 4, 7
>>> limit(lambda q: qp(q**a,q,n) / (1-q)**n, 1)
604800.0
>>> rf(a,n)
604800.0
The Taylor series of the reciprocal Euler function gives
the partition function `P(n)`, i.e. the number of ways of writing
`n` as a sum of positive integers::
>>> taylor(lambda q: 1/qp(q), 0, 10)
[1.0, 1.0, 2.0, 3.0, 5.0, 7.0, 11.0, 15.0, 22.0, 30.0, 42.0]
Special values include::
>>> qp(0)
1.0
>>> findroot(diffun(qp), -0.4) # location of maximum
-0.4112484791779547734440257
>>> qp(_)
1.228348867038575112586878
The q-Pochhammer symbol is related to the Jacobi theta functions.
For example, the following identity holds::
>>> q = mpf(0.5) # arbitrary
>>> qp(q)
0.2887880950866024212788997
>>> root(3,-2)*root(q,-24)*jtheta(2,pi/6,root(q,6))
0.2887880950866024212788997
"""
a = ctx.convert(a)
if n is None:
n = ctx.inf
else:
n = ctx.convert(n)
if n < 0:
raise ValueError("n cannot be negative")
if q is None:
q = a
else:
q = ctx.convert(q)
if n == 0:
return ctx.one + 0*(a+q)
infinite = (n == ctx.inf)
same = (a == q)
if infinite:
if abs(q) >= 1:
if same and (q == -1 or q == 1):
return ctx.zero * q
raise ValueError("q-function only defined for |q| < 1")
elif q == 0:
return ctx.one - a
maxterms = kwargs.get('maxterms', 50*ctx.prec)
if infinite and same:
# Euler's pentagonal theorem
def terms():
t = 1
yield t
k = 1
x1 = q
x2 = q**2
while 1:
yield (-1)**k * x1
yield (-1)**k * x2
x1 *= q**(3*k+1)
x2 *= q**(3*k+2)
k += 1
if k > maxterms:
raise ctx.NoConvergence
return ctx.sum_accurately(terms)
# return ctx.nprod(lambda k: 1-a*q**k, [0,n-1])
def factors():
k = 0
r = ctx.one
while 1:
yield 1 - a*r
r *= q
k += 1
if k >= n:
return
if k > maxterms:
raise ctx.NoConvergence
return ctx.mul_accurately(factors)
@defun_wrapped
def qgamma(ctx, z, q, **kwargs):
r"""
Evaluates the q-gamma function
.. math ::
\Gamma_q(z) = \frac{(q; q)_{\infty}}{(q^z; q)_{\infty}} (1-q)^{1-z}.
**Examples**
Evaluation for real and complex arguments::
>>> from mpmath import *
>>> mp.dps = 25; mp.pretty = True
>>> qgamma(4,0.75)
4.046875
>>> qgamma(6,6)
121226245.0
>>> qgamma(3+4j, 0.5j)
(0.1663082382255199834630088 + 0.01952474576025952984418217j)
The q-gamma function satisfies a functional equation similar
to that of the ordinary gamma function::
>>> q = mpf(0.25)
>>> z = mpf(2.5)
>>> qgamma(z+1,q)
1.428277424823760954685912
>>> (1-q**z)/(1-q)*qgamma(z,q)
1.428277424823760954685912
"""
if abs(q) > 1:
return ctx.qgamma(z,1/q)*q**((z-2)*(z-1)*0.5)
return ctx.qp(q, q, None, **kwargs) / \
ctx.qp(q**z, q, None, **kwargs) * (1-q)**(1-z)
@defun_wrapped
def qfac(ctx, z, q, **kwargs):
r"""
Evaluates the q-factorial,
.. math ::
[n]_q! = (1+q)(1+q+q^2)\cdots(1+q+\cdots+q^{n-1})
or more generally
.. math ::
[z]_q! = \frac{(q;q)_z}{(1-q)^z}.
**Examples**
>>> from mpmath import *
>>> mp.dps = 25; mp.pretty = True
>>> qfac(0,0)
1.0
>>> qfac(4,3)
2080.0
>>> qfac(5,6)
121226245.0
>>> qfac(1+1j, 2+1j)
(0.4370556551322672478613695 + 0.2609739839216039203708921j)
"""
if ctx.isint(z) and ctx._re(z) > 0:
n = int(ctx._re(z))
return ctx.qp(q, q, n, **kwargs) / (1-q)**n
return ctx.qgamma(z+1, q, **kwargs)
@defun
def qhyper(ctx, a_s, b_s, q, z, **kwargs):
r"""
Evaluates the basic hypergeometric series or hypergeometric q-series
.. math ::
\,_r\phi_s \left[\begin{matrix}
a_1 & a_2 & \ldots & a_r \\
b_1 & b_2 & \ldots & b_s
\end{matrix} ; q,z \right] =
\sum_{n=0}^\infty
\frac{(a_1;q)_n, \ldots, (a_r;q)_n}
{(b_1;q)_n, \ldots, (b_s;q)_n}
\left((-1)^n q^{n\choose 2}\right)^{1+s-r}
\frac{z^n}{(q;q)_n}
where `(a;q)_n` denotes the q-Pochhammer symbol (see :func:`~mpmath.qp`).
**Examples**
Evaluation works for real and complex arguments::
>>> from mpmath import *
>>> mp.dps = 25; mp.pretty = True
>>> qhyper([0.5], [2.25], 0.25, 4)
-0.1975849091263356009534385
>>> qhyper([0.5], [2.25], 0.25-0.25j, 4)
(2.806330244925716649839237 + 3.568997623337943121769938j)
>>> qhyper([1+j], [2,3+0.5j], 0.25, 3+4j)
(9.112885171773400017270226 - 1.272756997166375050700388j)
Comparing with a summation of the defining series, using
:func:`~mpmath.nsum`::
>>> b, q, z = 3, 0.25, 0.5
>>> qhyper([], [b], q, z)
0.6221136748254495583228324
>>> nsum(lambda n: z**n / qp(q,q,n)/qp(b,q,n) * q**(n*(n-1)), [0,inf])
0.6221136748254495583228324
"""
#a_s = [ctx._convert_param(a)[0] for a in a_s]
#b_s = [ctx._convert_param(b)[0] for b in b_s]
#q = ctx._convert_param(q)[0]
a_s = [ctx.convert(a) for a in a_s]
b_s = [ctx.convert(b) for b in b_s]
q = ctx.convert(q)
z = ctx.convert(z)
r = len(a_s)
s = len(b_s)
d = 1+s-r
maxterms = kwargs.get('maxterms', 50*ctx.prec)
def terms():
t = ctx.one
yield t
qk = 1
k = 0
x = 1
while 1:
for a in a_s:
p = 1 - a*qk
t *= p
for b in b_s:
p = 1 - b*qk
if not p:
raise ValueError
t /= p
t *= z
x *= (-1)**d * qk ** d
qk *= q
t /= (1 - qk)
k += 1
yield t * x
if k > maxterms:
raise ctx.NoConvergence
return ctx.sum_accurately(terms)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,32 @@
from .functions import defun_wrapped
@defun_wrapped
def squarew(ctx, t, amplitude=1, period=1):
P = period
A = amplitude
return A*((-1)**ctx.floor(2*t/P))
@defun_wrapped
def trianglew(ctx, t, amplitude=1, period=1):
A = amplitude
P = period
return 2*A*(0.5 - ctx.fabs(1 - 2*ctx.frac(t/P + 0.25)))
@defun_wrapped
def sawtoothw(ctx, t, amplitude=1, period=1):
A = amplitude
P = period
return A*ctx.frac(t/P)
@defun_wrapped
def unit_triangle(ctx, t, amplitude=1):
A = amplitude
if t <= -1 or t >= 1:
return ctx.zero
return A*(-ctx.fabs(t) + 1)
@defun_wrapped
def sigmoid(ctx, t, amplitude=1):
A = amplitude
return A / (1 + ctx.exp(-t))

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,844 @@
"""
Implements the PSLQ algorithm for integer relation detection,
and derivative algorithms for constant recognition.
"""
from .libmp.backend import xrange
from .libmp import int_types, sqrt_fixed
# round to nearest integer (can be done more elegantly...)
def round_fixed(x, prec):
return ((x + (1<<(prec-1))) >> prec) << prec
class IdentificationMethods(object):
pass
def pslq(ctx, x, tol=None, maxcoeff=1000, maxsteps=100, verbose=False):
r"""
Given a vector of real numbers `x = [x_0, x_1, ..., x_n]`, ``pslq(x)``
uses the PSLQ algorithm to find a list of integers
`[c_0, c_1, ..., c_n]` such that
.. math ::
|c_1 x_1 + c_2 x_2 + ... + c_n x_n| < \mathrm{tol}
and such that `\max |c_k| < \mathrm{maxcoeff}`. If no such vector
exists, :func:`~mpmath.pslq` returns ``None``. The tolerance defaults to
3/4 of the working precision.
**Examples**
Find rational approximations for `\pi`::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> pslq([-1, pi], tol=0.01)
[22, 7]
>>> pslq([-1, pi], tol=0.001)
[355, 113]
>>> mpf(22)/7; mpf(355)/113; +pi
3.14285714285714
3.14159292035398
3.14159265358979
Pi is not a rational number with denominator less than 1000::
>>> pslq([-1, pi])
>>>
To within the standard precision, it can however be approximated
by at least one rational number with denominator less than `10^{12}`::
>>> p, q = pslq([-1, pi], maxcoeff=10**12)
>>> print(p); print(q)
238410049439
75888275702
>>> mpf(p)/q
3.14159265358979
The PSLQ algorithm can be applied to long vectors. For example,
we can investigate the rational (in)dependence of integer square
roots::
>>> mp.dps = 30
>>> pslq([sqrt(n) for n in range(2, 5+1)])
>>>
>>> pslq([sqrt(n) for n in range(2, 6+1)])
>>>
>>> pslq([sqrt(n) for n in range(2, 8+1)])
[2, 0, 0, 0, 0, 0, -1]
**Machin formulas**
A famous formula for `\pi` is Machin's,
.. math ::
\frac{\pi}{4} = 4 \operatorname{acot} 5 - \operatorname{acot} 239
There are actually infinitely many formulas of this type. Two
others are
.. math ::
\frac{\pi}{4} = \operatorname{acot} 1
\frac{\pi}{4} = 12 \operatorname{acot} 49 + 32 \operatorname{acot} 57
+ 5 \operatorname{acot} 239 + 12 \operatorname{acot} 110443
We can easily verify the formulas using the PSLQ algorithm::
>>> mp.dps = 30
>>> pslq([pi/4, acot(1)])
[1, -1]
>>> pslq([pi/4, acot(5), acot(239)])
[1, -4, 1]
>>> pslq([pi/4, acot(49), acot(57), acot(239), acot(110443)])
[1, -12, -32, 5, -12]
We could try to generate a custom Machin-like formula by running
the PSLQ algorithm with a few inverse cotangent values, for example
acot(2), acot(3) ... acot(10). Unfortunately, there is a linear
dependence among these values, resulting in only that dependence
being detected, with a zero coefficient for `\pi`::
>>> pslq([pi] + [acot(n) for n in range(2,11)])
[0, 1, -1, 0, 0, 0, -1, 0, 0, 0]
We get better luck by removing linearly dependent terms::
>>> pslq([pi] + [acot(n) for n in range(2,11) if n not in (3, 5)])
[1, -8, 0, 0, 4, 0, 0, 0]
In other words, we found the following formula::
>>> 8*acot(2) - 4*acot(7)
3.14159265358979323846264338328
>>> +pi
3.14159265358979323846264338328
**Algorithm**
This is a fairly direct translation to Python of the pseudocode given by
David Bailey, "The PSLQ Integer Relation Algorithm":
http://www.cecm.sfu.ca/organics/papers/bailey/paper/html/node3.html
The present implementation uses fixed-point instead of floating-point
arithmetic, since this is significantly (about 7x) faster.
"""
n = len(x)
if n < 2:
raise ValueError("n cannot be less than 2")
# At too low precision, the algorithm becomes meaningless
prec = ctx.prec
if prec < 53:
raise ValueError("prec cannot be less than 53")
if verbose and prec // max(2,n) < 5:
print("Warning: precision for PSLQ may be too low")
target = int(prec * 0.75)
if tol is None:
tol = ctx.mpf(2)**(-target)
else:
tol = ctx.convert(tol)
extra = 60
prec += extra
if verbose:
print("PSLQ using prec %i and tol %s" % (prec, ctx.nstr(tol)))
tol = ctx.to_fixed(tol, prec)
assert tol
# Convert to fixed-point numbers. The dummy None is added so we can
# use 1-based indexing. (This just allows us to be consistent with
# Bailey's indexing. The algorithm is 100 lines long, so debugging
# a single wrong index can be painful.)
x = [None] + [ctx.to_fixed(ctx.mpf(xk), prec) for xk in x]
# Sanity check on magnitudes
minx = min(abs(xx) for xx in x[1:])
if not minx:
raise ValueError("PSLQ requires a vector of nonzero numbers")
if minx < tol//100:
if verbose:
print("STOPPING: (one number is too small)")
return None
g = sqrt_fixed((4<<prec)//3, prec)
A = {}
B = {}
H = {}
# Initialization
# step 1
for i in xrange(1, n+1):
for j in xrange(1, n+1):
A[i,j] = B[i,j] = (i==j) << prec
H[i,j] = 0
# step 2
s = [None] + [0] * n
for k in xrange(1, n+1):
t = 0
for j in xrange(k, n+1):
t += (x[j]**2 >> prec)
s[k] = sqrt_fixed(t, prec)
t = s[1]
y = x[:]
for k in xrange(1, n+1):
y[k] = (x[k] << prec) // t
s[k] = (s[k] << prec) // t
# step 3
for i in xrange(1, n+1):
for j in xrange(i+1, n):
H[i,j] = 0
if i <= n-1:
if s[i]:
H[i,i] = (s[i+1] << prec) // s[i]
else:
H[i,i] = 0
for j in range(1, i):
sjj1 = s[j]*s[j+1]
if sjj1:
H[i,j] = ((-y[i]*y[j])<<prec)//sjj1
else:
H[i,j] = 0
# step 4
for i in xrange(2, n+1):
for j in xrange(i-1, 0, -1):
#t = floor(H[i,j]/H[j,j] + 0.5)
if H[j,j]:
t = round_fixed((H[i,j] << prec)//H[j,j], prec)
else:
#t = 0
continue
y[j] = y[j] + (t*y[i] >> prec)
for k in xrange(1, j+1):
H[i,k] = H[i,k] - (t*H[j,k] >> prec)
for k in xrange(1, n+1):
A[i,k] = A[i,k] - (t*A[j,k] >> prec)
B[k,j] = B[k,j] + (t*B[k,i] >> prec)
# Main algorithm
for REP in range(maxsteps):
# Step 1
m = -1
szmax = -1
for i in range(1, n):
h = H[i,i]
sz = (g**i * abs(h)) >> (prec*(i-1))
if sz > szmax:
m = i
szmax = sz
# Step 2
y[m], y[m+1] = y[m+1], y[m]
for i in xrange(1,n+1): H[m,i], H[m+1,i] = H[m+1,i], H[m,i]
for i in xrange(1,n+1): A[m,i], A[m+1,i] = A[m+1,i], A[m,i]
for i in xrange(1,n+1): B[i,m], B[i,m+1] = B[i,m+1], B[i,m]
# Step 3
if m <= n - 2:
t0 = sqrt_fixed((H[m,m]**2 + H[m,m+1]**2)>>prec, prec)
# A zero element probably indicates that the precision has
# been exhausted. XXX: this could be spurious, due to
# using fixed-point arithmetic
if not t0:
break
t1 = (H[m,m] << prec) // t0
t2 = (H[m,m+1] << prec) // t0
for i in xrange(m, n+1):
t3 = H[i,m]
t4 = H[i,m+1]
H[i,m] = (t1*t3+t2*t4) >> prec
H[i,m+1] = (-t2*t3+t1*t4) >> prec
# Step 4
for i in xrange(m+1, n+1):
for j in xrange(min(i-1, m+1), 0, -1):
try:
t = round_fixed((H[i,j] << prec)//H[j,j], prec)
# Precision probably exhausted
except ZeroDivisionError:
break
y[j] = y[j] + ((t*y[i]) >> prec)
for k in xrange(1, j+1):
H[i,k] = H[i,k] - (t*H[j,k] >> prec)
for k in xrange(1, n+1):
A[i,k] = A[i,k] - (t*A[j,k] >> prec)
B[k,j] = B[k,j] + (t*B[k,i] >> prec)
# Until a relation is found, the error typically decreases
# slowly (e.g. a factor 1-10) with each step TODO: we could
# compare err from two successive iterations. If there is a
# large drop (several orders of magnitude), that indicates a
# "high quality" relation was detected. Reporting this to
# the user somehow might be useful.
best_err = maxcoeff<<prec
for i in xrange(1, n+1):
err = abs(y[i])
# Maybe we are done?
if err < tol:
# We are done if the coefficients are acceptable
vec = [int(round_fixed(B[j,i], prec) >> prec) for j in \
range(1,n+1)]
if max(abs(v) for v in vec) < maxcoeff:
if verbose:
print("FOUND relation at iter %i/%i, error: %s" % \
(REP, maxsteps, ctx.nstr(err / ctx.mpf(2)**prec, 1)))
return vec
best_err = min(err, best_err)
# Calculate a lower bound for the norm. We could do this
# more exactly (using the Euclidean norm) but there is probably
# no practical benefit.
recnorm = max(abs(h) for h in H.values())
if recnorm:
norm = ((1 << (2*prec)) // recnorm) >> prec
norm //= 100
else:
norm = ctx.inf
if verbose:
print("%i/%i: Error: %8s Norm: %s" % \
(REP, maxsteps, ctx.nstr(best_err / ctx.mpf(2)**prec, 1), norm))
if norm >= maxcoeff:
break
if verbose:
print("CANCELLING after step %i/%i." % (REP, maxsteps))
print("Could not find an integer relation. Norm bound: %s" % norm)
return None
def findpoly(ctx, x, n=1, **kwargs):
r"""
``findpoly(x, n)`` returns the coefficients of an integer
polynomial `P` of degree at most `n` such that `P(x) \approx 0`.
If no polynomial having `x` as a root can be found,
:func:`~mpmath.findpoly` returns ``None``.
:func:`~mpmath.findpoly` works by successively calling :func:`~mpmath.pslq` with
the vectors `[1, x]`, `[1, x, x^2]`, `[1, x, x^2, x^3]`, ...,
`[1, x, x^2, .., x^n]` as input. Keyword arguments given to
:func:`~mpmath.findpoly` are forwarded verbatim to :func:`~mpmath.pslq`. In
particular, you can specify a tolerance for `P(x)` with ``tol``
and a maximum permitted coefficient size with ``maxcoeff``.
For large values of `n`, it is recommended to run :func:`~mpmath.findpoly`
at high precision; preferably 50 digits or more.
**Examples**
By default (degree `n = 1`), :func:`~mpmath.findpoly` simply finds a linear
polynomial with a rational root::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> findpoly(0.7)
[-10, 7]
The generated coefficient list is valid input to ``polyval`` and
``polyroots``::
>>> nprint(polyval(findpoly(phi, 2), phi), 1)
-2.0e-16
>>> for r in polyroots(findpoly(phi, 2)):
... print(r)
...
-0.618033988749895
1.61803398874989
Numbers of the form `m + n \sqrt p` for integers `(m, n, p)` are
solutions to quadratic equations. As we find here, `1+\sqrt 2`
is a root of the polynomial `x^2 - 2x - 1`::
>>> findpoly(1+sqrt(2), 2)
[1, -2, -1]
>>> findroot(lambda x: x**2 - 2*x - 1, 1)
2.4142135623731
Despite only containing square roots, the following number results
in a polynomial of degree 4::
>>> findpoly(sqrt(2)+sqrt(3), 4)
[1, 0, -10, 0, 1]
In fact, `x^4 - 10x^2 + 1` is the *minimal polynomial* of
`r = \sqrt 2 + \sqrt 3`, meaning that a rational polynomial of
lower degree having `r` as a root does not exist. Given sufficient
precision, :func:`~mpmath.findpoly` will usually find the correct
minimal polynomial of a given algebraic number.
**Non-algebraic numbers**
If :func:`~mpmath.findpoly` fails to find a polynomial with given
coefficient size and tolerance constraints, that means no such
polynomial exists.
We can verify that `\pi` is not an algebraic number of degree 3 with
coefficients less than 1000::
>>> mp.dps = 15
>>> findpoly(pi, 3)
>>>
It is always possible to find an algebraic approximation of a number
using one (or several) of the following methods:
1. Increasing the permitted degree
2. Allowing larger coefficients
3. Reducing the tolerance
One example of each method is shown below::
>>> mp.dps = 15
>>> findpoly(pi, 4)
[95, -545, 863, -183, -298]
>>> findpoly(pi, 3, maxcoeff=10000)
[836, -1734, -2658, -457]
>>> findpoly(pi, 3, tol=1e-7)
[-4, 22, -29, -2]
It is unknown whether Euler's constant is transcendental (or even
irrational). We can use :func:`~mpmath.findpoly` to check that if is
an algebraic number, its minimal polynomial must have degree
at least 7 and a coefficient of magnitude at least 1000000::
>>> mp.dps = 200
>>> findpoly(euler, 6, maxcoeff=10**6, tol=1e-100, maxsteps=1000)
>>>
Note that the high precision and strict tolerance is necessary
for such high-degree runs, since otherwise unwanted low-accuracy
approximations will be detected. It may also be necessary to set
maxsteps high to prevent a premature exit (before the coefficient
bound has been reached). Running with ``verbose=True`` to get an
idea what is happening can be useful.
"""
x = ctx.mpf(x)
if n < 1:
raise ValueError("n cannot be less than 1")
if x == 0:
return [1, 0]
xs = [ctx.mpf(1)]
for i in range(1,n+1):
xs.append(x**i)
a = ctx.pslq(xs, **kwargs)
if a is not None:
return a[::-1]
def fracgcd(p, q):
x, y = p, q
while y:
x, y = y, x % y
if x != 1:
p //= x
q //= x
if q == 1:
return p
return p, q
def pslqstring(r, constants):
q = r[0]
r = r[1:]
s = []
for i in range(len(r)):
p = r[i]
if p:
z = fracgcd(-p,q)
cs = constants[i][1]
if cs == '1':
cs = ''
else:
cs = '*' + cs
if isinstance(z, int_types):
if z > 0: term = str(z) + cs
else: term = ("(%s)" % z) + cs
else:
term = ("(%s/%s)" % z) + cs
s.append(term)
s = ' + '.join(s)
if '+' in s or '*' in s:
s = '(' + s + ')'
return s or '0'
def prodstring(r, constants):
q = r[0]
r = r[1:]
num = []
den = []
for i in range(len(r)):
p = r[i]
if p:
z = fracgcd(-p,q)
cs = constants[i][1]
if isinstance(z, int_types):
if abs(z) == 1: t = cs
else: t = '%s**%s' % (cs, abs(z))
([num,den][z<0]).append(t)
else:
t = '%s**(%s/%s)' % (cs, abs(z[0]), z[1])
([num,den][z[0]<0]).append(t)
num = '*'.join(num)
den = '*'.join(den)
if num and den: return "(%s)/(%s)" % (num, den)
if num: return num
if den: return "1/(%s)" % den
def quadraticstring(ctx,t,a,b,c):
if c < 0:
a,b,c = -a,-b,-c
u1 = (-b+ctx.sqrt(b**2-4*a*c))/(2*c)
u2 = (-b-ctx.sqrt(b**2-4*a*c))/(2*c)
if abs(u1-t) < abs(u2-t):
if b: s = '((%s+sqrt(%s))/%s)' % (-b,b**2-4*a*c,2*c)
else: s = '(sqrt(%s)/%s)' % (-4*a*c,2*c)
else:
if b: s = '((%s-sqrt(%s))/%s)' % (-b,b**2-4*a*c,2*c)
else: s = '(-sqrt(%s)/%s)' % (-4*a*c,2*c)
return s
# Transformation y = f(x,c), with inverse function x = f(y,c)
# The third entry indicates whether the transformation is
# redundant when c = 1
transforms = [
(lambda ctx,x,c: x*c, '$y/$c', 0),
(lambda ctx,x,c: x/c, '$c*$y', 1),
(lambda ctx,x,c: c/x, '$c/$y', 0),
(lambda ctx,x,c: (x*c)**2, 'sqrt($y)/$c', 0),
(lambda ctx,x,c: (x/c)**2, '$c*sqrt($y)', 1),
(lambda ctx,x,c: (c/x)**2, '$c/sqrt($y)', 0),
(lambda ctx,x,c: c*x**2, 'sqrt($y)/sqrt($c)', 1),
(lambda ctx,x,c: x**2/c, 'sqrt($c)*sqrt($y)', 1),
(lambda ctx,x,c: c/x**2, 'sqrt($c)/sqrt($y)', 1),
(lambda ctx,x,c: ctx.sqrt(x*c), '$y**2/$c', 0),
(lambda ctx,x,c: ctx.sqrt(x/c), '$c*$y**2', 1),
(lambda ctx,x,c: ctx.sqrt(c/x), '$c/$y**2', 0),
(lambda ctx,x,c: c*ctx.sqrt(x), '$y**2/$c**2', 1),
(lambda ctx,x,c: ctx.sqrt(x)/c, '$c**2*$y**2', 1),
(lambda ctx,x,c: c/ctx.sqrt(x), '$c**2/$y**2', 1),
(lambda ctx,x,c: ctx.exp(x*c), 'log($y)/$c', 0),
(lambda ctx,x,c: ctx.exp(x/c), '$c*log($y)', 1),
(lambda ctx,x,c: ctx.exp(c/x), '$c/log($y)', 0),
(lambda ctx,x,c: c*ctx.exp(x), 'log($y/$c)', 1),
(lambda ctx,x,c: ctx.exp(x)/c, 'log($c*$y)', 1),
(lambda ctx,x,c: c/ctx.exp(x), 'log($c/$y)', 0),
(lambda ctx,x,c: ctx.ln(x*c), 'exp($y)/$c', 0),
(lambda ctx,x,c: ctx.ln(x/c), '$c*exp($y)', 1),
(lambda ctx,x,c: ctx.ln(c/x), '$c/exp($y)', 0),
(lambda ctx,x,c: c*ctx.ln(x), 'exp($y/$c)', 1),
(lambda ctx,x,c: ctx.ln(x)/c, 'exp($c*$y)', 1),
(lambda ctx,x,c: c/ctx.ln(x), 'exp($c/$y)', 0),
]
def identify(ctx, x, constants=[], tol=None, maxcoeff=1000, full=False,
verbose=False):
r"""
Given a real number `x`, ``identify(x)`` attempts to find an exact
formula for `x`. This formula is returned as a string. If no match
is found, ``None`` is returned. With ``full=True``, a list of
matching formulas is returned.
As a simple example, :func:`~mpmath.identify` will find an algebraic
formula for the golden ratio::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> identify(phi)
'((1+sqrt(5))/2)'
:func:`~mpmath.identify` can identify simple algebraic numbers and simple
combinations of given base constants, as well as certain basic
transformations thereof. More specifically, :func:`~mpmath.identify`
looks for the following:
1. Fractions
2. Quadratic algebraic numbers
3. Rational linear combinations of the base constants
4. Any of the above after first transforming `x` into `f(x)` where
`f(x)` is `1/x`, `\sqrt x`, `x^2`, `\log x` or `\exp x`, either
directly or with `x` or `f(x)` multiplied or divided by one of
the base constants
5. Products of fractional powers of the base constants and
small integers
Base constants can be given as a list of strings representing mpmath
expressions (:func:`~mpmath.identify` will ``eval`` the strings to numerical
values and use the original strings for the output), or as a dict of
formula:value pairs.
In order not to produce spurious results, :func:`~mpmath.identify` should
be used with high precision; preferably 50 digits or more.
**Examples**
Simple identifications can be performed safely at standard
precision. Here the default recognition of rational, algebraic,
and exp/log of algebraic numbers is demonstrated::
>>> mp.dps = 15
>>> identify(0.22222222222222222)
'(2/9)'
>>> identify(1.9662210973805663)
'sqrt(((24+sqrt(48))/8))'
>>> identify(4.1132503787829275)
'exp((sqrt(8)/2))'
>>> identify(0.881373587019543)
'log(((2+sqrt(8))/2))'
By default, :func:`~mpmath.identify` does not recognize `\pi`. At standard
precision it finds a not too useful approximation. At slightly
increased precision, this approximation is no longer accurate
enough and :func:`~mpmath.identify` more correctly returns ``None``::
>>> identify(pi)
'(2**(176/117)*3**(20/117)*5**(35/39))/(7**(92/117))'
>>> mp.dps = 30
>>> identify(pi)
>>>
Numbers such as `\pi`, and simple combinations of user-defined
constants, can be identified if they are provided explicitly::
>>> identify(3*pi-2*e, ['pi', 'e'])
'(3*pi + (-2)*e)'
Here is an example using a dict of constants. Note that the
constants need not be "atomic"; :func:`~mpmath.identify` can just
as well express the given number in terms of expressions
given by formulas::
>>> identify(pi+e, {'a':pi+2, 'b':2*e})
'((-2) + 1*a + (1/2)*b)'
Next, we attempt some identifications with a set of base constants.
It is necessary to increase the precision a bit.
>>> mp.dps = 50
>>> base = ['sqrt(2)','pi','log(2)']
>>> identify(0.25, base)
'(1/4)'
>>> identify(3*pi + 2*sqrt(2) + 5*log(2)/7, base)
'(2*sqrt(2) + 3*pi + (5/7)*log(2))'
>>> identify(exp(pi+2), base)
'exp((2 + 1*pi))'
>>> identify(1/(3+sqrt(2)), base)
'((3/7) + (-1/7)*sqrt(2))'
>>> identify(sqrt(2)/(3*pi+4), base)
'sqrt(2)/(4 + 3*pi)'
>>> identify(5**(mpf(1)/3)*pi*log(2)**2, base)
'5**(1/3)*pi*log(2)**2'
An example of an erroneous solution being found when too low
precision is used::
>>> mp.dps = 15
>>> identify(1/(3*pi-4*e+sqrt(8)), ['pi', 'e', 'sqrt(2)'])
'((11/25) + (-158/75)*pi + (76/75)*e + (44/15)*sqrt(2))'
>>> mp.dps = 50
>>> identify(1/(3*pi-4*e+sqrt(8)), ['pi', 'e', 'sqrt(2)'])
'1/(3*pi + (-4)*e + 2*sqrt(2))'
**Finding approximate solutions**
The tolerance ``tol`` defaults to 3/4 of the working precision.
Lowering the tolerance is useful for finding approximate matches.
We can for example try to generate approximations for pi::
>>> mp.dps = 15
>>> identify(pi, tol=1e-2)
'(22/7)'
>>> identify(pi, tol=1e-3)
'(355/113)'
>>> identify(pi, tol=1e-10)
'(5**(339/269))/(2**(64/269)*3**(13/269)*7**(92/269))'
With ``full=True``, and by supplying a few base constants,
``identify`` can generate almost endless lists of approximations
for any number (the output below has been truncated to show only
the first few)::
>>> for p in identify(pi, ['e', 'catalan'], tol=1e-5, full=True):
... print(p)
... # doctest: +ELLIPSIS
e/log((6 + (-4/3)*e))
(3**3*5*e*catalan**2)/(2*7**2)
sqrt(((-13) + 1*e + 22*catalan))
log(((-6) + 24*e + 4*catalan)/e)
exp(catalan*((-1/5) + (8/15)*e))
catalan*(6 + (-6)*e + 15*catalan)
sqrt((5 + 26*e + (-3)*catalan))/e
e*sqrt(((-27) + 2*e + 25*catalan))
log(((-1) + (-11)*e + 59*catalan))
((3/20) + (21/20)*e + (3/20)*catalan)
...
The numerical values are roughly as close to `\pi` as permitted by the
specified tolerance:
>>> e/log(6-4*e/3)
3.14157719846001
>>> 135*e*catalan**2/98
3.14166950419369
>>> sqrt(e-13+22*catalan)
3.14158000062992
>>> log(24*e-6+4*catalan)-1
3.14158791577159
**Symbolic processing**
The output formula can be evaluated as a Python expression.
Note however that if fractions (like '2/3') are present in
the formula, Python's :func:`~mpmath.eval()` may erroneously perform
integer division. Note also that the output is not necessarily
in the algebraically simplest form::
>>> identify(sqrt(2))
'(sqrt(8)/2)'
As a solution to both problems, consider using SymPy's
:func:`~mpmath.sympify` to convert the formula into a symbolic expression.
SymPy can be used to pretty-print or further simplify the formula
symbolically::
>>> from sympy import sympify # doctest: +SKIP
>>> sympify(identify(sqrt(2))) # doctest: +SKIP
2**(1/2)
Sometimes :func:`~mpmath.identify` can simplify an expression further than
a symbolic algorithm::
>>> from sympy import simplify # doctest: +SKIP
>>> x = sympify('-1/(-3/2+(1/2)*5**(1/2))*(3/2-1/2*5**(1/2))**(1/2)') # doctest: +SKIP
>>> x # doctest: +SKIP
(3/2 - 5**(1/2)/2)**(-1/2)
>>> x = simplify(x) # doctest: +SKIP
>>> x # doctest: +SKIP
2/(6 - 2*5**(1/2))**(1/2)
>>> mp.dps = 30 # doctest: +SKIP
>>> x = sympify(identify(x.evalf(30))) # doctest: +SKIP
>>> x # doctest: +SKIP
1/2 + 5**(1/2)/2
(In fact, this functionality is available directly in SymPy as the
function :func:`~mpmath.nsimplify`, which is essentially a wrapper for
:func:`~mpmath.identify`.)
**Miscellaneous issues and limitations**
The input `x` must be a real number. All base constants must be
positive real numbers and must not be rationals or rational linear
combinations of each other.
The worst-case computation time grows quickly with the number of
base constants. Already with 3 or 4 base constants,
:func:`~mpmath.identify` may require several seconds to finish. To search
for relations among a large number of constants, you should
consider using :func:`~mpmath.pslq` directly.
The extended transformations are applied to x, not the constants
separately. As a result, ``identify`` will for example be able to
recognize ``exp(2*pi+3)`` with ``pi`` given as a base constant, but
not ``2*exp(pi)+3``. It will be able to recognize the latter if
``exp(pi)`` is given explicitly as a base constant.
"""
solutions = []
def addsolution(s):
if verbose: print("Found: ", s)
solutions.append(s)
x = ctx.mpf(x)
# Further along, x will be assumed positive
if x == 0:
if full: return ['0']
else: return '0'
if x < 0:
sol = ctx.identify(-x, constants, tol, maxcoeff, full, verbose)
if sol is None:
return sol
if full:
return ["-(%s)"%s for s in sol]
else:
return "-(%s)" % sol
if tol:
tol = ctx.mpf(tol)
else:
tol = ctx.eps**0.7
M = maxcoeff
if constants:
if isinstance(constants, dict):
constants = [(ctx.mpf(v), name) for (name, v) in sorted(constants.items())]
else:
namespace = dict((name, getattr(ctx,name)) for name in dir(ctx))
constants = [(eval(p, namespace), p) for p in constants]
else:
constants = []
# We always want to find at least rational terms
if 1 not in [value for (name, value) in constants]:
constants = [(ctx.mpf(1), '1')] + constants
# PSLQ with simple algebraic and functional transformations
for ft, ftn, red in transforms:
for c, cn in constants:
if red and cn == '1':
continue
t = ft(ctx,x,c)
# Prevent exponential transforms from wreaking havoc
if abs(t) > M**2 or abs(t) < tol:
continue
# Linear combination of base constants
r = ctx.pslq([t] + [a[0] for a in constants], tol, M)
s = None
if r is not None and max(abs(uw) for uw in r) <= M and r[0]:
s = pslqstring(r, constants)
# Quadratic algebraic numbers
else:
q = ctx.pslq([ctx.one, t, t**2], tol, M)
if q is not None and len(q) == 3 and q[2]:
aa, bb, cc = q
if max(abs(aa),abs(bb),abs(cc)) <= M:
s = quadraticstring(ctx,t,aa,bb,cc)
if s:
if cn == '1' and ('/$c' in ftn):
s = ftn.replace('$y', s).replace('/$c', '')
else:
s = ftn.replace('$y', s).replace('$c', cn)
addsolution(s)
if not full: return solutions[0]
if verbose:
print(".")
# Check for a direct multiplicative formula
if x != 1:
# Allow fractional powers of fractions
ilogs = [2,3,5,7]
# Watch out for existing fractional powers of fractions
logs = []
for a, s in constants:
if not sum(bool(ctx.findpoly(ctx.ln(a)/ctx.ln(i),1)) for i in ilogs):
logs.append((ctx.ln(a), s))
logs = [(ctx.ln(i),str(i)) for i in ilogs] + logs
r = ctx.pslq([ctx.ln(x)] + [a[0] for a in logs], tol, M)
if r is not None and max(abs(uw) for uw in r) <= M and r[0]:
addsolution(prodstring(r, logs))
if not full: return solutions[0]
if full:
return sorted(solutions, key=len)
else:
return None
IdentificationMethods.pslq = pslq
IdentificationMethods.findpoly = findpoly
IdentificationMethods.identify = identify
if __name__ == '__main__':
import doctest
doctest.testmod()

View File

@ -0,0 +1,77 @@
from .libmpf import (prec_to_dps, dps_to_prec, repr_dps,
round_down, round_up, round_floor, round_ceiling, round_nearest,
to_pickable, from_pickable, ComplexResult,
fzero, fnzero, fone, fnone, ftwo, ften, fhalf, fnan, finf, fninf,
math_float_inf, round_int, normalize, normalize1,
from_man_exp, from_int, to_man_exp, to_int, mpf_ceil, mpf_floor,
mpf_nint, mpf_frac,
from_float, from_npfloat, from_Decimal, to_float, from_rational, to_rational, to_fixed,
mpf_rand, mpf_eq, mpf_hash, mpf_cmp, mpf_lt, mpf_le, mpf_gt, mpf_ge,
mpf_pos, mpf_neg, mpf_abs, mpf_sign, mpf_add, mpf_sub, mpf_sum,
mpf_mul, mpf_mul_int, mpf_shift, mpf_frexp,
mpf_div, mpf_rdiv_int, mpf_mod, mpf_pow_int,
mpf_perturb,
to_digits_exp, to_str, str_to_man_exp, from_str, from_bstr, to_bstr,
mpf_sqrt, mpf_hypot)
from .libmpc import (mpc_one, mpc_zero, mpc_two, mpc_half,
mpc_is_inf, mpc_is_infnan, mpc_to_str, mpc_to_complex, mpc_hash,
mpc_conjugate, mpc_is_nonzero, mpc_add, mpc_add_mpf,
mpc_sub, mpc_sub_mpf, mpc_pos, mpc_neg, mpc_shift, mpc_abs,
mpc_arg, mpc_floor, mpc_ceil, mpc_nint, mpc_frac, mpc_mul, mpc_square,
mpc_mul_mpf, mpc_mul_imag_mpf, mpc_mul_int,
mpc_div, mpc_div_mpf, mpc_reciprocal, mpc_mpf_div,
complex_int_pow, mpc_pow, mpc_pow_mpf, mpc_pow_int,
mpc_sqrt, mpc_nthroot, mpc_cbrt, mpc_exp, mpc_log, mpc_cos, mpc_sin,
mpc_tan, mpc_cos_pi, mpc_sin_pi, mpc_cosh, mpc_sinh, mpc_tanh,
mpc_atan, mpc_acos, mpc_asin, mpc_asinh, mpc_acosh, mpc_atanh,
mpc_fibonacci, mpf_expj, mpf_expjpi, mpc_expj, mpc_expjpi,
mpc_cos_sin, mpc_cos_sin_pi)
from .libelefun import (ln2_fixed, mpf_ln2, ln10_fixed, mpf_ln10,
pi_fixed, mpf_pi, e_fixed, mpf_e, phi_fixed, mpf_phi,
degree_fixed, mpf_degree,
mpf_pow, mpf_nthroot, mpf_cbrt, log_int_fixed, agm_fixed,
mpf_log, mpf_log_hypot, mpf_exp, mpf_cos_sin, mpf_cos, mpf_sin, mpf_tan,
mpf_cos_sin_pi, mpf_cos_pi, mpf_sin_pi, mpf_cosh_sinh,
mpf_cosh, mpf_sinh, mpf_tanh, mpf_atan, mpf_atan2, mpf_asin,
mpf_acos, mpf_asinh, mpf_acosh, mpf_atanh, mpf_fibonacci)
from .libhyper import (NoConvergence, make_hyp_summator,
mpf_erf, mpf_erfc, mpf_ei, mpc_ei, mpf_e1, mpc_e1, mpf_expint,
mpf_ci_si, mpf_ci, mpf_si, mpc_ci, mpc_si, mpf_besseljn,
mpc_besseljn, mpf_agm, mpf_agm1, mpc_agm, mpc_agm1,
mpf_ellipk, mpc_ellipk, mpf_ellipe, mpc_ellipe)
from .gammazeta import (catalan_fixed, mpf_catalan,
khinchin_fixed, mpf_khinchin, glaisher_fixed, mpf_glaisher,
apery_fixed, mpf_apery, euler_fixed, mpf_euler, mertens_fixed,
mpf_mertens, twinprime_fixed, mpf_twinprime,
mpf_bernoulli, bernfrac, mpf_gamma_int,
mpf_factorial, mpc_factorial, mpf_gamma, mpc_gamma,
mpf_loggamma, mpc_loggamma, mpf_rgamma, mpc_rgamma,
mpf_harmonic, mpc_harmonic, mpf_psi0, mpc_psi0,
mpf_psi, mpc_psi, mpf_zeta_int, mpf_zeta, mpc_zeta,
mpf_altzeta, mpc_altzeta, mpf_zetasum, mpc_zetasum)
from .libmpi import (mpi_str,
mpi_from_str, mpi_to_str,
mpi_eq, mpi_ne,
mpi_lt, mpi_le, mpi_gt, mpi_ge,
mpi_add, mpi_sub, mpi_delta, mpi_mid,
mpi_pos, mpi_neg, mpi_abs, mpi_mul, mpi_div, mpi_exp,
mpi_log, mpi_sqrt, mpi_pow_int, mpi_pow, mpi_cos_sin,
mpi_cos, mpi_sin, mpi_tan, mpi_cot,
mpi_atan, mpi_atan2,
mpci_pos, mpci_neg, mpci_add, mpci_sub, mpci_mul, mpci_div, mpci_pow,
mpci_abs, mpci_pow, mpci_exp, mpci_log, mpci_cos, mpci_sin,
mpi_gamma, mpci_gamma, mpi_loggamma, mpci_loggamma,
mpi_rgamma, mpci_rgamma, mpi_factorial, mpci_factorial)
from .libintmath import (trailing, bitcount, numeral, bin_to_radix,
isqrt, isqrt_small, isqrt_fast, sqrt_fixed, sqrtrem, ifib, ifac,
list_primes, isprime, moebius, gcd, eulernum, stirling1, stirling2)
from .backend import (gmpy, sage, BACKEND, STRICT, MPZ, MPZ_TYPE,
MPZ_ZERO, MPZ_ONE, MPZ_TWO, MPZ_THREE, MPZ_FIVE, int_types,
HASH_MODULUS, HASH_BITS)

View File

@ -0,0 +1,115 @@
import os
import sys
#----------------------------------------------------------------------------#
# Support GMPY for high-speed large integer arithmetic. #
# #
# To allow an external module to handle arithmetic, we need to make sure #
# that all high-precision variables are declared of the correct type. MPZ #
# is the constructor for the high-precision type. It defaults to Python's #
# long type but can be assinged another type, typically gmpy.mpz. #
# #
# MPZ must be used for the mantissa component of an mpf and must be used #
# for internal fixed-point operations. #
# #
# Side-effects #
# 1) "is" cannot be used to test for special values. Must use "==". #
# 2) There are bugs in GMPY prior to v1.02 so we must use v1.03 or later. #
#----------------------------------------------------------------------------#
# So we can import it from this module
gmpy = None
sage = None
sage_utils = None
if sys.version_info[0] < 3:
python3 = False
else:
python3 = True
BACKEND = 'python'
if not python3:
MPZ = long
xrange = xrange
basestring = basestring
def exec_(_code_, _globs_=None, _locs_=None):
"""Execute code in a namespace."""
if _globs_ is None:
frame = sys._getframe(1)
_globs_ = frame.f_globals
if _locs_ is None:
_locs_ = frame.f_locals
del frame
elif _locs_ is None:
_locs_ = _globs_
exec("""exec _code_ in _globs_, _locs_""")
else:
MPZ = int
xrange = range
basestring = str
import builtins
exec_ = getattr(builtins, "exec")
# Define constants for calculating hash on Python 3.2.
if sys.version_info >= (3, 2):
HASH_MODULUS = sys.hash_info.modulus
if sys.hash_info.width == 32:
HASH_BITS = 31
else:
HASH_BITS = 61
else:
HASH_MODULUS = None
HASH_BITS = None
if 'MPMATH_NOGMPY' not in os.environ:
try:
try:
import gmpy2 as gmpy
except ImportError:
try:
import gmpy
except ImportError:
raise ImportError
if gmpy.version() >= '1.03':
BACKEND = 'gmpy'
MPZ = gmpy.mpz
except:
pass
if ('MPMATH_NOSAGE' not in os.environ and 'SAGE_ROOT' in os.environ or
'MPMATH_SAGE' in os.environ):
try:
import sage.all
import sage.libs.mpmath.utils as _sage_utils
sage = sage.all
sage_utils = _sage_utils
BACKEND = 'sage'
MPZ = sage.Integer
except:
pass
if 'MPMATH_STRICT' in os.environ:
STRICT = True
else:
STRICT = False
MPZ_TYPE = type(MPZ(0))
MPZ_ZERO = MPZ(0)
MPZ_ONE = MPZ(1)
MPZ_TWO = MPZ(2)
MPZ_THREE = MPZ(3)
MPZ_FIVE = MPZ(5)
try:
if BACKEND == 'python':
int_types = (int, long)
else:
int_types = (int, long, MPZ_TYPE)
except NameError:
if BACKEND == 'python':
int_types = (int,)
else:
int_types = (int, MPZ_TYPE)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,584 @@
"""
Utility functions for integer math.
TODO: rename, cleanup, perhaps move the gmpy wrapper code
here from settings.py
"""
import math
from bisect import bisect
from .backend import xrange
from .backend import BACKEND, gmpy, sage, sage_utils, MPZ, MPZ_ONE, MPZ_ZERO
small_trailing = [0] * 256
for j in range(1,8):
small_trailing[1<<j::1<<(j+1)] = [j] * (1<<(7-j))
def giant_steps(start, target, n=2):
"""
Return a list of integers ~=
[start, n*start, ..., target/n^2, target/n, target]
but conservatively rounded so that the quotient between two
successive elements is actually slightly less than n.
With n = 2, this describes suitable precision steps for a
quadratically convergent algorithm such as Newton's method;
with n = 3 steps for cubic convergence (Halley's method), etc.
>>> giant_steps(50,1000)
[66, 128, 253, 502, 1000]
>>> giant_steps(50,1000,4)
[65, 252, 1000]
"""
L = [target]
while L[-1] > start*n:
L = L + [L[-1]//n + 2]
return L[::-1]
def rshift(x, n):
"""For an integer x, calculate x >> n with the fastest (floor)
rounding. Unlike the plain Python expression (x >> n), n is
allowed to be negative, in which case a left shift is performed."""
if n >= 0: return x >> n
else: return x << (-n)
def lshift(x, n):
"""For an integer x, calculate x << n. Unlike the plain Python
expression (x << n), n is allowed to be negative, in which case a
right shift with default (floor) rounding is performed."""
if n >= 0: return x << n
else: return x >> (-n)
if BACKEND == 'sage':
import operator
rshift = operator.rshift
lshift = operator.lshift
def python_trailing(n):
"""Count the number of trailing zero bits in abs(n)."""
if not n:
return 0
low_byte = n & 0xff
if low_byte:
return small_trailing[low_byte]
t = 8
n >>= 8
while not n & 0xff:
n >>= 8
t += 8
return t + small_trailing[n & 0xff]
if BACKEND == 'gmpy':
if gmpy.version() >= '2':
def gmpy_trailing(n):
"""Count the number of trailing zero bits in abs(n) using gmpy."""
if n: return MPZ(n).bit_scan1()
else: return 0
else:
def gmpy_trailing(n):
"""Count the number of trailing zero bits in abs(n) using gmpy."""
if n: return MPZ(n).scan1()
else: return 0
# Small powers of 2
powers = [1<<_ for _ in range(300)]
def python_bitcount(n):
"""Calculate bit size of the nonnegative integer n."""
bc = bisect(powers, n)
if bc != 300:
return bc
bc = int(math.log(n, 2)) - 4
return bc + bctable[n>>bc]
def gmpy_bitcount(n):
"""Calculate bit size of the nonnegative integer n."""
if n: return MPZ(n).numdigits(2)
else: return 0
#def sage_bitcount(n):
# if n: return MPZ(n).nbits()
# else: return 0
def sage_trailing(n):
return MPZ(n).trailing_zero_bits()
if BACKEND == 'gmpy':
bitcount = gmpy_bitcount
trailing = gmpy_trailing
elif BACKEND == 'sage':
sage_bitcount = sage_utils.bitcount
bitcount = sage_bitcount
trailing = sage_trailing
else:
bitcount = python_bitcount
trailing = python_trailing
if BACKEND == 'gmpy' and 'bit_length' in dir(gmpy):
bitcount = gmpy.bit_length
# Used to avoid slow function calls as far as possible
trailtable = [trailing(n) for n in range(256)]
bctable = [bitcount(n) for n in range(1024)]
# TODO: speed up for bases 2, 4, 8, 16, ...
def bin_to_radix(x, xbits, base, bdigits):
"""Changes radix of a fixed-point number; i.e., converts
x * 2**xbits to floor(x * 10**bdigits)."""
return x * (MPZ(base)**bdigits) >> xbits
stddigits = '0123456789abcdefghijklmnopqrstuvwxyz'
def small_numeral(n, base=10, digits=stddigits):
"""Return the string numeral of a positive integer in an arbitrary
base. Most efficient for small input."""
if base == 10:
return str(n)
digs = []
while n:
n, digit = divmod(n, base)
digs.append(digits[digit])
return "".join(digs[::-1])
def numeral_python(n, base=10, size=0, digits=stddigits):
"""Represent the integer n as a string of digits in the given base.
Recursive division is used to make this function about 3x faster
than Python's str() for converting integers to decimal strings.
The 'size' parameters specifies the number of digits in n; this
number is only used to determine splitting points and need not be
exact."""
if n <= 0:
if not n:
return "0"
return "-" + numeral(-n, base, size, digits)
# Fast enough to do directly
if size < 250:
return small_numeral(n, base, digits)
# Divide in half
half = (size // 2) + (size & 1)
A, B = divmod(n, base**half)
ad = numeral(A, base, half, digits)
bd = numeral(B, base, half, digits).rjust(half, "0")
return ad + bd
def numeral_gmpy(n, base=10, size=0, digits=stddigits):
"""Represent the integer n as a string of digits in the given base.
Recursive division is used to make this function about 3x faster
than Python's str() for converting integers to decimal strings.
The 'size' parameters specifies the number of digits in n; this
number is only used to determine splitting points and need not be
exact."""
if n < 0:
return "-" + numeral(-n, base, size, digits)
# gmpy.digits() may cause a segmentation fault when trying to convert
# extremely large values to a string. The size limit may need to be
# adjusted on some platforms, but 1500000 works on Windows and Linux.
if size < 1500000:
return gmpy.digits(n, base)
# Divide in half
half = (size // 2) + (size & 1)
A, B = divmod(n, MPZ(base)**half)
ad = numeral(A, base, half, digits)
bd = numeral(B, base, half, digits).rjust(half, "0")
return ad + bd
if BACKEND == "gmpy":
numeral = numeral_gmpy
else:
numeral = numeral_python
_1_800 = 1<<800
_1_600 = 1<<600
_1_400 = 1<<400
_1_200 = 1<<200
_1_100 = 1<<100
_1_50 = 1<<50
def isqrt_small_python(x):
"""
Correctly (floor) rounded integer square root, using
division. Fast up to ~200 digits.
"""
if not x:
return x
if x < _1_800:
# Exact with IEEE double precision arithmetic
if x < _1_50:
return int(x**0.5)
# Initial estimate can be any integer >= the true root; round up
r = int(x**0.5 * 1.00000000000001) + 1
else:
bc = bitcount(x)
n = bc//2
r = int((x>>(2*n-100))**0.5+2)<<(n-50) # +2 is to round up
# The following iteration now precisely computes floor(sqrt(x))
# See e.g. Crandall & Pomerance, "Prime Numbers: A Computational
# Perspective"
while 1:
y = (r+x//r)>>1
if y >= r:
return r
r = y
def isqrt_fast_python(x):
"""
Fast approximate integer square root, computed using division-free
Newton iteration for large x. For random integers the result is almost
always correct (floor(sqrt(x))), but is 1 ulp too small with a roughly
0.1% probability. If x is very close to an exact square, the answer is
1 ulp wrong with high probability.
With 0 guard bits, the largest error over a set of 10^5 random
inputs of size 1-10^5 bits was 3 ulp. The use of 10 guard bits
almost certainly guarantees a max 1 ulp error.
"""
# Use direct division-based iteration if sqrt(x) < 2^400
# Assume floating-point square root accurate to within 1 ulp, then:
# 0 Newton iterations good to 52 bits
# 1 Newton iterations good to 104 bits
# 2 Newton iterations good to 208 bits
# 3 Newton iterations good to 416 bits
if x < _1_800:
y = int(x**0.5)
if x >= _1_100:
y = (y + x//y) >> 1
if x >= _1_200:
y = (y + x//y) >> 1
if x >= _1_400:
y = (y + x//y) >> 1
return y
bc = bitcount(x)
guard_bits = 10
x <<= 2*guard_bits
bc += 2*guard_bits
bc += (bc&1)
hbc = bc//2
startprec = min(50, hbc)
# Newton iteration for 1/sqrt(x), with floating-point starting value
r = int(2.0**(2*startprec) * (x >> (bc-2*startprec)) ** -0.5)
pp = startprec
for p in giant_steps(startprec, hbc):
# r**2, scaled from real size 2**(-bc) to 2**p
r2 = (r*r) >> (2*pp - p)
# x*r**2, scaled from real size ~1.0 to 2**p
xr2 = ((x >> (bc-p)) * r2) >> p
# New value of r, scaled from real size 2**(-bc/2) to 2**p
r = (r * ((3<<p) - xr2)) >> (pp+1)
pp = p
# (1/sqrt(x))*x = sqrt(x)
return (r*(x>>hbc)) >> (p+guard_bits)
def sqrtrem_python(x):
"""Correctly rounded integer (floor) square root with remainder."""
# to check cutoff:
# plot(lambda x: timing(isqrt, 2**int(x)), [0,2000])
if x < _1_600:
y = isqrt_small_python(x)
return y, x - y*y
y = isqrt_fast_python(x) + 1
rem = x - y*y
# Correct remainder
while rem < 0:
y -= 1
rem += (1+2*y)
else:
if rem:
while rem > 2*(1+y):
y += 1
rem -= (1+2*y)
return y, rem
def isqrt_python(x):
"""Integer square root with correct (floor) rounding."""
return sqrtrem_python(x)[0]
def sqrt_fixed(x, prec):
return isqrt_fast(x<<prec)
sqrt_fixed2 = sqrt_fixed
if BACKEND == 'gmpy':
if gmpy.version() >= '2':
isqrt_small = isqrt_fast = isqrt = gmpy.isqrt
sqrtrem = gmpy.isqrt_rem
else:
isqrt_small = isqrt_fast = isqrt = gmpy.sqrt
sqrtrem = gmpy.sqrtrem
elif BACKEND == 'sage':
isqrt_small = isqrt_fast = isqrt = \
getattr(sage_utils, "isqrt", lambda n: MPZ(n).isqrt())
sqrtrem = lambda n: MPZ(n).sqrtrem()
else:
isqrt_small = isqrt_small_python
isqrt_fast = isqrt_fast_python
isqrt = isqrt_python
sqrtrem = sqrtrem_python
def ifib(n, _cache={}):
"""Computes the nth Fibonacci number as an integer, for
integer n."""
if n < 0:
return (-1)**(-n+1) * ifib(-n)
if n in _cache:
return _cache[n]
m = n
# Use Dijkstra's logarithmic algorithm
# The following implementation is basically equivalent to
# http://en.literateprograms.org/Fibonacci_numbers_(Scheme)
a, b, p, q = MPZ_ONE, MPZ_ZERO, MPZ_ZERO, MPZ_ONE
while n:
if n & 1:
aq = a*q
a, b = b*q+aq+a*p, b*p+aq
n -= 1
else:
qq = q*q
p, q = p*p+qq, qq+2*p*q
n >>= 1
if m < 250:
_cache[m] = b
return b
MAX_FACTORIAL_CACHE = 1000
def ifac(n, memo={0:1, 1:1}):
"""Return n factorial (for integers n >= 0 only)."""
f = memo.get(n)
if f:
return f
k = len(memo)
p = memo[k-1]
MAX = MAX_FACTORIAL_CACHE
while k <= n:
p *= k
if k <= MAX:
memo[k] = p
k += 1
return p
def ifac2(n, memo_pair=[{0:1}, {1:1}]):
"""Return n!! (double factorial), integers n >= 0 only."""
memo = memo_pair[n&1]
f = memo.get(n)
if f:
return f
k = max(memo)
p = memo[k]
MAX = MAX_FACTORIAL_CACHE
while k < n:
k += 2
p *= k
if k <= MAX:
memo[k] = p
return p
if BACKEND == 'gmpy':
ifac = gmpy.fac
elif BACKEND == 'sage':
ifac = lambda n: int(sage.factorial(n))
ifib = sage.fibonacci
def list_primes(n):
n = n + 1
sieve = list(xrange(n))
sieve[:2] = [0, 0]
for i in xrange(2, int(n**0.5)+1):
if sieve[i]:
for j in xrange(i**2, n, i):
sieve[j] = 0
return [p for p in sieve if p]
if BACKEND == 'sage':
# Note: it is *VERY* important for performance that we convert
# the list to Python ints.
def list_primes(n):
return [int(_) for _ in sage.primes(n+1)]
small_odd_primes = (3,5,7,11,13,17,19,23,29,31,37,41,43,47)
small_odd_primes_set = set(small_odd_primes)
def isprime(n):
"""
Determines whether n is a prime number. A probabilistic test is
performed if n is very large. No special trick is used for detecting
perfect powers.
>>> sum(list_primes(100000))
454396537
>>> sum(n*isprime(n) for n in range(100000))
454396537
"""
n = int(n)
if not n & 1:
return n == 2
if n < 50:
return n in small_odd_primes_set
for p in small_odd_primes:
if not n % p:
return False
m = n-1
s = trailing(m)
d = m >> s
def test(a):
x = pow(a,d,n)
if x == 1 or x == m:
return True
for r in xrange(1,s):
x = x**2 % n
if x == m:
return True
return False
# See http://primes.utm.edu/prove/prove2_3.html
if n < 1373653:
witnesses = [2,3]
elif n < 341550071728321:
witnesses = [2,3,5,7,11,13,17]
else:
witnesses = small_odd_primes
for a in witnesses:
if not test(a):
return False
return True
def moebius(n):
"""
Evaluates the Moebius function which is `mu(n) = (-1)^k` if `n`
is a product of `k` distinct primes and `mu(n) = 0` otherwise.
TODO: speed up using factorization
"""
n = abs(int(n))
if n < 2:
return n
factors = []
for p in xrange(2, n+1):
if not (n % p):
if not (n % p**2):
return 0
if not sum(p % f for f in factors):
factors.append(p)
return (-1)**len(factors)
def gcd(*args):
a = 0
for b in args:
if a:
while b:
a, b = b, a % b
else:
a = b
return a
# Comment by Juan Arias de Reyna:
#
# I learn this method to compute EulerE[2n] from van de Lune.
#
# We apply the formula EulerE[2n] = (-1)^n 2**(-2n) sum_{j=0}^n a(2n,2j+1)
#
# where the numbers a(n,j) vanish for j > n+1 or j <= -1 and satisfies
#
# a(0,-1) = a(0,0) = 0; a(0,1)= 1; a(0,2) = a(0,3) = 0
#
# a(n,j) = a(n-1,j) when n+j is even
# a(n,j) = (j-1) a(n-1,j-1) + (j+1) a(n-1,j+1) when n+j is odd
#
#
# But we can use only one array unidimensional a(j) since to compute
# a(n,j) we only need to know a(n-1,k) where k and j are of different parity
# and we have not to conserve the used values.
#
# We cached up the values of Euler numbers to sufficiently high order.
#
# Important Observation: If we pretend to use the numbers
# EulerE[1], EulerE[2], ... , EulerE[n]
# it is convenient to compute first EulerE[n], since the algorithm
# computes first all
# the previous ones, and keeps them in the CACHE
MAX_EULER_CACHE = 500
def eulernum(m, _cache={0:MPZ_ONE}):
r"""
Computes the Euler numbers `E(n)`, which can be defined as
coefficients of the Taylor expansion of `1/cosh x`:
.. math ::
\frac{1}{\cosh x} = \sum_{n=0}^\infty \frac{E_n}{n!} x^n
Example::
>>> [int(eulernum(n)) for n in range(11)]
[1, 0, -1, 0, 5, 0, -61, 0, 1385, 0, -50521]
>>> [int(eulernum(n)) for n in range(11)] # test cache
[1, 0, -1, 0, 5, 0, -61, 0, 1385, 0, -50521]
"""
# for odd m > 1, the Euler numbers are zero
if m & 1:
return MPZ_ZERO
f = _cache.get(m)
if f:
return f
MAX = MAX_EULER_CACHE
n = m
a = [MPZ(_) for _ in [0,0,1,0,0,0]]
for n in range(1, m+1):
for j in range(n+1, -1, -2):
a[j+1] = (j-1)*a[j] + (j+1)*a[j+2]
a.append(0)
suma = 0
for k in range(n+1, -1, -2):
suma += a[k+1]
if n <= MAX:
_cache[n] = ((-1)**(n//2))*(suma // 2**n)
if n == m:
return ((-1)**(n//2))*suma // 2**n
def stirling1(n, k):
"""
Stirling number of the first kind.
"""
if n < 0 or k < 0:
raise ValueError
if k >= n:
return MPZ(n == k)
if k < 1:
return MPZ_ZERO
L = [MPZ_ZERO] * (k+1)
L[1] = MPZ_ONE
for m in xrange(2, n+1):
for j in xrange(min(k, m), 0, -1):
L[j] = (m-1) * L[j] + L[j-1]
return (-1)**(n+k) * L[k]
def stirling2(n, k):
"""
Stirling number of the second kind.
"""
if n < 0 or k < 0:
raise ValueError
if k >= n:
return MPZ(n == k)
if k <= 1:
return MPZ(k == 1)
s = MPZ_ZERO
t = MPZ_ONE
for j in xrange(k+1):
if (k + j) & 1:
s -= t * MPZ(j)**n
else:
s += t * MPZ(j)**n
t = t * (k - j) // (j + 1)
return s // ifac(k)

View File

@ -0,0 +1,835 @@
"""
Low-level functions for complex arithmetic.
"""
import sys
from .backend import MPZ, MPZ_ZERO, MPZ_ONE, MPZ_TWO, BACKEND
from .libmpf import (\
round_floor, round_ceiling, round_down, round_up,
round_nearest, round_fast, bitcount,
bctable, normalize, normalize1, reciprocal_rnd, rshift, lshift, giant_steps,
negative_rnd,
to_str, to_fixed, from_man_exp, from_float, to_float, from_int, to_int,
fzero, fone, ftwo, fhalf, finf, fninf, fnan, fnone,
mpf_abs, mpf_pos, mpf_neg, mpf_add, mpf_sub, mpf_mul,
mpf_div, mpf_mul_int, mpf_shift, mpf_sqrt, mpf_hypot,
mpf_rdiv_int, mpf_floor, mpf_ceil, mpf_nint, mpf_frac,
mpf_sign, mpf_hash,
ComplexResult
)
from .libelefun import (\
mpf_pi, mpf_exp, mpf_log, mpf_cos_sin, mpf_cosh_sinh, mpf_tan, mpf_pow_int,
mpf_log_hypot,
mpf_cos_sin_pi, mpf_phi,
mpf_cos, mpf_sin, mpf_cos_pi, mpf_sin_pi,
mpf_atan, mpf_atan2, mpf_cosh, mpf_sinh, mpf_tanh,
mpf_asin, mpf_acos, mpf_acosh, mpf_nthroot, mpf_fibonacci
)
# An mpc value is a (real, imag) tuple
mpc_one = fone, fzero
mpc_zero = fzero, fzero
mpc_two = ftwo, fzero
mpc_half = (fhalf, fzero)
_infs = (finf, fninf)
_infs_nan = (finf, fninf, fnan)
def mpc_is_inf(z):
"""Check if either real or imaginary part is infinite"""
re, im = z
if re in _infs: return True
if im in _infs: return True
return False
def mpc_is_infnan(z):
"""Check if either real or imaginary part is infinite or nan"""
re, im = z
if re in _infs_nan: return True
if im in _infs_nan: return True
return False
def mpc_to_str(z, dps, **kwargs):
re, im = z
rs = to_str(re, dps)
if im[0]:
return rs + " - " + to_str(mpf_neg(im), dps, **kwargs) + "j"
else:
return rs + " + " + to_str(im, dps, **kwargs) + "j"
def mpc_to_complex(z, strict=False, rnd=round_fast):
re, im = z
return complex(to_float(re, strict, rnd), to_float(im, strict, rnd))
def mpc_hash(z):
if sys.version_info >= (3, 2):
re, im = z
h = mpf_hash(re) + sys.hash_info.imag * mpf_hash(im)
# Need to reduce either module 2^32 or 2^64
h = h % (2**sys.hash_info.width)
return int(h)
else:
try:
return hash(mpc_to_complex(z, strict=True))
except OverflowError:
return hash(z)
def mpc_conjugate(z, prec, rnd=round_fast):
re, im = z
return re, mpf_neg(im, prec, rnd)
def mpc_is_nonzero(z):
return z != mpc_zero
def mpc_add(z, w, prec, rnd=round_fast):
a, b = z
c, d = w
return mpf_add(a, c, prec, rnd), mpf_add(b, d, prec, rnd)
def mpc_add_mpf(z, x, prec, rnd=round_fast):
a, b = z
return mpf_add(a, x, prec, rnd), b
def mpc_sub(z, w, prec=0, rnd=round_fast):
a, b = z
c, d = w
return mpf_sub(a, c, prec, rnd), mpf_sub(b, d, prec, rnd)
def mpc_sub_mpf(z, p, prec=0, rnd=round_fast):
a, b = z
return mpf_sub(a, p, prec, rnd), b
def mpc_pos(z, prec, rnd=round_fast):
a, b = z
return mpf_pos(a, prec, rnd), mpf_pos(b, prec, rnd)
def mpc_neg(z, prec=None, rnd=round_fast):
a, b = z
return mpf_neg(a, prec, rnd), mpf_neg(b, prec, rnd)
def mpc_shift(z, n):
a, b = z
return mpf_shift(a, n), mpf_shift(b, n)
def mpc_abs(z, prec, rnd=round_fast):
"""Absolute value of a complex number, |a+bi|.
Returns an mpf value."""
a, b = z
return mpf_hypot(a, b, prec, rnd)
def mpc_arg(z, prec, rnd=round_fast):
"""Argument of a complex number. Returns an mpf value."""
a, b = z
return mpf_atan2(b, a, prec, rnd)
def mpc_floor(z, prec, rnd=round_fast):
a, b = z
return mpf_floor(a, prec, rnd), mpf_floor(b, prec, rnd)
def mpc_ceil(z, prec, rnd=round_fast):
a, b = z
return mpf_ceil(a, prec, rnd), mpf_ceil(b, prec, rnd)
def mpc_nint(z, prec, rnd=round_fast):
a, b = z
return mpf_nint(a, prec, rnd), mpf_nint(b, prec, rnd)
def mpc_frac(z, prec, rnd=round_fast):
a, b = z
return mpf_frac(a, prec, rnd), mpf_frac(b, prec, rnd)
def mpc_mul(z, w, prec, rnd=round_fast):
"""
Complex multiplication.
Returns the real and imaginary part of (a+bi)*(c+di), rounded to
the specified precision. The rounding mode applies to the real and
imaginary parts separately.
"""
a, b = z
c, d = w
p = mpf_mul(a, c)
q = mpf_mul(b, d)
r = mpf_mul(a, d)
s = mpf_mul(b, c)
re = mpf_sub(p, q, prec, rnd)
im = mpf_add(r, s, prec, rnd)
return re, im
def mpc_square(z, prec, rnd=round_fast):
# (a+b*I)**2 == a**2 - b**2 + 2*I*a*b
a, b = z
p = mpf_mul(a,a)
q = mpf_mul(b,b)
r = mpf_mul(a,b, prec, rnd)
re = mpf_sub(p, q, prec, rnd)
im = mpf_shift(r, 1)
return re, im
def mpc_mul_mpf(z, p, prec, rnd=round_fast):
a, b = z
re = mpf_mul(a, p, prec, rnd)
im = mpf_mul(b, p, prec, rnd)
return re, im
def mpc_mul_imag_mpf(z, x, prec, rnd=round_fast):
"""
Multiply the mpc value z by I*x where x is an mpf value.
"""
a, b = z
re = mpf_neg(mpf_mul(b, x, prec, rnd))
im = mpf_mul(a, x, prec, rnd)
return re, im
def mpc_mul_int(z, n, prec, rnd=round_fast):
a, b = z
re = mpf_mul_int(a, n, prec, rnd)
im = mpf_mul_int(b, n, prec, rnd)
return re, im
def mpc_div(z, w, prec, rnd=round_fast):
a, b = z
c, d = w
wp = prec + 10
# mag = c*c + d*d
mag = mpf_add(mpf_mul(c, c), mpf_mul(d, d), wp)
# (a*c+b*d)/mag, (b*c-a*d)/mag
t = mpf_add(mpf_mul(a,c), mpf_mul(b,d), wp)
u = mpf_sub(mpf_mul(b,c), mpf_mul(a,d), wp)
return mpf_div(t,mag,prec,rnd), mpf_div(u,mag,prec,rnd)
def mpc_div_mpf(z, p, prec, rnd=round_fast):
"""Calculate z/p where p is real"""
a, b = z
re = mpf_div(a, p, prec, rnd)
im = mpf_div(b, p, prec, rnd)
return re, im
def mpc_reciprocal(z, prec, rnd=round_fast):
"""Calculate 1/z efficiently"""
a, b = z
m = mpf_add(mpf_mul(a,a),mpf_mul(b,b),prec+10)
re = mpf_div(a, m, prec, rnd)
im = mpf_neg(mpf_div(b, m, prec, rnd))
return re, im
def mpc_mpf_div(p, z, prec, rnd=round_fast):
"""Calculate p/z where p is real efficiently"""
a, b = z
m = mpf_add(mpf_mul(a,a),mpf_mul(b,b), prec+10)
re = mpf_div(mpf_mul(a,p), m, prec, rnd)
im = mpf_div(mpf_neg(mpf_mul(b,p)), m, prec, rnd)
return re, im
def complex_int_pow(a, b, n):
"""Complex integer power: computes (a+b*I)**n exactly for
nonnegative n (a and b must be Python ints)."""
wre = 1
wim = 0
while n:
if n & 1:
wre, wim = wre*a - wim*b, wim*a + wre*b
n -= 1
a, b = a*a - b*b, 2*a*b
n //= 2
return wre, wim
def mpc_pow(z, w, prec, rnd=round_fast):
if w[1] == fzero:
return mpc_pow_mpf(z, w[0], prec, rnd)
return mpc_exp(mpc_mul(mpc_log(z, prec+10), w, prec+10), prec, rnd)
def mpc_pow_mpf(z, p, prec, rnd=round_fast):
psign, pman, pexp, pbc = p
if pexp >= 0:
return mpc_pow_int(z, (-1)**psign * (pman<<pexp), prec, rnd)
if pexp == -1:
sqrtz = mpc_sqrt(z, prec+10)
return mpc_pow_int(sqrtz, (-1)**psign * pman, prec, rnd)
return mpc_exp(mpc_mul_mpf(mpc_log(z, prec+10), p, prec+10), prec, rnd)
def mpc_pow_int(z, n, prec, rnd=round_fast):
a, b = z
if b == fzero:
return mpf_pow_int(a, n, prec, rnd), fzero
if a == fzero:
v = mpf_pow_int(b, n, prec, rnd)
n %= 4
if n == 0:
return v, fzero
elif n == 1:
return fzero, v
elif n == 2:
return mpf_neg(v), fzero
elif n == 3:
return fzero, mpf_neg(v)
if n == 0: return mpc_one
if n == 1: return mpc_pos(z, prec, rnd)
if n == 2: return mpc_square(z, prec, rnd)
if n == -1: return mpc_reciprocal(z, prec, rnd)
if n < 0: return mpc_reciprocal(mpc_pow_int(z, -n, prec+4), prec, rnd)
asign, aman, aexp, abc = a
bsign, bman, bexp, bbc = b
if asign: aman = -aman
if bsign: bman = -bman
de = aexp - bexp
abs_de = abs(de)
exact_size = n*(abs_de + max(abc, bbc))
if exact_size < 10000:
if de > 0:
aman <<= de
aexp = bexp
else:
bman <<= (-de)
bexp = aexp
re, im = complex_int_pow(aman, bman, n)
re = from_man_exp(re, int(n*aexp), prec, rnd)
im = from_man_exp(im, int(n*bexp), prec, rnd)
return re, im
return mpc_exp(mpc_mul_int(mpc_log(z, prec+10), n, prec+10), prec, rnd)
def mpc_sqrt(z, prec, rnd=round_fast):
"""Complex square root (principal branch).
We have sqrt(a+bi) = sqrt((r+a)/2) + b/sqrt(2*(r+a))*i where
r = abs(a+bi), when a+bi is not a negative real number."""
a, b = z
if b == fzero:
if a == fzero:
return (a, b)
# When a+bi is a negative real number, we get a real sqrt times i
if a[0]:
im = mpf_sqrt(mpf_neg(a), prec, rnd)
return (fzero, im)
else:
re = mpf_sqrt(a, prec, rnd)
return (re, fzero)
wp = prec+20
if not a[0]: # case a positive
t = mpf_add(mpc_abs((a, b), wp), a, wp) # t = abs(a+bi) + a
u = mpf_shift(t, -1) # u = t/2
re = mpf_sqrt(u, prec, rnd) # re = sqrt(u)
v = mpf_shift(t, 1) # v = 2*t
w = mpf_sqrt(v, wp) # w = sqrt(v)
im = mpf_div(b, w, prec, rnd) # im = b / w
else: # case a negative
t = mpf_sub(mpc_abs((a, b), wp), a, wp) # t = abs(a+bi) - a
u = mpf_shift(t, -1) # u = t/2
im = mpf_sqrt(u, prec, rnd) # im = sqrt(u)
v = mpf_shift(t, 1) # v = 2*t
w = mpf_sqrt(v, wp) # w = sqrt(v)
re = mpf_div(b, w, prec, rnd) # re = b/w
if b[0]:
re = mpf_neg(re)
im = mpf_neg(im)
return re, im
def mpc_nthroot_fixed(a, b, n, prec):
# a, b signed integers at fixed precision prec
start = 50
a1 = int(rshift(a, prec - n*start))
b1 = int(rshift(b, prec - n*start))
try:
r = (a1 + 1j * b1)**(1.0/n)
re = r.real
im = r.imag
re = MPZ(int(re))
im = MPZ(int(im))
except OverflowError:
a1 = from_int(a1, start)
b1 = from_int(b1, start)
fn = from_int(n)
nth = mpf_rdiv_int(1, fn, start)
re, im = mpc_pow((a1, b1), (nth, fzero), start)
re = to_int(re)
im = to_int(im)
extra = 10
prevp = start
extra1 = n
for p in giant_steps(start, prec+extra):
# this is slow for large n, unlike int_pow_fixed
re2, im2 = complex_int_pow(re, im, n-1)
re2 = rshift(re2, (n-1)*prevp - p - extra1)
im2 = rshift(im2, (n-1)*prevp - p - extra1)
r4 = (re2*re2 + im2*im2) >> (p + extra1)
ap = rshift(a, prec - p)
bp = rshift(b, prec - p)
rec = (ap * re2 + bp * im2) >> p
imc = (-ap * im2 + bp * re2) >> p
reb = (rec << p) // r4
imb = (imc << p) // r4
re = (reb + (n-1)*lshift(re, p-prevp))//n
im = (imb + (n-1)*lshift(im, p-prevp))//n
prevp = p
return re, im
def mpc_nthroot(z, n, prec, rnd=round_fast):
"""
Complex n-th root.
Use Newton method as in the real case when it is faster,
otherwise use z**(1/n)
"""
a, b = z
if a[0] == 0 and b == fzero:
re = mpf_nthroot(a, n, prec, rnd)
return (re, fzero)
if n < 2:
if n == 0:
return mpc_one
if n == 1:
return mpc_pos((a, b), prec, rnd)
if n == -1:
return mpc_div(mpc_one, (a, b), prec, rnd)
inverse = mpc_nthroot((a, b), -n, prec+5, reciprocal_rnd[rnd])
return mpc_div(mpc_one, inverse, prec, rnd)
if n <= 20:
prec2 = int(1.2 * (prec + 10))
asign, aman, aexp, abc = a
bsign, bman, bexp, bbc = b
pf = mpc_abs((a,b), prec)
if pf[-2] + pf[-1] > -10 and pf[-2] + pf[-1] < prec:
af = to_fixed(a, prec2)
bf = to_fixed(b, prec2)
re, im = mpc_nthroot_fixed(af, bf, n, prec2)
extra = 10
re = from_man_exp(re, -prec2-extra, prec2, rnd)
im = from_man_exp(im, -prec2-extra, prec2, rnd)
return re, im
fn = from_int(n)
prec2 = prec+10 + 10
nth = mpf_rdiv_int(1, fn, prec2)
re, im = mpc_pow((a, b), (nth, fzero), prec2, rnd)
re = normalize(re[0], re[1], re[2], re[3], prec, rnd)
im = normalize(im[0], im[1], im[2], im[3], prec, rnd)
return re, im
def mpc_cbrt(z, prec, rnd=round_fast):
"""
Complex cubic root.
"""
return mpc_nthroot(z, 3, prec, rnd)
def mpc_exp(z, prec, rnd=round_fast):
"""
Complex exponential function.
We use the direct formula exp(a+bi) = exp(a) * (cos(b) + sin(b)*i)
for the computation. This formula is very nice because it is
pefectly stable; since we just do real multiplications, the only
numerical errors that can creep in are single-ulp rounding errors.
The formula is efficient since mpmath's real exp is quite fast and
since we can compute cos and sin simultaneously.
It is no problem if a and b are large; if the implementations of
exp/cos/sin are accurate and efficient for all real numbers, then
so is this function for all complex numbers.
"""
a, b = z
if a == fzero:
return mpf_cos_sin(b, prec, rnd)
if b == fzero:
return mpf_exp(a, prec, rnd), fzero
mag = mpf_exp(a, prec+4, rnd)
c, s = mpf_cos_sin(b, prec+4, rnd)
re = mpf_mul(mag, c, prec, rnd)
im = mpf_mul(mag, s, prec, rnd)
return re, im
def mpc_log(z, prec, rnd=round_fast):
re = mpf_log_hypot(z[0], z[1], prec, rnd)
im = mpc_arg(z, prec, rnd)
return re, im
def mpc_cos(z, prec, rnd=round_fast):
"""Complex cosine. The formula used is cos(a+bi) = cos(a)*cosh(b) -
sin(a)*sinh(b)*i.
The same comments apply as for the complex exp: only real
multiplications are pewrormed, so no cancellation errors are
possible. The formula is also efficient since we can compute both
pairs (cos, sin) and (cosh, sinh) in single stwps."""
a, b = z
if b == fzero:
return mpf_cos(a, prec, rnd), fzero
if a == fzero:
return mpf_cosh(b, prec, rnd), fzero
wp = prec + 6
c, s = mpf_cos_sin(a, wp)
ch, sh = mpf_cosh_sinh(b, wp)
re = mpf_mul(c, ch, prec, rnd)
im = mpf_mul(s, sh, prec, rnd)
return re, mpf_neg(im)
def mpc_sin(z, prec, rnd=round_fast):
"""Complex sine. We have sin(a+bi) = sin(a)*cosh(b) +
cos(a)*sinh(b)*i. See the docstring for mpc_cos for additional
comments."""
a, b = z
if b == fzero:
return mpf_sin(a, prec, rnd), fzero
if a == fzero:
return fzero, mpf_sinh(b, prec, rnd)
wp = prec + 6
c, s = mpf_cos_sin(a, wp)
ch, sh = mpf_cosh_sinh(b, wp)
re = mpf_mul(s, ch, prec, rnd)
im = mpf_mul(c, sh, prec, rnd)
return re, im
def mpc_tan(z, prec, rnd=round_fast):
"""Complex tangent. Computed as tan(a+bi) = sin(2a)/M + sinh(2b)/M*i
where M = cos(2a) + cosh(2b)."""
a, b = z
asign, aman, aexp, abc = a
bsign, bman, bexp, bbc = b
if b == fzero: return mpf_tan(a, prec, rnd), fzero
if a == fzero: return fzero, mpf_tanh(b, prec, rnd)
wp = prec + 15
a = mpf_shift(a, 1)
b = mpf_shift(b, 1)
c, s = mpf_cos_sin(a, wp)
ch, sh = mpf_cosh_sinh(b, wp)
# TODO: handle cancellation when c ~= -1 and ch ~= 1
mag = mpf_add(c, ch, wp)
re = mpf_div(s, mag, prec, rnd)
im = mpf_div(sh, mag, prec, rnd)
return re, im
def mpc_cos_pi(z, prec, rnd=round_fast):
a, b = z
if b == fzero:
return mpf_cos_pi(a, prec, rnd), fzero
b = mpf_mul(b, mpf_pi(prec+5), prec+5)
if a == fzero:
return mpf_cosh(b, prec, rnd), fzero
wp = prec + 6
c, s = mpf_cos_sin_pi(a, wp)
ch, sh = mpf_cosh_sinh(b, wp)
re = mpf_mul(c, ch, prec, rnd)
im = mpf_mul(s, sh, prec, rnd)
return re, mpf_neg(im)
def mpc_sin_pi(z, prec, rnd=round_fast):
a, b = z
if b == fzero:
return mpf_sin_pi(a, prec, rnd), fzero
b = mpf_mul(b, mpf_pi(prec+5), prec+5)
if a == fzero:
return fzero, mpf_sinh(b, prec, rnd)
wp = prec + 6
c, s = mpf_cos_sin_pi(a, wp)
ch, sh = mpf_cosh_sinh(b, wp)
re = mpf_mul(s, ch, prec, rnd)
im = mpf_mul(c, sh, prec, rnd)
return re, im
def mpc_cos_sin(z, prec, rnd=round_fast):
a, b = z
if a == fzero:
ch, sh = mpf_cosh_sinh(b, prec, rnd)
return (ch, fzero), (fzero, sh)
if b == fzero:
c, s = mpf_cos_sin(a, prec, rnd)
return (c, fzero), (s, fzero)
wp = prec + 6
c, s = mpf_cos_sin(a, wp)
ch, sh = mpf_cosh_sinh(b, wp)
cre = mpf_mul(c, ch, prec, rnd)
cim = mpf_mul(s, sh, prec, rnd)
sre = mpf_mul(s, ch, prec, rnd)
sim = mpf_mul(c, sh, prec, rnd)
return (cre, mpf_neg(cim)), (sre, sim)
def mpc_cos_sin_pi(z, prec, rnd=round_fast):
a, b = z
if b == fzero:
c, s = mpf_cos_sin_pi(a, prec, rnd)
return (c, fzero), (s, fzero)
b = mpf_mul(b, mpf_pi(prec+5), prec+5)
if a == fzero:
ch, sh = mpf_cosh_sinh(b, prec, rnd)
return (ch, fzero), (fzero, sh)
wp = prec + 6
c, s = mpf_cos_sin_pi(a, wp)
ch, sh = mpf_cosh_sinh(b, wp)
cre = mpf_mul(c, ch, prec, rnd)
cim = mpf_mul(s, sh, prec, rnd)
sre = mpf_mul(s, ch, prec, rnd)
sim = mpf_mul(c, sh, prec, rnd)
return (cre, mpf_neg(cim)), (sre, sim)
def mpc_cosh(z, prec, rnd=round_fast):
"""Complex hyperbolic cosine. Computed as cosh(z) = cos(z*i)."""
a, b = z
return mpc_cos((b, mpf_neg(a)), prec, rnd)
def mpc_sinh(z, prec, rnd=round_fast):
"""Complex hyperbolic sine. Computed as sinh(z) = -i*sin(z*i)."""
a, b = z
b, a = mpc_sin((b, a), prec, rnd)
return a, b
def mpc_tanh(z, prec, rnd=round_fast):
"""Complex hyperbolic tangent. Computed as tanh(z) = -i*tan(z*i)."""
a, b = z
b, a = mpc_tan((b, a), prec, rnd)
return a, b
# TODO: avoid loss of accuracy
def mpc_atan(z, prec, rnd=round_fast):
a, b = z
# atan(z) = (I/2)*(log(1-I*z) - log(1+I*z))
# x = 1-I*z = 1 + b - I*a
# y = 1+I*z = 1 - b + I*a
wp = prec + 15
x = mpf_add(fone, b, wp), mpf_neg(a)
y = mpf_sub(fone, b, wp), a
l1 = mpc_log(x, wp)
l2 = mpc_log(y, wp)
a, b = mpc_sub(l1, l2, prec, rnd)
# (I/2) * (a+b*I) = (-b/2 + a/2*I)
v = mpf_neg(mpf_shift(b,-1)), mpf_shift(a,-1)
# Subtraction at infinity gives correct real part but
# wrong imaginary part (should be zero)
if v[1] == fnan and mpc_is_inf(z):
v = (v[0], fzero)
return v
beta_crossover = from_float(0.6417)
alpha_crossover = from_float(1.5)
def acos_asin(z, prec, rnd, n):
""" complex acos for n = 0, asin for n = 1
The algorithm is described in
T.E. Hull, T.F. Fairgrieve and P.T.P. Tang
'Implementing the Complex Arcsine and Arcosine Functions
using Exception Handling',
ACM Trans. on Math. Software Vol. 23 (1997), p299
The complex acos and asin can be defined as
acos(z) = acos(beta) - I*sign(a)* log(alpha + sqrt(alpha**2 -1))
asin(z) = asin(beta) + I*sign(a)* log(alpha + sqrt(alpha**2 -1))
where z = a + I*b
alpha = (1/2)*(r + s); beta = (1/2)*(r - s) = a/alpha
r = sqrt((a+1)**2 + y**2); s = sqrt((a-1)**2 + y**2)
These expressions are rewritten in different ways in different
regions, delimited by two crossovers alpha_crossover and beta_crossover,
and by abs(a) <= 1, in order to improve the numerical accuracy.
"""
a, b = z
wp = prec + 10
# special cases with real argument
if b == fzero:
am = mpf_sub(fone, mpf_abs(a), wp)
# case abs(a) <= 1
if not am[0]:
if n == 0:
return mpf_acos(a, prec, rnd), fzero
else:
return mpf_asin(a, prec, rnd), fzero
# cases abs(a) > 1
else:
# case a < -1
if a[0]:
pi = mpf_pi(prec, rnd)
c = mpf_acosh(mpf_neg(a), prec, rnd)
if n == 0:
return pi, mpf_neg(c)
else:
return mpf_neg(mpf_shift(pi, -1)), c
# case a > 1
else:
c = mpf_acosh(a, prec, rnd)
if n == 0:
return fzero, c
else:
pi = mpf_pi(prec, rnd)
return mpf_shift(pi, -1), mpf_neg(c)
asign = bsign = 0
if a[0]:
a = mpf_neg(a)
asign = 1
if b[0]:
b = mpf_neg(b)
bsign = 1
am = mpf_sub(fone, a, wp)
ap = mpf_add(fone, a, wp)
r = mpf_hypot(ap, b, wp)
s = mpf_hypot(am, b, wp)
alpha = mpf_shift(mpf_add(r, s, wp), -1)
beta = mpf_div(a, alpha, wp)
b2 = mpf_mul(b,b, wp)
# case beta <= beta_crossover
if not mpf_sub(beta_crossover, beta, wp)[0]:
if n == 0:
re = mpf_acos(beta, wp)
else:
re = mpf_asin(beta, wp)
else:
# to compute the real part in this region use the identity
# asin(beta) = atan(beta/sqrt(1-beta**2))
# beta/sqrt(1-beta**2) = (alpha + a) * (alpha - a)
# alpha + a is numerically accurate; alpha - a can have
# cancellations leading to numerical inaccuracies, so rewrite
# it in differente ways according to the region
Ax = mpf_add(alpha, a, wp)
# case a <= 1
if not am[0]:
# c = b*b/(r + (a+1)); d = (s + (1-a))
# alpha - a = (1/2)*(c + d)
# case n=0: re = atan(sqrt((1/2) * Ax * (c + d))/a)
# case n=1: re = atan(a/sqrt((1/2) * Ax * (c + d)))
c = mpf_div(b2, mpf_add(r, ap, wp), wp)
d = mpf_add(s, am, wp)
re = mpf_shift(mpf_mul(Ax, mpf_add(c, d, wp), wp), -1)
if n == 0:
re = mpf_atan(mpf_div(mpf_sqrt(re, wp), a, wp), wp)
else:
re = mpf_atan(mpf_div(a, mpf_sqrt(re, wp), wp), wp)
else:
# c = Ax/(r + (a+1)); d = Ax/(s - (1-a))
# alpha - a = (1/2)*(c + d)
# case n = 0: re = atan(b*sqrt(c + d)/2/a)
# case n = 1: re = atan(a/(b*sqrt(c + d)/2)
c = mpf_div(Ax, mpf_add(r, ap, wp), wp)
d = mpf_div(Ax, mpf_sub(s, am, wp), wp)
re = mpf_shift(mpf_add(c, d, wp), -1)
re = mpf_mul(b, mpf_sqrt(re, wp), wp)
if n == 0:
re = mpf_atan(mpf_div(re, a, wp), wp)
else:
re = mpf_atan(mpf_div(a, re, wp), wp)
# to compute alpha + sqrt(alpha**2 - 1), if alpha <= alpha_crossover
# replace it with 1 + Am1 + sqrt(Am1*(alpha+1)))
# where Am1 = alpha -1
# if alpha <= alpha_crossover:
if not mpf_sub(alpha_crossover, alpha, wp)[0]:
c1 = mpf_div(b2, mpf_add(r, ap, wp), wp)
# case a < 1
if mpf_neg(am)[0]:
# Am1 = (1/2) * (b*b/(r + (a+1)) + b*b/(s + (1-a))
c2 = mpf_add(s, am, wp)
c2 = mpf_div(b2, c2, wp)
Am1 = mpf_shift(mpf_add(c1, c2, wp), -1)
else:
# Am1 = (1/2) * (b*b/(r + (a+1)) + (s - (1-a)))
c2 = mpf_sub(s, am, wp)
Am1 = mpf_shift(mpf_add(c1, c2, wp), -1)
# im = log(1 + Am1 + sqrt(Am1*(alpha+1)))
im = mpf_mul(Am1, mpf_add(alpha, fone, wp), wp)
im = mpf_log(mpf_add(fone, mpf_add(Am1, mpf_sqrt(im, wp), wp), wp), wp)
else:
# im = log(alpha + sqrt(alpha*alpha - 1))
im = mpf_sqrt(mpf_sub(mpf_mul(alpha, alpha, wp), fone, wp), wp)
im = mpf_log(mpf_add(alpha, im, wp), wp)
if asign:
if n == 0:
re = mpf_sub(mpf_pi(wp), re, wp)
else:
re = mpf_neg(re)
if not bsign and n == 0:
im = mpf_neg(im)
if bsign and n == 1:
im = mpf_neg(im)
re = normalize(re[0], re[1], re[2], re[3], prec, rnd)
im = normalize(im[0], im[1], im[2], im[3], prec, rnd)
return re, im
def mpc_acos(z, prec, rnd=round_fast):
return acos_asin(z, prec, rnd, 0)
def mpc_asin(z, prec, rnd=round_fast):
return acos_asin(z, prec, rnd, 1)
def mpc_asinh(z, prec, rnd=round_fast):
# asinh(z) = I * asin(-I z)
a, b = z
a, b = mpc_asin((b, mpf_neg(a)), prec, rnd)
return mpf_neg(b), a
def mpc_acosh(z, prec, rnd=round_fast):
# acosh(z) = -I * acos(z) for Im(acos(z)) <= 0
# +I * acos(z) otherwise
a, b = mpc_acos(z, prec, rnd)
if b[0] or b == fzero:
return mpf_neg(b), a
else:
return b, mpf_neg(a)
def mpc_atanh(z, prec, rnd=round_fast):
# atanh(z) = (log(1+z)-log(1-z))/2
wp = prec + 15
a = mpc_add(z, mpc_one, wp)
b = mpc_sub(mpc_one, z, wp)
a = mpc_log(a, wp)
b = mpc_log(b, wp)
v = mpc_shift(mpc_sub(a, b, wp), -1)
# Subtraction at infinity gives correct imaginary part but
# wrong real part (should be zero)
if v[0] == fnan and mpc_is_inf(z):
v = (fzero, v[1])
return v
def mpc_fibonacci(z, prec, rnd=round_fast):
re, im = z
if im == fzero:
return (mpf_fibonacci(re, prec, rnd), fzero)
size = max(abs(re[2]+re[3]), abs(re[2]+re[3]))
wp = prec + size + 20
a = mpf_phi(wp)
b = mpf_add(mpf_shift(a, 1), fnone, wp)
u = mpc_pow((a, fzero), z, wp)
v = mpc_cos_pi(z, wp)
v = mpc_div(v, u, wp)
u = mpc_sub(u, v, wp)
u = mpc_div_mpf(u, b, prec, rnd)
return u
def mpf_expj(x, prec, rnd='f'):
raise ComplexResult
def mpc_expj(z, prec, rnd='f'):
re, im = z
if im == fzero:
return mpf_cos_sin(re, prec, rnd)
if re == fzero:
return mpf_exp(mpf_neg(im), prec, rnd), fzero
ey = mpf_exp(mpf_neg(im), prec+10)
c, s = mpf_cos_sin(re, prec+10)
re = mpf_mul(ey, c, prec, rnd)
im = mpf_mul(ey, s, prec, rnd)
return re, im
def mpf_expjpi(x, prec, rnd='f'):
raise ComplexResult
def mpc_expjpi(z, prec, rnd='f'):
re, im = z
if im == fzero:
return mpf_cos_sin_pi(re, prec, rnd)
sign, man, exp, bc = im
wp = prec+10
if man:
wp += max(0, exp+bc)
im = mpf_neg(mpf_mul(mpf_pi(wp), im, wp))
if re == fzero:
return mpf_exp(im, prec, rnd), fzero
ey = mpf_exp(im, prec+10)
c, s = mpf_cos_sin_pi(re, prec+10)
re = mpf_mul(ey, c, prec, rnd)
im = mpf_mul(ey, s, prec, rnd)
return re, im
if BACKEND == 'sage':
try:
import sage.libs.mpmath.ext_libmp as _lbmp
mpc_exp = _lbmp.mpc_exp
mpc_sqrt = _lbmp.mpc_sqrt
except (ImportError, AttributeError):
print("Warning: Sage imports in libmpc failed")

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,935 @@
"""
Computational functions for interval arithmetic.
"""
from .backend import xrange
from .libmpf import (
ComplexResult,
round_down, round_up, round_floor, round_ceiling, round_nearest,
prec_to_dps, repr_dps, dps_to_prec,
bitcount,
from_float,
fnan, finf, fninf, fzero, fhalf, fone, fnone,
mpf_sign, mpf_lt, mpf_le, mpf_gt, mpf_ge, mpf_eq, mpf_cmp,
mpf_min_max,
mpf_floor, from_int, to_int, to_str, from_str,
mpf_abs, mpf_neg, mpf_pos, mpf_add, mpf_sub, mpf_mul, mpf_mul_int,
mpf_div, mpf_shift, mpf_pow_int,
from_man_exp, MPZ_ONE)
from .libelefun import (
mpf_log, mpf_exp, mpf_sqrt, mpf_atan, mpf_atan2,
mpf_pi, mod_pi2, mpf_cos_sin
)
from .gammazeta import mpf_gamma, mpf_rgamma, mpf_loggamma, mpc_loggamma
def mpi_str(s, prec):
sa, sb = s
dps = prec_to_dps(prec) + 5
return "[%s, %s]" % (to_str(sa, dps), to_str(sb, dps))
#dps = prec_to_dps(prec)
#m = mpi_mid(s, prec)
#d = mpf_shift(mpi_delta(s, 20), -1)
#return "%s +/- %s" % (to_str(m, dps), to_str(d, 3))
mpi_zero = (fzero, fzero)
mpi_one = (fone, fone)
def mpi_eq(s, t):
return s == t
def mpi_ne(s, t):
return s != t
def mpi_lt(s, t):
sa, sb = s
ta, tb = t
if mpf_lt(sb, ta): return True
if mpf_ge(sa, tb): return False
return None
def mpi_le(s, t):
sa, sb = s
ta, tb = t
if mpf_le(sb, ta): return True
if mpf_gt(sa, tb): return False
return None
def mpi_gt(s, t): return mpi_lt(t, s)
def mpi_ge(s, t): return mpi_le(t, s)
def mpi_add(s, t, prec=0):
sa, sb = s
ta, tb = t
a = mpf_add(sa, ta, prec, round_floor)
b = mpf_add(sb, tb, prec, round_ceiling)
if a == fnan: a = fninf
if b == fnan: b = finf
return a, b
def mpi_sub(s, t, prec=0):
sa, sb = s
ta, tb = t
a = mpf_sub(sa, tb, prec, round_floor)
b = mpf_sub(sb, ta, prec, round_ceiling)
if a == fnan: a = fninf
if b == fnan: b = finf
return a, b
def mpi_delta(s, prec):
sa, sb = s
return mpf_sub(sb, sa, prec, round_up)
def mpi_mid(s, prec):
sa, sb = s
return mpf_shift(mpf_add(sa, sb, prec, round_nearest), -1)
def mpi_pos(s, prec):
sa, sb = s
a = mpf_pos(sa, prec, round_floor)
b = mpf_pos(sb, prec, round_ceiling)
return a, b
def mpi_neg(s, prec=0):
sa, sb = s
a = mpf_neg(sb, prec, round_floor)
b = mpf_neg(sa, prec, round_ceiling)
return a, b
def mpi_abs(s, prec=0):
sa, sb = s
sas = mpf_sign(sa)
sbs = mpf_sign(sb)
# Both points nonnegative?
if sas >= 0:
a = mpf_pos(sa, prec, round_floor)
b = mpf_pos(sb, prec, round_ceiling)
# Upper point nonnegative?
elif sbs >= 0:
a = fzero
negsa = mpf_neg(sa)
if mpf_lt(negsa, sb):
b = mpf_pos(sb, prec, round_ceiling)
else:
b = mpf_pos(negsa, prec, round_ceiling)
# Both negative?
else:
a = mpf_neg(sb, prec, round_floor)
b = mpf_neg(sa, prec, round_ceiling)
return a, b
# TODO: optimize
def mpi_mul_mpf(s, t, prec):
return mpi_mul(s, (t, t), prec)
def mpi_div_mpf(s, t, prec):
return mpi_div(s, (t, t), prec)
def mpi_mul(s, t, prec=0):
sa, sb = s
ta, tb = t
sas = mpf_sign(sa)
sbs = mpf_sign(sb)
tas = mpf_sign(ta)
tbs = mpf_sign(tb)
if sas == sbs == 0:
# Should maybe be undefined
if ta == fninf or tb == finf:
return fninf, finf
return fzero, fzero
if tas == tbs == 0:
# Should maybe be undefined
if sa == fninf or sb == finf:
return fninf, finf
return fzero, fzero
if sas >= 0:
# positive * positive
if tas >= 0:
a = mpf_mul(sa, ta, prec, round_floor)
b = mpf_mul(sb, tb, prec, round_ceiling)
if a == fnan: a = fzero
if b == fnan: b = finf
# positive * negative
elif tbs <= 0:
a = mpf_mul(sb, ta, prec, round_floor)
b = mpf_mul(sa, tb, prec, round_ceiling)
if a == fnan: a = fninf
if b == fnan: b = fzero
# positive * both signs
else:
a = mpf_mul(sb, ta, prec, round_floor)
b = mpf_mul(sb, tb, prec, round_ceiling)
if a == fnan: a = fninf
if b == fnan: b = finf
elif sbs <= 0:
# negative * positive
if tas >= 0:
a = mpf_mul(sa, tb, prec, round_floor)
b = mpf_mul(sb, ta, prec, round_ceiling)
if a == fnan: a = fninf
if b == fnan: b = fzero
# negative * negative
elif tbs <= 0:
a = mpf_mul(sb, tb, prec, round_floor)
b = mpf_mul(sa, ta, prec, round_ceiling)
if a == fnan: a = fzero
if b == fnan: b = finf
# negative * both signs
else:
a = mpf_mul(sa, tb, prec, round_floor)
b = mpf_mul(sa, ta, prec, round_ceiling)
if a == fnan: a = fninf
if b == fnan: b = finf
else:
# General case: perform all cross-multiplications and compare
# Since the multiplications can be done exactly, we need only
# do 4 (instead of 8: two for each rounding mode)
cases = [mpf_mul(sa, ta), mpf_mul(sa, tb), mpf_mul(sb, ta), mpf_mul(sb, tb)]
if fnan in cases:
a, b = (fninf, finf)
else:
a, b = mpf_min_max(cases)
a = mpf_pos(a, prec, round_floor)
b = mpf_pos(b, prec, round_ceiling)
return a, b
def mpi_square(s, prec=0):
sa, sb = s
if mpf_ge(sa, fzero):
a = mpf_mul(sa, sa, prec, round_floor)
b = mpf_mul(sb, sb, prec, round_ceiling)
elif mpf_le(sb, fzero):
a = mpf_mul(sb, sb, prec, round_floor)
b = mpf_mul(sa, sa, prec, round_ceiling)
else:
sa = mpf_neg(sa)
sa, sb = mpf_min_max([sa, sb])
a = fzero
b = mpf_mul(sb, sb, prec, round_ceiling)
return a, b
def mpi_div(s, t, prec):
sa, sb = s
ta, tb = t
sas = mpf_sign(sa)
sbs = mpf_sign(sb)
tas = mpf_sign(ta)
tbs = mpf_sign(tb)
# 0 / X
if sas == sbs == 0:
# 0 / <interval containing 0>
if (tas < 0 and tbs > 0) or (tas == 0 or tbs == 0):
return fninf, finf
return fzero, fzero
# Denominator contains both negative and positive numbers;
# this should properly be a multi-interval, but the closest
# match is the entire (extended) real line
if tas < 0 and tbs > 0:
return fninf, finf
# Assume denominator to be nonnegative
if tas < 0:
return mpi_div(mpi_neg(s), mpi_neg(t), prec)
# Division by zero
# XXX: make sure all results make sense
if tas == 0:
# Numerator contains both signs?
if sas < 0 and sbs > 0:
return fninf, finf
if tas == tbs:
return fninf, finf
# Numerator positive?
if sas >= 0:
a = mpf_div(sa, tb, prec, round_floor)
b = finf
if sbs <= 0:
a = fninf
b = mpf_div(sb, tb, prec, round_ceiling)
# Division with positive denominator
# We still have to handle nans resulting from inf/0 or inf/inf
else:
# Nonnegative numerator
if sas >= 0:
a = mpf_div(sa, tb, prec, round_floor)
b = mpf_div(sb, ta, prec, round_ceiling)
if a == fnan: a = fzero
if b == fnan: b = finf
# Nonpositive numerator
elif sbs <= 0:
a = mpf_div(sa, ta, prec, round_floor)
b = mpf_div(sb, tb, prec, round_ceiling)
if a == fnan: a = fninf
if b == fnan: b = fzero
# Numerator contains both signs?
else:
a = mpf_div(sa, ta, prec, round_floor)
b = mpf_div(sb, ta, prec, round_ceiling)
if a == fnan: a = fninf
if b == fnan: b = finf
return a, b
def mpi_pi(prec):
a = mpf_pi(prec, round_floor)
b = mpf_pi(prec, round_ceiling)
return a, b
def mpi_exp(s, prec):
sa, sb = s
# exp is monotonic
a = mpf_exp(sa, prec, round_floor)
b = mpf_exp(sb, prec, round_ceiling)
return a, b
def mpi_log(s, prec):
sa, sb = s
# log is monotonic
a = mpf_log(sa, prec, round_floor)
b = mpf_log(sb, prec, round_ceiling)
return a, b
def mpi_sqrt(s, prec):
sa, sb = s
# sqrt is monotonic
a = mpf_sqrt(sa, prec, round_floor)
b = mpf_sqrt(sb, prec, round_ceiling)
return a, b
def mpi_atan(s, prec):
sa, sb = s
a = mpf_atan(sa, prec, round_floor)
b = mpf_atan(sb, prec, round_ceiling)
return a, b
def mpi_pow_int(s, n, prec):
sa, sb = s
if n < 0:
return mpi_div((fone, fone), mpi_pow_int(s, -n, prec+20), prec)
if n == 0:
return (fone, fone)
if n == 1:
return s
if n == 2:
return mpi_square(s, prec)
# Odd -- signs are preserved
if n & 1:
a = mpf_pow_int(sa, n, prec, round_floor)
b = mpf_pow_int(sb, n, prec, round_ceiling)
# Even -- important to ensure positivity
else:
sas = mpf_sign(sa)
sbs = mpf_sign(sb)
# Nonnegative?
if sas >= 0:
a = mpf_pow_int(sa, n, prec, round_floor)
b = mpf_pow_int(sb, n, prec, round_ceiling)
# Nonpositive?
elif sbs <= 0:
a = mpf_pow_int(sb, n, prec, round_floor)
b = mpf_pow_int(sa, n, prec, round_ceiling)
# Mixed signs?
else:
a = fzero
# max(-a,b)**n
sa = mpf_neg(sa)
if mpf_ge(sa, sb):
b = mpf_pow_int(sa, n, prec, round_ceiling)
else:
b = mpf_pow_int(sb, n, prec, round_ceiling)
return a, b
def mpi_pow(s, t, prec):
ta, tb = t
if ta == tb and ta not in (finf, fninf):
if ta == from_int(to_int(ta)):
return mpi_pow_int(s, to_int(ta), prec)
if ta == fhalf:
return mpi_sqrt(s, prec)
u = mpi_log(s, prec + 20)
v = mpi_mul(u, t, prec + 20)
return mpi_exp(v, prec)
def MIN(x, y):
if mpf_le(x, y):
return x
return y
def MAX(x, y):
if mpf_ge(x, y):
return x
return y
def cos_sin_quadrant(x, wp):
sign, man, exp, bc = x
if x == fzero:
return fone, fzero, 0
# TODO: combine evaluation code to avoid duplicate modulo
c, s = mpf_cos_sin(x, wp)
t, n, wp_ = mod_pi2(man, exp, exp+bc, 15)
if sign:
n = -1-n
return c, s, n
def mpi_cos_sin(x, prec):
a, b = x
if a == b == fzero:
return (fone, fone), (fzero, fzero)
# Guaranteed to contain both -1 and 1
if (finf in x) or (fninf in x):
return (fnone, fone), (fnone, fone)
wp = prec + 20
ca, sa, na = cos_sin_quadrant(a, wp)
cb, sb, nb = cos_sin_quadrant(b, wp)
ca, cb = mpf_min_max([ca, cb])
sa, sb = mpf_min_max([sa, sb])
# Both functions are monotonic within one quadrant
if na == nb:
pass
# Guaranteed to contain both -1 and 1
elif nb - na >= 4:
return (fnone, fone), (fnone, fone)
else:
# cos has maximum between a and b
if na//4 != nb//4:
cb = fone
# cos has minimum
if (na-2)//4 != (nb-2)//4:
ca = fnone
# sin has maximum
if (na-1)//4 != (nb-1)//4:
sb = fone
# sin has minimum
if (na-3)//4 != (nb-3)//4:
sa = fnone
# Perturb to force interval rounding
more = from_man_exp((MPZ_ONE<<wp) + (MPZ_ONE<<10), -wp)
less = from_man_exp((MPZ_ONE<<wp) - (MPZ_ONE<<10), -wp)
def finalize(v, rounding):
if bool(v[0]) == (rounding == round_floor):
p = more
else:
p = less
v = mpf_mul(v, p, prec, rounding)
sign, man, exp, bc = v
if exp+bc >= 1:
if sign:
return fnone
return fone
return v
ca = finalize(ca, round_floor)
cb = finalize(cb, round_ceiling)
sa = finalize(sa, round_floor)
sb = finalize(sb, round_ceiling)
return (ca,cb), (sa,sb)
def mpi_cos(x, prec):
return mpi_cos_sin(x, prec)[0]
def mpi_sin(x, prec):
return mpi_cos_sin(x, prec)[1]
def mpi_tan(x, prec):
cos, sin = mpi_cos_sin(x, prec+20)
return mpi_div(sin, cos, prec)
def mpi_cot(x, prec):
cos, sin = mpi_cos_sin(x, prec+20)
return mpi_div(cos, sin, prec)
def mpi_from_str_a_b(x, y, percent, prec):
wp = prec + 20
xa = from_str(x, wp, round_floor)
xb = from_str(x, wp, round_ceiling)
#ya = from_str(y, wp, round_floor)
y = from_str(y, wp, round_ceiling)
assert mpf_ge(y, fzero)
if percent:
y = mpf_mul(MAX(mpf_abs(xa), mpf_abs(xb)), y, wp, round_ceiling)
y = mpf_div(y, from_int(100), wp, round_ceiling)
a = mpf_sub(xa, y, prec, round_floor)
b = mpf_add(xb, y, prec, round_ceiling)
return a, b
def mpi_from_str(s, prec):
"""
Parse an interval number given as a string.
Allowed forms are
"-1.23e-27"
Any single decimal floating-point literal.
"a +- b" or "a (b)"
a is the midpoint of the interval and b is the half-width
"a +- b%" or "a (b%)"
a is the midpoint of the interval and the half-width
is b percent of a (`a \times b / 100`).
"[a, b]"
The interval indicated directly.
"x[y,z]e"
x are shared digits, y and z are unequal digits, e is the exponent.
"""
e = ValueError("Improperly formed interval number '%s'" % s)
s = s.replace(" ", "")
wp = prec + 20
if "+-" in s:
x, y = s.split("+-")
return mpi_from_str_a_b(x, y, False, prec)
# case 2
elif "(" in s:
# Don't confuse with a complex number (x,y)
if s[0] == "(" or ")" not in s:
raise e
s = s.replace(")", "")
percent = False
if "%" in s:
if s[-1] != "%":
raise e
percent = True
s = s.replace("%", "")
x, y = s.split("(")
return mpi_from_str_a_b(x, y, percent, prec)
elif "," in s:
if ('[' not in s) or (']' not in s):
raise e
if s[0] == '[':
# case 3
s = s.replace("[", "")
s = s.replace("]", "")
a, b = s.split(",")
a = from_str(a, prec, round_floor)
b = from_str(b, prec, round_ceiling)
return a, b
else:
# case 4
x, y = s.split('[')
y, z = y.split(',')
if 'e' in s:
z, e = z.split(']')
else:
z, e = z.rstrip(']'), ''
a = from_str(x+y+e, prec, round_floor)
b = from_str(x+z+e, prec, round_ceiling)
return a, b
else:
a = from_str(s, prec, round_floor)
b = from_str(s, prec, round_ceiling)
return a, b
def mpi_to_str(x, dps, use_spaces=True, brackets='[]', mode='brackets', error_dps=4, **kwargs):
"""
Convert a mpi interval to a string.
**Arguments**
*dps*
decimal places to use for printing
*use_spaces*
use spaces for more readable output, defaults to true
*brackets*
pair of strings (or two-character string) giving left and right brackets
*mode*
mode of display: 'plusminus', 'percent', 'brackets' (default) or 'diff'
*error_dps*
limit the error to *error_dps* digits (mode 'plusminus and 'percent')
Additional keyword arguments are forwarded to the mpf-to-string conversion
for the components of the output.
**Examples**
>>> from mpmath import mpi, mp
>>> mp.dps = 30
>>> x = mpi(1, 2)._mpi_
>>> mpi_to_str(x, 2, mode='plusminus')
'1.5 +- 0.5'
>>> mpi_to_str(x, 2, mode='percent')
'1.5 (33.33%)'
>>> mpi_to_str(x, 2, mode='brackets')
'[1.0, 2.0]'
>>> mpi_to_str(x, 2, mode='brackets' , brackets=('<', '>'))
'<1.0, 2.0>'
>>> x = mpi('5.2582327113062393041', '5.2582327113062749951')._mpi_
>>> mpi_to_str(x, 15, mode='diff')
'5.2582327113062[4, 7]'
>>> mpi_to_str(mpi(0)._mpi_, 2, mode='percent')
'0.0 (0.0%)'
"""
prec = dps_to_prec(dps)
wp = prec + 20
a, b = x
mid = mpi_mid(x, prec)
delta = mpi_delta(x, prec)
a_str = to_str(a, dps, **kwargs)
b_str = to_str(b, dps, **kwargs)
mid_str = to_str(mid, dps, **kwargs)
sp = ""
if use_spaces:
sp = " "
br1, br2 = brackets
if mode == 'plusminus':
delta_str = to_str(mpf_shift(delta,-1), dps, **kwargs)
s = mid_str + sp + "+-" + sp + delta_str
elif mode == 'percent':
if mid == fzero:
p = fzero
else:
# p = 100 * delta(x) / (2*mid(x))
p = mpf_mul(delta, from_int(100))
p = mpf_div(p, mpf_mul(mid, from_int(2)), wp)
s = mid_str + sp + "(" + to_str(p, error_dps) + "%)"
elif mode == 'brackets':
s = br1 + a_str + "," + sp + b_str + br2
elif mode == 'diff':
# use more digits if str(x.a) and str(x.b) are equal
if a_str == b_str:
a_str = to_str(a, dps+3, **kwargs)
b_str = to_str(b, dps+3, **kwargs)
# separate mantissa and exponent
a = a_str.split('e')
if len(a) == 1:
a.append('')
b = b_str.split('e')
if len(b) == 1:
b.append('')
if a[1] == b[1]:
if a[0] != b[0]:
for i in xrange(len(a[0]) + 1):
if a[0][i] != b[0][i]:
break
s = (a[0][:i] + br1 + a[0][i:] + ',' + sp + b[0][i:] + br2
+ 'e'*min(len(a[1]), 1) + a[1])
else: # no difference
s = a[0] + br1 + br2 + 'e'*min(len(a[1]), 1) + a[1]
else:
s = br1 + 'e'.join(a) + ',' + sp + 'e'.join(b) + br2
else:
raise ValueError("'%s' is unknown mode for printing mpi" % mode)
return s
def mpci_add(x, y, prec):
a, b = x
c, d = y
return mpi_add(a, c, prec), mpi_add(b, d, prec)
def mpci_sub(x, y, prec):
a, b = x
c, d = y
return mpi_sub(a, c, prec), mpi_sub(b, d, prec)
def mpci_neg(x, prec=0):
a, b = x
return mpi_neg(a, prec), mpi_neg(b, prec)
def mpci_pos(x, prec):
a, b = x
return mpi_pos(a, prec), mpi_pos(b, prec)
def mpci_mul(x, y, prec):
# TODO: optimize for real/imag cases
a, b = x
c, d = y
r1 = mpi_mul(a,c)
r2 = mpi_mul(b,d)
re = mpi_sub(r1,r2,prec)
i1 = mpi_mul(a,d)
i2 = mpi_mul(b,c)
im = mpi_add(i1,i2,prec)
return re, im
def mpci_div(x, y, prec):
# TODO: optimize for real/imag cases
a, b = x
c, d = y
wp = prec+20
m1 = mpi_square(c)
m2 = mpi_square(d)
m = mpi_add(m1,m2,wp)
re = mpi_add(mpi_mul(a,c), mpi_mul(b,d), wp)
im = mpi_sub(mpi_mul(b,c), mpi_mul(a,d), wp)
re = mpi_div(re, m, prec)
im = mpi_div(im, m, prec)
return re, im
def mpci_exp(x, prec):
a, b = x
wp = prec+20
r = mpi_exp(a, wp)
c, s = mpi_cos_sin(b, wp)
a = mpi_mul(r, c, prec)
b = mpi_mul(r, s, prec)
return a, b
def mpi_shift(x, n):
a, b = x
return mpf_shift(a,n), mpf_shift(b,n)
def mpi_cosh_sinh(x, prec):
# TODO: accuracy for small x
wp = prec+20
e1 = mpi_exp(x, wp)
e2 = mpi_div(mpi_one, e1, wp)
c = mpi_add(e1, e2, prec)
s = mpi_sub(e1, e2, prec)
c = mpi_shift(c, -1)
s = mpi_shift(s, -1)
return c, s
def mpci_cos(x, prec):
a, b = x
wp = prec+10
c, s = mpi_cos_sin(a, wp)
ch, sh = mpi_cosh_sinh(b, wp)
re = mpi_mul(c, ch, prec)
im = mpi_mul(s, sh, prec)
return re, mpi_neg(im)
def mpci_sin(x, prec):
a, b = x
wp = prec+10
c, s = mpi_cos_sin(a, wp)
ch, sh = mpi_cosh_sinh(b, wp)
re = mpi_mul(s, ch, prec)
im = mpi_mul(c, sh, prec)
return re, im
def mpci_abs(x, prec):
a, b = x
if a == mpi_zero:
return mpi_abs(b)
if b == mpi_zero:
return mpi_abs(a)
# Important: nonnegative
a = mpi_square(a)
b = mpi_square(b)
t = mpi_add(a, b, prec+20)
return mpi_sqrt(t, prec)
def mpi_atan2(y, x, prec):
ya, yb = y
xa, xb = x
# Constrained to the real line
if ya == yb == fzero:
if mpf_ge(xa, fzero):
return mpi_zero
return mpi_pi(prec)
# Right half-plane
if mpf_ge(xa, fzero):
if mpf_ge(ya, fzero):
a = mpf_atan2(ya, xb, prec, round_floor)
else:
a = mpf_atan2(ya, xa, prec, round_floor)
if mpf_ge(yb, fzero):
b = mpf_atan2(yb, xa, prec, round_ceiling)
else:
b = mpf_atan2(yb, xb, prec, round_ceiling)
# Upper half-plane
elif mpf_ge(ya, fzero):
b = mpf_atan2(ya, xa, prec, round_ceiling)
if mpf_le(xb, fzero):
a = mpf_atan2(yb, xb, prec, round_floor)
else:
a = mpf_atan2(ya, xb, prec, round_floor)
# Lower half-plane
elif mpf_le(yb, fzero):
a = mpf_atan2(yb, xa, prec, round_floor)
if mpf_le(xb, fzero):
b = mpf_atan2(ya, xb, prec, round_ceiling)
else:
b = mpf_atan2(yb, xb, prec, round_ceiling)
# Covering the origin
else:
b = mpf_pi(prec, round_ceiling)
a = mpf_neg(b)
return a, b
def mpci_arg(z, prec):
x, y = z
return mpi_atan2(y, x, prec)
def mpci_log(z, prec):
x, y = z
re = mpi_log(mpci_abs(z, prec+20), prec)
im = mpci_arg(z, prec)
return re, im
def mpci_pow(x, y, prec):
# TODO: recognize/speed up real cases, integer y
yre, yim = y
if yim == mpi_zero:
ya, yb = yre
if ya == yb:
sign, man, exp, bc = yb
if man and exp >= 0:
return mpci_pow_int(x, (-1)**sign * int(man<<exp), prec)
# x^0
if yb == fzero:
return mpci_pow_int(x, 0, prec)
wp = prec+20
return mpci_exp(mpci_mul(y, mpci_log(x, wp), wp), prec)
def mpci_square(x, prec):
a, b = x
# (a+bi)^2 = (a^2-b^2) + 2abi
re = mpi_sub(mpi_square(a), mpi_square(b), prec)
im = mpi_mul(a, b, prec)
im = mpi_shift(im, 1)
return re, im
def mpci_pow_int(x, n, prec):
if n < 0:
return mpci_div((mpi_one,mpi_zero), mpci_pow_int(x, -n, prec+20), prec)
if n == 0:
return mpi_one, mpi_zero
if n == 1:
return mpci_pos(x, prec)
if n == 2:
return mpci_square(x, prec)
wp = prec + 20
result = (mpi_one, mpi_zero)
while n:
if n & 1:
result = mpci_mul(result, x, wp)
n -= 1
x = mpci_square(x, wp)
n >>= 1
return mpci_pos(result, prec)
gamma_min_a = from_float(1.46163214496)
gamma_min_b = from_float(1.46163214497)
gamma_min = (gamma_min_a, gamma_min_b)
gamma_mono_imag_a = from_float(-1.1)
gamma_mono_imag_b = from_float(1.1)
def mpi_overlap(x, y):
a, b = x
c, d = y
if mpf_lt(d, a): return False
if mpf_gt(c, b): return False
return True
# type = 0 -- gamma
# type = 1 -- factorial
# type = 2 -- 1/gamma
# type = 3 -- log-gamma
def mpi_gamma(z, prec, type=0):
a, b = z
wp = prec+20
if type == 1:
return mpi_gamma(mpi_add(z, mpi_one, wp), prec, 0)
# increasing
if mpf_gt(a, gamma_min_b):
if type == 0:
c = mpf_gamma(a, prec, round_floor)
d = mpf_gamma(b, prec, round_ceiling)
elif type == 2:
c = mpf_rgamma(b, prec, round_floor)
d = mpf_rgamma(a, prec, round_ceiling)
elif type == 3:
c = mpf_loggamma(a, prec, round_floor)
d = mpf_loggamma(b, prec, round_ceiling)
# decreasing
elif mpf_gt(a, fzero) and mpf_lt(b, gamma_min_a):
if type == 0:
c = mpf_gamma(b, prec, round_floor)
d = mpf_gamma(a, prec, round_ceiling)
elif type == 2:
c = mpf_rgamma(a, prec, round_floor)
d = mpf_rgamma(b, prec, round_ceiling)
elif type == 3:
c = mpf_loggamma(b, prec, round_floor)
d = mpf_loggamma(a, prec, round_ceiling)
else:
# TODO: reflection formula
znew = mpi_add(z, mpi_one, wp)
if type == 0: return mpi_div(mpi_gamma(znew, prec+2, 0), z, prec)
if type == 2: return mpi_mul(mpi_gamma(znew, prec+2, 2), z, prec)
if type == 3: return mpi_sub(mpi_gamma(znew, prec+2, 3), mpi_log(z, prec+2), prec)
return c, d
def mpci_gamma(z, prec, type=0):
(a1,a2), (b1,b2) = z
# Real case
if b1 == b2 == fzero and (type != 3 or mpf_gt(a1,fzero)):
return mpi_gamma(z, prec, type), mpi_zero
# Estimate precision
wp = prec+20
if type != 3:
amag = a2[2]+a2[3]
bmag = b2[2]+b2[3]
if a2 != fzero:
mag = max(amag, bmag)
else:
mag = bmag
an = abs(to_int(a2))
bn = abs(to_int(b2))
absn = max(an, bn)
gamma_size = max(0,absn*mag)
wp += bitcount(gamma_size)
# Assume type != 1
if type == 1:
(a1,a2) = mpi_add((a1,a2), mpi_one, wp); z = (a1,a2), (b1,b2)
type = 0
# Avoid non-monotonic region near the negative real axis
if mpf_lt(a1, gamma_min_b):
if mpi_overlap((b1,b2), (gamma_mono_imag_a, gamma_mono_imag_b)):
# TODO: reflection formula
#if mpf_lt(a2, mpf_shift(fone,-1)):
# znew = mpci_sub((mpi_one,mpi_zero),z,wp)
# ...
# Recurrence:
# gamma(z) = gamma(z+1)/z
znew = mpi_add((a1,a2), mpi_one, wp), (b1,b2)
if type == 0: return mpci_div(mpci_gamma(znew, prec+2, 0), z, prec)
if type == 2: return mpci_mul(mpci_gamma(znew, prec+2, 2), z, prec)
if type == 3: return mpci_sub(mpci_gamma(znew, prec+2, 3), mpci_log(z,prec+2), prec)
# Use monotonicity (except for a small region close to the
# origin and near poles)
# upper half-plane
if mpf_ge(b1, fzero):
minre = mpc_loggamma((a1,b2), wp, round_floor)
maxre = mpc_loggamma((a2,b1), wp, round_ceiling)
minim = mpc_loggamma((a1,b1), wp, round_floor)
maxim = mpc_loggamma((a2,b2), wp, round_ceiling)
# lower half-plane
elif mpf_le(b2, fzero):
minre = mpc_loggamma((a1,b1), wp, round_floor)
maxre = mpc_loggamma((a2,b2), wp, round_ceiling)
minim = mpc_loggamma((a2,b1), wp, round_floor)
maxim = mpc_loggamma((a1,b2), wp, round_ceiling)
# crosses real axis
else:
maxre = mpc_loggamma((a2,fzero), wp, round_ceiling)
# stretches more into the lower half-plane
if mpf_gt(mpf_neg(b1), b2):
minre = mpc_loggamma((a1,b1), wp, round_ceiling)
else:
minre = mpc_loggamma((a1,b2), wp, round_ceiling)
minim = mpc_loggamma((a2,b1), wp, round_floor)
maxim = mpc_loggamma((a2,b2), wp, round_floor)
w = (minre[0], maxre[0]), (minim[1], maxim[1])
if type == 3:
return mpi_pos(w[0], prec), mpi_pos(w[1], prec)
if type == 2:
w = mpci_neg(w)
return mpci_exp(w, prec)
def mpi_loggamma(z, prec): return mpi_gamma(z, prec, type=3)
def mpci_loggamma(z, prec): return mpci_gamma(z, prec, type=3)
def mpi_rgamma(z, prec): return mpi_gamma(z, prec, type=2)
def mpci_rgamma(z, prec): return mpci_gamma(z, prec, type=2)
def mpi_factorial(z, prec): return mpi_gamma(z, prec, type=1)
def mpci_factorial(z, prec): return mpci_gamma(z, prec, type=1)

View File

@ -0,0 +1,672 @@
"""
This module complements the math and cmath builtin modules by providing
fast machine precision versions of some additional functions (gamma, ...)
and wrapping math/cmath functions so that they can be called with either
real or complex arguments.
"""
import operator
import math
import cmath
# Irrational (?) constants
pi = 3.1415926535897932385
e = 2.7182818284590452354
sqrt2 = 1.4142135623730950488
sqrt5 = 2.2360679774997896964
phi = 1.6180339887498948482
ln2 = 0.69314718055994530942
ln10 = 2.302585092994045684
euler = 0.57721566490153286061
catalan = 0.91596559417721901505
khinchin = 2.6854520010653064453
apery = 1.2020569031595942854
logpi = 1.1447298858494001741
def _mathfun_real(f_real, f_complex):
def f(x, **kwargs):
if type(x) is float:
return f_real(x)
if type(x) is complex:
return f_complex(x)
try:
x = float(x)
return f_real(x)
except (TypeError, ValueError):
x = complex(x)
return f_complex(x)
f.__name__ = f_real.__name__
return f
def _mathfun(f_real, f_complex):
def f(x, **kwargs):
if type(x) is complex:
return f_complex(x)
try:
return f_real(float(x))
except (TypeError, ValueError):
return f_complex(complex(x))
f.__name__ = f_real.__name__
return f
def _mathfun_n(f_real, f_complex):
def f(*args, **kwargs):
try:
return f_real(*(float(x) for x in args))
except (TypeError, ValueError):
return f_complex(*(complex(x) for x in args))
f.__name__ = f_real.__name__
return f
# Workaround for non-raising log and sqrt in Python 2.5 and 2.4
# on Unix system
try:
math.log(-2.0)
def math_log(x):
if x <= 0.0:
raise ValueError("math domain error")
return math.log(x)
def math_sqrt(x):
if x < 0.0:
raise ValueError("math domain error")
return math.sqrt(x)
except (ValueError, TypeError):
math_log = math.log
math_sqrt = math.sqrt
pow = _mathfun_n(operator.pow, lambda x, y: complex(x)**y)
log = _mathfun_n(math_log, cmath.log)
sqrt = _mathfun(math_sqrt, cmath.sqrt)
exp = _mathfun_real(math.exp, cmath.exp)
cos = _mathfun_real(math.cos, cmath.cos)
sin = _mathfun_real(math.sin, cmath.sin)
tan = _mathfun_real(math.tan, cmath.tan)
acos = _mathfun(math.acos, cmath.acos)
asin = _mathfun(math.asin, cmath.asin)
atan = _mathfun_real(math.atan, cmath.atan)
cosh = _mathfun_real(math.cosh, cmath.cosh)
sinh = _mathfun_real(math.sinh, cmath.sinh)
tanh = _mathfun_real(math.tanh, cmath.tanh)
floor = _mathfun_real(math.floor,
lambda z: complex(math.floor(z.real), math.floor(z.imag)))
ceil = _mathfun_real(math.ceil,
lambda z: complex(math.ceil(z.real), math.ceil(z.imag)))
cos_sin = _mathfun_real(lambda x: (math.cos(x), math.sin(x)),
lambda z: (cmath.cos(z), cmath.sin(z)))
cbrt = _mathfun(lambda x: x**(1./3), lambda z: z**(1./3))
def nthroot(x, n):
r = 1./n
try:
return float(x) ** r
except (ValueError, TypeError):
return complex(x) ** r
def _sinpi_real(x):
if x < 0:
return -_sinpi_real(-x)
n, r = divmod(x, 0.5)
r *= pi
n %= 4
if n == 0: return math.sin(r)
if n == 1: return math.cos(r)
if n == 2: return -math.sin(r)
if n == 3: return -math.cos(r)
def _cospi_real(x):
if x < 0:
x = -x
n, r = divmod(x, 0.5)
r *= pi
n %= 4
if n == 0: return math.cos(r)
if n == 1: return -math.sin(r)
if n == 2: return -math.cos(r)
if n == 3: return math.sin(r)
def _sinpi_complex(z):
if z.real < 0:
return -_sinpi_complex(-z)
n, r = divmod(z.real, 0.5)
z = pi*complex(r, z.imag)
n %= 4
if n == 0: return cmath.sin(z)
if n == 1: return cmath.cos(z)
if n == 2: return -cmath.sin(z)
if n == 3: return -cmath.cos(z)
def _cospi_complex(z):
if z.real < 0:
z = -z
n, r = divmod(z.real, 0.5)
z = pi*complex(r, z.imag)
n %= 4
if n == 0: return cmath.cos(z)
if n == 1: return -cmath.sin(z)
if n == 2: return -cmath.cos(z)
if n == 3: return cmath.sin(z)
cospi = _mathfun_real(_cospi_real, _cospi_complex)
sinpi = _mathfun_real(_sinpi_real, _sinpi_complex)
def tanpi(x):
try:
return sinpi(x) / cospi(x)
except OverflowError:
if complex(x).imag > 10:
return 1j
if complex(x).imag < 10:
return -1j
raise
def cotpi(x):
try:
return cospi(x) / sinpi(x)
except OverflowError:
if complex(x).imag > 10:
return -1j
if complex(x).imag < 10:
return 1j
raise
INF = 1e300*1e300
NINF = -INF
NAN = INF-INF
EPS = 2.2204460492503131e-16
_exact_gamma = (INF, 1.0, 1.0, 2.0, 6.0, 24.0, 120.0, 720.0, 5040.0, 40320.0,
362880.0, 3628800.0, 39916800.0, 479001600.0, 6227020800.0, 87178291200.0,
1307674368000.0, 20922789888000.0, 355687428096000.0, 6402373705728000.0,
121645100408832000.0, 2432902008176640000.0)
_max_exact_gamma = len(_exact_gamma)-1
# Lanczos coefficients used by the GNU Scientific Library
_lanczos_g = 7
_lanczos_p = (0.99999999999980993, 676.5203681218851, -1259.1392167224028,
771.32342877765313, -176.61502916214059, 12.507343278686905,
-0.13857109526572012, 9.9843695780195716e-6, 1.5056327351493116e-7)
def _gamma_real(x):
_intx = int(x)
if _intx == x:
if _intx <= 0:
#return (-1)**_intx * INF
raise ZeroDivisionError("gamma function pole")
if _intx <= _max_exact_gamma:
return _exact_gamma[_intx]
if x < 0.5:
# TODO: sinpi
return pi / (_sinpi_real(x)*_gamma_real(1-x))
else:
x -= 1.0
r = _lanczos_p[0]
for i in range(1, _lanczos_g+2):
r += _lanczos_p[i]/(x+i)
t = x + _lanczos_g + 0.5
return 2.506628274631000502417 * t**(x+0.5) * math.exp(-t) * r
def _gamma_complex(x):
if not x.imag:
return complex(_gamma_real(x.real))
if x.real < 0.5:
# TODO: sinpi
return pi / (_sinpi_complex(x)*_gamma_complex(1-x))
else:
x -= 1.0
r = _lanczos_p[0]
for i in range(1, _lanczos_g+2):
r += _lanczos_p[i]/(x+i)
t = x + _lanczos_g + 0.5
return 2.506628274631000502417 * t**(x+0.5) * cmath.exp(-t) * r
gamma = _mathfun_real(_gamma_real, _gamma_complex)
def rgamma(x):
try:
return 1./gamma(x)
except ZeroDivisionError:
return x*0.0
def factorial(x):
return gamma(x+1.0)
def arg(x):
if type(x) is float:
return math.atan2(0.0,x)
return math.atan2(x.imag,x.real)
# XXX: broken for negatives
def loggamma(x):
if type(x) not in (float, complex):
try:
x = float(x)
except (ValueError, TypeError):
x = complex(x)
try:
xreal = x.real
ximag = x.imag
except AttributeError: # py2.5
xreal = x
ximag = 0.0
# Reflection formula
# http://functions.wolfram.com/GammaBetaErf/LogGamma/16/01/01/0003/
if xreal < 0.0:
if abs(x) < 0.5:
v = log(gamma(x))
if ximag == 0:
v = v.conjugate()
return v
z = 1-x
try:
re = z.real
im = z.imag
except AttributeError: # py2.5
re = z
im = 0.0
refloor = floor(re)
if im == 0.0:
imsign = 0
elif im < 0.0:
imsign = -1
else:
imsign = 1
return (-pi*1j)*abs(refloor)*(1-abs(imsign)) + logpi - \
log(sinpi(z-refloor)) - loggamma(z) + 1j*pi*refloor*imsign
if x == 1.0 or x == 2.0:
return x*0
p = 0.
while abs(x) < 11:
p -= log(x)
x += 1.0
s = 0.918938533204672742 + (x-0.5)*log(x) - x
r = 1./x
r2 = r*r
s += 0.083333333333333333333*r; r *= r2
s += -0.0027777777777777777778*r; r *= r2
s += 0.00079365079365079365079*r; r *= r2
s += -0.0005952380952380952381*r; r *= r2
s += 0.00084175084175084175084*r; r *= r2
s += -0.0019175269175269175269*r; r *= r2
s += 0.0064102564102564102564*r; r *= r2
s += -0.02955065359477124183*r
return s + p
_psi_coeff = [
0.083333333333333333333,
-0.0083333333333333333333,
0.003968253968253968254,
-0.0041666666666666666667,
0.0075757575757575757576,
-0.021092796092796092796,
0.083333333333333333333,
-0.44325980392156862745,
3.0539543302701197438,
-26.456212121212121212]
def _digamma_real(x):
_intx = int(x)
if _intx == x:
if _intx <= 0:
raise ZeroDivisionError("polygamma pole")
if x < 0.5:
x = 1.0-x
s = pi*cotpi(x)
else:
s = 0.0
while x < 10.0:
s -= 1.0/x
x += 1.0
x2 = x**-2
t = x2
for c in _psi_coeff:
s -= c*t
if t < 1e-20:
break
t *= x2
return s + math_log(x) - 0.5/x
def _digamma_complex(x):
if not x.imag:
return complex(_digamma_real(x.real))
if x.real < 0.5:
x = 1.0-x
s = pi*cotpi(x)
else:
s = 0.0
while abs(x) < 10.0:
s -= 1.0/x
x += 1.0
x2 = x**-2
t = x2
for c in _psi_coeff:
s -= c*t
if abs(t) < 1e-20:
break
t *= x2
return s + cmath.log(x) - 0.5/x
digamma = _mathfun_real(_digamma_real, _digamma_complex)
# TODO: could implement complex erf and erfc here. Need
# to find an accurate method (avoiding cancellation)
# for approx. 1 < abs(x) < 9.
_erfc_coeff_P = [
1.0000000161203922312,
2.1275306946297962644,
2.2280433377390253297,
1.4695509105618423961,
0.66275911699770787537,
0.20924776504163751585,
0.045459713768411264339,
0.0063065951710717791934,
0.00044560259661560421715][::-1]
_erfc_coeff_Q = [
1.0000000000000000000,
3.2559100272784894318,
4.9019435608903239131,
4.4971472894498014205,
2.7845640601891186528,
1.2146026030046904138,
0.37647108453729465912,
0.080970149639040548613,
0.011178148899483545902,
0.00078981003831980423513][::-1]
def _polyval(coeffs, x):
p = coeffs[0]
for c in coeffs[1:]:
p = c + x*p
return p
def _erf_taylor(x):
# Taylor series assuming 0 <= x <= 1
x2 = x*x
s = t = x
n = 1
while abs(t) > 1e-17:
t *= x2/n
s -= t/(n+n+1)
n += 1
t *= x2/n
s += t/(n+n+1)
n += 1
return 1.1283791670955125739*s
def _erfc_mid(x):
# Rational approximation assuming 0 <= x <= 9
return exp(-x*x)*_polyval(_erfc_coeff_P,x)/_polyval(_erfc_coeff_Q,x)
def _erfc_asymp(x):
# Asymptotic expansion assuming x >= 9
x2 = x*x
v = exp(-x2)/x*0.56418958354775628695
r = t = 0.5 / x2
s = 1.0
for n in range(1,22,4):
s -= t
t *= r * (n+2)
s += t
t *= r * (n+4)
if abs(t) < 1e-17:
break
return s * v
def erf(x):
"""
erf of a real number.
"""
x = float(x)
if x != x:
return x
if x < 0.0:
return -erf(-x)
if x >= 1.0:
if x >= 6.0:
return 1.0
return 1.0 - _erfc_mid(x)
return _erf_taylor(x)
def erfc(x):
"""
erfc of a real number.
"""
x = float(x)
if x != x:
return x
if x < 0.0:
if x < -6.0:
return 2.0
return 2.0-erfc(-x)
if x > 9.0:
return _erfc_asymp(x)
if x >= 1.0:
return _erfc_mid(x)
return 1.0 - _erf_taylor(x)
gauss42 = [\
(0.99839961899006235, 0.0041059986046490839),
(-0.99839961899006235, 0.0041059986046490839),
(0.9915772883408609, 0.009536220301748501),
(-0.9915772883408609,0.009536220301748501),
(0.97934250806374812, 0.014922443697357493),
(-0.97934250806374812, 0.014922443697357493),
(0.96175936533820439,0.020227869569052644),
(-0.96175936533820439, 0.020227869569052644),
(0.93892355735498811, 0.025422959526113047),
(-0.93892355735498811,0.025422959526113047),
(0.91095972490412735, 0.030479240699603467),
(-0.91095972490412735, 0.030479240699603467),
(0.87802056981217269,0.03536907109759211),
(-0.87802056981217269, 0.03536907109759211),
(0.8402859832618168, 0.040065735180692258),
(-0.8402859832618168,0.040065735180692258),
(0.7979620532554873, 0.044543577771965874),
(-0.7979620532554873, 0.044543577771965874),
(0.75127993568948048,0.048778140792803244),
(-0.75127993568948048, 0.048778140792803244),
(0.70049459055617114, 0.052746295699174064),
(-0.70049459055617114,0.052746295699174064),
(0.64588338886924779, 0.056426369358018376),
(-0.64588338886924779, 0.056426369358018376),
(0.58774459748510932, 0.059798262227586649),
(-0.58774459748510932, 0.059798262227586649),
(0.5263957499311922, 0.062843558045002565),
(-0.5263957499311922, 0.062843558045002565),
(0.46217191207042191, 0.065545624364908975),
(-0.46217191207042191, 0.065545624364908975),
(0.39542385204297503, 0.067889703376521934),
(-0.39542385204297503, 0.067889703376521934),
(0.32651612446541151, 0.069862992492594159),
(-0.32651612446541151, 0.069862992492594159),
(0.25582507934287907, 0.071454714265170971),
(-0.25582507934287907, 0.071454714265170971),
(0.18373680656485453, 0.072656175243804091),
(-0.18373680656485453, 0.072656175243804091),
(0.11064502720851986, 0.073460813453467527),
(-0.11064502720851986, 0.073460813453467527),
(0.036948943165351772, 0.073864234232172879),
(-0.036948943165351772, 0.073864234232172879)]
EI_ASYMP_CONVERGENCE_RADIUS = 40.0
def ei_asymp(z, _e1=False):
r = 1./z
s = t = 1.0
k = 1
while 1:
t *= k*r
s += t
if abs(t) < 1e-16:
break
k += 1
v = s*exp(z)/z
if _e1:
if type(z) is complex:
zreal = z.real
zimag = z.imag
else:
zreal = z
zimag = 0.0
if zimag == 0.0 and zreal > 0.0:
v += pi*1j
else:
if type(z) is complex:
if z.imag > 0:
v += pi*1j
if z.imag < 0:
v -= pi*1j
return v
def ei_taylor(z, _e1=False):
s = t = z
k = 2
while 1:
t = t*z/k
term = t/k
if abs(term) < 1e-17:
break
s += term
k += 1
s += euler
if _e1:
s += log(-z)
else:
if type(z) is float or z.imag == 0.0:
s += math_log(abs(z))
else:
s += cmath.log(z)
return s
def ei(z, _e1=False):
typez = type(z)
if typez not in (float, complex):
try:
z = float(z)
typez = float
except (TypeError, ValueError):
z = complex(z)
typez = complex
if not z:
return -INF
absz = abs(z)
if absz > EI_ASYMP_CONVERGENCE_RADIUS:
return ei_asymp(z, _e1)
elif absz <= 2.0 or (typez is float and z > 0.0):
return ei_taylor(z, _e1)
# Integrate, starting from whichever is smaller of a Taylor
# series value or an asymptotic series value
if typez is complex and z.real > 0.0:
zref = z / absz
ref = ei_taylor(zref, _e1)
else:
zref = EI_ASYMP_CONVERGENCE_RADIUS * z / absz
ref = ei_asymp(zref, _e1)
C = (zref-z)*0.5
D = (zref+z)*0.5
s = 0.0
if type(z) is complex:
_exp = cmath.exp
else:
_exp = math.exp
for x,w in gauss42:
t = C*x+D
s += w*_exp(t)/t
ref -= C*s
return ref
def e1(z):
# hack to get consistent signs if the imaginary part if 0
# and signed
typez = type(z)
if type(z) not in (float, complex):
try:
z = float(z)
typez = float
except (TypeError, ValueError):
z = complex(z)
typez = complex
if typez is complex and not z.imag:
z = complex(z.real, 0.0)
# end hack
return -ei(-z, _e1=True)
_zeta_int = [\
-0.5,
0.0,
1.6449340668482264365,1.2020569031595942854,1.0823232337111381915,
1.0369277551433699263,1.0173430619844491397,1.0083492773819228268,
1.0040773561979443394,1.0020083928260822144,1.0009945751278180853,
1.0004941886041194646,1.0002460865533080483,1.0001227133475784891,
1.0000612481350587048,1.0000305882363070205,1.0000152822594086519,
1.0000076371976378998,1.0000038172932649998,1.0000019082127165539,
1.0000009539620338728,1.0000004769329867878,1.0000002384505027277,
1.0000001192199259653,1.0000000596081890513,1.0000000298035035147,
1.0000000149015548284]
_zeta_P = [-3.50000000087575873, -0.701274355654678147,
-0.0672313458590012612, -0.00398731457954257841,
-0.000160948723019303141, -4.67633010038383371e-6,
-1.02078104417700585e-7, -1.68030037095896287e-9,
-1.85231868742346722e-11][::-1]
_zeta_Q = [1.00000000000000000, -0.936552848762465319,
-0.0588835413263763741, -0.00441498861482948666,
-0.000143416758067432622, -5.10691659585090782e-6,
-9.58813053268913799e-8, -1.72963791443181972e-9,
-1.83527919681474132e-11][::-1]
_zeta_1 = [3.03768838606128127e-10, -1.21924525236601262e-8,
2.01201845887608893e-7, -1.53917240683468381e-6,
-5.09890411005967954e-7, 0.000122464707271619326,
-0.000905721539353130232, -0.00239315326074843037,
0.084239750013159168, 0.418938517907442414, 0.500000001921884009]
_zeta_0 = [-3.46092485016748794e-10, -6.42610089468292485e-9,
1.76409071536679773e-7, -1.47141263991560698e-6, -6.38880222546167613e-7,
0.000122641099800668209, -0.000905894913516772796, -0.00239303348507992713,
0.0842396947501199816, 0.418938533204660256, 0.500000000000000052]
def zeta(s):
"""
Riemann zeta function, real argument
"""
if not isinstance(s, (float, int)):
try:
s = float(s)
except (ValueError, TypeError):
try:
s = complex(s)
if not s.imag:
return complex(zeta(s.real))
except (ValueError, TypeError):
pass
raise NotImplementedError
if s == 1:
raise ValueError("zeta(1) pole")
if s >= 27:
return 1.0 + 2.0**(-s) + 3.0**(-s)
n = int(s)
if n == s:
if n >= 0:
return _zeta_int[n]
if not (n % 2):
return 0.0
if s <= 0.0:
return 2.**s*pi**(s-1)*_sinpi_real(0.5*s)*_gamma_real(1-s)*zeta(1-s)
if s <= 2.0:
if s <= 1.0:
return _polyval(_zeta_0,s)/(s-1)
return _polyval(_zeta_1,s)/(s-1)
z = _polyval(_zeta_P,s) / _polyval(_zeta_Q,s)
return 1.0 + 2.0**(-s) + 3.0**(-s) + 4.0**(-s)*z

View File

@ -0,0 +1,2 @@
from . import eigen # to set methods
from . import eigen_symmetric # to set methods

View File

@ -0,0 +1,531 @@
from ..libmp.backend import xrange
# TODO: should use diagonalization-based algorithms
class MatrixCalculusMethods(object):
def _exp_pade(ctx, a):
"""
Exponential of a matrix using Pade approximants.
See G. H. Golub, C. F. van Loan 'Matrix Computations',
third Ed., page 572
TODO:
- find a good estimate for q
- reduce the number of matrix multiplications to improve
performance
"""
def eps_pade(p):
return ctx.mpf(2)**(3-2*p) * \
ctx.factorial(p)**2/(ctx.factorial(2*p)**2 * (2*p + 1))
q = 4
extraq = 8
while 1:
if eps_pade(q) < ctx.eps:
break
q += 1
q += extraq
j = int(max(1, ctx.mag(ctx.mnorm(a,'inf'))))
extra = q
prec = ctx.prec
ctx.dps += extra + 3
try:
a = a/2**j
na = a.rows
den = ctx.eye(na)
num = ctx.eye(na)
x = ctx.eye(na)
c = ctx.mpf(1)
for k in range(1, q+1):
c *= ctx.mpf(q - k + 1)/((2*q - k + 1) * k)
x = a*x
cx = c*x
num += cx
den += (-1)**k * cx
f = ctx.lu_solve_mat(den, num)
for k in range(j):
f = f*f
finally:
ctx.prec = prec
return f*1
def expm(ctx, A, method='taylor'):
r"""
Computes the matrix exponential of a square matrix `A`, which is defined
by the power series
.. math ::
\exp(A) = I + A + \frac{A^2}{2!} + \frac{A^3}{3!} + \ldots
With method='taylor', the matrix exponential is computed
using the Taylor series. With method='pade', Pade approximants
are used instead.
**Examples**
Basic examples::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> expm(zeros(3))
[1.0 0.0 0.0]
[0.0 1.0 0.0]
[0.0 0.0 1.0]
>>> expm(eye(3))
[2.71828182845905 0.0 0.0]
[ 0.0 2.71828182845905 0.0]
[ 0.0 0.0 2.71828182845905]
>>> expm([[1,1,0],[1,0,1],[0,1,0]])
[ 3.86814500615414 2.26812870852145 0.841130841230196]
[ 2.26812870852145 2.44114713886289 1.42699786729125]
[0.841130841230196 1.42699786729125 1.6000162976327]
>>> expm([[1,1,0],[1,0,1],[0,1,0]], method='pade')
[ 3.86814500615414 2.26812870852145 0.841130841230196]
[ 2.26812870852145 2.44114713886289 1.42699786729125]
[0.841130841230196 1.42699786729125 1.6000162976327]
>>> expm([[1+j, 0], [1+j,1]])
[(1.46869393991589 + 2.28735528717884j) 0.0]
[ (1.03776739863568 + 3.536943175722j) (2.71828182845905 + 0.0j)]
Matrices with large entries are allowed::
>>> expm(matrix([[1,2],[2,3]])**25)
[5.65024064048415e+2050488462815550 9.14228140091932e+2050488462815550]
[9.14228140091932e+2050488462815550 1.47925220414035e+2050488462815551]
The identity `\exp(A+B) = \exp(A) \exp(B)` does not hold for
noncommuting matrices::
>>> A = hilbert(3)
>>> B = A + eye(3)
>>> chop(mnorm(A*B - B*A))
0.0
>>> chop(mnorm(expm(A+B) - expm(A)*expm(B)))
0.0
>>> B = A + ones(3)
>>> mnorm(A*B - B*A)
1.8
>>> mnorm(expm(A+B) - expm(A)*expm(B))
42.0927851137247
"""
if method == 'pade':
prec = ctx.prec
try:
A = ctx.matrix(A)
ctx.prec += 2*A.rows
res = ctx._exp_pade(A)
finally:
ctx.prec = prec
return res
A = ctx.matrix(A)
prec = ctx.prec
j = int(max(1, ctx.mag(ctx.mnorm(A,'inf'))))
j += int(0.5*prec**0.5)
try:
ctx.prec += 10 + 2*j
tol = +ctx.eps
A = A/2**j
T = A
Y = A**0 + A
k = 2
while 1:
T *= A * (1/ctx.mpf(k))
if ctx.mnorm(T, 'inf') < tol:
break
Y += T
k += 1
for k in xrange(j):
Y = Y*Y
finally:
ctx.prec = prec
Y *= 1
return Y
def cosm(ctx, A):
r"""
Gives the cosine of a square matrix `A`, defined in analogy
with the matrix exponential.
Examples::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> X = eye(3)
>>> cosm(X)
[0.54030230586814 0.0 0.0]
[ 0.0 0.54030230586814 0.0]
[ 0.0 0.0 0.54030230586814]
>>> X = hilbert(3)
>>> cosm(X)
[ 0.424403834569555 -0.316643413047167 -0.221474945949293]
[-0.316643413047167 0.820646708837824 -0.127183694770039]
[-0.221474945949293 -0.127183694770039 0.909236687217541]
>>> X = matrix([[1+j,-2],[0,-j]])
>>> cosm(X)
[(0.833730025131149 - 0.988897705762865j) (1.07485840848393 - 0.17192140544213j)]
[ 0.0 (1.54308063481524 + 0.0j)]
"""
B = 0.5 * (ctx.expm(A*ctx.j) + ctx.expm(A*(-ctx.j)))
if not sum(A.apply(ctx.im).apply(abs)):
B = B.apply(ctx.re)
return B
def sinm(ctx, A):
r"""
Gives the sine of a square matrix `A`, defined in analogy
with the matrix exponential.
Examples::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> X = eye(3)
>>> sinm(X)
[0.841470984807897 0.0 0.0]
[ 0.0 0.841470984807897 0.0]
[ 0.0 0.0 0.841470984807897]
>>> X = hilbert(3)
>>> sinm(X)
[0.711608512150994 0.339783913247439 0.220742837314741]
[0.339783913247439 0.244113865695532 0.187231271174372]
[0.220742837314741 0.187231271174372 0.155816730769635]
>>> X = matrix([[1+j,-2],[0,-j]])
>>> sinm(X)
[(1.29845758141598 + 0.634963914784736j) (-1.96751511930922 + 0.314700021761367j)]
[ 0.0 (0.0 - 1.1752011936438j)]
"""
B = (-0.5j) * (ctx.expm(A*ctx.j) - ctx.expm(A*(-ctx.j)))
if not sum(A.apply(ctx.im).apply(abs)):
B = B.apply(ctx.re)
return B
def _sqrtm_rot(ctx, A, _may_rotate):
# If the iteration fails to converge, cheat by performing
# a rotation by a complex number
u = ctx.j**0.3
return ctx.sqrtm(u*A, _may_rotate) / ctx.sqrt(u)
def sqrtm(ctx, A, _may_rotate=2):
r"""
Computes a square root of the square matrix `A`, i.e. returns
a matrix `B = A^{1/2}` such that `B^2 = A`. The square root
of a matrix, if it exists, is not unique.
**Examples**
Square roots of some simple matrices::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> sqrtm([[1,0], [0,1]])
[1.0 0.0]
[0.0 1.0]
>>> sqrtm([[0,0], [0,0]])
[0.0 0.0]
[0.0 0.0]
>>> sqrtm([[2,0],[0,1]])
[1.4142135623731 0.0]
[ 0.0 1.0]
>>> sqrtm([[1,1],[1,0]])
[ (0.920442065259926 - 0.21728689675164j) (0.568864481005783 + 0.351577584254143j)]
[(0.568864481005783 + 0.351577584254143j) (0.351577584254143 - 0.568864481005783j)]
>>> sqrtm([[1,0],[0,1]])
[1.0 0.0]
[0.0 1.0]
>>> sqrtm([[-1,0],[0,1]])
[(0.0 - 1.0j) 0.0]
[ 0.0 (1.0 + 0.0j)]
>>> sqrtm([[j,0],[0,j]])
[(0.707106781186547 + 0.707106781186547j) 0.0]
[ 0.0 (0.707106781186547 + 0.707106781186547j)]
A square root of a rotation matrix, giving the corresponding
half-angle rotation matrix::
>>> t1 = 0.75
>>> t2 = t1 * 0.5
>>> A1 = matrix([[cos(t1), -sin(t1)], [sin(t1), cos(t1)]])
>>> A2 = matrix([[cos(t2), -sin(t2)], [sin(t2), cos(t2)]])
>>> sqrtm(A1)
[0.930507621912314 -0.366272529086048]
[0.366272529086048 0.930507621912314]
>>> A2
[0.930507621912314 -0.366272529086048]
[0.366272529086048 0.930507621912314]
The identity `(A^2)^{1/2} = A` does not necessarily hold::
>>> A = matrix([[4,1,4],[7,8,9],[10,2,11]])
>>> sqrtm(A**2)
[ 4.0 1.0 4.0]
[ 7.0 8.0 9.0]
[10.0 2.0 11.0]
>>> sqrtm(A)**2
[ 4.0 1.0 4.0]
[ 7.0 8.0 9.0]
[10.0 2.0 11.0]
>>> A = matrix([[-4,1,4],[7,-8,9],[10,2,11]])
>>> sqrtm(A**2)
[ 7.43715112194995 -0.324127569985474 1.8481718827526]
[-0.251549715716942 9.32699765900402 2.48221180985147]
[ 4.11609388833616 0.775751877098258 13.017955697342]
>>> chop(sqrtm(A)**2)
[-4.0 1.0 4.0]
[ 7.0 -8.0 9.0]
[10.0 2.0 11.0]
For some matrices, a square root does not exist::
>>> sqrtm([[0,1], [0,0]])
Traceback (most recent call last):
...
ZeroDivisionError: matrix is numerically singular
Two examples from the documentation for Matlab's ``sqrtm``::
>>> mp.dps = 15; mp.pretty = True
>>> sqrtm([[7,10],[15,22]])
[1.56669890360128 1.74077655955698]
[2.61116483933547 4.17786374293675]
>>>
>>> X = matrix(\
... [[5,-4,1,0,0],
... [-4,6,-4,1,0],
... [1,-4,6,-4,1],
... [0,1,-4,6,-4],
... [0,0,1,-4,5]])
>>> Y = matrix(\
... [[2,-1,-0,-0,-0],
... [-1,2,-1,0,-0],
... [0,-1,2,-1,0],
... [-0,0,-1,2,-1],
... [-0,-0,-0,-1,2]])
>>> mnorm(sqrtm(X) - Y)
4.53155328326114e-19
"""
A = ctx.matrix(A)
# Trivial
if A*0 == A:
return A
prec = ctx.prec
if _may_rotate:
d = ctx.det(A)
if abs(ctx.im(d)) < 16*ctx.eps and ctx.re(d) < 0:
return ctx._sqrtm_rot(A, _may_rotate-1)
try:
ctx.prec += 10
tol = ctx.eps * 128
Y = A
Z = I = A**0
k = 0
# Denman-Beavers iteration
while 1:
Yprev = Y
try:
Y, Z = 0.5*(Y+ctx.inverse(Z)), 0.5*(Z+ctx.inverse(Y))
except ZeroDivisionError:
if _may_rotate:
Y = ctx._sqrtm_rot(A, _may_rotate-1)
break
else:
raise
mag1 = ctx.mnorm(Y-Yprev, 'inf')
mag2 = ctx.mnorm(Y, 'inf')
if mag1 <= mag2*tol:
break
if _may_rotate and k > 6 and not mag1 < mag2 * 0.001:
return ctx._sqrtm_rot(A, _may_rotate-1)
k += 1
if k > ctx.prec:
raise ctx.NoConvergence
finally:
ctx.prec = prec
Y *= 1
return Y
def logm(ctx, A):
r"""
Computes a logarithm of the square matrix `A`, i.e. returns
a matrix `B = \log(A)` such that `\exp(B) = A`. The logarithm
of a matrix, if it exists, is not unique.
**Examples**
Logarithms of some simple matrices::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> X = eye(3)
>>> logm(X)
[0.0 0.0 0.0]
[0.0 0.0 0.0]
[0.0 0.0 0.0]
>>> logm(2*X)
[0.693147180559945 0.0 0.0]
[ 0.0 0.693147180559945 0.0]
[ 0.0 0.0 0.693147180559945]
>>> logm(expm(X))
[1.0 0.0 0.0]
[0.0 1.0 0.0]
[0.0 0.0 1.0]
A logarithm of a complex matrix::
>>> X = matrix([[2+j, 1, 3], [1-j, 1-2*j, 1], [-4, -5, j]])
>>> B = logm(X)
>>> nprint(B)
[ (0.808757 + 0.107759j) (2.20752 + 0.202762j) (1.07376 - 0.773874j)]
[ (0.905709 - 0.107795j) (0.0287395 - 0.824993j) (0.111619 + 0.514272j)]
[(-0.930151 + 0.399512j) (-2.06266 - 0.674397j) (0.791552 + 0.519839j)]
>>> chop(expm(B))
[(2.0 + 1.0j) 1.0 3.0]
[(1.0 - 1.0j) (1.0 - 2.0j) 1.0]
[ -4.0 -5.0 (0.0 + 1.0j)]
A matrix `X` close to the identity matrix, for which
`\log(\exp(X)) = \exp(\log(X)) = X` holds::
>>> X = eye(3) + hilbert(3)/4
>>> X
[ 1.25 0.125 0.0833333333333333]
[ 0.125 1.08333333333333 0.0625]
[0.0833333333333333 0.0625 1.05]
>>> logm(expm(X))
[ 1.25 0.125 0.0833333333333333]
[ 0.125 1.08333333333333 0.0625]
[0.0833333333333333 0.0625 1.05]
>>> expm(logm(X))
[ 1.25 0.125 0.0833333333333333]
[ 0.125 1.08333333333333 0.0625]
[0.0833333333333333 0.0625 1.05]
A logarithm of a rotation matrix, giving back the angle of
the rotation::
>>> t = 3.7
>>> A = matrix([[cos(t),sin(t)],[-sin(t),cos(t)]])
>>> chop(logm(A))
[ 0.0 -2.58318530717959]
[2.58318530717959 0.0]
>>> (2*pi-t)
2.58318530717959
For some matrices, a logarithm does not exist::
>>> logm([[1,0], [0,0]])
Traceback (most recent call last):
...
ZeroDivisionError: matrix is numerically singular
Logarithm of a matrix with large entries::
>>> logm(hilbert(3) * 10**20).apply(re)
[ 45.5597513593433 1.27721006042799 0.317662687717978]
[ 1.27721006042799 42.5222778973542 2.24003708791604]
[0.317662687717978 2.24003708791604 42.395212822267]
"""
A = ctx.matrix(A)
prec = ctx.prec
try:
ctx.prec += 10
tol = ctx.eps * 128
I = A**0
B = A
n = 0
while 1:
B = ctx.sqrtm(B)
n += 1
if ctx.mnorm(B-I, 'inf') < 0.125:
break
T = X = B-I
L = X*0
k = 1
while 1:
if k & 1:
L += T / k
else:
L -= T / k
T *= X
if ctx.mnorm(T, 'inf') < tol:
break
k += 1
if k > ctx.prec:
raise ctx.NoConvergence
finally:
ctx.prec = prec
L *= 2**n
return L
def powm(ctx, A, r):
r"""
Computes `A^r = \exp(A \log r)` for a matrix `A` and complex
number `r`.
**Examples**
Powers and inverse powers of a matrix::
>>> from mpmath import *
>>> mp.dps = 15; mp.pretty = True
>>> A = matrix([[4,1,4],[7,8,9],[10,2,11]])
>>> powm(A, 2)
[ 63.0 20.0 69.0]
[174.0 89.0 199.0]
[164.0 48.0 179.0]
>>> chop(powm(powm(A, 4), 1/4.))
[ 4.0 1.0 4.0]
[ 7.0 8.0 9.0]
[10.0 2.0 11.0]
>>> powm(extraprec(20)(powm)(A, -4), -1/4.)
[ 4.0 1.0 4.0]
[ 7.0 8.0 9.0]
[10.0 2.0 11.0]
>>> chop(powm(powm(A, 1+0.5j), 1/(1+0.5j)))
[ 4.0 1.0 4.0]
[ 7.0 8.0 9.0]
[10.0 2.0 11.0]
>>> powm(extraprec(5)(powm)(A, -1.5), -1/(1.5))
[ 4.0 1.0 4.0]
[ 7.0 8.0 9.0]
[10.0 2.0 11.0]
A Fibonacci-generating matrix::
>>> powm([[1,1],[1,0]], 10)
[89.0 55.0]
[55.0 34.0]
>>> fib(10)
55.0
>>> powm([[1,1],[1,0]], 6.5)
[(16.5166626964253 - 0.0121089837381789j) (10.2078589271083 + 0.0195927472575932j)]
[(10.2078589271083 + 0.0195927472575932j) (6.30880376931698 - 0.0317017309957721j)]
>>> (phi**6.5 - (1-phi)**6.5)/sqrt(5)
(10.2078589271083 - 0.0195927472575932j)
>>> powm([[1,1],[1,0]], 6.2)
[ (14.3076953002666 - 0.008222855781077j) (8.81733464837593 + 0.0133048601383712j)]
[(8.81733464837593 + 0.0133048601383712j) (5.49036065189071 - 0.0215277159194482j)]
>>> (phi**6.2 - (1-phi)**6.2)/sqrt(5)
(8.81733464837593 - 0.0133048601383712j)
"""
A = ctx.matrix(A)
r = ctx.convert(r)
prec = ctx.prec
try:
ctx.prec += 10
if ctx.isint(r):
v = A ** int(r)
elif ctx.isint(r*2):
y = int(r*2)
v = ctx.sqrtm(A) ** y
else:
v = ctx.expm(r*ctx.logm(A))
finally:
ctx.prec = prec
v *= 1
return v

View File

@ -0,0 +1,877 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
##################################################################################################
# module for the eigenvalue problem
# Copyright 2013 Timo Hartmann (thartmann15 at gmail.com)
#
# todo:
# - implement balancing
# - agressive early deflation
#
##################################################################################################
"""
The eigenvalue problem
----------------------
This file contains routines for the eigenvalue problem.
high level routines:
hessenberg : reduction of a real or complex square matrix to upper Hessenberg form
schur : reduction of a real or complex square matrix to upper Schur form
eig : eigenvalues and eigenvectors of a real or complex square matrix
low level routines:
hessenberg_reduce_0 : reduction of a real or complex square matrix to upper Hessenberg form
hessenberg_reduce_1 : auxiliary routine to hessenberg_reduce_0
qr_step : a single implicitly shifted QR step for an upper Hessenberg matrix
hessenberg_qr : Schur decomposition of an upper Hessenberg matrix
eig_tr_r : right eigenvectors of an upper triangular matrix
eig_tr_l : left eigenvectors of an upper triangular matrix
"""
from ..libmp.backend import xrange
class Eigen(object):
pass
def defun(f):
setattr(Eigen, f.__name__, f)
return f
def hessenberg_reduce_0(ctx, A, T):
"""
This routine computes the (upper) Hessenberg decomposition of a square matrix A.
Given A, an unitary matrix Q is calculated such that
Q' A Q = H and Q' Q = Q Q' = 1
where H is an upper Hessenberg matrix, meaning that it only contains zeros
below the first subdiagonal. Here ' denotes the hermitian transpose (i.e.
transposition and conjugation).
parameters:
A (input/output) On input, A contains the square matrix A of
dimension (n,n). On output, A contains a compressed representation
of Q and H.
T (output) An array of length n containing the first elements of
the Householder reflectors.
"""
# internally we work with householder reflections from the right.
# let u be a row vector (i.e. u[i]=A[i,:i]). then
# Q is build up by reflectors of the type (1-v'v) where v is a suitable
# modification of u. these reflectors are applyed to A from the right.
# because we work with reflectors from the right we have to start with
# the bottom row of A and work then upwards (this corresponds to
# some kind of RQ decomposition).
# the first part of the vectors v (i.e. A[i,:(i-1)]) are stored as row vectors
# in the lower left part of A (excluding the diagonal and subdiagonal).
# the last entry of v is stored in T.
# the upper right part of A (including diagonal and subdiagonal) becomes H.
n = A.rows
if n <= 2: return
for i in xrange(n-1, 1, -1):
# scale the vector
scale = 0
for k in xrange(0, i):
scale += abs(ctx.re(A[i,k])) + abs(ctx.im(A[i,k]))
scale_inv = 0
if scale != 0:
scale_inv = 1 / scale
if scale == 0 or ctx.isinf(scale_inv):
# sadly there are floating point numbers not equal to zero whose reciprocal is infinity
T[i] = 0
A[i,i-1] = 0
continue
# calculate parameters for housholder transformation
H = 0
for k in xrange(0, i):
A[i,k] *= scale_inv
rr = ctx.re(A[i,k])
ii = ctx.im(A[i,k])
H += rr * rr + ii * ii
F = A[i,i-1]
f = abs(F)
G = ctx.sqrt(H)
A[i,i-1] = - G * scale
if f == 0:
T[i] = G
else:
ff = F / f
T[i] = F + G * ff
A[i,i-1] *= ff
H += G * f
H = 1 / ctx.sqrt(H)
T[i] *= H
for k in xrange(0, i - 1):
A[i,k] *= H
for j in xrange(0, i):
# apply housholder transformation (from right)
G = ctx.conj(T[i]) * A[j,i-1]
for k in xrange(0, i-1):
G += ctx.conj(A[i,k]) * A[j,k]
A[j,i-1] -= G * T[i]
for k in xrange(0, i-1):
A[j,k] -= G * A[i,k]
for j in xrange(0, n):
# apply housholder transformation (from left)
G = T[i] * A[i-1,j]
for k in xrange(0, i-1):
G += A[i,k] * A[k,j]
A[i-1,j] -= G * ctx.conj(T[i])
for k in xrange(0, i-1):
A[k,j] -= G * ctx.conj(A[i,k])
def hessenberg_reduce_1(ctx, A, T):
"""
This routine forms the unitary matrix Q described in hessenberg_reduce_0.
parameters:
A (input/output) On input, A is the same matrix as delivered by
hessenberg_reduce_0. On output, A is set to Q.
T (input) On input, T is the same array as delivered by hessenberg_reduce_0.
"""
n = A.rows
if n == 1:
A[0,0] = 1
return
A[0,0] = A[1,1] = 1
A[0,1] = A[1,0] = 0
for i in xrange(2, n):
if T[i] != 0:
for j in xrange(0, i):
G = T[i] * A[i-1,j]
for k in xrange(0, i-1):
G += A[i,k] * A[k,j]
A[i-1,j] -= G * ctx.conj(T[i])
for k in xrange(0, i-1):
A[k,j] -= G * ctx.conj(A[i,k])
A[i,i] = 1
for j in xrange(0, i):
A[j,i] = A[i,j] = 0
@defun
def hessenberg(ctx, A, overwrite_a = False):
"""
This routine computes the Hessenberg decomposition of a square matrix A.
Given A, an unitary matrix Q is determined such that
Q' A Q = H and Q' Q = Q Q' = 1
where H is an upper right Hessenberg matrix. Here ' denotes the hermitian
transpose (i.e. transposition and conjugation).
input:
A : a real or complex square matrix
overwrite_a : if true, allows modification of A which may improve
performance. if false, A is not modified.
output:
Q : an unitary matrix
H : an upper right Hessenberg matrix
example:
>>> from mpmath import mp
>>> A = mp.matrix([[3, -1, 2], [2, 5, -5], [-2, -3, 7]])
>>> Q, H = mp.hessenberg(A)
>>> mp.nprint(H, 3) # doctest:+SKIP
[ 3.15 2.23 4.44]
[-0.769 4.85 3.05]
[ 0.0 3.61 7.0]
>>> print(mp.chop(A - Q * H * Q.transpose_conj()))
[0.0 0.0 0.0]
[0.0 0.0 0.0]
[0.0 0.0 0.0]
return value: (Q, H)
"""
n = A.rows
if n == 1:
return (ctx.matrix([[1]]), A)
if not overwrite_a:
A = A.copy()
T = ctx.matrix(n, 1)
hessenberg_reduce_0(ctx, A, T)
Q = A.copy()
hessenberg_reduce_1(ctx, Q, T)
for x in xrange(n):
for y in xrange(x+2, n):
A[y,x] = 0
return Q, A
###########################################################################
def qr_step(ctx, n0, n1, A, Q, shift):
"""
This subroutine executes a single implicitly shifted QR step applied to an
upper Hessenberg matrix A. Given A and shift as input, first an QR
decomposition is calculated:
Q R = A - shift * 1 .
The output is then following matrix:
R Q + shift * 1
parameters:
n0, n1 (input) Two integers which specify the submatrix A[n0:n1,n0:n1]
on which this subroutine operators. The subdiagonal elements
to the left and below this submatrix must be deflated (i.e. zero).
following restriction is imposed: n1>=n0+2
A (input/output) On input, A is an upper Hessenberg matrix.
On output, A is replaced by "R Q + shift * 1"
Q (input/output) The parameter Q is multiplied by the unitary matrix
Q arising from the QR decomposition. Q can also be false, in which
case the unitary matrix Q is not computated.
shift (input) a complex number specifying the shift. idealy close to an
eigenvalue of the bottemmost part of the submatrix A[n0:n1,n0:n1].
references:
Stoer, Bulirsch - Introduction to Numerical Analysis.
Kresser : Numerical Methods for General and Structured Eigenvalue Problems
"""
# implicitly shifted and bulge chasing is explained at p.398/399 in "Stoer, Bulirsch - Introduction to Numerical Analysis"
# for bulge chasing see also "Watkins - The Matrix Eigenvalue Problem" sec.4.5,p.173
# the Givens rotation we used is determined as follows: let c,s be two complex
# numbers. then we have following relation:
#
# v = sqrt(|c|^2 + |s|^2)
#
# 1/v [ c~ s~] [c] = [v]
# [-s c ] [s] [0]
#
# the matrix on the left is our Givens rotation.
n = A.rows
# first step
# calculate givens rotation
c = A[n0 ,n0] - shift
s = A[n0+1,n0]
v = ctx.hypot(ctx.hypot(ctx.re(c), ctx.im(c)), ctx.hypot(ctx.re(s), ctx.im(s)))
if v == 0:
v = 1
c = 1
s = 0
else:
c /= v
s /= v
cc = ctx.conj(c)
cs = ctx.conj(s)
for k in xrange(n0, n):
# apply givens rotation from the left
x = A[n0 ,k]
y = A[n0+1,k]
A[n0 ,k] = cc * x + cs * y
A[n0+1,k] = c * y - s * x
for k in xrange(min(n1, n0+3)):
# apply givens rotation from the right
x = A[k,n0 ]
y = A[k,n0+1]
A[k,n0 ] = c * x + s * y
A[k,n0+1] = cc * y - cs * x
if not isinstance(Q, bool):
for k in xrange(n):
# eigenvectors
x = Q[k,n0 ]
y = Q[k,n0+1]
Q[k,n0 ] = c * x + s * y
Q[k,n0+1] = cc * y - cs * x
# chase the bulge
for j in xrange(n0, n1 - 2):
# calculate givens rotation
c = A[j+1,j]
s = A[j+2,j]
v = ctx.hypot(ctx.hypot(ctx.re(c), ctx.im(c)), ctx.hypot(ctx.re(s), ctx.im(s)))
if v == 0:
A[j+1,j] = 0
v = 1
c = 1
s = 0
else:
A[j+1,j] = v
c /= v
s /= v
A[j+2,j] = 0
cc = ctx.conj(c)
cs = ctx.conj(s)
for k in xrange(j+1, n):
# apply givens rotation from the left
x = A[j+1,k]
y = A[j+2,k]
A[j+1,k] = cc * x + cs * y
A[j+2,k] = c * y - s * x
for k in xrange(0, min(n1, j+4)):
# apply givens rotation from the right
x = A[k,j+1]
y = A[k,j+2]
A[k,j+1] = c * x + s * y
A[k,j+2] = cc * y - cs * x
if not isinstance(Q, bool):
for k in xrange(0, n):
# eigenvectors
x = Q[k,j+1]
y = Q[k,j+2]
Q[k,j+1] = c * x + s * y
Q[k,j+2] = cc * y - cs * x
def hessenberg_qr(ctx, A, Q):
"""
This routine computes the Schur decomposition of an upper Hessenberg matrix A.
Given A, an unitary matrix Q is determined such that
Q' A Q = R and Q' Q = Q Q' = 1
where R is an upper right triangular matrix. Here ' denotes the hermitian
transpose (i.e. transposition and conjugation).
parameters:
A (input/output) On input, A contains an upper Hessenberg matrix.
On output, A is replace by the upper right triangluar matrix R.
Q (input/output) The parameter Q is multiplied by the unitary
matrix Q arising from the Schur decomposition. Q can also be
false, in which case the unitary matrix Q is not computated.
"""
n = A.rows
norm = 0
for x in xrange(n):
for y in xrange(min(x+2, n)):
norm += ctx.re(A[y,x]) ** 2 + ctx.im(A[y,x]) ** 2
norm = ctx.sqrt(norm) / n
if norm == 0:
return
n0 = 0
n1 = n
eps = ctx.eps / (100 * n)
maxits = ctx.dps * 4
its = totalits = 0
while 1:
# kressner p.32 algo 3
# the active submatrix is A[n0:n1,n0:n1]
k = n0
while k + 1 < n1:
s = abs(ctx.re(A[k,k])) + abs(ctx.im(A[k,k])) + abs(ctx.re(A[k+1,k+1])) + abs(ctx.im(A[k+1,k+1]))
if s < eps * norm:
s = norm
if abs(A[k+1,k]) < eps * s:
break
k += 1
if k + 1 < n1:
# deflation found at position (k+1, k)
A[k+1,k] = 0
n0 = k + 1
its = 0
if n0 + 1 >= n1:
# block of size at most two has converged
n0 = 0
n1 = k + 1
if n1 < 2:
# QR algorithm has converged
return
else:
if (its % 30) == 10:
# exceptional shift
shift = A[n1-1,n1-2]
elif (its % 30) == 20:
# exceptional shift
shift = abs(A[n1-1,n1-2])
elif (its % 30) == 29:
# exceptional shift
shift = norm
else:
# A = [ a b ] det(x-A)=x*x-x*tr(A)+det(A)
# [ c d ]
#
# eigenvalues bad: (tr(A)+sqrt((tr(A))**2-4*det(A)))/2
# bad because of cancellation if |c| is small and |a-d| is small, too.
#
# eigenvalues good: (a+d+sqrt((a-d)**2+4*b*c))/2
t = A[n1-2,n1-2] + A[n1-1,n1-1]
s = (A[n1-1,n1-1] - A[n1-2,n1-2]) ** 2 + 4 * A[n1-1,n1-2] * A[n1-2,n1-1]
if ctx.re(s) > 0:
s = ctx.sqrt(s)
else:
s = ctx.sqrt(-s) * 1j
a = (t + s) / 2
b = (t - s) / 2
if abs(A[n1-1,n1-1] - a) > abs(A[n1-1,n1-1] - b):
shift = b
else:
shift = a
its += 1
totalits += 1
qr_step(ctx, n0, n1, A, Q, shift)
if its > maxits:
raise RuntimeError("qr: failed to converge after %d steps" % its)
@defun
def schur(ctx, A, overwrite_a = False):
"""
This routine computes the Schur decomposition of a square matrix A.
Given A, an unitary matrix Q is determined such that
Q' A Q = R and Q' Q = Q Q' = 1
where R is an upper right triangular matrix. Here ' denotes the
hermitian transpose (i.e. transposition and conjugation).
input:
A : a real or complex square matrix
overwrite_a : if true, allows modification of A which may improve
performance. if false, A is not modified.
output:
Q : an unitary matrix
R : an upper right triangular matrix
return value: (Q, R)
example:
>>> from mpmath import mp
>>> A = mp.matrix([[3, -1, 2], [2, 5, -5], [-2, -3, 7]])
>>> Q, R = mp.schur(A)
>>> mp.nprint(R, 3) # doctest:+SKIP
[2.0 0.417 -2.53]
[0.0 4.0 -4.74]
[0.0 0.0 9.0]
>>> print(mp.chop(A - Q * R * Q.transpose_conj()))
[0.0 0.0 0.0]
[0.0 0.0 0.0]
[0.0 0.0 0.0]
warning: The Schur decomposition is not unique.
"""
n = A.rows
if n == 1:
return (ctx.matrix([[1]]), A)
if not overwrite_a:
A = A.copy()
T = ctx.matrix(n, 1)
hessenberg_reduce_0(ctx, A, T)
Q = A.copy()
hessenberg_reduce_1(ctx, Q, T)
for x in xrange(n):
for y in xrange(x + 2, n):
A[y,x] = 0
hessenberg_qr(ctx, A, Q)
return Q, A
def eig_tr_r(ctx, A):
"""
This routine calculates the right eigenvectors of an upper right triangular matrix.
input:
A an upper right triangular matrix
output:
ER a matrix whose columns form the right eigenvectors of A
return value: ER
"""
# this subroutine is inspired by the lapack routines ctrevc.f,clatrs.f
n = A.rows
ER = ctx.eye(n)
eps = ctx.eps
unfl = ctx.ldexp(ctx.one, -ctx.prec * 30)
# since mpmath effectively has no limits on the exponent, we simply scale doubles up
# original double has prec*20
smlnum = unfl * (n / eps)
simin = 1 / ctx.sqrt(eps)
rmax = 1
for i in xrange(1, n):
s = A[i,i]
smin = max(eps * abs(s), smlnum)
for j in xrange(i - 1, -1, -1):
r = 0
for k in xrange(j + 1, i + 1):
r += A[j,k] * ER[k,i]
t = A[j,j] - s
if abs(t) < smin:
t = smin
r = -r / t
ER[j,i] = r
rmax = max(rmax, abs(r))
if rmax > simin:
for k in xrange(j, i+1):
ER[k,i] /= rmax
rmax = 1
if rmax != 1:
for k in xrange(0, i + 1):
ER[k,i] /= rmax
return ER
def eig_tr_l(ctx, A):
"""
This routine calculates the left eigenvectors of an upper right triangular matrix.
input:
A an upper right triangular matrix
output:
EL a matrix whose rows form the left eigenvectors of A
return value: EL
"""
n = A.rows
EL = ctx.eye(n)
eps = ctx.eps
unfl = ctx.ldexp(ctx.one, -ctx.prec * 30)
# since mpmath effectively has no limits on the exponent, we simply scale doubles up
# original double has prec*20
smlnum = unfl * (n / eps)
simin = 1 / ctx.sqrt(eps)
rmax = 1
for i in xrange(0, n - 1):
s = A[i,i]
smin = max(eps * abs(s), smlnum)
for j in xrange(i + 1, n):
r = 0
for k in xrange(i, j):
r += EL[i,k] * A[k,j]
t = A[j,j] - s
if abs(t) < smin:
t = smin
r = -r / t
EL[i,j] = r
rmax = max(rmax, abs(r))
if rmax > simin:
for k in xrange(i, j + 1):
EL[i,k] /= rmax
rmax = 1
if rmax != 1:
for k in xrange(i, n):
EL[i,k] /= rmax
return EL
@defun
def eig(ctx, A, left = False, right = True, overwrite_a = False):
"""
This routine computes the eigenvalues and optionally the left and right
eigenvectors of a square matrix A. Given A, a vector E and matrices ER
and EL are calculated such that
A ER[:,i] = E[i] ER[:,i]
EL[i,:] A = EL[i,:] E[i]
E contains the eigenvalues of A. The columns of ER contain the right eigenvectors
of A whereas the rows of EL contain the left eigenvectors.
input:
A : a real or complex square matrix of shape (n, n)
left : if true, the left eigenvectors are calculated.
right : if true, the right eigenvectors are calculated.
overwrite_a : if true, allows modification of A which may improve
performance. if false, A is not modified.
output:
E : a list of length n containing the eigenvalues of A.
ER : a matrix whose columns contain the right eigenvectors of A.
EL : a matrix whose rows contain the left eigenvectors of A.
return values:
E if left and right are both false.
(E, ER) if right is true and left is false.
(E, EL) if left is true and right is false.
(E, EL, ER) if left and right are true.
examples:
>>> from mpmath import mp
>>> A = mp.matrix([[3, -1, 2], [2, 5, -5], [-2, -3, 7]])
>>> E, ER = mp.eig(A)
>>> print(mp.chop(A * ER[:,0] - E[0] * ER[:,0]))
[0.0]
[0.0]
[0.0]
>>> E, EL, ER = mp.eig(A,left = True, right = True)
>>> E, EL, ER = mp.eig_sort(E, EL, ER)
>>> mp.nprint(E)
[2.0, 4.0, 9.0]
>>> print(mp.chop(A * ER[:,0] - E[0] * ER[:,0]))
[0.0]
[0.0]
[0.0]
>>> print(mp.chop( EL[0,:] * A - EL[0,:] * E[0]))
[0.0 0.0 0.0]
warning:
- If there are multiple eigenvalues, the eigenvectors do not necessarily
span the whole vectorspace, i.e. ER and EL may have not full rank.
Furthermore in that case the eigenvectors are numerical ill-conditioned.
- In the general case the eigenvalues have no natural order.
see also:
- eigh (or eigsy, eighe) for the symmetric eigenvalue problem.
- eig_sort for sorting of eigenvalues and eigenvectors
"""
n = A.rows
if n == 1:
if left and (not right):
return ([A[0]], ctx.matrix([[1]]))
if right and (not left):
return ([A[0]], ctx.matrix([[1]]))
return ([A[0]], ctx.matrix([[1]]), ctx.matrix([[1]]))
if not overwrite_a:
A = A.copy()
T = ctx.zeros(n, 1)
hessenberg_reduce_0(ctx, A, T)
if left or right:
Q = A.copy()
hessenberg_reduce_1(ctx, Q, T)
else:
Q = False
for x in xrange(n):
for y in xrange(x + 2, n):
A[y,x] = 0
hessenberg_qr(ctx, A, Q)
E = [0 for i in xrange(n)]
for i in xrange(n):
E[i] = A[i,i]
if not (left or right):
return E
if left:
EL = eig_tr_l(ctx, A)
EL = EL * Q.transpose_conj()
if right:
ER = eig_tr_r(ctx, A)
ER = Q * ER
if left and (not right):
return (E, EL)
if right and (not left):
return (E, ER)
return (E, EL, ER)
@defun
def eig_sort(ctx, E, EL = False, ER = False, f = "real"):
"""
This routine sorts the eigenvalues and eigenvectors delivered by ``eig``.
parameters:
E : the eigenvalues as delivered by eig
EL : the left eigenvectors as delivered by eig, or false
ER : the right eigenvectors as delivered by eig, or false
f : either a string ("real" sort by increasing real part, "imag" sort by
increasing imag part, "abs" sort by absolute value) or a function
mapping complexs to the reals, i.e. ``f = lambda x: -mp.re(x) ``
would sort the eigenvalues by decreasing real part.
return values:
E if EL and ER are both false.
(E, ER) if ER is not false and left is false.
(E, EL) if EL is not false and right is false.
(E, EL, ER) if EL and ER are not false.
example:
>>> from mpmath import mp
>>> A = mp.matrix([[3, -1, 2], [2, 5, -5], [-2, -3, 7]])
>>> E, EL, ER = mp.eig(A,left = True, right = True)
>>> E, EL, ER = mp.eig_sort(E, EL, ER)
>>> mp.nprint(E)
[2.0, 4.0, 9.0]
>>> E, EL, ER = mp.eig_sort(E, EL, ER,f = lambda x: -mp.re(x))
>>> mp.nprint(E)
[9.0, 4.0, 2.0]
>>> print(mp.chop(A * ER[:,0] - E[0] * ER[:,0]))
[0.0]
[0.0]
[0.0]
>>> print(mp.chop( EL[0,:] * A - EL[0,:] * E[0]))
[0.0 0.0 0.0]
"""
if isinstance(f, str):
if f == "real":
f = ctx.re
elif f == "imag":
f = ctx.im
elif f == "abs":
f = abs
else:
raise RuntimeError("unknown function %s" % f)
n = len(E)
# Sort eigenvalues (bubble-sort)
for i in xrange(n):
imax = i
s = f(E[i]) # s is the current maximal element
for j in xrange(i + 1, n):
c = f(E[j])
if c < s:
s = c
imax = j
if imax != i:
# swap eigenvalues
z = E[i]
E[i] = E[imax]
E[imax] = z
if not isinstance(EL, bool):
for j in xrange(n):
z = EL[i,j]
EL[i,j] = EL[imax,j]
EL[imax,j] = z
if not isinstance(ER, bool):
for j in xrange(n):
z = ER[j,i]
ER[j,i] = ER[j,imax]
ER[j,imax] = z
if isinstance(EL, bool) and isinstance(ER, bool):
return E
if isinstance(EL, bool) and not(isinstance(ER, bool)):
return (E, ER)
if isinstance(ER, bool) and not(isinstance(EL, bool)):
return (E, EL)
return (E, EL, ER)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,790 @@
"""
Linear algebra
--------------
Linear equations
................
Basic linear algebra is implemented; you can for example solve the linear
equation system::
x + 2*y = -10
3*x + 4*y = 10
using ``lu_solve``::
>>> from mpmath import *
>>> mp.pretty = False
>>> A = matrix([[1, 2], [3, 4]])
>>> b = matrix([-10, 10])
>>> x = lu_solve(A, b)
>>> x
matrix(
[['30.0'],
['-20.0']])
If you don't trust the result, use ``residual`` to calculate the residual ||A*x-b||::
>>> residual(A, x, b)
matrix(
[['3.46944695195361e-18'],
['3.46944695195361e-18']])
>>> str(eps)
'2.22044604925031e-16'
As you can see, the solution is quite accurate. The error is caused by the
inaccuracy of the internal floating point arithmetic. Though, it's even smaller
than the current machine epsilon, which basically means you can trust the
result.
If you need more speed, use NumPy, or ``fp.lu_solve`` for a floating-point computation.
>>> fp.lu_solve(A, b) # doctest: +ELLIPSIS
matrix(...)
``lu_solve`` accepts overdetermined systems. It is usually not possible to solve
such systems, so the residual is minimized instead. Internally this is done
using Cholesky decomposition to compute a least squares approximation. This means
that that ``lu_solve`` will square the errors. If you can't afford this, use
``qr_solve`` instead. It is twice as slow but more accurate, and it calculates
the residual automatically.
Matrix factorization
....................
The function ``lu`` computes an explicit LU factorization of a matrix::
>>> P, L, U = lu(matrix([[0,2,3],[4,5,6],[7,8,9]]))
>>> print(P)
[0.0 0.0 1.0]
[1.0 0.0 0.0]
[0.0 1.0 0.0]
>>> print(L)
[ 1.0 0.0 0.0]
[ 0.0 1.0 0.0]
[0.571428571428571 0.214285714285714 1.0]
>>> print(U)
[7.0 8.0 9.0]
[0.0 2.0 3.0]
[0.0 0.0 0.214285714285714]
>>> print(P.T*L*U)
[0.0 2.0 3.0]
[4.0 5.0 6.0]
[7.0 8.0 9.0]
Interval matrices
-----------------
Matrices may contain interval elements. This allows one to perform
basic linear algebra operations such as matrix multiplication
and equation solving with rigorous error bounds::
>>> a = iv.matrix([['0.1','0.3','1.0'],
... ['7.1','5.5','4.8'],
... ['3.2','4.4','5.6']])
>>>
>>> b = iv.matrix(['4','0.6','0.5'])
>>> c = iv.lu_solve(a, b)
>>> print(c)
[ [5.2582327113062568605927528666, 5.25823271130625686059275702219]]
[[-13.1550493962678375411635581388, -13.1550493962678375411635540152]]
[ [7.42069154774972557628979076189, 7.42069154774972557628979190734]]
>>> print(a*c)
[ [3.99999999999999999999999844904, 4.00000000000000000000000155096]]
[[0.599999999999999999999968898009, 0.600000000000000000000031763736]]
[[0.499999999999999999999979320485, 0.500000000000000000000020679515]]
"""
# TODO:
# *implement high-level qr()
# *test unitvector
# *iterative solving
from copy import copy
from ..libmp.backend import xrange
class LinearAlgebraMethods(object):
def LU_decomp(ctx, A, overwrite=False, use_cache=True):
"""
LU-factorization of a n*n matrix using the Gauss algorithm.
Returns L and U in one matrix and the pivot indices.
Use overwrite to specify whether A will be overwritten with L and U.
"""
if not A.rows == A.cols:
raise ValueError('need n*n matrix')
# get from cache if possible
if use_cache and isinstance(A, ctx.matrix) and A._LU:
return A._LU
if not overwrite:
orig = A
A = A.copy()
tol = ctx.absmin(ctx.mnorm(A,1) * ctx.eps) # each pivot element has to be bigger
n = A.rows
p = [None]*(n - 1)
for j in xrange(n - 1):
# pivoting, choose max(abs(reciprocal row sum)*abs(pivot element))
biggest = 0
for k in xrange(j, n):
s = ctx.fsum([ctx.absmin(A[k,l]) for l in xrange(j, n)])
if ctx.absmin(s) <= tol:
raise ZeroDivisionError('matrix is numerically singular')
current = 1/s * ctx.absmin(A[k,j])
if current > biggest: # TODO: what if equal?
biggest = current
p[j] = k
# swap rows according to p
ctx.swap_row(A, j, p[j])
if ctx.absmin(A[j,j]) <= tol:
raise ZeroDivisionError('matrix is numerically singular')
# calculate elimination factors and add rows
for i in xrange(j + 1, n):
A[i,j] /= A[j,j]
for k in xrange(j + 1, n):
A[i,k] -= A[i,j]*A[j,k]
if ctx.absmin(A[n - 1,n - 1]) <= tol:
raise ZeroDivisionError('matrix is numerically singular')
# cache decomposition
if not overwrite and isinstance(orig, ctx.matrix):
orig._LU = (A, p)
return A, p
def L_solve(ctx, L, b, p=None):
"""
Solve the lower part of a LU factorized matrix for y.
"""
if L.rows != L.cols:
raise RuntimeError("need n*n matrix")
n = L.rows
if len(b) != n:
raise ValueError("Value should be equal to n")
b = copy(b)
if p: # swap b according to p
for k in xrange(0, len(p)):
ctx.swap_row(b, k, p[k])
# solve
for i in xrange(1, n):
for j in xrange(i):
b[i] -= L[i,j] * b[j]
return b
def U_solve(ctx, U, y):
"""
Solve the upper part of a LU factorized matrix for x.
"""
if U.rows != U.cols:
raise RuntimeError("need n*n matrix")
n = U.rows
if len(y) != n:
raise ValueError("Value should be equal to n")
x = copy(y)
for i in xrange(n - 1, -1, -1):
for j in xrange(i + 1, n):
x[i] -= U[i,j] * x[j]
x[i] /= U[i,i]
return x
def lu_solve(ctx, A, b, **kwargs):
"""
Ax = b => x
Solve a determined or overdetermined linear equations system.
Fast LU decomposition is used, which is less accurate than QR decomposition
(especially for overdetermined systems), but it's twice as efficient.
Use qr_solve if you want more precision or have to solve a very ill-
conditioned system.
If you specify real=True, it does not check for overdeterminded complex
systems.
"""
prec = ctx.prec
try:
ctx.prec += 10
# do not overwrite A nor b
A, b = ctx.matrix(A, **kwargs).copy(), ctx.matrix(b, **kwargs).copy()
if A.rows < A.cols:
raise ValueError('cannot solve underdetermined system')
if A.rows > A.cols:
# use least-squares method if overdetermined
# (this increases errors)
AH = A.H
A = AH * A
b = AH * b
if (kwargs.get('real', False) or
not sum(type(i) is ctx.mpc for i in A)):
# TODO: necessary to check also b?
x = ctx.cholesky_solve(A, b)
else:
x = ctx.lu_solve(A, b)
else:
# LU factorization
A, p = ctx.LU_decomp(A)
b = ctx.L_solve(A, b, p)
x = ctx.U_solve(A, b)
finally:
ctx.prec = prec
return x
def improve_solution(ctx, A, x, b, maxsteps=1):
"""
Improve a solution to a linear equation system iteratively.
This re-uses the LU decomposition and is thus cheap.
Usually 3 up to 4 iterations are giving the maximal improvement.
"""
if A.rows != A.cols:
raise RuntimeError("need n*n matrix") # TODO: really?
for _ in xrange(maxsteps):
r = ctx.residual(A, x, b)
if ctx.norm(r, 2) < 10*ctx.eps:
break
# this uses cached LU decomposition and is thus cheap
dx = ctx.lu_solve(A, -r)
x += dx
return x
def lu(ctx, A):
"""
A -> P, L, U
LU factorisation of a square matrix A. L is the lower, U the upper part.
P is the permutation matrix indicating the row swaps.
P*A = L*U
If you need efficiency, use the low-level method LU_decomp instead, it's
much more memory efficient.
"""
# get factorization
A, p = ctx.LU_decomp(A)
n = A.rows
L = ctx.matrix(n)
U = ctx.matrix(n)
for i in xrange(n):
for j in xrange(n):
if i > j:
L[i,j] = A[i,j]
elif i == j:
L[i,j] = 1
U[i,j] = A[i,j]
else:
U[i,j] = A[i,j]
# calculate permutation matrix
P = ctx.eye(n)
for k in xrange(len(p)):
ctx.swap_row(P, k, p[k])
return P, L, U
def unitvector(ctx, n, i):
"""
Return the i-th n-dimensional unit vector.
"""
assert 0 < i <= n, 'this unit vector does not exist'
return [ctx.zero]*(i-1) + [ctx.one] + [ctx.zero]*(n-i)
def inverse(ctx, A, **kwargs):
"""
Calculate the inverse of a matrix.
If you want to solve an equation system Ax = b, it's recommended to use
solve(A, b) instead, it's about 3 times more efficient.
"""
prec = ctx.prec
try:
ctx.prec += 10
# do not overwrite A
A = ctx.matrix(A, **kwargs).copy()
n = A.rows
# get LU factorisation
A, p = ctx.LU_decomp(A)
cols = []
# calculate unit vectors and solve corresponding system to get columns
for i in xrange(1, n + 1):
e = ctx.unitvector(n, i)
y = ctx.L_solve(A, e, p)
cols.append(ctx.U_solve(A, y))
# convert columns to matrix
inv = []
for i in xrange(n):
row = []
for j in xrange(n):
row.append(cols[j][i])
inv.append(row)
result = ctx.matrix(inv, **kwargs)
finally:
ctx.prec = prec
return result
def householder(ctx, A):
"""
(A|b) -> H, p, x, res
(A|b) is the coefficient matrix with left hand side of an optionally
overdetermined linear equation system.
H and p contain all information about the transformation matrices.
x is the solution, res the residual.
"""
if not isinstance(A, ctx.matrix):
raise TypeError("A should be a type of ctx.matrix")
m = A.rows
n = A.cols
if m < n - 1:
raise RuntimeError("Columns should not be less than rows")
# calculate Householder matrix
p = []
for j in xrange(0, n - 1):
s = ctx.fsum(abs(A[i,j])**2 for i in xrange(j, m))
if not abs(s) > ctx.eps:
raise ValueError('matrix is numerically singular')
p.append(-ctx.sign(ctx.re(A[j,j])) * ctx.sqrt(s))
kappa = ctx.one / (s - p[j] * A[j,j])
A[j,j] -= p[j]
for k in xrange(j+1, n):
y = ctx.fsum(ctx.conj(A[i,j]) * A[i,k] for i in xrange(j, m)) * kappa
for i in xrange(j, m):
A[i,k] -= A[i,j] * y
# solve Rx = c1
x = [A[i,n - 1] for i in xrange(n - 1)]
for i in xrange(n - 2, -1, -1):
x[i] -= ctx.fsum(A[i,j] * x[j] for j in xrange(i + 1, n - 1))
x[i] /= p[i]
# calculate residual
if not m == n - 1:
r = [A[m-1-i, n-1] for i in xrange(m - n + 1)]
else:
# determined system, residual should be 0
r = [0]*m # maybe a bad idea, changing r[i] will change all elements
return A, p, x, r
#def qr(ctx, A):
# """
# A -> Q, R
#
# QR factorisation of a square matrix A using Householder decomposition.
# Q is orthogonal, this leads to very few numerical errors.
#
# A = Q*R
# """
# H, p, x, res = householder(A)
# TODO: implement this
def residual(ctx, A, x, b, **kwargs):
"""
Calculate the residual of a solution to a linear equation system.
r = A*x - b for A*x = b
"""
oldprec = ctx.prec
try:
ctx.prec *= 2
A, x, b = ctx.matrix(A, **kwargs), ctx.matrix(x, **kwargs), ctx.matrix(b, **kwargs)
return A*x - b
finally:
ctx.prec = oldprec
def qr_solve(ctx, A, b, norm=None, **kwargs):
"""
Ax = b => x, ||Ax - b||
Solve a determined or overdetermined linear equations system and
calculate the norm of the residual (error).
QR decomposition using Householder factorization is applied, which gives very
accurate results even for ill-conditioned matrices. qr_solve is twice as
efficient.
"""
if norm is None:
norm = ctx.norm
prec = ctx.prec
try:
ctx.prec += 10
# do not overwrite A nor b
A, b = ctx.matrix(A, **kwargs).copy(), ctx.matrix(b, **kwargs).copy()
if A.rows < A.cols:
raise ValueError('cannot solve underdetermined system')
H, p, x, r = ctx.householder(ctx.extend(A, b))
res = ctx.norm(r)
# calculate residual "manually" for determined systems
if res == 0:
res = ctx.norm(ctx.residual(A, x, b))
return ctx.matrix(x, **kwargs), res
finally:
ctx.prec = prec
def cholesky(ctx, A, tol=None):
r"""
Cholesky decomposition of a symmetric positive-definite matrix `A`.
Returns a lower triangular matrix `L` such that `A = L \times L^T`.
More generally, for a complex Hermitian positive-definite matrix,
a Cholesky decomposition satisfying `A = L \times L^H` is returned.
The Cholesky decomposition can be used to solve linear equation
systems twice as efficiently as LU decomposition, or to
test whether `A` is positive-definite.
The optional parameter ``tol`` determines the tolerance for
verifying positive-definiteness.
**Examples**
Cholesky decomposition of a positive-definite symmetric matrix::
>>> from mpmath import *
>>> mp.dps = 25; mp.pretty = True
>>> A = eye(3) + hilbert(3)
>>> nprint(A)
[ 2.0 0.5 0.333333]
[ 0.5 1.33333 0.25]
[0.333333 0.25 1.2]
>>> L = cholesky(A)
>>> nprint(L)
[ 1.41421 0.0 0.0]
[0.353553 1.09924 0.0]
[0.235702 0.15162 1.05899]
>>> chop(A - L*L.T)
[0.0 0.0 0.0]
[0.0 0.0 0.0]
[0.0 0.0 0.0]
Cholesky decomposition of a Hermitian matrix::
>>> A = eye(3) + matrix([[0,0.25j,-0.5j],[-0.25j,0,0],[0.5j,0,0]])
>>> L = cholesky(A)
>>> nprint(L)
[ 1.0 0.0 0.0]
[(0.0 - 0.25j) (0.968246 + 0.0j) 0.0]
[ (0.0 + 0.5j) (0.129099 + 0.0j) (0.856349 + 0.0j)]
>>> chop(A - L*L.H)
[0.0 0.0 0.0]
[0.0 0.0 0.0]
[0.0 0.0 0.0]
Attempted Cholesky decomposition of a matrix that is not positive
definite::
>>> A = -eye(3) + hilbert(3)
>>> L = cholesky(A)
Traceback (most recent call last):
...
ValueError: matrix is not positive-definite
**References**
1. [Wikipedia]_ http://en.wikipedia.org/wiki/Cholesky_decomposition
"""
if not isinstance(A, ctx.matrix):
raise RuntimeError("A should be a type of ctx.matrix")
if not A.rows == A.cols:
raise ValueError('need n*n matrix')
if tol is None:
tol = +ctx.eps
n = A.rows
L = ctx.matrix(n)
for j in xrange(n):
c = ctx.re(A[j,j])
if abs(c-A[j,j]) > tol:
raise ValueError('matrix is not Hermitian')
s = c - ctx.fsum((L[j,k] for k in xrange(j)),
absolute=True, squared=True)
if s < tol:
raise ValueError('matrix is not positive-definite')
L[j,j] = ctx.sqrt(s)
for i in xrange(j, n):
it1 = (L[i,k] for k in xrange(j))
it2 = (L[j,k] for k in xrange(j))
t = ctx.fdot(it1, it2, conjugate=True)
L[i,j] = (A[i,j] - t) / L[j,j]
return L
def cholesky_solve(ctx, A, b, **kwargs):
"""
Ax = b => x
Solve a symmetric positive-definite linear equation system.
This is twice as efficient as lu_solve.
Typical use cases:
* A.T*A
* Hessian matrix
* differential equations
"""
prec = ctx.prec
try:
ctx.prec += 10
# do not overwrite A nor b
A, b = ctx.matrix(A, **kwargs).copy(), ctx.matrix(b, **kwargs).copy()
if A.rows != A.cols:
raise ValueError('can only solve determined system')
# Cholesky factorization
L = ctx.cholesky(A)
# solve
n = L.rows
if len(b) != n:
raise ValueError("Value should be equal to n")
for i in xrange(n):
b[i] -= ctx.fsum(L[i,j] * b[j] for j in xrange(i))
b[i] /= L[i,i]
x = ctx.U_solve(L.T, b)
return x
finally:
ctx.prec = prec
def det(ctx, A):
"""
Calculate the determinant of a matrix.
"""
prec = ctx.prec
try:
# do not overwrite A
A = ctx.matrix(A).copy()
# use LU factorization to calculate determinant
try:
R, p = ctx.LU_decomp(A)
except ZeroDivisionError:
return 0
z = 1
for i, e in enumerate(p):
if i != e:
z *= -1
for i in xrange(A.rows):
z *= R[i,i]
return z
finally:
ctx.prec = prec
def cond(ctx, A, norm=None):
"""
Calculate the condition number of a matrix using a specified matrix norm.
The condition number estimates the sensitivity of a matrix to errors.
Example: small input errors for ill-conditioned coefficient matrices
alter the solution of the system dramatically.
For ill-conditioned matrices it's recommended to use qr_solve() instead
of lu_solve(). This does not help with input errors however, it just avoids
to add additional errors.
Definition: cond(A) = ||A|| * ||A**-1||
"""
if norm is None:
norm = lambda x: ctx.mnorm(x,1)
return norm(A) * norm(ctx.inverse(A))
def lu_solve_mat(ctx, a, b):
"""Solve a * x = b where a and b are matrices."""
r = ctx.matrix(a.rows, b.cols)
for i in range(b.cols):
c = ctx.lu_solve(a, b.column(i))
for j in range(len(c)):
r[j, i] = c[j]
return r
def qr(ctx, A, mode = 'full', edps = 10):
"""
Compute a QR factorization $A = QR$ where
A is an m x n matrix of real or complex numbers where m >= n
mode has following meanings:
(1) mode = 'raw' returns two matrixes (A, tau) in the
internal format used by LAPACK
(2) mode = 'skinny' returns the leading n columns of Q
and n rows of R
(3) Any other value returns the leading m columns of Q
and m rows of R
edps is the increase in mp precision used for calculations
**Examples**
>>> from mpmath import *
>>> mp.dps = 15
>>> mp.pretty = True
>>> A = matrix([[1, 2], [3, 4], [1, 1]])
>>> Q, R = qr(A)
>>> Q
[-0.301511344577764 0.861640436855329 0.408248290463863]
[-0.904534033733291 -0.123091490979333 -0.408248290463863]
[-0.301511344577764 -0.492365963917331 0.816496580927726]
>>> R
[-3.3166247903554 -4.52267016866645]
[ 0.0 0.738548945875996]
[ 0.0 0.0]
>>> Q * R
[1.0 2.0]
[3.0 4.0]
[1.0 1.0]
>>> chop(Q.T * Q)
[1.0 0.0 0.0]
[0.0 1.0 0.0]
[0.0 0.0 1.0]
>>> B = matrix([[1+0j, 2-3j], [3+j, 4+5j]])
>>> Q, R = qr(B)
>>> nprint(Q)
[ (-0.301511 + 0.0j) (0.0695795 - 0.95092j)]
[(-0.904534 - 0.301511j) (-0.115966 + 0.278318j)]
>>> nprint(R)
[(-3.31662 + 0.0j) (-5.72872 - 2.41209j)]
[ 0.0 (3.91965 + 0.0j)]
>>> Q * R
[(1.0 + 0.0j) (2.0 - 3.0j)]
[(3.0 + 1.0j) (4.0 + 5.0j)]
>>> chop(Q.T * Q.conjugate())
[1.0 0.0]
[0.0 1.0]
"""
# check values before continuing
assert isinstance(A, ctx.matrix)
m = A.rows
n = A.cols
assert n >= 0
assert m >= n
assert edps >= 0
# check for complex data type
cmplx = any(type(x) is ctx.mpc for x in A)
# temporarily increase the precision and initialize
with ctx.extradps(edps):
tau = ctx.matrix(n,1)
A = A.copy()
# ---------------
# FACTOR MATRIX A
# ---------------
if cmplx:
one = ctx.mpc('1.0', '0.0')
zero = ctx.mpc('0.0', '0.0')
rzero = ctx.mpf('0.0')
# main loop to factor A (complex)
for j in xrange(0, n):
alpha = A[j,j]
alphr = ctx.re(alpha)
alphi = ctx.im(alpha)
if (m-j) >= 2:
xnorm = ctx.fsum( A[i,j]*ctx.conj(A[i,j]) for i in xrange(j+1, m) )
xnorm = ctx.re( ctx.sqrt(xnorm) )
else:
xnorm = rzero
if (xnorm == rzero) and (alphi == rzero):
tau[j] = zero
continue
if alphr < rzero:
beta = ctx.sqrt(alphr**2 + alphi**2 + xnorm**2)
else:
beta = -ctx.sqrt(alphr**2 + alphi**2 + xnorm**2)
tau[j] = ctx.mpc( (beta - alphr) / beta, -alphi / beta )
t = -ctx.conj(tau[j])
za = one / (alpha - beta)
for i in xrange(j+1, m):
A[i,j] *= za
A[j,j] = one
for k in xrange(j+1, n):
y = ctx.fsum(A[i,j] * ctx.conj(A[i,k]) for i in xrange(j, m))
temp = t * ctx.conj(y)
for i in xrange(j, m):
A[i,k] += A[i,j] * temp
A[j,j] = ctx.mpc(beta, '0.0')
else:
one = ctx.mpf('1.0')
zero = ctx.mpf('0.0')
# main loop to factor A (real)
for j in xrange(0, n):
alpha = A[j,j]
if (m-j) > 2:
xnorm = ctx.fsum( (A[i,j])**2 for i in xrange(j+1, m) )
xnorm = ctx.sqrt(xnorm)
elif (m-j) == 2:
xnorm = abs( A[m-1,j] )
else:
xnorm = zero
if xnorm == zero:
tau[j] = zero
continue
if alpha < zero:
beta = ctx.sqrt(alpha**2 + xnorm**2)
else:
beta = -ctx.sqrt(alpha**2 + xnorm**2)
tau[j] = (beta - alpha) / beta
t = -tau[j]
da = one / (alpha - beta)
for i in xrange(j+1, m):
A[i,j] *= da
A[j,j] = one
for k in xrange(j+1, n):
y = ctx.fsum( A[i,j] * A[i,k] for i in xrange(j, m) )
temp = t * y
for i in xrange(j,m):
A[i,k] += A[i,j] * temp
A[j,j] = beta
# return factorization in same internal format as LAPACK
if (mode == 'raw') or (mode == 'RAW'):
return A, tau
# ----------------------------------
# FORM Q USING BACKWARD ACCUMULATION
# ----------------------------------
# form R before the values are overwritten
R = A.copy()
for j in xrange(0, n):
for i in xrange(j+1, m):
R[i,j] = zero
# set the value of p (number of columns of Q to return)
p = m
if (mode == 'skinny') or (mode == 'SKINNY'):
p = n
# add columns to A if needed and initialize
A.cols += (p-n)
for j in xrange(0, p):
A[j,j] = one
for i in xrange(0, j):
A[i,j] = zero
# main loop to form Q
for j in xrange(n-1, -1, -1):
t = -tau[j]
A[j,j] += t
for k in xrange(j+1, p):
if cmplx:
y = ctx.fsum(A[i,j] * ctx.conj(A[i,k]) for i in xrange(j+1, m))
temp = t * ctx.conj(y)
else:
y = ctx.fsum(A[i,j] * A[i,k] for i in xrange(j+1, m))
temp = t * y
A[j,k] = temp
for i in xrange(j+1, m):
A[i,k] += A[i,j] * temp
for i in xrange(j+1, m):
A[i, j] *= t
return A, R[0:p,0:n]
# ------------------
# END OF FUNCTION QR
# ------------------

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,240 @@
import operator
import sys
from .libmp import int_types, mpf_hash, bitcount, from_man_exp, HASH_MODULUS
new = object.__new__
def create_reduced(p, q, _cache={}):
key = p, q
if key in _cache:
return _cache[key]
x, y = p, q
while y:
x, y = y, x % y
if x != 1:
p //= x
q //= x
v = new(mpq)
v._mpq_ = p, q
# Speedup integers, half-integers and other small fractions
if q <= 4 and abs(key[0]) < 100:
_cache[key] = v
return v
class mpq(object):
"""
Exact rational type, currently only intended for internal use.
"""
__slots__ = ["_mpq_"]
def __new__(cls, p, q=1):
if type(p) is tuple:
p, q = p
elif hasattr(p, '_mpq_'):
p, q = p._mpq_
return create_reduced(p, q)
def __repr__(s):
return "mpq(%s,%s)" % s._mpq_
def __str__(s):
return "(%s/%s)" % s._mpq_
def __int__(s):
a, b = s._mpq_
return a // b
def __nonzero__(s):
return bool(s._mpq_[0])
__bool__ = __nonzero__
def __hash__(s):
a, b = s._mpq_
if sys.version_info >= (3, 2):
inverse = pow(b, HASH_MODULUS-2, HASH_MODULUS)
if not inverse:
h = sys.hash_info.inf
else:
h = (abs(a) * inverse) % HASH_MODULUS
if a < 0: h = -h
if h == -1: h = -2
return h
else:
if b == 1:
return hash(a)
# Power of two: mpf compatible hash
if not (b & (b-1)):
return mpf_hash(from_man_exp(a, 1-bitcount(b)))
return hash((a,b))
def __eq__(s, t):
ttype = type(t)
if ttype is mpq:
return s._mpq_ == t._mpq_
if ttype in int_types:
a, b = s._mpq_
if b != 1:
return False
return a == t
return NotImplemented
def __ne__(s, t):
ttype = type(t)
if ttype is mpq:
return s._mpq_ != t._mpq_
if ttype in int_types:
a, b = s._mpq_
if b != 1:
return True
return a != t
return NotImplemented
def _cmp(s, t, op):
ttype = type(t)
if ttype in int_types:
a, b = s._mpq_
return op(a, t*b)
if ttype is mpq:
a, b = s._mpq_
c, d = t._mpq_
return op(a*d, b*c)
return NotImplementedError
def __lt__(s, t): return s._cmp(t, operator.lt)
def __le__(s, t): return s._cmp(t, operator.le)
def __gt__(s, t): return s._cmp(t, operator.gt)
def __ge__(s, t): return s._cmp(t, operator.ge)
def __abs__(s):
a, b = s._mpq_
if a >= 0:
return s
v = new(mpq)
v._mpq_ = -a, b
return v
def __neg__(s):
a, b = s._mpq_
v = new(mpq)
v._mpq_ = -a, b
return v
def __pos__(s):
return s
def __add__(s, t):
ttype = type(t)
if ttype is mpq:
a, b = s._mpq_
c, d = t._mpq_
return create_reduced(a*d+b*c, b*d)
if ttype in int_types:
a, b = s._mpq_
v = new(mpq)
v._mpq_ = a+b*t, b
return v
return NotImplemented
__radd__ = __add__
def __sub__(s, t):
ttype = type(t)
if ttype is mpq:
a, b = s._mpq_
c, d = t._mpq_
return create_reduced(a*d-b*c, b*d)
if ttype in int_types:
a, b = s._mpq_
v = new(mpq)
v._mpq_ = a-b*t, b
return v
return NotImplemented
def __rsub__(s, t):
ttype = type(t)
if ttype is mpq:
a, b = s._mpq_
c, d = t._mpq_
return create_reduced(b*c-a*d, b*d)
if ttype in int_types:
a, b = s._mpq_
v = new(mpq)
v._mpq_ = b*t-a, b
return v
return NotImplemented
def __mul__(s, t):
ttype = type(t)
if ttype is mpq:
a, b = s._mpq_
c, d = t._mpq_
return create_reduced(a*c, b*d)
if ttype in int_types:
a, b = s._mpq_
return create_reduced(a*t, b)
return NotImplemented
__rmul__ = __mul__
def __div__(s, t):
ttype = type(t)
if ttype is mpq:
a, b = s._mpq_
c, d = t._mpq_
return create_reduced(a*d, b*c)
if ttype in int_types:
a, b = s._mpq_
return create_reduced(a, b*t)
return NotImplemented
def __rdiv__(s, t):
ttype = type(t)
if ttype is mpq:
a, b = s._mpq_
c, d = t._mpq_
return create_reduced(b*c, a*d)
if ttype in int_types:
a, b = s._mpq_
return create_reduced(b*t, a)
return NotImplemented
def __pow__(s, t):
ttype = type(t)
if ttype in int_types:
a, b = s._mpq_
if t:
if t < 0:
a, b, t = b, a, -t
v = new(mpq)
v._mpq_ = a**t, b**t
return v
raise ZeroDivisionError
return NotImplemented
mpq_1 = mpq((1,1))
mpq_0 = mpq((0,1))
mpq_1_2 = mpq((1,2))
mpq_3_2 = mpq((3,2))
mpq_1_4 = mpq((1,4))
mpq_1_16 = mpq((1,16))
mpq_3_16 = mpq((3,16))
mpq_5_2 = mpq((5,2))
mpq_3_4 = mpq((3,4))
mpq_7_4 = mpq((7,4))
mpq_5_4 = mpq((5,4))
# Register with "numbers" ABC
# We do not subclass, hence we do not use the @abstractmethod checks. While
# this is less invasive it may turn out that we do not actually support
# parts of the expected interfaces. See
# http://docs.python.org/2/library/numbers.html for list of abstract
# methods.
try:
import numbers
numbers.Rational.register(mpq)
except ImportError:
pass

Some files were not shown because too many files have changed in this diff Show More