I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,626 @@
"""
Important note on tests in this module - the Aesara printing functions use a
global cache by default, which means that tests using it will modify global
state and thus not be independent from each other. Instead of using the "cache"
keyword argument each time, this module uses the aesara_code_ and
aesara_function_ functions defined below which default to using a new, empty
cache instead.
"""
import logging
from sympy.external import import_module
from sympy.testing.pytest import raises, SKIP
from sympy.utilities.exceptions import ignore_warnings
aesaralogger = logging.getLogger('aesara.configdefaults')
aesaralogger.setLevel(logging.CRITICAL)
aesara = import_module('aesara')
aesaralogger.setLevel(logging.WARNING)
if aesara:
import numpy as np
aet = aesara.tensor
from aesara.scalar.basic import ScalarType
from aesara.graph.basic import Variable
from aesara.tensor.var import TensorVariable
from aesara.tensor.elemwise import Elemwise, DimShuffle
from aesara.tensor.math import Dot
from sympy.printing.aesaracode import true_divide
xt, yt, zt = [aet.scalar(name, 'floatX') for name in 'xyz']
Xt, Yt, Zt = [aet.tensor('floatX', (False, False), name=n) for n in 'XYZ']
else:
#bin/test will not execute any tests now
disabled = True
import sympy as sy
from sympy.core.singleton import S
from sympy.abc import x, y, z, t
from sympy.printing.aesaracode import (aesara_code, dim_handling,
aesara_function)
# Default set of matrix symbols for testing - make square so we can both
# multiply and perform elementwise operations between them.
X, Y, Z = [sy.MatrixSymbol(n, 4, 4) for n in 'XYZ']
# For testing AppliedUndef
f_t = sy.Function('f')(t)
def aesara_code_(expr, **kwargs):
""" Wrapper for aesara_code that uses a new, empty cache by default. """
kwargs.setdefault('cache', {})
return aesara_code(expr, **kwargs)
def aesara_function_(inputs, outputs, **kwargs):
""" Wrapper for aesara_function that uses a new, empty cache by default. """
kwargs.setdefault('cache', {})
return aesara_function(inputs, outputs, **kwargs)
def fgraph_of(*exprs):
""" Transform SymPy expressions into Aesara Computation.
Parameters
==========
exprs
SymPy expressions
Returns
=======
aesara.graph.fg.FunctionGraph
"""
outs = list(map(aesara_code_, exprs))
ins = list(aesara.graph.basic.graph_inputs(outs))
ins, outs = aesara.graph.basic.clone(ins, outs)
return aesara.graph.fg.FunctionGraph(ins, outs)
def aesara_simplify(fgraph):
""" Simplify a Aesara Computation.
Parameters
==========
fgraph : aesara.graph.fg.FunctionGraph
Returns
=======
aesara.graph.fg.FunctionGraph
"""
mode = aesara.compile.get_default_mode().excluding("fusion")
fgraph = fgraph.clone()
mode.optimizer.rewrite(fgraph)
return fgraph
def theq(a, b):
""" Test two Aesara objects for equality.
Also accepts numeric types and lists/tuples of supported types.
Note - debugprint() has a bug where it will accept numeric types but does
not respect the "file" argument and in this case and instead prints the number
to stdout and returns an empty string. This can lead to tests passing where
they should fail because any two numbers will always compare as equal. To
prevent this we treat numbers as a separate case.
"""
numeric_types = (int, float, np.number)
a_is_num = isinstance(a, numeric_types)
b_is_num = isinstance(b, numeric_types)
# Compare numeric types using regular equality
if a_is_num or b_is_num:
if not (a_is_num and b_is_num):
return False
return a == b
# Compare sequences element-wise
a_is_seq = isinstance(a, (tuple, list))
b_is_seq = isinstance(b, (tuple, list))
if a_is_seq or b_is_seq:
if not (a_is_seq and b_is_seq) or type(a) != type(b):
return False
return list(map(theq, a)) == list(map(theq, b))
# Otherwise, assume debugprint() can handle it
astr = aesara.printing.debugprint(a, file='str')
bstr = aesara.printing.debugprint(b, file='str')
# Check for bug mentioned above
for argname, argval, argstr in [('a', a, astr), ('b', b, bstr)]:
if argstr == '':
raise TypeError(
'aesara.printing.debugprint(%s) returned empty string '
'(%s is instance of %r)'
% (argname, argname, type(argval))
)
return astr == bstr
def test_example_symbols():
"""
Check that the example symbols in this module print to their Aesara
equivalents, as many of the other tests depend on this.
"""
assert theq(xt, aesara_code_(x))
assert theq(yt, aesara_code_(y))
assert theq(zt, aesara_code_(z))
assert theq(Xt, aesara_code_(X))
assert theq(Yt, aesara_code_(Y))
assert theq(Zt, aesara_code_(Z))
def test_Symbol():
""" Test printing a Symbol to a aesara variable. """
xx = aesara_code_(x)
assert isinstance(xx, Variable)
assert xx.broadcastable == ()
assert xx.name == x.name
xx2 = aesara_code_(x, broadcastables={x: (False,)})
assert xx2.broadcastable == (False,)
assert xx2.name == x.name
def test_MatrixSymbol():
""" Test printing a MatrixSymbol to a aesara variable. """
XX = aesara_code_(X)
assert isinstance(XX, TensorVariable)
assert XX.broadcastable == (False, False)
@SKIP # TODO - this is currently not checked but should be implemented
def test_MatrixSymbol_wrong_dims():
""" Test MatrixSymbol with invalid broadcastable. """
bcs = [(), (False,), (True,), (True, False), (False, True,), (True, True)]
for bc in bcs:
with raises(ValueError):
aesara_code_(X, broadcastables={X: bc})
def test_AppliedUndef():
""" Test printing AppliedUndef instance, which works similarly to Symbol. """
ftt = aesara_code_(f_t)
assert isinstance(ftt, TensorVariable)
assert ftt.broadcastable == ()
assert ftt.name == 'f_t'
def test_add():
expr = x + y
comp = aesara_code_(expr)
assert comp.owner.op == aesara.tensor.add
def test_trig():
assert theq(aesara_code_(sy.sin(x)), aet.sin(xt))
assert theq(aesara_code_(sy.tan(x)), aet.tan(xt))
def test_many():
""" Test printing a complex expression with multiple symbols. """
expr = sy.exp(x**2 + sy.cos(y)) * sy.log(2*z)
comp = aesara_code_(expr)
expected = aet.exp(xt**2 + aet.cos(yt)) * aet.log(2*zt)
assert theq(comp, expected)
def test_dtype():
""" Test specifying specific data types through the dtype argument. """
for dtype in ['float32', 'float64', 'int8', 'int16', 'int32', 'int64']:
assert aesara_code_(x, dtypes={x: dtype}).type.dtype == dtype
# "floatX" type
assert aesara_code_(x, dtypes={x: 'floatX'}).type.dtype in ('float32', 'float64')
# Type promotion
assert aesara_code_(x + 1, dtypes={x: 'float32'}).type.dtype == 'float32'
assert aesara_code_(x + y, dtypes={x: 'float64', y: 'float32'}).type.dtype == 'float64'
def test_broadcastables():
""" Test the "broadcastables" argument when printing symbol-like objects. """
# No restrictions on shape
for s in [x, f_t]:
for bc in [(), (False,), (True,), (False, False), (True, False)]:
assert aesara_code_(s, broadcastables={s: bc}).broadcastable == bc
# TODO - matrix broadcasting?
def test_broadcasting():
""" Test "broadcastable" attribute after applying element-wise binary op. """
expr = x + y
cases = [
[(), (), ()],
[(False,), (False,), (False,)],
[(True,), (False,), (False,)],
[(False, True), (False, False), (False, False)],
[(True, False), (False, False), (False, False)],
]
for bc1, bc2, bc3 in cases:
comp = aesara_code_(expr, broadcastables={x: bc1, y: bc2})
assert comp.broadcastable == bc3
def test_MatMul():
expr = X*Y*Z
expr_t = aesara_code_(expr)
assert isinstance(expr_t.owner.op, Dot)
assert theq(expr_t, Xt.dot(Yt).dot(Zt))
def test_Transpose():
assert isinstance(aesara_code_(X.T).owner.op, DimShuffle)
def test_MatAdd():
expr = X+Y+Z
assert isinstance(aesara_code_(expr).owner.op, Elemwise)
def test_Rationals():
assert theq(aesara_code_(sy.Integer(2) / 3), true_divide(2, 3))
assert theq(aesara_code_(S.Half), true_divide(1, 2))
def test_Integers():
assert aesara_code_(sy.Integer(3)) == 3
def test_factorial():
n = sy.Symbol('n')
assert aesara_code_(sy.factorial(n))
def test_Derivative():
with ignore_warnings(UserWarning):
simp = lambda expr: aesara_simplify(fgraph_of(expr))
assert theq(simp(aesara_code_(sy.Derivative(sy.sin(x), x, evaluate=False))),
simp(aesara.grad(aet.sin(xt), xt)))
def test_aesara_function_simple():
""" Test aesara_function() with single output. """
f = aesara_function_([x, y], [x+y])
assert f(2, 3) == 5
def test_aesara_function_multi():
""" Test aesara_function() with multiple outputs. """
f = aesara_function_([x, y], [x+y, x-y])
o1, o2 = f(2, 3)
assert o1 == 5
assert o2 == -1
def test_aesara_function_numpy():
""" Test aesara_function() vs Numpy implementation. """
f = aesara_function_([x, y], [x+y], dim=1,
dtypes={x: 'float64', y: 'float64'})
assert np.linalg.norm(f([1, 2], [3, 4]) - np.asarray([4, 6])) < 1e-9
f = aesara_function_([x, y], [x+y], dtypes={x: 'float64', y: 'float64'},
dim=1)
xx = np.arange(3).astype('float64')
yy = 2*np.arange(3).astype('float64')
assert np.linalg.norm(f(xx, yy) - 3*np.arange(3)) < 1e-9
def test_aesara_function_matrix():
m = sy.Matrix([[x, y], [z, x + y + z]])
expected = np.array([[1.0, 2.0], [3.0, 1.0 + 2.0 + 3.0]])
f = aesara_function_([x, y, z], [m])
np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected)
f = aesara_function_([x, y, z], [m], scalar=True)
np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected)
f = aesara_function_([x, y, z], [m, m])
assert isinstance(f(1.0, 2.0, 3.0), type([]))
np.testing.assert_allclose(f(1.0, 2.0, 3.0)[0], expected)
np.testing.assert_allclose(f(1.0, 2.0, 3.0)[1], expected)
def test_dim_handling():
assert dim_handling([x], dim=2) == {x: (False, False)}
assert dim_handling([x, y], dims={x: 1, y: 2}) == {x: (False, True),
y: (False, False)}
assert dim_handling([x], broadcastables={x: (False,)}) == {x: (False,)}
def test_aesara_function_kwargs():
"""
Test passing additional kwargs from aesara_function() to aesara.function().
"""
import numpy as np
f = aesara_function_([x, y, z], [x+y], dim=1, on_unused_input='ignore',
dtypes={x: 'float64', y: 'float64', z: 'float64'})
assert np.linalg.norm(f([1, 2], [3, 4], [0, 0]) - np.asarray([4, 6])) < 1e-9
f = aesara_function_([x, y, z], [x+y],
dtypes={x: 'float64', y: 'float64', z: 'float64'},
dim=1, on_unused_input='ignore')
xx = np.arange(3).astype('float64')
yy = 2*np.arange(3).astype('float64')
zz = 2*np.arange(3).astype('float64')
assert np.linalg.norm(f(xx, yy, zz) - 3*np.arange(3)) < 1e-9
def test_aesara_function_scalar():
""" Test the "scalar" argument to aesara_function(). """
from aesara.compile.function.types import Function
args = [
([x, y], [x + y], None, [0]), # Single 0d output
([X, Y], [X + Y], None, [2]), # Single 2d output
([x, y], [x + y], {x: 0, y: 1}, [1]), # Single 1d output
([x, y], [x + y, x - y], None, [0, 0]), # Two 0d outputs
([x, y, X, Y], [x + y, X + Y], None, [0, 2]), # One 0d output, one 2d
]
# Create and test functions with and without the scalar setting
for inputs, outputs, in_dims, out_dims in args:
for scalar in [False, True]:
f = aesara_function_(inputs, outputs, dims=in_dims, scalar=scalar)
# Check the aesara_function attribute is set whether wrapped or not
assert isinstance(f.aesara_function, Function)
# Feed in inputs of the appropriate size and get outputs
in_values = [
np.ones([1 if bc else 5 for bc in i.type.broadcastable])
for i in f.aesara_function.input_storage
]
out_values = f(*in_values)
if not isinstance(out_values, list):
out_values = [out_values]
# Check output types and shapes
assert len(out_dims) == len(out_values)
for d, value in zip(out_dims, out_values):
if scalar and d == 0:
# Should have been converted to a scalar value
assert isinstance(value, np.number)
else:
# Otherwise should be an array
assert isinstance(value, np.ndarray)
assert value.ndim == d
def test_aesara_function_bad_kwarg():
"""
Passing an unknown keyword argument to aesara_function() should raise an
exception.
"""
raises(Exception, lambda : aesara_function_([x], [x+1], foobar=3))
def test_slice():
assert aesara_code_(slice(1, 2, 3)) == slice(1, 2, 3)
def theq_slice(s1, s2):
for attr in ['start', 'stop', 'step']:
a1 = getattr(s1, attr)
a2 = getattr(s2, attr)
if a1 is None or a2 is None:
if not (a1 is None or a2 is None):
return False
elif not theq(a1, a2):
return False
return True
dtypes = {x: 'int32', y: 'int32'}
assert theq_slice(aesara_code_(slice(x, y), dtypes=dtypes), slice(xt, yt))
assert theq_slice(aesara_code_(slice(1, x, 3), dtypes=dtypes), slice(1, xt, 3))
def test_MatrixSlice():
cache = {}
n = sy.Symbol('n', integer=True)
X = sy.MatrixSymbol('X', n, n)
Y = X[1:2:3, 4:5:6]
Yt = aesara_code_(Y, cache=cache)
s = ScalarType('int64')
assert tuple(Yt.owner.op.idx_list) == (slice(s, s, s), slice(s, s, s))
assert Yt.owner.inputs[0] == aesara_code_(X, cache=cache)
# == doesn't work in Aesara like it does in SymPy. You have to use
# equals.
assert all(Yt.owner.inputs[i].data == i for i in range(1, 7))
k = sy.Symbol('k')
aesara_code_(k, dtypes={k: 'int32'})
start, stop, step = 4, k, 2
Y = X[start:stop:step]
Yt = aesara_code_(Y, dtypes={n: 'int32', k: 'int32'})
# assert Yt.owner.op.idx_list[0].stop == kt
def test_BlockMatrix():
n = sy.Symbol('n', integer=True)
A, B, C, D = [sy.MatrixSymbol(name, n, n) for name in 'ABCD']
At, Bt, Ct, Dt = map(aesara_code_, (A, B, C, D))
Block = sy.BlockMatrix([[A, B], [C, D]])
Blockt = aesara_code_(Block)
solutions = [aet.join(0, aet.join(1, At, Bt), aet.join(1, Ct, Dt)),
aet.join(1, aet.join(0, At, Ct), aet.join(0, Bt, Dt))]
assert any(theq(Blockt, solution) for solution in solutions)
@SKIP
def test_BlockMatrix_Inverse_execution():
k, n = 2, 4
dtype = 'float32'
A = sy.MatrixSymbol('A', n, k)
B = sy.MatrixSymbol('B', n, n)
inputs = A, B
output = B.I*A
cutsizes = {A: [(n//2, n//2), (k//2, k//2)],
B: [(n//2, n//2), (n//2, n//2)]}
cutinputs = [sy.blockcut(i, *cutsizes[i]) for i in inputs]
cutoutput = output.subs(dict(zip(inputs, cutinputs)))
dtypes = dict(zip(inputs, [dtype]*len(inputs)))
f = aesara_function_(inputs, [output], dtypes=dtypes, cache={})
fblocked = aesara_function_(inputs, [sy.block_collapse(cutoutput)],
dtypes=dtypes, cache={})
ninputs = [np.random.rand(*x.shape).astype(dtype) for x in inputs]
ninputs = [np.arange(n*k).reshape(A.shape).astype(dtype),
np.eye(n).astype(dtype)]
ninputs[1] += np.ones(B.shape)*1e-5
assert np.allclose(f(*ninputs), fblocked(*ninputs), rtol=1e-5)
def test_DenseMatrix():
from aesara.tensor.basic import Join
t = sy.Symbol('theta')
for MatrixType in [sy.Matrix, sy.ImmutableMatrix]:
X = MatrixType([[sy.cos(t), -sy.sin(t)], [sy.sin(t), sy.cos(t)]])
tX = aesara_code_(X)
assert isinstance(tX, TensorVariable)
assert isinstance(tX.owner.op, Join)
def test_cache_basic():
""" Test single symbol-like objects are cached when printed by themselves. """
# Pairs of objects which should be considered equivalent with respect to caching
pairs = [
(x, sy.Symbol('x')),
(X, sy.MatrixSymbol('X', *X.shape)),
(f_t, sy.Function('f')(sy.Symbol('t'))),
]
for s1, s2 in pairs:
cache = {}
st = aesara_code_(s1, cache=cache)
# Test hit with same instance
assert aesara_code_(s1, cache=cache) is st
# Test miss with same instance but new cache
assert aesara_code_(s1, cache={}) is not st
# Test hit with different but equivalent instance
assert aesara_code_(s2, cache=cache) is st
def test_global_cache():
""" Test use of the global cache. """
from sympy.printing.aesaracode import global_cache
backup = dict(global_cache)
try:
# Temporarily empty global cache
global_cache.clear()
for s in [x, X, f_t]:
st = aesara_code(s)
assert aesara_code(s) is st
finally:
# Restore global cache
global_cache.update(backup)
def test_cache_types_distinct():
"""
Test that symbol-like objects of different types (Symbol, MatrixSymbol,
AppliedUndef) are distinguished by the cache even if they have the same
name.
"""
symbols = [sy.Symbol('f_t'), sy.MatrixSymbol('f_t', 4, 4), f_t]
cache = {} # Single shared cache
printed = {}
for s in symbols:
st = aesara_code_(s, cache=cache)
assert st not in printed.values()
printed[s] = st
# Check all printed objects are distinct
assert len(set(map(id, printed.values()))) == len(symbols)
# Check retrieving
for s, st in printed.items():
assert aesara_code(s, cache=cache) is st
def test_symbols_are_created_once():
"""
Test that a symbol is cached and reused when it appears in an expression
more than once.
"""
expr = sy.Add(x, x, evaluate=False)
comp = aesara_code_(expr)
assert theq(comp, xt + xt)
assert not theq(comp, xt + aesara_code_(x))
def test_cache_complex():
"""
Test caching on a complicated expression with multiple symbols appearing
multiple times.
"""
expr = x ** 2 + (y - sy.exp(x)) * sy.sin(z - x * y)
symbol_names = {s.name for s in expr.free_symbols}
expr_t = aesara_code_(expr)
# Iterate through variables in the Aesara computational graph that the
# printed expression depends on
seen = set()
for v in aesara.graph.basic.ancestors([expr_t]):
# Owner-less, non-constant variables should be our symbols
if v.owner is None and not isinstance(v, aesara.graph.basic.Constant):
# Check it corresponds to a symbol and appears only once
assert v.name in symbol_names
assert v.name not in seen
seen.add(v.name)
# Check all were present
assert seen == symbol_names
def test_Piecewise():
# A piecewise linear
expr = sy.Piecewise((0, x<0), (x, x<2), (1, True)) # ___/III
result = aesara_code_(expr)
assert result.owner.op == aet.switch
expected = aet.switch(xt<0, 0, aet.switch(xt<2, xt, 1))
assert theq(result, expected)
expr = sy.Piecewise((x, x < 0))
result = aesara_code_(expr)
expected = aet.switch(xt < 0, xt, np.nan)
assert theq(result, expected)
expr = sy.Piecewise((0, sy.And(x>0, x<2)), \
(x, sy.Or(x>2, x<0)))
result = aesara_code_(expr)
expected = aet.switch(aet.and_(xt>0,xt<2), 0, \
aet.switch(aet.or_(xt>2, xt<0), xt, np.nan))
assert theq(result, expected)
def test_Relationals():
assert theq(aesara_code_(sy.Eq(x, y)), aet.eq(xt, yt))
# assert theq(aesara_code_(sy.Ne(x, y)), aet.neq(xt, yt)) # TODO - implement
assert theq(aesara_code_(x > y), xt > yt)
assert theq(aesara_code_(x < y), xt < yt)
assert theq(aesara_code_(x >= y), xt >= yt)
assert theq(aesara_code_(x <= y), xt <= yt)
def test_complexfunctions():
dtypes = {x:'complex128', y:'complex128'}
xt, yt = aesara_code(x, dtypes=dtypes), aesara_code(y, dtypes=dtypes)
from sympy.functions.elementary.complexes import conjugate
from aesara.tensor import as_tensor_variable as atv
from aesara.tensor import complex as cplx
assert theq(aesara_code(y*conjugate(x), dtypes=dtypes), yt*(xt.conj()))
assert theq(aesara_code((1+2j)*x), xt*(atv(1.0)+atv(2.0)*cplx(0,1)))
def test_constantfunctions():
tf = aesara_function([],[1+1j])
assert(tf()==1+1j)

View File

@ -0,0 +1,883 @@
from sympy.core import (
S, pi, oo, Symbol, symbols, Rational, Integer, Float, Function, Mod, GoldenRatio, EulerGamma, Catalan,
Lambda, Dummy, nan, Mul, Pow, UnevaluatedExpr
)
from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne)
from sympy.functions import (
Abs, acos, acosh, asin, asinh, atan, atanh, atan2, ceiling, cos, cosh, erf,
erfc, exp, floor, gamma, log, loggamma, Max, Min, Piecewise, sign, sin, sinh,
sqrt, tan, tanh, fibonacci, lucas
)
from sympy.sets import Range
from sympy.logic import ITE, Implies, Equivalent
from sympy.codegen import For, aug_assign, Assignment
from sympy.testing.pytest import raises, XFAIL
from sympy.printing.codeprinter import PrintMethodNotImplementedError
from sympy.printing.c import C89CodePrinter, C99CodePrinter, get_math_macros
from sympy.codegen.ast import (
AddAugmentedAssignment, Element, Type, FloatType, Declaration, Pointer, Variable, value_const, pointer_const,
While, Scope, Print, FunctionPrototype, FunctionDefinition, FunctionCall, Return,
real, float32, float64, float80, float128, intc, Comment, CodeBlock, stderr, QuotedString
)
from sympy.codegen.cfunctions import expm1, log1p, exp2, log2, fma, log10, Cbrt, hypot, Sqrt
from sympy.codegen.cnodes import restrict
from sympy.utilities.lambdify import implemented_function
from sympy.tensor import IndexedBase, Idx
from sympy.matrices import Matrix, MatrixSymbol, SparseMatrix
from sympy.printing.codeprinter import ccode
x, y, z = symbols('x,y,z')
def test_printmethod():
class fabs(Abs):
def _ccode(self, printer):
return "fabs(%s)" % printer._print(self.args[0])
assert ccode(fabs(x)) == "fabs(x)"
def test_ccode_sqrt():
assert ccode(sqrt(x)) == "sqrt(x)"
assert ccode(x**0.5) == "sqrt(x)"
assert ccode(sqrt(x)) == "sqrt(x)"
def test_ccode_Pow():
assert ccode(x**3) == "pow(x, 3)"
assert ccode(x**(y**3)) == "pow(x, pow(y, 3))"
g = implemented_function('g', Lambda(x, 2*x))
assert ccode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \
"pow(3.5*2*x, -x + pow(y, x))/(pow(x, 2) + y)"
assert ccode(x**-1.0) == '1.0/x'
assert ccode(x**Rational(2, 3)) == 'pow(x, 2.0/3.0)'
assert ccode(x**Rational(2, 3), type_aliases={real: float80}) == 'powl(x, 2.0L/3.0L)'
_cond_cfunc = [(lambda base, exp: exp.is_integer, "dpowi"),
(lambda base, exp: not exp.is_integer, "pow")]
assert ccode(x**3, user_functions={'Pow': _cond_cfunc}) == 'dpowi(x, 3)'
assert ccode(x**0.5, user_functions={'Pow': _cond_cfunc}) == 'pow(x, 0.5)'
assert ccode(x**Rational(16, 5), user_functions={'Pow': _cond_cfunc}) == 'pow(x, 16.0/5.0)'
_cond_cfunc2 = [(lambda base, exp: base == 2, lambda base, exp: 'exp2(%s)' % exp),
(lambda base, exp: base != 2, 'pow')]
# Related to gh-11353
assert ccode(2**x, user_functions={'Pow': _cond_cfunc2}) == 'exp2(x)'
assert ccode(x**2, user_functions={'Pow': _cond_cfunc2}) == 'pow(x, 2)'
# For issue 14160
assert ccode(Mul(-2, x, Pow(Mul(y,y,evaluate=False), -1, evaluate=False),
evaluate=False)) == '-2*x/(y*y)'
def test_ccode_Max():
# Test for gh-11926
assert ccode(Max(x,x*x),user_functions={"Max":"my_max", "Pow":"my_pow"}) == 'my_max(x, my_pow(x, 2))'
def test_ccode_Min_performance():
#Shouldn't take more than a few seconds
big_min = Min(*symbols('a[0:50]'))
for curr_standard in ('c89', 'c99', 'c11'):
output = ccode(big_min, standard=curr_standard)
assert output.count('(') == output.count(')')
def test_ccode_constants_mathh():
assert ccode(exp(1)) == "M_E"
assert ccode(pi) == "M_PI"
assert ccode(oo, standard='c89') == "HUGE_VAL"
assert ccode(-oo, standard='c89') == "-HUGE_VAL"
assert ccode(oo) == "INFINITY"
assert ccode(-oo, standard='c99') == "-INFINITY"
assert ccode(pi, type_aliases={real: float80}) == "M_PIl"
def test_ccode_constants_other():
assert ccode(2*GoldenRatio) == "const double GoldenRatio = %s;\n2*GoldenRatio" % GoldenRatio.evalf(17)
assert ccode(
2*Catalan) == "const double Catalan = %s;\n2*Catalan" % Catalan.evalf(17)
assert ccode(2*EulerGamma) == "const double EulerGamma = %s;\n2*EulerGamma" % EulerGamma.evalf(17)
def test_ccode_Rational():
assert ccode(Rational(3, 7)) == "3.0/7.0"
assert ccode(Rational(3, 7), type_aliases={real: float80}) == "3.0L/7.0L"
assert ccode(Rational(18, 9)) == "2"
assert ccode(Rational(3, -7)) == "-3.0/7.0"
assert ccode(Rational(3, -7), type_aliases={real: float80}) == "-3.0L/7.0L"
assert ccode(Rational(-3, -7)) == "3.0/7.0"
assert ccode(Rational(-3, -7), type_aliases={real: float80}) == "3.0L/7.0L"
assert ccode(x + Rational(3, 7)) == "x + 3.0/7.0"
assert ccode(x + Rational(3, 7), type_aliases={real: float80}) == "x + 3.0L/7.0L"
assert ccode(Rational(3, 7)*x) == "(3.0/7.0)*x"
assert ccode(Rational(3, 7)*x, type_aliases={real: float80}) == "(3.0L/7.0L)*x"
def test_ccode_Integer():
assert ccode(Integer(67)) == "67"
assert ccode(Integer(-1)) == "-1"
def test_ccode_functions():
assert ccode(sin(x) ** cos(x)) == "pow(sin(x), cos(x))"
def test_ccode_inline_function():
x = symbols('x')
g = implemented_function('g', Lambda(x, 2*x))
assert ccode(g(x)) == "2*x"
g = implemented_function('g', Lambda(x, 2*x/Catalan))
assert ccode(
g(x)) == "const double Catalan = %s;\n2*x/Catalan" % Catalan.evalf(17)
A = IndexedBase('A')
i = Idx('i', symbols('n', integer=True))
g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x)))
assert ccode(g(A[i]), assign_to=A[i]) == (
"for (int i=0; i<n; i++){\n"
" A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n"
"}"
)
def test_ccode_exceptions():
assert ccode(gamma(x), standard='C99') == "tgamma(x)"
with raises(PrintMethodNotImplementedError):
ccode(gamma(x), standard='C89')
with raises(PrintMethodNotImplementedError):
ccode(gamma(x), standard='C89', allow_unknown_functions=False)
ccode(gamma(x), standard='C89', allow_unknown_functions=True)
def test_ccode_functions2():
assert ccode(ceiling(x)) == "ceil(x)"
assert ccode(Abs(x)) == "fabs(x)"
assert ccode(gamma(x)) == "tgamma(x)"
r, s = symbols('r,s', real=True)
assert ccode(Mod(ceiling(r), ceiling(s))) == '((ceil(r) % ceil(s)) + '\
'ceil(s)) % ceil(s)'
assert ccode(Mod(r, s)) == "fmod(r, s)"
p1, p2 = symbols('p1 p2', integer=True, positive=True)
assert ccode(Mod(p1, p2)) == 'p1 % p2'
assert ccode(Mod(p1, p2 + 3)) == 'p1 % (p2 + 3)'
assert ccode(Mod(-3, -7, evaluate=False)) == '(-3) % (-7)'
assert ccode(-Mod(3, 7, evaluate=False)) == '-(3 % 7)'
assert ccode(r*Mod(p1, p2)) == 'r*(p1 % p2)'
assert ccode(Mod(p1, p2)**s) == 'pow(p1 % p2, s)'
n = symbols('n', integer=True, negative=True)
assert ccode(Mod(-n, p2)) == '(-n) % p2'
assert ccode(fibonacci(n)) == '((1.0/5.0)*pow(2, -n)*sqrt(5)*(-pow(1 - sqrt(5), n) + pow(1 + sqrt(5), n)))'
assert ccode(lucas(n)) == '(pow(2, -n)*(pow(1 - sqrt(5), n) + pow(1 + sqrt(5), n)))'
def test_ccode_user_functions():
x = symbols('x', integer=False)
n = symbols('n', integer=True)
custom_functions = {
"ceiling": "ceil",
"Abs": [(lambda x: not x.is_integer, "fabs"), (lambda x: x.is_integer, "abs")],
}
assert ccode(ceiling(x), user_functions=custom_functions) == "ceil(x)"
assert ccode(Abs(x), user_functions=custom_functions) == "fabs(x)"
assert ccode(Abs(n), user_functions=custom_functions) == "abs(n)"
expr = Symbol('a')
muladd = Function('muladd')
for i in range(0, 100):
# the large number of terms acts as a regression test for gh-23839
expr = muladd(Rational(1, 2), Symbol(f'a{i}'), expr)
out = ccode(expr, user_functions={'muladd':'muladd'})
assert 'a99' in out
assert out.count('muladd') == 100
def test_ccode_boolean():
assert ccode(True) == "true"
assert ccode(S.true) == "true"
assert ccode(False) == "false"
assert ccode(S.false) == "false"
assert ccode(x & y) == "x && y"
assert ccode(x | y) == "x || y"
assert ccode(~x) == "!x"
assert ccode(x & y & z) == "x && y && z"
assert ccode(x | y | z) == "x || y || z"
assert ccode((x & y) | z) == "z || x && y"
assert ccode((x | y) & z) == "z && (x || y)"
# Automatic rewrites
assert ccode(x ^ y) == '(x || y) && (!x || !y)'
assert ccode((x ^ y) ^ z) == '(x || y || z) && (x || !y || !z) && (y || !x || !z) && (z || !x || !y)'
assert ccode(Implies(x, y)) == 'y || !x'
assert ccode(Equivalent(x, z ^ y, Implies(z, x))) == '(x || (y || !z) && (z || !y)) && (z && !x || (y || z) && (!y || !z))'
def test_ccode_Relational():
assert ccode(Eq(x, y)) == "x == y"
assert ccode(Ne(x, y)) == "x != y"
assert ccode(Le(x, y)) == "x <= y"
assert ccode(Lt(x, y)) == "x < y"
assert ccode(Gt(x, y)) == "x > y"
assert ccode(Ge(x, y)) == "x >= y"
def test_ccode_Piecewise():
expr = Piecewise((x, x < 1), (x**2, True))
assert ccode(expr) == (
"((x < 1) ? (\n"
" x\n"
")\n"
": (\n"
" pow(x, 2)\n"
"))")
assert ccode(expr, assign_to="c") == (
"if (x < 1) {\n"
" c = x;\n"
"}\n"
"else {\n"
" c = pow(x, 2);\n"
"}")
expr = Piecewise((x, x < 1), (x + 1, x < 2), (x**2, True))
assert ccode(expr) == (
"((x < 1) ? (\n"
" x\n"
")\n"
": ((x < 2) ? (\n"
" x + 1\n"
")\n"
": (\n"
" pow(x, 2)\n"
")))")
assert ccode(expr, assign_to='c') == (
"if (x < 1) {\n"
" c = x;\n"
"}\n"
"else if (x < 2) {\n"
" c = x + 1;\n"
"}\n"
"else {\n"
" c = pow(x, 2);\n"
"}")
# Check that Piecewise without a True (default) condition error
expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
raises(ValueError, lambda: ccode(expr))
def test_ccode_sinc():
from sympy.functions.elementary.trigonometric import sinc
expr = sinc(x)
assert ccode(expr) == (
"(((x != 0) ? (\n"
" sin(x)/x\n"
")\n"
": (\n"
" 1\n"
")))")
def test_ccode_Piecewise_deep():
p = ccode(2*Piecewise((x, x < 1), (x + 1, x < 2), (x**2, True)))
assert p == (
"2*((x < 1) ? (\n"
" x\n"
")\n"
": ((x < 2) ? (\n"
" x + 1\n"
")\n"
": (\n"
" pow(x, 2)\n"
")))")
expr = x*y*z + x**2 + y**2 + Piecewise((0, x < 0.5), (1, True)) + cos(z) - 1
assert ccode(expr) == (
"pow(x, 2) + x*y*z + pow(y, 2) + ((x < 0.5) ? (\n"
" 0\n"
")\n"
": (\n"
" 1\n"
")) + cos(z) - 1")
assert ccode(expr, assign_to='c') == (
"c = pow(x, 2) + x*y*z + pow(y, 2) + ((x < 0.5) ? (\n"
" 0\n"
")\n"
": (\n"
" 1\n"
")) + cos(z) - 1;")
def test_ccode_ITE():
expr = ITE(x < 1, y, z)
assert ccode(expr) == (
"((x < 1) ? (\n"
" y\n"
")\n"
": (\n"
" z\n"
"))")
def test_ccode_settings():
raises(TypeError, lambda: ccode(sin(x), method="garbage"))
def test_ccode_Indexed():
s, n, m, o = symbols('s n m o', integer=True)
i, j, k = Idx('i', n), Idx('j', m), Idx('k', o)
x = IndexedBase('x')[j]
A = IndexedBase('A')[i, j]
B = IndexedBase('B')[i, j, k]
p = C99CodePrinter()
assert p._print_Indexed(x) == 'x[j]'
assert p._print_Indexed(A) == 'A[%s]' % (m*i+j)
assert p._print_Indexed(B) == 'B[%s]' % (i*o*m+j*o+k)
A = IndexedBase('A', shape=(5,3))[i, j]
assert p._print_Indexed(A) == 'A[%s]' % (3*i + j)
A = IndexedBase('A', shape=(5,3), strides='F')[i, j]
assert ccode(A) == 'A[%s]' % (i + 5*j)
A = IndexedBase('A', shape=(29,29), strides=(1, s), offset=o)[i, j]
assert ccode(A) == 'A[o + s*j + i]'
Abase = IndexedBase('A', strides=(s, m, n), offset=o)
assert ccode(Abase[i, j, k]) == 'A[m*j + n*k + o + s*i]'
assert ccode(Abase[2, 3, k]) == 'A[3*m + n*k + o + 2*s]'
def test_Element():
assert ccode(Element('x', 'ij')) == 'x[i][j]'
assert ccode(Element('x', 'ij', strides='kl', offset='o')) == 'x[i*k + j*l + o]'
assert ccode(Element('x', (3,))) == 'x[3]'
assert ccode(Element('x', (3,4,5))) == 'x[3][4][5]'
def test_ccode_Indexed_without_looking_for_contraction():
len_y = 5
y = IndexedBase('y', shape=(len_y,))
x = IndexedBase('x', shape=(len_y,))
Dy = IndexedBase('Dy', shape=(len_y-1,))
i = Idx('i', len_y-1)
e = Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i]))
code0 = ccode(e.rhs, assign_to=e.lhs, contract=False)
assert code0 == 'Dy[i] = (y[%s] - y[i])/(x[%s] - x[i]);' % (i + 1, i + 1)
def test_ccode_loops_matrix_vector():
n, m = symbols('n m', integer=True)
A = IndexedBase('A')
x = IndexedBase('x')
y = IndexedBase('y')
i = Idx('i', m)
j = Idx('j', n)
s = (
'for (int i=0; i<m; i++){\n'
' y[i] = 0;\n'
'}\n'
'for (int i=0; i<m; i++){\n'
' for (int j=0; j<n; j++){\n'
' y[i] = A[%s]*x[j] + y[i];\n' % (i*n + j) +\
' }\n'
'}'
)
assert ccode(A[i, j]*x[j], assign_to=y[i]) == s
def test_dummy_loops():
i, m = symbols('i m', integer=True, cls=Dummy)
x = IndexedBase('x')
y = IndexedBase('y')
i = Idx(i, m)
expected = (
'for (int i_%(icount)i=0; i_%(icount)i<m_%(mcount)i; i_%(icount)i++){\n'
' y[i_%(icount)i] = x[i_%(icount)i];\n'
'}'
) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index}
assert ccode(x[i], assign_to=y[i]) == expected
def test_ccode_loops_add():
n, m = symbols('n m', integer=True)
A = IndexedBase('A')
x = IndexedBase('x')
y = IndexedBase('y')
z = IndexedBase('z')
i = Idx('i', m)
j = Idx('j', n)
s = (
'for (int i=0; i<m; i++){\n'
' y[i] = x[i] + z[i];\n'
'}\n'
'for (int i=0; i<m; i++){\n'
' for (int j=0; j<n; j++){\n'
' y[i] = A[%s]*x[j] + y[i];\n' % (i*n + j) +\
' }\n'
'}'
)
assert ccode(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i]) == s
def test_ccode_loops_multiple_contractions():
n, m, o, p = symbols('n m o p', integer=True)
a = IndexedBase('a')
b = IndexedBase('b')
y = IndexedBase('y')
i = Idx('i', m)
j = Idx('j', n)
k = Idx('k', o)
l = Idx('l', p)
s = (
'for (int i=0; i<m; i++){\n'
' y[i] = 0;\n'
'}\n'
'for (int i=0; i<m; i++){\n'
' for (int j=0; j<n; j++){\n'
' for (int k=0; k<o; k++){\n'
' for (int l=0; l<p; l++){\n'
' y[i] = a[%s]*b[%s] + y[i];\n' % (i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
' }\n'
' }\n'
' }\n'
'}'
)
assert ccode(b[j, k, l]*a[i, j, k, l], assign_to=y[i]) == s
def test_ccode_loops_addfactor():
n, m, o, p = symbols('n m o p', integer=True)
a = IndexedBase('a')
b = IndexedBase('b')
c = IndexedBase('c')
y = IndexedBase('y')
i = Idx('i', m)
j = Idx('j', n)
k = Idx('k', o)
l = Idx('l', p)
s = (
'for (int i=0; i<m; i++){\n'
' y[i] = 0;\n'
'}\n'
'for (int i=0; i<m; i++){\n'
' for (int j=0; j<n; j++){\n'
' for (int k=0; k<o; k++){\n'
' for (int l=0; l<p; l++){\n'
' y[i] = (a[%s] + b[%s])*c[%s] + y[i];\n' % (i*n*o*p + j*o*p + k*p + l, i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
' }\n'
' }\n'
' }\n'
'}'
)
assert ccode((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i]) == s
def test_ccode_loops_multiple_terms():
n, m, o, p = symbols('n m o p', integer=True)
a = IndexedBase('a')
b = IndexedBase('b')
c = IndexedBase('c')
y = IndexedBase('y')
i = Idx('i', m)
j = Idx('j', n)
k = Idx('k', o)
s0 = (
'for (int i=0; i<m; i++){\n'
' y[i] = 0;\n'
'}\n'
)
s1 = (
'for (int i=0; i<m; i++){\n'
' for (int j=0; j<n; j++){\n'
' for (int k=0; k<o; k++){\n'
' y[i] = b[j]*b[k]*c[%s] + y[i];\n' % (i*n*o + j*o + k) +\
' }\n'
' }\n'
'}\n'
)
s2 = (
'for (int i=0; i<m; i++){\n'
' for (int k=0; k<o; k++){\n'
' y[i] = a[%s]*b[k] + y[i];\n' % (i*o + k) +\
' }\n'
'}\n'
)
s3 = (
'for (int i=0; i<m; i++){\n'
' for (int j=0; j<n; j++){\n'
' y[i] = a[%s]*b[j] + y[i];\n' % (i*n + j) +\
' }\n'
'}\n'
)
c = ccode(b[j]*a[i, j] + b[k]*a[i, k] + b[j]*b[k]*c[i, j, k], assign_to=y[i])
assert (c == s0 + s1 + s2 + s3[:-1] or
c == s0 + s1 + s3 + s2[:-1] or
c == s0 + s2 + s1 + s3[:-1] or
c == s0 + s2 + s3 + s1[:-1] or
c == s0 + s3 + s1 + s2[:-1] or
c == s0 + s3 + s2 + s1[:-1])
def test_dereference_printing():
expr = x + y + sin(z) + z
assert ccode(expr, dereference=[z]) == "x + y + (*z) + sin((*z))"
def test_Matrix_printing():
# Test returning a Matrix
mat = Matrix([x*y, Piecewise((2 + x, y>0), (y, True)), sin(z)])
A = MatrixSymbol('A', 3, 1)
assert ccode(mat, A) == (
"A[0] = x*y;\n"
"if (y > 0) {\n"
" A[1] = x + 2;\n"
"}\n"
"else {\n"
" A[1] = y;\n"
"}\n"
"A[2] = sin(z);")
# Test using MatrixElements in expressions
expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0]
assert ccode(expr) == (
"((x > 0) ? (\n"
" 2*A[2]\n"
")\n"
": (\n"
" A[2]\n"
")) + sin(A[1]) + A[0]")
# Test using MatrixElements in a Matrix
q = MatrixSymbol('q', 5, 1)
M = MatrixSymbol('M', 3, 3)
m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])],
[q[1,0] + q[2,0], q[3, 0], 5],
[2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]])
assert ccode(m, M) == (
"M[0] = sin(q[1]);\n"
"M[1] = 0;\n"
"M[2] = cos(q[2]);\n"
"M[3] = q[1] + q[2];\n"
"M[4] = q[3];\n"
"M[5] = 5;\n"
"M[6] = 2*q[4]/q[1];\n"
"M[7] = sqrt(q[0]) + 4;\n"
"M[8] = 0;")
def test_sparse_matrix():
# gh-15791
with raises(PrintMethodNotImplementedError):
ccode(SparseMatrix([[1, 2, 3]]))
assert 'Not supported in C' in C89CodePrinter({'strict': False}).doprint(SparseMatrix([[1, 2, 3]]))
def test_ccode_reserved_words():
x, y = symbols('x, if')
with raises(ValueError):
ccode(y**2, error_on_reserved=True, standard='C99')
assert ccode(y**2) == 'pow(if_, 2)'
assert ccode(x * y**2, dereference=[y]) == 'pow((*if_), 2)*x'
assert ccode(y**2, reserved_word_suffix='_unreserved') == 'pow(if_unreserved, 2)'
def test_ccode_sign():
expr1, ref1 = sign(x) * y, 'y*(((x) > 0) - ((x) < 0))'
expr2, ref2 = sign(cos(x)), '(((cos(x)) > 0) - ((cos(x)) < 0))'
expr3, ref3 = sign(2 * x + x**2) * x + x**2, 'pow(x, 2) + x*(((pow(x, 2) + 2*x) > 0) - ((pow(x, 2) + 2*x) < 0))'
assert ccode(expr1) == ref1
assert ccode(expr1, 'z') == 'z = %s;' % ref1
assert ccode(expr2) == ref2
assert ccode(expr3) == ref3
def test_ccode_Assignment():
assert ccode(Assignment(x, y + z)) == 'x = y + z;'
assert ccode(aug_assign(x, '+', y + z)) == 'x += y + z;'
def test_ccode_For():
f = For(x, Range(0, 10, 2), [aug_assign(y, '*', x)])
assert ccode(f) == ("for (x = 0; x < 10; x += 2) {\n"
" y *= x;\n"
"}")
def test_ccode_Max_Min():
assert ccode(Max(x, 0), standard='C89') == '((0 > x) ? 0 : x)'
assert ccode(Max(x, 0), standard='C99') == 'fmax(0, x)'
assert ccode(Min(x, 0, sqrt(x)), standard='c89') == (
'((0 < ((x < sqrt(x)) ? x : sqrt(x))) ? 0 : ((x < sqrt(x)) ? x : sqrt(x)))'
)
def test_ccode_standard():
assert ccode(expm1(x), standard='c99') == 'expm1(x)'
assert ccode(nan, standard='c99') == 'NAN'
assert ccode(float('nan'), standard='c99') == 'NAN'
def test_C89CodePrinter():
c89printer = C89CodePrinter()
assert c89printer.language == 'C'
assert c89printer.standard == 'C89'
assert 'void' in c89printer.reserved_words
assert 'template' not in c89printer.reserved_words
def test_C99CodePrinter():
assert C99CodePrinter().doprint(expm1(x)) == 'expm1(x)'
assert C99CodePrinter().doprint(log1p(x)) == 'log1p(x)'
assert C99CodePrinter().doprint(exp2(x)) == 'exp2(x)'
assert C99CodePrinter().doprint(log2(x)) == 'log2(x)'
assert C99CodePrinter().doprint(fma(x, y, -z)) == 'fma(x, y, -z)'
assert C99CodePrinter().doprint(log10(x)) == 'log10(x)'
assert C99CodePrinter().doprint(Cbrt(x)) == 'cbrt(x)' # note Cbrt due to cbrt already taken.
assert C99CodePrinter().doprint(hypot(x, y)) == 'hypot(x, y)'
assert C99CodePrinter().doprint(loggamma(x)) == 'lgamma(x)'
assert C99CodePrinter().doprint(Max(x, 3, x**2)) == 'fmax(3, fmax(x, pow(x, 2)))'
assert C99CodePrinter().doprint(Min(x, 3)) == 'fmin(3, x)'
c99printer = C99CodePrinter()
assert c99printer.language == 'C'
assert c99printer.standard == 'C99'
assert 'restrict' in c99printer.reserved_words
assert 'using' not in c99printer.reserved_words
@XFAIL
def test_C99CodePrinter__precision_f80():
f80_printer = C99CodePrinter({"type_aliases": {real: float80}})
assert f80_printer.doprint(sin(x + Float('2.1'))) == 'sinl(x + 2.1L)'
def test_C99CodePrinter__precision():
n = symbols('n', integer=True)
p = symbols('p', integer=True, positive=True)
f32_printer = C99CodePrinter({"type_aliases": {real: float32}})
f64_printer = C99CodePrinter({"type_aliases": {real: float64}})
f80_printer = C99CodePrinter({"type_aliases": {real: float80}})
assert f32_printer.doprint(sin(x+2.1)) == 'sinf(x + 2.1F)'
assert f64_printer.doprint(sin(x+2.1)) == 'sin(x + 2.1000000000000001)'
assert f80_printer.doprint(sin(x+Float('2.0'))) == 'sinl(x + 2.0L)'
for printer, suffix in zip([f32_printer, f64_printer, f80_printer], ['f', '', 'l']):
def check(expr, ref):
assert printer.doprint(expr) == ref.format(s=suffix, S=suffix.upper())
check(Abs(n), 'abs(n)')
check(Abs(x + 2.0), 'fabs{s}(x + 2.0{S})')
check(sin(x + 4.0)**cos(x - 2.0), 'pow{s}(sin{s}(x + 4.0{S}), cos{s}(x - 2.0{S}))')
check(exp(x*8.0), 'exp{s}(8.0{S}*x)')
check(exp2(x), 'exp2{s}(x)')
check(expm1(x*4.0), 'expm1{s}(4.0{S}*x)')
check(Mod(p, 2), 'p % 2')
check(Mod(2*p + 3, 3*p + 5, evaluate=False), '(2*p + 3) % (3*p + 5)')
check(Mod(x + 2.0, 3.0), 'fmod{s}(1.0{S}*x + 2.0{S}, 3.0{S})')
check(Mod(x, 2.0*x + 3.0), 'fmod{s}(1.0{S}*x, 2.0{S}*x + 3.0{S})')
check(log(x/2), 'log{s}((1.0{S}/2.0{S})*x)')
check(log10(3*x/2), 'log10{s}((3.0{S}/2.0{S})*x)')
check(log2(x*8.0), 'log2{s}(8.0{S}*x)')
check(log1p(x), 'log1p{s}(x)')
check(2**x, 'pow{s}(2, x)')
check(2.0**x, 'pow{s}(2.0{S}, x)')
check(x**3, 'pow{s}(x, 3)')
check(x**4.0, 'pow{s}(x, 4.0{S})')
check(sqrt(3+x), 'sqrt{s}(x + 3)')
check(Cbrt(x-2.0), 'cbrt{s}(x - 2.0{S})')
check(hypot(x, y), 'hypot{s}(x, y)')
check(sin(3.*x + 2.), 'sin{s}(3.0{S}*x + 2.0{S})')
check(cos(3.*x - 1.), 'cos{s}(3.0{S}*x - 1.0{S})')
check(tan(4.*y + 2.), 'tan{s}(4.0{S}*y + 2.0{S})')
check(asin(3.*x + 2.), 'asin{s}(3.0{S}*x + 2.0{S})')
check(acos(3.*x + 2.), 'acos{s}(3.0{S}*x + 2.0{S})')
check(atan(3.*x + 2.), 'atan{s}(3.0{S}*x + 2.0{S})')
check(atan2(3.*x, 2.*y), 'atan2{s}(3.0{S}*x, 2.0{S}*y)')
check(sinh(3.*x + 2.), 'sinh{s}(3.0{S}*x + 2.0{S})')
check(cosh(3.*x - 1.), 'cosh{s}(3.0{S}*x - 1.0{S})')
check(tanh(4.0*y + 2.), 'tanh{s}(4.0{S}*y + 2.0{S})')
check(asinh(3.*x + 2.), 'asinh{s}(3.0{S}*x + 2.0{S})')
check(acosh(3.*x + 2.), 'acosh{s}(3.0{S}*x + 2.0{S})')
check(atanh(3.*x + 2.), 'atanh{s}(3.0{S}*x + 2.0{S})')
check(erf(42.*x), 'erf{s}(42.0{S}*x)')
check(erfc(42.*x), 'erfc{s}(42.0{S}*x)')
check(gamma(x), 'tgamma{s}(x)')
check(loggamma(x), 'lgamma{s}(x)')
check(ceiling(x + 2.), "ceil{s}(x) + 2")
check(floor(x + 2.), "floor{s}(x) + 2")
check(fma(x, y, -z), 'fma{s}(x, y, -z)')
check(Max(x, 8.0, x**4.0), 'fmax{s}(8.0{S}, fmax{s}(x, pow{s}(x, 4.0{S})))')
check(Min(x, 2.0), 'fmin{s}(2.0{S}, x)')
def test_get_math_macros():
macros = get_math_macros()
assert macros[exp(1)] == 'M_E'
assert macros[1/Sqrt(2)] == 'M_SQRT1_2'
def test_ccode_Declaration():
i = symbols('i', integer=True)
var1 = Variable(i, type=Type.from_expr(i))
dcl1 = Declaration(var1)
assert ccode(dcl1) == 'int i'
var2 = Variable(x, type=float32, attrs={value_const})
dcl2a = Declaration(var2)
assert ccode(dcl2a) == 'const float x'
dcl2b = var2.as_Declaration(value=pi)
assert ccode(dcl2b) == 'const float x = M_PI'
var3 = Variable(y, type=Type('bool'))
dcl3 = Declaration(var3)
printer = C89CodePrinter()
assert 'stdbool.h' not in printer.headers
assert printer.doprint(dcl3) == 'bool y'
assert 'stdbool.h' in printer.headers
u = symbols('u', real=True)
ptr4 = Pointer.deduced(u, attrs={pointer_const, restrict})
dcl4 = Declaration(ptr4)
assert ccode(dcl4) == 'double * const restrict u'
var5 = Variable(x, Type('__float128'), attrs={value_const})
dcl5a = Declaration(var5)
assert ccode(dcl5a) == 'const __float128 x'
var5b = Variable(var5.symbol, var5.type, pi, attrs=var5.attrs)
dcl5b = Declaration(var5b)
assert ccode(dcl5b) == 'const __float128 x = M_PI'
def test_C99CodePrinter_custom_type():
# We will look at __float128 (new in glibc 2.26)
f128 = FloatType('_Float128', float128.nbits, float128.nmant, float128.nexp)
p128 = C99CodePrinter({
"type_aliases": {real: f128},
"type_literal_suffixes": {f128: 'Q'},
"type_func_suffixes": {f128: 'f128'},
"type_math_macro_suffixes": {
real: 'f128',
f128: 'f128'
},
"type_macros": {
f128: ('__STDC_WANT_IEC_60559_TYPES_EXT__',)
}
})
assert p128.doprint(x) == 'x'
assert not p128.headers
assert not p128.libraries
assert not p128.macros
assert p128.doprint(2.0) == '2.0Q'
assert not p128.headers
assert not p128.libraries
assert p128.macros == {'__STDC_WANT_IEC_60559_TYPES_EXT__'}
assert p128.doprint(Rational(1, 2)) == '1.0Q/2.0Q'
assert p128.doprint(sin(x)) == 'sinf128(x)'
assert p128.doprint(cos(2., evaluate=False)) == 'cosf128(2.0Q)'
assert p128.doprint(x**-1.0) == '1.0Q/x'
var5 = Variable(x, f128, attrs={value_const})
dcl5a = Declaration(var5)
assert ccode(dcl5a) == 'const _Float128 x'
var5b = Variable(x, f128, pi, attrs={value_const})
dcl5b = Declaration(var5b)
assert p128.doprint(dcl5b) == 'const _Float128 x = M_PIf128'
var5b = Variable(x, f128, value=Catalan.evalf(38), attrs={value_const})
dcl5c = Declaration(var5b)
assert p128.doprint(dcl5c) == 'const _Float128 x = %sQ' % Catalan.evalf(f128.decimal_dig)
def test_MatrixElement_printing():
# test cases for issue #11821
A = MatrixSymbol("A", 1, 3)
B = MatrixSymbol("B", 1, 3)
C = MatrixSymbol("C", 1, 3)
assert(ccode(A[0, 0]) == "A[0]")
assert(ccode(3 * A[0, 0]) == "3*A[0]")
F = C[0, 0].subs(C, A - B)
assert(ccode(F) == "(A - B)[0]")
def test_ccode_math_macros():
assert ccode(z + exp(1)) == 'z + M_E'
assert ccode(z + log2(exp(1))) == 'z + M_LOG2E'
assert ccode(z + 1/log(2)) == 'z + M_LOG2E'
assert ccode(z + log(2)) == 'z + M_LN2'
assert ccode(z + log(10)) == 'z + M_LN10'
assert ccode(z + pi) == 'z + M_PI'
assert ccode(z + pi/2) == 'z + M_PI_2'
assert ccode(z + pi/4) == 'z + M_PI_4'
assert ccode(z + 1/pi) == 'z + M_1_PI'
assert ccode(z + 2/pi) == 'z + M_2_PI'
assert ccode(z + 2/sqrt(pi)) == 'z + M_2_SQRTPI'
assert ccode(z + 2/Sqrt(pi)) == 'z + M_2_SQRTPI'
assert ccode(z + sqrt(2)) == 'z + M_SQRT2'
assert ccode(z + Sqrt(2)) == 'z + M_SQRT2'
assert ccode(z + 1/sqrt(2)) == 'z + M_SQRT1_2'
assert ccode(z + 1/Sqrt(2)) == 'z + M_SQRT1_2'
def test_ccode_Type():
assert ccode(Type('float')) == 'float'
assert ccode(intc) == 'int'
def test_ccode_codegen_ast():
# Note that C only allows comments of the form /* ... */, double forward
# slash is not standard C, and some C compilers will grind to a halt upon
# encountering them.
assert ccode(Comment("this is a comment")) == "/* this is a comment */" # not //
assert ccode(While(abs(x) > 1, [aug_assign(x, '-', 1)])) == (
'while (fabs(x) > 1) {\n'
' x -= 1;\n'
'}'
)
assert ccode(Scope([AddAugmentedAssignment(x, 1)])) == (
'{\n'
' x += 1;\n'
'}'
)
inp_x = Declaration(Variable(x, type=real))
assert ccode(FunctionPrototype(real, 'pwer', [inp_x])) == 'double pwer(double x)'
assert ccode(FunctionDefinition(real, 'pwer', [inp_x], [Assignment(x, x**2)])) == (
'double pwer(double x){\n'
' x = pow(x, 2);\n'
'}'
)
# Elements of CodeBlock are formatted as statements:
block = CodeBlock(
x,
Print([x, y], "%d %d"),
Print([QuotedString('hello'), y], "%s %d", file=stderr),
FunctionCall('pwer', [x]),
Return(x),
)
assert ccode(block) == '\n'.join([
'x;',
'printf("%d %d", x, y);',
'fprintf(stderr, "%s %d", "hello", y);',
'pwer(x);',
'return x;',
])
def test_ccode_UnevaluatedExpr():
assert ccode(UnevaluatedExpr(y * x) + z) == "z + x*y"
assert ccode(UnevaluatedExpr(y + x) + z) == "z + (x + y)" # gh-21955
w = symbols('w')
assert ccode(UnevaluatedExpr(y + x) + UnevaluatedExpr(z + w)) == "(w + z) + (x + y)"
p, q, r = symbols("p q r", real=True)
q_r = UnevaluatedExpr(q + r)
expr = abs(exp(p+q_r))
assert ccode(expr) == "exp(p + (q + r))"
def test_ccode_array_like_containers():
assert ccode([2,3,4]) == "{2, 3, 4}"
assert ccode((2,3,4)) == "{2, 3, 4}"

View File

@ -0,0 +1,55 @@
from sympy.printing.codeprinter import CodePrinter, PrintMethodNotImplementedError
from sympy.core import symbols
from sympy.core.symbol import Dummy
from sympy.testing.pytest import raises
def setup_test_printer(**kwargs):
p = CodePrinter(settings=kwargs)
p._not_supported = set()
p._number_symbols = set()
return p
def test_print_Dummy():
d = Dummy('d')
p = setup_test_printer()
assert p._print_Dummy(d) == "d_%i" % d.dummy_index
def test_print_Symbol():
x, y = symbols('x, if')
p = setup_test_printer()
assert p._print(x) == 'x'
assert p._print(y) == 'if'
p.reserved_words.update(['if'])
assert p._print(y) == 'if_'
p = setup_test_printer(error_on_reserved=True)
p.reserved_words.update(['if'])
with raises(ValueError):
p._print(y)
p = setup_test_printer(reserved_word_suffix='_He_Man')
p.reserved_words.update(['if'])
assert p._print(y) == 'if_He_Man'
def test_issue_15791():
class CrashingCodePrinter(CodePrinter):
def emptyPrinter(self, obj):
raise NotImplementedError
from sympy.matrices import (
MutableSparseMatrix,
ImmutableSparseMatrix,
)
c = CrashingCodePrinter()
# these should not silently succeed
with raises(PrintMethodNotImplementedError):
c.doprint(ImmutableSparseMatrix(2, 2, {}))
with raises(PrintMethodNotImplementedError):
c.doprint(MutableSparseMatrix(2, 2, {}))

View File

@ -0,0 +1,116 @@
# -*- coding: utf-8 -*-
from sympy.core.function import (Derivative, Function)
from sympy.core.numbers import oo
from sympy.core.symbol import symbols
from sympy.functions.elementary.exponential import exp
from sympy.functions.elementary.trigonometric import cos
from sympy.integrals.integrals import Integral
from sympy.functions.special.bessel import besselj
from sympy.functions.special.polynomials import legendre
from sympy.functions.combinatorial.numbers import bell
from sympy.printing.conventions import split_super_sub, requires_partial
from sympy.testing.pytest import XFAIL
def test_super_sub():
assert split_super_sub("beta_13_2") == ("beta", [], ["13", "2"])
assert split_super_sub("beta_132_20") == ("beta", [], ["132", "20"])
assert split_super_sub("beta_13") == ("beta", [], ["13"])
assert split_super_sub("x_a_b") == ("x", [], ["a", "b"])
assert split_super_sub("x_1_2_3") == ("x", [], ["1", "2", "3"])
assert split_super_sub("x_a_b1") == ("x", [], ["a", "b1"])
assert split_super_sub("x_a_1") == ("x", [], ["a", "1"])
assert split_super_sub("x_1_a") == ("x", [], ["1", "a"])
assert split_super_sub("x_1^aa") == ("x", ["aa"], ["1"])
assert split_super_sub("x_1__aa") == ("x", ["aa"], ["1"])
assert split_super_sub("x_11^a") == ("x", ["a"], ["11"])
assert split_super_sub("x_11__a") == ("x", ["a"], ["11"])
assert split_super_sub("x_a_b_c_d") == ("x", [], ["a", "b", "c", "d"])
assert split_super_sub("x_a_b^c^d") == ("x", ["c", "d"], ["a", "b"])
assert split_super_sub("x_a_b__c__d") == ("x", ["c", "d"], ["a", "b"])
assert split_super_sub("x_a^b_c^d") == ("x", ["b", "d"], ["a", "c"])
assert split_super_sub("x_a__b_c__d") == ("x", ["b", "d"], ["a", "c"])
assert split_super_sub("x^a^b_c_d") == ("x", ["a", "b"], ["c", "d"])
assert split_super_sub("x__a__b_c_d") == ("x", ["a", "b"], ["c", "d"])
assert split_super_sub("x^a^b^c^d") == ("x", ["a", "b", "c", "d"], [])
assert split_super_sub("x__a__b__c__d") == ("x", ["a", "b", "c", "d"], [])
assert split_super_sub("alpha_11") == ("alpha", [], ["11"])
assert split_super_sub("alpha_11_11") == ("alpha", [], ["11", "11"])
assert split_super_sub("w1") == ("w", [], ["1"])
assert split_super_sub("w𝟙") == ("w", [], ["𝟙"])
assert split_super_sub("w11") == ("w", [], ["11"])
assert split_super_sub("w𝟙𝟙") == ("w", [], ["𝟙𝟙"])
assert split_super_sub("w𝟙2𝟙") == ("w", [], ["𝟙2𝟙"])
assert split_super_sub("w1^a") == ("w", ["a"], ["1"])
assert split_super_sub("ω1") == ("ω", [], ["1"])
assert split_super_sub("ω11") == ("ω", [], ["11"])
assert split_super_sub("ω1^a") == ("ω", ["a"], ["1"])
assert split_super_sub("ω𝟙^α") == ("ω", ["α"], ["𝟙"])
assert split_super_sub("ω𝟙2^3α") == ("ω", ["3α"], ["𝟙2"])
assert split_super_sub("") == ("", [], [])
def test_requires_partial():
x, y, z, t, nu = symbols('x y z t nu')
n = symbols('n', integer=True)
f = x * y
assert requires_partial(Derivative(f, x)) is True
assert requires_partial(Derivative(f, y)) is True
## integrating out one of the variables
assert requires_partial(Derivative(Integral(exp(-x * y), (x, 0, oo)), y, evaluate=False)) is False
## bessel function with smooth parameter
f = besselj(nu, x)
assert requires_partial(Derivative(f, x)) is True
assert requires_partial(Derivative(f, nu)) is True
## bessel function with integer parameter
f = besselj(n, x)
assert requires_partial(Derivative(f, x)) is False
# this is not really valid (differentiating with respect to an integer)
# but there's no reason to use the partial derivative symbol there. make
# sure we don't throw an exception here, though
assert requires_partial(Derivative(f, n)) is False
## bell polynomial
f = bell(n, x)
assert requires_partial(Derivative(f, x)) is False
# again, invalid
assert requires_partial(Derivative(f, n)) is False
## legendre polynomial
f = legendre(0, x)
assert requires_partial(Derivative(f, x)) is False
f = legendre(n, x)
assert requires_partial(Derivative(f, x)) is False
# again, invalid
assert requires_partial(Derivative(f, n)) is False
f = x ** n
assert requires_partial(Derivative(f, x)) is False
assert requires_partial(Derivative(Integral((x*y) ** n * exp(-x * y), (x, 0, oo)), y, evaluate=False)) is False
# parametric equation
f = (exp(t), cos(t))
g = sum(f)
assert requires_partial(Derivative(g, t)) is False
f = symbols('f', cls=Function)
assert requires_partial(Derivative(f(x), x)) is False
assert requires_partial(Derivative(f(x), y)) is False
assert requires_partial(Derivative(f(x, y), x)) is True
assert requires_partial(Derivative(f(x, y), y)) is True
assert requires_partial(Derivative(f(x, y), z)) is True
assert requires_partial(Derivative(f(x, y), x, y)) is True
@XFAIL
def test_requires_partial_unspecified_variables():
x, y = symbols('x y')
# function of unspecified variables
f = symbols('f', cls=Function)
assert requires_partial(Derivative(f, x)) is False
assert requires_partial(Derivative(f, x, y)) is True

View File

@ -0,0 +1,56 @@
from sympy.concrete.summations import Sum
from sympy.functions.elementary.exponential import log
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.utilities.lambdify import lambdify
from sympy.abc import x, i, a, b
from sympy.codegen.numpy_nodes import logaddexp
from sympy.printing.numpy import CuPyPrinter, _cupy_known_constants, _cupy_known_functions
from sympy.testing.pytest import skip, raises
from sympy.external import import_module
cp = import_module('cupy')
def test_cupy_print():
prntr = CuPyPrinter()
assert prntr.doprint(logaddexp(a, b)) == 'cupy.logaddexp(a, b)'
assert prntr.doprint(sqrt(x)) == 'cupy.sqrt(x)'
assert prntr.doprint(log(x)) == 'cupy.log(x)'
assert prntr.doprint("acos(x)") == 'cupy.arccos(x)'
assert prntr.doprint("exp(x)") == 'cupy.exp(x)'
assert prntr.doprint("Abs(x)") == 'abs(x)'
def test_not_cupy_print():
prntr = CuPyPrinter()
with raises(NotImplementedError):
prntr.doprint("abcd(x)")
def test_cupy_sum():
if not cp:
skip("CuPy not installed")
s = Sum(x ** i, (i, a, b))
f = lambdify((a, b, x), s, 'cupy')
a_, b_ = 0, 10
x_ = cp.linspace(-1, +1, 10)
assert cp.allclose(f(a_, b_, x_), sum(x_ ** i_ for i_ in range(a_, b_ + 1)))
s = Sum(i * x, (i, a, b))
f = lambdify((a, b, x), s, 'numpy')
a_, b_ = 0, 10
x_ = cp.linspace(-1, +1, 10)
assert cp.allclose(f(a_, b_, x_), sum(i_ * x_ for i_ in range(a_, b_ + 1)))
def test_cupy_known_funcs_consts():
assert _cupy_known_constants['NaN'] == 'cupy.nan'
assert _cupy_known_constants['EulerGamma'] == 'cupy.euler_gamma'
assert _cupy_known_functions['acos'] == 'cupy.arccos'
assert _cupy_known_functions['log'] == 'cupy.log'
def test_cupy_print_methods():
prntr = CuPyPrinter()
assert hasattr(prntr, '_print_acos')
assert hasattr(prntr, '_print_log')

View File

@ -0,0 +1,86 @@
from sympy.core.numbers import Float, Integer, Rational
from sympy.core.symbol import symbols
from sympy.functions import beta, Ei, zeta, Max, Min, sqrt, riemann_xi, frac
from sympy.printing.cxx import CXX98CodePrinter, CXX11CodePrinter, CXX17CodePrinter, cxxcode
from sympy.codegen.cfunctions import log1p
x, y, u, v = symbols('x y u v')
def test_CXX98CodePrinter():
assert CXX98CodePrinter().doprint(Max(x, 3)) in ('std::max(x, 3)', 'std::max(3, x)')
assert CXX98CodePrinter().doprint(Min(x, 3, sqrt(x))) == 'std::min(3, std::min(x, std::sqrt(x)))'
cxx98printer = CXX98CodePrinter()
assert cxx98printer.language == 'C++'
assert cxx98printer.standard == 'C++98'
assert 'template' in cxx98printer.reserved_words
assert 'alignas' not in cxx98printer.reserved_words
def test_CXX11CodePrinter():
assert CXX11CodePrinter().doprint(log1p(x)) == 'std::log1p(x)'
cxx11printer = CXX11CodePrinter()
assert cxx11printer.language == 'C++'
assert cxx11printer.standard == 'C++11'
assert 'operator' in cxx11printer.reserved_words
assert 'noexcept' in cxx11printer.reserved_words
assert 'concept' not in cxx11printer.reserved_words
def test_subclass_print_method():
class MyPrinter(CXX11CodePrinter):
def _print_log1p(self, expr):
return 'my_library::log1p(%s)' % ', '.join(map(self._print, expr.args))
assert MyPrinter().doprint(log1p(x)) == 'my_library::log1p(x)'
def test_subclass_print_method__ns():
class MyPrinter(CXX11CodePrinter):
_ns = 'my_library::'
p = CXX11CodePrinter()
myp = MyPrinter()
assert p.doprint(log1p(x)) == 'std::log1p(x)'
assert myp.doprint(log1p(x)) == 'my_library::log1p(x)'
def test_CXX17CodePrinter():
assert CXX17CodePrinter().doprint(beta(x, y)) == 'std::beta(x, y)'
assert CXX17CodePrinter().doprint(Ei(x)) == 'std::expint(x)'
assert CXX17CodePrinter().doprint(zeta(x)) == 'std::riemann_zeta(x)'
# Automatic rewrite
assert CXX17CodePrinter().doprint(frac(x)) == '(x - std::floor(x))'
assert CXX17CodePrinter().doprint(riemann_xi(x)) == '((1.0/2.0)*std::pow(M_PI, -1.0/2.0*x)*x*(x - 1)*std::tgamma((1.0/2.0)*x)*std::riemann_zeta(x))'
def test_cxxcode():
assert sorted(cxxcode(sqrt(x)*.5).split('*')) == sorted(['0.5', 'std::sqrt(x)'])
def test_cxxcode_nested_minmax():
assert cxxcode(Max(Min(x, y), Min(u, v))) \
== 'std::max(std::min(u, v), std::min(x, y))'
assert cxxcode(Min(Max(x, y), Max(u, v))) \
== 'std::min(std::max(u, v), std::max(x, y))'
def test_subclass_Integer_Float():
class MyPrinter(CXX17CodePrinter):
def _print_Integer(self, arg):
return 'bigInt("%s")' % super()._print_Integer(arg)
def _print_Float(self, arg):
rat = Rational(arg)
return 'bigFloat(%s, %s)' % (
self._print(Integer(rat.p)),
self._print(Integer(rat.q))
)
p = MyPrinter()
for i in range(13):
assert p.doprint(i) == 'bigInt("%d")' % i
assert p.doprint(Float(0.5)) == 'bigFloat(bigInt("1"), bigInt("2"))'
assert p.doprint(x**-1.0) == 'bigFloat(bigInt("1"), bigInt("1"))/x'

View File

@ -0,0 +1,134 @@
from sympy.printing.dot import (purestr, styleof, attrprint, dotnode,
dotedges, dotprint)
from sympy.core.basic import Basic
from sympy.core.expr import Expr
from sympy.core.numbers import (Float, Integer)
from sympy.core.singleton import S
from sympy.core.symbol import (Symbol, symbols)
from sympy.printing.repr import srepr
from sympy.abc import x
def test_purestr():
assert purestr(Symbol('x')) == "Symbol('x')"
assert purestr(Basic(S(1), S(2))) == "Basic(Integer(1), Integer(2))"
assert purestr(Float(2)) == "Float('2.0', precision=53)"
assert purestr(Symbol('x'), with_args=True) == ("Symbol('x')", ())
assert purestr(Basic(S(1), S(2)), with_args=True) == \
('Basic(Integer(1), Integer(2))', ('Integer(1)', 'Integer(2)'))
assert purestr(Float(2), with_args=True) == \
("Float('2.0', precision=53)", ())
def test_styleof():
styles = [(Basic, {'color': 'blue', 'shape': 'ellipse'}),
(Expr, {'color': 'black'})]
assert styleof(Basic(S(1)), styles) == {'color': 'blue', 'shape': 'ellipse'}
assert styleof(x + 1, styles) == {'color': 'black', 'shape': 'ellipse'}
def test_attrprint():
assert attrprint({'color': 'blue', 'shape': 'ellipse'}) == \
'"color"="blue", "shape"="ellipse"'
def test_dotnode():
assert dotnode(x, repeat=False) == \
'"Symbol(\'x\')" ["color"="black", "label"="x", "shape"="ellipse"];'
assert dotnode(x+2, repeat=False) == \
'"Add(Integer(2), Symbol(\'x\'))" ' \
'["color"="black", "label"="Add", "shape"="ellipse"];', \
dotnode(x+2,repeat=0)
assert dotnode(x + x**2, repeat=False) == \
'"Add(Symbol(\'x\'), Pow(Symbol(\'x\'), Integer(2)))" ' \
'["color"="black", "label"="Add", "shape"="ellipse"];'
assert dotnode(x + x**2, repeat=True) == \
'"Add(Symbol(\'x\'), Pow(Symbol(\'x\'), Integer(2)))_()" ' \
'["color"="black", "label"="Add", "shape"="ellipse"];'
def test_dotedges():
assert sorted(dotedges(x+2, repeat=False)) == [
'"Add(Integer(2), Symbol(\'x\'))" -> "Integer(2)";',
'"Add(Integer(2), Symbol(\'x\'))" -> "Symbol(\'x\')";'
]
assert sorted(dotedges(x + 2, repeat=True)) == [
'"Add(Integer(2), Symbol(\'x\'))_()" -> "Integer(2)_(0,)";',
'"Add(Integer(2), Symbol(\'x\'))_()" -> "Symbol(\'x\')_(1,)";'
]
def test_dotprint():
text = dotprint(x+2, repeat=False)
assert all(e in text for e in dotedges(x+2, repeat=False))
assert all(
n in text for n in [dotnode(expr, repeat=False)
for expr in (x, Integer(2), x+2)])
assert 'digraph' in text
text = dotprint(x+x**2, repeat=False)
assert all(e in text for e in dotedges(x+x**2, repeat=False))
assert all(
n in text for n in [dotnode(expr, repeat=False)
for expr in (x, Integer(2), x**2)])
assert 'digraph' in text
text = dotprint(x+x**2, repeat=True)
assert all(e in text for e in dotedges(x+x**2, repeat=True))
assert all(
n in text for n in [dotnode(expr, pos=())
for expr in [x + x**2]])
text = dotprint(x**x, repeat=True)
assert all(e in text for e in dotedges(x**x, repeat=True))
assert all(
n in text for n in [dotnode(x, pos=(0,)), dotnode(x, pos=(1,))])
assert 'digraph' in text
def test_dotprint_depth():
text = dotprint(3*x+2, depth=1)
assert dotnode(3*x+2) in text
assert dotnode(x) not in text
text = dotprint(3*x+2)
assert "depth" not in text
def test_Matrix_and_non_basics():
from sympy.matrices.expressions.matexpr import MatrixSymbol
n = Symbol('n')
assert dotprint(MatrixSymbol('X', n, n)) == \
"""digraph{
# Graph style
"ordering"="out"
"rankdir"="TD"
#########
# Nodes #
#########
"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" ["color"="black", "label"="MatrixSymbol", "shape"="ellipse"];
"Str('X')_(0,)" ["color"="blue", "label"="X", "shape"="ellipse"];
"Symbol('n')_(1,)" ["color"="black", "label"="n", "shape"="ellipse"];
"Symbol('n')_(2,)" ["color"="black", "label"="n", "shape"="ellipse"];
#########
# Edges #
#########
"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Str('X')_(0,)";
"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Symbol('n')_(1,)";
"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Symbol('n')_(2,)";
}"""
def test_labelfunc():
text = dotprint(x + 2, labelfunc=srepr)
assert "Symbol('x')" in text
assert "Integer(2)" in text
def test_commutative():
x, y = symbols('x y', commutative=False)
assert dotprint(x + y) == dotprint(y + x)
assert dotprint(x*y) != dotprint(y*x)

View File

@ -0,0 +1,854 @@
from sympy.core.add import Add
from sympy.core.expr import Expr
from sympy.core.function import (Function, Lambda, diff)
from sympy.core.mod import Mod
from sympy.core import (Catalan, EulerGamma, GoldenRatio)
from sympy.core.numbers import (E, Float, I, Integer, Rational, pi)
from sympy.core.relational import Eq
from sympy.core.singleton import S
from sympy.core.symbol import (Dummy, symbols)
from sympy.functions.combinatorial.factorials import factorial
from sympy.functions.elementary.complexes import (conjugate, sign)
from sympy.functions.elementary.exponential import (exp, log)
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.piecewise import Piecewise
from sympy.functions.elementary.trigonometric import (atan2, cos, sin)
from sympy.functions.special.gamma_functions import gamma
from sympy.integrals.integrals import Integral
from sympy.sets.fancysets import Range
from sympy.codegen import For, Assignment, aug_assign
from sympy.codegen.ast import Declaration, Variable, float32, float64, \
value_const, real, bool_, While, FunctionPrototype, FunctionDefinition, \
integer, Return, Element
from sympy.core.expr import UnevaluatedExpr
from sympy.core.relational import Relational
from sympy.logic.boolalg import And, Or, Not, Equivalent, Xor
from sympy.matrices import Matrix, MatrixSymbol
from sympy.printing.fortran import fcode, FCodePrinter
from sympy.tensor import IndexedBase, Idx
from sympy.tensor.array.expressions import ArraySymbol, ArrayElement
from sympy.utilities.lambdify import implemented_function
from sympy.testing.pytest import raises
def test_UnevaluatedExpr():
p, q, r = symbols("p q r", real=True)
q_r = UnevaluatedExpr(q + r)
expr = abs(exp(p+q_r))
assert fcode(expr, source_format="free") == "exp(p + (q + r))"
x, y, z = symbols("x y z")
y_z = UnevaluatedExpr(y + z)
expr2 = abs(exp(x+y_z))
assert fcode(expr2, human=False)[2].lstrip() == "exp(re(x) + re(y + z))"
assert fcode(expr2, user_functions={"re": "realpart"}).lstrip() == "exp(realpart(x) + realpart(y + z))"
def test_printmethod():
x = symbols('x')
class nint(Function):
def _fcode(self, printer):
return "nint(%s)" % printer._print(self.args[0])
assert fcode(nint(x)) == " nint(x)"
def test_fcode_sign(): #issue 12267
x=symbols('x')
y=symbols('y', integer=True)
z=symbols('z', complex=True)
assert fcode(sign(x), standard=95, source_format='free') == "merge(0d0, dsign(1d0, x), x == 0d0)"
assert fcode(sign(y), standard=95, source_format='free') == "merge(0, isign(1, y), y == 0)"
assert fcode(sign(z), standard=95, source_format='free') == "merge(cmplx(0d0, 0d0), z/abs(z), abs(z) == 0d0)"
raises(NotImplementedError, lambda: fcode(sign(x)))
def test_fcode_Pow():
x, y = symbols('x,y')
n = symbols('n', integer=True)
assert fcode(x**3) == " x**3"
assert fcode(x**(y**3)) == " x**(y**3)"
assert fcode(1/(sin(x)*3.5)**(x - y**x)/(x**2 + y)) == \
" (3.5d0*sin(x))**(-x + y**x)/(x**2 + y)"
assert fcode(sqrt(x)) == ' sqrt(x)'
assert fcode(sqrt(n)) == ' sqrt(dble(n))'
assert fcode(x**0.5) == ' sqrt(x)'
assert fcode(sqrt(x)) == ' sqrt(x)'
assert fcode(sqrt(10)) == ' sqrt(10.0d0)'
assert fcode(x**-1.0) == ' 1d0/x'
assert fcode(x**-2.0, 'y', source_format='free') == 'y = x**(-2.0d0)' # 2823
assert fcode(x**Rational(3, 7)) == ' x**(3.0d0/7.0d0)'
def test_fcode_Rational():
x = symbols('x')
assert fcode(Rational(3, 7)) == " 3.0d0/7.0d0"
assert fcode(Rational(18, 9)) == " 2"
assert fcode(Rational(3, -7)) == " -3.0d0/7.0d0"
assert fcode(Rational(-3, -7)) == " 3.0d0/7.0d0"
assert fcode(x + Rational(3, 7)) == " x + 3.0d0/7.0d0"
assert fcode(Rational(3, 7)*x) == " (3.0d0/7.0d0)*x"
def test_fcode_Integer():
assert fcode(Integer(67)) == " 67"
assert fcode(Integer(-1)) == " -1"
def test_fcode_Float():
assert fcode(Float(42.0)) == " 42.0000000000000d0"
assert fcode(Float(-1e20)) == " -1.00000000000000d+20"
def test_fcode_functions():
x, y = symbols('x,y')
assert fcode(sin(x) ** cos(y)) == " sin(x)**cos(y)"
raises(NotImplementedError, lambda: fcode(Mod(x, y), standard=66))
raises(NotImplementedError, lambda: fcode(x % y, standard=66))
raises(NotImplementedError, lambda: fcode(Mod(x, y), standard=77))
raises(NotImplementedError, lambda: fcode(x % y, standard=77))
for standard in [90, 95, 2003, 2008]:
assert fcode(Mod(x, y), standard=standard) == " modulo(x, y)"
assert fcode(x % y, standard=standard) == " modulo(x, y)"
def test_case():
ob = FCodePrinter()
x,x_,x__,y,X,X_,Y = symbols('x,x_,x__,y,X,X_,Y')
assert fcode(exp(x_) + sin(x*y) + cos(X*Y)) == \
' exp(x_) + sin(x*y) + cos(X__*Y_)'
assert fcode(exp(x__) + 2*x*Y*X_**Rational(7, 2)) == \
' 2*X_**(7.0d0/2.0d0)*Y*x + exp(x__)'
assert fcode(exp(x_) + sin(x*y) + cos(X*Y), name_mangling=False) == \
' exp(x_) + sin(x*y) + cos(X*Y)'
assert fcode(x - cos(X), name_mangling=False) == ' x - cos(X)'
assert ob.doprint(X*sin(x) + x_, assign_to='me') == ' me = X*sin(x_) + x__'
assert ob.doprint(X*sin(x), assign_to='mu') == ' mu = X*sin(x_)'
assert ob.doprint(x_, assign_to='ad') == ' ad = x__'
n, m = symbols('n,m', integer=True)
A = IndexedBase('A')
x = IndexedBase('x')
y = IndexedBase('y')
i = Idx('i', m)
I = Idx('I', n)
assert fcode(A[i, I]*x[I], assign_to=y[i], source_format='free') == (
"do i = 1, m\n"
" y(i) = 0\n"
"end do\n"
"do i = 1, m\n"
" do I_ = 1, n\n"
" y(i) = A(i, I_)*x(I_) + y(i)\n"
" end do\n"
"end do" )
#issue 6814
def test_fcode_functions_with_integers():
x= symbols('x')
log10_17 = log(10).evalf(17)
loglog10_17 = '0.8340324452479558d0'
assert fcode(x * log(10)) == " x*%sd0" % log10_17
assert fcode(x * log(10)) == " x*%sd0" % log10_17
assert fcode(x * log(S(10))) == " x*%sd0" % log10_17
assert fcode(log(S(10))) == " %sd0" % log10_17
assert fcode(exp(10)) == " %sd0" % exp(10).evalf(17)
assert fcode(x * log(log(10))) == " x*%s" % loglog10_17
assert fcode(x * log(log(S(10)))) == " x*%s" % loglog10_17
def test_fcode_NumberSymbol():
prec = 17
p = FCodePrinter()
assert fcode(Catalan) == ' parameter (Catalan = %sd0)\n Catalan' % Catalan.evalf(prec)
assert fcode(EulerGamma) == ' parameter (EulerGamma = %sd0)\n EulerGamma' % EulerGamma.evalf(prec)
assert fcode(E) == ' parameter (E = %sd0)\n E' % E.evalf(prec)
assert fcode(GoldenRatio) == ' parameter (GoldenRatio = %sd0)\n GoldenRatio' % GoldenRatio.evalf(prec)
assert fcode(pi) == ' parameter (pi = %sd0)\n pi' % pi.evalf(prec)
assert fcode(
pi, precision=5) == ' parameter (pi = %sd0)\n pi' % pi.evalf(5)
assert fcode(Catalan, human=False) == ({
(Catalan, p._print(Catalan.evalf(prec)))}, set(), ' Catalan')
assert fcode(EulerGamma, human=False) == ({(EulerGamma, p._print(
EulerGamma.evalf(prec)))}, set(), ' EulerGamma')
assert fcode(E, human=False) == (
{(E, p._print(E.evalf(prec)))}, set(), ' E')
assert fcode(GoldenRatio, human=False) == ({(GoldenRatio, p._print(
GoldenRatio.evalf(prec)))}, set(), ' GoldenRatio')
assert fcode(pi, human=False) == (
{(pi, p._print(pi.evalf(prec)))}, set(), ' pi')
assert fcode(pi, precision=5, human=False) == (
{(pi, p._print(pi.evalf(5)))}, set(), ' pi')
def test_fcode_complex():
assert fcode(I) == " cmplx(0,1)"
x = symbols('x')
assert fcode(4*I) == " cmplx(0,4)"
assert fcode(3 + 4*I) == " cmplx(3,4)"
assert fcode(3 + 4*I + x) == " cmplx(3,4) + x"
assert fcode(I*x) == " cmplx(0,1)*x"
assert fcode(3 + 4*I - x) == " cmplx(3,4) - x"
x = symbols('x', imaginary=True)
assert fcode(5*x) == " 5*x"
assert fcode(I*x) == " cmplx(0,1)*x"
assert fcode(3 + x) == " x + 3"
def test_implicit():
x, y = symbols('x,y')
assert fcode(sin(x)) == " sin(x)"
assert fcode(atan2(x, y)) == " atan2(x, y)"
assert fcode(conjugate(x)) == " conjg(x)"
def test_not_fortran():
x = symbols('x')
g = Function('g')
with raises(NotImplementedError):
fcode(gamma(x))
assert fcode(Integral(sin(x)), strict=False) == "C Not supported in Fortran:\nC Integral\n Integral(sin(x), x)"
with raises(NotImplementedError):
fcode(g(x))
def test_user_functions():
x = symbols('x')
assert fcode(sin(x), user_functions={"sin": "zsin"}) == " zsin(x)"
x = symbols('x')
assert fcode(
gamma(x), user_functions={"gamma": "mygamma"}) == " mygamma(x)"
g = Function('g')
assert fcode(g(x), user_functions={"g": "great"}) == " great(x)"
n = symbols('n', integer=True)
assert fcode(
factorial(n), user_functions={"factorial": "fct"}) == " fct(n)"
def test_inline_function():
x = symbols('x')
g = implemented_function('g', Lambda(x, 2*x))
assert fcode(g(x)) == " 2*x"
g = implemented_function('g', Lambda(x, 2*pi/x))
assert fcode(g(x)) == (
" parameter (pi = %sd0)\n"
" 2*pi/x"
) % pi.evalf(17)
A = IndexedBase('A')
i = Idx('i', symbols('n', integer=True))
g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x)))
assert fcode(g(A[i]), assign_to=A[i]) == (
" do i = 1, n\n"
" A(i) = (A(i) + 1)*(A(i) + 2)*A(i)\n"
" end do"
)
def test_assign_to():
x = symbols('x')
assert fcode(sin(x), assign_to="s") == " s = sin(x)"
def test_line_wrapping():
x, y = symbols('x,y')
assert fcode(((x + y)**10).expand(), assign_to="var") == (
" var = x**10 + 10*x**9*y + 45*x**8*y**2 + 120*x**7*y**3 + 210*x**6*\n"
" @ y**4 + 252*x**5*y**5 + 210*x**4*y**6 + 120*x**3*y**7 + 45*x**2*y\n"
" @ **8 + 10*x*y**9 + y**10"
)
e = [x**i for i in range(11)]
assert fcode(Add(*e)) == (
" x**10 + x**9 + x**8 + x**7 + x**6 + x**5 + x**4 + x**3 + x**2 + x\n"
" @ + 1"
)
def test_fcode_precedence():
x, y = symbols("x y")
assert fcode(And(x < y, y < x + 1), source_format="free") == \
"x < y .and. y < x + 1"
assert fcode(Or(x < y, y < x + 1), source_format="free") == \
"x < y .or. y < x + 1"
assert fcode(Xor(x < y, y < x + 1, evaluate=False),
source_format="free") == "x < y .neqv. y < x + 1"
assert fcode(Equivalent(x < y, y < x + 1), source_format="free") == \
"x < y .eqv. y < x + 1"
def test_fcode_Logical():
x, y, z = symbols("x y z")
# unary Not
assert fcode(Not(x), source_format="free") == ".not. x"
# binary And
assert fcode(And(x, y), source_format="free") == "x .and. y"
assert fcode(And(x, Not(y)), source_format="free") == "x .and. .not. y"
assert fcode(And(Not(x), y), source_format="free") == "y .and. .not. x"
assert fcode(And(Not(x), Not(y)), source_format="free") == \
".not. x .and. .not. y"
assert fcode(Not(And(x, y), evaluate=False), source_format="free") == \
".not. (x .and. y)"
# binary Or
assert fcode(Or(x, y), source_format="free") == "x .or. y"
assert fcode(Or(x, Not(y)), source_format="free") == "x .or. .not. y"
assert fcode(Or(Not(x), y), source_format="free") == "y .or. .not. x"
assert fcode(Or(Not(x), Not(y)), source_format="free") == \
".not. x .or. .not. y"
assert fcode(Not(Or(x, y), evaluate=False), source_format="free") == \
".not. (x .or. y)"
# mixed And/Or
assert fcode(And(Or(y, z), x), source_format="free") == "x .and. (y .or. z)"
assert fcode(And(Or(z, x), y), source_format="free") == "y .and. (x .or. z)"
assert fcode(And(Or(x, y), z), source_format="free") == "z .and. (x .or. y)"
assert fcode(Or(And(y, z), x), source_format="free") == "x .or. y .and. z"
assert fcode(Or(And(z, x), y), source_format="free") == "y .or. x .and. z"
assert fcode(Or(And(x, y), z), source_format="free") == "z .or. x .and. y"
# trinary And
assert fcode(And(x, y, z), source_format="free") == "x .and. y .and. z"
assert fcode(And(x, y, Not(z)), source_format="free") == \
"x .and. y .and. .not. z"
assert fcode(And(x, Not(y), z), source_format="free") == \
"x .and. z .and. .not. y"
assert fcode(And(Not(x), y, z), source_format="free") == \
"y .and. z .and. .not. x"
assert fcode(Not(And(x, y, z), evaluate=False), source_format="free") == \
".not. (x .and. y .and. z)"
# trinary Or
assert fcode(Or(x, y, z), source_format="free") == "x .or. y .or. z"
assert fcode(Or(x, y, Not(z)), source_format="free") == \
"x .or. y .or. .not. z"
assert fcode(Or(x, Not(y), z), source_format="free") == \
"x .or. z .or. .not. y"
assert fcode(Or(Not(x), y, z), source_format="free") == \
"y .or. z .or. .not. x"
assert fcode(Not(Or(x, y, z), evaluate=False), source_format="free") == \
".not. (x .or. y .or. z)"
def test_fcode_Xlogical():
x, y, z = symbols("x y z")
# binary Xor
assert fcode(Xor(x, y, evaluate=False), source_format="free") == \
"x .neqv. y"
assert fcode(Xor(x, Not(y), evaluate=False), source_format="free") == \
"x .neqv. .not. y"
assert fcode(Xor(Not(x), y, evaluate=False), source_format="free") == \
"y .neqv. .not. x"
assert fcode(Xor(Not(x), Not(y), evaluate=False),
source_format="free") == ".not. x .neqv. .not. y"
assert fcode(Not(Xor(x, y, evaluate=False), evaluate=False),
source_format="free") == ".not. (x .neqv. y)"
# binary Equivalent
assert fcode(Equivalent(x, y), source_format="free") == "x .eqv. y"
assert fcode(Equivalent(x, Not(y)), source_format="free") == \
"x .eqv. .not. y"
assert fcode(Equivalent(Not(x), y), source_format="free") == \
"y .eqv. .not. x"
assert fcode(Equivalent(Not(x), Not(y)), source_format="free") == \
".not. x .eqv. .not. y"
assert fcode(Not(Equivalent(x, y), evaluate=False),
source_format="free") == ".not. (x .eqv. y)"
# mixed And/Equivalent
assert fcode(Equivalent(And(y, z), x), source_format="free") == \
"x .eqv. y .and. z"
assert fcode(Equivalent(And(z, x), y), source_format="free") == \
"y .eqv. x .and. z"
assert fcode(Equivalent(And(x, y), z), source_format="free") == \
"z .eqv. x .and. y"
assert fcode(And(Equivalent(y, z), x), source_format="free") == \
"x .and. (y .eqv. z)"
assert fcode(And(Equivalent(z, x), y), source_format="free") == \
"y .and. (x .eqv. z)"
assert fcode(And(Equivalent(x, y), z), source_format="free") == \
"z .and. (x .eqv. y)"
# mixed Or/Equivalent
assert fcode(Equivalent(Or(y, z), x), source_format="free") == \
"x .eqv. y .or. z"
assert fcode(Equivalent(Or(z, x), y), source_format="free") == \
"y .eqv. x .or. z"
assert fcode(Equivalent(Or(x, y), z), source_format="free") == \
"z .eqv. x .or. y"
assert fcode(Or(Equivalent(y, z), x), source_format="free") == \
"x .or. (y .eqv. z)"
assert fcode(Or(Equivalent(z, x), y), source_format="free") == \
"y .or. (x .eqv. z)"
assert fcode(Or(Equivalent(x, y), z), source_format="free") == \
"z .or. (x .eqv. y)"
# mixed Xor/Equivalent
assert fcode(Equivalent(Xor(y, z, evaluate=False), x),
source_format="free") == "x .eqv. (y .neqv. z)"
assert fcode(Equivalent(Xor(z, x, evaluate=False), y),
source_format="free") == "y .eqv. (x .neqv. z)"
assert fcode(Equivalent(Xor(x, y, evaluate=False), z),
source_format="free") == "z .eqv. (x .neqv. y)"
assert fcode(Xor(Equivalent(y, z), x, evaluate=False),
source_format="free") == "x .neqv. (y .eqv. z)"
assert fcode(Xor(Equivalent(z, x), y, evaluate=False),
source_format="free") == "y .neqv. (x .eqv. z)"
assert fcode(Xor(Equivalent(x, y), z, evaluate=False),
source_format="free") == "z .neqv. (x .eqv. y)"
# mixed And/Xor
assert fcode(Xor(And(y, z), x, evaluate=False), source_format="free") == \
"x .neqv. y .and. z"
assert fcode(Xor(And(z, x), y, evaluate=False), source_format="free") == \
"y .neqv. x .and. z"
assert fcode(Xor(And(x, y), z, evaluate=False), source_format="free") == \
"z .neqv. x .and. y"
assert fcode(And(Xor(y, z, evaluate=False), x), source_format="free") == \
"x .and. (y .neqv. z)"
assert fcode(And(Xor(z, x, evaluate=False), y), source_format="free") == \
"y .and. (x .neqv. z)"
assert fcode(And(Xor(x, y, evaluate=False), z), source_format="free") == \
"z .and. (x .neqv. y)"
# mixed Or/Xor
assert fcode(Xor(Or(y, z), x, evaluate=False), source_format="free") == \
"x .neqv. y .or. z"
assert fcode(Xor(Or(z, x), y, evaluate=False), source_format="free") == \
"y .neqv. x .or. z"
assert fcode(Xor(Or(x, y), z, evaluate=False), source_format="free") == \
"z .neqv. x .or. y"
assert fcode(Or(Xor(y, z, evaluate=False), x), source_format="free") == \
"x .or. (y .neqv. z)"
assert fcode(Or(Xor(z, x, evaluate=False), y), source_format="free") == \
"y .or. (x .neqv. z)"
assert fcode(Or(Xor(x, y, evaluate=False), z), source_format="free") == \
"z .or. (x .neqv. y)"
# trinary Xor
assert fcode(Xor(x, y, z, evaluate=False), source_format="free") == \
"x .neqv. y .neqv. z"
assert fcode(Xor(x, y, Not(z), evaluate=False), source_format="free") == \
"x .neqv. y .neqv. .not. z"
assert fcode(Xor(x, Not(y), z, evaluate=False), source_format="free") == \
"x .neqv. z .neqv. .not. y"
assert fcode(Xor(Not(x), y, z, evaluate=False), source_format="free") == \
"y .neqv. z .neqv. .not. x"
def test_fcode_Relational():
x, y = symbols("x y")
assert fcode(Relational(x, y, "=="), source_format="free") == "x == y"
assert fcode(Relational(x, y, "!="), source_format="free") == "x /= y"
assert fcode(Relational(x, y, ">="), source_format="free") == "x >= y"
assert fcode(Relational(x, y, "<="), source_format="free") == "x <= y"
assert fcode(Relational(x, y, ">"), source_format="free") == "x > y"
assert fcode(Relational(x, y, "<"), source_format="free") == "x < y"
def test_fcode_Piecewise():
x = symbols('x')
expr = Piecewise((x, x < 1), (x**2, True))
# Check that inline conditional (merge) fails if standard isn't 95+
raises(NotImplementedError, lambda: fcode(expr))
code = fcode(expr, standard=95)
expected = " merge(x, x**2, x < 1)"
assert code == expected
assert fcode(Piecewise((x, x < 1), (x**2, True)), assign_to="var") == (
" if (x < 1) then\n"
" var = x\n"
" else\n"
" var = x**2\n"
" end if"
)
a = cos(x)/x
b = sin(x)/x
for i in range(10):
a = diff(a, x)
b = diff(b, x)
expected = (
" if (x < 0) then\n"
" weird_name = -cos(x)/x + 10*sin(x)/x**2 + 90*cos(x)/x**3 - 720*\n"
" @ sin(x)/x**4 - 5040*cos(x)/x**5 + 30240*sin(x)/x**6 + 151200*cos(x\n"
" @ )/x**7 - 604800*sin(x)/x**8 - 1814400*cos(x)/x**9 + 3628800*sin(x\n"
" @ )/x**10 + 3628800*cos(x)/x**11\n"
" else\n"
" weird_name = -sin(x)/x - 10*cos(x)/x**2 + 90*sin(x)/x**3 + 720*\n"
" @ cos(x)/x**4 - 5040*sin(x)/x**5 - 30240*cos(x)/x**6 + 151200*sin(x\n"
" @ )/x**7 + 604800*cos(x)/x**8 - 1814400*sin(x)/x**9 - 3628800*cos(x\n"
" @ )/x**10 + 3628800*sin(x)/x**11\n"
" end if"
)
code = fcode(Piecewise((a, x < 0), (b, True)), assign_to="weird_name")
assert code == expected
code = fcode(Piecewise((x, x < 1), (x**2, x > 1), (sin(x), True)), standard=95)
expected = " merge(x, merge(x**2, sin(x), x > 1), x < 1)"
assert code == expected
# Check that Piecewise without a True (default) condition error
expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
raises(ValueError, lambda: fcode(expr))
def test_wrap_fortran():
# "########################################################################"
printer = FCodePrinter()
lines = [
"C This is a long comment on a single line that must be wrapped properly to produce nice output",
" this = is + a + long + and + nasty + fortran + statement + that * must + be + wrapped + properly",
" this = is + a + long + and + nasty + fortran + statement + that * must + be + wrapped + properly",
" this = is + a + long + and + nasty + fortran + statement + that * must + be + wrapped + properly",
" this = is + a + long + and + nasty + fortran + statement + that*must + be + wrapped + properly",
" this = is + a + long + and + nasty + fortran + statement + that*must + be + wrapped + properly",
" this = is + a + long + and + nasty + fortran + statement + that*must + be + wrapped + properly",
" this = is + a + long + and + nasty + fortran + statement + that*must + be + wrapped + properly",
" this = is + a + long + and + nasty + fortran + statement + that**must + be + wrapped + properly",
" this = is + a + long + and + nasty + fortran + statement + that**must + be + wrapped + properly",
" this = is + a + long + and + nasty + fortran + statement + that**must + be + wrapped + properly",
" this = is + a + long + and + nasty + fortran + statement + that**must + be + wrapped + properly",
" this = is + a + long + and + nasty + fortran + statement + that**must + be + wrapped + properly",
" this = is + a + long + and + nasty + fortran + statement(that)/must + be + wrapped + properly",
" this = is + a + long + and + nasty + fortran + statement(that)/must + be + wrapped + properly",
]
wrapped_lines = printer._wrap_fortran(lines)
expected_lines = [
"C This is a long comment on a single line that must be wrapped",
"C properly to produce nice output",
" this = is + a + long + and + nasty + fortran + statement + that *",
" @ must + be + wrapped + properly",
" this = is + a + long + and + nasty + fortran + statement + that *",
" @ must + be + wrapped + properly",
" this = is + a + long + and + nasty + fortran + statement + that",
" @ * must + be + wrapped + properly",
" this = is + a + long + and + nasty + fortran + statement + that*",
" @ must + be + wrapped + properly",
" this = is + a + long + and + nasty + fortran + statement + that*",
" @ must + be + wrapped + properly",
" this = is + a + long + and + nasty + fortran + statement + that",
" @ *must + be + wrapped + properly",
" this = is + a + long + and + nasty + fortran + statement +",
" @ that*must + be + wrapped + properly",
" this = is + a + long + and + nasty + fortran + statement + that**",
" @ must + be + wrapped + properly",
" this = is + a + long + and + nasty + fortran + statement + that**",
" @ must + be + wrapped + properly",
" this = is + a + long + and + nasty + fortran + statement + that",
" @ **must + be + wrapped + properly",
" this = is + a + long + and + nasty + fortran + statement + that",
" @ **must + be + wrapped + properly",
" this = is + a + long + and + nasty + fortran + statement +",
" @ that**must + be + wrapped + properly",
" this = is + a + long + and + nasty + fortran + statement(that)/",
" @ must + be + wrapped + properly",
" this = is + a + long + and + nasty + fortran + statement(that)",
" @ /must + be + wrapped + properly",
]
for line in wrapped_lines:
assert len(line) <= 72
for w, e in zip(wrapped_lines, expected_lines):
assert w == e
assert len(wrapped_lines) == len(expected_lines)
def test_wrap_fortran_keep_d0():
printer = FCodePrinter()
lines = [
' this_variable_is_very_long_because_we_try_to_test_line_break=1.0d0',
' this_variable_is_very_long_because_we_try_to_test_line_break =1.0d0',
' this_variable_is_very_long_because_we_try_to_test_line_break = 1.0d0',
' this_variable_is_very_long_because_we_try_to_test_line_break = 1.0d0',
' this_variable_is_very_long_because_we_try_to_test_line_break = 1.0d0',
' this_variable_is_very_long_because_we_try_to_test_line_break = 10.0d0'
]
expected = [
' this_variable_is_very_long_because_we_try_to_test_line_break=1.0d0',
' this_variable_is_very_long_because_we_try_to_test_line_break =',
' @ 1.0d0',
' this_variable_is_very_long_because_we_try_to_test_line_break =',
' @ 1.0d0',
' this_variable_is_very_long_because_we_try_to_test_line_break =',
' @ 1.0d0',
' this_variable_is_very_long_because_we_try_to_test_line_break =',
' @ 1.0d0',
' this_variable_is_very_long_because_we_try_to_test_line_break =',
' @ 10.0d0'
]
assert printer._wrap_fortran(lines) == expected
def test_settings():
raises(TypeError, lambda: fcode(S(4), method="garbage"))
def test_free_form_code_line():
x, y = symbols('x,y')
assert fcode(cos(x) + sin(y), source_format='free') == "sin(y) + cos(x)"
def test_free_form_continuation_line():
x, y = symbols('x,y')
result = fcode(((cos(x) + sin(y))**(7)).expand(), source_format='free')
expected = (
'sin(y)**7 + 7*sin(y)**6*cos(x) + 21*sin(y)**5*cos(x)**2 + 35*sin(y)**4* &\n'
' cos(x)**3 + 35*sin(y)**3*cos(x)**4 + 21*sin(y)**2*cos(x)**5 + 7* &\n'
' sin(y)*cos(x)**6 + cos(x)**7'
)
assert result == expected
def test_free_form_comment_line():
printer = FCodePrinter({'source_format': 'free'})
lines = [ "! This is a long comment on a single line that must be wrapped properly to produce nice output"]
expected = [
'! This is a long comment on a single line that must be wrapped properly',
'! to produce nice output']
assert printer._wrap_fortran(lines) == expected
def test_loops():
n, m = symbols('n,m', integer=True)
A = IndexedBase('A')
x = IndexedBase('x')
y = IndexedBase('y')
i = Idx('i', m)
j = Idx('j', n)
expected = (
'do i = 1, m\n'
' y(i) = 0\n'
'end do\n'
'do i = 1, m\n'
' do j = 1, n\n'
' y(i) = %(rhs)s\n'
' end do\n'
'end do'
)
code = fcode(A[i, j]*x[j], assign_to=y[i], source_format='free')
assert (code == expected % {'rhs': 'y(i) + A(i, j)*x(j)'} or
code == expected % {'rhs': 'y(i) + x(j)*A(i, j)'} or
code == expected % {'rhs': 'x(j)*A(i, j) + y(i)'} or
code == expected % {'rhs': 'A(i, j)*x(j) + y(i)'})
def test_dummy_loops():
i, m = symbols('i m', integer=True, cls=Dummy)
x = IndexedBase('x')
y = IndexedBase('y')
i = Idx(i, m)
expected = (
'do i_%(icount)i = 1, m_%(mcount)i\n'
' y(i_%(icount)i) = x(i_%(icount)i)\n'
'end do'
) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index}
code = fcode(x[i], assign_to=y[i], source_format='free')
assert code == expected
def test_fcode_Indexed_without_looking_for_contraction():
len_y = 5
y = IndexedBase('y', shape=(len_y,))
x = IndexedBase('x', shape=(len_y,))
Dy = IndexedBase('Dy', shape=(len_y-1,))
i = Idx('i', len_y-1)
e=Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i]))
code0 = fcode(e.rhs, assign_to=e.lhs, contract=False)
assert code0.endswith('Dy(i) = (y(i + 1) - y(i))/(x(i + 1) - x(i))')
def test_element_like_objects():
len_y = 5
y = ArraySymbol('y', shape=(len_y,))
x = ArraySymbol('x', shape=(len_y,))
Dy = ArraySymbol('Dy', shape=(len_y-1,))
i = Idx('i', len_y-1)
e=Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i]))
code0 = fcode(Assignment(e.lhs, e.rhs))
assert code0.endswith('Dy(i) = (y(i + 1) - y(i))/(x(i + 1) - x(i))')
class ElementExpr(Element, Expr):
pass
e = e.subs((a, ElementExpr(a.name, a.indices)) for a in e.atoms(ArrayElement) )
e=Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i]))
code0 = fcode(Assignment(e.lhs, e.rhs))
assert code0.endswith('Dy(i) = (y(i + 1) - y(i))/(x(i + 1) - x(i))')
def test_derived_classes():
class MyFancyFCodePrinter(FCodePrinter):
_default_settings = FCodePrinter._default_settings.copy()
printer = MyFancyFCodePrinter()
x = symbols('x')
assert printer.doprint(sin(x), "bork") == " bork = sin(x)"
def test_indent():
codelines = (
'subroutine test(a)\n'
'integer :: a, i, j\n'
'\n'
'do\n'
'do \n'
'do j = 1, 5\n'
'if (a>b) then\n'
'if(b>0) then\n'
'a = 3\n'
'donot_indent_me = 2\n'
'do_not_indent_me_either = 2\n'
'ifIam_indented_something_went_wrong = 2\n'
'if_I_am_indented_something_went_wrong = 2\n'
'end should not be unindented here\n'
'end if\n'
'endif\n'
'end do\n'
'end do\n'
'enddo\n'
'end subroutine\n'
'\n'
'subroutine test2(a)\n'
'integer :: a\n'
'do\n'
'a = a + 1\n'
'end do \n'
'end subroutine\n'
)
expected = (
'subroutine test(a)\n'
'integer :: a, i, j\n'
'\n'
'do\n'
' do \n'
' do j = 1, 5\n'
' if (a>b) then\n'
' if(b>0) then\n'
' a = 3\n'
' donot_indent_me = 2\n'
' do_not_indent_me_either = 2\n'
' ifIam_indented_something_went_wrong = 2\n'
' if_I_am_indented_something_went_wrong = 2\n'
' end should not be unindented here\n'
' end if\n'
' endif\n'
' end do\n'
' end do\n'
'enddo\n'
'end subroutine\n'
'\n'
'subroutine test2(a)\n'
'integer :: a\n'
'do\n'
' a = a + 1\n'
'end do \n'
'end subroutine\n'
)
p = FCodePrinter({'source_format': 'free'})
result = p.indent_code(codelines)
assert result == expected
def test_Matrix_printing():
x, y, z = symbols('x,y,z')
# Test returning a Matrix
mat = Matrix([x*y, Piecewise((2 + x, y>0), (y, True)), sin(z)])
A = MatrixSymbol('A', 3, 1)
assert fcode(mat, A) == (
" A(1, 1) = x*y\n"
" if (y > 0) then\n"
" A(2, 1) = x + 2\n"
" else\n"
" A(2, 1) = y\n"
" end if\n"
" A(3, 1) = sin(z)")
# Test using MatrixElements in expressions
expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0]
assert fcode(expr, standard=95) == (
" merge(2*A(3, 1), A(3, 1), x > 0) + sin(A(2, 1)) + A(1, 1)")
# Test using MatrixElements in a Matrix
q = MatrixSymbol('q', 5, 1)
M = MatrixSymbol('M', 3, 3)
m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])],
[q[1,0] + q[2,0], q[3, 0], 5],
[2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]])
assert fcode(m, M) == (
" M(1, 1) = sin(q(2, 1))\n"
" M(2, 1) = q(2, 1) + q(3, 1)\n"
" M(3, 1) = 2*q(5, 1)/q(2, 1)\n"
" M(1, 2) = 0\n"
" M(2, 2) = q(4, 1)\n"
" M(3, 2) = sqrt(q(1, 1)) + 4\n"
" M(1, 3) = cos(q(3, 1))\n"
" M(2, 3) = 5\n"
" M(3, 3) = 0")
def test_fcode_For():
x, y = symbols('x y')
f = For(x, Range(0, 10, 2), [Assignment(y, x * y)])
sol = fcode(f)
assert sol == (" do x = 0, 9, 2\n"
" y = x*y\n"
" end do")
def test_fcode_Declaration():
def check(expr, ref, **kwargs):
assert fcode(expr, standard=95, source_format='free', **kwargs) == ref
i = symbols('i', integer=True)
var1 = Variable.deduced(i)
dcl1 = Declaration(var1)
check(dcl1, "integer*4 :: i")
x, y = symbols('x y')
var2 = Variable(x, float32, value=42, attrs={value_const})
dcl2b = Declaration(var2)
check(dcl2b, 'real*4, parameter :: x = 42')
var3 = Variable(y, type=bool_)
dcl3 = Declaration(var3)
check(dcl3, 'logical :: y')
check(float32, "real*4")
check(float64, "real*8")
check(real, "real*4", type_aliases={real: float32})
check(real, "real*8", type_aliases={real: float64})
def test_MatrixElement_printing():
# test cases for issue #11821
A = MatrixSymbol("A", 1, 3)
B = MatrixSymbol("B", 1, 3)
C = MatrixSymbol("C", 1, 3)
assert(fcode(A[0, 0]) == " A(1, 1)")
assert(fcode(3 * A[0, 0]) == " 3*A(1, 1)")
F = C[0, 0].subs(C, A - B)
assert(fcode(F) == " (A - B)(1, 1)")
def test_aug_assign():
x = symbols('x')
assert fcode(aug_assign(x, '+', 1), source_format='free') == 'x = x + 1'
def test_While():
x = symbols('x')
assert fcode(While(abs(x) > 1, [aug_assign(x, '-', 1)]), source_format='free') == (
'do while (abs(x) > 1)\n'
' x = x - 1\n'
'end do'
)
def test_FunctionPrototype_print():
x = symbols('x')
n = symbols('n', integer=True)
vx = Variable(x, type=real)
vn = Variable(n, type=integer)
fp1 = FunctionPrototype(real, 'power', [vx, vn])
# Should be changed to proper test once multi-line generation is working
# see https://github.com/sympy/sympy/issues/15824
raises(NotImplementedError, lambda: fcode(fp1))
def test_FunctionDefinition_print():
x = symbols('x')
n = symbols('n', integer=True)
vx = Variable(x, type=real)
vn = Variable(n, type=integer)
body = [Assignment(x, x**n), Return(x)]
fd1 = FunctionDefinition(real, 'power', [vx, vn], body)
# Should be changed to proper test once multi-line generation is working
# see https://github.com/sympy/sympy/issues/15824
raises(NotImplementedError, lambda: fcode(fd1))

View File

@ -0,0 +1,998 @@
from sympy.core import (pi, symbols, Rational, Integer, GoldenRatio, EulerGamma,
Catalan, Lambda, Dummy, Eq, Ne, Le, Lt, Gt, Ge)
from sympy.functions import Piecewise, sin, cos, Abs, exp, ceiling, sqrt
from sympy.testing.pytest import raises, warns_deprecated_sympy
from sympy.printing.glsl import GLSLPrinter
from sympy.printing.str import StrPrinter
from sympy.utilities.lambdify import implemented_function
from sympy.tensor import IndexedBase, Idx
from sympy.matrices import Matrix, MatrixSymbol
from sympy.core import Tuple
from sympy.printing.glsl import glsl_code
import textwrap
x, y, z = symbols('x,y,z')
def test_printmethod():
assert glsl_code(Abs(x)) == "abs(x)"
def test_print_without_operators():
assert glsl_code(x*y,use_operators = False) == 'mul(x, y)'
assert glsl_code(x**y+z,use_operators = False) == 'add(pow(x, y), z)'
assert glsl_code(x*(y+z),use_operators = False) == 'mul(x, add(y, z))'
assert glsl_code(x*(y+z),use_operators = False) == 'mul(x, add(y, z))'
assert glsl_code(x*(y+z**y**0.5),use_operators = False) == 'mul(x, add(y, pow(z, sqrt(y))))'
assert glsl_code(-x-y, use_operators=False, zero='zero()') == 'sub(zero(), add(x, y))'
assert glsl_code(-x-y, use_operators=False) == 'sub(0.0, add(x, y))'
def test_glsl_code_sqrt():
assert glsl_code(sqrt(x)) == "sqrt(x)"
assert glsl_code(x**0.5) == "sqrt(x)"
assert glsl_code(sqrt(x)) == "sqrt(x)"
def test_glsl_code_Pow():
g = implemented_function('g', Lambda(x, 2*x))
assert glsl_code(x**3) == "pow(x, 3.0)"
assert glsl_code(x**(y**3)) == "pow(x, pow(y, 3.0))"
assert glsl_code(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \
"pow(3.5*2*x, -x + pow(y, x))/(pow(x, 2.0) + y)"
assert glsl_code(x**-1.0) == '1.0/x'
def test_glsl_code_Relational():
assert glsl_code(Eq(x, y)) == "x == y"
assert glsl_code(Ne(x, y)) == "x != y"
assert glsl_code(Le(x, y)) == "x <= y"
assert glsl_code(Lt(x, y)) == "x < y"
assert glsl_code(Gt(x, y)) == "x > y"
assert glsl_code(Ge(x, y)) == "x >= y"
def test_glsl_code_constants_mathh():
assert glsl_code(exp(1)) == "float E = 2.71828183;\nE"
assert glsl_code(pi) == "float pi = 3.14159265;\npi"
# assert glsl_code(oo) == "Number.POSITIVE_INFINITY"
# assert glsl_code(-oo) == "Number.NEGATIVE_INFINITY"
def test_glsl_code_constants_other():
assert glsl_code(2*GoldenRatio) == "float GoldenRatio = 1.61803399;\n2*GoldenRatio"
assert glsl_code(2*Catalan) == "float Catalan = 0.915965594;\n2*Catalan"
assert glsl_code(2*EulerGamma) == "float EulerGamma = 0.577215665;\n2*EulerGamma"
def test_glsl_code_Rational():
assert glsl_code(Rational(3, 7)) == "3.0/7.0"
assert glsl_code(Rational(18, 9)) == "2"
assert glsl_code(Rational(3, -7)) == "-3.0/7.0"
assert glsl_code(Rational(-3, -7)) == "3.0/7.0"
def test_glsl_code_Integer():
assert glsl_code(Integer(67)) == "67"
assert glsl_code(Integer(-1)) == "-1"
def test_glsl_code_functions():
assert glsl_code(sin(x) ** cos(x)) == "pow(sin(x), cos(x))"
def test_glsl_code_inline_function():
x = symbols('x')
g = implemented_function('g', Lambda(x, 2*x))
assert glsl_code(g(x)) == "2*x"
g = implemented_function('g', Lambda(x, 2*x/Catalan))
assert glsl_code(g(x)) == "float Catalan = 0.915965594;\n2*x/Catalan"
A = IndexedBase('A')
i = Idx('i', symbols('n', integer=True))
g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x)))
assert glsl_code(g(A[i]), assign_to=A[i]) == (
"for (int i=0; i<n; i++){\n"
" A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n"
"}"
)
def test_glsl_code_exceptions():
assert glsl_code(ceiling(x)) == "ceil(x)"
assert glsl_code(Abs(x)) == "abs(x)"
def test_glsl_code_boolean():
assert glsl_code(x & y) == "x && y"
assert glsl_code(x | y) == "x || y"
assert glsl_code(~x) == "!x"
assert glsl_code(x & y & z) == "x && y && z"
assert glsl_code(x | y | z) == "x || y || z"
assert glsl_code((x & y) | z) == "z || x && y"
assert glsl_code((x | y) & z) == "z && (x || y)"
def test_glsl_code_Piecewise():
expr = Piecewise((x, x < 1), (x**2, True))
p = glsl_code(expr)
s = \
"""\
((x < 1) ? (
x
)
: (
pow(x, 2.0)
))\
"""
assert p == s
assert glsl_code(expr, assign_to="c") == (
"if (x < 1) {\n"
" c = x;\n"
"}\n"
"else {\n"
" c = pow(x, 2.0);\n"
"}")
# Check that Piecewise without a True (default) condition error
expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
raises(ValueError, lambda: glsl_code(expr))
def test_glsl_code_Piecewise_deep():
p = glsl_code(2*Piecewise((x, x < 1), (x**2, True)))
s = \
"""\
2*((x < 1) ? (
x
)
: (
pow(x, 2.0)
))\
"""
assert p == s
def test_glsl_code_settings():
raises(TypeError, lambda: glsl_code(sin(x), method="garbage"))
def test_glsl_code_Indexed():
n, m, o = symbols('n m o', integer=True)
i, j, k = Idx('i', n), Idx('j', m), Idx('k', o)
p = GLSLPrinter()
p._not_c = set()
x = IndexedBase('x')[j]
assert p._print_Indexed(x) == 'x[j]'
A = IndexedBase('A')[i, j]
assert p._print_Indexed(A) == 'A[%s]' % (m*i+j)
B = IndexedBase('B')[i, j, k]
assert p._print_Indexed(B) == 'B[%s]' % (i*o*m+j*o+k)
assert p._not_c == set()
def test_glsl_code_list_tuple_Tuple():
assert glsl_code([1,2,3,4]) == 'vec4(1, 2, 3, 4)'
assert glsl_code([1,2,3],glsl_types=False) == 'float[3](1, 2, 3)'
assert glsl_code([1,2,3]) == glsl_code((1,2,3))
assert glsl_code([1,2,3]) == glsl_code(Tuple(1,2,3))
m = MatrixSymbol('A',3,4)
assert glsl_code([m[0],m[1]])
def test_glsl_code_loops_matrix_vector():
n, m = symbols('n m', integer=True)
A = IndexedBase('A')
x = IndexedBase('x')
y = IndexedBase('y')
i = Idx('i', m)
j = Idx('j', n)
s = (
'for (int i=0; i<m; i++){\n'
' y[i] = 0.0;\n'
'}\n'
'for (int i=0; i<m; i++){\n'
' for (int j=0; j<n; j++){\n'
' y[i] = A[n*i + j]*x[j] + y[i];\n'
' }\n'
'}'
)
c = glsl_code(A[i, j]*x[j], assign_to=y[i])
assert c == s
def test_dummy_loops():
i, m = symbols('i m', integer=True, cls=Dummy)
x = IndexedBase('x')
y = IndexedBase('y')
i = Idx(i, m)
expected = (
'for (int i_%(icount)i=0; i_%(icount)i<m_%(mcount)i; i_%(icount)i++){\n'
' y[i_%(icount)i] = x[i_%(icount)i];\n'
'}'
) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index}
code = glsl_code(x[i], assign_to=y[i])
assert code == expected
def test_glsl_code_loops_add():
n, m = symbols('n m', integer=True)
A = IndexedBase('A')
x = IndexedBase('x')
y = IndexedBase('y')
z = IndexedBase('z')
i = Idx('i', m)
j = Idx('j', n)
s = (
'for (int i=0; i<m; i++){\n'
' y[i] = x[i] + z[i];\n'
'}\n'
'for (int i=0; i<m; i++){\n'
' for (int j=0; j<n; j++){\n'
' y[i] = A[n*i + j]*x[j] + y[i];\n'
' }\n'
'}'
)
c = glsl_code(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i])
assert c == s
def test_glsl_code_loops_multiple_contractions():
n, m, o, p = symbols('n m o p', integer=True)
a = IndexedBase('a')
b = IndexedBase('b')
y = IndexedBase('y')
i = Idx('i', m)
j = Idx('j', n)
k = Idx('k', o)
l = Idx('l', p)
s = (
'for (int i=0; i<m; i++){\n'
' y[i] = 0.0;\n'
'}\n'
'for (int i=0; i<m; i++){\n'
' for (int j=0; j<n; j++){\n'
' for (int k=0; k<o; k++){\n'
' for (int l=0; l<p; l++){\n'
' y[i] = a[%s]*b[%s] + y[i];\n' % (i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
' }\n'
' }\n'
' }\n'
'}'
)
c = glsl_code(b[j, k, l]*a[i, j, k, l], assign_to=y[i])
assert c == s
def test_glsl_code_loops_addfactor():
n, m, o, p = symbols('n m o p', integer=True)
a = IndexedBase('a')
b = IndexedBase('b')
c = IndexedBase('c')
y = IndexedBase('y')
i = Idx('i', m)
j = Idx('j', n)
k = Idx('k', o)
l = Idx('l', p)
s = (
'for (int i=0; i<m; i++){\n'
' y[i] = 0.0;\n'
'}\n'
'for (int i=0; i<m; i++){\n'
' for (int j=0; j<n; j++){\n'
' for (int k=0; k<o; k++){\n'
' for (int l=0; l<p; l++){\n'
' y[i] = (a[%s] + b[%s])*c[%s] + y[i];\n' % (i*n*o*p + j*o*p + k*p + l, i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
' }\n'
' }\n'
' }\n'
'}'
)
c = glsl_code((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i])
assert c == s
def test_glsl_code_loops_multiple_terms():
n, m, o, p = symbols('n m o p', integer=True)
a = IndexedBase('a')
b = IndexedBase('b')
c = IndexedBase('c')
y = IndexedBase('y')
i = Idx('i', m)
j = Idx('j', n)
k = Idx('k', o)
s0 = (
'for (int i=0; i<m; i++){\n'
' y[i] = 0.0;\n'
'}\n'
)
s1 = (
'for (int i=0; i<m; i++){\n'
' for (int j=0; j<n; j++){\n'
' for (int k=0; k<o; k++){\n'
' y[i] = b[j]*b[k]*c[%s] + y[i];\n' % (i*n*o + j*o + k) +\
' }\n'
' }\n'
'}\n'
)
s2 = (
'for (int i=0; i<m; i++){\n'
' for (int k=0; k<o; k++){\n'
' y[i] = a[%s]*b[k] + y[i];\n' % (i*o + k) +\
' }\n'
'}\n'
)
s3 = (
'for (int i=0; i<m; i++){\n'
' for (int j=0; j<n; j++){\n'
' y[i] = a[%s]*b[j] + y[i];\n' % (i*n + j) +\
' }\n'
'}\n'
)
c = glsl_code(
b[j]*a[i, j] + b[k]*a[i, k] + b[j]*b[k]*c[i, j, k], assign_to=y[i])
assert (c == s0 + s1 + s2 + s3[:-1] or
c == s0 + s1 + s3 + s2[:-1] or
c == s0 + s2 + s1 + s3[:-1] or
c == s0 + s2 + s3 + s1[:-1] or
c == s0 + s3 + s1 + s2[:-1] or
c == s0 + s3 + s2 + s1[:-1])
def test_Matrix_printing():
# Test returning a Matrix
mat = Matrix([x*y, Piecewise((2 + x, y>0), (y, True)), sin(z)])
A = MatrixSymbol('A', 3, 1)
assert glsl_code(mat, assign_to=A) == (
'''A[0][0] = x*y;
if (y > 0) {
A[1][0] = x + 2;
}
else {
A[1][0] = y;
}
A[2][0] = sin(z);''' )
assert glsl_code(Matrix([A[0],A[1]]))
# Test using MatrixElements in expressions
expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0]
assert glsl_code(expr) == (
'''((x > 0) ? (
2*A[2][0]
)
: (
A[2][0]
)) + sin(A[1][0]) + A[0][0]''' )
# Test using MatrixElements in a Matrix
q = MatrixSymbol('q', 5, 1)
M = MatrixSymbol('M', 3, 3)
m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])],
[q[1,0] + q[2,0], q[3, 0], 5],
[2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]])
assert glsl_code(m,M) == (
'''M[0][0] = sin(q[1]);
M[0][1] = 0;
M[0][2] = cos(q[2]);
M[1][0] = q[1] + q[2];
M[1][1] = q[3];
M[1][2] = 5;
M[2][0] = 2*q[4]/q[1];
M[2][1] = sqrt(q[0]) + 4;
M[2][2] = 0;'''
)
def test_Matrices_1x7():
gl = glsl_code
A = Matrix([1,2,3,4,5,6,7])
assert gl(A) == 'float[7](1, 2, 3, 4, 5, 6, 7)'
assert gl(A.transpose()) == 'float[7](1, 2, 3, 4, 5, 6, 7)'
def test_Matrices_1x7_array_type_int():
gl = glsl_code
A = Matrix([1,2,3,4,5,6,7])
assert gl(A, array_type='int') == 'int[7](1, 2, 3, 4, 5, 6, 7)'
def test_Tuple_array_type_custom():
gl = glsl_code
A = symbols('a b c')
assert gl(A, array_type='AbcType', glsl_types=False) == 'AbcType[3](a, b, c)'
def test_Matrices_1x7_spread_assign_to_symbols():
gl = glsl_code
A = Matrix([1,2,3,4,5,6,7])
assign_to = symbols('x.a x.b x.c x.d x.e x.f x.g')
assert gl(A, assign_to=assign_to) == textwrap.dedent('''\
x.a = 1;
x.b = 2;
x.c = 3;
x.d = 4;
x.e = 5;
x.f = 6;
x.g = 7;'''
)
def test_spread_assign_to_nested_symbols():
gl = glsl_code
expr = ((1,2,3), (1,2,3))
assign_to = (symbols('a b c'), symbols('x y z'))
assert gl(expr, assign_to=assign_to) == textwrap.dedent('''\
a = 1;
b = 2;
c = 3;
x = 1;
y = 2;
z = 3;'''
)
def test_spread_assign_to_deeply_nested_symbols():
gl = glsl_code
a, b, c, x, y, z = symbols('a b c x y z')
expr = (((1,2),3), ((1,2),3))
assign_to = (((a, b), c), ((x, y), z))
assert gl(expr, assign_to=assign_to) == textwrap.dedent('''\
a = 1;
b = 2;
c = 3;
x = 1;
y = 2;
z = 3;'''
)
def test_matrix_of_tuples_spread_assign_to_symbols():
gl = glsl_code
with warns_deprecated_sympy():
expr = Matrix([[(1,2),(3,4)],[(5,6),(7,8)]])
assign_to = (symbols('a b'), symbols('c d'), symbols('e f'), symbols('g h'))
assert gl(expr, assign_to) == textwrap.dedent('''\
a = 1;
b = 2;
c = 3;
d = 4;
e = 5;
f = 6;
g = 7;
h = 8;'''
)
def test_cannot_assign_to_cause_mismatched_length():
expr = (1, 2)
assign_to = symbols('x y z')
raises(ValueError, lambda: glsl_code(expr, assign_to))
def test_matrix_4x4_assign():
gl = glsl_code
expr = MatrixSymbol('A',4,4) * MatrixSymbol('B',4,4) + MatrixSymbol('C',4,4)
assign_to = MatrixSymbol('X',4,4)
assert gl(expr, assign_to=assign_to) == textwrap.dedent('''\
X[0][0] = A[0][0]*B[0][0] + A[0][1]*B[1][0] + A[0][2]*B[2][0] + A[0][3]*B[3][0] + C[0][0];
X[0][1] = A[0][0]*B[0][1] + A[0][1]*B[1][1] + A[0][2]*B[2][1] + A[0][3]*B[3][1] + C[0][1];
X[0][2] = A[0][0]*B[0][2] + A[0][1]*B[1][2] + A[0][2]*B[2][2] + A[0][3]*B[3][2] + C[0][2];
X[0][3] = A[0][0]*B[0][3] + A[0][1]*B[1][3] + A[0][2]*B[2][3] + A[0][3]*B[3][3] + C[0][3];
X[1][0] = A[1][0]*B[0][0] + A[1][1]*B[1][0] + A[1][2]*B[2][0] + A[1][3]*B[3][0] + C[1][0];
X[1][1] = A[1][0]*B[0][1] + A[1][1]*B[1][1] + A[1][2]*B[2][1] + A[1][3]*B[3][1] + C[1][1];
X[1][2] = A[1][0]*B[0][2] + A[1][1]*B[1][2] + A[1][2]*B[2][2] + A[1][3]*B[3][2] + C[1][2];
X[1][3] = A[1][0]*B[0][3] + A[1][1]*B[1][3] + A[1][2]*B[2][3] + A[1][3]*B[3][3] + C[1][3];
X[2][0] = A[2][0]*B[0][0] + A[2][1]*B[1][0] + A[2][2]*B[2][0] + A[2][3]*B[3][0] + C[2][0];
X[2][1] = A[2][0]*B[0][1] + A[2][1]*B[1][1] + A[2][2]*B[2][1] + A[2][3]*B[3][1] + C[2][1];
X[2][2] = A[2][0]*B[0][2] + A[2][1]*B[1][2] + A[2][2]*B[2][2] + A[2][3]*B[3][2] + C[2][2];
X[2][3] = A[2][0]*B[0][3] + A[2][1]*B[1][3] + A[2][2]*B[2][3] + A[2][3]*B[3][3] + C[2][3];
X[3][0] = A[3][0]*B[0][0] + A[3][1]*B[1][0] + A[3][2]*B[2][0] + A[3][3]*B[3][0] + C[3][0];
X[3][1] = A[3][0]*B[0][1] + A[3][1]*B[1][1] + A[3][2]*B[2][1] + A[3][3]*B[3][1] + C[3][1];
X[3][2] = A[3][0]*B[0][2] + A[3][1]*B[1][2] + A[3][2]*B[2][2] + A[3][3]*B[3][2] + C[3][2];
X[3][3] = A[3][0]*B[0][3] + A[3][1]*B[1][3] + A[3][2]*B[2][3] + A[3][3]*B[3][3] + C[3][3];'''
)
def test_1xN_vecs():
gl = glsl_code
for i in range(1,10):
A = Matrix(range(i))
assert gl(A.transpose()) == gl(A)
assert gl(A,mat_transpose=True) == gl(A)
if i > 1:
if i <= 4:
assert gl(A) == 'vec%s(%s)' % (i,', '.join(str(s) for s in range(i)))
else:
assert gl(A) == 'float[%s](%s)' % (i,', '.join(str(s) for s in range(i)))
def test_MxN_mats():
generatedAssertions='def test_misc_mats():\n'
for i in range(1,6):
for j in range(1,6):
A = Matrix([[x + y*j for x in range(j)] for y in range(i)])
gl = glsl_code(A)
glTransposed = glsl_code(A,mat_transpose=True)
generatedAssertions+=' mat = '+StrPrinter()._print(A)+'\n\n'
generatedAssertions+=' gl = \'\'\''+gl+'\'\'\'\n'
generatedAssertions+=' glTransposed = \'\'\''+glTransposed+'\'\'\'\n\n'
generatedAssertions+=' assert glsl_code(mat) == gl\n'
generatedAssertions+=' assert glsl_code(mat,mat_transpose=True) == glTransposed\n'
if i == 1 and j == 1:
assert gl == '0'
elif i <= 4 and j <= 4 and i>1 and j>1:
assert gl.startswith('mat%s' % j)
assert glTransposed.startswith('mat%s' % i)
elif i == 1 and j <= 4:
assert gl.startswith('vec')
elif j == 1 and i <= 4:
assert gl.startswith('vec')
elif i == 1:
assert gl.startswith('float[%s]('% j*i)
assert glTransposed.startswith('float[%s]('% j*i)
elif j == 1:
assert gl.startswith('float[%s]('% i*j)
assert glTransposed.startswith('float[%s]('% i*j)
else:
assert gl.startswith('float[%s](' % (i*j))
assert glTransposed.startswith('float[%s](' % (i*j))
glNested = glsl_code(A,mat_nested=True)
glNestedTransposed = glsl_code(A,mat_transpose=True,mat_nested=True)
assert glNested.startswith('float[%s][%s]' % (i,j))
assert glNestedTransposed.startswith('float[%s][%s]' % (j,i))
generatedAssertions+=' glNested = \'\'\''+glNested+'\'\'\'\n'
generatedAssertions+=' glNestedTransposed = \'\'\''+glNestedTransposed+'\'\'\'\n\n'
generatedAssertions+=' assert glsl_code(mat,mat_nested=True) == glNested\n'
generatedAssertions+=' assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed\n\n'
generateAssertions = False # set this to true to write bake these generated tests to a file
if generateAssertions:
gen = open('test_glsl_generated_matrices.py','w')
gen.write(generatedAssertions)
gen.close()
# these assertions were generated from the previous function
# glsl has complicated rules and this makes it easier to look over all the cases
def test_misc_mats():
mat = Matrix([[0]])
gl = '''0'''
glTransposed = '''0'''
assert glsl_code(mat) == gl
assert glsl_code(mat,mat_transpose=True) == glTransposed
mat = Matrix([[0, 1]])
gl = '''vec2(0, 1)'''
glTransposed = '''vec2(0, 1)'''
assert glsl_code(mat) == gl
assert glsl_code(mat,mat_transpose=True) == glTransposed
mat = Matrix([[0, 1, 2]])
gl = '''vec3(0, 1, 2)'''
glTransposed = '''vec3(0, 1, 2)'''
assert glsl_code(mat) == gl
assert glsl_code(mat,mat_transpose=True) == glTransposed
mat = Matrix([[0, 1, 2, 3]])
gl = '''vec4(0, 1, 2, 3)'''
glTransposed = '''vec4(0, 1, 2, 3)'''
assert glsl_code(mat) == gl
assert glsl_code(mat,mat_transpose=True) == glTransposed
mat = Matrix([[0, 1, 2, 3, 4]])
gl = '''float[5](0, 1, 2, 3, 4)'''
glTransposed = '''float[5](0, 1, 2, 3, 4)'''
assert glsl_code(mat) == gl
assert glsl_code(mat,mat_transpose=True) == glTransposed
mat = Matrix([
[0],
[1]])
gl = '''vec2(0, 1)'''
glTransposed = '''vec2(0, 1)'''
assert glsl_code(mat) == gl
assert glsl_code(mat,mat_transpose=True) == glTransposed
mat = Matrix([
[0, 1],
[2, 3]])
gl = '''mat2(0, 1, 2, 3)'''
glTransposed = '''mat2(0, 2, 1, 3)'''
assert glsl_code(mat) == gl
assert glsl_code(mat,mat_transpose=True) == glTransposed
mat = Matrix([
[0, 1, 2],
[3, 4, 5]])
gl = '''mat3x2(0, 1, 2, 3, 4, 5)'''
glTransposed = '''mat2x3(0, 3, 1, 4, 2, 5)'''
assert glsl_code(mat) == gl
assert glsl_code(mat,mat_transpose=True) == glTransposed
mat = Matrix([
[0, 1, 2, 3],
[4, 5, 6, 7]])
gl = '''mat4x2(0, 1, 2, 3, 4, 5, 6, 7)'''
glTransposed = '''mat2x4(0, 4, 1, 5, 2, 6, 3, 7)'''
assert glsl_code(mat) == gl
assert glsl_code(mat,mat_transpose=True) == glTransposed
mat = Matrix([
[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]])
gl = '''float[10](
0, 1, 2, 3, 4,
5, 6, 7, 8, 9
) /* a 2x5 matrix */'''
glTransposed = '''float[10](
0, 5,
1, 6,
2, 7,
3, 8,
4, 9
) /* a 5x2 matrix */'''
assert glsl_code(mat) == gl
assert glsl_code(mat,mat_transpose=True) == glTransposed
glNested = '''float[2][5](
float[](0, 1, 2, 3, 4),
float[](5, 6, 7, 8, 9)
)'''
glNestedTransposed = '''float[5][2](
float[](0, 5),
float[](1, 6),
float[](2, 7),
float[](3, 8),
float[](4, 9)
)'''
assert glsl_code(mat,mat_nested=True) == glNested
assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed
mat = Matrix([
[0],
[1],
[2]])
gl = '''vec3(0, 1, 2)'''
glTransposed = '''vec3(0, 1, 2)'''
assert glsl_code(mat) == gl
assert glsl_code(mat,mat_transpose=True) == glTransposed
mat = Matrix([
[0, 1],
[2, 3],
[4, 5]])
gl = '''mat2x3(0, 1, 2, 3, 4, 5)'''
glTransposed = '''mat3x2(0, 2, 4, 1, 3, 5)'''
assert glsl_code(mat) == gl
assert glsl_code(mat,mat_transpose=True) == glTransposed
mat = Matrix([
[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
gl = '''mat3(0, 1, 2, 3, 4, 5, 6, 7, 8)'''
glTransposed = '''mat3(0, 3, 6, 1, 4, 7, 2, 5, 8)'''
assert glsl_code(mat) == gl
assert glsl_code(mat,mat_transpose=True) == glTransposed
mat = Matrix([
[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]])
gl = '''mat4x3(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11)'''
glTransposed = '''mat3x4(0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11)'''
assert glsl_code(mat) == gl
assert glsl_code(mat,mat_transpose=True) == glTransposed
mat = Matrix([
[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]])
gl = '''float[15](
0, 1, 2, 3, 4,
5, 6, 7, 8, 9,
10, 11, 12, 13, 14
) /* a 3x5 matrix */'''
glTransposed = '''float[15](
0, 5, 10,
1, 6, 11,
2, 7, 12,
3, 8, 13,
4, 9, 14
) /* a 5x3 matrix */'''
assert glsl_code(mat) == gl
assert glsl_code(mat,mat_transpose=True) == glTransposed
glNested = '''float[3][5](
float[]( 0, 1, 2, 3, 4),
float[]( 5, 6, 7, 8, 9),
float[](10, 11, 12, 13, 14)
)'''
glNestedTransposed = '''float[5][3](
float[](0, 5, 10),
float[](1, 6, 11),
float[](2, 7, 12),
float[](3, 8, 13),
float[](4, 9, 14)
)'''
assert glsl_code(mat,mat_nested=True) == glNested
assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed
mat = Matrix([
[0],
[1],
[2],
[3]])
gl = '''vec4(0, 1, 2, 3)'''
glTransposed = '''vec4(0, 1, 2, 3)'''
assert glsl_code(mat) == gl
assert glsl_code(mat,mat_transpose=True) == glTransposed
mat = Matrix([
[0, 1],
[2, 3],
[4, 5],
[6, 7]])
gl = '''mat2x4(0, 1, 2, 3, 4, 5, 6, 7)'''
glTransposed = '''mat4x2(0, 2, 4, 6, 1, 3, 5, 7)'''
assert glsl_code(mat) == gl
assert glsl_code(mat,mat_transpose=True) == glTransposed
mat = Matrix([
[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
[9, 10, 11]])
gl = '''mat3x4(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11)'''
glTransposed = '''mat4x3(0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11)'''
assert glsl_code(mat) == gl
assert glsl_code(mat,mat_transpose=True) == glTransposed
mat = Matrix([
[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
gl = '''mat4( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)'''
glTransposed = '''mat4(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15)'''
assert glsl_code(mat) == gl
assert glsl_code(mat,mat_transpose=True) == glTransposed
mat = Matrix([
[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]])
gl = '''float[20](
0, 1, 2, 3, 4,
5, 6, 7, 8, 9,
10, 11, 12, 13, 14,
15, 16, 17, 18, 19
) /* a 4x5 matrix */'''
glTransposed = '''float[20](
0, 5, 10, 15,
1, 6, 11, 16,
2, 7, 12, 17,
3, 8, 13, 18,
4, 9, 14, 19
) /* a 5x4 matrix */'''
assert glsl_code(mat) == gl
assert glsl_code(mat,mat_transpose=True) == glTransposed
glNested = '''float[4][5](
float[]( 0, 1, 2, 3, 4),
float[]( 5, 6, 7, 8, 9),
float[](10, 11, 12, 13, 14),
float[](15, 16, 17, 18, 19)
)'''
glNestedTransposed = '''float[5][4](
float[](0, 5, 10, 15),
float[](1, 6, 11, 16),
float[](2, 7, 12, 17),
float[](3, 8, 13, 18),
float[](4, 9, 14, 19)
)'''
assert glsl_code(mat,mat_nested=True) == glNested
assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed
mat = Matrix([
[0],
[1],
[2],
[3],
[4]])
gl = '''float[5](0, 1, 2, 3, 4)'''
glTransposed = '''float[5](0, 1, 2, 3, 4)'''
assert glsl_code(mat) == gl
assert glsl_code(mat,mat_transpose=True) == glTransposed
mat = Matrix([
[0, 1],
[2, 3],
[4, 5],
[6, 7],
[8, 9]])
gl = '''float[10](
0, 1,
2, 3,
4, 5,
6, 7,
8, 9
) /* a 5x2 matrix */'''
glTransposed = '''float[10](
0, 2, 4, 6, 8,
1, 3, 5, 7, 9
) /* a 2x5 matrix */'''
assert glsl_code(mat) == gl
assert glsl_code(mat,mat_transpose=True) == glTransposed
glNested = '''float[5][2](
float[](0, 1),
float[](2, 3),
float[](4, 5),
float[](6, 7),
float[](8, 9)
)'''
glNestedTransposed = '''float[2][5](
float[](0, 2, 4, 6, 8),
float[](1, 3, 5, 7, 9)
)'''
assert glsl_code(mat,mat_nested=True) == glNested
assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed
mat = Matrix([
[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11],
[12, 13, 14]])
gl = '''float[15](
0, 1, 2,
3, 4, 5,
6, 7, 8,
9, 10, 11,
12, 13, 14
) /* a 5x3 matrix */'''
glTransposed = '''float[15](
0, 3, 6, 9, 12,
1, 4, 7, 10, 13,
2, 5, 8, 11, 14
) /* a 3x5 matrix */'''
assert glsl_code(mat) == gl
assert glsl_code(mat,mat_transpose=True) == glTransposed
glNested = '''float[5][3](
float[]( 0, 1, 2),
float[]( 3, 4, 5),
float[]( 6, 7, 8),
float[]( 9, 10, 11),
float[](12, 13, 14)
)'''
glNestedTransposed = '''float[3][5](
float[](0, 3, 6, 9, 12),
float[](1, 4, 7, 10, 13),
float[](2, 5, 8, 11, 14)
)'''
assert glsl_code(mat,mat_nested=True) == glNested
assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed
mat = Matrix([
[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15],
[16, 17, 18, 19]])
gl = '''float[20](
0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19
) /* a 5x4 matrix */'''
glTransposed = '''float[20](
0, 4, 8, 12, 16,
1, 5, 9, 13, 17,
2, 6, 10, 14, 18,
3, 7, 11, 15, 19
) /* a 4x5 matrix */'''
assert glsl_code(mat) == gl
assert glsl_code(mat,mat_transpose=True) == glTransposed
glNested = '''float[5][4](
float[]( 0, 1, 2, 3),
float[]( 4, 5, 6, 7),
float[]( 8, 9, 10, 11),
float[](12, 13, 14, 15),
float[](16, 17, 18, 19)
)'''
glNestedTransposed = '''float[4][5](
float[](0, 4, 8, 12, 16),
float[](1, 5, 9, 13, 17),
float[](2, 6, 10, 14, 18),
float[](3, 7, 11, 15, 19)
)'''
assert glsl_code(mat,mat_nested=True) == glNested
assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed
mat = Matrix([
[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]])
gl = '''float[25](
0, 1, 2, 3, 4,
5, 6, 7, 8, 9,
10, 11, 12, 13, 14,
15, 16, 17, 18, 19,
20, 21, 22, 23, 24
) /* a 5x5 matrix */'''
glTransposed = '''float[25](
0, 5, 10, 15, 20,
1, 6, 11, 16, 21,
2, 7, 12, 17, 22,
3, 8, 13, 18, 23,
4, 9, 14, 19, 24
) /* a 5x5 matrix */'''
assert glsl_code(mat) == gl
assert glsl_code(mat,mat_transpose=True) == glTransposed
glNested = '''float[5][5](
float[]( 0, 1, 2, 3, 4),
float[]( 5, 6, 7, 8, 9),
float[](10, 11, 12, 13, 14),
float[](15, 16, 17, 18, 19),
float[](20, 21, 22, 23, 24)
)'''
glNestedTransposed = '''float[5][5](
float[](0, 5, 10, 15, 20),
float[](1, 6, 11, 16, 21),
float[](2, 7, 12, 17, 22),
float[](3, 8, 13, 18, 23),
float[](4, 9, 14, 19, 24)
)'''
assert glsl_code(mat,mat_nested=True) == glNested
assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed

View File

@ -0,0 +1,18 @@
from sympy.functions.elementary.trigonometric import sin
from sympy.printing.gtk import print_gtk
from sympy.testing.pytest import XFAIL, raises
# this test fails if python-lxml isn't installed. We don't want to depend on
# anything with SymPy
@XFAIL
def test_1():
from sympy.abc import x
print_gtk(x**2, start_viewer=False)
print_gtk(x**2 + sin(x)/4, start_viewer=False)
def test_settings():
from sympy.abc import x
raises(TypeError, lambda: print_gtk(x, method="garbage"))

View File

@ -0,0 +1,370 @@
from sympy.concrete.summations import Sum
from sympy.core.mod import Mod
from sympy.core.relational import (Equality, Unequality)
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.piecewise import Piecewise
from sympy.matrices.expressions.blockmatrix import BlockMatrix
from sympy.matrices.expressions.matexpr import MatrixSymbol
from sympy.matrices.expressions.special import Identity
from sympy.utilities.lambdify import lambdify
from sympy.abc import x, i, j, a, b, c, d
from sympy.core import Function, Pow, Symbol
from sympy.codegen.matrix_nodes import MatrixSolve
from sympy.codegen.numpy_nodes import logaddexp, logaddexp2
from sympy.codegen.cfunctions import log1p, expm1, hypot, log10, exp2, log2, Sqrt
from sympy.tensor.array import Array
from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayAdd, \
PermuteDims, ArrayDiagonal
from sympy.printing.numpy import JaxPrinter, _jax_known_constants, _jax_known_functions
from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array
from sympy.testing.pytest import skip, raises
from sympy.external import import_module
# Unlike NumPy which will aggressively promote operands to double precision,
# jax always uses single precision. Double precision in jax can be
# configured before the call to `import jax`, however this must be explicitly
# configured and is not fully supported. Thus, the tests here have been modified
# from the tests in test_numpy.py, only in the fact that they assert lambdify
# function accuracy to only single precision accuracy.
# https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
jax = import_module('jax')
if jax:
deafult_float_info = jax.numpy.finfo(jax.numpy.array([]).dtype)
JAX_DEFAULT_EPSILON = deafult_float_info.eps
def test_jax_piecewise_regression():
"""
NumPyPrinter needs to print Piecewise()'s choicelist as a list to avoid
breaking compatibility with numpy 1.8. This is not necessary in numpy 1.9+.
See gh-9747 and gh-9749 for details.
"""
printer = JaxPrinter()
p = Piecewise((1, x < 0), (0, True))
assert printer.doprint(p) == \
'jax.numpy.select([jax.numpy.less(x, 0),True], [1,0], default=jax.numpy.nan)'
assert printer.module_imports == {'jax.numpy': {'select', 'less', 'nan'}}
def test_jax_logaddexp():
lae = logaddexp(a, b)
assert JaxPrinter().doprint(lae) == 'jax.numpy.logaddexp(a, b)'
lae2 = logaddexp2(a, b)
assert JaxPrinter().doprint(lae2) == 'jax.numpy.logaddexp2(a, b)'
def test_jax_sum():
if not jax:
skip("JAX not installed")
s = Sum(x ** i, (i, a, b))
f = lambdify((a, b, x), s, 'jax')
a_, b_ = 0, 10
x_ = jax.numpy.linspace(-1, +1, 10)
assert jax.numpy.allclose(f(a_, b_, x_), sum(x_ ** i_ for i_ in range(a_, b_ + 1)))
s = Sum(i * x, (i, a, b))
f = lambdify((a, b, x), s, 'jax')
a_, b_ = 0, 10
x_ = jax.numpy.linspace(-1, +1, 10)
assert jax.numpy.allclose(f(a_, b_, x_), sum(i_ * x_ for i_ in range(a_, b_ + 1)))
def test_jax_multiple_sums():
if not jax:
skip("JAX not installed")
s = Sum((x + j) * i, (i, a, b), (j, c, d))
f = lambdify((a, b, c, d, x), s, 'jax')
a_, b_ = 0, 10
c_, d_ = 11, 21
x_ = jax.numpy.linspace(-1, +1, 10)
assert jax.numpy.allclose(f(a_, b_, c_, d_, x_),
sum((x_ + j_) * i_ for i_ in range(a_, b_ + 1) for j_ in range(c_, d_ + 1)))
def test_jax_codegen_einsum():
if not jax:
skip("JAX not installed")
M = MatrixSymbol("M", 2, 2)
N = MatrixSymbol("N", 2, 2)
cg = convert_matrix_to_array(M * N)
f = lambdify((M, N), cg, 'jax')
ma = jax.numpy.array([[1, 2], [3, 4]])
mb = jax.numpy.array([[1,-2], [-1, 3]])
assert (f(ma, mb) == jax.numpy.matmul(ma, mb)).all()
def test_jax_codegen_extra():
if not jax:
skip("JAX not installed")
M = MatrixSymbol("M", 2, 2)
N = MatrixSymbol("N", 2, 2)
P = MatrixSymbol("P", 2, 2)
Q = MatrixSymbol("Q", 2, 2)
ma = jax.numpy.array([[1, 2], [3, 4]])
mb = jax.numpy.array([[1,-2], [-1, 3]])
mc = jax.numpy.array([[2, 0], [1, 2]])
md = jax.numpy.array([[1,-1], [4, 7]])
cg = ArrayTensorProduct(M, N)
f = lambdify((M, N), cg, 'jax')
assert (f(ma, mb) == jax.numpy.einsum(ma, [0, 1], mb, [2, 3])).all()
cg = ArrayAdd(M, N)
f = lambdify((M, N), cg, 'jax')
assert (f(ma, mb) == ma+mb).all()
cg = ArrayAdd(M, N, P)
f = lambdify((M, N, P), cg, 'jax')
assert (f(ma, mb, mc) == ma+mb+mc).all()
cg = ArrayAdd(M, N, P, Q)
f = lambdify((M, N, P, Q), cg, 'jax')
assert (f(ma, mb, mc, md) == ma+mb+mc+md).all()
cg = PermuteDims(M, [1, 0])
f = lambdify((M,), cg, 'jax')
assert (f(ma) == ma.T).all()
cg = PermuteDims(ArrayTensorProduct(M, N), [1, 2, 3, 0])
f = lambdify((M, N), cg, 'jax')
assert (f(ma, mb) == jax.numpy.transpose(jax.numpy.einsum(ma, [0, 1], mb, [2, 3]), (1, 2, 3, 0))).all()
cg = ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2))
f = lambdify((M, N), cg, 'jax')
assert (f(ma, mb) == jax.numpy.diagonal(jax.numpy.einsum(ma, [0, 1], mb, [2, 3]), axis1=1, axis2=2)).all()
def test_jax_relational():
if not jax:
skip("JAX not installed")
e = Equality(x, 1)
f = lambdify((x,), e, 'jax')
x_ = jax.numpy.array([0, 1, 2])
assert jax.numpy.array_equal(f(x_), [False, True, False])
e = Unequality(x, 1)
f = lambdify((x,), e, 'jax')
x_ = jax.numpy.array([0, 1, 2])
assert jax.numpy.array_equal(f(x_), [True, False, True])
e = (x < 1)
f = lambdify((x,), e, 'jax')
x_ = jax.numpy.array([0, 1, 2])
assert jax.numpy.array_equal(f(x_), [True, False, False])
e = (x <= 1)
f = lambdify((x,), e, 'jax')
x_ = jax.numpy.array([0, 1, 2])
assert jax.numpy.array_equal(f(x_), [True, True, False])
e = (x > 1)
f = lambdify((x,), e, 'jax')
x_ = jax.numpy.array([0, 1, 2])
assert jax.numpy.array_equal(f(x_), [False, False, True])
e = (x >= 1)
f = lambdify((x,), e, 'jax')
x_ = jax.numpy.array([0, 1, 2])
assert jax.numpy.array_equal(f(x_), [False, True, True])
# Multi-condition expressions
e = (x >= 1) & (x < 2)
f = lambdify((x,), e, 'jax')
x_ = jax.numpy.array([0, 1, 2])
assert jax.numpy.array_equal(f(x_), [False, True, False])
e = (x >= 1) | (x < 2)
f = lambdify((x,), e, 'jax')
x_ = jax.numpy.array([0, 1, 2])
assert jax.numpy.array_equal(f(x_), [True, True, True])
def test_jax_mod():
if not jax:
skip("JAX not installed")
e = Mod(a, b)
f = lambdify((a, b), e, 'jax')
a_ = jax.numpy.array([0, 1, 2, 3])
b_ = 2
assert jax.numpy.array_equal(f(a_, b_), [0, 1, 0, 1])
a_ = jax.numpy.array([0, 1, 2, 3])
b_ = jax.numpy.array([2, 2, 2, 2])
assert jax.numpy.array_equal(f(a_, b_), [0, 1, 0, 1])
a_ = jax.numpy.array([2, 3, 4, 5])
b_ = jax.numpy.array([2, 3, 4, 5])
assert jax.numpy.array_equal(f(a_, b_), [0, 0, 0, 0])
def test_jax_pow():
if not jax:
skip('JAX not installed')
expr = Pow(2, -1, evaluate=False)
f = lambdify([], expr, 'jax')
assert f() == 0.5
def test_jax_expm1():
if not jax:
skip("JAX not installed")
f = lambdify((a,), expm1(a), 'jax')
assert abs(f(1e-10) - 1e-10 - 5e-21) <= 1e-10 * JAX_DEFAULT_EPSILON
def test_jax_log1p():
if not jax:
skip("JAX not installed")
f = lambdify((a,), log1p(a), 'jax')
assert abs(f(1e-99) - 1e-99) <= 1e-99 * JAX_DEFAULT_EPSILON
def test_jax_hypot():
if not jax:
skip("JAX not installed")
assert abs(lambdify((a, b), hypot(a, b), 'jax')(3, 4) - 5) <= JAX_DEFAULT_EPSILON
def test_jax_log10():
if not jax:
skip("JAX not installed")
assert abs(lambdify((a,), log10(a), 'jax')(100) - 2) <= JAX_DEFAULT_EPSILON
def test_jax_exp2():
if not jax:
skip("JAX not installed")
assert abs(lambdify((a,), exp2(a), 'jax')(5) - 32) <= JAX_DEFAULT_EPSILON
def test_jax_log2():
if not jax:
skip("JAX not installed")
assert abs(lambdify((a,), log2(a), 'jax')(256) - 8) <= JAX_DEFAULT_EPSILON
def test_jax_Sqrt():
if not jax:
skip("JAX not installed")
assert abs(lambdify((a,), Sqrt(a), 'jax')(4) - 2) <= JAX_DEFAULT_EPSILON
def test_jax_sqrt():
if not jax:
skip("JAX not installed")
assert abs(lambdify((a,), sqrt(a), 'jax')(4) - 2) <= JAX_DEFAULT_EPSILON
def test_jax_matsolve():
if not jax:
skip("JAX not installed")
M = MatrixSymbol("M", 3, 3)
x = MatrixSymbol("x", 3, 1)
expr = M**(-1) * x + x
matsolve_expr = MatrixSolve(M, x) + x
f = lambdify((M, x), expr, 'jax')
f_matsolve = lambdify((M, x), matsolve_expr, 'jax')
m0 = jax.numpy.array([[1, 2, 3], [3, 2, 5], [5, 6, 7]])
assert jax.numpy.linalg.matrix_rank(m0) == 3
x0 = jax.numpy.array([3, 4, 5])
assert jax.numpy.allclose(f_matsolve(m0, x0), f(m0, x0))
def test_16857():
if not jax:
skip("JAX not installed")
a_1 = MatrixSymbol('a_1', 10, 3)
a_2 = MatrixSymbol('a_2', 10, 3)
a_3 = MatrixSymbol('a_3', 10, 3)
a_4 = MatrixSymbol('a_4', 10, 3)
A = BlockMatrix([[a_1, a_2], [a_3, a_4]])
assert A.shape == (20, 6)
printer = JaxPrinter()
assert printer.doprint(A) == 'jax.numpy.block([[a_1, a_2], [a_3, a_4]])'
def test_issue_17006():
if not jax:
skip("JAX not installed")
M = MatrixSymbol("M", 2, 2)
f = lambdify(M, M + Identity(2), 'jax')
ma = jax.numpy.array([[1, 2], [3, 4]])
mr = jax.numpy.array([[2, 2], [3, 5]])
assert (f(ma) == mr).all()
from sympy.core.symbol import symbols
n = symbols('n', integer=True)
N = MatrixSymbol("M", n, n)
raises(NotImplementedError, lambda: lambdify(N, N + Identity(n), 'jax'))
def test_jax_array():
assert JaxPrinter().doprint(Array(((1, 2), (3, 5)))) == 'jax.numpy.array([[1, 2], [3, 5]])'
assert JaxPrinter().doprint(Array((1, 2))) == 'jax.numpy.array((1, 2))'
def test_jax_known_funcs_consts():
assert _jax_known_constants['NaN'] == 'jax.numpy.nan'
assert _jax_known_constants['EulerGamma'] == 'jax.numpy.euler_gamma'
assert _jax_known_functions['acos'] == 'jax.numpy.arccos'
assert _jax_known_functions['log'] == 'jax.numpy.log'
def test_jax_print_methods():
prntr = JaxPrinter()
assert hasattr(prntr, '_print_acos')
assert hasattr(prntr, '_print_log')
def test_jax_printmethod():
printer = JaxPrinter()
assert hasattr(printer, 'printmethod')
assert printer.printmethod == '_jaxcode'
def test_jax_custom_print_method():
class expm1(Function):
def _jaxcode(self, printer):
x, = self.args
function = f'expm1({printer._print(x)})'
return printer._module_format(printer._module + '.' + function)
printer = JaxPrinter()
assert printer.doprint(expm1(Symbol('x'))) == 'jax.numpy.expm1(x)'

View File

@ -0,0 +1,396 @@
from sympy.core import (pi, oo, symbols, Rational, Integer, GoldenRatio,
EulerGamma, Catalan, Lambda, Dummy, S, Eq, Ne, Le,
Lt, Gt, Ge, Mod)
from sympy.functions import (Piecewise, sin, cos, Abs, exp, ceiling, sqrt,
sinh, cosh, tanh, asin, acos, acosh, Max, Min)
from sympy.testing.pytest import raises
from sympy.printing.jscode import JavascriptCodePrinter
from sympy.utilities.lambdify import implemented_function
from sympy.tensor import IndexedBase, Idx
from sympy.matrices import Matrix, MatrixSymbol
from sympy.printing.jscode import jscode
x, y, z = symbols('x,y,z')
def test_printmethod():
assert jscode(Abs(x)) == "Math.abs(x)"
def test_jscode_sqrt():
assert jscode(sqrt(x)) == "Math.sqrt(x)"
assert jscode(x**0.5) == "Math.sqrt(x)"
assert jscode(x**(S.One/3)) == "Math.cbrt(x)"
def test_jscode_Pow():
g = implemented_function('g', Lambda(x, 2*x))
assert jscode(x**3) == "Math.pow(x, 3)"
assert jscode(x**(y**3)) == "Math.pow(x, Math.pow(y, 3))"
assert jscode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \
"Math.pow(3.5*2*x, -x + Math.pow(y, x))/(Math.pow(x, 2) + y)"
assert jscode(x**-1.0) == '1/x'
def test_jscode_constants_mathh():
assert jscode(exp(1)) == "Math.E"
assert jscode(pi) == "Math.PI"
assert jscode(oo) == "Number.POSITIVE_INFINITY"
assert jscode(-oo) == "Number.NEGATIVE_INFINITY"
def test_jscode_constants_other():
assert jscode(
2*GoldenRatio) == "var GoldenRatio = %s;\n2*GoldenRatio" % GoldenRatio.evalf(17)
assert jscode(2*Catalan) == "var Catalan = %s;\n2*Catalan" % Catalan.evalf(17)
assert jscode(
2*EulerGamma) == "var EulerGamma = %s;\n2*EulerGamma" % EulerGamma.evalf(17)
def test_jscode_Rational():
assert jscode(Rational(3, 7)) == "3/7"
assert jscode(Rational(18, 9)) == "2"
assert jscode(Rational(3, -7)) == "-3/7"
assert jscode(Rational(-3, -7)) == "3/7"
def test_Relational():
assert jscode(Eq(x, y)) == "x == y"
assert jscode(Ne(x, y)) == "x != y"
assert jscode(Le(x, y)) == "x <= y"
assert jscode(Lt(x, y)) == "x < y"
assert jscode(Gt(x, y)) == "x > y"
assert jscode(Ge(x, y)) == "x >= y"
def test_Mod():
assert jscode(Mod(x, y)) == '((x % y) + y) % y'
assert jscode(Mod(x, x + y)) == '((x % (x + y)) + (x + y)) % (x + y)'
p1, p2 = symbols('p1 p2', positive=True)
assert jscode(Mod(p1, p2)) == 'p1 % p2'
assert jscode(Mod(p1, p2 + 3)) == 'p1 % (p2 + 3)'
assert jscode(Mod(-3, -7, evaluate=False)) == '(-3) % (-7)'
assert jscode(-Mod(p1, p2)) == '-(p1 % p2)'
assert jscode(x*Mod(p1, p2)) == 'x*(p1 % p2)'
def test_jscode_Integer():
assert jscode(Integer(67)) == "67"
assert jscode(Integer(-1)) == "-1"
def test_jscode_functions():
assert jscode(sin(x) ** cos(x)) == "Math.pow(Math.sin(x), Math.cos(x))"
assert jscode(sinh(x) * cosh(x)) == "Math.sinh(x)*Math.cosh(x)"
assert jscode(Max(x, y) + Min(x, y)) == "Math.max(x, y) + Math.min(x, y)"
assert jscode(tanh(x)*acosh(y)) == "Math.tanh(x)*Math.acosh(y)"
assert jscode(asin(x)-acos(y)) == "-Math.acos(y) + Math.asin(x)"
def test_jscode_inline_function():
x = symbols('x')
g = implemented_function('g', Lambda(x, 2*x))
assert jscode(g(x)) == "2*x"
g = implemented_function('g', Lambda(x, 2*x/Catalan))
assert jscode(g(x)) == "var Catalan = %s;\n2*x/Catalan" % Catalan.evalf(17)
A = IndexedBase('A')
i = Idx('i', symbols('n', integer=True))
g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x)))
assert jscode(g(A[i]), assign_to=A[i]) == (
"for (var i=0; i<n; i++){\n"
" A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n"
"}"
)
def test_jscode_exceptions():
assert jscode(ceiling(x)) == "Math.ceil(x)"
assert jscode(Abs(x)) == "Math.abs(x)"
def test_jscode_boolean():
assert jscode(x & y) == "x && y"
assert jscode(x | y) == "x || y"
assert jscode(~x) == "!x"
assert jscode(x & y & z) == "x && y && z"
assert jscode(x | y | z) == "x || y || z"
assert jscode((x & y) | z) == "z || x && y"
assert jscode((x | y) & z) == "z && (x || y)"
def test_jscode_Piecewise():
expr = Piecewise((x, x < 1), (x**2, True))
p = jscode(expr)
s = \
"""\
((x < 1) ? (
x
)
: (
Math.pow(x, 2)
))\
"""
assert p == s
assert jscode(expr, assign_to="c") == (
"if (x < 1) {\n"
" c = x;\n"
"}\n"
"else {\n"
" c = Math.pow(x, 2);\n"
"}")
# Check that Piecewise without a True (default) condition error
expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
raises(ValueError, lambda: jscode(expr))
def test_jscode_Piecewise_deep():
p = jscode(2*Piecewise((x, x < 1), (x**2, True)))
s = \
"""\
2*((x < 1) ? (
x
)
: (
Math.pow(x, 2)
))\
"""
assert p == s
def test_jscode_settings():
raises(TypeError, lambda: jscode(sin(x), method="garbage"))
def test_jscode_Indexed():
n, m, o = symbols('n m o', integer=True)
i, j, k = Idx('i', n), Idx('j', m), Idx('k', o)
p = JavascriptCodePrinter()
p._not_c = set()
x = IndexedBase('x')[j]
assert p._print_Indexed(x) == 'x[j]'
A = IndexedBase('A')[i, j]
assert p._print_Indexed(A) == 'A[%s]' % (m*i+j)
B = IndexedBase('B')[i, j, k]
assert p._print_Indexed(B) == 'B[%s]' % (i*o*m+j*o+k)
assert p._not_c == set()
def test_jscode_loops_matrix_vector():
n, m = symbols('n m', integer=True)
A = IndexedBase('A')
x = IndexedBase('x')
y = IndexedBase('y')
i = Idx('i', m)
j = Idx('j', n)
s = (
'for (var i=0; i<m; i++){\n'
' y[i] = 0;\n'
'}\n'
'for (var i=0; i<m; i++){\n'
' for (var j=0; j<n; j++){\n'
' y[i] = A[n*i + j]*x[j] + y[i];\n'
' }\n'
'}'
)
c = jscode(A[i, j]*x[j], assign_to=y[i])
assert c == s
def test_dummy_loops():
i, m = symbols('i m', integer=True, cls=Dummy)
x = IndexedBase('x')
y = IndexedBase('y')
i = Idx(i, m)
expected = (
'for (var i_%(icount)i=0; i_%(icount)i<m_%(mcount)i; i_%(icount)i++){\n'
' y[i_%(icount)i] = x[i_%(icount)i];\n'
'}'
) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index}
code = jscode(x[i], assign_to=y[i])
assert code == expected
def test_jscode_loops_add():
n, m = symbols('n m', integer=True)
A = IndexedBase('A')
x = IndexedBase('x')
y = IndexedBase('y')
z = IndexedBase('z')
i = Idx('i', m)
j = Idx('j', n)
s = (
'for (var i=0; i<m; i++){\n'
' y[i] = x[i] + z[i];\n'
'}\n'
'for (var i=0; i<m; i++){\n'
' for (var j=0; j<n; j++){\n'
' y[i] = A[n*i + j]*x[j] + y[i];\n'
' }\n'
'}'
)
c = jscode(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i])
assert c == s
def test_jscode_loops_multiple_contractions():
n, m, o, p = symbols('n m o p', integer=True)
a = IndexedBase('a')
b = IndexedBase('b')
y = IndexedBase('y')
i = Idx('i', m)
j = Idx('j', n)
k = Idx('k', o)
l = Idx('l', p)
s = (
'for (var i=0; i<m; i++){\n'
' y[i] = 0;\n'
'}\n'
'for (var i=0; i<m; i++){\n'
' for (var j=0; j<n; j++){\n'
' for (var k=0; k<o; k++){\n'
' for (var l=0; l<p; l++){\n'
' y[i] = a[%s]*b[%s] + y[i];\n' % (i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
' }\n'
' }\n'
' }\n'
'}'
)
c = jscode(b[j, k, l]*a[i, j, k, l], assign_to=y[i])
assert c == s
def test_jscode_loops_addfactor():
n, m, o, p = symbols('n m o p', integer=True)
a = IndexedBase('a')
b = IndexedBase('b')
c = IndexedBase('c')
y = IndexedBase('y')
i = Idx('i', m)
j = Idx('j', n)
k = Idx('k', o)
l = Idx('l', p)
s = (
'for (var i=0; i<m; i++){\n'
' y[i] = 0;\n'
'}\n'
'for (var i=0; i<m; i++){\n'
' for (var j=0; j<n; j++){\n'
' for (var k=0; k<o; k++){\n'
' for (var l=0; l<p; l++){\n'
' y[i] = (a[%s] + b[%s])*c[%s] + y[i];\n' % (i*n*o*p + j*o*p + k*p + l, i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
' }\n'
' }\n'
' }\n'
'}'
)
c = jscode((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i])
assert c == s
def test_jscode_loops_multiple_terms():
n, m, o, p = symbols('n m o p', integer=True)
a = IndexedBase('a')
b = IndexedBase('b')
c = IndexedBase('c')
y = IndexedBase('y')
i = Idx('i', m)
j = Idx('j', n)
k = Idx('k', o)
s0 = (
'for (var i=0; i<m; i++){\n'
' y[i] = 0;\n'
'}\n'
)
s1 = (
'for (var i=0; i<m; i++){\n'
' for (var j=0; j<n; j++){\n'
' for (var k=0; k<o; k++){\n'
' y[i] = b[j]*b[k]*c[%s] + y[i];\n' % (i*n*o + j*o + k) +\
' }\n'
' }\n'
'}\n'
)
s2 = (
'for (var i=0; i<m; i++){\n'
' for (var k=0; k<o; k++){\n'
' y[i] = a[%s]*b[k] + y[i];\n' % (i*o + k) +\
' }\n'
'}\n'
)
s3 = (
'for (var i=0; i<m; i++){\n'
' for (var j=0; j<n; j++){\n'
' y[i] = a[%s]*b[j] + y[i];\n' % (i*n + j) +\
' }\n'
'}\n'
)
c = jscode(
b[j]*a[i, j] + b[k]*a[i, k] + b[j]*b[k]*c[i, j, k], assign_to=y[i])
assert (c == s0 + s1 + s2 + s3[:-1] or
c == s0 + s1 + s3 + s2[:-1] or
c == s0 + s2 + s1 + s3[:-1] or
c == s0 + s2 + s3 + s1[:-1] or
c == s0 + s3 + s1 + s2[:-1] or
c == s0 + s3 + s2 + s1[:-1])
def test_Matrix_printing():
# Test returning a Matrix
mat = Matrix([x*y, Piecewise((2 + x, y>0), (y, True)), sin(z)])
A = MatrixSymbol('A', 3, 1)
assert jscode(mat, A) == (
"A[0] = x*y;\n"
"if (y > 0) {\n"
" A[1] = x + 2;\n"
"}\n"
"else {\n"
" A[1] = y;\n"
"}\n"
"A[2] = Math.sin(z);")
# Test using MatrixElements in expressions
expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0]
assert jscode(expr) == (
"((x > 0) ? (\n"
" 2*A[2]\n"
")\n"
": (\n"
" A[2]\n"
")) + Math.sin(A[1]) + A[0]")
# Test using MatrixElements in a Matrix
q = MatrixSymbol('q', 5, 1)
M = MatrixSymbol('M', 3, 3)
m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])],
[q[1,0] + q[2,0], q[3, 0], 5],
[2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]])
assert jscode(m, M) == (
"M[0] = Math.sin(q[1]);\n"
"M[1] = 0;\n"
"M[2] = Math.cos(q[2]);\n"
"M[3] = q[1] + q[2];\n"
"M[4] = q[3];\n"
"M[5] = 5;\n"
"M[6] = 2*q[4]/q[1];\n"
"M[7] = Math.sqrt(q[0]) + 4;\n"
"M[8] = 0;")
def test_MatrixElement_printing():
# test cases for issue #11821
A = MatrixSymbol("A", 1, 3)
B = MatrixSymbol("B", 1, 3)
C = MatrixSymbol("C", 1, 3)
assert(jscode(A[0, 0]) == "A[0]")
assert(jscode(3 * A[0, 0]) == "3*A[0]")
F = C[0, 0].subs(C, A - B)
assert(jscode(F) == "(A - B)[0]")

View File

@ -0,0 +1,386 @@
from sympy.core import (S, pi, oo, symbols, Function, Rational, Integer,
Tuple, Symbol, Eq, Ne, Le, Lt, Gt, Ge)
from sympy.core import EulerGamma, GoldenRatio, Catalan, Lambda, Mul, Pow
from sympy.functions import Piecewise, sqrt, ceiling, exp, sin, cos
from sympy.testing.pytest import raises
from sympy.utilities.lambdify import implemented_function
from sympy.matrices import (eye, Matrix, MatrixSymbol, Identity,
HadamardProduct, SparseMatrix)
from sympy.functions.special.bessel import (jn, yn, besselj, bessely, besseli,
besselk, hankel1, hankel2, airyai,
airybi, airyaiprime, airybiprime)
from sympy.testing.pytest import XFAIL
from sympy.printing.julia import julia_code
x, y, z = symbols('x,y,z')
def test_Integer():
assert julia_code(Integer(67)) == "67"
assert julia_code(Integer(-1)) == "-1"
def test_Rational():
assert julia_code(Rational(3, 7)) == "3 // 7"
assert julia_code(Rational(18, 9)) == "2"
assert julia_code(Rational(3, -7)) == "-3 // 7"
assert julia_code(Rational(-3, -7)) == "3 // 7"
assert julia_code(x + Rational(3, 7)) == "x + 3 // 7"
assert julia_code(Rational(3, 7)*x) == "(3 // 7) * x"
def test_Relational():
assert julia_code(Eq(x, y)) == "x == y"
assert julia_code(Ne(x, y)) == "x != y"
assert julia_code(Le(x, y)) == "x <= y"
assert julia_code(Lt(x, y)) == "x < y"
assert julia_code(Gt(x, y)) == "x > y"
assert julia_code(Ge(x, y)) == "x >= y"
def test_Function():
assert julia_code(sin(x) ** cos(x)) == "sin(x) .^ cos(x)"
assert julia_code(abs(x)) == "abs(x)"
assert julia_code(ceiling(x)) == "ceil(x)"
def test_Pow():
assert julia_code(x**3) == "x .^ 3"
assert julia_code(x**(y**3)) == "x .^ (y .^ 3)"
assert julia_code(x**Rational(2, 3)) == 'x .^ (2 // 3)'
g = implemented_function('g', Lambda(x, 2*x))
assert julia_code(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \
"(3.5 * 2 * x) .^ (-x + y .^ x) ./ (x .^ 2 + y)"
# For issue 14160
assert julia_code(Mul(-2, x, Pow(Mul(y,y,evaluate=False), -1, evaluate=False),
evaluate=False)) == '-2 * x ./ (y .* y)'
def test_basic_ops():
assert julia_code(x*y) == "x .* y"
assert julia_code(x + y) == "x + y"
assert julia_code(x - y) == "x - y"
assert julia_code(-x) == "-x"
def test_1_over_x_and_sqrt():
# 1.0 and 0.5 would do something different in regular StrPrinter,
# but these are exact in IEEE floating point so no different here.
assert julia_code(1/x) == '1 ./ x'
assert julia_code(x**-1) == julia_code(x**-1.0) == '1 ./ x'
assert julia_code(1/sqrt(x)) == '1 ./ sqrt(x)'
assert julia_code(x**-S.Half) == julia_code(x**-0.5) == '1 ./ sqrt(x)'
assert julia_code(sqrt(x)) == 'sqrt(x)'
assert julia_code(x**S.Half) == julia_code(x**0.5) == 'sqrt(x)'
assert julia_code(1/pi) == '1 / pi'
assert julia_code(pi**-1) == julia_code(pi**-1.0) == '1 / pi'
assert julia_code(pi**-0.5) == '1 / sqrt(pi)'
def test_mix_number_mult_symbols():
assert julia_code(3*x) == "3 * x"
assert julia_code(pi*x) == "pi * x"
assert julia_code(3/x) == "3 ./ x"
assert julia_code(pi/x) == "pi ./ x"
assert julia_code(x/3) == "x / 3"
assert julia_code(x/pi) == "x / pi"
assert julia_code(x*y) == "x .* y"
assert julia_code(3*x*y) == "3 * x .* y"
assert julia_code(3*pi*x*y) == "3 * pi * x .* y"
assert julia_code(x/y) == "x ./ y"
assert julia_code(3*x/y) == "3 * x ./ y"
assert julia_code(x*y/z) == "x .* y ./ z"
assert julia_code(x/y*z) == "x .* z ./ y"
assert julia_code(1/x/y) == "1 ./ (x .* y)"
assert julia_code(2*pi*x/y/z) == "2 * pi * x ./ (y .* z)"
assert julia_code(3*pi/x) == "3 * pi ./ x"
assert julia_code(S(3)/5) == "3 // 5"
assert julia_code(S(3)/5*x) == "(3 // 5) * x"
assert julia_code(x/y/z) == "x ./ (y .* z)"
assert julia_code((x+y)/z) == "(x + y) ./ z"
assert julia_code((x+y)/(z+x)) == "(x + y) ./ (x + z)"
assert julia_code((x+y)/EulerGamma) == "(x + y) / eulergamma"
assert julia_code(x/3/pi) == "x / (3 * pi)"
assert julia_code(S(3)/5*x*y/pi) == "(3 // 5) * x .* y / pi"
def test_mix_number_pow_symbols():
assert julia_code(pi**3) == 'pi ^ 3'
assert julia_code(x**2) == 'x .^ 2'
assert julia_code(x**(pi**3)) == 'x .^ (pi ^ 3)'
assert julia_code(x**y) == 'x .^ y'
assert julia_code(x**(y**z)) == 'x .^ (y .^ z)'
assert julia_code((x**y)**z) == '(x .^ y) .^ z'
def test_imag():
I = S('I')
assert julia_code(I) == "im"
assert julia_code(5*I) == "5im"
assert julia_code((S(3)/2)*I) == "(3 // 2) * im"
assert julia_code(3+4*I) == "3 + 4im"
def test_constants():
assert julia_code(pi) == "pi"
assert julia_code(oo) == "Inf"
assert julia_code(-oo) == "-Inf"
assert julia_code(S.NegativeInfinity) == "-Inf"
assert julia_code(S.NaN) == "NaN"
assert julia_code(S.Exp1) == "e"
assert julia_code(exp(1)) == "e"
def test_constants_other():
assert julia_code(2*GoldenRatio) == "2 * golden"
assert julia_code(2*Catalan) == "2 * catalan"
assert julia_code(2*EulerGamma) == "2 * eulergamma"
def test_boolean():
assert julia_code(x & y) == "x && y"
assert julia_code(x | y) == "x || y"
assert julia_code(~x) == "!x"
assert julia_code(x & y & z) == "x && y && z"
assert julia_code(x | y | z) == "x || y || z"
assert julia_code((x & y) | z) == "z || x && y"
assert julia_code((x | y) & z) == "z && (x || y)"
def test_Matrices():
assert julia_code(Matrix(1, 1, [10])) == "[10]"
A = Matrix([[1, sin(x/2), abs(x)],
[0, 1, pi],
[0, exp(1), ceiling(x)]]);
expected = ("[1 sin(x / 2) abs(x);\n"
"0 1 pi;\n"
"0 e ceil(x)]")
assert julia_code(A) == expected
# row and columns
assert julia_code(A[:,0]) == "[1, 0, 0]"
assert julia_code(A[0,:]) == "[1 sin(x / 2) abs(x)]"
# empty matrices
assert julia_code(Matrix(0, 0, [])) == 'zeros(0, 0)'
assert julia_code(Matrix(0, 3, [])) == 'zeros(0, 3)'
# annoying to read but correct
assert julia_code(Matrix([[x, x - y, -y]])) == "[x x - y -y]"
def test_vector_entries_hadamard():
# For a row or column, user might to use the other dimension
A = Matrix([[1, sin(2/x), 3*pi/x/5]])
assert julia_code(A) == "[1 sin(2 ./ x) (3 // 5) * pi ./ x]"
assert julia_code(A.T) == "[1, sin(2 ./ x), (3 // 5) * pi ./ x]"
@XFAIL
def test_Matrices_entries_not_hadamard():
# For Matrix with col >= 2, row >= 2, they need to be scalars
# FIXME: is it worth worrying about this? Its not wrong, just
# leave it user's responsibility to put scalar data for x.
A = Matrix([[1, sin(2/x), 3*pi/x/5], [1, 2, x*y]])
expected = ("[1 sin(2/x) 3*pi/(5*x);\n"
"1 2 x*y]") # <- we give x.*y
assert julia_code(A) == expected
def test_MatrixSymbol():
n = Symbol('n', integer=True)
A = MatrixSymbol('A', n, n)
B = MatrixSymbol('B', n, n)
assert julia_code(A*B) == "A * B"
assert julia_code(B*A) == "B * A"
assert julia_code(2*A*B) == "2 * A * B"
assert julia_code(B*2*A) == "2 * B * A"
assert julia_code(A*(B + 3*Identity(n))) == "A * (3 * eye(n) + B)"
assert julia_code(A**(x**2)) == "A ^ (x .^ 2)"
assert julia_code(A**3) == "A ^ 3"
assert julia_code(A**S.Half) == "A ^ (1 // 2)"
def test_special_matrices():
assert julia_code(6*Identity(3)) == "6 * eye(3)"
def test_containers():
assert julia_code([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \
"Any[1, 2, 3, Any[4, 5, Any[6, 7]], 8, Any[9, 10], 11]"
assert julia_code((1, 2, (3, 4))) == "(1, 2, (3, 4))"
assert julia_code([1]) == "Any[1]"
assert julia_code((1,)) == "(1,)"
assert julia_code(Tuple(*[1, 2, 3])) == "(1, 2, 3)"
assert julia_code((1, x*y, (3, x**2))) == "(1, x .* y, (3, x .^ 2))"
# scalar, matrix, empty matrix and empty list
assert julia_code((1, eye(3), Matrix(0, 0, []), [])) == "(1, [1 0 0;\n0 1 0;\n0 0 1], zeros(0, 0), Any[])"
def test_julia_noninline():
source = julia_code((x+y)/Catalan, assign_to='me', inline=False)
expected = (
"const Catalan = %s\n"
"me = (x + y) / Catalan"
) % Catalan.evalf(17)
assert source == expected
def test_julia_piecewise():
expr = Piecewise((x, x < 1), (x**2, True))
assert julia_code(expr) == "((x < 1) ? (x) : (x .^ 2))"
assert julia_code(expr, assign_to="r") == (
"r = ((x < 1) ? (x) : (x .^ 2))")
assert julia_code(expr, assign_to="r", inline=False) == (
"if (x < 1)\n"
" r = x\n"
"else\n"
" r = x .^ 2\n"
"end")
expr = Piecewise((x**2, x < 1), (x**3, x < 2), (x**4, x < 3), (x**5, True))
expected = ("((x < 1) ? (x .^ 2) :\n"
"(x < 2) ? (x .^ 3) :\n"
"(x < 3) ? (x .^ 4) : (x .^ 5))")
assert julia_code(expr) == expected
assert julia_code(expr, assign_to="r") == "r = " + expected
assert julia_code(expr, assign_to="r", inline=False) == (
"if (x < 1)\n"
" r = x .^ 2\n"
"elseif (x < 2)\n"
" r = x .^ 3\n"
"elseif (x < 3)\n"
" r = x .^ 4\n"
"else\n"
" r = x .^ 5\n"
"end")
# Check that Piecewise without a True (default) condition error
expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
raises(ValueError, lambda: julia_code(expr))
def test_julia_piecewise_times_const():
pw = Piecewise((x, x < 1), (x**2, True))
assert julia_code(2*pw) == "2 * ((x < 1) ? (x) : (x .^ 2))"
assert julia_code(pw/x) == "((x < 1) ? (x) : (x .^ 2)) ./ x"
assert julia_code(pw/(x*y)) == "((x < 1) ? (x) : (x .^ 2)) ./ (x .* y)"
assert julia_code(pw/3) == "((x < 1) ? (x) : (x .^ 2)) / 3"
def test_julia_matrix_assign_to():
A = Matrix([[1, 2, 3]])
assert julia_code(A, assign_to='a') == "a = [1 2 3]"
A = Matrix([[1, 2], [3, 4]])
assert julia_code(A, assign_to='A') == "A = [1 2;\n3 4]"
def test_julia_matrix_assign_to_more():
# assigning to Symbol or MatrixSymbol requires lhs/rhs match
A = Matrix([[1, 2, 3]])
B = MatrixSymbol('B', 1, 3)
C = MatrixSymbol('C', 2, 3)
assert julia_code(A, assign_to=B) == "B = [1 2 3]"
raises(ValueError, lambda: julia_code(A, assign_to=x))
raises(ValueError, lambda: julia_code(A, assign_to=C))
def test_julia_matrix_1x1():
A = Matrix([[3]])
B = MatrixSymbol('B', 1, 1)
C = MatrixSymbol('C', 1, 2)
assert julia_code(A, assign_to=B) == "B = [3]"
# FIXME?
#assert julia_code(A, assign_to=x) == "x = [3]"
raises(ValueError, lambda: julia_code(A, assign_to=C))
def test_julia_matrix_elements():
A = Matrix([[x, 2, x*y]])
assert julia_code(A[0, 0]**2 + A[0, 1] + A[0, 2]) == "x .^ 2 + x .* y + 2"
A = MatrixSymbol('AA', 1, 3)
assert julia_code(A) == "AA"
assert julia_code(A[0, 0]**2 + sin(A[0,1]) + A[0,2]) == \
"sin(AA[1,2]) + AA[1,1] .^ 2 + AA[1,3]"
assert julia_code(sum(A)) == "AA[1,1] + AA[1,2] + AA[1,3]"
def test_julia_boolean():
assert julia_code(True) == "true"
assert julia_code(S.true) == "true"
assert julia_code(False) == "false"
assert julia_code(S.false) == "false"
def test_julia_not_supported():
with raises(NotImplementedError):
julia_code(S.ComplexInfinity)
f = Function('f')
assert julia_code(f(x).diff(x), strict=False) == (
"# Not supported in Julia:\n"
"# Derivative\n"
"Derivative(f(x), x)"
)
def test_trick_indent_with_end_else_words():
# words starting with "end" or "else" do not confuse the indenter
t1 = S('endless');
t2 = S('elsewhere');
pw = Piecewise((t1, x < 0), (t2, x <= 1), (1, True))
assert julia_code(pw, inline=False) == (
"if (x < 0)\n"
" endless\n"
"elseif (x <= 1)\n"
" elsewhere\n"
"else\n"
" 1\n"
"end")
def test_haramard():
A = MatrixSymbol('A', 3, 3)
B = MatrixSymbol('B', 3, 3)
v = MatrixSymbol('v', 3, 1)
h = MatrixSymbol('h', 1, 3)
C = HadamardProduct(A, B)
assert julia_code(C) == "A .* B"
assert julia_code(C*v) == "(A .* B) * v"
assert julia_code(h*C*v) == "h * (A .* B) * v"
assert julia_code(C*A) == "(A .* B) * A"
# mixing Hadamard and scalar strange b/c we vectorize scalars
assert julia_code(C*x*y) == "(x .* y) * (A .* B)"
def test_sparse():
M = SparseMatrix(5, 6, {})
M[2, 2] = 10;
M[1, 2] = 20;
M[1, 3] = 22;
M[0, 3] = 30;
M[3, 0] = x*y;
assert julia_code(M) == (
"sparse([4, 2, 3, 1, 2], [1, 3, 3, 4, 4], [x .* y, 20, 10, 30, 22], 5, 6)"
)
def test_specfun():
n = Symbol('n')
for f in [besselj, bessely, besseli, besselk]:
assert julia_code(f(n, x)) == f.__name__ + '(n, x)'
for f in [airyai, airyaiprime, airybi, airybiprime]:
assert julia_code(f(x)) == f.__name__ + '(x)'
assert julia_code(hankel1(n, x)) == 'hankelh1(n, x)'
assert julia_code(hankel2(n, x)) == 'hankelh2(n, x)'
assert julia_code(jn(n, x)) == 'sqrt(2) * sqrt(pi) * sqrt(1 ./ x) .* besselj(n + 1 // 2, x) / 2'
assert julia_code(yn(n, x)) == 'sqrt(2) * sqrt(pi) * sqrt(1 ./ x) .* bessely(n + 1 // 2, x) / 2'
def test_MatrixElement_printing():
# test cases for issue #11821
A = MatrixSymbol("A", 1, 3)
B = MatrixSymbol("B", 1, 3)
C = MatrixSymbol("C", 1, 3)
assert(julia_code(A[0, 0]) == "A[1,1]")
assert(julia_code(3 * A[0, 0]) == "3 * A[1,1]")
F = C[0, 0].subs(C, A - B)
assert(julia_code(F) == "(A - B)[1,1]")

View File

@ -0,0 +1,246 @@
from sympy.concrete.summations import Sum
from sympy.core.expr import Expr
from sympy.core.symbol import symbols
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.piecewise import Piecewise
from sympy.functions.elementary.trigonometric import sin
from sympy.matrices.dense import MutableDenseMatrix as Matrix
from sympy.sets.sets import Interval
from sympy.utilities.lambdify import lambdify
from sympy.testing.pytest import raises
from sympy.printing.tensorflow import TensorflowPrinter
from sympy.printing.lambdarepr import lambdarepr, LambdaPrinter, NumExprPrinter
x, y, z = symbols("x,y,z")
i, a, b = symbols("i,a,b")
j, c, d = symbols("j,c,d")
def test_basic():
assert lambdarepr(x*y) == "x*y"
assert lambdarepr(x + y) in ["y + x", "x + y"]
assert lambdarepr(x**y) == "x**y"
def test_matrix():
# Test printing a Matrix that has an element that is printed differently
# with the LambdaPrinter than with the StrPrinter.
e = x % 2
assert lambdarepr(e) != str(e)
assert lambdarepr(Matrix([e])) == 'ImmutableDenseMatrix([[x % 2]])'
def test_piecewise():
# In each case, test eval() the lambdarepr() to make sure there are a
# correct number of parentheses. It will give a SyntaxError if there aren't.
h = "lambda x: "
p = Piecewise((x, x < 0))
l = lambdarepr(p)
eval(h + l)
assert l == "((x) if (x < 0) else None)"
p = Piecewise(
(1, x < 1),
(2, x < 2),
(0, True)
)
l = lambdarepr(p)
eval(h + l)
assert l == "((1) if (x < 1) else (2) if (x < 2) else (0))"
p = Piecewise(
(1, x < 1),
(2, x < 2),
)
l = lambdarepr(p)
eval(h + l)
assert l == "((1) if (x < 1) else (2) if (x < 2) else None)"
p = Piecewise(
(x, x < 1),
(x**2, Interval(3, 4, True, False).contains(x)),
(0, True),
)
l = lambdarepr(p)
eval(h + l)
assert l == "((x) if (x < 1) else (x**2) if (((x <= 4)) and ((x > 3))) else (0))"
p = Piecewise(
(x**2, x < 0),
(x, x < 1),
(2 - x, x >= 1),
(0, True), evaluate=False
)
l = lambdarepr(p)
eval(h + l)
assert l == "((x**2) if (x < 0) else (x) if (x < 1)"\
" else (2 - x) if (x >= 1) else (0))"
p = Piecewise(
(x**2, x < 0),
(x, x < 1),
(2 - x, x >= 1), evaluate=False
)
l = lambdarepr(p)
eval(h + l)
assert l == "((x**2) if (x < 0) else (x) if (x < 1)"\
" else (2 - x) if (x >= 1) else None)"
p = Piecewise(
(1, x >= 1),
(2, x >= 2),
(3, x >= 3),
(4, x >= 4),
(5, x >= 5),
(6, True)
)
l = lambdarepr(p)
eval(h + l)
assert l == "((1) if (x >= 1) else (2) if (x >= 2) else (3) if (x >= 3)"\
" else (4) if (x >= 4) else (5) if (x >= 5) else (6))"
p = Piecewise(
(1, x <= 1),
(2, x <= 2),
(3, x <= 3),
(4, x <= 4),
(5, x <= 5),
(6, True)
)
l = lambdarepr(p)
eval(h + l)
assert l == "((1) if (x <= 1) else (2) if (x <= 2) else (3) if (x <= 3)"\
" else (4) if (x <= 4) else (5) if (x <= 5) else (6))"
p = Piecewise(
(1, x > 1),
(2, x > 2),
(3, x > 3),
(4, x > 4),
(5, x > 5),
(6, True)
)
l = lambdarepr(p)
eval(h + l)
assert l =="((1) if (x > 1) else (2) if (x > 2) else (3) if (x > 3)"\
" else (4) if (x > 4) else (5) if (x > 5) else (6))"
p = Piecewise(
(1, x < 1),
(2, x < 2),
(3, x < 3),
(4, x < 4),
(5, x < 5),
(6, True)
)
l = lambdarepr(p)
eval(h + l)
assert l == "((1) if (x < 1) else (2) if (x < 2) else (3) if (x < 3)"\
" else (4) if (x < 4) else (5) if (x < 5) else (6))"
p = Piecewise(
(Piecewise(
(1, x > 0),
(2, True)
), y > 0),
(3, True)
)
l = lambdarepr(p)
eval(h + l)
assert l == "((((1) if (x > 0) else (2))) if (y > 0) else (3))"
def test_sum__1():
# In each case, test eval() the lambdarepr() to make sure that
# it evaluates to the same results as the symbolic expression
s = Sum(x ** i, (i, a, b))
l = lambdarepr(s)
assert l == "(builtins.sum(x**i for i in range(a, b+1)))"
args = x, a, b
f = lambdify(args, s)
v = 2, 3, 8
assert f(*v) == s.subs(zip(args, v)).doit()
def test_sum__2():
s = Sum(i * x, (i, a, b))
l = lambdarepr(s)
assert l == "(builtins.sum(i*x for i in range(a, b+1)))"
args = x, a, b
f = lambdify(args, s)
v = 2, 3, 8
assert f(*v) == s.subs(zip(args, v)).doit()
def test_multiple_sums():
s = Sum(i * x + j, (i, a, b), (j, c, d))
l = lambdarepr(s)
assert l == "(builtins.sum(i*x + j for i in range(a, b+1) for j in range(c, d+1)))"
args = x, a, b, c, d
f = lambdify(args, s)
vals = 2, 3, 4, 5, 6
f_ref = s.subs(zip(args, vals)).doit()
f_res = f(*vals)
assert f_res == f_ref
def test_sqrt():
prntr = LambdaPrinter({'standard' : 'python3'})
assert prntr._print_Pow(sqrt(x), rational=False) == 'sqrt(x)'
assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)'
def test_settings():
raises(TypeError, lambda: lambdarepr(sin(x), method="garbage"))
def test_numexpr():
# test ITE rewrite as Piecewise
from sympy.logic.boolalg import ITE
expr = ITE(x > 0, True, False, evaluate=False)
assert NumExprPrinter().doprint(expr) == \
"numexpr.evaluate('where((x > 0), True, False)', truediv=True)"
from sympy.codegen.ast import Return, FunctionDefinition, Variable, Assignment
func_def = FunctionDefinition(None, 'foo', [Variable(x)], [Assignment(y,x), Return(y**2)])
expected = "def foo(x):\n"\
" y = numexpr.evaluate('x', truediv=True)\n"\
" return numexpr.evaluate('y**2', truediv=True)"
assert NumExprPrinter().doprint(func_def) == expected
class CustomPrintedObject(Expr):
def _lambdacode(self, printer):
return 'lambda'
def _tensorflowcode(self, printer):
return 'tensorflow'
def _numpycode(self, printer):
return 'numpy'
def _numexprcode(self, printer):
return 'numexpr'
def _mpmathcode(self, printer):
return 'mpmath'
def test_printmethod():
# In each case, printmethod is called to test
# its working
obj = CustomPrintedObject()
assert LambdaPrinter().doprint(obj) == 'lambda'
assert TensorflowPrinter().doprint(obj) == 'tensorflow'
assert NumExprPrinter().doprint(obj) == "numexpr.evaluate('numexpr', truediv=True)"
assert NumExprPrinter().doprint(Piecewise((y, x >= 0), (z, x < 0))) == \
"numexpr.evaluate('where((x >= 0), y, z)', truediv=True)"

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,224 @@
from sympy.external import import_module
from sympy.testing.pytest import raises
import ctypes
if import_module('llvmlite'):
import sympy.printing.llvmjitcode as g
else:
disabled = True
import sympy
from sympy.abc import a, b, n
# copied from numpy.isclose documentation
def isclose(a, b):
rtol = 1e-5
atol = 1e-8
return abs(a-b) <= atol + rtol*abs(b)
def test_simple_expr():
e = a + 1.0
f = g.llvm_callable([a], e)
res = float(e.subs({a: 4.0}).evalf())
jit_res = f(4.0)
assert isclose(jit_res, res)
def test_two_arg():
e = 4.0*a + b + 3.0
f = g.llvm_callable([a, b], e)
res = float(e.subs({a: 4.0, b: 3.0}).evalf())
jit_res = f(4.0, 3.0)
assert isclose(jit_res, res)
def test_func():
e = 4.0*sympy.exp(-a)
f = g.llvm_callable([a], e)
res = float(e.subs({a: 1.5}).evalf())
jit_res = f(1.5)
assert isclose(jit_res, res)
def test_two_func():
e = 4.0*sympy.exp(-a) + sympy.exp(b)
f = g.llvm_callable([a, b], e)
res = float(e.subs({a: 1.5, b: 2.0}).evalf())
jit_res = f(1.5, 2.0)
assert isclose(jit_res, res)
def test_two_sqrt():
e = 4.0*sympy.sqrt(a) + sympy.sqrt(b)
f = g.llvm_callable([a, b], e)
res = float(e.subs({a: 1.5, b: 2.0}).evalf())
jit_res = f(1.5, 2.0)
assert isclose(jit_res, res)
def test_two_pow():
e = a**1.5 + b**7
f = g.llvm_callable([a, b], e)
res = float(e.subs({a: 1.5, b: 2.0}).evalf())
jit_res = f(1.5, 2.0)
assert isclose(jit_res, res)
def test_callback():
e = a + 1.2
f = g.llvm_callable([a], e, callback_type='scipy.integrate.test')
m = ctypes.c_int(1)
array_type = ctypes.c_double * 1
inp = {a: 2.2}
array = array_type(inp[a])
jit_res = f(m, array)
res = float(e.subs(inp).evalf())
assert isclose(jit_res, res)
def test_callback_cubature():
e = a + 1.2
f = g.llvm_callable([a], e, callback_type='cubature')
m = ctypes.c_int(1)
array_type = ctypes.c_double * 1
inp = {a: 2.2}
array = array_type(inp[a])
out_array = array_type(0.0)
jit_ret = f(m, array, None, m, out_array)
assert jit_ret == 0
res = float(e.subs(inp).evalf())
assert isclose(out_array[0], res)
def test_callback_two():
e = 3*a*b
f = g.llvm_callable([a, b], e, callback_type='scipy.integrate.test')
m = ctypes.c_int(2)
array_type = ctypes.c_double * 2
inp = {a: 0.2, b: 1.7}
array = array_type(inp[a], inp[b])
jit_res = f(m, array)
res = float(e.subs(inp).evalf())
assert isclose(jit_res, res)
def test_callback_alt_two():
d = sympy.IndexedBase('d')
e = 3*d[0]*d[1]
f = g.llvm_callable([n, d], e, callback_type='scipy.integrate.test')
m = ctypes.c_int(2)
array_type = ctypes.c_double * 2
inp = {d[0]: 0.2, d[1]: 1.7}
array = array_type(inp[d[0]], inp[d[1]])
jit_res = f(m, array)
res = float(e.subs(inp).evalf())
assert isclose(jit_res, res)
def test_multiple_statements():
# Match return from CSE
e = [[(b, 4.0*a)], [b + 5]]
f = g.llvm_callable([a], e)
b_val = e[0][0][1].subs({a: 1.5})
res = float(e[1][0].subs({b: b_val}).evalf())
jit_res = f(1.5)
assert isclose(jit_res, res)
f_callback = g.llvm_callable([a], e, callback_type='scipy.integrate.test')
m = ctypes.c_int(1)
array_type = ctypes.c_double * 1
array = array_type(1.5)
jit_callback_res = f_callback(m, array)
assert isclose(jit_callback_res, res)
def test_cse():
e = a*a + b*b + sympy.exp(-a*a - b*b)
e2 = sympy.cse(e)
f = g.llvm_callable([a, b], e2)
res = float(e.subs({a: 2.3, b: 0.1}).evalf())
jit_res = f(2.3, 0.1)
assert isclose(jit_res, res)
def eval_cse(e, sub_dict):
tmp_dict = {}
for tmp_name, tmp_expr in e[0]:
e2 = tmp_expr.subs(sub_dict)
e3 = e2.subs(tmp_dict)
tmp_dict[tmp_name] = e3
return [e.subs(sub_dict).subs(tmp_dict) for e in e[1]]
def test_cse_multiple():
e1 = a*a
e2 = a*a + b*b
e3 = sympy.cse([e1, e2])
raises(NotImplementedError,
lambda: g.llvm_callable([a, b], e3, callback_type='scipy.integrate'))
f = g.llvm_callable([a, b], e3)
jit_res = f(0.1, 1.5)
assert len(jit_res) == 2
res = eval_cse(e3, {a: 0.1, b: 1.5})
assert isclose(res[0], jit_res[0])
assert isclose(res[1], jit_res[1])
def test_callback_cubature_multiple():
e1 = a*a
e2 = a*a + b*b
e3 = sympy.cse([e1, e2, 4*e2])
f = g.llvm_callable([a, b], e3, callback_type='cubature')
# Number of input variables
ndim = 2
# Number of output expression values
outdim = 3
m = ctypes.c_int(ndim)
fdim = ctypes.c_int(outdim)
array_type = ctypes.c_double * ndim
out_array_type = ctypes.c_double * outdim
inp = {a: 0.2, b: 1.5}
array = array_type(inp[a], inp[b])
out_array = out_array_type()
jit_ret = f(m, array, None, fdim, out_array)
assert jit_ret == 0
res = eval_cse(e3, inp)
assert isclose(out_array[0], res[0])
assert isclose(out_array[1], res[1])
assert isclose(out_array[2], res[2])
def test_symbol_not_found():
e = a*a + b
raises(LookupError, lambda: g.llvm_callable([a], e))
def test_bad_callback():
e = a
raises(ValueError, lambda: g.llvm_callable([a], e, callback_type='bad_callback'))

View File

@ -0,0 +1,381 @@
from sympy.core import (S, pi, oo, symbols, Function, Rational, Integer,
Tuple, Symbol, Eq, Ne, Le, Lt, Gt, Ge)
from sympy.core import EulerGamma, GoldenRatio, Catalan, Lambda, Mul, Pow
from sympy.functions import Piecewise, sqrt, ceiling, exp, sin, cos, sinc, lucas
from sympy.testing.pytest import raises
from sympy.utilities.lambdify import implemented_function
from sympy.matrices import (eye, Matrix, MatrixSymbol, Identity,
HadamardProduct, SparseMatrix)
from sympy.functions.special.bessel import besseli
from sympy.printing.maple import maple_code
x, y, z = symbols('x,y,z')
def test_Integer():
assert maple_code(Integer(67)) == "67"
assert maple_code(Integer(-1)) == "-1"
def test_Rational():
assert maple_code(Rational(3, 7)) == "3/7"
assert maple_code(Rational(18, 9)) == "2"
assert maple_code(Rational(3, -7)) == "-3/7"
assert maple_code(Rational(-3, -7)) == "3/7"
assert maple_code(x + Rational(3, 7)) == "x + 3/7"
assert maple_code(Rational(3, 7) * x) == '(3/7)*x'
def test_Relational():
assert maple_code(Eq(x, y)) == "x = y"
assert maple_code(Ne(x, y)) == "x <> y"
assert maple_code(Le(x, y)) == "x <= y"
assert maple_code(Lt(x, y)) == "x < y"
assert maple_code(Gt(x, y)) == "x > y"
assert maple_code(Ge(x, y)) == "x >= y"
def test_Function():
assert maple_code(sin(x) ** cos(x)) == "sin(x)^cos(x)"
assert maple_code(abs(x)) == "abs(x)"
assert maple_code(ceiling(x)) == "ceil(x)"
def test_Pow():
assert maple_code(x ** 3) == "x^3"
assert maple_code(x ** (y ** 3)) == "x^(y^3)"
assert maple_code((x ** 3) ** y) == "(x^3)^y"
assert maple_code(x ** Rational(2, 3)) == 'x^(2/3)'
g = implemented_function('g', Lambda(x, 2 * x))
assert maple_code(1 / (g(x) * 3.5) ** (x - y ** x) / (x ** 2 + y)) == \
"(3.5*2*x)^(-x + y^x)/(x^2 + y)"
# For issue 14160
assert maple_code(Mul(-2, x, Pow(Mul(y, y, evaluate=False), -1, evaluate=False),
evaluate=False)) == '-2*x/(y*y)'
def test_basic_ops():
assert maple_code(x * y) == "x*y"
assert maple_code(x + y) == "x + y"
assert maple_code(x - y) == "x - y"
assert maple_code(-x) == "-x"
def test_1_over_x_and_sqrt():
# 1.0 and 0.5 would do something different in regular StrPrinter,
# but these are exact in IEEE floating point so no different here.
assert maple_code(1 / x) == '1/x'
assert maple_code(x ** -1) == maple_code(x ** -1.0) == '1/x'
assert maple_code(1 / sqrt(x)) == '1/sqrt(x)'
assert maple_code(x ** -S.Half) == maple_code(x ** -0.5) == '1/sqrt(x)'
assert maple_code(sqrt(x)) == 'sqrt(x)'
assert maple_code(x ** S.Half) == maple_code(x ** 0.5) == 'sqrt(x)'
assert maple_code(1 / pi) == '1/Pi'
assert maple_code(pi ** -1) == maple_code(pi ** -1.0) == '1/Pi'
assert maple_code(pi ** -0.5) == '1/sqrt(Pi)'
def test_mix_number_mult_symbols():
assert maple_code(3 * x) == "3*x"
assert maple_code(pi * x) == "Pi*x"
assert maple_code(3 / x) == "3/x"
assert maple_code(pi / x) == "Pi/x"
assert maple_code(x / 3) == '(1/3)*x'
assert maple_code(x / pi) == "x/Pi"
assert maple_code(x * y) == "x*y"
assert maple_code(3 * x * y) == "3*x*y"
assert maple_code(3 * pi * x * y) == "3*Pi*x*y"
assert maple_code(x / y) == "x/y"
assert maple_code(3 * x / y) == "3*x/y"
assert maple_code(x * y / z) == "x*y/z"
assert maple_code(x / y * z) == "x*z/y"
assert maple_code(1 / x / y) == "1/(x*y)"
assert maple_code(2 * pi * x / y / z) == "2*Pi*x/(y*z)"
assert maple_code(3 * pi / x) == "3*Pi/x"
assert maple_code(S(3) / 5) == "3/5"
assert maple_code(S(3) / 5 * x) == '(3/5)*x'
assert maple_code(x / y / z) == "x/(y*z)"
assert maple_code((x + y) / z) == "(x + y)/z"
assert maple_code((x + y) / (z + x)) == "(x + y)/(x + z)"
assert maple_code((x + y) / EulerGamma) == '(x + y)/gamma'
assert maple_code(x / 3 / pi) == '(1/3)*x/Pi'
assert maple_code(S(3) / 5 * x * y / pi) == '(3/5)*x*y/Pi'
def test_mix_number_pow_symbols():
assert maple_code(pi ** 3) == 'Pi^3'
assert maple_code(x ** 2) == 'x^2'
assert maple_code(x ** (pi ** 3)) == 'x^(Pi^3)'
assert maple_code(x ** y) == 'x^y'
assert maple_code(x ** (y ** z)) == 'x^(y^z)'
assert maple_code((x ** y) ** z) == '(x^y)^z'
def test_imag():
I = S('I')
assert maple_code(I) == "I"
assert maple_code(5 * I) == "5*I"
assert maple_code((S(3) / 2) * I) == "(3/2)*I"
assert maple_code(3 + 4 * I) == "3 + 4*I"
def test_constants():
assert maple_code(pi) == "Pi"
assert maple_code(oo) == "infinity"
assert maple_code(-oo) == "-infinity"
assert maple_code(S.NegativeInfinity) == "-infinity"
assert maple_code(S.NaN) == "undefined"
assert maple_code(S.Exp1) == "exp(1)"
assert maple_code(exp(1)) == "exp(1)"
def test_constants_other():
assert maple_code(2 * GoldenRatio) == '2*(1/2 + (1/2)*sqrt(5))'
assert maple_code(2 * Catalan) == '2*Catalan'
assert maple_code(2 * EulerGamma) == "2*gamma"
def test_boolean():
assert maple_code(x & y) == "x and y"
assert maple_code(x | y) == "x or y"
assert maple_code(~x) == "not x"
assert maple_code(x & y & z) == "x and y and z"
assert maple_code(x | y | z) == "x or y or z"
assert maple_code((x & y) | z) == "z or x and y"
assert maple_code((x | y) & z) == "z and (x or y)"
def test_Matrices():
assert maple_code(Matrix(1, 1, [10])) == \
'Matrix([[10]], storage = rectangular)'
A = Matrix([[1, sin(x / 2), abs(x)],
[0, 1, pi],
[0, exp(1), ceiling(x)]])
expected = \
'Matrix(' \
'[[1, sin((1/2)*x), abs(x)],' \
' [0, 1, Pi],' \
' [0, exp(1), ceil(x)]], ' \
'storage = rectangular)'
assert maple_code(A) == expected
# row and columns
assert maple_code(A[:, 0]) == \
'Matrix([[1], [0], [0]], storage = rectangular)'
assert maple_code(A[0, :]) == \
'Matrix([[1, sin((1/2)*x), abs(x)]], storage = rectangular)'
assert maple_code(Matrix([[x, x - y, -y]])) == \
'Matrix([[x, x - y, -y]], storage = rectangular)'
# empty matrices
assert maple_code(Matrix(0, 0, [])) == \
'Matrix([], storage = rectangular)'
assert maple_code(Matrix(0, 3, [])) == \
'Matrix([], storage = rectangular)'
def test_SparseMatrices():
assert maple_code(SparseMatrix(Identity(2))) == 'Matrix([[1, 0], [0, 1]], storage = sparse)'
def test_vector_entries_hadamard():
# For a row or column, user might to use the other dimension
A = Matrix([[1, sin(2 / x), 3 * pi / x / 5]])
assert maple_code(A) == \
'Matrix([[1, sin(2/x), (3/5)*Pi/x]], storage = rectangular)'
assert maple_code(A.T) == \
'Matrix([[1], [sin(2/x)], [(3/5)*Pi/x]], storage = rectangular)'
def test_Matrices_entries_not_hadamard():
A = Matrix([[1, sin(2 / x), 3 * pi / x / 5], [1, 2, x * y]])
expected = \
'Matrix([[1, sin(2/x), (3/5)*Pi/x], [1, 2, x*y]], ' \
'storage = rectangular)'
assert maple_code(A) == expected
def test_MatrixSymbol():
n = Symbol('n', integer=True)
A = MatrixSymbol('A', n, n)
B = MatrixSymbol('B', n, n)
assert maple_code(A * B) == "A.B"
assert maple_code(B * A) == "B.A"
assert maple_code(2 * A * B) == "2*A.B"
assert maple_code(B * 2 * A) == "2*B.A"
assert maple_code(
A * (B + 3 * Identity(n))) == "A.(3*Matrix(n, shape = identity) + B)"
assert maple_code(A ** (x ** 2)) == "MatrixPower(A, x^2)"
assert maple_code(A ** 3) == "MatrixPower(A, 3)"
assert maple_code(A ** (S.Half)) == "MatrixPower(A, 1/2)"
def test_special_matrices():
assert maple_code(6 * Identity(3)) == "6*Matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]], storage = sparse)"
assert maple_code(Identity(x)) == 'Matrix(x, shape = identity)'
def test_containers():
assert maple_code([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \
"[1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]"
assert maple_code((1, 2, (3, 4))) == "[1, 2, [3, 4]]"
assert maple_code([1]) == "[1]"
assert maple_code((1,)) == "[1]"
assert maple_code(Tuple(*[1, 2, 3])) == "[1, 2, 3]"
assert maple_code((1, x * y, (3, x ** 2))) == "[1, x*y, [3, x^2]]"
# scalar, matrix, empty matrix and empty list
assert maple_code((1, eye(3), Matrix(0, 0, []), [])) == \
"[1, Matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]], storage = rectangular), Matrix([], storage = rectangular), []]"
def test_maple_noninline():
source = maple_code((x + y)/Catalan, assign_to='me', inline=False)
expected = "me := (x + y)/Catalan"
assert source == expected
def test_maple_matrix_assign_to():
A = Matrix([[1, 2, 3]])
assert maple_code(A, assign_to='a') == "a := Matrix([[1, 2, 3]], storage = rectangular)"
A = Matrix([[1, 2], [3, 4]])
assert maple_code(A, assign_to='A') == "A := Matrix([[1, 2], [3, 4]], storage = rectangular)"
def test_maple_matrix_assign_to_more():
# assigning to Symbol or MatrixSymbol requires lhs/rhs match
A = Matrix([[1, 2, 3]])
B = MatrixSymbol('B', 1, 3)
C = MatrixSymbol('C', 2, 3)
assert maple_code(A, assign_to=B) == "B := Matrix([[1, 2, 3]], storage = rectangular)"
raises(ValueError, lambda: maple_code(A, assign_to=x))
raises(ValueError, lambda: maple_code(A, assign_to=C))
def test_maple_matrix_1x1():
A = Matrix([[3]])
assert maple_code(A, assign_to='B') == "B := Matrix([[3]], storage = rectangular)"
def test_maple_matrix_elements():
A = Matrix([[x, 2, x * y]])
assert maple_code(A[0, 0] ** 2 + A[0, 1] + A[0, 2]) == "x^2 + x*y + 2"
AA = MatrixSymbol('AA', 1, 3)
assert maple_code(AA) == "AA"
assert maple_code(AA[0, 0] ** 2 + sin(AA[0, 1]) + AA[0, 2]) == \
"sin(AA[1, 2]) + AA[1, 1]^2 + AA[1, 3]"
assert maple_code(sum(AA)) == "AA[1, 1] + AA[1, 2] + AA[1, 3]"
def test_maple_boolean():
assert maple_code(True) == "true"
assert maple_code(S.true) == "true"
assert maple_code(False) == "false"
assert maple_code(S.false) == "false"
def test_sparse():
M = SparseMatrix(5, 6, {})
M[2, 2] = 10
M[1, 2] = 20
M[1, 3] = 22
M[0, 3] = 30
M[3, 0] = x * y
assert maple_code(M) == \
'Matrix([[0, 0, 0, 30, 0, 0],' \
' [0, 0, 20, 22, 0, 0],' \
' [0, 0, 10, 0, 0, 0],' \
' [x*y, 0, 0, 0, 0, 0],' \
' [0, 0, 0, 0, 0, 0]], ' \
'storage = sparse)'
# Not an important point.
def test_maple_not_supported():
with raises(NotImplementedError):
maple_code(S.ComplexInfinity)
def test_MatrixElement_printing():
# test cases for issue #11821
A = MatrixSymbol("A", 1, 3)
B = MatrixSymbol("B", 1, 3)
assert (maple_code(A[0, 0]) == "A[1, 1]")
assert (maple_code(3 * A[0, 0]) == "3*A[1, 1]")
F = A-B
assert (maple_code(F[0,0]) == "A[1, 1] - B[1, 1]")
def test_hadamard():
A = MatrixSymbol('A', 3, 3)
B = MatrixSymbol('B', 3, 3)
v = MatrixSymbol('v', 3, 1)
h = MatrixSymbol('h', 1, 3)
C = HadamardProduct(A, B)
assert maple_code(C) == "A*B"
assert maple_code(C * v) == "(A*B).v"
# HadamardProduct is higher than dot product.
assert maple_code(h * C * v) == "h.(A*B).v"
assert maple_code(C * A) == "(A*B).A"
# mixing Hadamard and scalar strange b/c we vectorize scalars
assert maple_code(C * x * y) == "x*y*(A*B)"
def test_maple_piecewise():
expr = Piecewise((x, x < 1), (x ** 2, True))
assert maple_code(expr) == "piecewise(x < 1, x, x^2)"
assert maple_code(expr, assign_to="r") == (
"r := piecewise(x < 1, x, x^2)")
expr = Piecewise((x ** 2, x < 1), (x ** 3, x < 2), (x ** 4, x < 3), (x ** 5, True))
expected = "piecewise(x < 1, x^2, x < 2, x^3, x < 3, x^4, x^5)"
assert maple_code(expr) == expected
assert maple_code(expr, assign_to="r") == "r := " + expected
# Check that Piecewise without a True (default) condition error
expr = Piecewise((x, x < 1), (x ** 2, x > 1), (sin(x), x > 0))
raises(ValueError, lambda: maple_code(expr))
def test_maple_piecewise_times_const():
pw = Piecewise((x, x < 1), (x ** 2, True))
assert maple_code(2 * pw) == "2*piecewise(x < 1, x, x^2)"
assert maple_code(pw / x) == "piecewise(x < 1, x, x^2)/x"
assert maple_code(pw / (x * y)) == "piecewise(x < 1, x, x^2)/(x*y)"
assert maple_code(pw / 3) == "(1/3)*piecewise(x < 1, x, x^2)"
def test_maple_derivatives():
f = Function('f')
assert maple_code(f(x).diff(x)) == 'diff(f(x), x)'
assert maple_code(f(x).diff(x, 2)) == 'diff(f(x), x$2)'
def test_automatic_rewrites():
assert maple_code(lucas(x)) == '(2^(-x)*((1 - sqrt(5))^x + (1 + sqrt(5))^x))'
assert maple_code(sinc(x)) == '(piecewise(x <> 0, sin(x)/x, 1))'
def test_specfun():
assert maple_code('asin(x)') == 'arcsin(x)'
assert maple_code(besseli(x, y)) == 'BesselI(x, y)'

View File

@ -0,0 +1,287 @@
from sympy.core import (S, pi, oo, symbols, Function, Rational, Integer, Tuple,
Derivative, Eq, Ne, Le, Lt, Gt, Ge)
from sympy.integrals import Integral
from sympy.concrete import Sum
from sympy.functions import (exp, sin, cos, fresnelc, fresnels, conjugate, Max,
Min, gamma, polygamma, loggamma, erf, erfi, erfc,
erf2, expint, erfinv, erfcinv, Ei, Si, Ci, li,
Shi, Chi, uppergamma, beta, subfactorial, erf2inv,
factorial, factorial2, catalan, RisingFactorial,
FallingFactorial, harmonic, atan2, sec, acsc,
hermite, laguerre, assoc_laguerre, jacobi,
gegenbauer, chebyshevt, chebyshevu, legendre,
assoc_legendre, Li, LambertW)
from sympy.printing.mathematica import mathematica_code as mcode
x, y, z, w = symbols('x,y,z,w')
f = Function('f')
def test_Integer():
assert mcode(Integer(67)) == "67"
assert mcode(Integer(-1)) == "-1"
def test_Rational():
assert mcode(Rational(3, 7)) == "3/7"
assert mcode(Rational(18, 9)) == "2"
assert mcode(Rational(3, -7)) == "-3/7"
assert mcode(Rational(-3, -7)) == "3/7"
assert mcode(x + Rational(3, 7)) == "x + 3/7"
assert mcode(Rational(3, 7)*x) == "(3/7)*x"
def test_Relational():
assert mcode(Eq(x, y)) == "x == y"
assert mcode(Ne(x, y)) == "x != y"
assert mcode(Le(x, y)) == "x <= y"
assert mcode(Lt(x, y)) == "x < y"
assert mcode(Gt(x, y)) == "x > y"
assert mcode(Ge(x, y)) == "x >= y"
def test_Function():
assert mcode(f(x, y, z)) == "f[x, y, z]"
assert mcode(sin(x) ** cos(x)) == "Sin[x]^Cos[x]"
assert mcode(sec(x) * acsc(x)) == "ArcCsc[x]*Sec[x]"
assert mcode(atan2(x, y)) == "ArcTan[x, y]"
assert mcode(conjugate(x)) == "Conjugate[x]"
assert mcode(Max(x, y, z)*Min(y, z)) == "Max[x, y, z]*Min[y, z]"
assert mcode(fresnelc(x)) == "FresnelC[x]"
assert mcode(fresnels(x)) == "FresnelS[x]"
assert mcode(gamma(x)) == "Gamma[x]"
assert mcode(uppergamma(x, y)) == "Gamma[x, y]"
assert mcode(polygamma(x, y)) == "PolyGamma[x, y]"
assert mcode(loggamma(x)) == "LogGamma[x]"
assert mcode(erf(x)) == "Erf[x]"
assert mcode(erfc(x)) == "Erfc[x]"
assert mcode(erfi(x)) == "Erfi[x]"
assert mcode(erf2(x, y)) == "Erf[x, y]"
assert mcode(expint(x, y)) == "ExpIntegralE[x, y]"
assert mcode(erfcinv(x)) == "InverseErfc[x]"
assert mcode(erfinv(x)) == "InverseErf[x]"
assert mcode(erf2inv(x, y)) == "InverseErf[x, y]"
assert mcode(Ei(x)) == "ExpIntegralEi[x]"
assert mcode(Ci(x)) == "CosIntegral[x]"
assert mcode(li(x)) == "LogIntegral[x]"
assert mcode(Si(x)) == "SinIntegral[x]"
assert mcode(Shi(x)) == "SinhIntegral[x]"
assert mcode(Chi(x)) == "CoshIntegral[x]"
assert mcode(beta(x, y)) == "Beta[x, y]"
assert mcode(factorial(x)) == "Factorial[x]"
assert mcode(factorial2(x)) == "Factorial2[x]"
assert mcode(subfactorial(x)) == "Subfactorial[x]"
assert mcode(FallingFactorial(x, y)) == "FactorialPower[x, y]"
assert mcode(RisingFactorial(x, y)) == "Pochhammer[x, y]"
assert mcode(catalan(x)) == "CatalanNumber[x]"
assert mcode(harmonic(x)) == "HarmonicNumber[x]"
assert mcode(harmonic(x, y)) == "HarmonicNumber[x, y]"
assert mcode(Li(x)) == "LogIntegral[x] - LogIntegral[2]"
assert mcode(LambertW(x)) == "ProductLog[x]"
assert mcode(LambertW(x, -1)) == "ProductLog[-1, x]"
assert mcode(LambertW(x, y)) == "ProductLog[y, x]"
def test_special_polynomials():
assert mcode(hermite(x, y)) == "HermiteH[x, y]"
assert mcode(laguerre(x, y)) == "LaguerreL[x, y]"
assert mcode(assoc_laguerre(x, y, z)) == "LaguerreL[x, y, z]"
assert mcode(jacobi(x, y, z, w)) == "JacobiP[x, y, z, w]"
assert mcode(gegenbauer(x, y, z)) == "GegenbauerC[x, y, z]"
assert mcode(chebyshevt(x, y)) == "ChebyshevT[x, y]"
assert mcode(chebyshevu(x, y)) == "ChebyshevU[x, y]"
assert mcode(legendre(x, y)) == "LegendreP[x, y]"
assert mcode(assoc_legendre(x, y, z)) == "LegendreP[x, y, z]"
def test_Pow():
assert mcode(x**3) == "x^3"
assert mcode(x**(y**3)) == "x^(y^3)"
assert mcode(1/(f(x)*3.5)**(x - y**x)/(x**2 + y)) == \
"(3.5*f[x])^(-x + y^x)/(x^2 + y)"
assert mcode(x**-1.0) == 'x^(-1.0)'
assert mcode(x**Rational(2, 3)) == 'x^(2/3)'
def test_Mul():
A, B, C, D = symbols('A B C D', commutative=False)
assert mcode(x*y*z) == "x*y*z"
assert mcode(x*y*A) == "x*y*A"
assert mcode(x*y*A*B) == "x*y*A**B"
assert mcode(x*y*A*B*C) == "x*y*A**B**C"
assert mcode(x*A*B*(C + D)*A*y) == "x*y*A**B**(C + D)**A"
def test_constants():
assert mcode(S.Zero) == "0"
assert mcode(S.One) == "1"
assert mcode(S.NegativeOne) == "-1"
assert mcode(S.Half) == "1/2"
assert mcode(S.ImaginaryUnit) == "I"
assert mcode(oo) == "Infinity"
assert mcode(S.NegativeInfinity) == "-Infinity"
assert mcode(S.ComplexInfinity) == "ComplexInfinity"
assert mcode(S.NaN) == "Indeterminate"
assert mcode(S.Exp1) == "E"
assert mcode(pi) == "Pi"
assert mcode(S.GoldenRatio) == "GoldenRatio"
assert mcode(S.TribonacciConstant) == \
"(1/3 + (1/3)*(19 - 3*33^(1/2))^(1/3) + " \
"(1/3)*(3*33^(1/2) + 19)^(1/3))"
assert mcode(2*S.TribonacciConstant) == \
"2*(1/3 + (1/3)*(19 - 3*33^(1/2))^(1/3) + " \
"(1/3)*(3*33^(1/2) + 19)^(1/3))"
assert mcode(S.EulerGamma) == "EulerGamma"
assert mcode(S.Catalan) == "Catalan"
def test_containers():
assert mcode([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \
"{1, 2, 3, {4, 5, {6, 7}}, 8, {9, 10}, 11}"
assert mcode((1, 2, (3, 4))) == "{1, 2, {3, 4}}"
assert mcode([1]) == "{1}"
assert mcode((1,)) == "{1}"
assert mcode(Tuple(*[1, 2, 3])) == "{1, 2, 3}"
def test_matrices():
from sympy.matrices import MutableDenseMatrix, MutableSparseMatrix, \
ImmutableDenseMatrix, ImmutableSparseMatrix
A = MutableDenseMatrix(
[[1, -1, 0, 0],
[0, 1, -1, 0],
[0, 0, 1, -1],
[0, 0, 0, 1]]
)
B = MutableSparseMatrix(A)
C = ImmutableDenseMatrix(A)
D = ImmutableSparseMatrix(A)
assert mcode(C) == mcode(A) == \
"{{1, -1, 0, 0}, " \
"{0, 1, -1, 0}, " \
"{0, 0, 1, -1}, " \
"{0, 0, 0, 1}}"
assert mcode(D) == mcode(B) == \
"SparseArray[{" \
"{1, 1} -> 1, {1, 2} -> -1, {2, 2} -> 1, {2, 3} -> -1, " \
"{3, 3} -> 1, {3, 4} -> -1, {4, 4} -> 1" \
"}, {4, 4}]"
# Trivial cases of matrices
assert mcode(MutableDenseMatrix(0, 0, [])) == '{}'
assert mcode(MutableSparseMatrix(0, 0, [])) == 'SparseArray[{}, {0, 0}]'
assert mcode(MutableDenseMatrix(0, 3, [])) == '{}'
assert mcode(MutableSparseMatrix(0, 3, [])) == 'SparseArray[{}, {0, 3}]'
assert mcode(MutableDenseMatrix(3, 0, [])) == '{{}, {}, {}}'
assert mcode(MutableSparseMatrix(3, 0, [])) == 'SparseArray[{}, {3, 0}]'
def test_NDArray():
from sympy.tensor.array import (
MutableDenseNDimArray, ImmutableDenseNDimArray,
MutableSparseNDimArray, ImmutableSparseNDimArray)
example = MutableDenseNDimArray(
[[[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]],
[[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24]]]
)
assert mcode(example) == \
"{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, " \
"{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}"
example = ImmutableDenseNDimArray(example)
assert mcode(example) == \
"{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, " \
"{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}"
example = MutableSparseNDimArray(example)
assert mcode(example) == \
"SparseArray[{" \
"{1, 1, 1} -> 1, {1, 1, 2} -> 2, {1, 1, 3} -> 3, " \
"{1, 1, 4} -> 4, {1, 2, 1} -> 5, {1, 2, 2} -> 6, " \
"{1, 2, 3} -> 7, {1, 2, 4} -> 8, {1, 3, 1} -> 9, " \
"{1, 3, 2} -> 10, {1, 3, 3} -> 11, {1, 3, 4} -> 12, " \
"{2, 1, 1} -> 13, {2, 1, 2} -> 14, {2, 1, 3} -> 15, " \
"{2, 1, 4} -> 16, {2, 2, 1} -> 17, {2, 2, 2} -> 18, " \
"{2, 2, 3} -> 19, {2, 2, 4} -> 20, {2, 3, 1} -> 21, " \
"{2, 3, 2} -> 22, {2, 3, 3} -> 23, {2, 3, 4} -> 24" \
"}, {2, 3, 4}]"
example = ImmutableSparseNDimArray(example)
assert mcode(example) == \
"SparseArray[{" \
"{1, 1, 1} -> 1, {1, 1, 2} -> 2, {1, 1, 3} -> 3, " \
"{1, 1, 4} -> 4, {1, 2, 1} -> 5, {1, 2, 2} -> 6, " \
"{1, 2, 3} -> 7, {1, 2, 4} -> 8, {1, 3, 1} -> 9, " \
"{1, 3, 2} -> 10, {1, 3, 3} -> 11, {1, 3, 4} -> 12, " \
"{2, 1, 1} -> 13, {2, 1, 2} -> 14, {2, 1, 3} -> 15, " \
"{2, 1, 4} -> 16, {2, 2, 1} -> 17, {2, 2, 2} -> 18, " \
"{2, 2, 3} -> 19, {2, 2, 4} -> 20, {2, 3, 1} -> 21, " \
"{2, 3, 2} -> 22, {2, 3, 3} -> 23, {2, 3, 4} -> 24" \
"}, {2, 3, 4}]"
def test_Integral():
assert mcode(Integral(sin(sin(x)), x)) == "Hold[Integrate[Sin[Sin[x]], x]]"
assert mcode(Integral(exp(-x**2 - y**2),
(x, -oo, oo),
(y, -oo, oo))) == \
"Hold[Integrate[Exp[-x^2 - y^2], {x, -Infinity, Infinity}, " \
"{y, -Infinity, Infinity}]]"
def test_Derivative():
assert mcode(Derivative(sin(x), x)) == "Hold[D[Sin[x], x]]"
assert mcode(Derivative(x, x)) == "Hold[D[x, x]]"
assert mcode(Derivative(sin(x)*y**4, x, 2)) == "Hold[D[y^4*Sin[x], {x, 2}]]"
assert mcode(Derivative(sin(x)*y**4, x, y, x)) == "Hold[D[y^4*Sin[x], x, y, x]]"
assert mcode(Derivative(sin(x)*y**4, x, y, 3, x)) == "Hold[D[y^4*Sin[x], x, {y, 3}, x]]"
def test_Sum():
assert mcode(Sum(sin(x), (x, 0, 10))) == "Hold[Sum[Sin[x], {x, 0, 10}]]"
assert mcode(Sum(exp(-x**2 - y**2),
(x, -oo, oo),
(y, -oo, oo))) == \
"Hold[Sum[Exp[-x^2 - y^2], {x, -Infinity, Infinity}, " \
"{y, -Infinity, Infinity}]]"
def test_comment():
from sympy.printing.mathematica import MCodePrinter
assert MCodePrinter()._get_comment("Hello World") == \
"(* Hello World *)"
def test_userfuncs():
# Dictionary mutation test
some_function = symbols("some_function", cls=Function)
my_user_functions = {"some_function": "SomeFunction"}
assert mcode(
some_function(z),
user_functions=my_user_functions) == \
'SomeFunction[z]'
assert mcode(
some_function(z),
user_functions=my_user_functions) == \
'SomeFunction[z]'
# List argument test
my_user_functions = \
{"some_function": [(lambda x: True, "SomeOtherFunction")]}
assert mcode(
some_function(z),
user_functions=my_user_functions) == \
'SomeOtherFunction[z]'

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,365 @@
from sympy.concrete.summations import Sum
from sympy.core.mod import Mod
from sympy.core.relational import (Equality, Unequality)
from sympy.core.symbol import Symbol
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.piecewise import Piecewise
from sympy.functions.special.gamma_functions import polygamma
from sympy.functions.special.error_functions import (Si, Ci)
from sympy.matrices.expressions.blockmatrix import BlockMatrix
from sympy.matrices.expressions.matexpr import MatrixSymbol
from sympy.matrices.expressions.special import Identity
from sympy.utilities.lambdify import lambdify
from sympy import symbols, Min, Max
from sympy.abc import x, i, j, a, b, c, d
from sympy.core import Pow
from sympy.codegen.matrix_nodes import MatrixSolve
from sympy.codegen.numpy_nodes import logaddexp, logaddexp2
from sympy.codegen.cfunctions import log1p, expm1, hypot, log10, exp2, log2, Sqrt
from sympy.tensor.array import Array
from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayAdd, \
PermuteDims, ArrayDiagonal
from sympy.printing.numpy import NumPyPrinter, SciPyPrinter, _numpy_known_constants, \
_numpy_known_functions, _scipy_known_constants, _scipy_known_functions
from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array
from sympy.testing.pytest import skip, raises
from sympy.external import import_module
np = import_module('numpy')
jax = import_module('jax')
if np:
deafult_float_info = np.finfo(np.array([]).dtype)
NUMPY_DEFAULT_EPSILON = deafult_float_info.eps
def test_numpy_piecewise_regression():
"""
NumPyPrinter needs to print Piecewise()'s choicelist as a list to avoid
breaking compatibility with numpy 1.8. This is not necessary in numpy 1.9+.
See gh-9747 and gh-9749 for details.
"""
printer = NumPyPrinter()
p = Piecewise((1, x < 0), (0, True))
assert printer.doprint(p) == \
'numpy.select([numpy.less(x, 0),True], [1,0], default=numpy.nan)'
assert printer.module_imports == {'numpy': {'select', 'less', 'nan'}}
def test_numpy_logaddexp():
lae = logaddexp(a, b)
assert NumPyPrinter().doprint(lae) == 'numpy.logaddexp(a, b)'
lae2 = logaddexp2(a, b)
assert NumPyPrinter().doprint(lae2) == 'numpy.logaddexp2(a, b)'
def test_sum():
if not np:
skip("NumPy not installed")
s = Sum(x ** i, (i, a, b))
f = lambdify((a, b, x), s, 'numpy')
a_, b_ = 0, 10
x_ = np.linspace(-1, +1, 10)
assert np.allclose(f(a_, b_, x_), sum(x_ ** i_ for i_ in range(a_, b_ + 1)))
s = Sum(i * x, (i, a, b))
f = lambdify((a, b, x), s, 'numpy')
a_, b_ = 0, 10
x_ = np.linspace(-1, +1, 10)
assert np.allclose(f(a_, b_, x_), sum(i_ * x_ for i_ in range(a_, b_ + 1)))
def test_multiple_sums():
if not np:
skip("NumPy not installed")
s = Sum((x + j) * i, (i, a, b), (j, c, d))
f = lambdify((a, b, c, d, x), s, 'numpy')
a_, b_ = 0, 10
c_, d_ = 11, 21
x_ = np.linspace(-1, +1, 10)
assert np.allclose(f(a_, b_, c_, d_, x_),
sum((x_ + j_) * i_ for i_ in range(a_, b_ + 1) for j_ in range(c_, d_ + 1)))
def test_codegen_einsum():
if not np:
skip("NumPy not installed")
M = MatrixSymbol("M", 2, 2)
N = MatrixSymbol("N", 2, 2)
cg = convert_matrix_to_array(M * N)
f = lambdify((M, N), cg, 'numpy')
ma = np.array([[1, 2], [3, 4]])
mb = np.array([[1,-2], [-1, 3]])
assert (f(ma, mb) == np.matmul(ma, mb)).all()
def test_codegen_extra():
if not np:
skip("NumPy not installed")
M = MatrixSymbol("M", 2, 2)
N = MatrixSymbol("N", 2, 2)
P = MatrixSymbol("P", 2, 2)
Q = MatrixSymbol("Q", 2, 2)
ma = np.array([[1, 2], [3, 4]])
mb = np.array([[1,-2], [-1, 3]])
mc = np.array([[2, 0], [1, 2]])
md = np.array([[1,-1], [4, 7]])
cg = ArrayTensorProduct(M, N)
f = lambdify((M, N), cg, 'numpy')
assert (f(ma, mb) == np.einsum(ma, [0, 1], mb, [2, 3])).all()
cg = ArrayAdd(M, N)
f = lambdify((M, N), cg, 'numpy')
assert (f(ma, mb) == ma+mb).all()
cg = ArrayAdd(M, N, P)
f = lambdify((M, N, P), cg, 'numpy')
assert (f(ma, mb, mc) == ma+mb+mc).all()
cg = ArrayAdd(M, N, P, Q)
f = lambdify((M, N, P, Q), cg, 'numpy')
assert (f(ma, mb, mc, md) == ma+mb+mc+md).all()
cg = PermuteDims(M, [1, 0])
f = lambdify((M,), cg, 'numpy')
assert (f(ma) == ma.T).all()
cg = PermuteDims(ArrayTensorProduct(M, N), [1, 2, 3, 0])
f = lambdify((M, N), cg, 'numpy')
assert (f(ma, mb) == np.transpose(np.einsum(ma, [0, 1], mb, [2, 3]), (1, 2, 3, 0))).all()
cg = ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2))
f = lambdify((M, N), cg, 'numpy')
assert (f(ma, mb) == np.diagonal(np.einsum(ma, [0, 1], mb, [2, 3]), axis1=1, axis2=2)).all()
def test_relational():
if not np:
skip("NumPy not installed")
e = Equality(x, 1)
f = lambdify((x,), e)
x_ = np.array([0, 1, 2])
assert np.array_equal(f(x_), [False, True, False])
e = Unequality(x, 1)
f = lambdify((x,), e)
x_ = np.array([0, 1, 2])
assert np.array_equal(f(x_), [True, False, True])
e = (x < 1)
f = lambdify((x,), e)
x_ = np.array([0, 1, 2])
assert np.array_equal(f(x_), [True, False, False])
e = (x <= 1)
f = lambdify((x,), e)
x_ = np.array([0, 1, 2])
assert np.array_equal(f(x_), [True, True, False])
e = (x > 1)
f = lambdify((x,), e)
x_ = np.array([0, 1, 2])
assert np.array_equal(f(x_), [False, False, True])
e = (x >= 1)
f = lambdify((x,), e)
x_ = np.array([0, 1, 2])
assert np.array_equal(f(x_), [False, True, True])
def test_mod():
if not np:
skip("NumPy not installed")
e = Mod(a, b)
f = lambdify((a, b), e)
a_ = np.array([0, 1, 2, 3])
b_ = 2
assert np.array_equal(f(a_, b_), [0, 1, 0, 1])
a_ = np.array([0, 1, 2, 3])
b_ = np.array([2, 2, 2, 2])
assert np.array_equal(f(a_, b_), [0, 1, 0, 1])
a_ = np.array([2, 3, 4, 5])
b_ = np.array([2, 3, 4, 5])
assert np.array_equal(f(a_, b_), [0, 0, 0, 0])
def test_pow():
if not np:
skip('NumPy not installed')
expr = Pow(2, -1, evaluate=False)
f = lambdify([], expr, 'numpy')
assert f() == 0.5
def test_expm1():
if not np:
skip("NumPy not installed")
f = lambdify((a,), expm1(a), 'numpy')
assert abs(f(1e-10) - 1e-10 - 5e-21) <= 1e-10 * NUMPY_DEFAULT_EPSILON
def test_log1p():
if not np:
skip("NumPy not installed")
f = lambdify((a,), log1p(a), 'numpy')
assert abs(f(1e-99) - 1e-99) <= 1e-99 * NUMPY_DEFAULT_EPSILON
def test_hypot():
if not np:
skip("NumPy not installed")
assert abs(lambdify((a, b), hypot(a, b), 'numpy')(3, 4) - 5) <= NUMPY_DEFAULT_EPSILON
def test_log10():
if not np:
skip("NumPy not installed")
assert abs(lambdify((a,), log10(a), 'numpy')(100) - 2) <= NUMPY_DEFAULT_EPSILON
def test_exp2():
if not np:
skip("NumPy not installed")
assert abs(lambdify((a,), exp2(a), 'numpy')(5) - 32) <= NUMPY_DEFAULT_EPSILON
def test_log2():
if not np:
skip("NumPy not installed")
assert abs(lambdify((a,), log2(a), 'numpy')(256) - 8) <= NUMPY_DEFAULT_EPSILON
def test_Sqrt():
if not np:
skip("NumPy not installed")
assert abs(lambdify((a,), Sqrt(a), 'numpy')(4) - 2) <= NUMPY_DEFAULT_EPSILON
def test_sqrt():
if not np:
skip("NumPy not installed")
assert abs(lambdify((a,), sqrt(a), 'numpy')(4) - 2) <= NUMPY_DEFAULT_EPSILON
def test_matsolve():
if not np:
skip("NumPy not installed")
M = MatrixSymbol("M", 3, 3)
x = MatrixSymbol("x", 3, 1)
expr = M**(-1) * x + x
matsolve_expr = MatrixSolve(M, x) + x
f = lambdify((M, x), expr)
f_matsolve = lambdify((M, x), matsolve_expr)
m0 = np.array([[1, 2, 3], [3, 2, 5], [5, 6, 7]])
assert np.linalg.matrix_rank(m0) == 3
x0 = np.array([3, 4, 5])
assert np.allclose(f_matsolve(m0, x0), f(m0, x0))
def test_16857():
if not np:
skip("NumPy not installed")
a_1 = MatrixSymbol('a_1', 10, 3)
a_2 = MatrixSymbol('a_2', 10, 3)
a_3 = MatrixSymbol('a_3', 10, 3)
a_4 = MatrixSymbol('a_4', 10, 3)
A = BlockMatrix([[a_1, a_2], [a_3, a_4]])
assert A.shape == (20, 6)
printer = NumPyPrinter()
assert printer.doprint(A) == 'numpy.block([[a_1, a_2], [a_3, a_4]])'
def test_issue_17006():
if not np:
skip("NumPy not installed")
M = MatrixSymbol("M", 2, 2)
f = lambdify(M, M + Identity(2))
ma = np.array([[1, 2], [3, 4]])
mr = np.array([[2, 2], [3, 5]])
assert (f(ma) == mr).all()
from sympy.core.symbol import symbols
n = symbols('n', integer=True)
N = MatrixSymbol("M", n, n)
raises(NotImplementedError, lambda: lambdify(N, N + Identity(n)))
def test_jax_tuple_compatibility():
if not jax:
skip("Jax not installed")
x, y, z = symbols('x y z')
expr = Max(x, y, z) + Min(x, y, z)
func = lambdify((x, y, z), expr, 'jax')
input_tuple1, input_tuple2 = (1, 2, 3), (4, 5, 6)
input_array1, input_array2 = jax.numpy.asarray(input_tuple1), jax.numpy.asarray(input_tuple2)
assert np.allclose(func(*input_tuple1), func(*input_array1))
assert np.allclose(func(*input_tuple2), func(*input_array2))
def test_numpy_array():
assert NumPyPrinter().doprint(Array(((1, 2), (3, 5)))) == 'numpy.array([[1, 2], [3, 5]])'
assert NumPyPrinter().doprint(Array((1, 2))) == 'numpy.array((1, 2))'
def test_numpy_known_funcs_consts():
assert _numpy_known_constants['NaN'] == 'numpy.nan'
assert _numpy_known_constants['EulerGamma'] == 'numpy.euler_gamma'
assert _numpy_known_functions['acos'] == 'numpy.arccos'
assert _numpy_known_functions['log'] == 'numpy.log'
def test_scipy_known_funcs_consts():
assert _scipy_known_constants['GoldenRatio'] == 'scipy.constants.golden_ratio'
assert _scipy_known_constants['Pi'] == 'scipy.constants.pi'
assert _scipy_known_functions['erf'] == 'scipy.special.erf'
assert _scipy_known_functions['factorial'] == 'scipy.special.factorial'
def test_numpy_print_methods():
prntr = NumPyPrinter()
assert hasattr(prntr, '_print_acos')
assert hasattr(prntr, '_print_log')
def test_scipy_print_methods():
prntr = SciPyPrinter()
assert hasattr(prntr, '_print_acos')
assert hasattr(prntr, '_print_log')
assert hasattr(prntr, '_print_erf')
assert hasattr(prntr, '_print_factorial')
assert hasattr(prntr, '_print_chebyshevt')
k = Symbol('k', integer=True, nonnegative=True)
x = Symbol('x', real=True)
assert prntr.doprint(polygamma(k, x)) == "scipy.special.polygamma(k, x)"
assert prntr.doprint(Si(x)) == "scipy.special.sici(x)[0]"
assert prntr.doprint(Ci(x)) == "scipy.special.sici(x)[1]"

View File

@ -0,0 +1,515 @@
from sympy.core import (S, pi, oo, symbols, Function, Rational, Integer,
Tuple, Symbol, EulerGamma, GoldenRatio, Catalan,
Lambda, Mul, Pow, Mod, Eq, Ne, Le, Lt, Gt, Ge)
from sympy.codegen.matrix_nodes import MatrixSolve
from sympy.functions import (arg, atan2, bernoulli, beta, ceiling, chebyshevu,
chebyshevt, conjugate, DiracDelta, exp, expint,
factorial, floor, harmonic, Heaviside, im,
laguerre, LambertW, log, Max, Min, Piecewise,
polylog, re, RisingFactorial, sign, sinc, sqrt,
zeta, binomial, legendre, dirichlet_eta,
riemann_xi)
from sympy.functions import (sin, cos, tan, cot, sec, csc, asin, acos, acot,
atan, asec, acsc, sinh, cosh, tanh, coth, csch,
sech, asinh, acosh, atanh, acoth, asech, acsch)
from sympy.testing.pytest import raises, XFAIL
from sympy.utilities.lambdify import implemented_function
from sympy.matrices import (eye, Matrix, MatrixSymbol, Identity,
HadamardProduct, SparseMatrix, HadamardPower)
from sympy.functions.special.bessel import (jn, yn, besselj, bessely, besseli,
besselk, hankel1, hankel2, airyai,
airybi, airyaiprime, airybiprime)
from sympy.functions.special.gamma_functions import (gamma, lowergamma,
uppergamma, loggamma,
polygamma)
from sympy.functions.special.error_functions import (Chi, Ci, erf, erfc, erfi,
erfcinv, erfinv, fresnelc,
fresnels, li, Shi, Si, Li,
erf2, Ei)
from sympy.printing.octave import octave_code, octave_code as mcode
x, y, z = symbols('x,y,z')
def test_Integer():
assert mcode(Integer(67)) == "67"
assert mcode(Integer(-1)) == "-1"
def test_Rational():
assert mcode(Rational(3, 7)) == "3/7"
assert mcode(Rational(18, 9)) == "2"
assert mcode(Rational(3, -7)) == "-3/7"
assert mcode(Rational(-3, -7)) == "3/7"
assert mcode(x + Rational(3, 7)) == "x + 3/7"
assert mcode(Rational(3, 7)*x) == "3*x/7"
def test_Relational():
assert mcode(Eq(x, y)) == "x == y"
assert mcode(Ne(x, y)) == "x != y"
assert mcode(Le(x, y)) == "x <= y"
assert mcode(Lt(x, y)) == "x < y"
assert mcode(Gt(x, y)) == "x > y"
assert mcode(Ge(x, y)) == "x >= y"
def test_Function():
assert mcode(sin(x) ** cos(x)) == "sin(x).^cos(x)"
assert mcode(sign(x)) == "sign(x)"
assert mcode(exp(x)) == "exp(x)"
assert mcode(log(x)) == "log(x)"
assert mcode(factorial(x)) == "factorial(x)"
assert mcode(floor(x)) == "floor(x)"
assert mcode(atan2(y, x)) == "atan2(y, x)"
assert mcode(beta(x, y)) == 'beta(x, y)'
assert mcode(polylog(x, y)) == 'polylog(x, y)'
assert mcode(harmonic(x)) == 'harmonic(x)'
assert mcode(bernoulli(x)) == "bernoulli(x)"
assert mcode(bernoulli(x, y)) == "bernoulli(x, y)"
assert mcode(legendre(x, y)) == "legendre(x, y)"
def test_Function_change_name():
assert mcode(abs(x)) == "abs(x)"
assert mcode(ceiling(x)) == "ceil(x)"
assert mcode(arg(x)) == "angle(x)"
assert mcode(im(x)) == "imag(x)"
assert mcode(re(x)) == "real(x)"
assert mcode(conjugate(x)) == "conj(x)"
assert mcode(chebyshevt(y, x)) == "chebyshevT(y, x)"
assert mcode(chebyshevu(y, x)) == "chebyshevU(y, x)"
assert mcode(laguerre(x, y)) == "laguerreL(x, y)"
assert mcode(Chi(x)) == "coshint(x)"
assert mcode(Shi(x)) == "sinhint(x)"
assert mcode(Ci(x)) == "cosint(x)"
assert mcode(Si(x)) == "sinint(x)"
assert mcode(li(x)) == "logint(x)"
assert mcode(loggamma(x)) == "gammaln(x)"
assert mcode(polygamma(x, y)) == "psi(x, y)"
assert mcode(RisingFactorial(x, y)) == "pochhammer(x, y)"
assert mcode(DiracDelta(x)) == "dirac(x)"
assert mcode(DiracDelta(x, 3)) == "dirac(3, x)"
assert mcode(Heaviside(x)) == "heaviside(x, 1/2)"
assert mcode(Heaviside(x, y)) == "heaviside(x, y)"
assert mcode(binomial(x, y)) == "bincoeff(x, y)"
assert mcode(Mod(x, y)) == "mod(x, y)"
def test_minmax():
assert mcode(Max(x, y) + Min(x, y)) == "max(x, y) + min(x, y)"
assert mcode(Max(x, y, z)) == "max(x, max(y, z))"
assert mcode(Min(x, y, z)) == "min(x, min(y, z))"
def test_Pow():
assert mcode(x**3) == "x.^3"
assert mcode(x**(y**3)) == "x.^(y.^3)"
assert mcode(x**Rational(2, 3)) == 'x.^(2/3)'
g = implemented_function('g', Lambda(x, 2*x))
assert mcode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \
"(3.5*2*x).^(-x + y.^x)./(x.^2 + y)"
# For issue 14160
assert mcode(Mul(-2, x, Pow(Mul(y,y,evaluate=False), -1, evaluate=False),
evaluate=False)) == '-2*x./(y.*y)'
def test_basic_ops():
assert mcode(x*y) == "x.*y"
assert mcode(x + y) == "x + y"
assert mcode(x - y) == "x - y"
assert mcode(-x) == "-x"
def test_1_over_x_and_sqrt():
# 1.0 and 0.5 would do something different in regular StrPrinter,
# but these are exact in IEEE floating point so no different here.
assert mcode(1/x) == '1./x'
assert mcode(x**-1) == mcode(x**-1.0) == '1./x'
assert mcode(1/sqrt(x)) == '1./sqrt(x)'
assert mcode(x**-S.Half) == mcode(x**-0.5) == '1./sqrt(x)'
assert mcode(sqrt(x)) == 'sqrt(x)'
assert mcode(x**S.Half) == mcode(x**0.5) == 'sqrt(x)'
assert mcode(1/pi) == '1/pi'
assert mcode(pi**-1) == mcode(pi**-1.0) == '1/pi'
assert mcode(pi**-0.5) == '1/sqrt(pi)'
def test_mix_number_mult_symbols():
assert mcode(3*x) == "3*x"
assert mcode(pi*x) == "pi*x"
assert mcode(3/x) == "3./x"
assert mcode(pi/x) == "pi./x"
assert mcode(x/3) == "x/3"
assert mcode(x/pi) == "x/pi"
assert mcode(x*y) == "x.*y"
assert mcode(3*x*y) == "3*x.*y"
assert mcode(3*pi*x*y) == "3*pi*x.*y"
assert mcode(x/y) == "x./y"
assert mcode(3*x/y) == "3*x./y"
assert mcode(x*y/z) == "x.*y./z"
assert mcode(x/y*z) == "x.*z./y"
assert mcode(1/x/y) == "1./(x.*y)"
assert mcode(2*pi*x/y/z) == "2*pi*x./(y.*z)"
assert mcode(3*pi/x) == "3*pi./x"
assert mcode(S(3)/5) == "3/5"
assert mcode(S(3)/5*x) == "3*x/5"
assert mcode(x/y/z) == "x./(y.*z)"
assert mcode((x+y)/z) == "(x + y)./z"
assert mcode((x+y)/(z+x)) == "(x + y)./(x + z)"
assert mcode((x+y)/EulerGamma) == "(x + y)/%s" % EulerGamma.evalf(17)
assert mcode(x/3/pi) == "x/(3*pi)"
assert mcode(S(3)/5*x*y/pi) == "3*x.*y/(5*pi)"
def test_mix_number_pow_symbols():
assert mcode(pi**3) == 'pi^3'
assert mcode(x**2) == 'x.^2'
assert mcode(x**(pi**3)) == 'x.^(pi^3)'
assert mcode(x**y) == 'x.^y'
assert mcode(x**(y**z)) == 'x.^(y.^z)'
assert mcode((x**y)**z) == '(x.^y).^z'
def test_imag():
I = S('I')
assert mcode(I) == "1i"
assert mcode(5*I) == "5i"
assert mcode((S(3)/2)*I) == "3*1i/2"
assert mcode(3+4*I) == "3 + 4i"
assert mcode(sqrt(3)*I) == "sqrt(3)*1i"
def test_constants():
assert mcode(pi) == "pi"
assert mcode(oo) == "inf"
assert mcode(-oo) == "-inf"
assert mcode(S.NegativeInfinity) == "-inf"
assert mcode(S.NaN) == "NaN"
assert mcode(S.Exp1) == "exp(1)"
assert mcode(exp(1)) == "exp(1)"
def test_constants_other():
assert mcode(2*GoldenRatio) == "2*(1+sqrt(5))/2"
assert mcode(2*Catalan) == "2*%s" % Catalan.evalf(17)
assert mcode(2*EulerGamma) == "2*%s" % EulerGamma.evalf(17)
def test_boolean():
assert mcode(x & y) == "x & y"
assert mcode(x | y) == "x | y"
assert mcode(~x) == "~x"
assert mcode(x & y & z) == "x & y & z"
assert mcode(x | y | z) == "x | y | z"
assert mcode((x & y) | z) == "z | x & y"
assert mcode((x | y) & z) == "z & (x | y)"
def test_KroneckerDelta():
from sympy.functions import KroneckerDelta
assert mcode(KroneckerDelta(x, y)) == "double(x == y)"
assert mcode(KroneckerDelta(x, y + 1)) == "double(x == (y + 1))"
assert mcode(KroneckerDelta(2**x, y)) == "double((2.^x) == y)"
def test_Matrices():
assert mcode(Matrix(1, 1, [10])) == "10"
A = Matrix([[1, sin(x/2), abs(x)],
[0, 1, pi],
[0, exp(1), ceiling(x)]]);
expected = "[1 sin(x/2) abs(x); 0 1 pi; 0 exp(1) ceil(x)]"
assert mcode(A) == expected
# row and columns
assert mcode(A[:,0]) == "[1; 0; 0]"
assert mcode(A[0,:]) == "[1 sin(x/2) abs(x)]"
# empty matrices
assert mcode(Matrix(0, 0, [])) == '[]'
assert mcode(Matrix(0, 3, [])) == 'zeros(0, 3)'
# annoying to read but correct
assert mcode(Matrix([[x, x - y, -y]])) == "[x x - y -y]"
def test_vector_entries_hadamard():
# For a row or column, user might to use the other dimension
A = Matrix([[1, sin(2/x), 3*pi/x/5]])
assert mcode(A) == "[1 sin(2./x) 3*pi./(5*x)]"
assert mcode(A.T) == "[1; sin(2./x); 3*pi./(5*x)]"
@XFAIL
def test_Matrices_entries_not_hadamard():
# For Matrix with col >= 2, row >= 2, they need to be scalars
# FIXME: is it worth worrying about this? Its not wrong, just
# leave it user's responsibility to put scalar data for x.
A = Matrix([[1, sin(2/x), 3*pi/x/5], [1, 2, x*y]])
expected = ("[1 sin(2/x) 3*pi/(5*x);\n"
"1 2 x*y]") # <- we give x.*y
assert mcode(A) == expected
def test_MatrixSymbol():
n = Symbol('n', integer=True)
A = MatrixSymbol('A', n, n)
B = MatrixSymbol('B', n, n)
assert mcode(A*B) == "A*B"
assert mcode(B*A) == "B*A"
assert mcode(2*A*B) == "2*A*B"
assert mcode(B*2*A) == "2*B*A"
assert mcode(A*(B + 3*Identity(n))) == "A*(3*eye(n) + B)"
assert mcode(A**(x**2)) == "A^(x.^2)"
assert mcode(A**3) == "A^3"
assert mcode(A**S.Half) == "A^(1/2)"
def test_MatrixSolve():
n = Symbol('n', integer=True)
A = MatrixSymbol('A', n, n)
x = MatrixSymbol('x', n, 1)
assert mcode(MatrixSolve(A, x)) == "A \\ x"
def test_special_matrices():
assert mcode(6*Identity(3)) == "6*eye(3)"
def test_containers():
assert mcode([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \
"{1, 2, 3, {4, 5, {6, 7}}, 8, {9, 10}, 11}"
assert mcode((1, 2, (3, 4))) == "{1, 2, {3, 4}}"
assert mcode([1]) == "{1}"
assert mcode((1,)) == "{1}"
assert mcode(Tuple(*[1, 2, 3])) == "{1, 2, 3}"
assert mcode((1, x*y, (3, x**2))) == "{1, x.*y, {3, x.^2}}"
# scalar, matrix, empty matrix and empty list
assert mcode((1, eye(3), Matrix(0, 0, []), [])) == "{1, [1 0 0; 0 1 0; 0 0 1], [], {}}"
def test_octave_noninline():
source = mcode((x+y)/Catalan, assign_to='me', inline=False)
expected = (
"Catalan = %s;\n"
"me = (x + y)/Catalan;"
) % Catalan.evalf(17)
assert source == expected
def test_octave_piecewise():
expr = Piecewise((x, x < 1), (x**2, True))
assert mcode(expr) == "((x < 1).*(x) + (~(x < 1)).*(x.^2))"
assert mcode(expr, assign_to="r") == (
"r = ((x < 1).*(x) + (~(x < 1)).*(x.^2));")
assert mcode(expr, assign_to="r", inline=False) == (
"if (x < 1)\n"
" r = x;\n"
"else\n"
" r = x.^2;\n"
"end")
expr = Piecewise((x**2, x < 1), (x**3, x < 2), (x**4, x < 3), (x**5, True))
expected = ("((x < 1).*(x.^2) + (~(x < 1)).*( ...\n"
"(x < 2).*(x.^3) + (~(x < 2)).*( ...\n"
"(x < 3).*(x.^4) + (~(x < 3)).*(x.^5))))")
assert mcode(expr) == expected
assert mcode(expr, assign_to="r") == "r = " + expected + ";"
assert mcode(expr, assign_to="r", inline=False) == (
"if (x < 1)\n"
" r = x.^2;\n"
"elseif (x < 2)\n"
" r = x.^3;\n"
"elseif (x < 3)\n"
" r = x.^4;\n"
"else\n"
" r = x.^5;\n"
"end")
# Check that Piecewise without a True (default) condition error
expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
raises(ValueError, lambda: mcode(expr))
def test_octave_piecewise_times_const():
pw = Piecewise((x, x < 1), (x**2, True))
assert mcode(2*pw) == "2*((x < 1).*(x) + (~(x < 1)).*(x.^2))"
assert mcode(pw/x) == "((x < 1).*(x) + (~(x < 1)).*(x.^2))./x"
assert mcode(pw/(x*y)) == "((x < 1).*(x) + (~(x < 1)).*(x.^2))./(x.*y)"
assert mcode(pw/3) == "((x < 1).*(x) + (~(x < 1)).*(x.^2))/3"
def test_octave_matrix_assign_to():
A = Matrix([[1, 2, 3]])
assert mcode(A, assign_to='a') == "a = [1 2 3];"
A = Matrix([[1, 2], [3, 4]])
assert mcode(A, assign_to='A') == "A = [1 2; 3 4];"
def test_octave_matrix_assign_to_more():
# assigning to Symbol or MatrixSymbol requires lhs/rhs match
A = Matrix([[1, 2, 3]])
B = MatrixSymbol('B', 1, 3)
C = MatrixSymbol('C', 2, 3)
assert mcode(A, assign_to=B) == "B = [1 2 3];"
raises(ValueError, lambda: mcode(A, assign_to=x))
raises(ValueError, lambda: mcode(A, assign_to=C))
def test_octave_matrix_1x1():
A = Matrix([[3]])
B = MatrixSymbol('B', 1, 1)
C = MatrixSymbol('C', 1, 2)
assert mcode(A, assign_to=B) == "B = 3;"
# FIXME?
#assert mcode(A, assign_to=x) == "x = 3;"
raises(ValueError, lambda: mcode(A, assign_to=C))
def test_octave_matrix_elements():
A = Matrix([[x, 2, x*y]])
assert mcode(A[0, 0]**2 + A[0, 1] + A[0, 2]) == "x.^2 + x.*y + 2"
A = MatrixSymbol('AA', 1, 3)
assert mcode(A) == "AA"
assert mcode(A[0, 0]**2 + sin(A[0,1]) + A[0,2]) == \
"sin(AA(1, 2)) + AA(1, 1).^2 + AA(1, 3)"
assert mcode(sum(A)) == "AA(1, 1) + AA(1, 2) + AA(1, 3)"
def test_octave_boolean():
assert mcode(True) == "true"
assert mcode(S.true) == "true"
assert mcode(False) == "false"
assert mcode(S.false) == "false"
def test_octave_not_supported():
with raises(NotImplementedError):
mcode(S.ComplexInfinity)
f = Function('f')
assert mcode(f(x).diff(x), strict=False) == (
"% Not supported in Octave:\n"
"% Derivative\n"
"Derivative(f(x), x)"
)
def test_octave_not_supported_not_on_whitelist():
from sympy.functions.special.polynomials import assoc_laguerre
with raises(NotImplementedError):
mcode(assoc_laguerre(x, y, z))
def test_octave_expint():
assert mcode(expint(1, x)) == "expint(x)"
with raises(NotImplementedError):
mcode(expint(2, x))
assert mcode(expint(y, x), strict=False) == (
"% Not supported in Octave:\n"
"% expint\n"
"expint(y, x)"
)
def test_trick_indent_with_end_else_words():
# words starting with "end" or "else" do not confuse the indenter
t1 = S('endless');
t2 = S('elsewhere');
pw = Piecewise((t1, x < 0), (t2, x <= 1), (1, True))
assert mcode(pw, inline=False) == (
"if (x < 0)\n"
" endless\n"
"elseif (x <= 1)\n"
" elsewhere\n"
"else\n"
" 1\n"
"end")
def test_hadamard():
A = MatrixSymbol('A', 3, 3)
B = MatrixSymbol('B', 3, 3)
v = MatrixSymbol('v', 3, 1)
h = MatrixSymbol('h', 1, 3)
C = HadamardProduct(A, B)
n = Symbol('n')
assert mcode(C) == "A.*B"
assert mcode(C*v) == "(A.*B)*v"
assert mcode(h*C*v) == "h*(A.*B)*v"
assert mcode(C*A) == "(A.*B)*A"
# mixing Hadamard and scalar strange b/c we vectorize scalars
assert mcode(C*x*y) == "(x.*y)*(A.*B)"
# Testing HadamardPower:
assert mcode(HadamardPower(A, n)) == "A.**n"
assert mcode(HadamardPower(A, 1+n)) == "A.**(n + 1)"
assert mcode(HadamardPower(A*B.T, 1+n)) == "(A*B.T).**(n + 1)"
def test_sparse():
M = SparseMatrix(5, 6, {})
M[2, 2] = 10;
M[1, 2] = 20;
M[1, 3] = 22;
M[0, 3] = 30;
M[3, 0] = x*y;
assert mcode(M) == (
"sparse([4 2 3 1 2], [1 3 3 4 4], [x.*y 20 10 30 22], 5, 6)"
)
def test_sinc():
assert mcode(sinc(x)) == 'sinc(x/pi)'
assert mcode(sinc(x + 3)) == 'sinc((x + 3)/pi)'
assert mcode(sinc(pi*(x + 3))) == 'sinc(x + 3)'
def test_trigfun():
for f in (sin, cos, tan, cot, sec, csc, asin, acos, acot, atan, asec, acsc,
sinh, cosh, tanh, coth, csch, sech, asinh, acosh, atanh, acoth,
asech, acsch):
assert octave_code(f(x) == f.__name__ + '(x)')
def test_specfun():
n = Symbol('n')
for f in [besselj, bessely, besseli, besselk]:
assert octave_code(f(n, x)) == f.__name__ + '(n, x)'
for f in (erfc, erfi, erf, erfinv, erfcinv, fresnelc, fresnels, gamma):
assert octave_code(f(x)) == f.__name__ + '(x)'
assert octave_code(hankel1(n, x)) == 'besselh(n, 1, x)'
assert octave_code(hankel2(n, x)) == 'besselh(n, 2, x)'
assert octave_code(airyai(x)) == 'airy(0, x)'
assert octave_code(airyaiprime(x)) == 'airy(1, x)'
assert octave_code(airybi(x)) == 'airy(2, x)'
assert octave_code(airybiprime(x)) == 'airy(3, x)'
assert octave_code(uppergamma(n, x)) == '(gammainc(x, n, \'upper\').*gamma(n))'
assert octave_code(lowergamma(n, x)) == '(gammainc(x, n).*gamma(n))'
assert octave_code(z**lowergamma(n, x)) == 'z.^(gammainc(x, n).*gamma(n))'
assert octave_code(jn(n, x)) == 'sqrt(2)*sqrt(pi)*sqrt(1./x).*besselj(n + 1/2, x)/2'
assert octave_code(yn(n, x)) == 'sqrt(2)*sqrt(pi)*sqrt(1./x).*bessely(n + 1/2, x)/2'
assert octave_code(LambertW(x)) == 'lambertw(x)'
assert octave_code(LambertW(x, n)) == 'lambertw(n, x)'
# Automatic rewrite
assert octave_code(Ei(x)) == '(logint(exp(x)))'
assert octave_code(dirichlet_eta(x)) == '(((x == 1).*(log(2)) + (~(x == 1)).*((1 - 2.^(1 - x)).*zeta(x))))'
assert octave_code(riemann_xi(x)) == '(pi.^(-x/2).*x.*(x - 1).*gamma(x/2).*zeta(x)/2)'
def test_MatrixElement_printing():
# test cases for issue #11821
A = MatrixSymbol("A", 1, 3)
B = MatrixSymbol("B", 1, 3)
C = MatrixSymbol("C", 1, 3)
assert mcode(A[0, 0]) == "A(1, 1)"
assert mcode(3 * A[0, 0]) == "3*A(1, 1)"
F = C[0, 0].subs(C, A - B)
assert mcode(F) == "(A - B)(1, 1)"
def test_zeta_printing_issue_14820():
assert octave_code(zeta(x)) == 'zeta(x)'
with raises(NotImplementedError):
octave_code(zeta(x, y))
def test_automatic_rewrite():
assert octave_code(Li(x)) == '(logint(x) - logint(2))'
assert octave_code(erf2(x, y)) == '(-erf(x) + erf(y))'

View File

@ -0,0 +1,89 @@
from sympy.concrete.products import Product
from sympy.concrete.summations import Sum
from sympy.core.function import Derivative
from sympy.core.numbers import Integer, Rational, Float, oo
from sympy.core.relational import Rel
from sympy.core.symbol import symbols
from sympy.functions import sin
from sympy.integrals.integrals import Integral
from sympy.series.order import Order
from sympy.printing.precedence import precedence, PRECEDENCE
x, y = symbols("x,y")
def test_Add():
assert precedence(x + y) == PRECEDENCE["Add"]
assert precedence(x*y + 1) == PRECEDENCE["Add"]
def test_Function():
assert precedence(sin(x)) == PRECEDENCE["Func"]
def test_Derivative():
assert precedence(Derivative(x, y)) == PRECEDENCE["Atom"]
def test_Integral():
assert precedence(Integral(x, y)) == PRECEDENCE["Atom"]
def test_Mul():
assert precedence(x*y) == PRECEDENCE["Mul"]
assert precedence(-x*y) == PRECEDENCE["Add"]
def test_Number():
assert precedence(Integer(0)) == PRECEDENCE["Atom"]
assert precedence(Integer(1)) == PRECEDENCE["Atom"]
assert precedence(Integer(-1)) == PRECEDENCE["Add"]
assert precedence(Integer(10)) == PRECEDENCE["Atom"]
assert precedence(Rational(5, 2)) == PRECEDENCE["Mul"]
assert precedence(Rational(-5, 2)) == PRECEDENCE["Add"]
assert precedence(Float(5)) == PRECEDENCE["Atom"]
assert precedence(Float(-5)) == PRECEDENCE["Add"]
assert precedence(oo) == PRECEDENCE["Atom"]
assert precedence(-oo) == PRECEDENCE["Add"]
def test_Order():
assert precedence(Order(x)) == PRECEDENCE["Atom"]
def test_Pow():
assert precedence(x**y) == PRECEDENCE["Pow"]
assert precedence(-x**y) == PRECEDENCE["Add"]
assert precedence(x**-y) == PRECEDENCE["Pow"]
def test_Product():
assert precedence(Product(x, (x, y, y + 1))) == PRECEDENCE["Atom"]
def test_Relational():
assert precedence(Rel(x + y, y, "<")) == PRECEDENCE["Relational"]
def test_Sum():
assert precedence(Sum(x, (x, y, y + 1))) == PRECEDENCE["Atom"]
def test_Symbol():
assert precedence(x) == PRECEDENCE["Atom"]
def test_And_Or():
# precedence relations between logical operators, ...
assert precedence(x & y) > precedence(x | y)
assert precedence(~y) > precedence(x & y)
# ... and with other operators (cfr. other programming languages)
assert precedence(x + y) > precedence(x | y)
assert precedence(x + y) > precedence(x & y)
assert precedence(x*y) > precedence(x | y)
assert precedence(x*y) > precedence(x & y)
assert precedence(~y) > precedence(x*y)
assert precedence(~y) > precedence(x - y)
# double checks
assert precedence(x & y) == PRECEDENCE["And"]
assert precedence(x | y) == PRECEDENCE["Or"]
assert precedence(~y) == PRECEDENCE["Not"]

View File

@ -0,0 +1,38 @@
# -*- coding: utf-8 -*-
from sympy.core.relational import Eq
from sympy.core.symbol import Symbol
from sympy.functions.elementary.piecewise import Piecewise
from sympy.printing.preview import preview
from io import BytesIO
def test_preview():
x = Symbol('x')
obj = BytesIO()
try:
preview(x, output='png', viewer='BytesIO', outputbuffer=obj)
except RuntimeError:
pass # latex not installed on CI server
def test_preview_unicode_symbol():
# issue 9107
a = Symbol('α')
obj = BytesIO()
try:
preview(a, output='png', viewer='BytesIO', outputbuffer=obj)
except RuntimeError:
pass # latex not installed on CI server
def test_preview_latex_construct_in_expr():
# see PR 9801
x = Symbol('x')
pw = Piecewise((1, Eq(x, 0)), (0, True))
obj = BytesIO()
try:
preview(pw, output='png', viewer='BytesIO', outputbuffer=obj)
except RuntimeError:
pass # latex not installed on CI server

View File

@ -0,0 +1,429 @@
from sympy.codegen import Assignment
from sympy.codegen.ast import none
from sympy.codegen.cfunctions import expm1, log1p
from sympy.codegen.scipy_nodes import cosm1
from sympy.codegen.matrix_nodes import MatrixSolve
from sympy.core import Expr, Mod, symbols, Eq, Le, Gt, zoo, oo, Rational, Pow
from sympy.core.numbers import pi
from sympy.core.singleton import S
from sympy.functions import acos, KroneckerDelta, Piecewise, sign, sqrt, Min, Max, cot, acsch, asec, coth, sec
from sympy.logic import And, Or
from sympy.matrices import SparseMatrix, MatrixSymbol, Identity
from sympy.printing.pycode import (
MpmathPrinter, PythonCodePrinter, pycode, SymPyPrinter
)
from sympy.printing.tensorflow import TensorflowPrinter
from sympy.printing.numpy import NumPyPrinter, SciPyPrinter
from sympy.testing.pytest import raises, skip
from sympy.tensor import IndexedBase, Idx
from sympy.tensor.array.expressions.array_expressions import ArraySymbol, ArrayDiagonal, ArrayContraction, ZeroArray, OneArray
from sympy.external import import_module
from sympy.functions.special.gamma_functions import loggamma
x, y, z = symbols('x y z')
p = IndexedBase("p")
def test_PythonCodePrinter():
prntr = PythonCodePrinter()
assert not prntr.module_imports
assert prntr.doprint(x**y) == 'x**y'
assert prntr.doprint(Mod(x, 2)) == 'x % 2'
assert prntr.doprint(-Mod(x, y)) == '-(x % y)'
assert prntr.doprint(Mod(-x, y)) == '(-x) % y'
assert prntr.doprint(And(x, y)) == 'x and y'
assert prntr.doprint(Or(x, y)) == 'x or y'
assert prntr.doprint(1/(x+y)) == '1/(x + y)'
assert not prntr.module_imports
assert prntr.doprint(pi) == 'math.pi'
assert prntr.module_imports == {'math': {'pi'}}
assert prntr.doprint(x**Rational(1, 2)) == 'math.sqrt(x)'
assert prntr.doprint(sqrt(x)) == 'math.sqrt(x)'
assert prntr.module_imports == {'math': {'pi', 'sqrt'}}
assert prntr.doprint(acos(x)) == 'math.acos(x)'
assert prntr.doprint(cot(x)) == '(1/math.tan(x))'
assert prntr.doprint(coth(x)) == '((math.exp(x) + math.exp(-x))/(math.exp(x) - math.exp(-x)))'
assert prntr.doprint(asec(x)) == '(math.acos(1/x))'
assert prntr.doprint(acsch(x)) == '(math.log(math.sqrt(1 + x**(-2)) + 1/x))'
assert prntr.doprint(Assignment(x, 2)) == 'x = 2'
assert prntr.doprint(Piecewise((1, Eq(x, 0)),
(2, x>6))) == '((1) if (x == 0) else (2) if (x > 6) else None)'
assert prntr.doprint(Piecewise((2, Le(x, 0)),
(3, Gt(x, 0)), evaluate=False)) == '((2) if (x <= 0) else'\
' (3) if (x > 0) else None)'
assert prntr.doprint(sign(x)) == '(0.0 if x == 0 else math.copysign(1, x))'
assert prntr.doprint(p[0, 1]) == 'p[0, 1]'
assert prntr.doprint(KroneckerDelta(x,y)) == '(1 if x == y else 0)'
assert prntr.doprint((2,3)) == "(2, 3)"
assert prntr.doprint([2,3]) == "[2, 3]"
assert prntr.doprint(Min(x, y)) == "min(x, y)"
assert prntr.doprint(Max(x, y)) == "max(x, y)"
def test_PythonCodePrinter_standard():
prntr = PythonCodePrinter()
assert prntr.standard == 'python3'
raises(ValueError, lambda: PythonCodePrinter({'standard':'python4'}))
def test_MpmathPrinter():
p = MpmathPrinter()
assert p.doprint(sign(x)) == 'mpmath.sign(x)'
assert p.doprint(Rational(1, 2)) == 'mpmath.mpf(1)/mpmath.mpf(2)'
assert p.doprint(S.Exp1) == 'mpmath.e'
assert p.doprint(S.Pi) == 'mpmath.pi'
assert p.doprint(S.GoldenRatio) == 'mpmath.phi'
assert p.doprint(S.EulerGamma) == 'mpmath.euler'
assert p.doprint(S.NaN) == 'mpmath.nan'
assert p.doprint(S.Infinity) == 'mpmath.inf'
assert p.doprint(S.NegativeInfinity) == 'mpmath.ninf'
assert p.doprint(loggamma(x)) == 'mpmath.loggamma(x)'
def test_NumPyPrinter():
from sympy.core.function import Lambda
from sympy.matrices.expressions.adjoint import Adjoint
from sympy.matrices.expressions.diagonal import (DiagMatrix, DiagonalMatrix, DiagonalOf)
from sympy.matrices.expressions.funcmatrix import FunctionMatrix
from sympy.matrices.expressions.hadamard import HadamardProduct
from sympy.matrices.expressions.kronecker import KroneckerProduct
from sympy.matrices.expressions.special import (OneMatrix, ZeroMatrix)
from sympy.abc import a, b
p = NumPyPrinter()
assert p.doprint(sign(x)) == 'numpy.sign(x)'
A = MatrixSymbol("A", 2, 2)
B = MatrixSymbol("B", 2, 2)
C = MatrixSymbol("C", 1, 5)
D = MatrixSymbol("D", 3, 4)
assert p.doprint(A**(-1)) == "numpy.linalg.inv(A)"
assert p.doprint(A**5) == "numpy.linalg.matrix_power(A, 5)"
assert p.doprint(Identity(3)) == "numpy.eye(3)"
u = MatrixSymbol('x', 2, 1)
v = MatrixSymbol('y', 2, 1)
assert p.doprint(MatrixSolve(A, u)) == 'numpy.linalg.solve(A, x)'
assert p.doprint(MatrixSolve(A, u) + v) == 'numpy.linalg.solve(A, x) + y'
assert p.doprint(ZeroMatrix(2, 3)) == "numpy.zeros((2, 3))"
assert p.doprint(OneMatrix(2, 3)) == "numpy.ones((2, 3))"
assert p.doprint(FunctionMatrix(4, 5, Lambda((a, b), a + b))) == \
"numpy.fromfunction(lambda a, b: a + b, (4, 5))"
assert p.doprint(HadamardProduct(A, B)) == "numpy.multiply(A, B)"
assert p.doprint(KroneckerProduct(A, B)) == "numpy.kron(A, B)"
assert p.doprint(Adjoint(A)) == "numpy.conjugate(numpy.transpose(A))"
assert p.doprint(DiagonalOf(A)) == "numpy.reshape(numpy.diag(A), (-1, 1))"
assert p.doprint(DiagMatrix(C)) == "numpy.diagflat(C)"
assert p.doprint(DiagonalMatrix(D)) == "numpy.multiply(D, numpy.eye(3, 4))"
# Workaround for numpy negative integer power errors
assert p.doprint(x**-1) == 'x**(-1.0)'
assert p.doprint(x**-2) == 'x**(-2.0)'
expr = Pow(2, -1, evaluate=False)
assert p.doprint(expr) == "2**(-1.0)"
assert p.doprint(S.Exp1) == 'numpy.e'
assert p.doprint(S.Pi) == 'numpy.pi'
assert p.doprint(S.EulerGamma) == 'numpy.euler_gamma'
assert p.doprint(S.NaN) == 'numpy.nan'
assert p.doprint(S.Infinity) == 'numpy.inf'
assert p.doprint(S.NegativeInfinity) == '-numpy.inf'
# Function rewriting operator precedence fix
assert p.doprint(sec(x)**2) == '(numpy.cos(x)**(-1.0))**2'
def test_issue_18770():
numpy = import_module('numpy')
if not numpy:
skip("numpy not installed.")
from sympy.functions.elementary.miscellaneous import (Max, Min)
from sympy.utilities.lambdify import lambdify
expr1 = Min(0.1*x + 3, x + 1, 0.5*x + 1)
func = lambdify(x, expr1, "numpy")
assert (func(numpy.linspace(0, 3, 3)) == [1.0, 1.75, 2.5 ]).all()
assert func(4) == 3
expr1 = Max(x**2, x**3)
func = lambdify(x,expr1, "numpy")
assert (func(numpy.linspace(-1, 2, 4)) == [1, 0, 1, 8] ).all()
assert func(4) == 64
def test_SciPyPrinter():
p = SciPyPrinter()
expr = acos(x)
assert 'numpy' not in p.module_imports
assert p.doprint(expr) == 'numpy.arccos(x)'
assert 'numpy' in p.module_imports
assert not any(m.startswith('scipy') for m in p.module_imports)
smat = SparseMatrix(2, 5, {(0, 1): 3})
assert p.doprint(smat) == \
'scipy.sparse.coo_matrix(([3], ([0], [1])), shape=(2, 5))'
assert 'scipy.sparse' in p.module_imports
assert p.doprint(S.GoldenRatio) == 'scipy.constants.golden_ratio'
assert p.doprint(S.Pi) == 'scipy.constants.pi'
assert p.doprint(S.Exp1) == 'numpy.e'
def test_pycode_reserved_words():
s1, s2 = symbols('if else')
raises(ValueError, lambda: pycode(s1 + s2, error_on_reserved=True))
py_str = pycode(s1 + s2)
assert py_str in ('else_ + if_', 'if_ + else_')
def test_issue_20762():
# Make sure pycode removes curly braces from subscripted variables
a_b, b, a_11 = symbols('a_{b} b a_{11}')
expr = a_b*b
assert pycode(expr) == 'a_b*b'
expr = a_11*b
assert pycode(expr) == 'a_11*b'
def test_sqrt():
prntr = PythonCodePrinter()
assert prntr._print_Pow(sqrt(x), rational=False) == 'math.sqrt(x)'
assert prntr._print_Pow(1/sqrt(x), rational=False) == '1/math.sqrt(x)'
prntr = PythonCodePrinter({'standard' : 'python3'})
assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)'
assert prntr._print_Pow(1/sqrt(x), rational=True) == 'x**(-1/2)'
prntr = MpmathPrinter()
assert prntr._print_Pow(sqrt(x), rational=False) == 'mpmath.sqrt(x)'
assert prntr._print_Pow(sqrt(x), rational=True) == \
"x**(mpmath.mpf(1)/mpmath.mpf(2))"
prntr = NumPyPrinter()
assert prntr._print_Pow(sqrt(x), rational=False) == 'numpy.sqrt(x)'
assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)'
prntr = SciPyPrinter()
assert prntr._print_Pow(sqrt(x), rational=False) == 'numpy.sqrt(x)'
assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)'
prntr = SymPyPrinter()
assert prntr._print_Pow(sqrt(x), rational=False) == 'sympy.sqrt(x)'
assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)'
def test_frac():
from sympy.functions.elementary.integers import frac
expr = frac(x)
prntr = NumPyPrinter()
assert prntr.doprint(expr) == 'numpy.mod(x, 1)'
prntr = SciPyPrinter()
assert prntr.doprint(expr) == 'numpy.mod(x, 1)'
prntr = PythonCodePrinter()
assert prntr.doprint(expr) == 'x % 1'
prntr = MpmathPrinter()
assert prntr.doprint(expr) == 'mpmath.frac(x)'
prntr = SymPyPrinter()
assert prntr.doprint(expr) == 'sympy.functions.elementary.integers.frac(x)'
class CustomPrintedObject(Expr):
def _numpycode(self, printer):
return 'numpy'
def _mpmathcode(self, printer):
return 'mpmath'
def test_printmethod():
obj = CustomPrintedObject()
assert NumPyPrinter().doprint(obj) == 'numpy'
assert MpmathPrinter().doprint(obj) == 'mpmath'
def test_codegen_ast_nodes():
assert pycode(none) == 'None'
def test_issue_14283():
prntr = PythonCodePrinter()
assert prntr.doprint(zoo) == "math.nan"
assert prntr.doprint(-oo) == "float('-inf')"
def test_NumPyPrinter_print_seq():
n = NumPyPrinter()
assert n._print_seq(range(2)) == '(0, 1,)'
def test_issue_16535_16536():
from sympy.functions.special.gamma_functions import (lowergamma, uppergamma)
a = symbols('a')
expr1 = lowergamma(a, x)
expr2 = uppergamma(a, x)
prntr = SciPyPrinter()
assert prntr.doprint(expr1) == 'scipy.special.gamma(a)*scipy.special.gammainc(a, x)'
assert prntr.doprint(expr2) == 'scipy.special.gamma(a)*scipy.special.gammaincc(a, x)'
p_numpy = NumPyPrinter()
p_pycode = PythonCodePrinter({'strict': False})
for expr in [expr1, expr2]:
with raises(NotImplementedError):
p_numpy.doprint(expr1)
assert "Not supported" in p_pycode.doprint(expr)
def test_Integral():
from sympy.functions.elementary.exponential import exp
from sympy.integrals.integrals import Integral
single = Integral(exp(-x), (x, 0, oo))
double = Integral(x**2*exp(x*y), (x, -z, z), (y, 0, z))
indefinite = Integral(x**2, x)
evaluateat = Integral(x**2, (x, 1))
prntr = SciPyPrinter()
assert prntr.doprint(single) == 'scipy.integrate.quad(lambda x: numpy.exp(-x), 0, numpy.inf)[0]'
assert prntr.doprint(double) == 'scipy.integrate.nquad(lambda x, y: x**2*numpy.exp(x*y), ((-z, z), (0, z)))[0]'
raises(NotImplementedError, lambda: prntr.doprint(indefinite))
raises(NotImplementedError, lambda: prntr.doprint(evaluateat))
prntr = MpmathPrinter()
assert prntr.doprint(single) == 'mpmath.quad(lambda x: mpmath.exp(-x), (0, mpmath.inf))'
assert prntr.doprint(double) == 'mpmath.quad(lambda x, y: x**2*mpmath.exp(x*y), (-z, z), (0, z))'
raises(NotImplementedError, lambda: prntr.doprint(indefinite))
raises(NotImplementedError, lambda: prntr.doprint(evaluateat))
def test_fresnel_integrals():
from sympy.functions.special.error_functions import (fresnelc, fresnels)
expr1 = fresnelc(x)
expr2 = fresnels(x)
prntr = SciPyPrinter()
assert prntr.doprint(expr1) == 'scipy.special.fresnel(x)[1]'
assert prntr.doprint(expr2) == 'scipy.special.fresnel(x)[0]'
p_numpy = NumPyPrinter()
p_pycode = PythonCodePrinter()
p_mpmath = MpmathPrinter()
for expr in [expr1, expr2]:
with raises(NotImplementedError):
p_numpy.doprint(expr)
with raises(NotImplementedError):
p_pycode.doprint(expr)
assert p_mpmath.doprint(expr1) == 'mpmath.fresnelc(x)'
assert p_mpmath.doprint(expr2) == 'mpmath.fresnels(x)'
def test_beta():
from sympy.functions.special.beta_functions import beta
expr = beta(x, y)
prntr = SciPyPrinter()
assert prntr.doprint(expr) == 'scipy.special.beta(x, y)'
prntr = NumPyPrinter()
assert prntr.doprint(expr) == '(math.gamma(x)*math.gamma(y)/math.gamma(x + y))'
prntr = PythonCodePrinter()
assert prntr.doprint(expr) == '(math.gamma(x)*math.gamma(y)/math.gamma(x + y))'
prntr = PythonCodePrinter({'allow_unknown_functions': True})
assert prntr.doprint(expr) == '(math.gamma(x)*math.gamma(y)/math.gamma(x + y))'
prntr = MpmathPrinter()
assert prntr.doprint(expr) == 'mpmath.beta(x, y)'
def test_airy():
from sympy.functions.special.bessel import (airyai, airybi)
expr1 = airyai(x)
expr2 = airybi(x)
prntr = SciPyPrinter()
assert prntr.doprint(expr1) == 'scipy.special.airy(x)[0]'
assert prntr.doprint(expr2) == 'scipy.special.airy(x)[2]'
prntr = NumPyPrinter({'strict': False})
assert "Not supported" in prntr.doprint(expr1)
assert "Not supported" in prntr.doprint(expr2)
prntr = PythonCodePrinter({'strict': False})
assert "Not supported" in prntr.doprint(expr1)
assert "Not supported" in prntr.doprint(expr2)
def test_airy_prime():
from sympy.functions.special.bessel import (airyaiprime, airybiprime)
expr1 = airyaiprime(x)
expr2 = airybiprime(x)
prntr = SciPyPrinter()
assert prntr.doprint(expr1) == 'scipy.special.airy(x)[1]'
assert prntr.doprint(expr2) == 'scipy.special.airy(x)[3]'
prntr = NumPyPrinter({'strict': False})
assert "Not supported" in prntr.doprint(expr1)
assert "Not supported" in prntr.doprint(expr2)
prntr = PythonCodePrinter({'strict': False})
assert "Not supported" in prntr.doprint(expr1)
assert "Not supported" in prntr.doprint(expr2)
def test_numerical_accuracy_functions():
prntr = SciPyPrinter()
assert prntr.doprint(expm1(x)) == 'numpy.expm1(x)'
assert prntr.doprint(log1p(x)) == 'numpy.log1p(x)'
assert prntr.doprint(cosm1(x)) == 'scipy.special.cosm1(x)'
def test_array_printer():
A = ArraySymbol('A', (4,4,6,6,6))
I = IndexedBase('I')
i,j,k = Idx('i', (0,1)), Idx('j', (2,3)), Idx('k', (4,5))
prntr = NumPyPrinter()
assert prntr.doprint(ZeroArray(5)) == 'numpy.zeros((5,))'
assert prntr.doprint(OneArray(5)) == 'numpy.ones((5,))'
assert prntr.doprint(ArrayContraction(A, [2,3])) == 'numpy.einsum("abccd->abd", A)'
assert prntr.doprint(I) == 'I'
assert prntr.doprint(ArrayDiagonal(A, [2,3,4])) == 'numpy.einsum("abccc->abc", A)'
assert prntr.doprint(ArrayDiagonal(A, [0,1], [2,3])) == 'numpy.einsum("aabbc->cab", A)'
assert prntr.doprint(ArrayContraction(A, [2], [3])) == 'numpy.einsum("abcde->abe", A)'
assert prntr.doprint(Assignment(I[i,j,k], I[i,j,k])) == 'I = I'
prntr = TensorflowPrinter()
assert prntr.doprint(ZeroArray(5)) == 'tensorflow.zeros((5,))'
assert prntr.doprint(OneArray(5)) == 'tensorflow.ones((5,))'
assert prntr.doprint(ArrayContraction(A, [2,3])) == 'tensorflow.linalg.einsum("abccd->abd", A)'
assert prntr.doprint(I) == 'I'
assert prntr.doprint(ArrayDiagonal(A, [2,3,4])) == 'tensorflow.linalg.einsum("abccc->abc", A)'
assert prntr.doprint(ArrayDiagonal(A, [0,1], [2,3])) == 'tensorflow.linalg.einsum("aabbc->cab", A)'
assert prntr.doprint(ArrayContraction(A, [2], [3])) == 'tensorflow.linalg.einsum("abcde->abe", A)'
assert prntr.doprint(Assignment(I[i,j,k], I[i,j,k])) == 'I = I'

View File

@ -0,0 +1,203 @@
from sympy.core.function import (Derivative, Function)
from sympy.core.numbers import (I, Rational, oo, pi)
from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne)
from sympy.core.symbol import (Symbol, symbols)
from sympy.functions.elementary.complexes import (Abs, conjugate)
from sympy.functions.elementary.exponential import (exp, log)
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.trigonometric import sin
from sympy.integrals.integrals import Integral
from sympy.matrices.dense import Matrix
from sympy.series.limits import limit
from sympy.printing.python import python
from sympy.testing.pytest import raises, XFAIL
x, y = symbols('x,y')
th = Symbol('theta')
ph = Symbol('phi')
def test_python_basic():
# Simple numbers/symbols
assert python(-Rational(1)/2) == "e = Rational(-1, 2)"
assert python(-Rational(13)/22) == "e = Rational(-13, 22)"
assert python(oo) == "e = oo"
# Powers
assert python(x**2) == "x = Symbol(\'x\')\ne = x**2"
assert python(1/x) == "x = Symbol('x')\ne = 1/x"
assert python(y*x**-2) == "y = Symbol('y')\nx = Symbol('x')\ne = y/x**2"
assert python(
x**Rational(-5, 2)) == "x = Symbol('x')\ne = x**Rational(-5, 2)"
# Sums of terms
assert python(x**2 + x + 1) in [
"x = Symbol('x')\ne = 1 + x + x**2",
"x = Symbol('x')\ne = x + x**2 + 1",
"x = Symbol('x')\ne = x**2 + x + 1", ]
assert python(1 - x) in [
"x = Symbol('x')\ne = 1 - x",
"x = Symbol('x')\ne = -x + 1"]
assert python(1 - 2*x) in [
"x = Symbol('x')\ne = 1 - 2*x",
"x = Symbol('x')\ne = -2*x + 1"]
assert python(1 - Rational(3, 2)*y/x) in [
"y = Symbol('y')\nx = Symbol('x')\ne = 1 - 3/2*y/x",
"y = Symbol('y')\nx = Symbol('x')\ne = -3/2*y/x + 1",
"y = Symbol('y')\nx = Symbol('x')\ne = 1 - 3*y/(2*x)"]
# Multiplication
assert python(x/y) == "x = Symbol('x')\ny = Symbol('y')\ne = x/y"
assert python(-x/y) == "x = Symbol('x')\ny = Symbol('y')\ne = -x/y"
assert python((x + 2)/y) in [
"y = Symbol('y')\nx = Symbol('x')\ne = 1/y*(2 + x)",
"y = Symbol('y')\nx = Symbol('x')\ne = 1/y*(x + 2)",
"x = Symbol('x')\ny = Symbol('y')\ne = 1/y*(2 + x)",
"x = Symbol('x')\ny = Symbol('y')\ne = (2 + x)/y",
"x = Symbol('x')\ny = Symbol('y')\ne = (x + 2)/y"]
assert python((1 + x)*y) in [
"y = Symbol('y')\nx = Symbol('x')\ne = y*(1 + x)",
"y = Symbol('y')\nx = Symbol('x')\ne = y*(x + 1)", ]
# Check for proper placement of negative sign
assert python(-5*x/(x + 10)) == "x = Symbol('x')\ne = -5*x/(x + 10)"
assert python(1 - Rational(3, 2)*(x + 1)) in [
"x = Symbol('x')\ne = Rational(-3, 2)*x + Rational(-1, 2)",
"x = Symbol('x')\ne = -3*x/2 + Rational(-1, 2)",
"x = Symbol('x')\ne = -3*x/2 + Rational(-1, 2)"
]
def test_python_keyword_symbol_name_escaping():
# Check for escaping of keywords
assert python(
5*Symbol("lambda")) == "lambda_ = Symbol('lambda')\ne = 5*lambda_"
assert (python(5*Symbol("lambda") + 7*Symbol("lambda_")) ==
"lambda__ = Symbol('lambda')\nlambda_ = Symbol('lambda_')\ne = 7*lambda_ + 5*lambda__")
assert (python(5*Symbol("for") + Function("for_")(8)) ==
"for__ = Symbol('for')\nfor_ = Function('for_')\ne = 5*for__ + for_(8)")
def test_python_keyword_function_name_escaping():
assert python(
5*Function("for")(8)) == "for_ = Function('for')\ne = 5*for_(8)"
def test_python_relational():
assert python(Eq(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = Eq(x, y)"
assert python(Ge(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = x >= y"
assert python(Le(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = x <= y"
assert python(Gt(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = x > y"
assert python(Lt(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = x < y"
assert python(Ne(x/(y + 1), y**2)) in [
"x = Symbol('x')\ny = Symbol('y')\ne = Ne(x/(1 + y), y**2)",
"x = Symbol('x')\ny = Symbol('y')\ne = Ne(x/(y + 1), y**2)"]
def test_python_functions():
# Simple
assert python(2*x + exp(x)) in "x = Symbol('x')\ne = 2*x + exp(x)"
assert python(sqrt(2)) == 'e = sqrt(2)'
assert python(2**Rational(1, 3)) == 'e = 2**Rational(1, 3)'
assert python(sqrt(2 + pi)) == 'e = sqrt(2 + pi)'
assert python((2 + pi)**Rational(1, 3)) == 'e = (2 + pi)**Rational(1, 3)'
assert python(2**Rational(1, 4)) == 'e = 2**Rational(1, 4)'
assert python(Abs(x)) == "x = Symbol('x')\ne = Abs(x)"
assert python(
Abs(x/(x**2 + 1))) in ["x = Symbol('x')\ne = Abs(x/(1 + x**2))",
"x = Symbol('x')\ne = Abs(x/(x**2 + 1))"]
# Univariate/Multivariate functions
f = Function('f')
assert python(f(x)) == "x = Symbol('x')\nf = Function('f')\ne = f(x)"
assert python(f(x, y)) == "x = Symbol('x')\ny = Symbol('y')\nf = Function('f')\ne = f(x, y)"
assert python(f(x/(y + 1), y)) in [
"x = Symbol('x')\ny = Symbol('y')\nf = Function('f')\ne = f(x/(1 + y), y)",
"x = Symbol('x')\ny = Symbol('y')\nf = Function('f')\ne = f(x/(y + 1), y)"]
# Nesting of square roots
assert python(sqrt((sqrt(x + 1)) + 1)) in [
"x = Symbol('x')\ne = sqrt(1 + sqrt(1 + x))",
"x = Symbol('x')\ne = sqrt(sqrt(x + 1) + 1)"]
# Nesting of powers
assert python((((x + 1)**Rational(1, 3)) + 1)**Rational(1, 3)) in [
"x = Symbol('x')\ne = (1 + (1 + x)**Rational(1, 3))**Rational(1, 3)",
"x = Symbol('x')\ne = ((x + 1)**Rational(1, 3) + 1)**Rational(1, 3)"]
# Function powers
assert python(sin(x)**2) == "x = Symbol('x')\ne = sin(x)**2"
@XFAIL
def test_python_functions_conjugates():
a, b = map(Symbol, 'ab')
assert python( conjugate(a + b*I) ) == '_ _\na - I*b'
assert python( conjugate(exp(a + b*I)) ) == ' _ _\n a - I*b\ne '
def test_python_derivatives():
# Simple
f_1 = Derivative(log(x), x, evaluate=False)
assert python(f_1) == "x = Symbol('x')\ne = Derivative(log(x), x)"
f_2 = Derivative(log(x), x, evaluate=False) + x
assert python(f_2) == "x = Symbol('x')\ne = x + Derivative(log(x), x)"
# Multiple symbols
f_3 = Derivative(log(x) + x**2, x, y, evaluate=False)
assert python(f_3) == \
"x = Symbol('x')\ny = Symbol('y')\ne = Derivative(x**2 + log(x), x, y)"
f_4 = Derivative(2*x*y, y, x, evaluate=False) + x**2
assert python(f_4) in [
"x = Symbol('x')\ny = Symbol('y')\ne = x**2 + Derivative(2*x*y, y, x)",
"x = Symbol('x')\ny = Symbol('y')\ne = Derivative(2*x*y, y, x) + x**2"]
def test_python_integrals():
# Simple
f_1 = Integral(log(x), x)
assert python(f_1) == "x = Symbol('x')\ne = Integral(log(x), x)"
f_2 = Integral(x**2, x)
assert python(f_2) == "x = Symbol('x')\ne = Integral(x**2, x)"
# Double nesting of pow
f_3 = Integral(x**(2**x), x)
assert python(f_3) == "x = Symbol('x')\ne = Integral(x**(2**x), x)"
# Definite integrals
f_4 = Integral(x**2, (x, 1, 2))
assert python(f_4) == "x = Symbol('x')\ne = Integral(x**2, (x, 1, 2))"
f_5 = Integral(x**2, (x, Rational(1, 2), 10))
assert python(
f_5) == "x = Symbol('x')\ne = Integral(x**2, (x, Rational(1, 2), 10))"
# Nested integrals
f_6 = Integral(x**2*y**2, x, y)
assert python(f_6) == "x = Symbol('x')\ny = Symbol('y')\ne = Integral(x**2*y**2, x, y)"
def test_python_matrix():
p = python(Matrix([[x**2+1, 1], [y, x+y]]))
s = "x = Symbol('x')\ny = Symbol('y')\ne = MutableDenseMatrix([[x**2 + 1, 1], [y, x + y]])"
assert p == s
def test_python_limits():
assert python(limit(x, x, oo)) == 'e = oo'
assert python(limit(x**2, x, 0)) == 'e = 0'
def test_issue_20762():
# Make sure Python removes curly braces from subscripted variables
a_b = Symbol('a_{b}')
b = Symbol('b')
expr = a_b*b
assert python(expr) == "a_b = Symbol('a_{b}')\nb = Symbol('b')\ne = a_b*b"
def test_settings():
raises(TypeError, lambda: python(x, method="garbage"))

View File

@ -0,0 +1,476 @@
from sympy.core import (S, pi, oo, Symbol, symbols, Rational, Integer,
GoldenRatio, EulerGamma, Catalan, Lambda, Dummy)
from sympy.functions import (Piecewise, sin, cos, Abs, exp, ceiling, sqrt,
gamma, sign, Max, Min, factorial, beta)
from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne)
from sympy.sets import Range
from sympy.logic import ITE
from sympy.codegen import For, aug_assign, Assignment
from sympy.testing.pytest import raises
from sympy.printing.rcode import RCodePrinter
from sympy.utilities.lambdify import implemented_function
from sympy.tensor import IndexedBase, Idx
from sympy.matrices import Matrix, MatrixSymbol
from sympy.printing.rcode import rcode
x, y, z = symbols('x,y,z')
def test_printmethod():
class fabs(Abs):
def _rcode(self, printer):
return "abs(%s)" % printer._print(self.args[0])
assert rcode(fabs(x)) == "abs(x)"
def test_rcode_sqrt():
assert rcode(sqrt(x)) == "sqrt(x)"
assert rcode(x**0.5) == "sqrt(x)"
assert rcode(sqrt(x)) == "sqrt(x)"
def test_rcode_Pow():
assert rcode(x**3) == "x^3"
assert rcode(x**(y**3)) == "x^(y^3)"
g = implemented_function('g', Lambda(x, 2*x))
assert rcode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \
"(3.5*2*x)^(-x + y^x)/(x^2 + y)"
assert rcode(x**-1.0) == '1.0/x'
assert rcode(x**Rational(2, 3)) == 'x^(2.0/3.0)'
_cond_cfunc = [(lambda base, exp: exp.is_integer, "dpowi"),
(lambda base, exp: not exp.is_integer, "pow")]
assert rcode(x**3, user_functions={'Pow': _cond_cfunc}) == 'dpowi(x, 3)'
assert rcode(x**3.2, user_functions={'Pow': _cond_cfunc}) == 'pow(x, 3.2)'
def test_rcode_Max():
# Test for gh-11926
assert rcode(Max(x,x*x),user_functions={"Max":"my_max", "Pow":"my_pow"}) == 'my_max(x, my_pow(x, 2))'
def test_rcode_constants_mathh():
assert rcode(exp(1)) == "exp(1)"
assert rcode(pi) == "pi"
assert rcode(oo) == "Inf"
assert rcode(-oo) == "-Inf"
def test_rcode_constants_other():
assert rcode(2*GoldenRatio) == "GoldenRatio = 1.61803398874989;\n2*GoldenRatio"
assert rcode(
2*Catalan) == "Catalan = 0.915965594177219;\n2*Catalan"
assert rcode(2*EulerGamma) == "EulerGamma = 0.577215664901533;\n2*EulerGamma"
def test_rcode_Rational():
assert rcode(Rational(3, 7)) == "3.0/7.0"
assert rcode(Rational(18, 9)) == "2"
assert rcode(Rational(3, -7)) == "-3.0/7.0"
assert rcode(Rational(-3, -7)) == "3.0/7.0"
assert rcode(x + Rational(3, 7)) == "x + 3.0/7.0"
assert rcode(Rational(3, 7)*x) == "(3.0/7.0)*x"
def test_rcode_Integer():
assert rcode(Integer(67)) == "67"
assert rcode(Integer(-1)) == "-1"
def test_rcode_functions():
assert rcode(sin(x) ** cos(x)) == "sin(x)^cos(x)"
assert rcode(factorial(x) + gamma(y)) == "factorial(x) + gamma(y)"
assert rcode(beta(Min(x, y), Max(x, y))) == "beta(min(x, y), max(x, y))"
def test_rcode_inline_function():
x = symbols('x')
g = implemented_function('g', Lambda(x, 2*x))
assert rcode(g(x)) == "2*x"
g = implemented_function('g', Lambda(x, 2*x/Catalan))
assert rcode(
g(x)) == "Catalan = %s;\n2*x/Catalan" % Catalan.n()
A = IndexedBase('A')
i = Idx('i', symbols('n', integer=True))
g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x)))
res=rcode(g(A[i]), assign_to=A[i])
ref=(
"for (i in 1:n){\n"
" A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n"
"}"
)
assert res == ref
def test_rcode_exceptions():
assert rcode(ceiling(x)) == "ceiling(x)"
assert rcode(Abs(x)) == "abs(x)"
assert rcode(gamma(x)) == "gamma(x)"
def test_rcode_user_functions():
x = symbols('x', integer=False)
n = symbols('n', integer=True)
custom_functions = {
"ceiling": "myceil",
"Abs": [(lambda x: not x.is_integer, "fabs"), (lambda x: x.is_integer, "abs")],
}
assert rcode(ceiling(x), user_functions=custom_functions) == "myceil(x)"
assert rcode(Abs(x), user_functions=custom_functions) == "fabs(x)"
assert rcode(Abs(n), user_functions=custom_functions) == "abs(n)"
def test_rcode_boolean():
assert rcode(True) == "True"
assert rcode(S.true) == "True"
assert rcode(False) == "False"
assert rcode(S.false) == "False"
assert rcode(x & y) == "x & y"
assert rcode(x | y) == "x | y"
assert rcode(~x) == "!x"
assert rcode(x & y & z) == "x & y & z"
assert rcode(x | y | z) == "x | y | z"
assert rcode((x & y) | z) == "z | x & y"
assert rcode((x | y) & z) == "z & (x | y)"
def test_rcode_Relational():
assert rcode(Eq(x, y)) == "x == y"
assert rcode(Ne(x, y)) == "x != y"
assert rcode(Le(x, y)) == "x <= y"
assert rcode(Lt(x, y)) == "x < y"
assert rcode(Gt(x, y)) == "x > y"
assert rcode(Ge(x, y)) == "x >= y"
def test_rcode_Piecewise():
expr = Piecewise((x, x < 1), (x**2, True))
res=rcode(expr)
ref="ifelse(x < 1,x,x^2)"
assert res == ref
tau=Symbol("tau")
res=rcode(expr,tau)
ref="tau = ifelse(x < 1,x,x^2);"
assert res == ref
expr = 2*Piecewise((x, x < 1), (x**2, x<2), (x**3,True))
assert rcode(expr) == "2*ifelse(x < 1,x,ifelse(x < 2,x^2,x^3))"
res = rcode(expr, assign_to='c')
assert res == "c = 2*ifelse(x < 1,x,ifelse(x < 2,x^2,x^3));"
# Check that Piecewise without a True (default) condition error
#expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
#raises(ValueError, lambda: rcode(expr))
expr = 2*Piecewise((x, x < 1), (x**2, x<2))
assert(rcode(expr))== "2*ifelse(x < 1,x,ifelse(x < 2,x^2,NA))"
def test_rcode_sinc():
from sympy.functions.elementary.trigonometric import sinc
expr = sinc(x)
res = rcode(expr)
ref = "(ifelse(x != 0,sin(x)/x,1))"
assert res == ref
def test_rcode_Piecewise_deep():
p = rcode(2*Piecewise((x, x < 1), (x + 1, x < 2), (x**2, True)))
assert p == "2*ifelse(x < 1,x,ifelse(x < 2,x + 1,x^2))"
expr = x*y*z + x**2 + y**2 + Piecewise((0, x < 0.5), (1, True)) + cos(z) - 1
p = rcode(expr)
ref="x^2 + x*y*z + y^2 + ifelse(x < 0.5,0,1) + cos(z) - 1"
assert p == ref
ref="c = x^2 + x*y*z + y^2 + ifelse(x < 0.5,0,1) + cos(z) - 1;"
p = rcode(expr, assign_to='c')
assert p == ref
def test_rcode_ITE():
expr = ITE(x < 1, y, z)
p = rcode(expr)
ref="ifelse(x < 1,y,z)"
assert p == ref
def test_rcode_settings():
raises(TypeError, lambda: rcode(sin(x), method="garbage"))
def test_rcode_Indexed():
n, m, o = symbols('n m o', integer=True)
i, j, k = Idx('i', n), Idx('j', m), Idx('k', o)
p = RCodePrinter()
p._not_r = set()
x = IndexedBase('x')[j]
assert p._print_Indexed(x) == 'x[j]'
A = IndexedBase('A')[i, j]
assert p._print_Indexed(A) == 'A[i, j]'
B = IndexedBase('B')[i, j, k]
assert p._print_Indexed(B) == 'B[i, j, k]'
assert p._not_r == set()
def test_rcode_Indexed_without_looking_for_contraction():
len_y = 5
y = IndexedBase('y', shape=(len_y,))
x = IndexedBase('x', shape=(len_y,))
Dy = IndexedBase('Dy', shape=(len_y-1,))
i = Idx('i', len_y-1)
e=Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i]))
code0 = rcode(e.rhs, assign_to=e.lhs, contract=False)
assert code0 == 'Dy[i] = (y[%s] - y[i])/(x[%s] - x[i]);' % (i + 1, i + 1)
def test_rcode_loops_matrix_vector():
n, m = symbols('n m', integer=True)
A = IndexedBase('A')
x = IndexedBase('x')
y = IndexedBase('y')
i = Idx('i', m)
j = Idx('j', n)
s = (
'for (i in 1:m){\n'
' y[i] = 0;\n'
'}\n'
'for (i in 1:m){\n'
' for (j in 1:n){\n'
' y[i] = A[i, j]*x[j] + y[i];\n'
' }\n'
'}'
)
c = rcode(A[i, j]*x[j], assign_to=y[i])
assert c == s
def test_dummy_loops():
# the following line could also be
# [Dummy(s, integer=True) for s in 'im']
# or [Dummy(integer=True) for s in 'im']
i, m = symbols('i m', integer=True, cls=Dummy)
x = IndexedBase('x')
y = IndexedBase('y')
i = Idx(i, m)
expected = (
'for (i_%(icount)i in 1:m_%(mcount)i){\n'
' y[i_%(icount)i] = x[i_%(icount)i];\n'
'}'
) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index}
code = rcode(x[i], assign_to=y[i])
assert code == expected
def test_rcode_loops_add():
n, m = symbols('n m', integer=True)
A = IndexedBase('A')
x = IndexedBase('x')
y = IndexedBase('y')
z = IndexedBase('z')
i = Idx('i', m)
j = Idx('j', n)
s = (
'for (i in 1:m){\n'
' y[i] = x[i] + z[i];\n'
'}\n'
'for (i in 1:m){\n'
' for (j in 1:n){\n'
' y[i] = A[i, j]*x[j] + y[i];\n'
' }\n'
'}'
)
c = rcode(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i])
assert c == s
def test_rcode_loops_multiple_contractions():
n, m, o, p = symbols('n m o p', integer=True)
a = IndexedBase('a')
b = IndexedBase('b')
y = IndexedBase('y')
i = Idx('i', m)
j = Idx('j', n)
k = Idx('k', o)
l = Idx('l', p)
s = (
'for (i in 1:m){\n'
' y[i] = 0;\n'
'}\n'
'for (i in 1:m){\n'
' for (j in 1:n){\n'
' for (k in 1:o){\n'
' for (l in 1:p){\n'
' y[i] = a[i, j, k, l]*b[j, k, l] + y[i];\n'
' }\n'
' }\n'
' }\n'
'}'
)
c = rcode(b[j, k, l]*a[i, j, k, l], assign_to=y[i])
assert c == s
def test_rcode_loops_addfactor():
n, m, o, p = symbols('n m o p', integer=True)
a = IndexedBase('a')
b = IndexedBase('b')
c = IndexedBase('c')
y = IndexedBase('y')
i = Idx('i', m)
j = Idx('j', n)
k = Idx('k', o)
l = Idx('l', p)
s = (
'for (i in 1:m){\n'
' y[i] = 0;\n'
'}\n'
'for (i in 1:m){\n'
' for (j in 1:n){\n'
' for (k in 1:o){\n'
' for (l in 1:p){\n'
' y[i] = (a[i, j, k, l] + b[i, j, k, l])*c[j, k, l] + y[i];\n'
' }\n'
' }\n'
' }\n'
'}'
)
c = rcode((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i])
assert c == s
def test_rcode_loops_multiple_terms():
n, m, o, p = symbols('n m o p', integer=True)
a = IndexedBase('a')
b = IndexedBase('b')
c = IndexedBase('c')
y = IndexedBase('y')
i = Idx('i', m)
j = Idx('j', n)
k = Idx('k', o)
s0 = (
'for (i in 1:m){\n'
' y[i] = 0;\n'
'}\n'
)
s1 = (
'for (i in 1:m){\n'
' for (j in 1:n){\n'
' for (k in 1:o){\n'
' y[i] = b[j]*b[k]*c[i, j, k] + y[i];\n'
' }\n'
' }\n'
'}\n'
)
s2 = (
'for (i in 1:m){\n'
' for (k in 1:o){\n'
' y[i] = a[i, k]*b[k] + y[i];\n'
' }\n'
'}\n'
)
s3 = (
'for (i in 1:m){\n'
' for (j in 1:n){\n'
' y[i] = a[i, j]*b[j] + y[i];\n'
' }\n'
'}\n'
)
c = rcode(
b[j]*a[i, j] + b[k]*a[i, k] + b[j]*b[k]*c[i, j, k], assign_to=y[i])
ref={}
ref[0] = s0 + s1 + s2 + s3[:-1]
ref[1] = s0 + s1 + s3 + s2[:-1]
ref[2] = s0 + s2 + s1 + s3[:-1]
ref[3] = s0 + s2 + s3 + s1[:-1]
ref[4] = s0 + s3 + s1 + s2[:-1]
ref[5] = s0 + s3 + s2 + s1[:-1]
assert (c == ref[0] or
c == ref[1] or
c == ref[2] or
c == ref[3] or
c == ref[4] or
c == ref[5])
def test_dereference_printing():
expr = x + y + sin(z) + z
assert rcode(expr, dereference=[z]) == "x + y + (*z) + sin((*z))"
def test_Matrix_printing():
# Test returning a Matrix
mat = Matrix([x*y, Piecewise((2 + x, y>0), (y, True)), sin(z)])
A = MatrixSymbol('A', 3, 1)
p = rcode(mat, A)
assert p == (
"A[0] = x*y;\n"
"A[1] = ifelse(y > 0,x + 2,y);\n"
"A[2] = sin(z);")
# Test using MatrixElements in expressions
expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0]
p = rcode(expr)
assert p == ("ifelse(x > 0,2*A[2],A[2]) + sin(A[1]) + A[0]")
# Test using MatrixElements in a Matrix
q = MatrixSymbol('q', 5, 1)
M = MatrixSymbol('M', 3, 3)
m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])],
[q[1,0] + q[2,0], q[3, 0], 5],
[2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]])
assert rcode(m, M) == (
"M[0] = sin(q[1]);\n"
"M[1] = 0;\n"
"M[2] = cos(q[2]);\n"
"M[3] = q[1] + q[2];\n"
"M[4] = q[3];\n"
"M[5] = 5;\n"
"M[6] = 2*q[4]/q[1];\n"
"M[7] = sqrt(q[0]) + 4;\n"
"M[8] = 0;")
def test_rcode_sgn():
expr = sign(x) * y
assert rcode(expr) == 'y*sign(x)'
p = rcode(expr, 'z')
assert p == 'z = y*sign(x);'
p = rcode(sign(2 * x + x**2) * x + x**2)
assert p == "x^2 + x*sign(x^2 + 2*x)"
expr = sign(cos(x))
p = rcode(expr)
assert p == 'sign(cos(x))'
def test_rcode_Assignment():
assert rcode(Assignment(x, y + z)) == 'x = y + z;'
assert rcode(aug_assign(x, '+', y + z)) == 'x += y + z;'
def test_rcode_For():
f = For(x, Range(0, 10, 2), [aug_assign(y, '*', x)])
sol = rcode(f)
assert sol == ("for(x in seq(from=0, to=9, by=2){\n"
" y *= x;\n"
"}")
def test_MatrixElement_printing():
# test cases for issue #11821
A = MatrixSymbol("A", 1, 3)
B = MatrixSymbol("B", 1, 3)
C = MatrixSymbol("C", 1, 3)
assert(rcode(A[0, 0]) == "A[0]")
assert(rcode(3 * A[0, 0]) == "3*A[0]")
F = C[0, 0].subs(C, A - B)
assert(rcode(F) == "(A - B)[0]")

View File

@ -0,0 +1,382 @@
from __future__ import annotations
from typing import Any
from sympy.external.gmpy import GROUND_TYPES
from sympy.testing.pytest import raises, warns_deprecated_sympy
from sympy.assumptions.ask import Q
from sympy.core.function import (Function, WildFunction)
from sympy.core.numbers import (AlgebraicNumber, Float, Integer, Rational)
from sympy.core.singleton import S
from sympy.core.symbol import (Dummy, Symbol, Wild, symbols)
from sympy.core.sympify import sympify
from sympy.functions.elementary.complexes import Abs
from sympy.functions.elementary.miscellaneous import (root, sqrt)
from sympy.functions.elementary.trigonometric import sin
from sympy.functions.special.delta_functions import Heaviside
from sympy.logic.boolalg import (false, true)
from sympy.matrices.dense import (Matrix, ones)
from sympy.matrices.expressions.matexpr import MatrixSymbol
from sympy.matrices.immutable import ImmutableDenseMatrix
from sympy.combinatorics import Cycle, Permutation
from sympy.core.symbol import Str
from sympy.geometry import Point, Ellipse
from sympy.printing import srepr
from sympy.polys import ring, field, ZZ, QQ, lex, grlex, Poly
from sympy.polys.polyclasses import DMP
from sympy.polys.agca.extensions import FiniteExtension
x, y = symbols('x,y')
# eval(srepr(expr)) == expr has to succeed in the right environment. The right
# environment is the scope of "from sympy import *" for most cases.
ENV: dict[str, Any] = {"Str": Str}
exec("from sympy import *", ENV)
def sT(expr, string, import_stmt=None, **kwargs):
"""
sT := sreprTest
Tests that srepr delivers the expected string and that
the condition eval(srepr(expr))==expr holds.
"""
if import_stmt is None:
ENV2 = ENV
else:
ENV2 = ENV.copy()
exec(import_stmt, ENV2)
assert srepr(expr, **kwargs) == string
assert eval(string, ENV2) == expr
def test_printmethod():
class R(Abs):
def _sympyrepr(self, printer):
return "foo(%s)" % printer._print(self.args[0])
assert srepr(R(x)) == "foo(Symbol('x'))"
def test_Add():
sT(x + y, "Add(Symbol('x'), Symbol('y'))")
assert srepr(x**2 + 1, order='lex') == "Add(Pow(Symbol('x'), Integer(2)), Integer(1))"
assert srepr(x**2 + 1, order='old') == "Add(Integer(1), Pow(Symbol('x'), Integer(2)))"
assert srepr(sympify('x + 3 - 2', evaluate=False), order='none') == "Add(Symbol('x'), Integer(3), Mul(Integer(-1), Integer(2)))"
def test_more_than_255_args_issue_10259():
from sympy.core.add import Add
from sympy.core.mul import Mul
for op in (Add, Mul):
expr = op(*symbols('x:256'))
assert eval(srepr(expr)) == expr
def test_Function():
sT(Function("f")(x), "Function('f')(Symbol('x'))")
# test unapplied Function
sT(Function('f'), "Function('f')")
sT(sin(x), "sin(Symbol('x'))")
sT(sin, "sin")
def test_Heaviside():
sT(Heaviside(x), "Heaviside(Symbol('x'))")
sT(Heaviside(x, 1), "Heaviside(Symbol('x'), Integer(1))")
def test_Geometry():
sT(Point(0, 0), "Point2D(Integer(0), Integer(0))")
sT(Ellipse(Point(0, 0), 5, 1),
"Ellipse(Point2D(Integer(0), Integer(0)), Integer(5), Integer(1))")
# TODO more tests
def test_Singletons():
sT(S.Catalan, 'Catalan')
sT(S.ComplexInfinity, 'zoo')
sT(S.EulerGamma, 'EulerGamma')
sT(S.Exp1, 'E')
sT(S.GoldenRatio, 'GoldenRatio')
sT(S.TribonacciConstant, 'TribonacciConstant')
sT(S.Half, 'Rational(1, 2)')
sT(S.ImaginaryUnit, 'I')
sT(S.Infinity, 'oo')
sT(S.NaN, 'nan')
sT(S.NegativeInfinity, '-oo')
sT(S.NegativeOne, 'Integer(-1)')
sT(S.One, 'Integer(1)')
sT(S.Pi, 'pi')
sT(S.Zero, 'Integer(0)')
sT(S.Complexes, 'Complexes')
sT(S.EmptySequence, 'EmptySequence')
sT(S.EmptySet, 'EmptySet')
# sT(S.IdentityFunction, 'Lambda(_x, _x)')
sT(S.Naturals, 'Naturals')
sT(S.Naturals0, 'Naturals0')
sT(S.Rationals, 'Rationals')
sT(S.Reals, 'Reals')
sT(S.UniversalSet, 'UniversalSet')
def test_Integer():
sT(Integer(4), "Integer(4)")
def test_list():
sT([x, Integer(4)], "[Symbol('x'), Integer(4)]")
def test_Matrix():
for cls, name in [(Matrix, "MutableDenseMatrix"), (ImmutableDenseMatrix, "ImmutableDenseMatrix")]:
sT(cls([[x**+1, 1], [y, x + y]]),
"%s([[Symbol('x'), Integer(1)], [Symbol('y'), Add(Symbol('x'), Symbol('y'))]])" % name)
sT(cls(), "%s([])" % name)
sT(cls([[x**+1, 1], [y, x + y]]), "%s([[Symbol('x'), Integer(1)], [Symbol('y'), Add(Symbol('x'), Symbol('y'))]])" % name)
def test_empty_Matrix():
sT(ones(0, 3), "MutableDenseMatrix(0, 3, [])")
sT(ones(4, 0), "MutableDenseMatrix(4, 0, [])")
sT(ones(0, 0), "MutableDenseMatrix([])")
def test_Rational():
sT(Rational(1, 3), "Rational(1, 3)")
sT(Rational(-1, 3), "Rational(-1, 3)")
def test_Float():
sT(Float('1.23', dps=3), "Float('1.22998', precision=13)")
sT(Float('1.23456789', dps=9), "Float('1.23456788994', precision=33)")
sT(Float('1.234567890123456789', dps=19),
"Float('1.234567890123456789013', precision=66)")
sT(Float('0.60038617995049726', dps=15),
"Float('0.60038617995049726', precision=53)")
sT(Float('1.23', precision=13), "Float('1.22998', precision=13)")
sT(Float('1.23456789', precision=33),
"Float('1.23456788994', precision=33)")
sT(Float('1.234567890123456789', precision=66),
"Float('1.234567890123456789013', precision=66)")
sT(Float('0.60038617995049726', precision=53),
"Float('0.60038617995049726', precision=53)")
sT(Float('0.60038617995049726', 15),
"Float('0.60038617995049726', precision=53)")
def test_Symbol():
sT(x, "Symbol('x')")
sT(y, "Symbol('y')")
sT(Symbol('x', negative=True), "Symbol('x', negative=True)")
def test_Symbol_two_assumptions():
x = Symbol('x', negative=0, integer=1)
# order could vary
s1 = "Symbol('x', integer=True, negative=False)"
s2 = "Symbol('x', negative=False, integer=True)"
assert srepr(x) in (s1, s2)
assert eval(srepr(x), ENV) == x
def test_Symbol_no_special_commutative_treatment():
sT(Symbol('x'), "Symbol('x')")
sT(Symbol('x', commutative=False), "Symbol('x', commutative=False)")
sT(Symbol('x', commutative=0), "Symbol('x', commutative=False)")
sT(Symbol('x', commutative=True), "Symbol('x', commutative=True)")
sT(Symbol('x', commutative=1), "Symbol('x', commutative=True)")
def test_Wild():
sT(Wild('x', even=True), "Wild('x', even=True)")
def test_Dummy():
d = Dummy('d')
sT(d, "Dummy('d', dummy_index=%s)" % str(d.dummy_index))
def test_Dummy_assumption():
d = Dummy('d', nonzero=True)
assert d == eval(srepr(d))
s1 = "Dummy('d', dummy_index=%s, nonzero=True)" % str(d.dummy_index)
s2 = "Dummy('d', nonzero=True, dummy_index=%s)" % str(d.dummy_index)
assert srepr(d) in (s1, s2)
def test_Dummy_from_Symbol():
# should not get the full dictionary of assumptions
n = Symbol('n', integer=True)
d = n.as_dummy()
assert srepr(d
) == "Dummy('n', dummy_index=%s)" % str(d.dummy_index)
def test_tuple():
sT((x,), "(Symbol('x'),)")
sT((x, y), "(Symbol('x'), Symbol('y'))")
def test_WildFunction():
sT(WildFunction('w'), "WildFunction('w')")
def test_settins():
raises(TypeError, lambda: srepr(x, method="garbage"))
def test_Mul():
sT(3*x**3*y, "Mul(Integer(3), Pow(Symbol('x'), Integer(3)), Symbol('y'))")
assert srepr(3*x**3*y, order='old') == "Mul(Integer(3), Symbol('y'), Pow(Symbol('x'), Integer(3)))"
assert srepr(sympify('(x+4)*2*x*7', evaluate=False), order='none') == "Mul(Add(Symbol('x'), Integer(4)), Integer(2), Symbol('x'), Integer(7))"
def test_AlgebraicNumber():
a = AlgebraicNumber(sqrt(2))
sT(a, "AlgebraicNumber(Pow(Integer(2), Rational(1, 2)), [Integer(1), Integer(0)])")
a = AlgebraicNumber(root(-2, 3))
sT(a, "AlgebraicNumber(Pow(Integer(-2), Rational(1, 3)), [Integer(1), Integer(0)])")
def test_PolyRing():
assert srepr(ring("x", ZZ, lex)[0]) == "PolyRing((Symbol('x'),), ZZ, lex)"
assert srepr(ring("x,y", QQ, grlex)[0]) == "PolyRing((Symbol('x'), Symbol('y')), QQ, grlex)"
assert srepr(ring("x,y,z", ZZ["t"], lex)[0]) == "PolyRing((Symbol('x'), Symbol('y'), Symbol('z')), ZZ[t], lex)"
def test_FracField():
assert srepr(field("x", ZZ, lex)[0]) == "FracField((Symbol('x'),), ZZ, lex)"
assert srepr(field("x,y", QQ, grlex)[0]) == "FracField((Symbol('x'), Symbol('y')), QQ, grlex)"
assert srepr(field("x,y,z", ZZ["t"], lex)[0]) == "FracField((Symbol('x'), Symbol('y'), Symbol('z')), ZZ[t], lex)"
def test_PolyElement():
R, x, y = ring("x,y", ZZ)
assert srepr(3*x**2*y + 1) == "PolyElement(PolyRing((Symbol('x'), Symbol('y')), ZZ, lex), [((2, 1), 3), ((0, 0), 1)])"
def test_FracElement():
F, x, y = field("x,y", ZZ)
assert srepr((3*x**2*y + 1)/(x - y**2)) == "FracElement(FracField((Symbol('x'), Symbol('y')), ZZ, lex), [((2, 1), 3), ((0, 0), 1)], [((1, 0), 1), ((0, 2), -1)])"
def test_FractionField():
assert srepr(QQ.frac_field(x)) == \
"FractionField(FracField((Symbol('x'),), QQ, lex))"
assert srepr(QQ.frac_field(x, y, order=grlex)) == \
"FractionField(FracField((Symbol('x'), Symbol('y')), QQ, grlex))"
def test_PolynomialRingBase():
assert srepr(ZZ.old_poly_ring(x)) == \
"GlobalPolynomialRing(ZZ, Symbol('x'))"
assert srepr(ZZ[x].old_poly_ring(y)) == \
"GlobalPolynomialRing(ZZ[x], Symbol('y'))"
assert srepr(QQ.frac_field(x).old_poly_ring(y)) == \
"GlobalPolynomialRing(FractionField(FracField((Symbol('x'),), QQ, lex)), Symbol('y'))"
def test_DMP():
p1 = DMP([1, 2], ZZ)
p2 = ZZ.old_poly_ring(x)([1, 2])
if GROUND_TYPES != 'flint':
assert srepr(p1) == "DMP_Python([1, 2], ZZ)"
assert srepr(p2) == "DMP_Python([1, 2], ZZ)"
else:
assert srepr(p1) == "DUP_Flint([1, 2], ZZ)"
assert srepr(p2) == "DUP_Flint([1, 2], ZZ)"
def test_FiniteExtension():
assert srepr(FiniteExtension(Poly(x**2 + 1, x))) == \
"FiniteExtension(Poly(x**2 + 1, x, domain='ZZ'))"
def test_ExtensionElement():
A = FiniteExtension(Poly(x**2 + 1, x))
if GROUND_TYPES != 'flint':
ans = "ExtElem(DMP_Python([1, 0], ZZ), FiniteExtension(Poly(x**2 + 1, x, domain='ZZ')))"
else:
ans = "ExtElem(DUP_Flint([1, 0], ZZ), FiniteExtension(Poly(x**2 + 1, x, domain='ZZ')))"
assert srepr(A.generator) == ans
def test_BooleanAtom():
assert srepr(true) == "true"
assert srepr(false) == "false"
def test_Integers():
sT(S.Integers, "Integers")
def test_Naturals():
sT(S.Naturals, "Naturals")
def test_Naturals0():
sT(S.Naturals0, "Naturals0")
def test_Reals():
sT(S.Reals, "Reals")
def test_matrix_expressions():
n = symbols('n', integer=True)
A = MatrixSymbol("A", n, n)
B = MatrixSymbol("B", n, n)
sT(A, "MatrixSymbol(Str('A'), Symbol('n', integer=True), Symbol('n', integer=True))")
sT(A*B, "MatMul(MatrixSymbol(Str('A'), Symbol('n', integer=True), Symbol('n', integer=True)), MatrixSymbol(Str('B'), Symbol('n', integer=True), Symbol('n', integer=True)))")
sT(A + B, "MatAdd(MatrixSymbol(Str('A'), Symbol('n', integer=True), Symbol('n', integer=True)), MatrixSymbol(Str('B'), Symbol('n', integer=True), Symbol('n', integer=True)))")
def test_Cycle():
# FIXME: sT fails because Cycle is not immutable and calling srepr(Cycle(1, 2))
# adds keys to the Cycle dict (GH-17661)
#import_stmt = "from sympy.combinatorics import Cycle"
#sT(Cycle(1, 2), "Cycle(1, 2)", import_stmt)
assert srepr(Cycle(1, 2)) == "Cycle(1, 2)"
def test_Permutation():
import_stmt = "from sympy.combinatorics import Permutation"
sT(Permutation(1, 2)(3, 4), "Permutation([0, 2, 1, 4, 3])", import_stmt, perm_cyclic=False)
sT(Permutation(1, 2)(3, 4), "Permutation(1, 2)(3, 4)", import_stmt, perm_cyclic=True)
with warns_deprecated_sympy():
old_print_cyclic = Permutation.print_cyclic
Permutation.print_cyclic = False
sT(Permutation(1, 2)(3, 4), "Permutation([0, 2, 1, 4, 3])", import_stmt)
Permutation.print_cyclic = old_print_cyclic
def test_dict():
from sympy.abc import x, y, z
d = {}
assert srepr(d) == "{}"
d = {x: y}
assert srepr(d) == "{Symbol('x'): Symbol('y')}"
d = {x: y, y: z}
assert srepr(d) in (
"{Symbol('x'): Symbol('y'), Symbol('y'): Symbol('z')}",
"{Symbol('y'): Symbol('z'), Symbol('x'): Symbol('y')}",
)
d = {x: {y: z}}
assert srepr(d) == "{Symbol('x'): {Symbol('y'): Symbol('z')}}"
def test_set():
from sympy.abc import x, y
s = set()
assert srepr(s) == "set()"
s = {x, y}
assert srepr(s) in ("{Symbol('x'), Symbol('y')}", "{Symbol('y'), Symbol('x')}")
def test_Predicate():
sT(Q.even, "Q.even")
def test_AppliedPredicate():
sT(Q.even(Symbol('z')), "AppliedPredicate(Q.even, Symbol('z'))")

View File

@ -0,0 +1,360 @@
from sympy.core import (S, pi, oo, symbols, Rational, Integer,
GoldenRatio, EulerGamma, Catalan, Lambda, Dummy,
Eq, Ne, Le, Lt, Gt, Ge, Mod)
from sympy.functions import (Piecewise, sin, cos, Abs, exp, ceiling, sqrt,
sign, floor)
from sympy.logic import ITE
from sympy.testing.pytest import raises
from sympy.utilities.lambdify import implemented_function
from sympy.tensor import IndexedBase, Idx
from sympy.matrices import MatrixSymbol, SparseMatrix, Matrix
from sympy.printing.rust import rust_code
x, y, z = symbols('x,y,z')
def test_Integer():
assert rust_code(Integer(42)) == "42"
assert rust_code(Integer(-56)) == "-56"
def test_Relational():
assert rust_code(Eq(x, y)) == "x == y"
assert rust_code(Ne(x, y)) == "x != y"
assert rust_code(Le(x, y)) == "x <= y"
assert rust_code(Lt(x, y)) == "x < y"
assert rust_code(Gt(x, y)) == "x > y"
assert rust_code(Ge(x, y)) == "x >= y"
def test_Rational():
assert rust_code(Rational(3, 7)) == "3_f64/7.0"
assert rust_code(Rational(18, 9)) == "2"
assert rust_code(Rational(3, -7)) == "-3_f64/7.0"
assert rust_code(Rational(-3, -7)) == "3_f64/7.0"
assert rust_code(x + Rational(3, 7)) == "x + 3_f64/7.0"
assert rust_code(Rational(3, 7)*x) == "(3_f64/7.0)*x"
def test_basic_ops():
assert rust_code(x + y) == "x + y"
assert rust_code(x - y) == "x - y"
assert rust_code(x * y) == "x*y"
assert rust_code(x / y) == "x/y"
assert rust_code(-x) == "-x"
def test_printmethod():
class fabs(Abs):
def _rust_code(self, printer):
return "%s.fabs()" % printer._print(self.args[0])
assert rust_code(fabs(x)) == "x.fabs()"
a = MatrixSymbol("a", 1, 3)
assert rust_code(a[0,0]) == 'a[0]'
def test_Functions():
assert rust_code(sin(x) ** cos(x)) == "x.sin().powf(x.cos())"
assert rust_code(abs(x)) == "x.abs()"
assert rust_code(ceiling(x)) == "x.ceil()"
assert rust_code(floor(x)) == "x.floor()"
# Automatic rewrite
assert rust_code(Mod(x, 3)) == 'x - 3*((1_f64/3.0)*x).floor()'
def test_Pow():
assert rust_code(1/x) == "x.recip()"
assert rust_code(x**-1) == rust_code(x**-1.0) == "x.recip()"
assert rust_code(sqrt(x)) == "x.sqrt()"
assert rust_code(x**S.Half) == rust_code(x**0.5) == "x.sqrt()"
assert rust_code(1/sqrt(x)) == "x.sqrt().recip()"
assert rust_code(x**-S.Half) == rust_code(x**-0.5) == "x.sqrt().recip()"
assert rust_code(1/pi) == "PI.recip()"
assert rust_code(pi**-1) == rust_code(pi**-1.0) == "PI.recip()"
assert rust_code(pi**-0.5) == "PI.sqrt().recip()"
assert rust_code(x**Rational(1, 3)) == "x.cbrt()"
assert rust_code(2**x) == "x.exp2()"
assert rust_code(exp(x)) == "x.exp()"
assert rust_code(x**3) == "x.powi(3)"
assert rust_code(x**(y**3)) == "x.powf(y.powi(3))"
assert rust_code(x**Rational(2, 3)) == "x.powf(2_f64/3.0)"
g = implemented_function('g', Lambda(x, 2*x))
assert rust_code(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \
"(3.5*2*x).powf(-x + y.powf(x))/(x.powi(2) + y)"
_cond_cfunc = [(lambda base, exp: exp.is_integer, "dpowi", 1),
(lambda base, exp: not exp.is_integer, "pow", 1)]
assert rust_code(x**3, user_functions={'Pow': _cond_cfunc}) == 'x.dpowi(3)'
assert rust_code(x**3.2, user_functions={'Pow': _cond_cfunc}) == 'x.pow(3.2)'
def test_constants():
assert rust_code(pi) == "PI"
assert rust_code(oo) == "INFINITY"
assert rust_code(S.Infinity) == "INFINITY"
assert rust_code(-oo) == "NEG_INFINITY"
assert rust_code(S.NegativeInfinity) == "NEG_INFINITY"
assert rust_code(S.NaN) == "NAN"
assert rust_code(exp(1)) == "E"
assert rust_code(S.Exp1) == "E"
def test_constants_other():
assert rust_code(2*GoldenRatio) == "const GoldenRatio: f64 = %s;\n2*GoldenRatio" % GoldenRatio.evalf(17)
assert rust_code(
2*Catalan) == "const Catalan: f64 = %s;\n2*Catalan" % Catalan.evalf(17)
assert rust_code(2*EulerGamma) == "const EulerGamma: f64 = %s;\n2*EulerGamma" % EulerGamma.evalf(17)
def test_boolean():
assert rust_code(True) == "true"
assert rust_code(S.true) == "true"
assert rust_code(False) == "false"
assert rust_code(S.false) == "false"
assert rust_code(x & y) == "x && y"
assert rust_code(x | y) == "x || y"
assert rust_code(~x) == "!x"
assert rust_code(x & y & z) == "x && y && z"
assert rust_code(x | y | z) == "x || y || z"
assert rust_code((x & y) | z) == "z || x && y"
assert rust_code((x | y) & z) == "z && (x || y)"
def test_Piecewise():
expr = Piecewise((x, x < 1), (x + 2, True))
assert rust_code(expr) == (
"if (x < 1) {\n"
" x\n"
"} else {\n"
" x + 2\n"
"}")
assert rust_code(expr, assign_to="r") == (
"r = if (x < 1) {\n"
" x\n"
"} else {\n"
" x + 2\n"
"};")
assert rust_code(expr, assign_to="r", inline=True) == (
"r = if (x < 1) { x } else { x + 2 };")
expr = Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True))
assert rust_code(expr, inline=True) == (
"if (x < 1) { x } else if (x < 5) { x + 1 } else { x + 2 }")
assert rust_code(expr, assign_to="r", inline=True) == (
"r = if (x < 1) { x } else if (x < 5) { x + 1 } else { x + 2 };")
assert rust_code(expr, assign_to="r") == (
"r = if (x < 1) {\n"
" x\n"
"} else if (x < 5) {\n"
" x + 1\n"
"} else {\n"
" x + 2\n"
"};")
expr = 2*Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True))
assert rust_code(expr, inline=True) == (
"2*if (x < 1) { x } else if (x < 5) { x + 1 } else { x + 2 }")
expr = 2*Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True)) - 42
assert rust_code(expr, inline=True) == (
"2*if (x < 1) { x } else if (x < 5) { x + 1 } else { x + 2 } - 42")
# Check that Piecewise without a True (default) condition error
expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
raises(ValueError, lambda: rust_code(expr))
def test_dereference_printing():
expr = x + y + sin(z) + z
assert rust_code(expr, dereference=[z]) == "x + y + (*z) + (*z).sin()"
def test_sign():
expr = sign(x) * y
assert rust_code(expr) == "y*x.signum()"
assert rust_code(expr, assign_to='r') == "r = y*x.signum();"
expr = sign(x + y) + 42
assert rust_code(expr) == "(x + y).signum() + 42"
assert rust_code(expr, assign_to='r') == "r = (x + y).signum() + 42;"
expr = sign(cos(x))
assert rust_code(expr) == "x.cos().signum()"
def test_reserved_words():
x, y = symbols("x if")
expr = sin(y)
assert rust_code(expr) == "if_.sin()"
assert rust_code(expr, dereference=[y]) == "(*if_).sin()"
assert rust_code(expr, reserved_word_suffix='_unreserved') == "if_unreserved.sin()"
with raises(ValueError):
rust_code(expr, error_on_reserved=True)
def test_ITE():
expr = ITE(x < 1, y, z)
assert rust_code(expr) == (
"if (x < 1) {\n"
" y\n"
"} else {\n"
" z\n"
"}")
def test_Indexed():
n, m, o = symbols('n m o', integer=True)
i, j, k = Idx('i', n), Idx('j', m), Idx('k', o)
x = IndexedBase('x')[j]
assert rust_code(x) == "x[j]"
A = IndexedBase('A')[i, j]
assert rust_code(A) == "A[m*i + j]"
B = IndexedBase('B')[i, j, k]
assert rust_code(B) == "B[m*o*i + o*j + k]"
def test_dummy_loops():
i, m = symbols('i m', integer=True, cls=Dummy)
x = IndexedBase('x')
y = IndexedBase('y')
i = Idx(i, m)
assert rust_code(x[i], assign_to=y[i]) == (
"for i in 0..m {\n"
" y[i] = x[i];\n"
"}")
def test_loops():
m, n = symbols('m n', integer=True)
A = IndexedBase('A')
x = IndexedBase('x')
y = IndexedBase('y')
z = IndexedBase('z')
i = Idx('i', m)
j = Idx('j', n)
assert rust_code(A[i, j]*x[j], assign_to=y[i]) == (
"for i in 0..m {\n"
" y[i] = 0;\n"
"}\n"
"for i in 0..m {\n"
" for j in 0..n {\n"
" y[i] = A[n*i + j]*x[j] + y[i];\n"
" }\n"
"}")
assert rust_code(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i]) == (
"for i in 0..m {\n"
" y[i] = x[i] + z[i];\n"
"}\n"
"for i in 0..m {\n"
" for j in 0..n {\n"
" y[i] = A[n*i + j]*x[j] + y[i];\n"
" }\n"
"}")
def test_loops_multiple_contractions():
n, m, o, p = symbols('n m o p', integer=True)
a = IndexedBase('a')
b = IndexedBase('b')
y = IndexedBase('y')
i = Idx('i', m)
j = Idx('j', n)
k = Idx('k', o)
l = Idx('l', p)
assert rust_code(b[j, k, l]*a[i, j, k, l], assign_to=y[i]) == (
"for i in 0..m {\n"
" y[i] = 0;\n"
"}\n"
"for i in 0..m {\n"
" for j in 0..n {\n"
" for k in 0..o {\n"
" for l in 0..p {\n"
" y[i] = a[%s]*b[%s] + y[i];\n" % (i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
" }\n"
" }\n"
" }\n"
"}")
def test_loops_addfactor():
m, n, o, p = symbols('m n o p', integer=True)
a = IndexedBase('a')
b = IndexedBase('b')
c = IndexedBase('c')
y = IndexedBase('y')
i = Idx('i', m)
j = Idx('j', n)
k = Idx('k', o)
l = Idx('l', p)
code = rust_code((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i])
assert code == (
"for i in 0..m {\n"
" y[i] = 0;\n"
"}\n"
"for i in 0..m {\n"
" for j in 0..n {\n"
" for k in 0..o {\n"
" for l in 0..p {\n"
" y[i] = (a[%s] + b[%s])*c[%s] + y[i];\n" % (i*n*o*p + j*o*p + k*p + l, i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
" }\n"
" }\n"
" }\n"
"}")
def test_settings():
raises(TypeError, lambda: rust_code(sin(x), method="garbage"))
def test_inline_function():
x = symbols('x')
g = implemented_function('g', Lambda(x, 2*x))
assert rust_code(g(x)) == "2*x"
g = implemented_function('g', Lambda(x, 2*x/Catalan))
assert rust_code(g(x)) == (
"const Catalan: f64 = %s;\n2*x/Catalan" % Catalan.evalf(17))
A = IndexedBase('A')
i = Idx('i', symbols('n', integer=True))
g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x)))
assert rust_code(g(A[i]), assign_to=A[i]) == (
"for i in 0..n {\n"
" A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n"
"}")
def test_user_functions():
x = symbols('x', integer=False)
n = symbols('n', integer=True)
custom_functions = {
"ceiling": "ceil",
"Abs": [(lambda x: not x.is_integer, "fabs", 4), (lambda x: x.is_integer, "abs", 4)],
}
assert rust_code(ceiling(x), user_functions=custom_functions) == "x.ceil()"
assert rust_code(Abs(x), user_functions=custom_functions) == "fabs(x)"
assert rust_code(Abs(n), user_functions=custom_functions) == "abs(n)"
def test_matrix():
assert rust_code(Matrix([1, 2, 3])) == '[1, 2, 3]'
with raises(ValueError):
rust_code(Matrix([[1, 2, 3]]))
def test_sparse_matrix():
# gh-15791
with raises(NotImplementedError):
rust_code(SparseMatrix([[1, 2, 3]]))

View File

@ -0,0 +1,553 @@
import contextlib
import itertools
import re
import typing
from enum import Enum
from typing import Callable
import sympy
from sympy import Add, Implies, sqrt
from sympy.core import Mul, Pow
from sympy.core import (S, pi, symbols, Function, Rational, Integer,
Symbol, Eq, Ne, Le, Lt, Gt, Ge)
from sympy.functions import Piecewise, exp, sin, cos
from sympy.assumptions.ask import Q
from sympy.printing.smtlib import smtlib_code
from sympy.testing.pytest import raises, Failed
x, y, z = symbols('x,y,z')
class _W(Enum):
DEFAULTING_TO_FLOAT = re.compile("Could not infer type of `.+`. Defaulting to float.", re.I)
WILL_NOT_DECLARE = re.compile("Non-Symbol/Function `.+` will not be declared.", re.I)
WILL_NOT_ASSERT = re.compile("Non-Boolean expression `.+` will not be asserted. Converting to SMTLib verbatim.", re.I)
@contextlib.contextmanager
def _check_warns(expected: typing.Iterable[_W]):
warns: typing.List[str] = []
log_warn = warns.append
yield log_warn
errors = []
for i, (w, e) in enumerate(itertools.zip_longest(warns, expected)):
if not e:
errors += [f"[{i}] Received unexpected warning `{w}`."]
elif not w:
errors += [f"[{i}] Did not receive expected warning `{e.name}`."]
elif not e.value.match(w):
errors += [f"[{i}] Warning `{w}` does not match expected {e.name}."]
if errors: raise Failed('\n'.join(errors))
def test_Integer():
with _check_warns([_W.WILL_NOT_ASSERT] * 2) as w:
assert smtlib_code(Integer(67), log_warn=w) == "67"
assert smtlib_code(Integer(-1), log_warn=w) == "-1"
with _check_warns([]) as w:
assert smtlib_code(Integer(67)) == "67"
assert smtlib_code(Integer(-1)) == "-1"
def test_Rational():
with _check_warns([_W.WILL_NOT_ASSERT] * 4) as w:
assert smtlib_code(Rational(3, 7), log_warn=w) == "(/ 3 7)"
assert smtlib_code(Rational(18, 9), log_warn=w) == "2"
assert smtlib_code(Rational(3, -7), log_warn=w) == "(/ -3 7)"
assert smtlib_code(Rational(-3, -7), log_warn=w) == "(/ 3 7)"
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT] * 2) as w:
assert smtlib_code(x + Rational(3, 7), auto_declare=False, log_warn=w) == "(+ (/ 3 7) x)"
assert smtlib_code(Rational(3, 7) * x, log_warn=w) == "(declare-const x Real)\n" \
"(* (/ 3 7) x)"
def test_Relational():
with _check_warns([_W.DEFAULTING_TO_FLOAT] * 12) as w:
assert smtlib_code(Eq(x, y), auto_declare=False, log_warn=w) == "(assert (= x y))"
assert smtlib_code(Ne(x, y), auto_declare=False, log_warn=w) == "(assert (not (= x y)))"
assert smtlib_code(Le(x, y), auto_declare=False, log_warn=w) == "(assert (<= x y))"
assert smtlib_code(Lt(x, y), auto_declare=False, log_warn=w) == "(assert (< x y))"
assert smtlib_code(Gt(x, y), auto_declare=False, log_warn=w) == "(assert (> x y))"
assert smtlib_code(Ge(x, y), auto_declare=False, log_warn=w) == "(assert (>= x y))"
def test_AppliedBinaryRelation():
with _check_warns([_W.DEFAULTING_TO_FLOAT] * 12) as w:
assert smtlib_code(Q.eq(x, y), auto_declare=False, log_warn=w) == "(assert (= x y))"
assert smtlib_code(Q.ne(x, y), auto_declare=False, log_warn=w) == "(assert (not (= x y)))"
assert smtlib_code(Q.lt(x, y), auto_declare=False, log_warn=w) == "(assert (< x y))"
assert smtlib_code(Q.le(x, y), auto_declare=False, log_warn=w) == "(assert (<= x y))"
assert smtlib_code(Q.gt(x, y), auto_declare=False, log_warn=w) == "(assert (> x y))"
assert smtlib_code(Q.ge(x, y), auto_declare=False, log_warn=w) == "(assert (>= x y))"
raises(ValueError, lambda: smtlib_code(Q.complex(x), log_warn=w))
def test_AppliedPredicate():
with _check_warns([_W.DEFAULTING_TO_FLOAT] * 6) as w:
assert smtlib_code(Q.positive(x), auto_declare=False, log_warn=w) == "(assert (> x 0))"
assert smtlib_code(Q.negative(x), auto_declare=False, log_warn=w) == "(assert (< x 0))"
assert smtlib_code(Q.zero(x), auto_declare=False, log_warn=w) == "(assert (= x 0))"
assert smtlib_code(Q.nonpositive(x), auto_declare=False, log_warn=w) == "(assert (<= x 0))"
assert smtlib_code(Q.nonnegative(x), auto_declare=False, log_warn=w) == "(assert (>= x 0))"
assert smtlib_code(Q.nonzero(x), auto_declare=False, log_warn=w) == "(assert (not (= x 0)))"
def test_Function():
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
assert smtlib_code(sin(x) ** cos(x), auto_declare=False, log_warn=w) == "(pow (sin x) (cos x))"
with _check_warns([_W.WILL_NOT_ASSERT]) as w:
assert smtlib_code(
abs(x),
symbol_table={x: int, y: bool},
known_types={int: "INTEGER_TYPE"},
known_functions={sympy.Abs: "ABSOLUTE_VALUE_OF"},
log_warn=w
) == "(declare-const x INTEGER_TYPE)\n" \
"(ABSOLUTE_VALUE_OF x)"
my_fun1 = Function('f1')
with _check_warns([_W.WILL_NOT_ASSERT]) as w:
assert smtlib_code(
my_fun1(x),
symbol_table={my_fun1: Callable[[bool], float]},
log_warn=w
) == "(declare-const x Bool)\n" \
"(declare-fun f1 (Bool) Real)\n" \
"(f1 x)"
with _check_warns([]) as w:
assert smtlib_code(
my_fun1(x),
symbol_table={my_fun1: Callable[[bool], bool]},
log_warn=w
) == "(declare-const x Bool)\n" \
"(declare-fun f1 (Bool) Bool)\n" \
"(assert (f1 x))"
assert smtlib_code(
Eq(my_fun1(x, z), y),
symbol_table={my_fun1: Callable[[int, bool], bool]},
log_warn=w
) == "(declare-const x Int)\n" \
"(declare-const y Bool)\n" \
"(declare-const z Bool)\n" \
"(declare-fun f1 (Int Bool) Bool)\n" \
"(assert (= (f1 x z) y))"
assert smtlib_code(
Eq(my_fun1(x, z), y),
symbol_table={my_fun1: Callable[[int, bool], bool]},
known_functions={my_fun1: "MY_KNOWN_FUN", Eq: '=='},
log_warn=w
) == "(declare-const x Int)\n" \
"(declare-const y Bool)\n" \
"(declare-const z Bool)\n" \
"(assert (== (MY_KNOWN_FUN x z) y))"
with _check_warns([_W.DEFAULTING_TO_FLOAT] * 3) as w:
assert smtlib_code(
Eq(my_fun1(x, z), y),
known_functions={my_fun1: "MY_KNOWN_FUN", Eq: '=='},
log_warn=w
) == "(declare-const x Real)\n" \
"(declare-const y Real)\n" \
"(declare-const z Real)\n" \
"(assert (== (MY_KNOWN_FUN x z) y))"
def test_Pow():
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
assert smtlib_code(x ** 3, auto_declare=False, log_warn=w) == "(pow x 3)"
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
assert smtlib_code(x ** (y ** 3), auto_declare=False, log_warn=w) == "(pow x (pow y 3))"
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
assert smtlib_code(x ** Rational(2, 3), auto_declare=False, log_warn=w) == '(pow x (/ 2 3))'
a = Symbol('a', integer=True)
b = Symbol('b', real=True)
c = Symbol('c')
def g(x): return 2 * x
# if x=1, y=2, then expr=2.333...
expr = 1 / (g(a) * 3.5) ** (a - b ** a) / (a ** 2 + b)
with _check_warns([]) as w:
assert smtlib_code(
[
Eq(a < 2, c),
Eq(b > a, c),
c & True,
Eq(expr, 2 + Rational(1, 3))
],
log_warn=w
) == '(declare-const a Int)\n' \
'(declare-const b Real)\n' \
'(declare-const c Bool)\n' \
'(assert (= (< a 2) c))\n' \
'(assert (= (> b a) c))\n' \
'(assert c)\n' \
'(assert (= ' \
'(* (pow (* 7.0 a) (+ (pow b a) (* -1 a))) (pow (+ b (pow a 2)) -1)) ' \
'(/ 7 3)' \
'))'
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
assert smtlib_code(
Mul(-2, c, Pow(Mul(b, b, evaluate=False), -1, evaluate=False), evaluate=False),
log_warn=w
) == '(declare-const b Real)\n' \
'(declare-const c Real)\n' \
'(* -2 c (pow (* b b) -1))'
def test_basic_ops():
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
assert smtlib_code(x * y, auto_declare=False, log_warn=w) == "(* x y)"
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
assert smtlib_code(x + y, auto_declare=False, log_warn=w) == "(+ x y)"
# with _check_warns([_SmtlibWarnings.DEFAULTING_TO_FLOAT, _SmtlibWarnings.DEFAULTING_TO_FLOAT, _SmtlibWarnings.WILL_NOT_ASSERT]) as w:
# todo: implement re-write, currently does '(+ x (* -1 y))' instead
# assert smtlib_code(x - y, auto_declare=False, log_warn=w) == "(- x y)"
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
assert smtlib_code(-x, auto_declare=False, log_warn=w) == "(* -1 x)"
def test_quantifier_extensions():
from sympy.logic.boolalg import Boolean
from sympy import Interval, Tuple, sympify
# start For-all quantifier class example
class ForAll(Boolean):
def _smtlib(self, printer):
bound_symbol_declarations = [
printer._s_expr(sym.name, [
printer._known_types[printer.symbol_table[sym]],
Interval(start, end)
]) for sym, start, end in self.limits
]
return printer._s_expr('forall', [
printer._s_expr('', bound_symbol_declarations),
self.function
])
@property
def bound_symbols(self):
return {s for s, _, _ in self.limits}
@property
def free_symbols(self):
bound_symbol_names = {s.name for s in self.bound_symbols}
return {
s for s in self.function.free_symbols
if s.name not in bound_symbol_names
}
def __new__(cls, *args):
limits = [sympify(a) for a in args if isinstance(a, (tuple, Tuple))]
function = [sympify(a) for a in args if isinstance(a, Boolean)]
assert len(limits) + len(function) == len(args)
assert len(function) == 1
function = function[0]
if isinstance(function, ForAll): return ForAll.__new__(
ForAll, *(limits + function.limits), function.function
)
inst = Boolean.__new__(cls)
inst._args = tuple(limits + [function])
inst.limits = limits
inst.function = function
return inst
# end For-All Quantifier class example
f = Function('f')
with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w:
assert smtlib_code(
ForAll((x, -42, +21), Eq(f(x), f(x))),
symbol_table={f: Callable[[float], float]},
log_warn=w
) == '(assert (forall ( (x Real [-42, 21])) true))'
with _check_warns([_W.DEFAULTING_TO_FLOAT] * 2) as w:
assert smtlib_code(
ForAll(
(x, -42, +21), (y, -100, 3),
Implies(Eq(x, y), Eq(f(x), f(y)))
),
symbol_table={f: Callable[[float], float]},
log_warn=w
) == '(declare-fun f (Real) Real)\n' \
'(assert (' \
'forall ( (x Real [-42, 21]) (y Real [-100, 3])) ' \
'(=> (= x y) (= (f x) (f y)))' \
'))'
a = Symbol('a', integer=True)
b = Symbol('b', real=True)
c = Symbol('c')
with _check_warns([]) as w:
assert smtlib_code(
ForAll(
(a, 2, 100), ForAll(
(b, 2, 100),
Implies(a < b, sqrt(a) < b) | c
)),
log_warn=w
) == '(declare-const c Bool)\n' \
'(assert (forall ( (a Int [2, 100]) (b Real [2, 100])) ' \
'(or c (=> (< a b) (< (pow a (/ 1 2)) b)))' \
'))'
def test_mix_number_mult_symbols():
with _check_warns([_W.WILL_NOT_ASSERT]) as w:
assert smtlib_code(
1 / pi,
known_constants={pi: "MY_PI"},
log_warn=w
) == '(pow MY_PI -1)'
with _check_warns([_W.WILL_NOT_ASSERT]) as w:
assert smtlib_code(
[
Eq(pi, 3.14, evaluate=False),
1 / pi,
],
known_constants={pi: "MY_PI"},
log_warn=w
) == '(assert (= MY_PI 3.14))\n' \
'(pow MY_PI -1)'
with _check_warns([_W.WILL_NOT_ASSERT]) as w:
assert smtlib_code(
Add(S.Zero, S.One, S.NegativeOne, S.Half,
S.Exp1, S.Pi, S.GoldenRatio, evaluate=False),
known_constants={
S.Pi: 'p', S.GoldenRatio: 'g',
S.Exp1: 'e'
},
known_functions={
Add: 'plus',
exp: 'exp'
},
precision=3,
log_warn=w
) == '(plus 0 1 -1 (/ 1 2) (exp 1) p g)'
with _check_warns([_W.WILL_NOT_ASSERT]) as w:
assert smtlib_code(
Add(S.Zero, S.One, S.NegativeOne, S.Half,
S.Exp1, S.Pi, S.GoldenRatio, evaluate=False),
known_constants={
S.Pi: 'p'
},
known_functions={
Add: 'plus',
exp: 'exp'
},
precision=3,
log_warn=w
) == '(plus 0 1 -1 (/ 1 2) (exp 1) p 1.62)'
with _check_warns([_W.WILL_NOT_ASSERT]) as w:
assert smtlib_code(
Add(S.Zero, S.One, S.NegativeOne, S.Half,
S.Exp1, S.Pi, S.GoldenRatio, evaluate=False),
known_functions={Add: 'plus'},
precision=3,
log_warn=w
) == '(plus 0 1 -1 (/ 1 2) 2.72 3.14 1.62)'
with _check_warns([_W.WILL_NOT_ASSERT]) as w:
assert smtlib_code(
Add(S.Zero, S.One, S.NegativeOne, S.Half,
S.Exp1, S.Pi, S.GoldenRatio, evaluate=False),
known_constants={S.Exp1: 'e'},
known_functions={Add: 'plus'},
precision=3,
log_warn=w
) == '(plus 0 1 -1 (/ 1 2) e 3.14 1.62)'
def test_boolean():
with _check_warns([]) as w:
assert smtlib_code(x & y, log_warn=w) == '(declare-const x Bool)\n' \
'(declare-const y Bool)\n' \
'(assert (and x y))'
assert smtlib_code(x | y, log_warn=w) == '(declare-const x Bool)\n' \
'(declare-const y Bool)\n' \
'(assert (or x y))'
assert smtlib_code(~x, log_warn=w) == '(declare-const x Bool)\n' \
'(assert (not x))'
assert smtlib_code(x & y & z, log_warn=w) == '(declare-const x Bool)\n' \
'(declare-const y Bool)\n' \
'(declare-const z Bool)\n' \
'(assert (and x y z))'
with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w:
assert smtlib_code((x & ~y) | (z > 3), log_warn=w) == '(declare-const x Bool)\n' \
'(declare-const y Bool)\n' \
'(declare-const z Real)\n' \
'(assert (or (> z 3) (and x (not y))))'
f = Function('f')
g = Function('g')
h = Function('h')
with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w:
assert smtlib_code(
[Gt(f(x), y),
Lt(y, g(z))],
symbol_table={
f: Callable[[bool], int], g: Callable[[bool], int],
}, log_warn=w
) == '(declare-const x Bool)\n' \
'(declare-const y Real)\n' \
'(declare-const z Bool)\n' \
'(declare-fun f (Bool) Int)\n' \
'(declare-fun g (Bool) Int)\n' \
'(assert (> (f x) y))\n' \
'(assert (< y (g z)))'
with _check_warns([]) as w:
assert smtlib_code(
[Eq(f(x), y),
Lt(y, g(z))],
symbol_table={
f: Callable[[bool], int], g: Callable[[bool], int],
}, log_warn=w
) == '(declare-const x Bool)\n' \
'(declare-const y Int)\n' \
'(declare-const z Bool)\n' \
'(declare-fun f (Bool) Int)\n' \
'(declare-fun g (Bool) Int)\n' \
'(assert (= (f x) y))\n' \
'(assert (< y (g z)))'
with _check_warns([]) as w:
assert smtlib_code(
[Eq(f(x), y),
Eq(g(f(x)), z),
Eq(h(g(f(x))), x)],
symbol_table={
f: Callable[[float], int],
g: Callable[[int], bool],
h: Callable[[bool], float]
},
log_warn=w
) == '(declare-const x Real)\n' \
'(declare-const y Int)\n' \
'(declare-const z Bool)\n' \
'(declare-fun f (Real) Int)\n' \
'(declare-fun g (Int) Bool)\n' \
'(declare-fun h (Bool) Real)\n' \
'(assert (= (f x) y))\n' \
'(assert (= (g (f x)) z))\n' \
'(assert (= (h (g (f x))) x))'
# todo: make smtlib_code support arrays
# def test_containers():
# assert julia_code([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \
# "Any[1, 2, 3, Any[4, 5, Any[6, 7]], 8, Any[9, 10], 11]"
# assert julia_code((1, 2, (3, 4))) == "(1, 2, (3, 4))"
# assert julia_code([1]) == "Any[1]"
# assert julia_code((1,)) == "(1,)"
# assert julia_code(Tuple(*[1, 2, 3])) == "(1, 2, 3)"
# assert julia_code((1, x * y, (3, x ** 2))) == "(1, x .* y, (3, x .^ 2))"
# # scalar, matrix, empty matrix and empty list
# assert julia_code((1, eye(3), Matrix(0, 0, []), [])) == "(1, [1 0 0;\n0 1 0;\n0 0 1], zeros(0, 0), Any[])"
def test_smtlib_piecewise():
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
assert smtlib_code(
Piecewise((x, x < 1),
(x ** 2, True)),
auto_declare=False,
log_warn=w
) == '(ite (< x 1) x (pow x 2))'
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
assert smtlib_code(
Piecewise((x ** 2, x < 1),
(x ** 3, x < 2),
(x ** 4, x < 3),
(x ** 5, True)),
auto_declare=False,
log_warn=w
) == '(ite (< x 1) (pow x 2) ' \
'(ite (< x 2) (pow x 3) ' \
'(ite (< x 3) (pow x 4) ' \
'(pow x 5))))'
# Check that Piecewise without a True (default) condition error
expr = Piecewise((x, x < 1), (x ** 2, x > 1), (sin(x), x > 0))
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
raises(AssertionError, lambda: smtlib_code(expr, log_warn=w))
def test_smtlib_piecewise_times_const():
pw = Piecewise((x, x < 1), (x ** 2, True))
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
assert smtlib_code(2 * pw, log_warn=w) == '(declare-const x Real)\n(* 2 (ite (< x 1) x (pow x 2)))'
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
assert smtlib_code(pw / x, log_warn=w) == '(declare-const x Real)\n(* (pow x -1) (ite (< x 1) x (pow x 2)))'
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
assert smtlib_code(pw / (x * y), log_warn=w) == '(declare-const x Real)\n(declare-const y Real)\n(* (pow x -1) (pow y -1) (ite (< x 1) x (pow x 2)))'
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
assert smtlib_code(pw / 3, log_warn=w) == '(declare-const x Real)\n(* (/ 1 3) (ite (< x 1) x (pow x 2)))'
# todo: make smtlib_code support arrays / matrices ?
# def test_smtlib_matrix_assign_to():
# A = Matrix([[1, 2, 3]])
# assert smtlib_code(A, assign_to='a') == "a = [1 2 3]"
# A = Matrix([[1, 2], [3, 4]])
# assert smtlib_code(A, assign_to='A') == "A = [1 2;\n3 4]"
# def test_julia_matrix_1x1():
# A = Matrix([[3]])
# B = MatrixSymbol('B', 1, 1)
# C = MatrixSymbol('C', 1, 2)
# assert julia_code(A, assign_to=B) == "B = [3]"
# raises(ValueError, lambda: julia_code(A, assign_to=C))
# def test_julia_matrix_elements():
# A = Matrix([[x, 2, x * y]])
# assert julia_code(A[0, 0] ** 2 + A[0, 1] + A[0, 2]) == "x .^ 2 + x .* y + 2"
# A = MatrixSymbol('AA', 1, 3)
# assert julia_code(A) == "AA"
# assert julia_code(A[0, 0] ** 2 + sin(A[0, 1]) + A[0, 2]) == \
# "sin(AA[1,2]) + AA[1,1] .^ 2 + AA[1,3]"
# assert julia_code(sum(A)) == "AA[1,1] + AA[1,2] + AA[1,3]"
def test_smtlib_boolean():
with _check_warns([]) as w:
assert smtlib_code(True, auto_assert=False, log_warn=w) == 'true'
assert smtlib_code(True, log_warn=w) == '(assert true)'
assert smtlib_code(S.true, log_warn=w) == '(assert true)'
assert smtlib_code(S.false, log_warn=w) == '(assert false)'
assert smtlib_code(False, log_warn=w) == '(assert false)'
assert smtlib_code(False, auto_assert=False, log_warn=w) == 'false'
def test_not_supported():
f = Function('f')
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
raises(KeyError, lambda: smtlib_code(f(x).diff(x), symbol_table={f: Callable[[float], float]}, log_warn=w))
with _check_warns([_W.WILL_NOT_ASSERT]) as w:
raises(KeyError, lambda: smtlib_code(S.ComplexInfinity, log_warn=w))
def test_Float():
assert smtlib_code(0.0) == "0.0"
assert smtlib_code(0.000000000000000003) == '(* 3.0 (pow 10 -18))'
assert smtlib_code(5.3) == "5.3"

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,182 @@
from sympy.core.singleton import S
from sympy.printing.tableform import TableForm
from sympy.printing.latex import latex
from sympy.abc import x
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.trigonometric import sin
from sympy.testing.pytest import raises
from textwrap import dedent
def test_TableForm():
s = str(TableForm([["a", "b"], ["c", "d"], ["e", 0]],
headings="automatic"))
assert s == (
' | 1 2\n'
'-------\n'
'1 | a b\n'
'2 | c d\n'
'3 | e '
)
s = str(TableForm([["a", "b"], ["c", "d"], ["e", 0]],
headings="automatic", wipe_zeros=False))
assert s == dedent('''\
| 1 2
-------
1 | a b
2 | c d
3 | e 0''')
s = str(TableForm([[x**2, "b"], ["c", x**2], ["e", "f"]],
headings=("automatic", None)))
assert s == (
'1 | x**2 b \n'
'2 | c x**2\n'
'3 | e f '
)
s = str(TableForm([["a", "b"], ["c", "d"], ["e", "f"]],
headings=(None, "automatic")))
assert s == dedent('''\
1 2
---
a b
c d
e f''')
s = str(TableForm([[5, 7], [4, 2], [10, 3]],
headings=[["Group A", "Group B", "Group C"], ["y1", "y2"]]))
assert s == (
' | y1 y2\n'
'---------------\n'
'Group A | 5 7 \n'
'Group B | 4 2 \n'
'Group C | 10 3 '
)
raises(
ValueError,
lambda:
TableForm(
[[5, 7], [4, 2], [10, 3]],
headings=[["Group A", "Group B", "Group C"], ["y1", "y2"]],
alignments="middle")
)
s = str(TableForm([[5, 7], [4, 2], [10, 3]],
headings=[["Group A", "Group B", "Group C"], ["y1", "y2"]],
alignments="right"))
assert s == dedent('''\
| y1 y2
---------------
Group A | 5 7
Group B | 4 2
Group C | 10 3''')
# other alignment permutations
d = [[1, 100], [100, 1]]
s = TableForm(d, headings=(('xxx', 'x'), None), alignments='l')
assert str(s) == (
'xxx | 1 100\n'
' x | 100 1 '
)
s = TableForm(d, headings=(('xxx', 'x'), None), alignments='lr')
assert str(s) == dedent('''\
xxx | 1 100
x | 100 1''')
s = TableForm(d, headings=(('xxx', 'x'), None), alignments='clr')
assert str(s) == dedent('''\
xxx | 1 100
x | 100 1''')
s = TableForm(d, headings=(('xxx', 'x'), None))
assert str(s) == (
'xxx | 1 100\n'
' x | 100 1 '
)
raises(ValueError, lambda: TableForm(d, alignments='clr'))
#pad
s = str(TableForm([[None, "-", 2], [1]], pad='?'))
assert s == dedent('''\
? - 2
1 ? ?''')
def test_TableForm_latex():
s = latex(TableForm([[0, x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]],
wipe_zeros=True, headings=("automatic", "automatic")))
assert s == (
'\\begin{tabular}{r l l}\n'
' & 1 & 2 \\\\\n'
'\\hline\n'
'1 & & $x^{3}$ \\\\\n'
'2 & $c$ & $\\frac{1}{4}$ \\\\\n'
'3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n'
'\\end{tabular}'
)
s = latex(TableForm([[0, x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]],
wipe_zeros=True, headings=("automatic", "automatic"), alignments='l'))
assert s == (
'\\begin{tabular}{r l l}\n'
' & 1 & 2 \\\\\n'
'\\hline\n'
'1 & & $x^{3}$ \\\\\n'
'2 & $c$ & $\\frac{1}{4}$ \\\\\n'
'3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n'
'\\end{tabular}'
)
s = latex(TableForm([[0, x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]],
wipe_zeros=True, headings=("automatic", "automatic"), alignments='l'*3))
assert s == (
'\\begin{tabular}{l l l}\n'
' & 1 & 2 \\\\\n'
'\\hline\n'
'1 & & $x^{3}$ \\\\\n'
'2 & $c$ & $\\frac{1}{4}$ \\\\\n'
'3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n'
'\\end{tabular}'
)
s = latex(TableForm([["a", x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]],
headings=("automatic", "automatic")))
assert s == (
'\\begin{tabular}{r l l}\n'
' & 1 & 2 \\\\\n'
'\\hline\n'
'1 & $a$ & $x^{3}$ \\\\\n'
'2 & $c$ & $\\frac{1}{4}$ \\\\\n'
'3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n'
'\\end{tabular}'
)
s = latex(TableForm([["a", x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]],
formats=['(%s)', None], headings=("automatic", "automatic")))
assert s == (
'\\begin{tabular}{r l l}\n'
' & 1 & 2 \\\\\n'
'\\hline\n'
'1 & (a) & $x^{3}$ \\\\\n'
'2 & (c) & $\\frac{1}{4}$ \\\\\n'
'3 & (sqrt(x)) & $\\sin{\\left(x^{2} \\right)}$ \\\\\n'
'\\end{tabular}'
)
def neg_in_paren(x, i, j):
if i % 2:
return ('(%s)' if x < 0 else '%s') % x
else:
pass # use default print
s = latex(TableForm([[-1, 2], [-3, 4]],
formats=[neg_in_paren]*2, headings=("automatic", "automatic")))
assert s == (
'\\begin{tabular}{r l l}\n'
' & 1 & 2 \\\\\n'
'\\hline\n'
'1 & -1 & 2 \\\\\n'
'2 & (-3) & 4 \\\\\n'
'\\end{tabular}'
)
s = latex(TableForm([["a", x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]]))
assert s == (
'\\begin{tabular}{l l}\n'
'$a$ & $x^{3}$ \\\\\n'
'$c$ & $\\frac{1}{4}$ \\\\\n'
'$\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n'
'\\end{tabular}'
)

View File

@ -0,0 +1,465 @@
import random
from sympy.core.function import Derivative
from sympy.core.symbol import symbols
from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayAdd, \
PermuteDims, ArrayDiagonal
from sympy.core.relational import Eq, Ne, Ge, Gt, Le, Lt
from sympy.external import import_module
from sympy.functions import \
Abs, ceiling, exp, floor, sign, sin, asin, sqrt, cos, \
acos, tan, atan, atan2, cosh, acosh, sinh, asinh, tanh, atanh, \
re, im, arg, erf, loggamma, log
from sympy.matrices import Matrix, MatrixBase, eye, randMatrix
from sympy.matrices.expressions import \
Determinant, HadamardProduct, Inverse, MatrixSymbol, Trace
from sympy.printing.tensorflow import tensorflow_code
from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array
from sympy.utilities.lambdify import lambdify
from sympy.testing.pytest import skip
from sympy.testing.pytest import XFAIL
tf = tensorflow = import_module("tensorflow")
if tensorflow:
# Hide Tensorflow warnings
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
M = MatrixSymbol("M", 3, 3)
N = MatrixSymbol("N", 3, 3)
P = MatrixSymbol("P", 3, 3)
Q = MatrixSymbol("Q", 3, 3)
x, y, z, t = symbols("x y z t")
if tf is not None:
llo = [list(range(i, i+3)) for i in range(0, 9, 3)]
m3x3 = tf.constant(llo)
m3x3sympy = Matrix(llo)
def _compare_tensorflow_matrix(variables, expr, use_float=False):
f = lambdify(variables, expr, 'tensorflow')
if not use_float:
random_matrices = [randMatrix(v.rows, v.cols) for v in variables]
else:
random_matrices = [randMatrix(v.rows, v.cols)/100. for v in variables]
graph = tf.Graph()
r = None
with graph.as_default():
random_variables = [eval(tensorflow_code(i)) for i in random_matrices]
session = tf.compat.v1.Session(graph=graph)
r = session.run(f(*random_variables))
e = expr.subs(dict(zip(variables, random_matrices)))
e = e.doit()
if e.is_Matrix:
if not isinstance(e, MatrixBase):
e = e.as_explicit()
e = e.tolist()
if not use_float:
assert (r == e).all()
else:
r = [i for row in r for i in row]
e = [i for row in e for i in row]
assert all(
abs(a-b) < 10**-(4-int(log(abs(a), 10))) for a, b in zip(r, e))
# Creating a custom inverse test.
# See https://github.com/sympy/sympy/issues/18469
def _compare_tensorflow_matrix_inverse(variables, expr, use_float=False):
f = lambdify(variables, expr, 'tensorflow')
if not use_float:
random_matrices = [eye(v.rows, v.cols)*4 for v in variables]
else:
random_matrices = [eye(v.rows, v.cols)*3.14 for v in variables]
graph = tf.Graph()
r = None
with graph.as_default():
random_variables = [eval(tensorflow_code(i)) for i in random_matrices]
session = tf.compat.v1.Session(graph=graph)
r = session.run(f(*random_variables))
e = expr.subs(dict(zip(variables, random_matrices)))
e = e.doit()
if e.is_Matrix:
if not isinstance(e, MatrixBase):
e = e.as_explicit()
e = e.tolist()
if not use_float:
assert (r == e).all()
else:
r = [i for row in r for i in row]
e = [i for row in e for i in row]
assert all(
abs(a-b) < 10**-(4-int(log(abs(a), 10))) for a, b in zip(r, e))
def _compare_tensorflow_matrix_scalar(variables, expr):
f = lambdify(variables, expr, 'tensorflow')
random_matrices = [
randMatrix(v.rows, v.cols).evalf() / 100 for v in variables]
graph = tf.Graph()
r = None
with graph.as_default():
random_variables = [eval(tensorflow_code(i)) for i in random_matrices]
session = tf.compat.v1.Session(graph=graph)
r = session.run(f(*random_variables))
e = expr.subs(dict(zip(variables, random_matrices)))
e = e.doit()
assert abs(r-e) < 10**-6
def _compare_tensorflow_scalar(
variables, expr, rng=lambda: random.randint(0, 10)):
f = lambdify(variables, expr, 'tensorflow')
rvs = [rng() for v in variables]
graph = tf.Graph()
r = None
with graph.as_default():
tf_rvs = [eval(tensorflow_code(i)) for i in rvs]
session = tf.compat.v1.Session(graph=graph)
r = session.run(f(*tf_rvs))
e = expr.subs(dict(zip(variables, rvs))).evalf().doit()
assert abs(r-e) < 10**-6
def _compare_tensorflow_relational(
variables, expr, rng=lambda: random.randint(0, 10)):
f = lambdify(variables, expr, 'tensorflow')
rvs = [rng() for v in variables]
graph = tf.Graph()
r = None
with graph.as_default():
tf_rvs = [eval(tensorflow_code(i)) for i in rvs]
session = tf.compat.v1.Session(graph=graph)
r = session.run(f(*tf_rvs))
e = expr.subs(dict(zip(variables, rvs))).doit()
assert r == e
def test_tensorflow_printing():
assert tensorflow_code(eye(3)) == \
"tensorflow.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])"
expr = Matrix([[x, sin(y)], [exp(z), -t]])
assert tensorflow_code(expr) == \
"tensorflow.Variable(" \
"[[x, tensorflow.math.sin(y)]," \
" [tensorflow.math.exp(z), -t]])"
# This (random) test is XFAIL because it fails occasionally
# See https://github.com/sympy/sympy/issues/18469
@XFAIL
def test_tensorflow_math():
if not tf:
skip("TensorFlow not installed")
expr = Abs(x)
assert tensorflow_code(expr) == "tensorflow.math.abs(x)"
_compare_tensorflow_scalar((x,), expr)
expr = sign(x)
assert tensorflow_code(expr) == "tensorflow.math.sign(x)"
_compare_tensorflow_scalar((x,), expr)
expr = ceiling(x)
assert tensorflow_code(expr) == "tensorflow.math.ceil(x)"
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
expr = floor(x)
assert tensorflow_code(expr) == "tensorflow.math.floor(x)"
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
expr = exp(x)
assert tensorflow_code(expr) == "tensorflow.math.exp(x)"
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
expr = sqrt(x)
assert tensorflow_code(expr) == "tensorflow.math.sqrt(x)"
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
expr = x ** 4
assert tensorflow_code(expr) == "tensorflow.math.pow(x, 4)"
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
expr = cos(x)
assert tensorflow_code(expr) == "tensorflow.math.cos(x)"
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
expr = acos(x)
assert tensorflow_code(expr) == "tensorflow.math.acos(x)"
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(0, 0.95))
expr = sin(x)
assert tensorflow_code(expr) == "tensorflow.math.sin(x)"
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
expr = asin(x)
assert tensorflow_code(expr) == "tensorflow.math.asin(x)"
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
expr = tan(x)
assert tensorflow_code(expr) == "tensorflow.math.tan(x)"
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
expr = atan(x)
assert tensorflow_code(expr) == "tensorflow.math.atan(x)"
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
expr = atan2(y, x)
assert tensorflow_code(expr) == "tensorflow.math.atan2(y, x)"
_compare_tensorflow_scalar((y, x), expr, rng=lambda: random.random())
expr = cosh(x)
assert tensorflow_code(expr) == "tensorflow.math.cosh(x)"
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
expr = acosh(x)
assert tensorflow_code(expr) == "tensorflow.math.acosh(x)"
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(1, 2))
expr = sinh(x)
assert tensorflow_code(expr) == "tensorflow.math.sinh(x)"
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(1, 2))
expr = asinh(x)
assert tensorflow_code(expr) == "tensorflow.math.asinh(x)"
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(1, 2))
expr = tanh(x)
assert tensorflow_code(expr) == "tensorflow.math.tanh(x)"
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(1, 2))
expr = atanh(x)
assert tensorflow_code(expr) == "tensorflow.math.atanh(x)"
_compare_tensorflow_scalar(
(x,), expr, rng=lambda: random.uniform(-.5, .5))
expr = erf(x)
assert tensorflow_code(expr) == "tensorflow.math.erf(x)"
_compare_tensorflow_scalar(
(x,), expr, rng=lambda: random.random())
expr = loggamma(x)
assert tensorflow_code(expr) == "tensorflow.math.lgamma(x)"
_compare_tensorflow_scalar(
(x,), expr, rng=lambda: random.random())
def test_tensorflow_complexes():
assert tensorflow_code(re(x)) == "tensorflow.math.real(x)"
assert tensorflow_code(im(x)) == "tensorflow.math.imag(x)"
assert tensorflow_code(arg(x)) == "tensorflow.math.angle(x)"
def test_tensorflow_relational():
if not tf:
skip("TensorFlow not installed")
expr = Eq(x, y)
assert tensorflow_code(expr) == "tensorflow.math.equal(x, y)"
_compare_tensorflow_relational((x, y), expr)
expr = Ne(x, y)
assert tensorflow_code(expr) == "tensorflow.math.not_equal(x, y)"
_compare_tensorflow_relational((x, y), expr)
expr = Ge(x, y)
assert tensorflow_code(expr) == "tensorflow.math.greater_equal(x, y)"
_compare_tensorflow_relational((x, y), expr)
expr = Gt(x, y)
assert tensorflow_code(expr) == "tensorflow.math.greater(x, y)"
_compare_tensorflow_relational((x, y), expr)
expr = Le(x, y)
assert tensorflow_code(expr) == "tensorflow.math.less_equal(x, y)"
_compare_tensorflow_relational((x, y), expr)
expr = Lt(x, y)
assert tensorflow_code(expr) == "tensorflow.math.less(x, y)"
_compare_tensorflow_relational((x, y), expr)
# This (random) test is XFAIL because it fails occasionally
# See https://github.com/sympy/sympy/issues/18469
@XFAIL
def test_tensorflow_matrices():
if not tf:
skip("TensorFlow not installed")
expr = M
assert tensorflow_code(expr) == "M"
_compare_tensorflow_matrix((M,), expr)
expr = M + N
assert tensorflow_code(expr) == "tensorflow.math.add(M, N)"
_compare_tensorflow_matrix((M, N), expr)
expr = M * N
assert tensorflow_code(expr) == "tensorflow.linalg.matmul(M, N)"
_compare_tensorflow_matrix((M, N), expr)
expr = HadamardProduct(M, N)
assert tensorflow_code(expr) == "tensorflow.math.multiply(M, N)"
_compare_tensorflow_matrix((M, N), expr)
expr = M*N*P*Q
assert tensorflow_code(expr) == \
"tensorflow.linalg.matmul(" \
"tensorflow.linalg.matmul(" \
"tensorflow.linalg.matmul(M, N), P), Q)"
_compare_tensorflow_matrix((M, N, P, Q), expr)
expr = M**3
assert tensorflow_code(expr) == \
"tensorflow.linalg.matmul(tensorflow.linalg.matmul(M, M), M)"
_compare_tensorflow_matrix((M,), expr)
expr = Trace(M)
assert tensorflow_code(expr) == "tensorflow.linalg.trace(M)"
_compare_tensorflow_matrix((M,), expr)
expr = Determinant(M)
assert tensorflow_code(expr) == "tensorflow.linalg.det(M)"
_compare_tensorflow_matrix_scalar((M,), expr)
expr = Inverse(M)
assert tensorflow_code(expr) == "tensorflow.linalg.inv(M)"
_compare_tensorflow_matrix_inverse((M,), expr, use_float=True)
expr = M.T
assert tensorflow_code(expr, tensorflow_version='1.14') == \
"tensorflow.linalg.matrix_transpose(M)"
assert tensorflow_code(expr, tensorflow_version='1.13') == \
"tensorflow.matrix_transpose(M)"
_compare_tensorflow_matrix((M,), expr)
def test_codegen_einsum():
if not tf:
skip("TensorFlow not installed")
graph = tf.Graph()
with graph.as_default():
session = tf.compat.v1.Session(graph=graph)
M = MatrixSymbol("M", 2, 2)
N = MatrixSymbol("N", 2, 2)
cg = convert_matrix_to_array(M * N)
f = lambdify((M, N), cg, 'tensorflow')
ma = tf.constant([[1, 2], [3, 4]])
mb = tf.constant([[1,-2], [-1, 3]])
y = session.run(f(ma, mb))
c = session.run(tf.matmul(ma, mb))
assert (y == c).all()
def test_codegen_extra():
if not tf:
skip("TensorFlow not installed")
graph = tf.Graph()
with graph.as_default():
session = tf.compat.v1.Session()
M = MatrixSymbol("M", 2, 2)
N = MatrixSymbol("N", 2, 2)
P = MatrixSymbol("P", 2, 2)
Q = MatrixSymbol("Q", 2, 2)
ma = tf.constant([[1, 2], [3, 4]])
mb = tf.constant([[1,-2], [-1, 3]])
mc = tf.constant([[2, 0], [1, 2]])
md = tf.constant([[1,-1], [4, 7]])
cg = ArrayTensorProduct(M, N)
assert tensorflow_code(cg) == \
'tensorflow.linalg.einsum("ab,cd", M, N)'
f = lambdify((M, N), cg, 'tensorflow')
y = session.run(f(ma, mb))
c = session.run(tf.einsum("ij,kl", ma, mb))
assert (y == c).all()
cg = ArrayAdd(M, N)
assert tensorflow_code(cg) == 'tensorflow.math.add(M, N)'
f = lambdify((M, N), cg, 'tensorflow')
y = session.run(f(ma, mb))
c = session.run(ma + mb)
assert (y == c).all()
cg = ArrayAdd(M, N, P)
assert tensorflow_code(cg) == \
'tensorflow.math.add(tensorflow.math.add(M, N), P)'
f = lambdify((M, N, P), cg, 'tensorflow')
y = session.run(f(ma, mb, mc))
c = session.run(ma + mb + mc)
assert (y == c).all()
cg = ArrayAdd(M, N, P, Q)
assert tensorflow_code(cg) == \
'tensorflow.math.add(' \
'tensorflow.math.add(tensorflow.math.add(M, N), P), Q)'
f = lambdify((M, N, P, Q), cg, 'tensorflow')
y = session.run(f(ma, mb, mc, md))
c = session.run(ma + mb + mc + md)
assert (y == c).all()
cg = PermuteDims(M, [1, 0])
assert tensorflow_code(cg) == 'tensorflow.transpose(M, [1, 0])'
f = lambdify((M,), cg, 'tensorflow')
y = session.run(f(ma))
c = session.run(tf.transpose(ma))
assert (y == c).all()
cg = PermuteDims(ArrayTensorProduct(M, N), [1, 2, 3, 0])
assert tensorflow_code(cg) == \
'tensorflow.transpose(' \
'tensorflow.linalg.einsum("ab,cd", M, N), [1, 2, 3, 0])'
f = lambdify((M, N), cg, 'tensorflow')
y = session.run(f(ma, mb))
c = session.run(tf.transpose(tf.einsum("ab,cd", ma, mb), [1, 2, 3, 0]))
assert (y == c).all()
cg = ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2))
assert tensorflow_code(cg) == \
'tensorflow.linalg.einsum("ab,bc->acb", M, N)'
f = lambdify((M, N), cg, 'tensorflow')
y = session.run(f(ma, mb))
c = session.run(tf.einsum("ab,bc->acb", ma, mb))
assert (y == c).all()
def test_MatrixElement_printing():
A = MatrixSymbol("A", 1, 3)
B = MatrixSymbol("B", 1, 3)
C = MatrixSymbol("C", 1, 3)
assert tensorflow_code(A[0, 0]) == "A[0, 0]"
assert tensorflow_code(3 * A[0, 0]) == "3*A[0, 0]"
F = C[0, 0].subs(C, A - B)
assert tensorflow_code(F) == "(tensorflow.math.add((-1)*B, A))[0, 0]"
def test_tensorflow_Derivative():
expr = Derivative(sin(x), x)
assert tensorflow_code(expr) == \
"tensorflow.gradients(tensorflow.math.sin(x), x)[0]"

View File

@ -0,0 +1,639 @@
"""
Important note on tests in this module - the Theano printing functions use a
global cache by default, which means that tests using it will modify global
state and thus not be independent from each other. Instead of using the "cache"
keyword argument each time, this module uses the theano_code_ and
theano_function_ functions defined below which default to using a new, empty
cache instead.
"""
import logging
from sympy.external import import_module
from sympy.testing.pytest import raises, SKIP, warns_deprecated_sympy
theanologger = logging.getLogger('theano.configdefaults')
theanologger.setLevel(logging.CRITICAL)
theano = import_module('theano')
theanologger.setLevel(logging.WARNING)
if theano:
import numpy as np
ts = theano.scalar
tt = theano.tensor
xt, yt, zt = [tt.scalar(name, 'floatX') for name in 'xyz']
Xt, Yt, Zt = [tt.tensor('floatX', (False, False), name=n) for n in 'XYZ']
else:
#bin/test will not execute any tests now
disabled = True
import sympy as sy
from sympy.core.singleton import S
from sympy.abc import x, y, z, t
from sympy.printing.theanocode import (theano_code, dim_handling,
theano_function)
# Default set of matrix symbols for testing - make square so we can both
# multiply and perform elementwise operations between them.
X, Y, Z = [sy.MatrixSymbol(n, 4, 4) for n in 'XYZ']
# For testing AppliedUndef
f_t = sy.Function('f')(t)
def theano_code_(expr, **kwargs):
""" Wrapper for theano_code that uses a new, empty cache by default. """
kwargs.setdefault('cache', {})
with warns_deprecated_sympy():
return theano_code(expr, **kwargs)
def theano_function_(inputs, outputs, **kwargs):
""" Wrapper for theano_function that uses a new, empty cache by default. """
kwargs.setdefault('cache', {})
with warns_deprecated_sympy():
return theano_function(inputs, outputs, **kwargs)
def fgraph_of(*exprs):
""" Transform SymPy expressions into Theano Computation.
Parameters
==========
exprs
SymPy expressions
Returns
=======
theano.gof.FunctionGraph
"""
outs = list(map(theano_code_, exprs))
ins = theano.gof.graph.inputs(outs)
ins, outs = theano.gof.graph.clone(ins, outs)
return theano.gof.FunctionGraph(ins, outs)
def theano_simplify(fgraph):
""" Simplify a Theano Computation.
Parameters
==========
fgraph : theano.gof.FunctionGraph
Returns
=======
theano.gof.FunctionGraph
"""
mode = theano.compile.get_default_mode().excluding("fusion")
fgraph = fgraph.clone()
mode.optimizer.optimize(fgraph)
return fgraph
def theq(a, b):
""" Test two Theano objects for equality.
Also accepts numeric types and lists/tuples of supported types.
Note - debugprint() has a bug where it will accept numeric types but does
not respect the "file" argument and in this case and instead prints the number
to stdout and returns an empty string. This can lead to tests passing where
they should fail because any two numbers will always compare as equal. To
prevent this we treat numbers as a separate case.
"""
numeric_types = (int, float, np.number)
a_is_num = isinstance(a, numeric_types)
b_is_num = isinstance(b, numeric_types)
# Compare numeric types using regular equality
if a_is_num or b_is_num:
if not (a_is_num and b_is_num):
return False
return a == b
# Compare sequences element-wise
a_is_seq = isinstance(a, (tuple, list))
b_is_seq = isinstance(b, (tuple, list))
if a_is_seq or b_is_seq:
if not (a_is_seq and b_is_seq) or type(a) != type(b):
return False
return list(map(theq, a)) == list(map(theq, b))
# Otherwise, assume debugprint() can handle it
astr = theano.printing.debugprint(a, file='str')
bstr = theano.printing.debugprint(b, file='str')
# Check for bug mentioned above
for argname, argval, argstr in [('a', a, astr), ('b', b, bstr)]:
if argstr == '':
raise TypeError(
'theano.printing.debugprint(%s) returned empty string '
'(%s is instance of %r)'
% (argname, argname, type(argval))
)
return astr == bstr
def test_example_symbols():
"""
Check that the example symbols in this module print to their Theano
equivalents, as many of the other tests depend on this.
"""
assert theq(xt, theano_code_(x))
assert theq(yt, theano_code_(y))
assert theq(zt, theano_code_(z))
assert theq(Xt, theano_code_(X))
assert theq(Yt, theano_code_(Y))
assert theq(Zt, theano_code_(Z))
def test_Symbol():
""" Test printing a Symbol to a theano variable. """
xx = theano_code_(x)
assert isinstance(xx, (tt.TensorVariable, ts.ScalarVariable))
assert xx.broadcastable == ()
assert xx.name == x.name
xx2 = theano_code_(x, broadcastables={x: (False,)})
assert xx2.broadcastable == (False,)
assert xx2.name == x.name
def test_MatrixSymbol():
""" Test printing a MatrixSymbol to a theano variable. """
XX = theano_code_(X)
assert isinstance(XX, tt.TensorVariable)
assert XX.broadcastable == (False, False)
@SKIP # TODO - this is currently not checked but should be implemented
def test_MatrixSymbol_wrong_dims():
""" Test MatrixSymbol with invalid broadcastable. """
bcs = [(), (False,), (True,), (True, False), (False, True,), (True, True)]
for bc in bcs:
with raises(ValueError):
theano_code_(X, broadcastables={X: bc})
def test_AppliedUndef():
""" Test printing AppliedUndef instance, which works similarly to Symbol. """
ftt = theano_code_(f_t)
assert isinstance(ftt, tt.TensorVariable)
assert ftt.broadcastable == ()
assert ftt.name == 'f_t'
def test_add():
expr = x + y
comp = theano_code_(expr)
assert comp.owner.op == theano.tensor.add
def test_trig():
assert theq(theano_code_(sy.sin(x)), tt.sin(xt))
assert theq(theano_code_(sy.tan(x)), tt.tan(xt))
def test_many():
""" Test printing a complex expression with multiple symbols. """
expr = sy.exp(x**2 + sy.cos(y)) * sy.log(2*z)
comp = theano_code_(expr)
expected = tt.exp(xt**2 + tt.cos(yt)) * tt.log(2*zt)
assert theq(comp, expected)
def test_dtype():
""" Test specifying specific data types through the dtype argument. """
for dtype in ['float32', 'float64', 'int8', 'int16', 'int32', 'int64']:
assert theano_code_(x, dtypes={x: dtype}).type.dtype == dtype
# "floatX" type
assert theano_code_(x, dtypes={x: 'floatX'}).type.dtype in ('float32', 'float64')
# Type promotion
assert theano_code_(x + 1, dtypes={x: 'float32'}).type.dtype == 'float32'
assert theano_code_(x + y, dtypes={x: 'float64', y: 'float32'}).type.dtype == 'float64'
def test_broadcastables():
""" Test the "broadcastables" argument when printing symbol-like objects. """
# No restrictions on shape
for s in [x, f_t]:
for bc in [(), (False,), (True,), (False, False), (True, False)]:
assert theano_code_(s, broadcastables={s: bc}).broadcastable == bc
# TODO - matrix broadcasting?
def test_broadcasting():
""" Test "broadcastable" attribute after applying element-wise binary op. """
expr = x + y
cases = [
[(), (), ()],
[(False,), (False,), (False,)],
[(True,), (False,), (False,)],
[(False, True), (False, False), (False, False)],
[(True, False), (False, False), (False, False)],
]
for bc1, bc2, bc3 in cases:
comp = theano_code_(expr, broadcastables={x: bc1, y: bc2})
assert comp.broadcastable == bc3
def test_MatMul():
expr = X*Y*Z
expr_t = theano_code_(expr)
assert isinstance(expr_t.owner.op, tt.Dot)
assert theq(expr_t, Xt.dot(Yt).dot(Zt))
def test_Transpose():
assert isinstance(theano_code_(X.T).owner.op, tt.DimShuffle)
def test_MatAdd():
expr = X+Y+Z
assert isinstance(theano_code_(expr).owner.op, tt.Elemwise)
def test_Rationals():
assert theq(theano_code_(sy.Integer(2) / 3), tt.true_div(2, 3))
assert theq(theano_code_(S.Half), tt.true_div(1, 2))
def test_Integers():
assert theano_code_(sy.Integer(3)) == 3
def test_factorial():
n = sy.Symbol('n')
assert theano_code_(sy.factorial(n))
def test_Derivative():
simp = lambda expr: theano_simplify(fgraph_of(expr))
assert theq(simp(theano_code_(sy.Derivative(sy.sin(x), x, evaluate=False))),
simp(theano.grad(tt.sin(xt), xt)))
def test_theano_function_simple():
""" Test theano_function() with single output. """
f = theano_function_([x, y], [x+y])
assert f(2, 3) == 5
def test_theano_function_multi():
""" Test theano_function() with multiple outputs. """
f = theano_function_([x, y], [x+y, x-y])
o1, o2 = f(2, 3)
assert o1 == 5
assert o2 == -1
def test_theano_function_numpy():
""" Test theano_function() vs Numpy implementation. """
f = theano_function_([x, y], [x+y], dim=1,
dtypes={x: 'float64', y: 'float64'})
assert np.linalg.norm(f([1, 2], [3, 4]) - np.asarray([4, 6])) < 1e-9
f = theano_function_([x, y], [x+y], dtypes={x: 'float64', y: 'float64'},
dim=1)
xx = np.arange(3).astype('float64')
yy = 2*np.arange(3).astype('float64')
assert np.linalg.norm(f(xx, yy) - 3*np.arange(3)) < 1e-9
def test_theano_function_matrix():
m = sy.Matrix([[x, y], [z, x + y + z]])
expected = np.array([[1.0, 2.0], [3.0, 1.0 + 2.0 + 3.0]])
f = theano_function_([x, y, z], [m])
np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected)
f = theano_function_([x, y, z], [m], scalar=True)
np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected)
f = theano_function_([x, y, z], [m, m])
assert isinstance(f(1.0, 2.0, 3.0), type([]))
np.testing.assert_allclose(f(1.0, 2.0, 3.0)[0], expected)
np.testing.assert_allclose(f(1.0, 2.0, 3.0)[1], expected)
def test_dim_handling():
assert dim_handling([x], dim=2) == {x: (False, False)}
assert dim_handling([x, y], dims={x: 1, y: 2}) == {x: (False, True),
y: (False, False)}
assert dim_handling([x], broadcastables={x: (False,)}) == {x: (False,)}
def test_theano_function_kwargs():
"""
Test passing additional kwargs from theano_function() to theano.function().
"""
import numpy as np
f = theano_function_([x, y, z], [x+y], dim=1, on_unused_input='ignore',
dtypes={x: 'float64', y: 'float64', z: 'float64'})
assert np.linalg.norm(f([1, 2], [3, 4], [0, 0]) - np.asarray([4, 6])) < 1e-9
f = theano_function_([x, y, z], [x+y],
dtypes={x: 'float64', y: 'float64', z: 'float64'},
dim=1, on_unused_input='ignore')
xx = np.arange(3).astype('float64')
yy = 2*np.arange(3).astype('float64')
zz = 2*np.arange(3).astype('float64')
assert np.linalg.norm(f(xx, yy, zz) - 3*np.arange(3)) < 1e-9
def test_theano_function_scalar():
""" Test the "scalar" argument to theano_function(). """
args = [
([x, y], [x + y], None, [0]), # Single 0d output
([X, Y], [X + Y], None, [2]), # Single 2d output
([x, y], [x + y], {x: 0, y: 1}, [1]), # Single 1d output
([x, y], [x + y, x - y], None, [0, 0]), # Two 0d outputs
([x, y, X, Y], [x + y, X + Y], None, [0, 2]), # One 0d output, one 2d
]
# Create and test functions with and without the scalar setting
for inputs, outputs, in_dims, out_dims in args:
for scalar in [False, True]:
f = theano_function_(inputs, outputs, dims=in_dims, scalar=scalar)
# Check the theano_function attribute is set whether wrapped or not
assert isinstance(f.theano_function, theano.compile.function_module.Function)
# Feed in inputs of the appropriate size and get outputs
in_values = [
np.ones([1 if bc else 5 for bc in i.type.broadcastable])
for i in f.theano_function.input_storage
]
out_values = f(*in_values)
if not isinstance(out_values, list):
out_values = [out_values]
# Check output types and shapes
assert len(out_dims) == len(out_values)
for d, value in zip(out_dims, out_values):
if scalar and d == 0:
# Should have been converted to a scalar value
assert isinstance(value, np.number)
else:
# Otherwise should be an array
assert isinstance(value, np.ndarray)
assert value.ndim == d
def test_theano_function_bad_kwarg():
"""
Passing an unknown keyword argument to theano_function() should raise an
exception.
"""
raises(Exception, lambda : theano_function_([x], [x+1], foobar=3))
def test_slice():
assert theano_code_(slice(1, 2, 3)) == slice(1, 2, 3)
def theq_slice(s1, s2):
for attr in ['start', 'stop', 'step']:
a1 = getattr(s1, attr)
a2 = getattr(s2, attr)
if a1 is None or a2 is None:
if not (a1 is None or a2 is None):
return False
elif not theq(a1, a2):
return False
return True
dtypes = {x: 'int32', y: 'int32'}
assert theq_slice(theano_code_(slice(x, y), dtypes=dtypes), slice(xt, yt))
assert theq_slice(theano_code_(slice(1, x, 3), dtypes=dtypes), slice(1, xt, 3))
def test_MatrixSlice():
from theano import Constant
cache = {}
n = sy.Symbol('n', integer=True)
X = sy.MatrixSymbol('X', n, n)
Y = X[1:2:3, 4:5:6]
Yt = theano_code_(Y, cache=cache)
s = ts.Scalar('int64')
assert tuple(Yt.owner.op.idx_list) == (slice(s, s, s), slice(s, s, s))
assert Yt.owner.inputs[0] == theano_code_(X, cache=cache)
# == doesn't work in theano like it does in SymPy. You have to use
# equals.
assert all(Yt.owner.inputs[i].equals(Constant(s, i)) for i in range(1, 7))
k = sy.Symbol('k')
theano_code_(k, dtypes={k: 'int32'})
start, stop, step = 4, k, 2
Y = X[start:stop:step]
Yt = theano_code_(Y, dtypes={n: 'int32', k: 'int32'})
# assert Yt.owner.op.idx_list[0].stop == kt
def test_BlockMatrix():
n = sy.Symbol('n', integer=True)
A, B, C, D = [sy.MatrixSymbol(name, n, n) for name in 'ABCD']
At, Bt, Ct, Dt = map(theano_code_, (A, B, C, D))
Block = sy.BlockMatrix([[A, B], [C, D]])
Blockt = theano_code_(Block)
solutions = [tt.join(0, tt.join(1, At, Bt), tt.join(1, Ct, Dt)),
tt.join(1, tt.join(0, At, Ct), tt.join(0, Bt, Dt))]
assert any(theq(Blockt, solution) for solution in solutions)
@SKIP
def test_BlockMatrix_Inverse_execution():
k, n = 2, 4
dtype = 'float32'
A = sy.MatrixSymbol('A', n, k)
B = sy.MatrixSymbol('B', n, n)
inputs = A, B
output = B.I*A
cutsizes = {A: [(n//2, n//2), (k//2, k//2)],
B: [(n//2, n//2), (n//2, n//2)]}
cutinputs = [sy.blockcut(i, *cutsizes[i]) for i in inputs]
cutoutput = output.subs(dict(zip(inputs, cutinputs)))
dtypes = dict(zip(inputs, [dtype]*len(inputs)))
f = theano_function_(inputs, [output], dtypes=dtypes, cache={})
fblocked = theano_function_(inputs, [sy.block_collapse(cutoutput)],
dtypes=dtypes, cache={})
ninputs = [np.random.rand(*x.shape).astype(dtype) for x in inputs]
ninputs = [np.arange(n*k).reshape(A.shape).astype(dtype),
np.eye(n).astype(dtype)]
ninputs[1] += np.ones(B.shape)*1e-5
assert np.allclose(f(*ninputs), fblocked(*ninputs), rtol=1e-5)
def test_DenseMatrix():
t = sy.Symbol('theta')
for MatrixType in [sy.Matrix, sy.ImmutableMatrix]:
X = MatrixType([[sy.cos(t), -sy.sin(t)], [sy.sin(t), sy.cos(t)]])
tX = theano_code_(X)
assert isinstance(tX, tt.TensorVariable)
assert tX.owner.op == tt.join_
def test_cache_basic():
""" Test single symbol-like objects are cached when printed by themselves. """
# Pairs of objects which should be considered equivalent with respect to caching
pairs = [
(x, sy.Symbol('x')),
(X, sy.MatrixSymbol('X', *X.shape)),
(f_t, sy.Function('f')(sy.Symbol('t'))),
]
for s1, s2 in pairs:
cache = {}
st = theano_code_(s1, cache=cache)
# Test hit with same instance
assert theano_code_(s1, cache=cache) is st
# Test miss with same instance but new cache
assert theano_code_(s1, cache={}) is not st
# Test hit with different but equivalent instance
assert theano_code_(s2, cache=cache) is st
def test_global_cache():
""" Test use of the global cache. """
from sympy.printing.theanocode import global_cache
backup = dict(global_cache)
try:
# Temporarily empty global cache
global_cache.clear()
for s in [x, X, f_t]:
with warns_deprecated_sympy():
st = theano_code(s)
assert theano_code(s) is st
finally:
# Restore global cache
global_cache.update(backup)
def test_cache_types_distinct():
"""
Test that symbol-like objects of different types (Symbol, MatrixSymbol,
AppliedUndef) are distinguished by the cache even if they have the same
name.
"""
symbols = [sy.Symbol('f_t'), sy.MatrixSymbol('f_t', 4, 4), f_t]
cache = {} # Single shared cache
printed = {}
for s in symbols:
st = theano_code_(s, cache=cache)
assert st not in printed.values()
printed[s] = st
# Check all printed objects are distinct
assert len(set(map(id, printed.values()))) == len(symbols)
# Check retrieving
for s, st in printed.items():
with warns_deprecated_sympy():
assert theano_code(s, cache=cache) is st
def test_symbols_are_created_once():
"""
Test that a symbol is cached and reused when it appears in an expression
more than once.
"""
expr = sy.Add(x, x, evaluate=False)
comp = theano_code_(expr)
assert theq(comp, xt + xt)
assert not theq(comp, xt + theano_code_(x))
def test_cache_complex():
"""
Test caching on a complicated expression with multiple symbols appearing
multiple times.
"""
expr = x ** 2 + (y - sy.exp(x)) * sy.sin(z - x * y)
symbol_names = {s.name for s in expr.free_symbols}
expr_t = theano_code_(expr)
# Iterate through variables in the Theano computational graph that the
# printed expression depends on
seen = set()
for v in theano.gof.graph.ancestors([expr_t]):
# Owner-less, non-constant variables should be our symbols
if v.owner is None and not isinstance(v, theano.gof.graph.Constant):
# Check it corresponds to a symbol and appears only once
assert v.name in symbol_names
assert v.name not in seen
seen.add(v.name)
# Check all were present
assert seen == symbol_names
def test_Piecewise():
# A piecewise linear
expr = sy.Piecewise((0, x<0), (x, x<2), (1, True)) # ___/III
result = theano_code_(expr)
assert result.owner.op == tt.switch
expected = tt.switch(xt<0, 0, tt.switch(xt<2, xt, 1))
assert theq(result, expected)
expr = sy.Piecewise((x, x < 0))
result = theano_code_(expr)
expected = tt.switch(xt < 0, xt, np.nan)
assert theq(result, expected)
expr = sy.Piecewise((0, sy.And(x>0, x<2)), \
(x, sy.Or(x>2, x<0)))
result = theano_code_(expr)
expected = tt.switch(tt.and_(xt>0,xt<2), 0, \
tt.switch(tt.or_(xt>2, xt<0), xt, np.nan))
assert theq(result, expected)
def test_Relationals():
assert theq(theano_code_(sy.Eq(x, y)), tt.eq(xt, yt))
# assert theq(theano_code_(sy.Ne(x, y)), tt.neq(xt, yt)) # TODO - implement
assert theq(theano_code_(x > y), xt > yt)
assert theq(theano_code_(x < y), xt < yt)
assert theq(theano_code_(x >= y), xt >= yt)
assert theq(theano_code_(x <= y), xt <= yt)
def test_complexfunctions():
with warns_deprecated_sympy():
xt, yt = theano_code_(x, dtypes={x:'complex128'}), theano_code_(y, dtypes={y: 'complex128'})
from sympy.functions.elementary.complexes import conjugate
from theano.tensor import as_tensor_variable as atv
from theano.tensor import complex as cplx
with warns_deprecated_sympy():
assert theq(theano_code_(y*conjugate(x)), yt*(xt.conj()))
assert theq(theano_code_((1+2j)*x), xt*(atv(1.0)+atv(2.0)*cplx(0,1)))
def test_constantfunctions():
with warns_deprecated_sympy():
tf = theano_function_([],[1+1j])
assert(tf()==1+1j)
def test_Exp1():
"""
Test that exp(1) prints without error and evaluates close to SymPy's E
"""
# sy.exp(1) should yield same instance of E as sy.E (singleton), but extra
# check added for sanity
e_a = sy.exp(1)
e_b = sy.E
np.testing.assert_allclose(float(e_a), np.e)
np.testing.assert_allclose(float(e_b), np.e)
e = theano_code_(e_a)
np.testing.assert_allclose(float(e_a), e.eval())
e = theano_code_(e_b)
np.testing.assert_allclose(float(e_b), e.eval())

View File

@ -0,0 +1,196 @@
from sympy.printing.tree import tree
from sympy.testing.pytest import XFAIL
# Remove this flag after making _assumptions cache deterministic.
@XFAIL
def test_print_tree_MatAdd():
from sympy.matrices.expressions import MatrixSymbol
A = MatrixSymbol('A', 3, 3)
B = MatrixSymbol('B', 3, 3)
test_str = [
'MatAdd: A + B\n',
'algebraic: False\n',
'commutative: False\n',
'complex: False\n',
'composite: False\n',
'even: False\n',
'extended_negative: False\n',
'extended_nonnegative: False\n',
'extended_nonpositive: False\n',
'extended_nonzero: False\n',
'extended_positive: False\n',
'extended_real: False\n',
'imaginary: False\n',
'integer: False\n',
'irrational: False\n',
'negative: False\n',
'noninteger: False\n',
'nonnegative: False\n',
'nonpositive: False\n',
'nonzero: False\n',
'odd: False\n',
'positive: False\n',
'prime: False\n',
'rational: False\n',
'real: False\n',
'transcendental: False\n',
'zero: False\n',
'+-MatrixSymbol: A\n',
'| algebraic: False\n',
'| commutative: False\n',
'| complex: False\n',
'| composite: False\n',
'| even: False\n',
'| extended_negative: False\n',
'| extended_nonnegative: False\n',
'| extended_nonpositive: False\n',
'| extended_nonzero: False\n',
'| extended_positive: False\n',
'| extended_real: False\n',
'| imaginary: False\n',
'| integer: False\n',
'| irrational: False\n',
'| negative: False\n',
'| noninteger: False\n',
'| nonnegative: False\n',
'| nonpositive: False\n',
'| nonzero: False\n',
'| odd: False\n',
'| positive: False\n',
'| prime: False\n',
'| rational: False\n',
'| real: False\n',
'| transcendental: False\n',
'| zero: False\n',
'| +-Symbol: A\n',
'| | commutative: True\n',
'| +-Integer: 3\n',
'| | algebraic: True\n',
'| | commutative: True\n',
'| | complex: True\n',
'| | extended_negative: False\n',
'| | extended_nonnegative: True\n',
'| | extended_real: True\n',
'| | finite: True\n',
'| | hermitian: True\n',
'| | imaginary: False\n',
'| | infinite: False\n',
'| | integer: True\n',
'| | irrational: False\n',
'| | negative: False\n',
'| | noninteger: False\n',
'| | nonnegative: True\n',
'| | rational: True\n',
'| | real: True\n',
'| | transcendental: False\n',
'| +-Integer: 3\n',
'| algebraic: True\n',
'| commutative: True\n',
'| complex: True\n',
'| extended_negative: False\n',
'| extended_nonnegative: True\n',
'| extended_real: True\n',
'| finite: True\n',
'| hermitian: True\n',
'| imaginary: False\n',
'| infinite: False\n',
'| integer: True\n',
'| irrational: False\n',
'| negative: False\n',
'| noninteger: False\n',
'| nonnegative: True\n',
'| rational: True\n',
'| real: True\n',
'| transcendental: False\n',
'+-MatrixSymbol: B\n',
' algebraic: False\n',
' commutative: False\n',
' complex: False\n',
' composite: False\n',
' even: False\n',
' extended_negative: False\n',
' extended_nonnegative: False\n',
' extended_nonpositive: False\n',
' extended_nonzero: False\n',
' extended_positive: False\n',
' extended_real: False\n',
' imaginary: False\n',
' integer: False\n',
' irrational: False\n',
' negative: False\n',
' noninteger: False\n',
' nonnegative: False\n',
' nonpositive: False\n',
' nonzero: False\n',
' odd: False\n',
' positive: False\n',
' prime: False\n',
' rational: False\n',
' real: False\n',
' transcendental: False\n',
' zero: False\n',
' +-Symbol: B\n',
' | commutative: True\n',
' +-Integer: 3\n',
' | algebraic: True\n',
' | commutative: True\n',
' | complex: True\n',
' | extended_negative: False\n',
' | extended_nonnegative: True\n',
' | extended_real: True\n',
' | finite: True\n',
' | hermitian: True\n',
' | imaginary: False\n',
' | infinite: False\n',
' | integer: True\n',
' | irrational: False\n',
' | negative: False\n',
' | noninteger: False\n',
' | nonnegative: True\n',
' | rational: True\n',
' | real: True\n',
' | transcendental: False\n',
' +-Integer: 3\n',
' algebraic: True\n',
' commutative: True\n',
' complex: True\n',
' extended_negative: False\n',
' extended_nonnegative: True\n',
' extended_real: True\n',
' finite: True\n',
' hermitian: True\n',
' imaginary: False\n',
' infinite: False\n',
' integer: True\n',
' irrational: False\n',
' negative: False\n',
' noninteger: False\n',
' nonnegative: True\n',
' rational: True\n',
' real: True\n',
' transcendental: False\n'
]
assert tree(A + B) == "".join(test_str)
def test_print_tree_MatAdd_noassumptions():
from sympy.matrices.expressions import MatrixSymbol
A = MatrixSymbol('A', 3, 3)
B = MatrixSymbol('B', 3, 3)
test_str = \
"""MatAdd: A + B
+-MatrixSymbol: A
| +-Str: A
| +-Integer: 3
| +-Integer: 3
+-MatrixSymbol: B
+-Str: B
+-Integer: 3
+-Integer: 3
"""
assert tree(A + B, assumptions=False) == test_str