I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,14 @@
from sympy.core.symbol import symbols
from sympy.codegen.abstract_nodes import List
def test_List():
l = List(2, 3, 4)
assert l == List(2, 3, 4)
assert str(l) == "[2, 3, 4]"
x, y, z = symbols('x y z')
l = List(x**2,y**3,z**4)
# contrary to python's built-in list, we can call e.g. "replace" on List.
m = l.replace(lambda arg: arg.is_Pow and arg.exp>2, lambda p: p.base-p.exp)
assert m == [x**2, y-3, z-4]
hash(m)

View File

@ -0,0 +1,179 @@
import tempfile
from sympy import log, Min, Max, sqrt
from sympy.core.numbers import Float
from sympy.core.symbol import Symbol, symbols
from sympy.functions.elementary.trigonometric import cos
from sympy.codegen.ast import Assignment, Raise, RuntimeError_, QuotedString
from sympy.codegen.algorithms import newtons_method, newtons_method_function
from sympy.codegen.cfunctions import expm1
from sympy.codegen.fnodes import bind_C
from sympy.codegen.futils import render_as_module as f_module
from sympy.codegen.pyutils import render_as_module as py_module
from sympy.external import import_module
from sympy.printing.codeprinter import ccode
from sympy.utilities._compilation import compile_link_import_strings, has_c, has_fortran
from sympy.utilities._compilation.util import may_xfail
from sympy.testing.pytest import skip, raises
cython = import_module('cython')
wurlitzer = import_module('wurlitzer')
def test_newtons_method():
x, dx, atol = symbols('x dx atol')
expr = cos(x) - x**3
algo = newtons_method(expr, x, atol, dx)
assert algo.has(Assignment(dx, -expr/expr.diff(x)))
@may_xfail
def test_newtons_method_function__ccode():
x = Symbol('x', real=True)
expr = cos(x) - x**3
func = newtons_method_function(expr, x)
if not cython:
skip("cython not installed.")
if not has_c():
skip("No C compiler found.")
compile_kw = {"std": 'c99'}
with tempfile.TemporaryDirectory() as folder:
mod, info = compile_link_import_strings([
('newton.c', ('#include <math.h>\n'
'#include <stdio.h>\n') + ccode(func)),
('_newton.pyx', ("#cython: language_level={}\n".format("3") +
"cdef extern double newton(double)\n"
"def py_newton(x):\n"
" return newton(x)\n"))
], build_dir=folder, compile_kwargs=compile_kw)
assert abs(mod.py_newton(0.5) - 0.865474033102) < 1e-12
@may_xfail
def test_newtons_method_function__fcode():
x = Symbol('x', real=True)
expr = cos(x) - x**3
func = newtons_method_function(expr, x, attrs=[bind_C(name='newton')])
if not cython:
skip("cython not installed.")
if not has_fortran():
skip("No Fortran compiler found.")
f_mod = f_module([func], 'mod_newton')
with tempfile.TemporaryDirectory() as folder:
mod, info = compile_link_import_strings([
('newton.f90', f_mod),
('_newton.pyx', ("#cython: language_level={}\n".format("3") +
"cdef extern double newton(double*)\n"
"def py_newton(double x):\n"
" return newton(&x)\n"))
], build_dir=folder)
assert abs(mod.py_newton(0.5) - 0.865474033102) < 1e-12
def test_newtons_method_function__pycode():
x = Symbol('x', real=True)
expr = cos(x) - x**3
func = newtons_method_function(expr, x)
py_mod = py_module(func)
namespace = {}
exec(py_mod, namespace, namespace)
res = eval('newton(0.5)', namespace)
assert abs(res - 0.865474033102) < 1e-12
@may_xfail
def test_newtons_method_function__ccode_parameters():
args = x, A, k, p = symbols('x A k p')
expr = A*cos(k*x) - p*x**3
raises(ValueError, lambda: newtons_method_function(expr, x))
use_wurlitzer = wurlitzer
func = newtons_method_function(expr, x, args, debug=use_wurlitzer)
if not has_c():
skip("No C compiler found.")
if not cython:
skip("cython not installed.")
compile_kw = {"std": 'c99'}
with tempfile.TemporaryDirectory() as folder:
mod, info = compile_link_import_strings([
('newton_par.c', ('#include <math.h>\n'
'#include <stdio.h>\n') + ccode(func)),
('_newton_par.pyx', ("#cython: language_level={}\n".format("3") +
"cdef extern double newton(double, double, double, double)\n"
"def py_newton(x, A=1, k=1, p=1):\n"
" return newton(x, A, k, p)\n"))
], compile_kwargs=compile_kw, build_dir=folder)
if use_wurlitzer:
with wurlitzer.pipes() as (out, err):
result = mod.py_newton(0.5)
else:
result = mod.py_newton(0.5)
assert abs(result - 0.865474033102) < 1e-12
if not use_wurlitzer:
skip("C-level output only tested when package 'wurlitzer' is available.")
out, err = out.read(), err.read()
assert err == ''
assert out == """\
x= 0.5
x= 1.1121 d_x= 0.61214
x= 0.90967 d_x= -0.20247
x= 0.86726 d_x= -0.042409
x= 0.86548 d_x= -0.0017867
x= 0.86547 d_x= -3.1022e-06
x= 0.86547 d_x= -9.3421e-12
x= 0.86547 d_x= 3.6902e-17
""" # try to run tests with LC_ALL=C if this assertion fails
def test_newtons_method_function__rtol_cse_nan():
a, b, c, N_geo, N_tot = symbols('a b c N_geo N_tot', real=True, nonnegative=True)
i = Symbol('i', integer=True, nonnegative=True)
N_ari = N_tot - N_geo - 1
delta_ari = (c-b)/N_ari
ln_delta_geo = log(b) + log(-expm1((log(a)-log(b))/N_geo))
eqb_log = ln_delta_geo - log(delta_ari)
def _clamp(low, expr, high):
return Min(Max(low, expr), high)
meth_kw = {
'clamped_newton': {'delta_fn': lambda e, x: _clamp(
(sqrt(a*x)-x)*0.99,
-e/e.diff(x),
(sqrt(c*x)-x)*0.99
)},
'halley': {'delta_fn': lambda e, x: (-2*(e*e.diff(x))/(2*e.diff(x)**2 - e*e.diff(x, 2)))},
'halley_alt': {'delta_fn': lambda e, x: (-e/e.diff(x)/(1-e/e.diff(x)*e.diff(x,2)/2/e.diff(x)))},
}
args = eqb_log, b
for use_cse in [False, True]:
kwargs = {
'params': (b, a, c, N_geo, N_tot), 'itermax': 60, 'debug': True, 'cse': use_cse,
'counter': i, 'atol': 1e-100, 'rtol': 2e-16, 'bounds': (a,c),
'handle_nan': Raise(RuntimeError_(QuotedString("encountered NaN.")))
}
func = {k: newtons_method_function(*args, func_name=f"{k}_b", **dict(kwargs, **kw)) for k, kw in meth_kw.items()}
py_mod = {k: py_module(v) for k, v in func.items()}
namespace = {}
root_find_b = {}
for k, v in py_mod.items():
ns = namespace[k] = {}
exec(v, ns, ns)
root_find_b[k] = ns[f'{k}_b']
ref = Float('13.2261515064168768938151923226496')
reftol = {'clamped_newton': 2e-16, 'halley': 2e-16, 'halley_alt': 3e-16}
guess = 4.0
for meth, func in root_find_b.items():
result = func(guess, 1e-2, 1e2, 50, 100)
req = ref*reftol[meth]
if use_cse:
req *= 2
assert abs(result - ref) < req

View File

@ -0,0 +1,57 @@
# This file contains tests that exercise multiple AST nodes
import tempfile
from sympy.external import import_module
from sympy.printing.codeprinter import ccode
from sympy.utilities._compilation import compile_link_import_strings, has_c
from sympy.utilities._compilation.util import may_xfail
from sympy.testing.pytest import skip
from sympy.codegen.ast import (
FunctionDefinition, FunctionPrototype, Variable, Pointer, real, Assignment,
integer, CodeBlock, While
)
from sympy.codegen.cnodes import void, PreIncrement
from sympy.codegen.cutils import render_as_source_file
cython = import_module('cython')
np = import_module('numpy')
def _mk_func1():
declars = n, inp, out = Variable('n', integer), Pointer('inp', real), Pointer('out', real)
i = Variable('i', integer)
whl = While(i<n, [Assignment(out[i], inp[i]), PreIncrement(i)])
body = CodeBlock(i.as_Declaration(value=0), whl)
return FunctionDefinition(void, 'our_test_function', declars, body)
def _render_compile_import(funcdef, build_dir):
code_str = render_as_source_file(funcdef, settings={"contract": False})
declar = ccode(FunctionPrototype.from_FunctionDefinition(funcdef))
return compile_link_import_strings([
('our_test_func.c', code_str),
('_our_test_func.pyx', ("#cython: language_level={}\n".format("3") +
"cdef extern {declar}\n"
"def _{fname}({typ}[:] inp, {typ}[:] out):\n"
" {fname}(inp.size, &inp[0], &out[0])").format(
declar=declar, fname=funcdef.name, typ='double'
))
], build_dir=build_dir)
@may_xfail
def test_copying_function():
if not np:
skip("numpy not installed.")
if not has_c():
skip("No C compiler found.")
if not cython:
skip("Cython not found.")
info = None
with tempfile.TemporaryDirectory() as folder:
mod, info = _render_compile_import(_mk_func1(), build_dir=folder)
inp = np.arange(10.0)
out = np.empty_like(inp)
mod._our_test_function(inp, out)
assert np.allclose(inp, out)

View File

@ -0,0 +1,53 @@
import math
from sympy.core.symbol import symbols
from sympy.functions.elementary.exponential import exp
from sympy.codegen.rewriting import optimize
from sympy.codegen.approximations import SumApprox, SeriesApprox
def test_SumApprox_trivial():
x = symbols('x')
expr1 = 1 + x
sum_approx = SumApprox(bounds={x: (-1e-20, 1e-20)}, reltol=1e-16)
apx1 = optimize(expr1, [sum_approx])
assert apx1 - 1 == 0
def test_SumApprox_monotone_terms():
x, y, z = symbols('x y z')
expr1 = exp(z)*(x**2 + y**2 + 1)
bnds1 = {x: (0, 1e-3), y: (100, 1000)}
sum_approx_m2 = SumApprox(bounds=bnds1, reltol=1e-2)
sum_approx_m5 = SumApprox(bounds=bnds1, reltol=1e-5)
sum_approx_m11 = SumApprox(bounds=bnds1, reltol=1e-11)
assert (optimize(expr1, [sum_approx_m2])/exp(z) - (y**2)).simplify() == 0
assert (optimize(expr1, [sum_approx_m5])/exp(z) - (y**2 + 1)).simplify() == 0
assert (optimize(expr1, [sum_approx_m11])/exp(z) - (y**2 + 1 + x**2)).simplify() == 0
def test_SeriesApprox_trivial():
x, z = symbols('x z')
for factor in [1, exp(z)]:
x = symbols('x')
expr1 = exp(x)*factor
bnds1 = {x: (-1, 1)}
series_approx_50 = SeriesApprox(bounds=bnds1, reltol=0.50)
series_approx_10 = SeriesApprox(bounds=bnds1, reltol=0.10)
series_approx_05 = SeriesApprox(bounds=bnds1, reltol=0.05)
c = (bnds1[x][1] + bnds1[x][0])/2 # 0.0
f0 = math.exp(c) # 1.0
ref_50 = f0 + x + x**2/2
ref_10 = f0 + x + x**2/2 + x**3/6
ref_05 = f0 + x + x**2/2 + x**3/6 + x**4/24
res_50 = optimize(expr1, [series_approx_50])
res_10 = optimize(expr1, [series_approx_10])
res_05 = optimize(expr1, [series_approx_05])
assert (res_50/factor - ref_50).simplify() == 0
assert (res_10/factor - ref_10).simplify() == 0
assert (res_05/factor - ref_05).simplify() == 0
max_ord3 = SeriesApprox(bounds=bnds1, reltol=0.05, max_order=3)
assert optimize(expr1, [max_ord3]) == expr1

View File

@ -0,0 +1,661 @@
import math
from sympy.core.containers import Tuple
from sympy.core.numbers import nan, oo, Float, Integer
from sympy.core.relational import Lt
from sympy.core.symbol import symbols, Symbol
from sympy.functions.elementary.trigonometric import sin
from sympy.matrices.dense import Matrix
from sympy.matrices.expressions.matexpr import MatrixSymbol
from sympy.sets.fancysets import Range
from sympy.tensor.indexed import Idx, IndexedBase
from sympy.testing.pytest import raises
from sympy.codegen.ast import (
Assignment, Attribute, aug_assign, CodeBlock, For, Type, Variable, Pointer, Declaration,
AddAugmentedAssignment, SubAugmentedAssignment, MulAugmentedAssignment,
DivAugmentedAssignment, ModAugmentedAssignment, value_const, pointer_const,
integer, real, complex_, int8, uint8, float16 as f16, float32 as f32,
float64 as f64, float80 as f80, float128 as f128, complex64 as c64, complex128 as c128,
While, Scope, String, Print, QuotedString, FunctionPrototype, FunctionDefinition, Return,
FunctionCall, untyped, IntBaseType, intc, Node, none, NoneToken, Token, Comment
)
x, y, z, t, x0, x1, x2, a, b = symbols("x, y, z, t, x0, x1, x2, a, b")
n = symbols("n", integer=True)
A = MatrixSymbol('A', 3, 1)
mat = Matrix([1, 2, 3])
B = IndexedBase('B')
i = Idx("i", n)
A22 = MatrixSymbol('A22',2,2)
B22 = MatrixSymbol('B22',2,2)
def test_Assignment():
# Here we just do things to show they don't error
Assignment(x, y)
Assignment(x, 0)
Assignment(A, mat)
Assignment(A[1,0], 0)
Assignment(A[1,0], x)
Assignment(B[i], x)
Assignment(B[i], 0)
a = Assignment(x, y)
assert a.func(*a.args) == a
assert a.op == ':='
# Here we test things to show that they error
# Matrix to scalar
raises(ValueError, lambda: Assignment(B[i], A))
raises(ValueError, lambda: Assignment(B[i], mat))
raises(ValueError, lambda: Assignment(x, mat))
raises(ValueError, lambda: Assignment(x, A))
raises(ValueError, lambda: Assignment(A[1,0], mat))
# Scalar to matrix
raises(ValueError, lambda: Assignment(A, x))
raises(ValueError, lambda: Assignment(A, 0))
# Non-atomic lhs
raises(TypeError, lambda: Assignment(mat, A))
raises(TypeError, lambda: Assignment(0, x))
raises(TypeError, lambda: Assignment(x*x, 1))
raises(TypeError, lambda: Assignment(A + A, mat))
raises(TypeError, lambda: Assignment(B, 0))
def test_AugAssign():
# Here we just do things to show they don't error
aug_assign(x, '+', y)
aug_assign(x, '+', 0)
aug_assign(A, '+', mat)
aug_assign(A[1, 0], '+', 0)
aug_assign(A[1, 0], '+', x)
aug_assign(B[i], '+', x)
aug_assign(B[i], '+', 0)
# Check creation via aug_assign vs constructor
for binop, cls in [
('+', AddAugmentedAssignment),
('-', SubAugmentedAssignment),
('*', MulAugmentedAssignment),
('/', DivAugmentedAssignment),
('%', ModAugmentedAssignment),
]:
a = aug_assign(x, binop, y)
b = cls(x, y)
assert a.func(*a.args) == a == b
assert a.binop == binop
assert a.op == binop + '='
# Here we test things to show that they error
# Matrix to scalar
raises(ValueError, lambda: aug_assign(B[i], '+', A))
raises(ValueError, lambda: aug_assign(B[i], '+', mat))
raises(ValueError, lambda: aug_assign(x, '+', mat))
raises(ValueError, lambda: aug_assign(x, '+', A))
raises(ValueError, lambda: aug_assign(A[1, 0], '+', mat))
# Scalar to matrix
raises(ValueError, lambda: aug_assign(A, '+', x))
raises(ValueError, lambda: aug_assign(A, '+', 0))
# Non-atomic lhs
raises(TypeError, lambda: aug_assign(mat, '+', A))
raises(TypeError, lambda: aug_assign(0, '+', x))
raises(TypeError, lambda: aug_assign(x * x, '+', 1))
raises(TypeError, lambda: aug_assign(A + A, '+', mat))
raises(TypeError, lambda: aug_assign(B, '+', 0))
def test_Assignment_printing():
assignment_classes = [
Assignment,
AddAugmentedAssignment,
SubAugmentedAssignment,
MulAugmentedAssignment,
DivAugmentedAssignment,
ModAugmentedAssignment,
]
pairs = [
(x, 2 * y + 2),
(B[i], x),
(A22, B22),
(A[0, 0], x),
]
for cls in assignment_classes:
for lhs, rhs in pairs:
a = cls(lhs, rhs)
assert repr(a) == '%s(%s, %s)' % (cls.__name__, repr(lhs), repr(rhs))
def test_CodeBlock():
c = CodeBlock(Assignment(x, 1), Assignment(y, x + 1))
assert c.func(*c.args) == c
assert c.left_hand_sides == Tuple(x, y)
assert c.right_hand_sides == Tuple(1, x + 1)
def test_CodeBlock_topological_sort():
assignments = [
Assignment(x, y + z),
Assignment(z, 1),
Assignment(t, x),
Assignment(y, 2),
]
ordered_assignments = [
# Note that the unrelated z=1 and y=2 are kept in that order
Assignment(z, 1),
Assignment(y, 2),
Assignment(x, y + z),
Assignment(t, x),
]
c1 = CodeBlock.topological_sort(assignments)
assert c1 == CodeBlock(*ordered_assignments)
# Cycle
invalid_assignments = [
Assignment(x, y + z),
Assignment(z, 1),
Assignment(y, x),
Assignment(y, 2),
]
raises(ValueError, lambda: CodeBlock.topological_sort(invalid_assignments))
# Free symbols
free_assignments = [
Assignment(x, y + z),
Assignment(z, a * b),
Assignment(t, x),
Assignment(y, b + 3),
]
free_assignments_ordered = [
Assignment(z, a * b),
Assignment(y, b + 3),
Assignment(x, y + z),
Assignment(t, x),
]
c2 = CodeBlock.topological_sort(free_assignments)
assert c2 == CodeBlock(*free_assignments_ordered)
def test_CodeBlock_free_symbols():
c1 = CodeBlock(
Assignment(x, y + z),
Assignment(z, 1),
Assignment(t, x),
Assignment(y, 2),
)
assert c1.free_symbols == set()
c2 = CodeBlock(
Assignment(x, y + z),
Assignment(z, a * b),
Assignment(t, x),
Assignment(y, b + 3),
)
assert c2.free_symbols == {a, b}
def test_CodeBlock_cse():
c1 = CodeBlock(
Assignment(y, 1),
Assignment(x, sin(y)),
Assignment(z, sin(y)),
Assignment(t, x*z),
)
assert c1.cse() == CodeBlock(
Assignment(y, 1),
Assignment(x0, sin(y)),
Assignment(x, x0),
Assignment(z, x0),
Assignment(t, x*z),
)
# Multiple assignments to same symbol not supported
raises(NotImplementedError, lambda: CodeBlock(
Assignment(x, 1),
Assignment(y, 1), Assignment(y, 2)
).cse())
# Check auto-generated symbols do not collide with existing ones
c2 = CodeBlock(
Assignment(x0, sin(y) + 1),
Assignment(x1, 2 * sin(y)),
Assignment(z, x * y),
)
assert c2.cse() == CodeBlock(
Assignment(x2, sin(y)),
Assignment(x0, x2 + 1),
Assignment(x1, 2 * x2),
Assignment(z, x * y),
)
def test_CodeBlock_cse__issue_14118():
# see https://github.com/sympy/sympy/issues/14118
c = CodeBlock(
Assignment(A22, Matrix([[x, sin(y)],[3, 4]])),
Assignment(B22, Matrix([[sin(y), 2*sin(y)], [sin(y)**2, 7]]))
)
assert c.cse() == CodeBlock(
Assignment(x0, sin(y)),
Assignment(A22, Matrix([[x, x0],[3, 4]])),
Assignment(B22, Matrix([[x0, 2*x0], [x0**2, 7]]))
)
def test_For():
f = For(n, Range(0, 3), (Assignment(A[n, 0], x + n), aug_assign(x, '+', y)))
f = For(n, (1, 2, 3, 4, 5), (Assignment(A[n, 0], x + n),))
assert f.func(*f.args) == f
raises(TypeError, lambda: For(n, x, (x + y,)))
def test_none():
assert none.is_Atom
assert none == none
class Foo(Token):
pass
foo = Foo()
assert foo != none
assert none == None
assert none == NoneToken()
assert none.func(*none.args) == none
def test_String():
st = String('foobar')
assert st.is_Atom
assert st == String('foobar')
assert st.text == 'foobar'
assert st.func(**st.kwargs()) == st
assert st.func(*st.args) == st
class Signifier(String):
pass
si = Signifier('foobar')
assert si != st
assert si.text == st.text
s = String('foo')
assert str(s) == 'foo'
assert repr(s) == "String('foo')"
def test_Comment():
c = Comment('foobar')
assert c.text == 'foobar'
assert str(c) == 'foobar'
def test_Node():
n = Node()
assert n == Node()
assert n.func(*n.args) == n
def test_Type():
t = Type('MyType')
assert len(t.args) == 1
assert t.name == String('MyType')
assert str(t) == 'MyType'
assert repr(t) == "Type(String('MyType'))"
assert Type(t) == t
assert t.func(*t.args) == t
t1 = Type('t1')
t2 = Type('t2')
assert t1 != t2
assert t1 == t1 and t2 == t2
t1b = Type('t1')
assert t1 == t1b
assert t2 != t1b
def test_Type__from_expr():
assert Type.from_expr(i) == integer
u = symbols('u', real=True)
assert Type.from_expr(u) == real
assert Type.from_expr(n) == integer
assert Type.from_expr(3) == integer
assert Type.from_expr(3.0) == real
assert Type.from_expr(3+1j) == complex_
raises(ValueError, lambda: Type.from_expr(sum))
def test_Type__cast_check__integers():
# Rounding
raises(ValueError, lambda: integer.cast_check(3.5))
assert integer.cast_check('3') == 3
assert integer.cast_check(Float('3.0000000000000000000')) == 3
assert integer.cast_check(Float('3.0000000000000000001')) == 3 # unintuitive maybe?
# Range
assert int8.cast_check(127.0) == 127
raises(ValueError, lambda: int8.cast_check(128))
assert int8.cast_check(-128) == -128
raises(ValueError, lambda: int8.cast_check(-129))
assert uint8.cast_check(0) == 0
assert uint8.cast_check(128) == 128
raises(ValueError, lambda: uint8.cast_check(256.0))
raises(ValueError, lambda: uint8.cast_check(-1))
def test_Attribute():
noexcept = Attribute('noexcept')
assert noexcept == Attribute('noexcept')
alignas16 = Attribute('alignas', [16])
alignas32 = Attribute('alignas', [32])
assert alignas16 != alignas32
assert alignas16.func(*alignas16.args) == alignas16
def test_Variable():
v = Variable(x, type=real)
assert v == Variable(v)
assert v == Variable('x', type=real)
assert v.symbol == x
assert v.type == real
assert value_const not in v.attrs
assert v.func(*v.args) == v
assert str(v) == 'Variable(x, type=real)'
w = Variable(y, f32, attrs={value_const})
assert w.symbol == y
assert w.type == f32
assert value_const in w.attrs
assert w.func(*w.args) == w
v_n = Variable(n, type=Type.from_expr(n))
assert v_n.type == integer
assert v_n.func(*v_n.args) == v_n
v_i = Variable(i, type=Type.from_expr(n))
assert v_i.type == integer
assert v_i != v_n
a_i = Variable.deduced(i)
assert a_i.type == integer
assert Variable.deduced(Symbol('x', real=True)).type == real
assert a_i.func(*a_i.args) == a_i
v_n2 = Variable.deduced(n, value=3.5, cast_check=False)
assert v_n2.func(*v_n2.args) == v_n2
assert abs(v_n2.value - 3.5) < 1e-15
raises(ValueError, lambda: Variable.deduced(n, value=3.5, cast_check=True))
v_n3 = Variable.deduced(n)
assert v_n3.type == integer
assert str(v_n3) == 'Variable(n, type=integer)'
assert Variable.deduced(z, value=3).type == integer
assert Variable.deduced(z, value=3.0).type == real
assert Variable.deduced(z, value=3.0+1j).type == complex_
def test_Pointer():
p = Pointer(x)
assert p.symbol == x
assert p.type == untyped
assert value_const not in p.attrs
assert pointer_const not in p.attrs
assert p.func(*p.args) == p
u = symbols('u', real=True)
pu = Pointer(u, type=Type.from_expr(u), attrs={value_const, pointer_const})
assert pu.symbol is u
assert pu.type == real
assert value_const in pu.attrs
assert pointer_const in pu.attrs
assert pu.func(*pu.args) == pu
i = symbols('i', integer=True)
deref = pu[i]
assert deref.indices == (i,)
def test_Declaration():
u = symbols('u', real=True)
vu = Variable(u, type=Type.from_expr(u))
assert Declaration(vu).variable.type == real
vn = Variable(n, type=Type.from_expr(n))
assert Declaration(vn).variable.type == integer
# PR 19107, does not allow comparison between expressions and Basic
# lt = StrictLessThan(vu, vn)
# assert isinstance(lt, StrictLessThan)
vuc = Variable(u, Type.from_expr(u), value=3.0, attrs={value_const})
assert value_const in vuc.attrs
assert pointer_const not in vuc.attrs
decl = Declaration(vuc)
assert decl.variable == vuc
assert isinstance(decl.variable.value, Float)
assert decl.variable.value == 3.0
assert decl.func(*decl.args) == decl
assert vuc.as_Declaration() == decl
assert vuc.as_Declaration(value=None, attrs=None) == Declaration(vu)
vy = Variable(y, type=integer, value=3)
decl2 = Declaration(vy)
assert decl2.variable == vy
assert decl2.variable.value == Integer(3)
vi = Variable(i, type=Type.from_expr(i), value=3.0)
decl3 = Declaration(vi)
assert decl3.variable.type == integer
assert decl3.variable.value == 3.0
raises(ValueError, lambda: Declaration(vi, 42))
def test_IntBaseType():
assert intc.name == String('intc')
assert intc.args == (intc.name,)
assert str(IntBaseType('a').name) == 'a'
def test_FloatType():
assert f16.dig == 3
assert f32.dig == 6
assert f64.dig == 15
assert f80.dig == 18
assert f128.dig == 33
assert f16.decimal_dig == 5
assert f32.decimal_dig == 9
assert f64.decimal_dig == 17
assert f80.decimal_dig == 21
assert f128.decimal_dig == 36
assert f16.max_exponent == 16
assert f32.max_exponent == 128
assert f64.max_exponent == 1024
assert f80.max_exponent == 16384
assert f128.max_exponent == 16384
assert f16.min_exponent == -13
assert f32.min_exponent == -125
assert f64.min_exponent == -1021
assert f80.min_exponent == -16381
assert f128.min_exponent == -16381
assert abs(f16.eps / Float('0.00097656', precision=16) - 1) < 0.1*10**-f16.dig
assert abs(f32.eps / Float('1.1920929e-07', precision=32) - 1) < 0.1*10**-f32.dig
assert abs(f64.eps / Float('2.2204460492503131e-16', precision=64) - 1) < 0.1*10**-f64.dig
assert abs(f80.eps / Float('1.08420217248550443401e-19', precision=80) - 1) < 0.1*10**-f80.dig
assert abs(f128.eps / Float(' 1.92592994438723585305597794258492732e-34', precision=128) - 1) < 0.1*10**-f128.dig
assert abs(f16.max / Float('65504', precision=16) - 1) < .1*10**-f16.dig
assert abs(f32.max / Float('3.40282347e+38', precision=32) - 1) < 0.1*10**-f32.dig
assert abs(f64.max / Float('1.79769313486231571e+308', precision=64) - 1) < 0.1*10**-f64.dig # cf. np.finfo(np.float64).max
assert abs(f80.max / Float('1.18973149535723176502e+4932', precision=80) - 1) < 0.1*10**-f80.dig
assert abs(f128.max / Float('1.18973149535723176508575932662800702e+4932', precision=128) - 1) < 0.1*10**-f128.dig
# cf. np.finfo(np.float32).tiny
assert abs(f16.tiny / Float('6.1035e-05', precision=16) - 1) < 0.1*10**-f16.dig
assert abs(f32.tiny / Float('1.17549435e-38', precision=32) - 1) < 0.1*10**-f32.dig
assert abs(f64.tiny / Float('2.22507385850720138e-308', precision=64) - 1) < 0.1*10**-f64.dig
assert abs(f80.tiny / Float('3.36210314311209350626e-4932', precision=80) - 1) < 0.1*10**-f80.dig
assert abs(f128.tiny / Float('3.3621031431120935062626778173217526e-4932', precision=128) - 1) < 0.1*10**-f128.dig
assert f64.cast_check(0.5) == Float(0.5, 17)
assert abs(f64.cast_check(3.7) - 3.7) < 3e-17
assert isinstance(f64.cast_check(3), (Float, float))
assert f64.cast_nocheck(oo) == float('inf')
assert f64.cast_nocheck(-oo) == float('-inf')
assert f64.cast_nocheck(float(oo)) == float('inf')
assert f64.cast_nocheck(float(-oo)) == float('-inf')
assert math.isnan(f64.cast_nocheck(nan))
assert f32 != f64
assert f64 == f64.func(*f64.args)
def test_Type__cast_check__floating_point():
raises(ValueError, lambda: f32.cast_check(123.45678949))
raises(ValueError, lambda: f32.cast_check(12.345678949))
raises(ValueError, lambda: f32.cast_check(1.2345678949))
raises(ValueError, lambda: f32.cast_check(.12345678949))
assert abs(123.456789049 - f32.cast_check(123.456789049) - 4.9e-8) < 1e-8
assert abs(0.12345678904 - f32.cast_check(0.12345678904) - 4e-11) < 1e-11
dcm21 = Float('0.123456789012345670499') # 21 decimals
assert abs(dcm21 - f64.cast_check(dcm21) - 4.99e-19) < 1e-19
f80.cast_check(Float('0.12345678901234567890103', precision=88))
raises(ValueError, lambda: f80.cast_check(Float('0.12345678901234567890149', precision=88)))
v10 = 12345.67894
raises(ValueError, lambda: f32.cast_check(v10))
assert abs(Float(str(v10), precision=64+8) - f64.cast_check(v10)) < v10*1e-16
assert abs(f32.cast_check(2147483647) - 2147483650) < 1
def test_Type__cast_check__complex_floating_point():
val9_11 = 123.456789049 + 0.123456789049j
raises(ValueError, lambda: c64.cast_check(.12345678949 + .12345678949j))
assert abs(val9_11 - c64.cast_check(val9_11) - 4.9e-8) < 1e-8
dcm21 = Float('0.123456789012345670499') + 1e-20j # 21 decimals
assert abs(dcm21 - c128.cast_check(dcm21) - 4.99e-19) < 1e-19
v19 = Float('0.1234567890123456749') + 1j*Float('0.1234567890123456749')
raises(ValueError, lambda: c128.cast_check(v19))
def test_While():
xpp = AddAugmentedAssignment(x, 1)
whl1 = While(x < 2, [xpp])
assert whl1.condition.args[0] == x
assert whl1.condition.args[1] == 2
assert whl1.condition == Lt(x, 2, evaluate=False)
assert whl1.body.args == (xpp,)
assert whl1.func(*whl1.args) == whl1
cblk = CodeBlock(AddAugmentedAssignment(x, 1))
whl2 = While(x < 2, cblk)
assert whl1 == whl2
assert whl1 != While(x < 3, [xpp])
def test_Scope():
assign = Assignment(x, y)
incr = AddAugmentedAssignment(x, 1)
scp = Scope([assign, incr])
cblk = CodeBlock(assign, incr)
assert scp.body == cblk
assert scp == Scope(cblk)
assert scp != Scope([incr, assign])
assert scp.func(*scp.args) == scp
def test_Print():
fmt = "%d %.3f"
ps = Print([n, x], fmt)
assert str(ps.format_string) == fmt
assert ps.print_args == Tuple(n, x)
assert ps.args == (Tuple(n, x), QuotedString(fmt), none)
assert ps == Print((n, x), fmt)
assert ps != Print([x, n], fmt)
assert ps.func(*ps.args) == ps
ps2 = Print([n, x])
assert ps2 == Print([n, x])
assert ps2 != ps
assert ps2.format_string == None
def test_FunctionPrototype_and_FunctionDefinition():
vx = Variable(x, type=real)
vn = Variable(n, type=integer)
fp1 = FunctionPrototype(real, 'power', [vx, vn])
assert fp1.return_type == real
assert fp1.name == String('power')
assert fp1.parameters == Tuple(vx, vn)
assert fp1 == FunctionPrototype(real, 'power', [vx, vn])
assert fp1 != FunctionPrototype(real, 'power', [vn, vx])
assert fp1.func(*fp1.args) == fp1
body = [Assignment(x, x**n), Return(x)]
fd1 = FunctionDefinition(real, 'power', [vx, vn], body)
assert fd1.return_type == real
assert str(fd1.name) == 'power'
assert fd1.parameters == Tuple(vx, vn)
assert fd1.body == CodeBlock(*body)
assert fd1 == FunctionDefinition(real, 'power', [vx, vn], body)
assert fd1 != FunctionDefinition(real, 'power', [vx, vn], body[::-1])
assert fd1.func(*fd1.args) == fd1
fp2 = FunctionPrototype.from_FunctionDefinition(fd1)
assert fp2 == fp1
fd2 = FunctionDefinition.from_FunctionPrototype(fp1, body)
assert fd2 == fd1
def test_Return():
rs = Return(x)
assert rs.args == (x,)
assert rs == Return(x)
assert rs != Return(y)
assert rs.func(*rs.args) == rs
def test_FunctionCall():
fc = FunctionCall('power', (x, 3))
assert fc.function_args[0] == x
assert fc.function_args[1] == 3
assert len(fc.function_args) == 2
assert isinstance(fc.function_args[1], Integer)
assert fc == FunctionCall('power', (x, 3))
assert fc != FunctionCall('power', (3, x))
assert fc != FunctionCall('Power', (x, 3))
assert fc.func(*fc.args) == fc
fc2 = FunctionCall('fma', [2, 3, 4])
assert len(fc2.function_args) == 3
assert fc2.function_args[0] == 2
assert fc2.function_args[1] == 3
assert fc2.function_args[2] == 4
assert str(fc2) in ( # not sure if QuotedString is a better default...
'FunctionCall(fma, function_args=(2, 3, 4))',
'FunctionCall("fma", function_args=(2, 3, 4))',
)
def test_ast_replace():
x = Variable('x', real)
y = Variable('y', real)
n = Variable('n', integer)
pwer = FunctionDefinition(real, 'pwer', [x, n], [pow(x.symbol, n.symbol)])
pname = pwer.name
pcall = FunctionCall('pwer', [y, 3])
tree1 = CodeBlock(pwer, pcall)
assert str(tree1.args[0].name) == 'pwer'
assert str(tree1.args[1].name) == 'pwer'
for a, b in zip(tree1, [pwer, pcall]):
assert a == b
tree2 = tree1.replace(pname, String('power'))
assert str(tree1.args[0].name) == 'pwer'
assert str(tree1.args[1].name) == 'pwer'
assert str(tree2.args[0].name) == 'power'
assert str(tree2.args[1].name) == 'power'

View File

@ -0,0 +1,165 @@
from sympy.core.numbers import (Rational, pi)
from sympy.core.singleton import S
from sympy.core.symbol import (Symbol, symbols)
from sympy.functions.elementary.exponential import (exp, log)
from sympy.codegen.cfunctions import (
expm1, log1p, exp2, log2, fma, log10, Sqrt, Cbrt, hypot
)
from sympy.core.function import expand_log
def test_expm1():
# Eval
assert expm1(0) == 0
x = Symbol('x', real=True)
# Expand and rewrite
assert expm1(x).expand(func=True) - exp(x) == -1
assert expm1(x).rewrite('tractable') - exp(x) == -1
assert expm1(x).rewrite('exp') - exp(x) == -1
# Precision
assert not ((exp(1e-10).evalf() - 1) - 1e-10 - 5e-21) < 1e-22 # for comparison
assert abs(expm1(1e-10).evalf() - 1e-10 - 5e-21) < 1e-22
# Properties
assert expm1(x).is_real
assert expm1(x).is_finite
# Diff
assert expm1(42*x).diff(x) - 42*exp(42*x) == 0
assert expm1(42*x).diff(x) - expm1(42*x).expand(func=True).diff(x) == 0
def test_log1p():
# Eval
assert log1p(0) == 0
d = S(10)
assert expand_log(log1p(d**-1000) - log(d**1000 + 1) + log(d**1000)) == 0
x = Symbol('x', real=True)
# Expand and rewrite
assert log1p(x).expand(func=True) - log(x + 1) == 0
assert log1p(x).rewrite('tractable') - log(x + 1) == 0
assert log1p(x).rewrite('log') - log(x + 1) == 0
# Precision
assert not abs(log(1e-99 + 1).evalf() - 1e-99) < 1e-100 # for comparison
assert abs(expand_log(log1p(1e-99)).evalf() - 1e-99) < 1e-100
# Properties
assert log1p(-2**Rational(-1, 2)).is_real
assert not log1p(-1).is_finite
assert log1p(pi).is_finite
assert not log1p(x).is_positive
assert log1p(Symbol('y', positive=True)).is_positive
assert not log1p(x).is_zero
assert log1p(Symbol('z', zero=True)).is_zero
assert not log1p(x).is_nonnegative
assert log1p(Symbol('o', nonnegative=True)).is_nonnegative
# Diff
assert log1p(42*x).diff(x) - 42/(42*x + 1) == 0
assert log1p(42*x).diff(x) - log1p(42*x).expand(func=True).diff(x) == 0
def test_exp2():
# Eval
assert exp2(2) == 4
x = Symbol('x', real=True)
# Expand
assert exp2(x).expand(func=True) - 2**x == 0
# Diff
assert exp2(42*x).diff(x) - 42*exp2(42*x)*log(2) == 0
assert exp2(42*x).diff(x) - exp2(42*x).diff(x) == 0
def test_log2():
# Eval
assert log2(8) == 3
assert log2(pi) != log(pi)/log(2) # log2 should *save* (CPU) instructions
x = Symbol('x', real=True)
assert log2(x) != log(x)/log(2)
assert log2(2**x) == x
# Expand
assert log2(x).expand(func=True) - log(x)/log(2) == 0
# Diff
assert log2(42*x).diff() - 1/(log(2)*x) == 0
assert log2(42*x).diff() - log2(42*x).expand(func=True).diff(x) == 0
def test_fma():
x, y, z = symbols('x y z')
# Expand
assert fma(x, y, z).expand(func=True) - x*y - z == 0
expr = fma(17*x, 42*y, 101*z)
# Diff
assert expr.diff(x) - expr.expand(func=True).diff(x) == 0
assert expr.diff(y) - expr.expand(func=True).diff(y) == 0
assert expr.diff(z) - expr.expand(func=True).diff(z) == 0
assert expr.diff(x) - 17*42*y == 0
assert expr.diff(y) - 17*42*x == 0
assert expr.diff(z) - 101 == 0
def test_log10():
x = Symbol('x')
# Expand
assert log10(x).expand(func=True) - log(x)/log(10) == 0
# Diff
assert log10(42*x).diff(x) - 1/(log(10)*x) == 0
assert log10(42*x).diff(x) - log10(42*x).expand(func=True).diff(x) == 0
def test_Cbrt():
x = Symbol('x')
# Expand
assert Cbrt(x).expand(func=True) - x**Rational(1, 3) == 0
# Diff
assert Cbrt(42*x).diff(x) - 42*(42*x)**(Rational(1, 3) - 1)/3 == 0
assert Cbrt(42*x).diff(x) - Cbrt(42*x).expand(func=True).diff(x) == 0
def test_Sqrt():
x = Symbol('x')
# Expand
assert Sqrt(x).expand(func=True) - x**S.Half == 0
# Diff
assert Sqrt(42*x).diff(x) - 42*(42*x)**(S.Half - 1)/2 == 0
assert Sqrt(42*x).diff(x) - Sqrt(42*x).expand(func=True).diff(x) == 0
def test_hypot():
x, y = symbols('x y')
# Expand
assert hypot(x, y).expand(func=True) - (x**2 + y**2)**S.Half == 0
# Diff
assert hypot(17*x, 42*y).diff(x).expand(func=True) - hypot(17*x, 42*y).expand(func=True).diff(x) == 0
assert hypot(17*x, 42*y).diff(y).expand(func=True) - hypot(17*x, 42*y).expand(func=True).diff(y) == 0
assert hypot(17*x, 42*y).diff(x).expand(func=True) - 2*17*17*x*((17*x)**2 + (42*y)**2)**Rational(-1, 2)/2 == 0
assert hypot(17*x, 42*y).diff(y).expand(func=True) - 2*42*42*y*((17*x)**2 + (42*y)**2)**Rational(-1, 2)/2 == 0

View File

@ -0,0 +1,112 @@
from sympy.core.symbol import symbols
from sympy.printing.codeprinter import ccode
from sympy.codegen.ast import Declaration, Variable, float64, int64, String, CodeBlock
from sympy.codegen.cnodes import (
alignof, CommaOperator, goto, Label, PreDecrement, PostDecrement, PreIncrement, PostIncrement,
sizeof, union, struct
)
x, y = symbols('x y')
def test_alignof():
ax = alignof(x)
assert ccode(ax) == 'alignof(x)'
assert ax.func(*ax.args) == ax
def test_CommaOperator():
expr = CommaOperator(PreIncrement(x), 2*x)
assert ccode(expr) == '(++(x), 2*x)'
assert expr.func(*expr.args) == expr
def test_goto_Label():
s = 'early_exit'
g = goto(s)
assert g.func(*g.args) == g
assert g != goto('foobar')
assert ccode(g) == 'goto early_exit'
l1 = Label(s)
assert ccode(l1) == 'early_exit:'
assert l1 == Label('early_exit')
assert l1 != Label('foobar')
body = [PreIncrement(x)]
l2 = Label(s, body)
assert l2.name == String("early_exit")
assert l2.body == CodeBlock(PreIncrement(x))
assert ccode(l2) == ("early_exit:\n"
"++(x);")
body = [PreIncrement(x), PreDecrement(y)]
l2 = Label(s, body)
assert l2.name == String("early_exit")
assert l2.body == CodeBlock(PreIncrement(x), PreDecrement(y))
assert ccode(l2) == ("early_exit:\n"
"{\n ++(x);\n --(y);\n}")
def test_PreDecrement():
p = PreDecrement(x)
assert p.func(*p.args) == p
assert ccode(p) == '--(x)'
def test_PostDecrement():
p = PostDecrement(x)
assert p.func(*p.args) == p
assert ccode(p) == '(x)--'
def test_PreIncrement():
p = PreIncrement(x)
assert p.func(*p.args) == p
assert ccode(p) == '++(x)'
def test_PostIncrement():
p = PostIncrement(x)
assert p.func(*p.args) == p
assert ccode(p) == '(x)++'
def test_sizeof():
typename = 'unsigned int'
sz = sizeof(typename)
assert ccode(sz) == 'sizeof(%s)' % typename
assert sz.func(*sz.args) == sz
assert not sz.is_Atom
assert sz.atoms() == {String('unsigned int'), String('sizeof')}
def test_struct():
vx, vy = Variable(x, type=float64), Variable(y, type=float64)
s = struct('vec2', [vx, vy])
assert s.func(*s.args) == s
assert s == struct('vec2', (vx, vy))
assert s != struct('vec2', (vy, vx))
assert str(s.name) == 'vec2'
assert len(s.declarations) == 2
assert all(isinstance(arg, Declaration) for arg in s.declarations)
assert ccode(s) == (
"struct vec2 {\n"
" double x;\n"
" double y;\n"
"}")
def test_union():
vx, vy = Variable(x, type=float64), Variable(y, type=int64)
u = union('dualuse', [vx, vy])
assert u.func(*u.args) == u
assert u == union('dualuse', (vx, vy))
assert str(u.name) == 'dualuse'
assert len(u.declarations) == 2
assert all(isinstance(arg, Declaration) for arg in u.declarations)
assert ccode(u) == (
"union dualuse {\n"
" double x;\n"
" int64_t y;\n"
"}")

View File

@ -0,0 +1,14 @@
from sympy.core.symbol import Symbol
from sympy.codegen.ast import Type
from sympy.codegen.cxxnodes import using
from sympy.printing.codeprinter import cxxcode
x = Symbol('x')
def test_using():
v = Type('std::vector')
u1 = using(v)
assert cxxcode(u1) == 'using std::vector'
u2 = using(v, 'vec')
assert cxxcode(u2) == 'using vec = std::vector'

View File

@ -0,0 +1,213 @@
import os
import tempfile
from sympy.core.symbol import (Symbol, symbols)
from sympy.codegen.ast import (
Assignment, Print, Declaration, FunctionDefinition, Return, real,
FunctionCall, Variable, Element, integer
)
from sympy.codegen.fnodes import (
allocatable, ArrayConstructor, isign, dsign, cmplx, kind, literal_dp,
Program, Module, use, Subroutine, dimension, assumed_extent, ImpliedDoLoop,
intent_out, size, Do, SubroutineCall, sum_, array, bind_C
)
from sympy.codegen.futils import render_as_module
from sympy.core.expr import unchanged
from sympy.external import import_module
from sympy.printing.codeprinter import fcode
from sympy.utilities._compilation import has_fortran, compile_run_strings, compile_link_import_strings
from sympy.utilities._compilation.util import may_xfail
from sympy.testing.pytest import skip, XFAIL
cython = import_module('cython')
np = import_module('numpy')
def test_size():
x = Symbol('x', real=True)
sx = size(x)
assert fcode(sx, source_format='free') == 'size(x)'
@may_xfail
def test_size_assumed_shape():
if not has_fortran():
skip("No fortran compiler found.")
a = Symbol('a', real=True)
body = [Return((sum_(a**2)/size(a))**.5)]
arr = array(a, dim=[':'], intent='in')
fd = FunctionDefinition(real, 'rms', [arr], body)
render_as_module([fd], 'mod_rms')
(stdout, stderr), info = compile_run_strings([
('rms.f90', render_as_module([fd], 'mod_rms')),
('main.f90', (
'program myprog\n'
'use mod_rms, only: rms\n'
'real*8, dimension(4), parameter :: x = [4, 2, 2, 2]\n'
'print *, dsqrt(7d0) - rms(x)\n'
'end program\n'
))
], clean=True)
assert '0.00000' in stdout
assert stderr == ''
assert info['exit_status'] == os.EX_OK
@XFAIL # https://github.com/sympy/sympy/issues/20265
@may_xfail
def test_ImpliedDoLoop():
if not has_fortran():
skip("No fortran compiler found.")
a, i = symbols('a i', integer=True)
idl = ImpliedDoLoop(i**3, i, -3, 3, 2)
ac = ArrayConstructor([-28, idl, 28])
a = array(a, dim=[':'], attrs=[allocatable])
prog = Program('idlprog', [
a.as_Declaration(),
Assignment(a, ac),
Print([a])
])
fsrc = fcode(prog, standard=2003, source_format='free')
(stdout, stderr), info = compile_run_strings([('main.f90', fsrc)], clean=True)
for numstr in '-28 -27 -1 1 27 28'.split():
assert numstr in stdout
assert stderr == ''
assert info['exit_status'] == os.EX_OK
@may_xfail
def test_Program():
x = Symbol('x', real=True)
vx = Variable.deduced(x, 42)
decl = Declaration(vx)
prnt = Print([x, x+1])
prog = Program('foo', [decl, prnt])
if not has_fortran():
skip("No fortran compiler found.")
(stdout, stderr), info = compile_run_strings([('main.f90', fcode(prog, standard=90))], clean=True)
assert '42' in stdout
assert '43' in stdout
assert stderr == ''
assert info['exit_status'] == os.EX_OK
@may_xfail
def test_Module():
x = Symbol('x', real=True)
v_x = Variable.deduced(x)
sq = FunctionDefinition(real, 'sqr', [v_x], [Return(x**2)])
mod_sq = Module('mod_sq', [], [sq])
sq_call = FunctionCall('sqr', [42.])
prg_sq = Program('foobar', [
use('mod_sq', only=['sqr']),
Print(['"Square of 42 = "', sq_call])
])
if not has_fortran():
skip("No fortran compiler found.")
(stdout, stderr), info = compile_run_strings([
('mod_sq.f90', fcode(mod_sq, standard=90)),
('main.f90', fcode(prg_sq, standard=90))
], clean=True)
assert '42' in stdout
assert str(42**2) in stdout
assert stderr == ''
@XFAIL # https://github.com/sympy/sympy/issues/20265
@may_xfail
def test_Subroutine():
# Code to generate the subroutine in the example from
# http://www.fortran90.org/src/best-practices.html#arrays
r = Symbol('r', real=True)
i = Symbol('i', integer=True)
v_r = Variable.deduced(r, attrs=(dimension(assumed_extent), intent_out))
v_i = Variable.deduced(i)
v_n = Variable('n', integer)
do_loop = Do([
Assignment(Element(r, [i]), literal_dp(1)/i**2)
], i, 1, v_n)
sub = Subroutine("f", [v_r], [
Declaration(v_n),
Declaration(v_i),
Assignment(v_n, size(r)),
do_loop
])
x = Symbol('x', real=True)
v_x3 = Variable.deduced(x, attrs=[dimension(3)])
mod = Module('mymod', definitions=[sub])
prog = Program('foo', [
use(mod, only=[sub]),
Declaration(v_x3),
SubroutineCall(sub, [v_x3]),
Print([sum_(v_x3), v_x3])
])
if not has_fortran():
skip("No fortran compiler found.")
(stdout, stderr), info = compile_run_strings([
('a.f90', fcode(mod, standard=90)),
('b.f90', fcode(prog, standard=90))
], clean=True)
ref = [1.0/i**2 for i in range(1, 4)]
assert str(sum(ref))[:-3] in stdout
for _ in ref:
assert str(_)[:-3] in stdout
assert stderr == ''
def test_isign():
x = Symbol('x', integer=True)
assert unchanged(isign, 1, x)
assert fcode(isign(1, x), standard=95, source_format='free') == 'isign(1, x)'
def test_dsign():
x = Symbol('x')
assert unchanged(dsign, 1, x)
assert fcode(dsign(literal_dp(1), x), standard=95, source_format='free') == 'dsign(1d0, x)'
def test_cmplx():
x = Symbol('x')
assert unchanged(cmplx, 1, x)
def test_kind():
x = Symbol('x')
assert unchanged(kind, x)
def test_literal_dp():
assert fcode(literal_dp(0), source_format='free') == '0d0'
@may_xfail
def test_bind_C():
if not has_fortran():
skip("No fortran compiler found.")
if not cython:
skip("Cython not found.")
if not np:
skip("NumPy not found.")
a = Symbol('a', real=True)
s = Symbol('s', integer=True)
body = [Return((sum_(a**2)/s)**.5)]
arr = array(a, dim=[s], intent='in')
fd = FunctionDefinition(real, 'rms', [arr, s], body, attrs=[bind_C('rms')])
f_mod = render_as_module([fd], 'mod_rms')
with tempfile.TemporaryDirectory() as folder:
mod, info = compile_link_import_strings([
('rms.f90', f_mod),
('_rms.pyx', (
"#cython: language_level={}\n".format("3") +
"cdef extern double rms(double*, int*)\n"
"def py_rms(double[::1] x):\n"
" cdef int s = x.size\n"
" return rms(&x[0], &s)\n"))
], build_dir=folder)
assert abs(mod.py_rms(np.array([2., 4., 2., 2.])) - 7**0.5) < 1e-14

View File

@ -0,0 +1,50 @@
from sympy.core.symbol import symbols
from sympy.core.function import Function
from sympy.matrices.dense import Matrix
from sympy.matrices.dense import zeros
from sympy.simplify.simplify import simplify
from sympy.codegen.matrix_nodes import MatrixSolve
from sympy.utilities.lambdify import lambdify
from sympy.printing.numpy import NumPyPrinter
from sympy.testing.pytest import skip
from sympy.external import import_module
def test_matrix_solve_issue_24862():
A = Matrix(3, 3, symbols('a:9'))
b = Matrix(3, 1, symbols('b:3'))
hash(MatrixSolve(A, b))
def test_matrix_solve_derivative_exact():
q = symbols('q')
a11, a12, a21, a22, b1, b2 = (
f(q) for f in symbols('a11 a12 a21 a22 b1 b2', cls=Function))
A = Matrix([[a11, a12], [a21, a22]])
b = Matrix([b1, b2])
x_lu = A.LUsolve(b)
dxdq_lu = A.LUsolve(b.diff(q) - A.diff(q) * A.LUsolve(b))
assert simplify(x_lu.diff(q) - dxdq_lu) == zeros(2, 1)
# dxdq_ms is the MatrixSolve equivalent of dxdq_lu
dxdq_ms = MatrixSolve(A, b.diff(q) - A.diff(q) * MatrixSolve(A, b))
assert MatrixSolve(A, b).diff(q) == dxdq_ms
def test_matrix_solve_derivative_numpy():
np = import_module('numpy')
if not np:
skip("numpy not installed.")
q = symbols('q')
a11, a12, a21, a22, b1, b2 = (
f(q) for f in symbols('a11 a12 a21 a22 b1 b2', cls=Function))
A = Matrix([[a11, a12], [a21, a22]])
b = Matrix([b1, b2])
dx_lu = A.LUsolve(b).diff(q)
subs = {a11.diff(q): 0.2, a12.diff(q): 0.3, a21.diff(q): 0.1,
a22.diff(q): 0.5, b1.diff(q): 0.4, b2.diff(q): 0.9,
a11: 1.3, a12: 0.5, a21: 1.2, a22: 4, b1: 6.2, b2: 3.5}
p, p_vals = zip(*subs.items())
dx_sm = MatrixSolve(A, b).diff(q)
np.testing.assert_allclose(
lambdify(p, dx_sm, printer=NumPyPrinter)(*p_vals),
lambdify(p, dx_lu, printer=NumPyPrinter)(*p_vals))

View File

@ -0,0 +1,50 @@
from itertools import product
from sympy.core.singleton import S
from sympy.core.symbol import symbols
from sympy.functions.elementary.exponential import (exp, log)
from sympy.printing.repr import srepr
from sympy.codegen.numpy_nodes import logaddexp, logaddexp2
x, y, z = symbols('x y z')
def test_logaddexp():
lae_xy = logaddexp(x, y)
ref_xy = log(exp(x) + exp(y))
for wrt, deriv_order in product([x, y, z], range(3)):
assert (
lae_xy.diff(wrt, deriv_order) -
ref_xy.diff(wrt, deriv_order)
).rewrite(log).simplify() == 0
one_third_e = 1*exp(1)/3
two_thirds_e = 2*exp(1)/3
logThirdE = log(one_third_e)
logTwoThirdsE = log(two_thirds_e)
lae_sum_to_e = logaddexp(logThirdE, logTwoThirdsE)
assert lae_sum_to_e.rewrite(log) == 1
assert lae_sum_to_e.simplify() == 1
was = logaddexp(2, 3)
assert srepr(was) == srepr(was.simplify()) # cannot simplify with 2, 3
def test_logaddexp2():
lae2_xy = logaddexp2(x, y)
ref2_xy = log(2**x + 2**y)/log(2)
for wrt, deriv_order in product([x, y, z], range(3)):
assert (
lae2_xy.diff(wrt, deriv_order) -
ref2_xy.diff(wrt, deriv_order)
).rewrite(log).cancel() == 0
def lb(x):
return log(x)/log(2)
two_thirds = S.One*2/3
four_thirds = 2*two_thirds
lbTwoThirds = lb(two_thirds)
lbFourThirds = lb(four_thirds)
lae2_sum_to_2 = logaddexp2(lbTwoThirds, lbFourThirds)
assert lae2_sum_to_2.rewrite(log) == 1
assert lae2_sum_to_2.simplify() == 1
was = logaddexp2(x, y)
assert srepr(was) == srepr(was.simplify()) # cannot simplify with x, y

View File

@ -0,0 +1,13 @@
from sympy.core.symbol import symbols
from sympy.codegen.pynodes import List
def test_List():
l = List(2, 3, 4)
assert l == List(2, 3, 4)
assert str(l) == "[2, 3, 4]"
x, y, z = symbols('x y z')
l = List(x**2,y**3,z**4)
# contrary to python's built-in list, we can call e.g. "replace" on List.
m = l.replace(lambda arg: arg.is_Pow and arg.exp>2, lambda p: p.base-p.exp)
assert m == [x**2, y-3, z-4]

View File

@ -0,0 +1,7 @@
from sympy.codegen.ast import Print
from sympy.codegen.pyutils import render_as_module
def test_standard():
ast = Print('x y'.split(), r"coordinate: %12.5g %12.5g\n")
assert render_as_module(ast, standard='python3') == \
'\n\nprint("coordinate: %12.5g %12.5g\\n" % (x, y), end="")'

View File

@ -0,0 +1,479 @@
import tempfile
from sympy.core.numbers import pi, Rational
from sympy.core.power import Pow
from sympy.core.singleton import S
from sympy.core.symbol import Symbol
from sympy.functions.elementary.complexes import Abs
from sympy.functions.elementary.exponential import (exp, log)
from sympy.functions.elementary.trigonometric import (cos, sin, sinc)
from sympy.matrices.expressions.matexpr import MatrixSymbol
from sympy.assumptions import assuming, Q
from sympy.external import import_module
from sympy.printing.codeprinter import ccode
from sympy.codegen.matrix_nodes import MatrixSolve
from sympy.codegen.cfunctions import log2, exp2, expm1, log1p
from sympy.codegen.numpy_nodes import logaddexp, logaddexp2
from sympy.codegen.scipy_nodes import cosm1, powm1
from sympy.codegen.rewriting import (
optimize, cosm1_opt, log2_opt, exp2_opt, expm1_opt, log1p_opt, powm1_opt, optims_c99,
create_expand_pow_optimization, matinv_opt, logaddexp_opt, logaddexp2_opt,
optims_numpy, optims_scipy, sinc_opts, FuncMinusOneOptim
)
from sympy.testing.pytest import XFAIL, skip
from sympy.utilities import lambdify
from sympy.utilities._compilation import compile_link_import_strings, has_c
from sympy.utilities._compilation.util import may_xfail
cython = import_module('cython')
numpy = import_module('numpy')
scipy = import_module('scipy')
def test_log2_opt():
x = Symbol('x')
expr1 = 7*log(3*x + 5)/(log(2))
opt1 = optimize(expr1, [log2_opt])
assert opt1 == 7*log2(3*x + 5)
assert opt1.rewrite(log) == expr1
expr2 = 3*log(5*x + 7)/(13*log(2))
opt2 = optimize(expr2, [log2_opt])
assert opt2 == 3*log2(5*x + 7)/13
assert opt2.rewrite(log) == expr2
expr3 = log(x)/log(2)
opt3 = optimize(expr3, [log2_opt])
assert opt3 == log2(x)
assert opt3.rewrite(log) == expr3
expr4 = log(x)/log(2) + log(x+1)
opt4 = optimize(expr4, [log2_opt])
assert opt4 == log2(x) + log(2)*log2(x+1)
assert opt4.rewrite(log) == expr4
expr5 = log(17)
opt5 = optimize(expr5, [log2_opt])
assert opt5 == expr5
expr6 = log(x + 3)/log(2)
opt6 = optimize(expr6, [log2_opt])
assert str(opt6) == 'log2(x + 3)'
assert opt6.rewrite(log) == expr6
def test_exp2_opt():
x = Symbol('x')
expr1 = 1 + 2**x
opt1 = optimize(expr1, [exp2_opt])
assert opt1 == 1 + exp2(x)
assert opt1.rewrite(Pow) == expr1
expr2 = 1 + 3**x
assert expr2 == optimize(expr2, [exp2_opt])
def test_expm1_opt():
x = Symbol('x')
expr1 = exp(x) - 1
opt1 = optimize(expr1, [expm1_opt])
assert expm1(x) - opt1 == 0
assert opt1.rewrite(exp) == expr1
expr2 = 3*exp(x) - 3
opt2 = optimize(expr2, [expm1_opt])
assert 3*expm1(x) == opt2
assert opt2.rewrite(exp) == expr2
expr3 = 3*exp(x) - 5
opt3 = optimize(expr3, [expm1_opt])
assert 3*expm1(x) - 2 == opt3
assert opt3.rewrite(exp) == expr3
expm1_opt_non_opportunistic = FuncMinusOneOptim(exp, expm1, opportunistic=False)
assert expr3 == optimize(expr3, [expm1_opt_non_opportunistic])
assert opt1 == optimize(expr1, [expm1_opt_non_opportunistic])
assert opt2 == optimize(expr2, [expm1_opt_non_opportunistic])
expr4 = 3*exp(x) + log(x) - 3
opt4 = optimize(expr4, [expm1_opt])
assert 3*expm1(x) + log(x) == opt4
assert opt4.rewrite(exp) == expr4
expr5 = 3*exp(2*x) - 3
opt5 = optimize(expr5, [expm1_opt])
assert 3*expm1(2*x) == opt5
assert opt5.rewrite(exp) == expr5
expr6 = (2*exp(x) + 1)/(exp(x) + 1) + 1
opt6 = optimize(expr6, [expm1_opt])
assert opt6.count_ops() <= expr6.count_ops()
def ev(e):
return e.subs(x, 3).evalf()
assert abs(ev(expr6) - ev(opt6)) < 1e-15
y = Symbol('y')
expr7 = (2*exp(x) - 1)/(1 - exp(y)) - 1/(1-exp(y))
opt7 = optimize(expr7, [expm1_opt])
assert -2*expm1(x)/expm1(y) == opt7
assert (opt7.rewrite(exp) - expr7).factor() == 0
expr8 = (1+exp(x))**2 - 4
opt8 = optimize(expr8, [expm1_opt])
tgt8a = (exp(x) + 3)*expm1(x)
tgt8b = 2*expm1(x) + expm1(2*x)
# Both tgt8a & tgt8b seem to give full precision (~16 digits for double)
# for x=1e-7 (compare with expr8 which only achieves ~8 significant digits).
# If we can show that either tgt8a or tgt8b is preferable, we can
# change this test to ensure the preferable version is returned.
assert (tgt8a - tgt8b).rewrite(exp).factor() == 0
assert opt8 in (tgt8a, tgt8b)
assert (opt8.rewrite(exp) - expr8).factor() == 0
expr9 = sin(expr8)
opt9 = optimize(expr9, [expm1_opt])
tgt9a = sin(tgt8a)
tgt9b = sin(tgt8b)
assert opt9 in (tgt9a, tgt9b)
assert (opt9.rewrite(exp) - expr9.rewrite(exp)).factor().is_zero
def test_expm1_two_exp_terms():
x, y = map(Symbol, 'x y'.split())
expr1 = exp(x) + exp(y) - 2
opt1 = optimize(expr1, [expm1_opt])
assert opt1 == expm1(x) + expm1(y)
def test_cosm1_opt():
x = Symbol('x')
expr1 = cos(x) - 1
opt1 = optimize(expr1, [cosm1_opt])
assert cosm1(x) - opt1 == 0
assert opt1.rewrite(cos) == expr1
expr2 = 3*cos(x) - 3
opt2 = optimize(expr2, [cosm1_opt])
assert 3*cosm1(x) == opt2
assert opt2.rewrite(cos) == expr2
expr3 = 3*cos(x) - 5
opt3 = optimize(expr3, [cosm1_opt])
assert 3*cosm1(x) - 2 == opt3
assert opt3.rewrite(cos) == expr3
cosm1_opt_non_opportunistic = FuncMinusOneOptim(cos, cosm1, opportunistic=False)
assert expr3 == optimize(expr3, [cosm1_opt_non_opportunistic])
assert opt1 == optimize(expr1, [cosm1_opt_non_opportunistic])
assert opt2 == optimize(expr2, [cosm1_opt_non_opportunistic])
expr4 = 3*cos(x) + log(x) - 3
opt4 = optimize(expr4, [cosm1_opt])
assert 3*cosm1(x) + log(x) == opt4
assert opt4.rewrite(cos) == expr4
expr5 = 3*cos(2*x) - 3
opt5 = optimize(expr5, [cosm1_opt])
assert 3*cosm1(2*x) == opt5
assert opt5.rewrite(cos) == expr5
expr6 = 2 - 2*cos(x)
opt6 = optimize(expr6, [cosm1_opt])
assert -2*cosm1(x) == opt6
assert opt6.rewrite(cos) == expr6
def test_cosm1_two_cos_terms():
x, y = map(Symbol, 'x y'.split())
expr1 = cos(x) + cos(y) - 2
opt1 = optimize(expr1, [cosm1_opt])
assert opt1 == cosm1(x) + cosm1(y)
def test_expm1_cosm1_mixed():
x = Symbol('x')
expr1 = exp(x) + cos(x) - 2
opt1 = optimize(expr1, [expm1_opt, cosm1_opt])
assert opt1 == cosm1(x) + expm1(x)
def _check_num_lambdify(expr, opt, val_subs, approx_ref, lambdify_kw=None, poorness=1e10):
""" poorness=1e10 signifies that `expr` loses precision of at least ten decimal digits. """
num_ref = expr.subs(val_subs).evalf()
eps = numpy.finfo(numpy.float64).eps
assert abs(num_ref - approx_ref) < approx_ref*eps
f1 = lambdify(list(val_subs.keys()), opt, **(lambdify_kw or {}))
args_float = tuple(map(float, val_subs.values()))
num_err1 = abs(f1(*args_float) - approx_ref)
assert num_err1 < abs(num_ref*eps)
f2 = lambdify(list(val_subs.keys()), expr, **(lambdify_kw or {}))
num_err2 = abs(f2(*args_float) - approx_ref)
assert num_err2 > abs(num_ref*eps*poorness) # this only ensures that the *test* works as intended
def test_cosm1_apart():
x = Symbol('x')
expr1 = 1/cos(x) - 1
opt1 = optimize(expr1, [cosm1_opt])
assert opt1 == -cosm1(x)/cos(x)
if scipy:
_check_num_lambdify(expr1, opt1, {x: S(10)**-30}, 5e-61, lambdify_kw={"modules": 'scipy'})
expr2 = 2/cos(x) - 2
opt2 = optimize(expr2, optims_scipy)
assert opt2 == -2*cosm1(x)/cos(x)
if scipy:
_check_num_lambdify(expr2, opt2, {x: S(10)**-30}, 1e-60, lambdify_kw={"modules": 'scipy'})
expr3 = pi/cos(3*x) - pi
opt3 = optimize(expr3, [cosm1_opt])
assert opt3 == -pi*cosm1(3*x)/cos(3*x)
if scipy:
_check_num_lambdify(expr3, opt3, {x: S(10)**-30/3}, float(5e-61*pi), lambdify_kw={"modules": 'scipy'})
def test_powm1():
args = x, y = map(Symbol, "xy")
expr1 = x**y - 1
opt1 = optimize(expr1, [powm1_opt])
assert opt1 == powm1(x, y)
for arg in args:
assert expr1.diff(arg) == opt1.diff(arg)
if scipy and tuple(map(int, scipy.version.version.split('.')[:3])) >= (1, 10, 0):
subs1_a = {x: Rational(*(1.0+1e-13).as_integer_ratio()), y: pi}
ref1_f64_a = 3.139081648208105e-13
_check_num_lambdify(expr1, opt1, subs1_a, ref1_f64_a, lambdify_kw={"modules": 'scipy'}, poorness=10**11)
subs1_b = {x: pi, y: Rational(*(1e-10).as_integer_ratio())}
ref1_f64_b = 1.1447298859149205e-10
_check_num_lambdify(expr1, opt1, subs1_b, ref1_f64_b, lambdify_kw={"modules": 'scipy'}, poorness=10**9)
def test_log1p_opt():
x = Symbol('x')
expr1 = log(x + 1)
opt1 = optimize(expr1, [log1p_opt])
assert log1p(x) - opt1 == 0
assert opt1.rewrite(log) == expr1
expr2 = log(3*x + 3)
opt2 = optimize(expr2, [log1p_opt])
assert log1p(x) + log(3) == opt2
assert (opt2.rewrite(log) - expr2).simplify() == 0
expr3 = log(2*x + 1)
opt3 = optimize(expr3, [log1p_opt])
assert log1p(2*x) - opt3 == 0
assert opt3.rewrite(log) == expr3
expr4 = log(x+3)
opt4 = optimize(expr4, [log1p_opt])
assert str(opt4) == 'log(x + 3)'
def test_optims_c99():
x = Symbol('x')
expr1 = 2**x + log(x)/log(2) + log(x + 1) + exp(x) - 1
opt1 = optimize(expr1, optims_c99).simplify()
assert opt1 == exp2(x) + log2(x) + log1p(x) + expm1(x)
assert opt1.rewrite(exp).rewrite(log).rewrite(Pow) == expr1
expr2 = log(x)/log(2) + log(x + 1)
opt2 = optimize(expr2, optims_c99)
assert opt2 == log2(x) + log1p(x)
assert opt2.rewrite(log) == expr2
expr3 = log(x)/log(2) + log(17*x + 17)
opt3 = optimize(expr3, optims_c99)
delta3 = opt3 - (log2(x) + log(17) + log1p(x))
assert delta3 == 0
assert (opt3.rewrite(log) - expr3).simplify() == 0
expr4 = 2**x + 3*log(5*x + 7)/(13*log(2)) + 11*exp(x) - 11 + log(17*x + 17)
opt4 = optimize(expr4, optims_c99).simplify()
delta4 = opt4 - (exp2(x) + 3*log2(5*x + 7)/13 + 11*expm1(x) + log(17) + log1p(x))
assert delta4 == 0
assert (opt4.rewrite(exp).rewrite(log).rewrite(Pow) - expr4).simplify() == 0
expr5 = 3*exp(2*x) - 3
opt5 = optimize(expr5, optims_c99)
delta5 = opt5 - 3*expm1(2*x)
assert delta5 == 0
assert opt5.rewrite(exp) == expr5
expr6 = exp(2*x) - 3
opt6 = optimize(expr6, optims_c99)
assert opt6 in (expm1(2*x) - 2, expr6) # expm1(2*x) - 2 is not better or worse
expr7 = log(3*x + 3)
opt7 = optimize(expr7, optims_c99)
delta7 = opt7 - (log(3) + log1p(x))
assert delta7 == 0
assert (opt7.rewrite(log) - expr7).simplify() == 0
expr8 = log(2*x + 3)
opt8 = optimize(expr8, optims_c99)
assert opt8 == expr8
def test_create_expand_pow_optimization():
cc = lambda x: ccode(
optimize(x, [create_expand_pow_optimization(4)]))
x = Symbol('x')
assert cc(x**4) == 'x*x*x*x'
assert cc(x**4 + x**2) == 'x*x + x*x*x*x'
assert cc(x**5 + x**4) == 'pow(x, 5) + x*x*x*x'
assert cc(sin(x)**4) == 'pow(sin(x), 4)'
# gh issue 15335
assert cc(x**(-4)) == '1.0/(x*x*x*x)'
assert cc(x**(-5)) == 'pow(x, -5)'
assert cc(-x**4) == '-(x*x*x*x)'
assert cc(x**4 - x**2) == '-(x*x) + x*x*x*x'
i = Symbol('i', integer=True)
assert cc(x**i - x**2) == 'pow(x, i) - (x*x)'
y = Symbol('y', real=True)
assert cc(Abs(exp(y**4))) == "exp(y*y*y*y)"
# gh issue 20753
cc2 = lambda x: ccode(optimize(x, [create_expand_pow_optimization(
4, base_req=lambda b: b.is_Function)]))
assert cc2(x**3 + sin(x)**3) == "pow(x, 3) + sin(x)*sin(x)*sin(x)"
def test_matsolve():
n = Symbol('n', integer=True)
A = MatrixSymbol('A', n, n)
x = MatrixSymbol('x', n, 1)
with assuming(Q.fullrank(A)):
assert optimize(A**(-1) * x, [matinv_opt]) == MatrixSolve(A, x)
assert optimize(A**(-1) * x + x, [matinv_opt]) == MatrixSolve(A, x) + x
def test_logaddexp_opt():
x, y = map(Symbol, 'x y'.split())
expr1 = log(exp(x) + exp(y))
opt1 = optimize(expr1, [logaddexp_opt])
assert logaddexp(x, y) - opt1 == 0
assert logaddexp(y, x) - opt1 == 0
assert opt1.rewrite(log) == expr1
def test_logaddexp2_opt():
x, y = map(Symbol, 'x y'.split())
expr1 = log(2**x + 2**y)/log(2)
opt1 = optimize(expr1, [logaddexp2_opt])
assert logaddexp2(x, y) - opt1 == 0
assert logaddexp2(y, x) - opt1 == 0
assert opt1.rewrite(log) == expr1
def test_sinc_opts():
def check(d):
for k, v in d.items():
assert optimize(k, sinc_opts) == v
x = Symbol('x')
check({
sin(x)/x : sinc(x),
sin(2*x)/(2*x) : sinc(2*x),
sin(3*x)/x : 3*sinc(3*x),
x*sin(x) : x*sin(x)
})
y = Symbol('y')
check({
sin(x*y)/(x*y) : sinc(x*y),
y*sin(x/y)/x : sinc(x/y),
sin(sin(x))/sin(x) : sinc(sin(x)),
sin(3*sin(x))/sin(x) : 3*sinc(3*sin(x)),
sin(x)/y : sin(x)/y
})
def test_optims_numpy():
def check(d):
for k, v in d.items():
assert optimize(k, optims_numpy) == v
x = Symbol('x')
check({
sin(2*x)/(2*x) + exp(2*x) - 1: sinc(2*x) + expm1(2*x),
log(x+3)/log(2) + log(x**2 + 1): log1p(x**2) + log2(x+3)
})
@XFAIL # room for improvement, ideally this test case should pass.
def test_optims_numpy_TODO():
def check(d):
for k, v in d.items():
assert optimize(k, optims_numpy) == v
x, y = map(Symbol, 'x y'.split())
check({
log(x*y)*sin(x*y)*log(x*y+1)/(log(2)*x*y): log2(x*y)*sinc(x*y)*log1p(x*y),
exp(x*sin(y)/y) - 1: expm1(x*sinc(y))
})
@may_xfail
def test_compiled_ccode_with_rewriting():
if not cython:
skip("cython not installed.")
if not has_c():
skip("No C compiler found.")
x = Symbol('x')
about_two = 2**(58/S(117))*3**(97/S(117))*5**(4/S(39))*7**(92/S(117))/S(30)*pi
# about_two: 1.999999999999581826
unchanged = 2*exp(x) - about_two
xval = S(10)**-11
ref = unchanged.subs(x, xval).n(19) # 2.0418173913673213e-11
rewritten = optimize(2*exp(x) - about_two, [expm1_opt])
# Unfortunately, we need to call ``.n()`` on our expressions before we hand them
# to ``ccode``, and we need to request a large number of significant digits.
# In this test, results converged for double precision when the following number
# of significant digits were chosen:
NUMBER_OF_DIGITS = 25 # TODO: this should ideally be automatically handled.
func_c = '''
#include <math.h>
double func_unchanged(double x) {
return %(unchanged)s;
}
double func_rewritten(double x) {
return %(rewritten)s;
}
''' % {"unchanged": ccode(unchanged.n(NUMBER_OF_DIGITS)),
"rewritten": ccode(rewritten.n(NUMBER_OF_DIGITS))}
func_pyx = '''
#cython: language_level=3
cdef extern double func_unchanged(double)
cdef extern double func_rewritten(double)
def py_unchanged(x):
return func_unchanged(x)
def py_rewritten(x):
return func_rewritten(x)
'''
with tempfile.TemporaryDirectory() as folder:
mod, info = compile_link_import_strings(
[('func.c', func_c), ('_func.pyx', func_pyx)],
build_dir=folder, compile_kwargs={"std": 'c99'}
)
err_rewritten = abs(mod.py_rewritten(1e-11) - ref)
err_unchanged = abs(mod.py_unchanged(1e-11) - ref)
assert 1e-27 < err_rewritten < 1e-25 # highly accurate.
assert 1e-19 < err_unchanged < 1e-16 # quite poor.
# Tolerances used above were determined as follows:
# >>> no_opt = unchanged.subs(x, xval.evalf()).evalf()
# >>> with_opt = rewritten.n(25).subs(x, 1e-11).evalf()
# >>> with_opt - ref, no_opt - ref
# (1.1536301877952077e-26, 1.6547074214222335e-18)

View File

@ -0,0 +1,44 @@
from itertools import product
from sympy.core.power import Pow
from sympy.core.symbol import symbols
from sympy.functions.elementary.exponential import exp, log
from sympy.functions.elementary.trigonometric import cos
from sympy.core.numbers import pi
from sympy.codegen.scipy_nodes import cosm1, powm1
x, y, z = symbols('x y z')
def test_cosm1():
cm1_xy = cosm1(x*y)
ref_xy = cos(x*y) - 1
for wrt, deriv_order in product([x, y, z], range(3)):
assert (
cm1_xy.diff(wrt, deriv_order) -
ref_xy.diff(wrt, deriv_order)
).rewrite(cos).simplify() == 0
expr_minus2 = cosm1(pi)
assert expr_minus2.rewrite(cos) == -2
assert cosm1(3.14).simplify() == cosm1(3.14) # cannot simplify with 3.14
assert cosm1(pi/2).simplify() == -1
assert (1/cos(x) - 1 + cosm1(x)/cos(x)).simplify() == 0
def test_powm1():
cases = {
powm1(x, y): x**y - 1,
powm1(x*y, z): (x*y)**z - 1,
powm1(x, y*z): x**(y*z)-1,
powm1(x*y*z, x*y*z): (x*y*z)**(x*y*z)-1
}
for pm1_e, ref_e in cases.items():
for wrt, deriv_order in product([x, y, z], range(3)):
der = pm1_e.diff(wrt, deriv_order)
ref = ref_e.diff(wrt, deriv_order)
delta = (der - ref).rewrite(Pow)
assert delta.simplify() == 0
eulers_constant_m1 = powm1(x, 1/log(x))
assert eulers_constant_m1.rewrite(Pow) == exp(1) - 1
assert eulers_constant_m1.simplify() == exp(1) - 1