I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,75 @@
from sympy.core.numbers import Rational
from sympy.core.symbol import symbols
from sympy.functions.combinatorial.factorials import (FallingFactorial, RisingFactorial, binomial, factorial)
from sympy.functions.special.gamma_functions import gamma
from sympy.simplify.combsimp import combsimp
from sympy.abc import x
def test_combsimp():
k, m, n = symbols('k m n', integer = True)
assert combsimp(factorial(n)) == factorial(n)
assert combsimp(binomial(n, k)) == binomial(n, k)
assert combsimp(factorial(n)/factorial(n - 3)) == n*(-1 + n)*(-2 + n)
assert combsimp(binomial(n + 1, k + 1)/binomial(n, k)) == (1 + n)/(1 + k)
assert combsimp(binomial(3*n + 4, n + 1)/binomial(3*n + 1, n)) == \
Rational(3, 2)*((3*n + 2)*(3*n + 4)/((n + 1)*(2*n + 3)))
assert combsimp(factorial(n)**2/factorial(n - 3)) == \
factorial(n)*n*(-1 + n)*(-2 + n)
assert combsimp(factorial(n)*binomial(n + 1, k + 1)/binomial(n, k)) == \
factorial(n + 1)/(1 + k)
assert combsimp(gamma(n + 3)) == factorial(n + 2)
assert combsimp(factorial(x)) == gamma(x + 1)
# issue 9699
assert combsimp((n + 1)*factorial(n)) == factorial(n + 1)
assert combsimp(factorial(n)/n) == factorial(n-1)
# issue 6658
assert combsimp(binomial(n, n - k)) == binomial(n, k)
# issue 6341, 7135
assert combsimp(factorial(n)/(factorial(k)*factorial(n - k))) == \
binomial(n, k)
assert combsimp(factorial(k)*factorial(n - k)/factorial(n)) == \
1/binomial(n, k)
assert combsimp(factorial(2*n)/factorial(n)**2) == binomial(2*n, n)
assert combsimp(factorial(2*n)*factorial(k)*factorial(n - k)/
factorial(n)**3) == binomial(2*n, n)/binomial(n, k)
assert combsimp(factorial(n*(1 + n) - n**2 - n)) == 1
assert combsimp(6*FallingFactorial(-4, n)/factorial(n)) == \
(-1)**n*(n + 1)*(n + 2)*(n + 3)
assert combsimp(6*FallingFactorial(-4, n - 1)/factorial(n - 1)) == \
(-1)**(n - 1)*n*(n + 1)*(n + 2)
assert combsimp(6*FallingFactorial(-4, n - 3)/factorial(n - 3)) == \
(-1)**(n - 3)*n*(n - 1)*(n - 2)
assert combsimp(6*FallingFactorial(-4, -n - 1)/factorial(-n - 1)) == \
-(-1)**(-n - 1)*n*(n - 1)*(n - 2)
assert combsimp(6*RisingFactorial(4, n)/factorial(n)) == \
(n + 1)*(n + 2)*(n + 3)
assert combsimp(6*RisingFactorial(4, n - 1)/factorial(n - 1)) == \
n*(n + 1)*(n + 2)
assert combsimp(6*RisingFactorial(4, n - 3)/factorial(n - 3)) == \
n*(n - 1)*(n - 2)
assert combsimp(6*RisingFactorial(4, -n - 1)/factorial(-n - 1)) == \
-n*(n - 1)*(n - 2)
def test_issue_6878():
n = symbols('n', integer=True)
assert combsimp(RisingFactorial(-10, n)) == 3628800*(-1)**n/factorial(10 - n)
def test_issue_14528():
p = symbols("p", integer=True, positive=True)
assert combsimp(binomial(1,p)) == 1/(factorial(p)*factorial(1-p))
assert combsimp(factorial(2-p)) == factorial(2-p)

View File

@ -0,0 +1,758 @@
from functools import reduce
import itertools
from operator import add
from sympy.codegen.matrix_nodes import MatrixSolve
from sympy.core.add import Add
from sympy.core.containers import Tuple
from sympy.core.expr import UnevaluatedExpr
from sympy.core.function import Function
from sympy.core.mul import Mul
from sympy.core.power import Pow
from sympy.core.relational import Eq
from sympy.core.singleton import S
from sympy.core.symbol import (Symbol, symbols)
from sympy.core.sympify import sympify
from sympy.functions.elementary.exponential import exp
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.piecewise import Piecewise
from sympy.functions.elementary.trigonometric import (cos, sin)
from sympy.matrices.dense import Matrix
from sympy.matrices.expressions import Inverse, MatAdd, MatMul, Transpose
from sympy.polys.rootoftools import CRootOf
from sympy.series.order import O
from sympy.simplify.cse_main import cse
from sympy.simplify.simplify import signsimp
from sympy.tensor.indexed import (Idx, IndexedBase)
from sympy.core.function import count_ops
from sympy.simplify.cse_opts import sub_pre, sub_post
from sympy.functions.special.hyper import meijerg
from sympy.simplify import cse_main, cse_opts
from sympy.utilities.iterables import subsets
from sympy.testing.pytest import XFAIL, raises
from sympy.matrices import (MutableDenseMatrix, MutableSparseMatrix,
ImmutableDenseMatrix, ImmutableSparseMatrix)
from sympy.matrices.expressions import MatrixSymbol
w, x, y, z = symbols('w,x,y,z')
x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12 = symbols('x:13')
def test_numbered_symbols():
ns = cse_main.numbered_symbols(prefix='y')
assert list(itertools.islice(
ns, 0, 10)) == [Symbol('y%s' % i) for i in range(0, 10)]
ns = cse_main.numbered_symbols(prefix='y')
assert list(itertools.islice(
ns, 10, 20)) == [Symbol('y%s' % i) for i in range(10, 20)]
ns = cse_main.numbered_symbols()
assert list(itertools.islice(
ns, 0, 10)) == [Symbol('x%s' % i) for i in range(0, 10)]
# Dummy "optimization" functions for testing.
def opt1(expr):
return expr + y
def opt2(expr):
return expr*z
def test_preprocess_for_cse():
assert cse_main.preprocess_for_cse(x, [(opt1, None)]) == x + y
assert cse_main.preprocess_for_cse(x, [(None, opt1)]) == x
assert cse_main.preprocess_for_cse(x, [(None, None)]) == x
assert cse_main.preprocess_for_cse(x, [(opt1, opt2)]) == x + y
assert cse_main.preprocess_for_cse(
x, [(opt1, None), (opt2, None)]) == (x + y)*z
def test_postprocess_for_cse():
assert cse_main.postprocess_for_cse(x, [(opt1, None)]) == x
assert cse_main.postprocess_for_cse(x, [(None, opt1)]) == x + y
assert cse_main.postprocess_for_cse(x, [(None, None)]) == x
assert cse_main.postprocess_for_cse(x, [(opt1, opt2)]) == x*z
# Note the reverse order of application.
assert cse_main.postprocess_for_cse(
x, [(None, opt1), (None, opt2)]) == x*z + y
def test_cse_single():
# Simple substitution.
e = Add(Pow(x + y, 2), sqrt(x + y))
substs, reduced = cse([e])
assert substs == [(x0, x + y)]
assert reduced == [sqrt(x0) + x0**2]
subst42, (red42,) = cse([42]) # issue_15082
assert len(subst42) == 0 and red42 == 42
subst_half, (red_half,) = cse([0.5])
assert len(subst_half) == 0 and red_half == 0.5
def test_cse_single2():
# Simple substitution, test for being able to pass the expression directly
e = Add(Pow(x + y, 2), sqrt(x + y))
substs, reduced = cse(e)
assert substs == [(x0, x + y)]
assert reduced == [sqrt(x0) + x0**2]
substs, reduced = cse(Matrix([[1]]))
assert isinstance(reduced[0], Matrix)
subst42, (red42,) = cse(42) # issue 15082
assert len(subst42) == 0 and red42 == 42
subst_half, (red_half,) = cse(0.5) # issue 15082
assert len(subst_half) == 0 and red_half == 0.5
def test_cse_not_possible():
# No substitution possible.
e = Add(x, y)
substs, reduced = cse([e])
assert substs == []
assert reduced == [x + y]
# issue 6329
eq = (meijerg((1, 2), (y, 4), (5,), [], x) +
meijerg((1, 3), (y, 4), (5,), [], x))
assert cse(eq) == ([], [eq])
def test_nested_substitution():
# Substitution within a substitution.
e = Add(Pow(w*x + y, 2), sqrt(w*x + y))
substs, reduced = cse([e])
assert substs == [(x0, w*x + y)]
assert reduced == [sqrt(x0) + x0**2]
def test_subtraction_opt():
# Make sure subtraction is optimized.
e = (x - y)*(z - y) + exp((x - y)*(z - y))
substs, reduced = cse(
[e], optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)])
assert substs == [(x0, (x - y)*(y - z))]
assert reduced == [-x0 + exp(-x0)]
e = -(x - y)*(z - y) + exp(-(x - y)*(z - y))
substs, reduced = cse(
[e], optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)])
assert substs == [(x0, (x - y)*(y - z))]
assert reduced == [x0 + exp(x0)]
# issue 4077
n = -1 + 1/x
e = n/x/(-n)**2 - 1/n/x
assert cse(e, optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)]) == \
([], [0])
assert cse(((w + x + y + z)*(w - y - z))/(w + x)**3) == \
([(x0, w + x), (x1, y + z)], [(w - x1)*(x0 + x1)/x0**3])
def test_multiple_expressions():
e1 = (x + y)*z
e2 = (x + y)*w
substs, reduced = cse([e1, e2])
assert substs == [(x0, x + y)]
assert reduced == [x0*z, x0*w]
l = [w*x*y + z, w*y]
substs, reduced = cse(l)
rsubsts, _ = cse(reversed(l))
assert substs == rsubsts
assert reduced == [z + x*x0, x0]
l = [w*x*y, w*x*y + z, w*y]
substs, reduced = cse(l)
rsubsts, _ = cse(reversed(l))
assert substs == rsubsts
assert reduced == [x1, x1 + z, x0]
l = [(x - z)*(y - z), x - z, y - z]
substs, reduced = cse(l)
rsubsts, _ = cse(reversed(l))
assert substs == [(x0, -z), (x1, x + x0), (x2, x0 + y)]
assert rsubsts == [(x0, -z), (x1, x0 + y), (x2, x + x0)]
assert reduced == [x1*x2, x1, x2]
l = [w*y + w + x + y + z, w*x*y]
assert cse(l) == ([(x0, w*y)], [w + x + x0 + y + z, x*x0])
assert cse([x + y, x + y + z]) == ([(x0, x + y)], [x0, z + x0])
assert cse([x + y, x + z]) == ([], [x + y, x + z])
assert cse([x*y, z + x*y, x*y*z + 3]) == \
([(x0, x*y)], [x0, z + x0, 3 + x0*z])
@XFAIL # CSE of non-commutative Mul terms is disabled
def test_non_commutative_cse():
A, B, C = symbols('A B C', commutative=False)
l = [A*B*C, A*C]
assert cse(l) == ([], l)
l = [A*B*C, A*B]
assert cse(l) == ([(x0, A*B)], [x0*C, x0])
# Test if CSE of non-commutative Mul terms is disabled
def test_bypass_non_commutatives():
A, B, C = symbols('A B C', commutative=False)
l = [A*B*C, A*C]
assert cse(l) == ([], l)
l = [A*B*C, A*B]
assert cse(l) == ([], l)
l = [B*C, A*B*C]
assert cse(l) == ([], l)
@XFAIL # CSE fails when replacing non-commutative sub-expressions
def test_non_commutative_order():
A, B, C = symbols('A B C', commutative=False)
x0 = symbols('x0', commutative=False)
l = [B+C, A*(B+C)]
assert cse(l) == ([(x0, B+C)], [x0, A*x0])
@XFAIL # Worked in gh-11232, but was reverted due to performance considerations
def test_issue_10228():
assert cse([x*y**2 + x*y]) == ([(x0, x*y)], [x0*y + x0])
assert cse([x + y, 2*x + y]) == ([(x0, x + y)], [x0, x + x0])
assert cse((w + 2*x + y + z, w + x + 1)) == (
[(x0, w + x)], [x0 + x + y + z, x0 + 1])
assert cse(((w + x + y + z)*(w - x))/(w + x)) == (
[(x0, w + x)], [(x0 + y + z)*(w - x)/x0])
a, b, c, d, f, g, j, m = symbols('a, b, c, d, f, g, j, m')
exprs = (d*g**2*j*m, 4*a*f*g*m, a*b*c*f**2)
assert cse(exprs) == (
[(x0, g*m), (x1, a*f)], [d*g*j*x0, 4*x0*x1, b*c*f*x1]
)
@XFAIL
def test_powers():
assert cse(x*y**2 + x*y) == ([(x0, x*y)], [x0*y + x0])
def test_issue_4498():
assert cse(w/(x - y) + z/(y - x), optimizations='basic') == \
([], [(w - z)/(x - y)])
def test_issue_4020():
assert cse(x**5 + x**4 + x**3 + x**2, optimizations='basic') \
== ([(x0, x**2)], [x0*(x**3 + x + x0 + 1)])
def test_issue_4203():
assert cse(sin(x**x)/x**x) == ([(x0, x**x)], [sin(x0)/x0])
def test_issue_6263():
e = Eq(x*(-x + 1) + x*(x - 1), 0)
assert cse(e, optimizations='basic') == ([], [True])
def test_issue_25043():
c = symbols("c")
x = symbols("x0", real=True)
cse_expr = cse(c*x**2 + c*(x**4 - x**2))[-1][-1]
free = cse_expr.free_symbols
assert len(free) == len({i.name for i in free})
def test_dont_cse_tuples():
from sympy.core.function import Subs
f = Function("f")
g = Function("g")
name_val, (expr,) = cse(
Subs(f(x, y), (x, y), (0, 1))
+ Subs(g(x, y), (x, y), (0, 1)))
assert name_val == []
assert expr == (Subs(f(x, y), (x, y), (0, 1))
+ Subs(g(x, y), (x, y), (0, 1)))
name_val, (expr,) = cse(
Subs(f(x, y), (x, y), (0, x + y))
+ Subs(g(x, y), (x, y), (0, x + y)))
assert name_val == [(x0, x + y)]
assert expr == Subs(f(x, y), (x, y), (0, x0)) + \
Subs(g(x, y), (x, y), (0, x0))
def test_pow_invpow():
assert cse(1/x**2 + x**2) == \
([(x0, x**2)], [x0 + 1/x0])
assert cse(x**2 + (1 + 1/x**2)/x**2) == \
([(x0, x**2), (x1, 1/x0)], [x0 + x1*(x1 + 1)])
assert cse(1/x**2 + (1 + 1/x**2)*x**2) == \
([(x0, x**2), (x1, 1/x0)], [x0*(x1 + 1) + x1])
assert cse(cos(1/x**2) + sin(1/x**2)) == \
([(x0, x**(-2))], [sin(x0) + cos(x0)])
assert cse(cos(x**2) + sin(x**2)) == \
([(x0, x**2)], [sin(x0) + cos(x0)])
assert cse(y/(2 + x**2) + z/x**2/y) == \
([(x0, x**2)], [y/(x0 + 2) + z/(x0*y)])
assert cse(exp(x**2) + x**2*cos(1/x**2)) == \
([(x0, x**2)], [x0*cos(1/x0) + exp(x0)])
assert cse((1 + 1/x**2)/x**2) == \
([(x0, x**(-2))], [x0*(x0 + 1)])
assert cse(x**(2*y) + x**(-2*y)) == \
([(x0, x**(2*y))], [x0 + 1/x0])
def test_postprocess():
eq = (x + 1 + exp((x + 1)/(y + 1)) + cos(y + 1))
assert cse([eq, Eq(x, z + 1), z - 2, (z + 1)*(x + 1)],
postprocess=cse_main.cse_separate) == \
[[(x0, y + 1), (x2, z + 1), (x, x2), (x1, x + 1)],
[x1 + exp(x1/x0) + cos(x0), z - 2, x1*x2]]
def test_issue_4499():
# previously, this gave 16 constants
from sympy.abc import a, b
B = Function('B')
G = Function('G')
t = Tuple(*
(a, a + S.Half, 2*a, b, 2*a - b + 1, (sqrt(z)/2)**(-2*a + 1)*B(2*a -
b, sqrt(z))*B(b - 1, sqrt(z))*G(b)*G(2*a - b + 1),
sqrt(z)*(sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b,
sqrt(z))*G(b)*G(2*a - b + 1), sqrt(z)*(sqrt(z)/2)**(-2*a + 1)*B(b - 1,
sqrt(z))*B(2*a - b + 1, sqrt(z))*G(b)*G(2*a - b + 1),
(sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b + 1,
sqrt(z))*G(b)*G(2*a - b + 1), 1, 0, S.Half, z/2, -b + 1, -2*a + b,
-2*a))
c = cse(t)
ans = (
[(x0, 2*a), (x1, -b + x0), (x2, x1 + 1), (x3, b - 1), (x4, sqrt(z)),
(x5, B(x3, x4)), (x6, (x4/2)**(1 - x0)*G(b)*G(x2)), (x7, x6*B(x1, x4)),
(x8, B(b, x4)), (x9, x6*B(x2, x4))],
[(a, a + S.Half, x0, b, x2, x5*x7, x4*x7*x8, x4*x5*x9, x8*x9,
1, 0, S.Half, z/2, -x3, -x1, -x0)])
assert ans == c
def test_issue_6169():
r = CRootOf(x**6 - 4*x**5 - 2, 1)
assert cse(r) == ([], [r])
# and a check that the right thing is done with the new
# mechanism
assert sub_post(sub_pre((-x - y)*z - x - y)) == -z*(x + y) - x - y
def test_cse_Indexed():
len_y = 5
y = IndexedBase('y', shape=(len_y,))
x = IndexedBase('x', shape=(len_y,))
i = Idx('i', len_y-1)
expr1 = (y[i+1]-y[i])/(x[i+1]-x[i])
expr2 = 1/(x[i+1]-x[i])
replacements, reduced_exprs = cse([expr1, expr2])
assert len(replacements) > 0
def test_cse_MatrixSymbol():
# MatrixSymbols have non-Basic args, so make sure that works
A = MatrixSymbol("A", 3, 3)
assert cse(A) == ([], [A])
n = symbols('n', integer=True)
B = MatrixSymbol("B", n, n)
assert cse(B) == ([], [B])
assert cse(A[0] * A[0]) == ([], [A[0]*A[0]])
assert cse(A[0,0]*A[0,1] + A[0,0]*A[0,1]*A[0,2]) == ([(x0, A[0, 0]*A[0, 1])], [x0*A[0, 2] + x0])
def test_cse_MatrixExpr():
A = MatrixSymbol('A', 3, 3)
y = MatrixSymbol('y', 3, 1)
expr1 = (A.T*A).I * A * y
expr2 = (A.T*A) * A * y
replacements, reduced_exprs = cse([expr1, expr2])
assert len(replacements) > 0
replacements, reduced_exprs = cse([expr1 + expr2, expr1])
assert replacements
replacements, reduced_exprs = cse([A**2, A + A**2])
assert replacements
def test_Piecewise():
f = Piecewise((-z + x*y, Eq(y, 0)), (-z - x*y, True))
ans = cse(f)
actual_ans = ([(x0, x*y)],
[Piecewise((x0 - z, Eq(y, 0)), (-z - x0, True))])
assert ans == actual_ans
def test_ignore_order_terms():
eq = exp(x).series(x,0,3) + sin(y+x**3) - 1
assert cse(eq) == ([], [sin(x**3 + y) + x + x**2/2 + O(x**3)])
def test_name_conflict():
z1 = x0 + y
z2 = x2 + x3
l = [cos(z1) + z1, cos(z2) + z2, x0 + x2]
substs, reduced = cse(l)
assert [e.subs(reversed(substs)) for e in reduced] == l
def test_name_conflict_cust_symbols():
z1 = x0 + y
z2 = x2 + x3
l = [cos(z1) + z1, cos(z2) + z2, x0 + x2]
substs, reduced = cse(l, symbols("x:10"))
assert [e.subs(reversed(substs)) for e in reduced] == l
def test_symbols_exhausted_error():
l = cos(x+y)+x+y+cos(w+y)+sin(w+y)
sym = [x, y, z]
with raises(ValueError):
cse(l, symbols=sym)
def test_issue_7840():
# daveknippers' example
C393 = sympify( \
'Piecewise((C391 - 1.65, C390 < 0.5), (Piecewise((C391 - 1.65, \
C391 > 2.35), (C392, True)), True))'
)
C391 = sympify( \
'Piecewise((2.05*C390**(-1.03), C390 < 0.5), (2.5*C390**(-0.625), True))'
)
C393 = C393.subs('C391',C391)
# simple substitution
sub = {}
sub['C390'] = 0.703451854
sub['C392'] = 1.01417794
ss_answer = C393.subs(sub)
# cse
substitutions,new_eqn = cse(C393)
for pair in substitutions:
sub[pair[0].name] = pair[1].subs(sub)
cse_answer = new_eqn[0].subs(sub)
# both methods should be the same
assert ss_answer == cse_answer
# GitRay's example
expr = sympify(
"Piecewise((Symbol('ON'), Equality(Symbol('mode'), Symbol('ON'))), \
(Piecewise((Piecewise((Symbol('OFF'), StrictLessThan(Symbol('x'), \
Symbol('threshold'))), (Symbol('ON'), true)), Equality(Symbol('mode'), \
Symbol('AUTO'))), (Symbol('OFF'), true)), true))"
)
substitutions, new_eqn = cse(expr)
# this Piecewise should be exactly the same
assert new_eqn[0] == expr
# there should not be any replacements
assert len(substitutions) < 1
def test_issue_8891():
for cls in (MutableDenseMatrix, MutableSparseMatrix,
ImmutableDenseMatrix, ImmutableSparseMatrix):
m = cls(2, 2, [x + y, 0, 0, 0])
res = cse([x + y, m])
ans = ([(x0, x + y)], [x0, cls([[x0, 0], [0, 0]])])
assert res == ans
assert isinstance(res[1][-1], cls)
def test_issue_11230():
# a specific test that always failed
a, b, f, k, l, i = symbols('a b f k l i')
p = [a*b*f*k*l, a*i*k**2*l, f*i*k**2*l]
R, C = cse(p)
assert not any(i.is_Mul for a in C for i in a.args)
# random tests for the issue
from sympy.core.random import choice
from sympy.core.function import expand_mul
s = symbols('a:m')
# 35 Mul tests, none of which should ever fail
ex = [Mul(*[choice(s) for i in range(5)]) for i in range(7)]
for p in subsets(ex, 3):
p = list(p)
R, C = cse(p)
assert not any(i.is_Mul for a in C for i in a.args)
for ri in reversed(R):
for i in range(len(C)):
C[i] = C[i].subs(*ri)
assert p == C
# 35 Add tests, none of which should ever fail
ex = [Add(*[choice(s[:7]) for i in range(5)]) for i in range(7)]
for p in subsets(ex, 3):
p = list(p)
R, C = cse(p)
assert not any(i.is_Add for a in C for i in a.args)
for ri in reversed(R):
for i in range(len(C)):
C[i] = C[i].subs(*ri)
# use expand_mul to handle cases like this:
# p = [a + 2*b + 2*e, 2*b + c + 2*e, b + 2*c + 2*g]
# x0 = 2*(b + e) is identified giving a rebuilt p that
# is now `[a + 2*(b + e), c + 2*(b + e), b + 2*c + 2*g]`
assert p == [expand_mul(i) for i in C]
@XFAIL
def test_issue_11577():
def check(eq):
r, c = cse(eq)
assert eq.count_ops() >= \
len(r) + sum(i[1].count_ops() for i in r) + \
count_ops(c)
eq = x**5*y**2 + x**5*y + x**5
assert cse(eq) == (
[(x0, x**4), (x1, x*y)], [x**5 + x0*x1*y + x0*x1])
# ([(x0, x**5*y)], [x0*y + x0 + x**5]) or
# ([(x0, x**5)], [x0*y**2 + x0*y + x0])
check(eq)
eq = x**2/(y + 1)**2 + x/(y + 1)
assert cse(eq) == (
[(x0, y + 1)], [x**2/x0**2 + x/x0])
# ([(x0, x/(y + 1))], [x0**2 + x0])
check(eq)
def test_hollow_rejection():
eq = [x + 3, x + 4]
assert cse(eq) == ([], eq)
def test_cse_ignore():
exprs = [exp(y)*(3*y + 3*sqrt(x+1)), exp(y)*(5*y + 5*sqrt(x+1))]
subst1, red1 = cse(exprs)
assert any(y in sub.free_symbols for _, sub in subst1), "cse failed to identify any term with y"
subst2, red2 = cse(exprs, ignore=(y,)) # y is not allowed in substitutions
assert not any(y in sub.free_symbols for _, sub in subst2), "Sub-expressions containing y must be ignored"
assert any(sub - sqrt(x + 1) == 0 for _, sub in subst2), "cse failed to identify sqrt(x + 1) as sub-expression"
def test_cse_ignore_issue_15002():
l = [
w*exp(x)*exp(-z),
exp(y)*exp(x)*exp(-z)
]
substs, reduced = cse(l, ignore=(x,))
rl = [e.subs(reversed(substs)) for e in reduced]
assert rl == l
def test_cse_unevaluated():
xp1 = UnevaluatedExpr(x + 1)
# This used to cause RecursionError
[(x0, ue)], [red] = cse([(-1 - xp1) / (1 - xp1)])
if ue == xp1:
assert red == (-1 - x0) / (1 - x0)
elif ue == -xp1:
assert red == (-1 + x0) / (1 + x0)
else:
msg = f'Expected common subexpression {xp1} or {-xp1}, instead got {ue}'
assert False, msg
def test_cse__performance():
nexprs, nterms = 3, 20
x = symbols('x:%d' % nterms)
exprs = [
reduce(add, [x[j]*(-1)**(i+j) for j in range(nterms)])
for i in range(nexprs)
]
assert (exprs[0] + exprs[1]).simplify() == 0
subst, red = cse(exprs)
assert len(subst) > 0, "exprs[0] == -exprs[2], i.e. a CSE"
for i, e in enumerate(red):
assert (e.subs(reversed(subst)) - exprs[i]).simplify() == 0
def test_issue_12070():
exprs = [x + y, 2 + x + y, x + y + z, 3 + x + y + z]
subst, red = cse(exprs)
assert 6 >= (len(subst) + sum(v.count_ops() for k, v in subst) +
count_ops(red))
def test_issue_13000():
eq = x/(-4*x**2 + y**2)
cse_eq = cse(eq)[1][0]
assert cse_eq == eq
def test_issue_18203():
eq = CRootOf(x**5 + 11*x - 2, 0) + CRootOf(x**5 + 11*x - 2, 1)
assert cse(eq) == ([], [eq])
def test_unevaluated_mul():
eq = Mul(x + y, x + y, evaluate=False)
assert cse(eq) == ([(x0, x + y)], [x0**2])
def test_cse_release_variables():
from sympy.simplify.cse_main import cse_release_variables
_0, _1, _2, _3, _4 = symbols('_:5')
eqs = [(x + y - 1)**2, x,
x + y, (x + y)/(2*x + 1) + (x + y - 1)**2,
(2*x + 1)**(x + y)]
r, e = cse(eqs, postprocess=cse_release_variables)
# this can change in keeping with the intention of the function
assert r, e == ([
(x0, x + y), (x1, (x0 - 1)**2), (x2, 2*x + 1),
(_3, x0/x2 + x1), (_4, x2**x0), (x2, None), (_0, x1),
(x1, None), (_2, x0), (x0, None), (_1, x)], (_0, _1, _2, _3, _4))
r.reverse()
r = [(s, v) for s, v in r if v is not None]
assert eqs == [i.subs(r) for i in e]
def test_cse_list():
_cse = lambda x: cse(x, list=False)
assert _cse(x) == ([], x)
assert _cse('x') == ([], 'x')
it = [x]
for c in (list, tuple, set):
assert _cse(c(it)) == ([], c(it))
#Tuple works different from tuple:
assert _cse(Tuple(*it)) == ([], Tuple(*it))
d = {x: 1}
assert _cse(d) == ([], d)
def test_issue_18991():
A = MatrixSymbol('A', 2, 2)
assert signsimp(-A * A - A) == -A * A - A
def test_unevaluated_Mul():
m = [Mul(1, 2, evaluate=False)]
assert cse(m) == ([], m)
def test_cse_matrix_expression_inverse():
A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
x = Inverse(A)
cse_expr = cse(x)
assert cse_expr == ([], [Inverse(A)])
def test_cse_matrix_expression_matmul_inverse():
A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
b = ImmutableDenseMatrix(symbols('b:2'))
x = MatMul(Inverse(A), b)
cse_expr = cse(x)
assert cse_expr == ([], [x])
def test_cse_matrix_negate_matrix():
A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
x = MatMul(S.NegativeOne, A)
cse_expr = cse(x)
assert cse_expr == ([], [x])
def test_cse_matrix_negate_matmul_not_extracted():
A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
B = ImmutableDenseMatrix(symbols('B:4')).reshape(2, 2)
x = MatMul(S.NegativeOne, A, B)
cse_expr = cse(x)
assert cse_expr == ([], [x])
@XFAIL # No simplification rule for nested associative operations
def test_cse_matrix_nested_matmul_collapsed():
A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
B = ImmutableDenseMatrix(symbols('B:4')).reshape(2, 2)
x = MatMul(S.NegativeOne, MatMul(A, B))
cse_expr = cse(x)
assert cse_expr == ([], [MatMul(S.NegativeOne, A, B)])
def test_cse_matrix_optimize_out_single_argument_mul():
A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
x = MatMul(MatMul(MatMul(A)))
cse_expr = cse(x)
assert cse_expr == ([], [A])
@XFAIL # Multiple simplification passed not supported in CSE
def test_cse_matrix_optimize_out_single_argument_mul_combined():
A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
x = MatAdd(MatMul(MatMul(MatMul(A))), MatMul(MatMul(A)), MatMul(A), A)
cse_expr = cse(x)
assert cse_expr == ([], [MatMul(4, A)])
def test_cse_matrix_optimize_out_single_argument_add():
A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
x = MatAdd(MatAdd(MatAdd(MatAdd(A))))
cse_expr = cse(x)
assert cse_expr == ([], [A])
@XFAIL # Multiple simplification passed not supported in CSE
def test_cse_matrix_optimize_out_single_argument_add_combined():
A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
x = MatMul(MatAdd(MatAdd(MatAdd(A))), MatAdd(MatAdd(A)), MatAdd(A), A)
cse_expr = cse(x)
assert cse_expr == ([], [MatMul(4, A)])
def test_cse_matrix_expression_matrix_solve():
A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
b = ImmutableDenseMatrix(symbols('b:2'))
x = MatrixSolve(A, b)
cse_expr = cse(x)
assert cse_expr == ([], [x])
def test_cse_matrix_matrix_expression():
X = ImmutableDenseMatrix(symbols('X:4')).reshape(2, 2)
y = ImmutableDenseMatrix(symbols('y:2'))
b = MatMul(Inverse(MatMul(Transpose(X), X)), Transpose(X), y)
cse_expr = cse(b)
x0 = MatrixSymbol('x0', 2, 2)
reduced_expr_expected = MatMul(Inverse(MatMul(x0, X)), x0, y)
assert cse_expr == ([(x0, Transpose(X))], [reduced_expr_expected])
def test_cse_matrix_kalman_filter():
"""Kalman Filter example from Matthew Rocklin's SciPy 2013 talk.
Talk titled: "Matrix Expressions and BLAS/LAPACK; SciPy 2013 Presentation"
Video: https://pyvideo.org/scipy-2013/matrix-expressions-and-blaslapack-scipy-2013-pr.html
Notes
=====
Equations are:
new_mu = mu + Sigma*H.T * (R + H*Sigma*H.T).I * (H*mu - data)
= MatAdd(mu, MatMul(Sigma, Transpose(H), Inverse(MatAdd(R, MatMul(H, Sigma, Transpose(H)))), MatAdd(MatMul(H, mu), MatMul(S.NegativeOne, data))))
new_Sigma = Sigma - Sigma*H.T * (R + H*Sigma*H.T).I * H * Sigma
= MatAdd(Sigma, MatMul(S.NegativeOne, Sigma, Transpose(H)), Inverse(MatAdd(R, MatMul(H*Sigma*Transpose(H)))), H, Sigma))
"""
N = 2
mu = ImmutableDenseMatrix(symbols(f'mu:{N}'))
Sigma = ImmutableDenseMatrix(symbols(f'Sigma:{N * N}')).reshape(N, N)
H = ImmutableDenseMatrix(symbols(f'H:{N * N}')).reshape(N, N)
R = ImmutableDenseMatrix(symbols(f'R:{N * N}')).reshape(N, N)
data = ImmutableDenseMatrix(symbols(f'data:{N}'))
new_mu = MatAdd(mu, MatMul(Sigma, Transpose(H), Inverse(MatAdd(R, MatMul(H, Sigma, Transpose(H)))), MatAdd(MatMul(H, mu), MatMul(S.NegativeOne, data))))
new_Sigma = MatAdd(Sigma, MatMul(S.NegativeOne, Sigma, Transpose(H), Inverse(MatAdd(R, MatMul(H, Sigma, Transpose(H)))), H, Sigma))
cse_expr = cse([new_mu, new_Sigma])
x0 = MatrixSymbol('x0', N, N)
x1 = MatrixSymbol('x1', N, N)
replacements_expected = [
(x0, Transpose(H)),
(x1, Inverse(MatAdd(R, MatMul(H, Sigma, x0)))),
]
reduced_exprs_expected = [
MatAdd(mu, MatMul(Sigma, x0, x1, MatAdd(MatMul(H, mu), MatMul(S.NegativeOne, data)))),
MatAdd(Sigma, MatMul(S.NegativeOne, Sigma, x0, x1, H, Sigma)),
]
assert cse_expr == (replacements_expected, reduced_exprs_expected)

View File

@ -0,0 +1,90 @@
"""Tests for tools for manipulation of expressions using paths. """
from sympy.simplify.epathtools import epath, EPath
from sympy.testing.pytest import raises
from sympy.core.numbers import E
from sympy.functions.elementary.trigonometric import (cos, sin)
from sympy.abc import x, y, z, t
def test_epath_select():
expr = [((x, 1, t), 2), ((3, y, 4), z)]
assert epath("/*", expr) == [((x, 1, t), 2), ((3, y, 4), z)]
assert epath("/*/*", expr) == [(x, 1, t), 2, (3, y, 4), z]
assert epath("/*/*/*", expr) == [x, 1, t, 3, y, 4]
assert epath("/*/*/*/*", expr) == []
assert epath("/[:]", expr) == [((x, 1, t), 2), ((3, y, 4), z)]
assert epath("/[:]/[:]", expr) == [(x, 1, t), 2, (3, y, 4), z]
assert epath("/[:]/[:]/[:]", expr) == [x, 1, t, 3, y, 4]
assert epath("/[:]/[:]/[:]/[:]", expr) == []
assert epath("/*/[:]", expr) == [(x, 1, t), 2, (3, y, 4), z]
assert epath("/*/[0]", expr) == [(x, 1, t), (3, y, 4)]
assert epath("/*/[1]", expr) == [2, z]
assert epath("/*/[2]", expr) == []
assert epath("/*/int", expr) == [2]
assert epath("/*/Symbol", expr) == [z]
assert epath("/*/tuple", expr) == [(x, 1, t), (3, y, 4)]
assert epath("/*/__iter__?", expr) == [(x, 1, t), (3, y, 4)]
assert epath("/*/int|tuple", expr) == [(x, 1, t), 2, (3, y, 4)]
assert epath("/*/Symbol|tuple", expr) == [(x, 1, t), (3, y, 4), z]
assert epath("/*/int|Symbol|tuple", expr) == [(x, 1, t), 2, (3, y, 4), z]
assert epath("/*/int|__iter__?", expr) == [(x, 1, t), 2, (3, y, 4)]
assert epath("/*/Symbol|__iter__?", expr) == [(x, 1, t), (3, y, 4), z]
assert epath(
"/*/int|Symbol|__iter__?", expr) == [(x, 1, t), 2, (3, y, 4), z]
assert epath("/*/[0]/int", expr) == [1, 3, 4]
assert epath("/*/[0]/Symbol", expr) == [x, t, y]
assert epath("/*/[0]/int[1:]", expr) == [1, 4]
assert epath("/*/[0]/Symbol[1:]", expr) == [t, y]
assert epath("/Symbol", x + y + z + 1) == [x, y, z]
assert epath("/*/*/Symbol", t + sin(x + 1) + cos(x + y + E)) == [x, x, y]
def test_epath_apply():
expr = [((x, 1, t), 2), ((3, y, 4), z)]
func = lambda expr: expr**2
assert epath("/*", expr, list) == [[(x, 1, t), 2], [(3, y, 4), z]]
assert epath("/*/[0]", expr, list) == [([x, 1, t], 2), ([3, y, 4], z)]
assert epath("/*/[1]", expr, func) == [((x, 1, t), 4), ((3, y, 4), z**2)]
assert epath("/*/[2]", expr, list) == expr
assert epath("/*/[0]/int", expr, func) == [((x, 1, t), 2), ((9, y, 16), z)]
assert epath("/*/[0]/Symbol", expr, func) == [((x**2, 1, t**2), 2),
((3, y**2, 4), z)]
assert epath(
"/*/[0]/int[1:]", expr, func) == [((x, 1, t), 2), ((3, y, 16), z)]
assert epath("/*/[0]/Symbol[1:]", expr, func) == [((x, 1, t**2),
2), ((3, y**2, 4), z)]
assert epath("/Symbol", x + y + z + 1, func) == x**2 + y**2 + z**2 + 1
assert epath("/*/*/Symbol", t + sin(x + 1) + cos(x + y + E), func) == \
t + sin(x**2 + 1) + cos(x**2 + y**2 + E)
def test_EPath():
assert EPath("/*/[0]")._path == "/*/[0]"
assert EPath(EPath("/*/[0]"))._path == "/*/[0]"
assert isinstance(epath("/*/[0]"), EPath) is True
assert repr(EPath("/*/[0]")) == "EPath('/*/[0]')"
raises(ValueError, lambda: EPath(""))
raises(ValueError, lambda: EPath("/"))
raises(ValueError, lambda: EPath("/|x"))
raises(ValueError, lambda: EPath("/["))
raises(ValueError, lambda: EPath("/[0]%"))
raises(NotImplementedError, lambda: EPath("Symbol"))

View File

@ -0,0 +1,492 @@
from sympy.core.add import Add
from sympy.core.mul import Mul
from sympy.core.numbers import (I, Rational, pi)
from sympy.core.parameters import evaluate
from sympy.core.singleton import S
from sympy.core.symbol import (Dummy, Symbol, symbols)
from sympy.functions.elementary.hyperbolic import (cosh, coth, csch, sech, sinh, tanh)
from sympy.functions.elementary.miscellaneous import (root, sqrt)
from sympy.functions.elementary.trigonometric import (cos, cot, csc, sec, sin, tan)
from sympy.simplify.powsimp import powsimp
from sympy.simplify.fu import (
L, TR1, TR10, TR10i, TR11, _TR11, TR12, TR12i, TR13, TR14, TR15, TR16,
TR111, TR2, TR2i, TR3, TR4, TR5, TR6, TR7, TR8, TR9, TRmorrie, _TR56 as T,
TRpower, hyper_as_trig, fu, process_common_addends, trig_split,
as_f_sign_1)
from sympy.core.random import verify_numerically
from sympy.abc import a, b, c, x, y, z
def test_TR1():
assert TR1(2*csc(x) + sec(x)) == 1/cos(x) + 2/sin(x)
def test_TR2():
assert TR2(tan(x)) == sin(x)/cos(x)
assert TR2(cot(x)) == cos(x)/sin(x)
assert TR2(tan(tan(x) - sin(x)/cos(x))) == 0
def test_TR2i():
# just a reminder that ratios of powers only simplify if both
# numerator and denominator satisfy the condition that each
# has a positive base or an integer exponent; e.g. the following,
# at y=-1, x=1/2 gives sqrt(2)*I != -sqrt(2)*I
assert powsimp(2**x/y**x) != (2/y)**x
assert TR2i(sin(x)/cos(x)) == tan(x)
assert TR2i(sin(x)*sin(y)/cos(x)) == tan(x)*sin(y)
assert TR2i(1/(sin(x)/cos(x))) == 1/tan(x)
assert TR2i(1/(sin(x)*sin(y)/cos(x))) == 1/tan(x)/sin(y)
assert TR2i(sin(x)/2/(cos(x) + 1)) == sin(x)/(cos(x) + 1)/2
assert TR2i(sin(x)/2/(cos(x) + 1), half=True) == tan(x/2)/2
assert TR2i(sin(1)/(cos(1) + 1), half=True) == tan(S.Half)
assert TR2i(sin(2)/(cos(2) + 1), half=True) == tan(1)
assert TR2i(sin(4)/(cos(4) + 1), half=True) == tan(2)
assert TR2i(sin(5)/(cos(5) + 1), half=True) == tan(5*S.Half)
assert TR2i((cos(1) + 1)/sin(1), half=True) == 1/tan(S.Half)
assert TR2i((cos(2) + 1)/sin(2), half=True) == 1/tan(1)
assert TR2i((cos(4) + 1)/sin(4), half=True) == 1/tan(2)
assert TR2i((cos(5) + 1)/sin(5), half=True) == 1/tan(5*S.Half)
assert TR2i((cos(1) + 1)**(-a)*sin(1)**a, half=True) == tan(S.Half)**a
assert TR2i((cos(2) + 1)**(-a)*sin(2)**a, half=True) == tan(1)**a
assert TR2i((cos(4) + 1)**(-a)*sin(4)**a, half=True) == (cos(4) + 1)**(-a)*sin(4)**a
assert TR2i((cos(5) + 1)**(-a)*sin(5)**a, half=True) == (cos(5) + 1)**(-a)*sin(5)**a
assert TR2i((cos(1) + 1)**a*sin(1)**(-a), half=True) == tan(S.Half)**(-a)
assert TR2i((cos(2) + 1)**a*sin(2)**(-a), half=True) == tan(1)**(-a)
assert TR2i((cos(4) + 1)**a*sin(4)**(-a), half=True) == (cos(4) + 1)**a*sin(4)**(-a)
assert TR2i((cos(5) + 1)**a*sin(5)**(-a), half=True) == (cos(5) + 1)**a*sin(5)**(-a)
i = symbols('i', integer=True)
assert TR2i(((cos(5) + 1)**i*sin(5)**(-i)), half=True) == tan(5*S.Half)**(-i)
assert TR2i(1/((cos(5) + 1)**i*sin(5)**(-i)), half=True) == tan(5*S.Half)**i
def test_TR3():
assert TR3(cos(y - x*(y - x))) == cos(x*(x - y) + y)
assert cos(pi/2 + x) == -sin(x)
assert cos(30*pi/2 + x) == -cos(x)
for f in (cos, sin, tan, cot, csc, sec):
i = f(pi*Rational(3, 7))
j = TR3(i)
assert verify_numerically(i, j) and i.func != j.func
with evaluate(False):
eq = cos(9*pi/22)
assert eq.has(9*pi) and TR3(eq) == sin(pi/11)
def test_TR4():
for i in [0, pi/6, pi/4, pi/3, pi/2]:
with evaluate(False):
eq = cos(i)
assert isinstance(eq, cos) and TR4(eq) == cos(i)
def test__TR56():
h = lambda x: 1 - x
assert T(sin(x)**3, sin, cos, h, 4, False) == sin(x)*(-cos(x)**2 + 1)
assert T(sin(x)**10, sin, cos, h, 4, False) == sin(x)**10
assert T(sin(x)**6, sin, cos, h, 6, False) == (-cos(x)**2 + 1)**3
assert T(sin(x)**6, sin, cos, h, 6, True) == sin(x)**6
assert T(sin(x)**8, sin, cos, h, 10, True) == (-cos(x)**2 + 1)**4
# issue 17137
assert T(sin(x)**I, sin, cos, h, 4, True) == sin(x)**I
assert T(sin(x)**(2*I + 1), sin, cos, h, 4, True) == sin(x)**(2*I + 1)
def test_TR5():
assert TR5(sin(x)**2) == -cos(x)**2 + 1
assert TR5(sin(x)**-2) == sin(x)**(-2)
assert TR5(sin(x)**4) == (-cos(x)**2 + 1)**2
def test_TR6():
assert TR6(cos(x)**2) == -sin(x)**2 + 1
assert TR6(cos(x)**-2) == cos(x)**(-2)
assert TR6(cos(x)**4) == (-sin(x)**2 + 1)**2
def test_TR7():
assert TR7(cos(x)**2) == cos(2*x)/2 + S.Half
assert TR7(cos(x)**2 + 1) == cos(2*x)/2 + Rational(3, 2)
def test_TR8():
assert TR8(cos(2)*cos(3)) == cos(5)/2 + cos(1)/2
assert TR8(cos(2)*sin(3)) == sin(5)/2 + sin(1)/2
assert TR8(sin(2)*sin(3)) == -cos(5)/2 + cos(1)/2
assert TR8(sin(1)*sin(2)*sin(3)) == sin(4)/4 - sin(6)/4 + sin(2)/4
assert TR8(cos(2)*cos(3)*cos(4)*cos(5)) == \
cos(4)/4 + cos(10)/8 + cos(2)/8 + cos(8)/8 + cos(14)/8 + \
cos(6)/8 + Rational(1, 8)
assert TR8(cos(2)*cos(3)*cos(4)*cos(5)*cos(6)) == \
cos(10)/8 + cos(4)/8 + 3*cos(2)/16 + cos(16)/16 + cos(8)/8 + \
cos(14)/16 + cos(20)/16 + cos(12)/16 + Rational(1, 16) + cos(6)/8
assert TR8(sin(pi*Rational(3, 7))**2*cos(pi*Rational(3, 7))**2/(16*sin(pi/7)**2)) == Rational(1, 64)
def test_TR9():
a = S.Half
b = 3*a
assert TR9(a) == a
assert TR9(cos(1) + cos(2)) == 2*cos(a)*cos(b)
assert TR9(cos(1) - cos(2)) == 2*sin(a)*sin(b)
assert TR9(sin(1) - sin(2)) == -2*sin(a)*cos(b)
assert TR9(sin(1) + sin(2)) == 2*sin(b)*cos(a)
assert TR9(cos(1) + 2*sin(1) + 2*sin(2)) == cos(1) + 4*sin(b)*cos(a)
assert TR9(cos(4) + cos(2) + 2*cos(1)*cos(3)) == 4*cos(1)*cos(3)
assert TR9((cos(4) + cos(2))/cos(3)/2 + cos(3)) == 2*cos(1)*cos(2)
assert TR9(cos(3) + cos(4) + cos(5) + cos(6)) == \
4*cos(S.Half)*cos(1)*cos(Rational(9, 2))
assert TR9(cos(3) + cos(3)*cos(2)) == cos(3) + cos(2)*cos(3)
assert TR9(-cos(y) + cos(x*y)) == -2*sin(x*y/2 - y/2)*sin(x*y/2 + y/2)
assert TR9(-sin(y) + sin(x*y)) == 2*sin(x*y/2 - y/2)*cos(x*y/2 + y/2)
c = cos(x)
s = sin(x)
for si in ((1, 1), (1, -1), (-1, 1), (-1, -1)):
for a in ((c, s), (s, c), (cos(x), cos(x*y)), (sin(x), sin(x*y))):
args = zip(si, a)
ex = Add(*[Mul(*ai) for ai in args])
t = TR9(ex)
assert not (a[0].func == a[1].func and (
not verify_numerically(ex, t.expand(trig=True)) or t.is_Add)
or a[1].func != a[0].func and ex != t)
def test_TR10():
assert TR10(cos(a + b)) == -sin(a)*sin(b) + cos(a)*cos(b)
assert TR10(sin(a + b)) == sin(a)*cos(b) + sin(b)*cos(a)
assert TR10(sin(a + b + c)) == \
(-sin(a)*sin(b) + cos(a)*cos(b))*sin(c) + \
(sin(a)*cos(b) + sin(b)*cos(a))*cos(c)
assert TR10(cos(a + b + c)) == \
(-sin(a)*sin(b) + cos(a)*cos(b))*cos(c) - \
(sin(a)*cos(b) + sin(b)*cos(a))*sin(c)
def test_TR10i():
assert TR10i(cos(1)*cos(3) + sin(1)*sin(3)) == cos(2)
assert TR10i(cos(1)*cos(3) - sin(1)*sin(3)) == cos(4)
assert TR10i(cos(1)*sin(3) - sin(1)*cos(3)) == sin(2)
assert TR10i(cos(1)*sin(3) + sin(1)*cos(3)) == sin(4)
assert TR10i(cos(1)*sin(3) + sin(1)*cos(3) + 7) == sin(4) + 7
assert TR10i(cos(1)*sin(3) + sin(1)*cos(3) + cos(3)) == cos(3) + sin(4)
assert TR10i(2*cos(1)*sin(3) + 2*sin(1)*cos(3) + cos(3)) == \
2*sin(4) + cos(3)
assert TR10i(cos(2)*cos(3) + sin(2)*(cos(1)*sin(2) + cos(2)*sin(1))) == \
cos(1)
eq = (cos(2)*cos(3) + sin(2)*(
cos(1)*sin(2) + cos(2)*sin(1)))*cos(5) + sin(1)*sin(5)
assert TR10i(eq) == TR10i(eq.expand()) == cos(4)
assert TR10i(sqrt(2)*cos(x)*x + sqrt(6)*sin(x)*x) == \
2*sqrt(2)*x*sin(x + pi/6)
assert TR10i(cos(x)/sqrt(6) + sin(x)/sqrt(2) +
cos(x)/sqrt(6)/3 + sin(x)/sqrt(2)/3) == 4*sqrt(6)*sin(x + pi/6)/9
assert TR10i(cos(x)/sqrt(6) + sin(x)/sqrt(2) +
cos(y)/sqrt(6)/3 + sin(y)/sqrt(2)/3) == \
sqrt(6)*sin(x + pi/6)/3 + sqrt(6)*sin(y + pi/6)/9
assert TR10i(cos(x) + sqrt(3)*sin(x) + 2*sqrt(3)*cos(x + pi/6)) == 4*cos(x)
assert TR10i(cos(x) + sqrt(3)*sin(x) +
2*sqrt(3)*cos(x + pi/6) + 4*sin(x)) == 4*sqrt(2)*sin(x + pi/4)
assert TR10i(cos(2)*sin(3) + sin(2)*cos(4)) == \
sin(2)*cos(4) + sin(3)*cos(2)
A = Symbol('A', commutative=False)
assert TR10i(sqrt(2)*cos(x)*A + sqrt(6)*sin(x)*A) == \
2*sqrt(2)*sin(x + pi/6)*A
c = cos(x)
s = sin(x)
h = sin(y)
r = cos(y)
for si in ((1, 1), (1, -1), (-1, 1), (-1, -1)):
for argsi in ((c*r, s*h), (c*h, s*r)): # explicit 2-args
args = zip(si, argsi)
ex = Add(*[Mul(*ai) for ai in args])
t = TR10i(ex)
assert not (ex - t.expand(trig=True) or t.is_Add)
c = cos(x)
s = sin(x)
h = sin(pi/6)
r = cos(pi/6)
for si in ((1, 1), (1, -1), (-1, 1), (-1, -1)):
for argsi in ((c*r, s*h), (c*h, s*r)): # induced
args = zip(si, argsi)
ex = Add(*[Mul(*ai) for ai in args])
t = TR10i(ex)
assert not (ex - t.expand(trig=True) or t.is_Add)
def test_TR11():
assert TR11(sin(2*x)) == 2*sin(x)*cos(x)
assert TR11(sin(4*x)) == 4*((-sin(x)**2 + cos(x)**2)*sin(x)*cos(x))
assert TR11(sin(x*Rational(4, 3))) == \
4*((-sin(x/3)**2 + cos(x/3)**2)*sin(x/3)*cos(x/3))
assert TR11(cos(2*x)) == -sin(x)**2 + cos(x)**2
assert TR11(cos(4*x)) == \
(-sin(x)**2 + cos(x)**2)**2 - 4*sin(x)**2*cos(x)**2
assert TR11(cos(2)) == cos(2)
assert TR11(cos(pi*Rational(3, 7)), pi*Rational(2, 7)) == -cos(pi*Rational(2, 7))**2 + sin(pi*Rational(2, 7))**2
assert TR11(cos(4), 2) == -sin(2)**2 + cos(2)**2
assert TR11(cos(6), 2) == cos(6)
assert TR11(sin(x)/cos(x/2), x/2) == 2*sin(x/2)
def test__TR11():
assert _TR11(sin(x/3)*sin(2*x)*sin(x/4)/(cos(x/6)*cos(x/8))) == \
4*sin(x/8)*sin(x/6)*sin(2*x),_TR11(sin(x/3)*sin(2*x)*sin(x/4)/(cos(x/6)*cos(x/8)))
assert _TR11(sin(x/3)/cos(x/6)) == 2*sin(x/6)
assert _TR11(cos(x/6)/sin(x/3)) == 1/(2*sin(x/6))
assert _TR11(sin(2*x)*cos(x/8)/sin(x/4)) == sin(2*x)/(2*sin(x/8)), _TR11(sin(2*x)*cos(x/8)/sin(x/4))
assert _TR11(sin(x)/sin(x/2)) == 2*cos(x/2)
def test_TR12():
assert TR12(tan(x + y)) == (tan(x) + tan(y))/(-tan(x)*tan(y) + 1)
assert TR12(tan(x + y + z)) ==\
(tan(z) + (tan(x) + tan(y))/(-tan(x)*tan(y) + 1))/(
1 - (tan(x) + tan(y))*tan(z)/(-tan(x)*tan(y) + 1))
assert TR12(tan(x*y)) == tan(x*y)
def test_TR13():
assert TR13(tan(3)*tan(2)) == -tan(2)/tan(5) - tan(3)/tan(5) + 1
assert TR13(cot(3)*cot(2)) == 1 + cot(3)*cot(5) + cot(2)*cot(5)
assert TR13(tan(1)*tan(2)*tan(3)) == \
(-tan(2)/tan(5) - tan(3)/tan(5) + 1)*tan(1)
assert TR13(tan(1)*tan(2)*cot(3)) == \
(-tan(2)/tan(3) + 1 - tan(1)/tan(3))*cot(3)
def test_L():
assert L(cos(x) + sin(x)) == 2
def test_fu():
assert fu(sin(50)**2 + cos(50)**2 + sin(pi/6)) == Rational(3, 2)
assert fu(sqrt(6)*cos(x) + sqrt(2)*sin(x)) == 2*sqrt(2)*sin(x + pi/3)
eq = sin(x)**4 - cos(y)**2 + sin(y)**2 + 2*cos(x)**2
assert fu(eq) == cos(x)**4 - 2*cos(y)**2 + 2
assert fu(S.Half - cos(2*x)/2) == sin(x)**2
assert fu(sin(a)*(cos(b) - sin(b)) + cos(a)*(sin(b) + cos(b))) == \
sqrt(2)*sin(a + b + pi/4)
assert fu(sqrt(3)*cos(x)/2 + sin(x)/2) == sin(x + pi/3)
assert fu(1 - sin(2*x)**2/4 - sin(y)**2 - cos(x)**4) == \
-cos(x)**2 + cos(y)**2
assert fu(cos(pi*Rational(4, 9))) == sin(pi/18)
assert fu(cos(pi/9)*cos(pi*Rational(2, 9))*cos(pi*Rational(3, 9))*cos(pi*Rational(4, 9))) == Rational(1, 16)
assert fu(
tan(pi*Rational(7, 18)) + tan(pi*Rational(5, 18)) - sqrt(3)*tan(pi*Rational(5, 18))*tan(pi*Rational(7, 18))) == \
-sqrt(3)
assert fu(tan(1)*tan(2)) == tan(1)*tan(2)
expr = Mul(*[cos(2**i) for i in range(10)])
assert fu(expr) == sin(1024)/(1024*sin(1))
# issue #18059:
assert fu(cos(x) + sqrt(sin(x)**2)) == cos(x) + sqrt(sin(x)**2)
assert fu((-14*sin(x)**3 + 35*sin(x) + 6*sqrt(3)*cos(x)**3 + 9*sqrt(3)*cos(x))/((cos(2*x) + 4))) == \
7*sin(x) + 3*sqrt(3)*cos(x)
def test_objective():
assert fu(sin(x)/cos(x), measure=lambda x: x.count_ops()) == \
tan(x)
assert fu(sin(x)/cos(x), measure=lambda x: -x.count_ops()) == \
sin(x)/cos(x)
def test_process_common_addends():
# this tests that the args are not evaluated as they are given to do
# and that key2 works when key1 is False
do = lambda x: Add(*[i**(i%2) for i in x.args])
assert process_common_addends(Add(*[1, 2, 3, 4], evaluate=False), do,
key2=lambda x: x%2, key1=False) == 1**1 + 3**1 + 2**0 + 4**0
def test_trig_split():
assert trig_split(cos(x), cos(y)) == (1, 1, 1, x, y, True)
assert trig_split(2*cos(x), -2*cos(y)) == (2, 1, -1, x, y, True)
assert trig_split(cos(x)*sin(y), cos(y)*sin(y)) == \
(sin(y), 1, 1, x, y, True)
assert trig_split(cos(x), -sqrt(3)*sin(x), two=True) == \
(2, 1, -1, x, pi/6, False)
assert trig_split(cos(x), sin(x), two=True) == \
(sqrt(2), 1, 1, x, pi/4, False)
assert trig_split(cos(x), -sin(x), two=True) == \
(sqrt(2), 1, -1, x, pi/4, False)
assert trig_split(sqrt(2)*cos(x), -sqrt(6)*sin(x), two=True) == \
(2*sqrt(2), 1, -1, x, pi/6, False)
assert trig_split(-sqrt(6)*cos(x), -sqrt(2)*sin(x), two=True) == \
(-2*sqrt(2), 1, 1, x, pi/3, False)
assert trig_split(cos(x)/sqrt(6), sin(x)/sqrt(2), two=True) == \
(sqrt(6)/3, 1, 1, x, pi/6, False)
assert trig_split(-sqrt(6)*cos(x)*sin(y),
-sqrt(2)*sin(x)*sin(y), two=True) == \
(-2*sqrt(2)*sin(y), 1, 1, x, pi/3, False)
assert trig_split(cos(x), sin(x)) is None
assert trig_split(cos(x), sin(z)) is None
assert trig_split(2*cos(x), -sin(x)) is None
assert trig_split(cos(x), -sqrt(3)*sin(x)) is None
assert trig_split(cos(x)*cos(y), sin(x)*sin(z)) is None
assert trig_split(cos(x)*cos(y), sin(x)*sin(y)) is None
assert trig_split(-sqrt(6)*cos(x), sqrt(2)*sin(x)*sin(y), two=True) is \
None
assert trig_split(sqrt(3)*sqrt(x), cos(3), two=True) is None
assert trig_split(sqrt(3)*root(x, 3), sin(3)*cos(2), two=True) is None
assert trig_split(cos(5)*cos(6), cos(7)*sin(5), two=True) is None
def test_TRmorrie():
assert TRmorrie(7*Mul(*[cos(i) for i in range(10)])) == \
7*sin(12)*sin(16)*cos(5)*cos(7)*cos(9)/(64*sin(1)*sin(3))
assert TRmorrie(x) == x
assert TRmorrie(2*x) == 2*x
e = cos(pi/7)*cos(pi*Rational(2, 7))*cos(pi*Rational(4, 7))
assert TR8(TRmorrie(e)) == Rational(-1, 8)
e = Mul(*[cos(2**i*pi/17) for i in range(1, 17)])
assert TR8(TR3(TRmorrie(e))) == Rational(1, 65536)
# issue 17063
eq = cos(x)/cos(x/2)
assert TRmorrie(eq) == eq
# issue #20430
eq = cos(x/2)*sin(x/2)*cos(x)**3
assert TRmorrie(eq) == sin(2*x)*cos(x)**2/4
def test_TRpower():
assert TRpower(1/sin(x)**2) == 1/sin(x)**2
assert TRpower(cos(x)**3*sin(x/2)**4) == \
(3*cos(x)/4 + cos(3*x)/4)*(-cos(x)/2 + cos(2*x)/8 + Rational(3, 8))
for k in range(2, 8):
assert verify_numerically(sin(x)**k, TRpower(sin(x)**k))
assert verify_numerically(cos(x)**k, TRpower(cos(x)**k))
def test_hyper_as_trig():
from sympy.simplify.fu import _osborne, _osbornei
eq = sinh(x)**2 + cosh(x)**2
t, f = hyper_as_trig(eq)
assert f(fu(t)) == cosh(2*x)
e, f = hyper_as_trig(tanh(x + y))
assert f(TR12(e)) == (tanh(x) + tanh(y))/(tanh(x)*tanh(y) + 1)
d = Dummy()
assert _osborne(sinh(x), d) == I*sin(x*d)
assert _osborne(tanh(x), d) == I*tan(x*d)
assert _osborne(coth(x), d) == cot(x*d)/I
assert _osborne(cosh(x), d) == cos(x*d)
assert _osborne(sech(x), d) == sec(x*d)
assert _osborne(csch(x), d) == csc(x*d)/I
for func in (sinh, cosh, tanh, coth, sech, csch):
h = func(pi)
assert _osbornei(_osborne(h, d), d) == h
# /!\ the _osborne functions are not meant to work
# in the o(i(trig, d), d) direction so we just check
# that they work as they are supposed to work
assert _osbornei(cos(x*y + z), y) == cosh(x + z*I)
assert _osbornei(sin(x*y + z), y) == sinh(x + z*I)/I
assert _osbornei(tan(x*y + z), y) == tanh(x + z*I)/I
assert _osbornei(cot(x*y + z), y) == coth(x + z*I)*I
assert _osbornei(sec(x*y + z), y) == sech(x + z*I)
assert _osbornei(csc(x*y + z), y) == csch(x + z*I)*I
def test_TR12i():
ta, tb, tc = [tan(i) for i in (a, b, c)]
assert TR12i((ta + tb)/(-ta*tb + 1)) == tan(a + b)
assert TR12i((ta + tb)/(ta*tb - 1)) == -tan(a + b)
assert TR12i((-ta - tb)/(ta*tb - 1)) == tan(a + b)
eq = (ta + tb)/(-ta*tb + 1)**2*(-3*ta - 3*tc)/(2*(ta*tc - 1))
assert TR12i(eq.expand()) == \
-3*tan(a + b)*tan(a + c)/(tan(a) + tan(b) - 1)/2
assert TR12i(tan(x)/sin(x)) == tan(x)/sin(x)
eq = (ta + cos(2))/(-ta*tb + 1)
assert TR12i(eq) == eq
eq = (ta + tb + 2)**2/(-ta*tb + 1)
assert TR12i(eq) == eq
eq = ta/(-ta*tb + 1)
assert TR12i(eq) == eq
eq = (((ta + tb)*(a + 1)).expand())**2/(ta*tb - 1)
assert TR12i(eq) == -(a + 1)**2*tan(a + b)
def test_TR14():
eq = (cos(x) - 1)*(cos(x) + 1)
ans = -sin(x)**2
assert TR14(eq) == ans
assert TR14(1/eq) == 1/ans
assert TR14((cos(x) - 1)**2*(cos(x) + 1)**2) == ans**2
assert TR14((cos(x) - 1)**2*(cos(x) + 1)**3) == ans**2*(cos(x) + 1)
assert TR14((cos(x) - 1)**3*(cos(x) + 1)**2) == ans**2*(cos(x) - 1)
eq = (cos(x) - 1)**y*(cos(x) + 1)**y
assert TR14(eq) == eq
eq = (cos(x) - 2)**y*(cos(x) + 1)
assert TR14(eq) == eq
eq = (tan(x) - 2)**2*(cos(x) + 1)
assert TR14(eq) == eq
i = symbols('i', integer=True)
assert TR14((cos(x) - 1)**i*(cos(x) + 1)**i) == ans**i
assert TR14((sin(x) - 1)**i*(sin(x) + 1)**i) == (-cos(x)**2)**i
# could use extraction in this case
eq = (cos(x) - 1)**(i + 1)*(cos(x) + 1)**i
assert TR14(eq) in [(cos(x) - 1)*ans**i, eq]
assert TR14((sin(x) - 1)*(sin(x) + 1)) == -cos(x)**2
p1 = (cos(x) + 1)*(cos(x) - 1)
p2 = (cos(y) - 1)*2*(cos(y) + 1)
p3 = (3*(cos(y) - 1))*(3*(cos(y) + 1))
assert TR14(p1*p2*p3*(x - 1)) == -18*((x - 1)*sin(x)**2*sin(y)**4)
def test_TR15_16_17():
assert TR15(1 - 1/sin(x)**2) == -cot(x)**2
assert TR16(1 - 1/cos(x)**2) == -tan(x)**2
assert TR111(1 - 1/tan(x)**2) == 1 - cot(x)**2
def test_as_f_sign_1():
assert as_f_sign_1(x + 1) == (1, x, 1)
assert as_f_sign_1(x - 1) == (1, x, -1)
assert as_f_sign_1(-x + 1) == (-1, x, -1)
assert as_f_sign_1(-x - 1) == (-1, x, 1)
assert as_f_sign_1(2*x + 2) == (2, x, 1)
assert as_f_sign_1(x*y - y) == (y, x, -1)
assert as_f_sign_1(-x*y + y) == (-y, x, -1)
def test_issue_25590():
A = Symbol('A', commutative=False)
B = Symbol('B', commutative=False)
assert TR8(2*cos(x)*sin(x)*B*A) == sin(2*x)*B*A
assert TR13(tan(2)*tan(3)*B*A) == (-tan(2)/tan(5) - tan(3)/tan(5) + 1)*B*A
# XXX The result may not be optimal than
# sin(2*x)*B*A + cos(x)**2 and may change in the future
assert (2*cos(x)*sin(x)*B*A + cos(x)**2).simplify() == sin(2*x)*B*A + cos(2*x)/2 + S.One/2

View File

@ -0,0 +1,54 @@
""" Unit tests for Hyper_Function"""
from sympy.core import symbols, Dummy, Tuple, S, Rational
from sympy.functions import hyper
from sympy.simplify.hyperexpand import Hyper_Function
def test_attrs():
a, b = symbols('a, b', cls=Dummy)
f = Hyper_Function([2, a], [b])
assert f.ap == Tuple(2, a)
assert f.bq == Tuple(b)
assert f.args == (Tuple(2, a), Tuple(b))
assert f.sizes == (2, 1)
def test_call():
a, b, x = symbols('a, b, x', cls=Dummy)
f = Hyper_Function([2, a], [b])
assert f(x) == hyper([2, a], [b], x)
def test_has():
a, b, c = symbols('a, b, c', cls=Dummy)
f = Hyper_Function([2, -a], [b])
assert f.has(a)
assert f.has(Tuple(b))
assert not f.has(c)
def test_eq():
assert Hyper_Function([1], []) == Hyper_Function([1], [])
assert (Hyper_Function([1], []) != Hyper_Function([1], [])) is False
assert Hyper_Function([1], []) != Hyper_Function([2], [])
assert Hyper_Function([1], []) != Hyper_Function([1, 2], [])
assert Hyper_Function([1], []) != Hyper_Function([1], [2])
def test_gamma():
assert Hyper_Function([2, 3], [-1]).gamma == 0
assert Hyper_Function([-2, -3], [-1]).gamma == 2
n = Dummy(integer=True)
assert Hyper_Function([-1, n, 1], []).gamma == 1
assert Hyper_Function([-1, -n, 1], []).gamma == 1
p = Dummy(integer=True, positive=True)
assert Hyper_Function([-1, p, 1], []).gamma == 1
assert Hyper_Function([-1, -p, 1], []).gamma == 2
def test_suitable_origin():
assert Hyper_Function((S.Half,), (Rational(3, 2),))._is_suitable_origin() is True
assert Hyper_Function((S.Half,), (S.Half,))._is_suitable_origin() is False
assert Hyper_Function((S.Half,), (Rational(-1, 2),))._is_suitable_origin() is False
assert Hyper_Function((S.Half,), (0,))._is_suitable_origin() is False
assert Hyper_Function((S.Half,), (-1, 1,))._is_suitable_origin() is False
assert Hyper_Function((S.Half, 0), (1,))._is_suitable_origin() is False
assert Hyper_Function((S.Half, 1),
(2, Rational(-2, 3)))._is_suitable_origin() is True
assert Hyper_Function((S.Half, 1),
(2, Rational(-2, 3), Rational(3, 2)))._is_suitable_origin() is True

View File

@ -0,0 +1,127 @@
from sympy.core.function import Function
from sympy.core.numbers import (Rational, pi)
from sympy.core.singleton import S
from sympy.core.symbol import symbols
from sympy.functions.combinatorial.factorials import (rf, binomial, factorial)
from sympy.functions.elementary.exponential import exp
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.piecewise import Piecewise
from sympy.functions.elementary.trigonometric import (cos, sin)
from sympy.functions.special.gamma_functions import gamma
from sympy.simplify.gammasimp import gammasimp
from sympy.simplify.powsimp import powsimp
from sympy.simplify.simplify import simplify
from sympy.abc import x, y, n, k
def test_gammasimp():
R = Rational
# was part of test_combsimp_gamma() in test_combsimp.py
assert gammasimp(gamma(x)) == gamma(x)
assert gammasimp(gamma(x + 1)/x) == gamma(x)
assert gammasimp(gamma(x)/(x - 1)) == gamma(x - 1)
assert gammasimp(x*gamma(x)) == gamma(x + 1)
assert gammasimp((x + 1)*gamma(x + 1)) == gamma(x + 2)
assert gammasimp(gamma(x + y)*(x + y)) == gamma(x + y + 1)
assert gammasimp(x/gamma(x + 1)) == 1/gamma(x)
assert gammasimp((x + 1)**2/gamma(x + 2)) == (x + 1)/gamma(x + 1)
assert gammasimp(x*gamma(x) + gamma(x + 3)/(x + 2)) == \
(x + 2)*gamma(x + 1)
assert gammasimp(gamma(2*x)*x) == gamma(2*x + 1)/2
assert gammasimp(gamma(2*x)/(x - S.Half)) == 2*gamma(2*x - 1)
assert gammasimp(gamma(x)*gamma(1 - x)) == pi/sin(pi*x)
assert gammasimp(gamma(x)*gamma(-x)) == -pi/(x*sin(pi*x))
assert gammasimp(1/gamma(x + 3)/gamma(1 - x)) == \
sin(pi*x)/(pi*x*(x + 1)*(x + 2))
assert gammasimp(factorial(n + 2)) == gamma(n + 3)
assert gammasimp(binomial(n, k)) == \
gamma(n + 1)/(gamma(k + 1)*gamma(-k + n + 1))
assert powsimp(gammasimp(
gamma(x)*gamma(x + S.Half)*gamma(y)/gamma(x + y))) == \
2**(-2*x + 1)*sqrt(pi)*gamma(2*x)*gamma(y)/gamma(x + y)
assert gammasimp(1/gamma(x)/gamma(x - Rational(1, 3))/gamma(x + Rational(1, 3))) == \
3**(3*x - Rational(3, 2))/(2*pi*gamma(3*x - 1))
assert simplify(
gamma(S.Half + x/2)*gamma(1 + x/2)/gamma(1 + x)/sqrt(pi)*2**x) == 1
assert gammasimp(gamma(Rational(-1, 4))*gamma(Rational(-3, 4))) == 16*sqrt(2)*pi/3
assert powsimp(gammasimp(gamma(2*x)/gamma(x))) == \
2**(2*x - 1)*gamma(x + S.Half)/sqrt(pi)
# issue 6792
e = (-gamma(k)*gamma(k + 2) + gamma(k + 1)**2)/gamma(k)**2
assert gammasimp(e) == -k
assert gammasimp(1/e) == -1/k
e = (gamma(x) + gamma(x + 1))/gamma(x)
assert gammasimp(e) == x + 1
assert gammasimp(1/e) == 1/(x + 1)
e = (gamma(x) + gamma(x + 2))*(gamma(x - 1) + gamma(x))/gamma(x)
assert gammasimp(e) == (x**2 + x + 1)*gamma(x + 1)/(x - 1)
e = (-gamma(k)*gamma(k + 2) + gamma(k + 1)**2)/gamma(k)**2
assert gammasimp(e**2) == k**2
assert gammasimp(e**2/gamma(k + 1)) == k/gamma(k)
a = R(1, 2) + R(1, 3)
b = a + R(1, 3)
assert gammasimp(gamma(2*k)/gamma(k)*gamma(k + a)*gamma(k + b)
) == 3*2**(2*k + 1)*3**(-3*k - 2)*sqrt(pi)*gamma(3*k + R(3, 2))/2
# issue 9699
assert gammasimp((x + 1)*factorial(x)/gamma(y)) == gamma(x + 2)/gamma(y)
assert gammasimp(rf(x + n, k)*binomial(n, k)).simplify() == Piecewise(
(gamma(n + 1)*gamma(k + n + x)/(gamma(k + 1)*gamma(n + x)*gamma(-k + n + 1)), n > -x),
((-1)**k*gamma(n + 1)*gamma(-n - x + 1)/(gamma(k + 1)*gamma(-k + n + 1)*gamma(-k - n - x + 1)), True))
A, B = symbols('A B', commutative=False)
assert gammasimp(e*B*A) == gammasimp(e)*B*A
# check iteration
assert gammasimp(gamma(2*k)/gamma(k)*gamma(-k - R(1, 2))) == (
-2**(2*k + 1)*sqrt(pi)/(2*((2*k + 1)*cos(pi*k))))
assert gammasimp(
gamma(k)*gamma(k + R(1, 3))*gamma(k + R(2, 3))/gamma(k*R(3, 2))) == (
3*2**(3*k + 1)*3**(-3*k - S.Half)*sqrt(pi)*gamma(k*R(3, 2) + S.Half)/2)
# issue 6153
assert gammasimp(gamma(Rational(1, 4))/gamma(Rational(5, 4))) == 4
# was part of test_combsimp() in test_combsimp.py
assert gammasimp(binomial(n + 2, k + S.Half)) == gamma(n + 3)/ \
(gamma(k + R(3, 2))*gamma(-k + n + R(5, 2)))
assert gammasimp(binomial(n + 2, k + 2.0)) == \
gamma(n + 3)/(gamma(k + 3.0)*gamma(-k + n + 1))
# issue 11548
assert gammasimp(binomial(0, x)) == sin(pi*x)/(pi*x)
e = gamma(n + Rational(1, 3))*gamma(n + R(2, 3))
assert gammasimp(e) == e
assert gammasimp(gamma(4*n + S.Half)/gamma(2*n - R(3, 4))) == \
2**(4*n - R(5, 2))*(8*n - 3)*gamma(2*n + R(3, 4))/sqrt(pi)
i, m = symbols('i m', integer = True)
e = gamma(exp(i))
assert gammasimp(e) == e
e = gamma(m + 3)
assert gammasimp(e) == e
e = gamma(m + 1)/(gamma(i + 1)*gamma(-i + m + 1))
assert gammasimp(e) == e
p = symbols("p", integer=True, positive=True)
assert gammasimp(gamma(-p + 4)) == gamma(-p + 4)
def test_issue_22606():
fx = Function('f')(x)
eq = x + gamma(y)
# seems like ans should be `eq`, not `(x*y + gamma(y + 1))/y`
ans = gammasimp(eq)
assert gammasimp(eq.subs(x, fx)).subs(fx, x) == ans
assert gammasimp(eq.subs(x, cos(x))).subs(cos(x), x) == ans
assert 1/gammasimp(1/eq) == ans
assert gammasimp(fx.subs(x, eq)).args[0] == ans

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,366 @@
from sympy.core.function import Function
from sympy.core.mul import Mul
from sympy.core.numbers import (E, I, Rational, oo, pi)
from sympy.core.singleton import S
from sympy.core.symbol import (Dummy, Symbol, symbols)
from sympy.functions.elementary.exponential import (exp, log)
from sympy.functions.elementary.miscellaneous import (root, sqrt)
from sympy.functions.elementary.trigonometric import sin
from sympy.functions.special.gamma_functions import gamma
from sympy.functions.special.hyper import hyper
from sympy.matrices.expressions.matexpr import MatrixSymbol
from sympy.simplify.powsimp import (powdenest, powsimp)
from sympy.simplify.simplify import (signsimp, simplify)
from sympy.core.symbol import Str
from sympy.abc import x, y, z, a, b
def test_powsimp():
x, y, z, n = symbols('x,y,z,n')
f = Function('f')
assert powsimp( 4**x * 2**(-x) * 2**(-x) ) == 1
assert powsimp( (-4)**x * (-2)**(-x) * 2**(-x) ) == 1
assert powsimp(
f(4**x * 2**(-x) * 2**(-x)) ) == f(4**x * 2**(-x) * 2**(-x))
assert powsimp( f(4**x * 2**(-x) * 2**(-x)), deep=True ) == f(1)
assert exp(x)*exp(y) == exp(x)*exp(y)
assert powsimp(exp(x)*exp(y)) == exp(x + y)
assert powsimp(exp(x)*exp(y)*2**x*2**y) == (2*E)**(x + y)
assert powsimp(exp(x)*exp(y)*2**x*2**y, combine='exp') == \
exp(x + y)*2**(x + y)
assert powsimp(exp(x)*exp(y)*exp(2)*sin(x) + sin(y) + 2**x*2**y) == \
exp(2 + x + y)*sin(x) + sin(y) + 2**(x + y)
assert powsimp(sin(exp(x)*exp(y))) == sin(exp(x)*exp(y))
assert powsimp(sin(exp(x)*exp(y)), deep=True) == sin(exp(x + y))
assert powsimp(x**2*x**y) == x**(2 + y)
# This should remain factored, because 'exp' with deep=True is supposed
# to act like old automatic exponent combining.
assert powsimp((1 + E*exp(E))*exp(-E), combine='exp', deep=True) == \
(1 + exp(1 + E))*exp(-E)
assert powsimp((1 + E*exp(E))*exp(-E), deep=True) == \
(1 + exp(1 + E))*exp(-E)
assert powsimp((1 + E*exp(E))*exp(-E)) == (1 + exp(1 + E))*exp(-E)
assert powsimp((1 + E*exp(E))*exp(-E), combine='exp') == \
(1 + exp(1 + E))*exp(-E)
assert powsimp((1 + E*exp(E))*exp(-E), combine='base') == \
(1 + E*exp(E))*exp(-E)
x, y = symbols('x,y', nonnegative=True)
n = Symbol('n', real=True)
assert powsimp(y**n * (y/x)**(-n)) == x**n
assert powsimp(x**(x**(x*y)*y**(x*y))*y**(x**(x*y)*y**(x*y)), deep=True) \
== (x*y)**(x*y)**(x*y)
assert powsimp(2**(2**(2*x)*x), deep=False) == 2**(2**(2*x)*x)
assert powsimp(2**(2**(2*x)*x), deep=True) == 2**(x*4**x)
assert powsimp(
exp(-x + exp(-x)*exp(-x*log(x))), deep=False, combine='exp') == \
exp(-x + exp(-x)*exp(-x*log(x)))
assert powsimp(
exp(-x + exp(-x)*exp(-x*log(x))), deep=False, combine='exp') == \
exp(-x + exp(-x)*exp(-x*log(x)))
assert powsimp((x + y)/(3*z), deep=False, combine='exp') == (x + y)/(3*z)
assert powsimp((x/3 + y/3)/z, deep=True, combine='exp') == (x/3 + y/3)/z
assert powsimp(exp(x)/(1 + exp(x)*exp(y)), deep=True) == \
exp(x)/(1 + exp(x + y))
assert powsimp(x*y**(z**x*z**y), deep=True) == x*y**(z**(x + y))
assert powsimp((z**x*z**y)**x, deep=True) == (z**(x + y))**x
assert powsimp(x*(z**x*z**y)**x, deep=True) == x*(z**(x + y))**x
p = symbols('p', positive=True)
assert powsimp((1/x)**log(2)/x) == (1/x)**(1 + log(2))
assert powsimp((1/p)**log(2)/p) == p**(-1 - log(2))
# coefficient of exponent can only be simplified for positive bases
assert powsimp(2**(2*x)) == 4**x
assert powsimp((-1)**(2*x)) == (-1)**(2*x)
i = symbols('i', integer=True)
assert powsimp((-1)**(2*i)) == 1
assert powsimp((-1)**(-x)) != (-1)**x # could be 1/((-1)**x), but is not
# force=True overrides assumptions
assert powsimp((-1)**(2*x), force=True) == 1
# rational exponents allow combining of negative terms
w, n, m = symbols('w n m', negative=True)
e = i/a # not a rational exponent if `a` is unknown
ex = w**e*n**e*m**e
assert powsimp(ex) == m**(i/a)*n**(i/a)*w**(i/a)
e = i/3
ex = w**e*n**e*m**e
assert powsimp(ex) == (-1)**i*(-m*n*w)**(i/3)
e = (3 + i)/i
ex = w**e*n**e*m**e
assert powsimp(ex) == (-1)**(3*e)*(-m*n*w)**e
eq = x**(a*Rational(2, 3))
# eq != (x**a)**(2/3) (try x = -1 and a = 3 to see)
assert powsimp(eq).exp == eq.exp == a*Rational(2, 3)
# powdenest goes the other direction
assert powsimp(2**(2*x)) == 4**x
assert powsimp(exp(p/2)) == exp(p/2)
# issue 6368
eq = Mul(*[sqrt(Dummy(imaginary=True)) for i in range(3)])
assert powsimp(eq) == eq and eq.is_Mul
assert all(powsimp(e) == e for e in (sqrt(x**a), sqrt(x**2)))
# issue 8836
assert str( powsimp(exp(I*pi/3)*root(-1,3)) ) == '(-1)**(2/3)'
# issue 9183
assert powsimp(-0.1**x) == -0.1**x
# issue 10095
assert powsimp((1/(2*E))**oo) == (exp(-1)/2)**oo
# PR 13131
eq = sin(2*x)**2*sin(2.0*x)**2
assert powsimp(eq) == eq
# issue 14615
assert powsimp(x**2*y**3*(x*y**2)**Rational(3, 2)
) == x*y*(x*y**2)**Rational(5, 2)
def test_powsimp_negated_base():
assert powsimp((-x + y)/sqrt(x - y)) == -sqrt(x - y)
assert powsimp((-x + y)*(-z + y)/sqrt(x - y)/sqrt(z - y)) == sqrt(x - y)*sqrt(z - y)
p = symbols('p', positive=True)
reps = {p: 2, a: S.Half}
assert powsimp((-p)**a/p**a).subs(reps) == ((-1)**a).subs(reps)
assert powsimp((-p)**a*p**a).subs(reps) == ((-p**2)**a).subs(reps)
n = symbols('n', negative=True)
reps = {p: -2, a: S.Half}
assert powsimp((-n)**a/n**a).subs(reps) == (-1)**(-a).subs(a, S.Half)
assert powsimp((-n)**a*n**a).subs(reps) == ((-n**2)**a).subs(reps)
# if x is 0 then the lhs is 0**a*oo**a which is not (-1)**a
eq = (-x)**a/x**a
assert powsimp(eq) == eq
def test_powsimp_nc():
x, y, z = symbols('x,y,z')
A, B, C = symbols('A B C', commutative=False)
assert powsimp(A**x*A**y, combine='all') == A**(x + y)
assert powsimp(A**x*A**y, combine='base') == A**x*A**y
assert powsimp(A**x*A**y, combine='exp') == A**(x + y)
assert powsimp(A**x*B**x, combine='all') == A**x*B**x
assert powsimp(A**x*B**x, combine='base') == A**x*B**x
assert powsimp(A**x*B**x, combine='exp') == A**x*B**x
assert powsimp(B**x*A**x, combine='all') == B**x*A**x
assert powsimp(B**x*A**x, combine='base') == B**x*A**x
assert powsimp(B**x*A**x, combine='exp') == B**x*A**x
assert powsimp(A**x*A**y*A**z, combine='all') == A**(x + y + z)
assert powsimp(A**x*A**y*A**z, combine='base') == A**x*A**y*A**z
assert powsimp(A**x*A**y*A**z, combine='exp') == A**(x + y + z)
assert powsimp(A**x*B**x*C**x, combine='all') == A**x*B**x*C**x
assert powsimp(A**x*B**x*C**x, combine='base') == A**x*B**x*C**x
assert powsimp(A**x*B**x*C**x, combine='exp') == A**x*B**x*C**x
assert powsimp(B**x*A**x*C**x, combine='all') == B**x*A**x*C**x
assert powsimp(B**x*A**x*C**x, combine='base') == B**x*A**x*C**x
assert powsimp(B**x*A**x*C**x, combine='exp') == B**x*A**x*C**x
def test_issue_6440():
assert powsimp(16*2**a*8**b) == 2**(a + 3*b + 4)
def test_powdenest():
x, y = symbols('x,y')
p, q = symbols('p q', positive=True)
i, j = symbols('i,j', integer=True)
assert powdenest(x) == x
assert powdenest(x + 2*(x**(a*Rational(2, 3)))**(3*x)) == (x + 2*(x**(a*Rational(2, 3)))**(3*x))
assert powdenest((exp(a*Rational(2, 3)))**(3*x)) # -X-> (exp(a/3))**(6*x)
assert powdenest((x**(a*Rational(2, 3)))**(3*x)) == ((x**(a*Rational(2, 3)))**(3*x))
assert powdenest(exp(3*x*log(2))) == 2**(3*x)
assert powdenest(sqrt(p**2)) == p
eq = p**(2*i)*q**(4*i)
assert powdenest(eq) == (p*q**2)**(2*i)
# -X-> (x**x)**i*(x**x)**j == x**(x*(i + j))
assert powdenest((x**x)**(i + j))
assert powdenest(exp(3*y*log(x))) == x**(3*y)
assert powdenest(exp(y*(log(a) + log(b)))) == (a*b)**y
assert powdenest(exp(3*(log(a) + log(b)))) == a**3*b**3
assert powdenest(((x**(2*i))**(3*y))**x) == ((x**(2*i))**(3*y))**x
assert powdenest(((x**(2*i))**(3*y))**x, force=True) == x**(6*i*x*y)
assert powdenest(((x**(a*Rational(2, 3)))**(3*y/i))**x) == \
(((x**(a*Rational(2, 3)))**(3*y/i))**x)
assert powdenest((x**(2*i)*y**(4*i))**z, force=True) == (x*y**2)**(2*i*z)
assert powdenest((p**(2*i)*q**(4*i))**j) == (p*q**2)**(2*i*j)
e = ((p**(2*a))**(3*y))**x
assert powdenest(e) == e
e = ((x**2*y**4)**a)**(x*y)
assert powdenest(e) == e
e = (((x**2*y**4)**a)**(x*y))**3
assert powdenest(e) == ((x**2*y**4)**a)**(3*x*y)
assert powdenest((((x**2*y**4)**a)**(x*y)), force=True) == \
(x*y**2)**(2*a*x*y)
assert powdenest((((x**2*y**4)**a)**(x*y))**3, force=True) == \
(x*y**2)**(6*a*x*y)
assert powdenest((x**2*y**6)**i) != (x*y**3)**(2*i)
x, y = symbols('x,y', positive=True)
assert powdenest((x**2*y**6)**i) == (x*y**3)**(2*i)
assert powdenest((x**(i*Rational(2, 3))*y**(i/2))**(2*i)) == (x**Rational(4, 3)*y)**(i**2)
assert powdenest(sqrt(x**(2*i)*y**(6*i))) == (x*y**3)**i
assert powdenest(4**x) == 2**(2*x)
assert powdenest((4**x)**y) == 2**(2*x*y)
assert powdenest(4**x*y) == 2**(2*x)*y
def test_powdenest_polar():
x, y, z = symbols('x y z', polar=True)
a, b, c = symbols('a b c')
assert powdenest((x*y*z)**a) == x**a*y**a*z**a
assert powdenest((x**a*y**b)**c) == x**(a*c)*y**(b*c)
assert powdenest(((x**a)**b*y**c)**c) == x**(a*b*c)*y**(c**2)
def test_issue_5805():
arg = ((gamma(x)*hyper((), (), x))*pi)**2
assert powdenest(arg) == (pi*gamma(x)*hyper((), (), x))**2
assert arg.is_positive is None
def test_issue_9324_powsimp_on_matrix_symbol():
M = MatrixSymbol('M', 10, 10)
expr = powsimp(M, deep=True)
assert expr == M
assert expr.args[0] == Str('M')
def test_issue_6367():
z = -5*sqrt(2)/(2*sqrt(2*sqrt(29) + 29)) + sqrt(-sqrt(29)/29 + S.Half)
assert Mul(*[powsimp(a) for a in Mul.make_args(z.normal())]) == 0
assert powsimp(z.normal()) == 0
assert simplify(z) == 0
assert powsimp(sqrt(2 + sqrt(3))*sqrt(2 - sqrt(3)) + 1) == 2
assert powsimp(z) != 0
def test_powsimp_polar():
from sympy.functions.elementary.complexes import polar_lift
from sympy.functions.elementary.exponential import exp_polar
x, y, z = symbols('x y z')
p, q, r = symbols('p q r', polar=True)
assert (polar_lift(-1))**(2*x) == exp_polar(2*pi*I*x)
assert powsimp(p**x * q**x) == (p*q)**x
assert p**x * (1/p)**x == 1
assert (1/p)**x == p**(-x)
assert exp_polar(x)*exp_polar(y) == exp_polar(x)*exp_polar(y)
assert powsimp(exp_polar(x)*exp_polar(y)) == exp_polar(x + y)
assert powsimp(exp_polar(x)*exp_polar(y)*p**x*p**y) == \
(p*exp_polar(1))**(x + y)
assert powsimp(exp_polar(x)*exp_polar(y)*p**x*p**y, combine='exp') == \
exp_polar(x + y)*p**(x + y)
assert powsimp(
exp_polar(x)*exp_polar(y)*exp_polar(2)*sin(x) + sin(y) + p**x*p**y) \
== p**(x + y) + sin(x)*exp_polar(2 + x + y) + sin(y)
assert powsimp(sin(exp_polar(x)*exp_polar(y))) == \
sin(exp_polar(x)*exp_polar(y))
assert powsimp(sin(exp_polar(x)*exp_polar(y)), deep=True) == \
sin(exp_polar(x + y))
def test_issue_5728():
b = x*sqrt(y)
a = sqrt(b)
c = sqrt(sqrt(x)*y)
assert powsimp(a*b) == sqrt(b)**3
assert powsimp(a*b**2*sqrt(y)) == sqrt(y)*a**5
assert powsimp(a*x**2*c**3*y) == c**3*a**5
assert powsimp(a*x*c**3*y**2) == c**7*a
assert powsimp(x*c**3*y**2) == c**7
assert powsimp(x*c**3*y) == x*y*c**3
assert powsimp(sqrt(x)*c**3*y) == c**5
assert powsimp(sqrt(x)*a**3*sqrt(y)) == sqrt(x)*sqrt(y)*a**3
assert powsimp(Mul(sqrt(x)*c**3*sqrt(y), y, evaluate=False)) == \
sqrt(x)*sqrt(y)**3*c**3
assert powsimp(a**2*a*x**2*y) == a**7
# symbolic powers work, too
b = x**y*y
a = b*sqrt(b)
assert a.is_Mul is True
assert powsimp(a) == sqrt(b)**3
# as does exp
a = x*exp(y*Rational(2, 3))
assert powsimp(a*sqrt(a)) == sqrt(a)**3
assert powsimp(a**2*sqrt(a)) == sqrt(a)**5
assert powsimp(a**2*sqrt(sqrt(a))) == sqrt(sqrt(a))**9
def test_issue_from_PR1599():
n1, n2, n3, n4 = symbols('n1 n2 n3 n4', negative=True)
assert (powsimp(sqrt(n1)*sqrt(n2)*sqrt(n3)) ==
-I*sqrt(-n1)*sqrt(-n2)*sqrt(-n3))
assert (powsimp(root(n1, 3)*root(n2, 3)*root(n3, 3)*root(n4, 3)) ==
-(-1)**Rational(1, 3)*
(-n1)**Rational(1, 3)*(-n2)**Rational(1, 3)*(-n3)**Rational(1, 3)*(-n4)**Rational(1, 3))
def test_issue_10195():
a = Symbol('a', integer=True)
l = Symbol('l', even=True, nonzero=True)
n = Symbol('n', odd=True)
e_x = (-1)**(n/2 - S.Half) - (-1)**(n*Rational(3, 2) - S.Half)
assert powsimp((-1)**(l/2)) == I**l
assert powsimp((-1)**(n/2)) == I**n
assert powsimp((-1)**(n*Rational(3, 2))) == -I**n
assert powsimp(e_x) == (-1)**(n/2 - S.Half) + (-1)**(n*Rational(3, 2) +
S.Half)
assert powsimp((-1)**(a*Rational(3, 2))) == (-I)**a
def test_issue_15709():
assert powsimp(3**x*Rational(2, 3)) == 2*3**(x-1)
assert powsimp(2*3**x/3) == 2*3**(x-1)
def test_issue_11981():
x, y = symbols('x y', commutative=False)
assert powsimp((x*y)**2 * (y*x)**2) == (x*y)**2 * (y*x)**2
def test_issue_17524():
a = symbols("a", real=True)
e = (-1 - a**2)*sqrt(1 + a**2)
assert signsimp(powsimp(e)) == signsimp(e) == -(a**2 + 1)**(S(3)/2)
def test_issue_19627():
# if you use force the user must verify
assert powdenest(sqrt(sin(x)**2), force=True) == sin(x)
assert powdenest((x**(S.Half/y))**(2*y), force=True) == x
from sympy.core.function import expand_power_base
e = 1 - a
expr = (exp(z/e)*x**(b/e)*y**((1 - b)/e))**e
assert powdenest(expand_power_base(expr, force=True), force=True
) == x**b*y**(1 - b)*exp(z)
def test_issue_22546():
p1, p2 = symbols('p1, p2', positive=True)
ref = powsimp(p1**z/p2**z)
e = z + 1
ans = ref.subs(z, e)
assert ans.is_Pow
assert powsimp(p1**e/p2**e) == ans
i = symbols('i', integer=True)
ref = powsimp(x**i/y**i)
e = i + 1
ans = ref.subs(i, e)
assert ans.is_Pow
assert powsimp(x**e/y**e) == ans

View File

@ -0,0 +1,490 @@
from sympy.core.add import Add
from sympy.core.function import (Derivative, Function, diff)
from sympy.core.mul import Mul
from sympy.core.numbers import (I, Rational)
from sympy.core.power import Pow
from sympy.core.singleton import S
from sympy.core.symbol import (Symbol, Wild, symbols)
from sympy.functions.elementary.complexes import Abs
from sympy.functions.elementary.exponential import (exp, log)
from sympy.functions.elementary.miscellaneous import (root, sqrt)
from sympy.functions.elementary.trigonometric import (cos, sin)
from sympy.polys.polytools import factor
from sympy.series.order import O
from sympy.simplify.radsimp import (collect, collect_const, fraction, radsimp, rcollect)
from sympy.core.expr import unchanged
from sympy.core.mul import _unevaluated_Mul as umul
from sympy.simplify.radsimp import (_unevaluated_Add,
collect_sqrt, fraction_expand, collect_abs)
from sympy.testing.pytest import raises
from sympy.abc import x, y, z, a, b, c, d
def test_radsimp():
r2 = sqrt(2)
r3 = sqrt(3)
r5 = sqrt(5)
r7 = sqrt(7)
assert fraction(radsimp(1/r2)) == (sqrt(2), 2)
assert radsimp(1/(1 + r2)) == \
-1 + sqrt(2)
assert radsimp(1/(r2 + r3)) == \
-sqrt(2) + sqrt(3)
assert fraction(radsimp(1/(1 + r2 + r3))) == \
(-sqrt(6) + sqrt(2) + 2, 4)
assert fraction(radsimp(1/(r2 + r3 + r5))) == \
(-sqrt(30) + 2*sqrt(3) + 3*sqrt(2), 12)
assert fraction(radsimp(1/(1 + r2 + r3 + r5))) == (
(-34*sqrt(10) - 26*sqrt(15) - 55*sqrt(3) - 61*sqrt(2) + 14*sqrt(30) +
93 + 46*sqrt(6) + 53*sqrt(5), 71))
assert fraction(radsimp(1/(r2 + r3 + r5 + r7))) == (
(-50*sqrt(42) - 133*sqrt(5) - 34*sqrt(70) - 145*sqrt(3) + 22*sqrt(105)
+ 185*sqrt(2) + 62*sqrt(30) + 135*sqrt(7), 215))
z = radsimp(1/(1 + r2/3 + r3/5 + r5 + r7))
assert len((3616791619821680643598*z).args) == 16
assert radsimp(1/z) == 1/z
assert radsimp(1/z, max_terms=20).expand() == 1 + r2/3 + r3/5 + r5 + r7
assert radsimp(1/(r2*3)) == \
sqrt(2)/6
assert radsimp(1/(r2*a + r3 + r5 + r7)) == (
(8*sqrt(2)*a**7 - 8*sqrt(7)*a**6 - 8*sqrt(5)*a**6 - 8*sqrt(3)*a**6 -
180*sqrt(2)*a**5 + 8*sqrt(30)*a**5 + 8*sqrt(42)*a**5 + 8*sqrt(70)*a**5
- 24*sqrt(105)*a**4 + 84*sqrt(3)*a**4 + 100*sqrt(5)*a**4 +
116*sqrt(7)*a**4 - 72*sqrt(70)*a**3 - 40*sqrt(42)*a**3 -
8*sqrt(30)*a**3 + 782*sqrt(2)*a**3 - 462*sqrt(3)*a**2 -
302*sqrt(7)*a**2 - 254*sqrt(5)*a**2 + 120*sqrt(105)*a**2 -
795*sqrt(2)*a - 62*sqrt(30)*a + 82*sqrt(42)*a + 98*sqrt(70)*a -
118*sqrt(105) + 59*sqrt(7) + 295*sqrt(5) + 531*sqrt(3))/(16*a**8 -
480*a**6 + 3128*a**4 - 6360*a**2 + 3481))
assert radsimp(1/(r2*a + r2*b + r3 + r7)) == (
(sqrt(2)*a*(a + b)**2 - 5*sqrt(2)*a + sqrt(42)*a + sqrt(2)*b*(a +
b)**2 - 5*sqrt(2)*b + sqrt(42)*b - sqrt(7)*(a + b)**2 - sqrt(3)*(a +
b)**2 - 2*sqrt(3) + 2*sqrt(7))/(2*a**4 + 8*a**3*b + 12*a**2*b**2 -
20*a**2 + 8*a*b**3 - 40*a*b + 2*b**4 - 20*b**2 + 8))
assert radsimp(1/(r2*a + r2*b + r2*c + r2*d)) == \
sqrt(2)/(2*a + 2*b + 2*c + 2*d)
assert radsimp(1/(1 + r2*a + r2*b + r2*c + r2*d)) == (
(sqrt(2)*a + sqrt(2)*b + sqrt(2)*c + sqrt(2)*d - 1)/(2*a**2 + 4*a*b +
4*a*c + 4*a*d + 2*b**2 + 4*b*c + 4*b*d + 2*c**2 + 4*c*d + 2*d**2 - 1))
assert radsimp((y**2 - x)/(y - sqrt(x))) == \
sqrt(x) + y
assert radsimp(-(y**2 - x)/(y - sqrt(x))) == \
-(sqrt(x) + y)
assert radsimp(1/(1 - I + a*I)) == \
(-I*a + 1 + I)/(a**2 - 2*a + 2)
assert radsimp(1/((-x + y)*(x - sqrt(y)))) == \
(-x - sqrt(y))/((x - y)*(x**2 - y))
e = (3 + 3*sqrt(2))*x*(3*x - 3*sqrt(y))
assert radsimp(e) == x*(3 + 3*sqrt(2))*(3*x - 3*sqrt(y))
assert radsimp(1/e) == (
(-9*x + 9*sqrt(2)*x - 9*sqrt(y) + 9*sqrt(2)*sqrt(y))/(9*x*(9*x**2 -
9*y)))
assert radsimp(1 + 1/(1 + sqrt(3))) == \
Mul(S.Half, -1 + sqrt(3), evaluate=False) + 1
A = symbols("A", commutative=False)
assert radsimp(x**2 + sqrt(2)*x**2 - sqrt(2)*x*A) == \
x**2 + sqrt(2)*x**2 - sqrt(2)*x*A
assert radsimp(1/sqrt(5 + 2 * sqrt(6))) == -sqrt(2) + sqrt(3)
assert radsimp(1/sqrt(5 + 2 * sqrt(6))**3) == -(-sqrt(3) + sqrt(2))**3
# issue 6532
assert fraction(radsimp(1/sqrt(x))) == (sqrt(x), x)
assert fraction(radsimp(1/sqrt(2*x + 3))) == (sqrt(2*x + 3), 2*x + 3)
assert fraction(radsimp(1/sqrt(2*(x + 3)))) == (sqrt(2*x + 6), 2*x + 6)
# issue 5994
e = S('-(2 + 2*sqrt(2) + 4*2**(1/4))/'
'(1 + 2**(3/4) + 3*2**(1/4) + 3*sqrt(2))')
assert radsimp(e).expand() == -2*2**Rational(3, 4) - 2*2**Rational(1, 4) + 2 + 2*sqrt(2)
# issue 5986 (modifications to radimp didn't initially recognize this so
# the test is included here)
assert radsimp(1/(-sqrt(5)/2 - S.Half + (-sqrt(5)/2 - S.Half)**2)) == 1
# from issue 5934
eq = (
(-240*sqrt(2)*sqrt(sqrt(5) + 5)*sqrt(8*sqrt(5) + 40) -
360*sqrt(2)*sqrt(-8*sqrt(5) + 40)*sqrt(-sqrt(5) + 5) -
120*sqrt(10)*sqrt(-8*sqrt(5) + 40)*sqrt(-sqrt(5) + 5) +
120*sqrt(2)*sqrt(-sqrt(5) + 5)*sqrt(8*sqrt(5) + 40) +
120*sqrt(2)*sqrt(-8*sqrt(5) + 40)*sqrt(sqrt(5) + 5) +
120*sqrt(10)*sqrt(-sqrt(5) + 5)*sqrt(8*sqrt(5) + 40) +
120*sqrt(10)*sqrt(-8*sqrt(5) + 40)*sqrt(sqrt(5) + 5))/(-36000 -
7200*sqrt(5) + (12*sqrt(10)*sqrt(sqrt(5) + 5) +
24*sqrt(10)*sqrt(-sqrt(5) + 5))**2))
assert radsimp(eq) is S.NaN # it's 0/0
# work with normal form
e = 1/sqrt(sqrt(7)/7 + 2*sqrt(2) + 3*sqrt(3) + 5*sqrt(5)) + 3
assert radsimp(e) == (
-sqrt(sqrt(7) + 14*sqrt(2) + 21*sqrt(3) +
35*sqrt(5))*(-11654899*sqrt(35) - 1577436*sqrt(210) - 1278438*sqrt(15)
- 1346996*sqrt(10) + 1635060*sqrt(6) + 5709765 + 7539830*sqrt(14) +
8291415*sqrt(21))/1300423175 + 3)
# obey power rules
base = sqrt(3) - sqrt(2)
assert radsimp(1/base**3) == (sqrt(3) + sqrt(2))**3
assert radsimp(1/(-base)**3) == -(sqrt(2) + sqrt(3))**3
assert radsimp(1/(-base)**x) == (-base)**(-x)
assert radsimp(1/base**x) == (sqrt(2) + sqrt(3))**x
assert radsimp(root(1/(-1 - sqrt(2)), -x)) == (-1)**(-1/x)*(1 + sqrt(2))**(1/x)
# recurse
e = cos(1/(1 + sqrt(2)))
assert radsimp(e) == cos(-sqrt(2) + 1)
assert radsimp(e/2) == cos(-sqrt(2) + 1)/2
assert radsimp(1/e) == 1/cos(-sqrt(2) + 1)
assert radsimp(2/e) == 2/cos(-sqrt(2) + 1)
assert fraction(radsimp(e/sqrt(x))) == (sqrt(x)*cos(-sqrt(2)+1), x)
# test that symbolic denominators are not processed
r = 1 + sqrt(2)
assert radsimp(x/r, symbolic=False) == -x*(-sqrt(2) + 1)
assert radsimp(x/(y + r), symbolic=False) == x/(y + 1 + sqrt(2))
assert radsimp(x/(y + r)/r, symbolic=False) == \
-x*(-sqrt(2) + 1)/(y + 1 + sqrt(2))
# issue 7408
eq = sqrt(x)/sqrt(y)
assert radsimp(eq) == umul(sqrt(x), sqrt(y), 1/y)
assert radsimp(eq, symbolic=False) == eq
# issue 7498
assert radsimp(sqrt(x)/sqrt(y)**3) == umul(sqrt(x), sqrt(y**3), 1/y**3)
# for coverage
eq = sqrt(x)/y**2
assert radsimp(eq) == eq
def test_radsimp_issue_3214():
c, p = symbols('c p', positive=True)
s = sqrt(c**2 - p**2)
b = (c + I*p - s)/(c + I*p + s)
assert radsimp(b) == -I*(c + I*p - sqrt(c**2 - p**2))**2/(2*c*p)
def test_collect_1():
"""Collect with respect to Symbol"""
x, y, z, n = symbols('x,y,z,n')
assert collect(1, x) == 1
assert collect( x + y*x, x ) == x * (1 + y)
assert collect( x + x**2, x ) == x + x**2
assert collect( x**2 + y*x**2, x ) == (x**2)*(1 + y)
assert collect( x**2 + y*x, x ) == x*y + x**2
assert collect( 2*x**2 + y*x**2 + 3*x*y, [x] ) == x**2*(2 + y) + 3*x*y
assert collect( 2*x**2 + y*x**2 + 3*x*y, [y] ) == 2*x**2 + y*(x**2 + 3*x)
assert collect( ((1 + y + x)**4).expand(), x) == ((1 + y)**4).expand() + \
x*(4*(1 + y)**3).expand() + x**2*(6*(1 + y)**2).expand() + \
x**3*(4*(1 + y)).expand() + x**4
# symbols can be given as any iterable
expr = x + y
assert collect(expr, expr.free_symbols) == expr
assert collect(x*exp(x) + sin(x)*y + sin(x)*2 + 3*x, x, exact=None
) == x*exp(x) + 3*x + (y + 2)*sin(x)
assert collect(x*exp(x) + sin(x)*y + sin(x)*2 + 3*x + y*x +
y*x*exp(x), x, exact=None
) == x*exp(x)*(y + 1) + (3 + y)*x + (y + 2)*sin(x)
def test_collect_2():
"""Collect with respect to a sum"""
a, b, x = symbols('a,b,x')
assert collect(a*(cos(x) + sin(x)) + b*(cos(x) + sin(x)),
sin(x) + cos(x)) == (a + b)*(cos(x) + sin(x))
def test_collect_3():
"""Collect with respect to a product"""
a, b, c = symbols('a,b,c')
f = Function('f')
x, y, z, n = symbols('x,y,z,n')
assert collect(-x/8 + x*y, -x) == x*(y - Rational(1, 8))
assert collect( 1 + x*(y**2), x*y ) == 1 + x*(y**2)
assert collect( x*y + a*x*y, x*y) == x*y*(1 + a)
assert collect( 1 + x*y + a*x*y, x*y) == 1 + x*y*(1 + a)
assert collect(a*x*f(x) + b*(x*f(x)), x*f(x)) == x*(a + b)*f(x)
assert collect(a*x*log(x) + b*(x*log(x)), x*log(x)) == x*(a + b)*log(x)
assert collect(a*x**2*log(x)**2 + b*(x*log(x))**2, x*log(x)) == \
x**2*log(x)**2*(a + b)
# with respect to a product of three symbols
assert collect(y*x*z + a*x*y*z, x*y*z) == (1 + a)*x*y*z
def test_collect_4():
"""Collect with respect to a power"""
a, b, c, x = symbols('a,b,c,x')
assert collect(a*x**c + b*x**c, x**c) == x**c*(a + b)
# issue 6096: 2 stays with c (unless c is integer or x is positive0
assert collect(a*x**(2*c) + b*x**(2*c), x**c) == x**(2*c)*(a + b)
def test_collect_5():
"""Collect with respect to a tuple"""
a, x, y, z, n = symbols('a,x,y,z,n')
assert collect(x**2*y**4 + z*(x*y**2)**2 + z + a*z, [x*y**2, z]) in [
z*(1 + a + x**2*y**4) + x**2*y**4,
z*(1 + a) + x**2*y**4*(1 + z) ]
assert collect((1 + (x + y) + (x + y)**2).expand(),
[x, y]) == 1 + y + x*(1 + 2*y) + x**2 + y**2
def test_collect_pr19431():
"""Unevaluated collect with respect to a product"""
a = symbols('a')
assert collect(a**2*(a**2 + 1), a**2, evaluate=False)[a**2] == (a**2 + 1)
def test_collect_D():
D = Derivative
f = Function('f')
x, a, b = symbols('x,a,b')
fx = D(f(x), x)
fxx = D(f(x), x, x)
assert collect(a*fx + b*fx, fx) == (a + b)*fx
assert collect(a*D(fx, x) + b*D(fx, x), fx) == (a + b)*D(fx, x)
assert collect(a*fxx + b*fxx, fx) == (a + b)*D(fx, x)
# issue 4784
assert collect(5*f(x) + 3*fx, fx) == 5*f(x) + 3*fx
assert collect(f(x) + f(x)*diff(f(x), x) + x*diff(f(x), x)*f(x), f(x).diff(x)) == \
(x*f(x) + f(x))*D(f(x), x) + f(x)
assert collect(f(x) + f(x)*diff(f(x), x) + x*diff(f(x), x)*f(x), f(x).diff(x), exact=True) == \
(x*f(x) + f(x))*D(f(x), x) + f(x)
assert collect(1/f(x) + 1/f(x)*diff(f(x), x) + x*diff(f(x), x)/f(x), f(x).diff(x), exact=True) == \
(1/f(x) + x/f(x))*D(f(x), x) + 1/f(x)
e = (1 + x*fx + fx)/f(x)
assert collect(e.expand(), fx) == fx*(x/f(x) + 1/f(x)) + 1/f(x)
def test_collect_func():
f = ((x + a + 1)**3).expand()
assert collect(f, x) == a**3 + 3*a**2 + 3*a + x**3 + x**2*(3*a + 3) + \
x*(3*a**2 + 6*a + 3) + 1
assert collect(f, x, factor) == x**3 + 3*x**2*(a + 1) + 3*x*(a + 1)**2 + \
(a + 1)**3
assert collect(f, x, evaluate=False) == {
S.One: a**3 + 3*a**2 + 3*a + 1,
x: 3*a**2 + 6*a + 3, x**2: 3*a + 3,
x**3: 1
}
assert collect(f, x, factor, evaluate=False) == {
S.One: (a + 1)**3, x: 3*(a + 1)**2,
x**2: umul(S(3), a + 1), x**3: 1}
def test_collect_order():
a, b, x, t = symbols('a,b,x,t')
assert collect(t + t*x + t*x**2 + O(x**3), t) == t*(1 + x + x**2 + O(x**3))
assert collect(t + t*x + x**2 + O(x**3), t) == \
t*(1 + x + O(x**3)) + x**2 + O(x**3)
f = a*x + b*x + c*x**2 + d*x**2 + O(x**3)
g = x*(a + b) + x**2*(c + d) + O(x**3)
assert collect(f, x) == g
assert collect(f, x, distribute_order_term=False) == g
f = sin(a + b).series(b, 0, 10)
assert collect(f, [sin(a), cos(a)]) == \
sin(a)*cos(b).series(b, 0, 10) + cos(a)*sin(b).series(b, 0, 10)
assert collect(f, [sin(a), cos(a)], distribute_order_term=False) == \
sin(a)*cos(b).series(b, 0, 10).removeO() + \
cos(a)*sin(b).series(b, 0, 10).removeO() + O(b**10)
def test_rcollect():
assert rcollect((x**2*y + x*y + x + y)/(x + y), y) == \
(x + y*(1 + x + x**2))/(x + y)
assert rcollect(sqrt(-((x + 1)*(y + 1))), z) == sqrt(-((x + 1)*(y + 1)))
def test_collect_D_0():
D = Derivative
f = Function('f')
x, a, b = symbols('x,a,b')
fxx = D(f(x), x, x)
assert collect(a*fxx + b*fxx, fxx) == (a + b)*fxx
def test_collect_Wild():
"""Collect with respect to functions with Wild argument"""
a, b, x, y = symbols('a b x y')
f = Function('f')
w1 = Wild('.1')
w2 = Wild('.2')
assert collect(f(x) + a*f(x), f(w1)) == (1 + a)*f(x)
assert collect(f(x, y) + a*f(x, y), f(w1)) == f(x, y) + a*f(x, y)
assert collect(f(x, y) + a*f(x, y), f(w1, w2)) == (1 + a)*f(x, y)
assert collect(f(x, y) + a*f(x, y), f(w1, w1)) == f(x, y) + a*f(x, y)
assert collect(f(x, x) + a*f(x, x), f(w1, w1)) == (1 + a)*f(x, x)
assert collect(a*(x + 1)**y + (x + 1)**y, w1**y) == (1 + a)*(x + 1)**y
assert collect(a*(x + 1)**y + (x + 1)**y, w1**b) == \
a*(x + 1)**y + (x + 1)**y
assert collect(a*(x + 1)**y + (x + 1)**y, (x + 1)**w2) == \
(1 + a)*(x + 1)**y
assert collect(a*(x + 1)**y + (x + 1)**y, w1**w2) == (1 + a)*(x + 1)**y
def test_collect_const():
# coverage not provided by above tests
assert collect_const(2*sqrt(3) + 4*a*sqrt(5)) == \
2*(2*sqrt(5)*a + sqrt(3)) # let the primitive reabsorb
assert collect_const(2*sqrt(3) + 4*a*sqrt(5), sqrt(3)) == \
2*sqrt(3) + 4*a*sqrt(5)
assert collect_const(sqrt(2)*(1 + sqrt(2)) + sqrt(3) + x*sqrt(2)) == \
sqrt(2)*(x + 1 + sqrt(2)) + sqrt(3)
# issue 5290
assert collect_const(2*x + 2*y + 1, 2) == \
collect_const(2*x + 2*y + 1) == \
Add(S.One, Mul(2, x + y, evaluate=False), evaluate=False)
assert collect_const(-y - z) == Mul(-1, y + z, evaluate=False)
assert collect_const(2*x - 2*y - 2*z, 2) == \
Mul(2, x - y - z, evaluate=False)
assert collect_const(2*x - 2*y - 2*z, -2) == \
_unevaluated_Add(2*x, Mul(-2, y + z, evaluate=False))
# this is why the content_primitive is used
eq = (sqrt(15 + 5*sqrt(2))*x + sqrt(3 + sqrt(2))*y)*2
assert collect_sqrt(eq + 2) == \
2*sqrt(sqrt(2) + 3)*(sqrt(5)*x + y) + 2
# issue 16296
assert collect_const(a + b + x/2 + y/2) == a + b + Mul(S.Half, x + y, evaluate=False)
def test_issue_13143():
f = Function('f')
fx = f(x).diff(x)
e = f(x) + fx + f(x)*fx
# collect function before derivative
assert collect(e, Wild('w')) == f(x)*(fx + 1) + fx
e = f(x) + f(x)*fx + x*fx*f(x)
assert collect(e, fx) == (x*f(x) + f(x))*fx + f(x)
assert collect(e, f(x)) == (x*fx + fx + 1)*f(x)
e = f(x) + fx + f(x)*fx
assert collect(e, [f(x), fx]) == f(x)*(1 + fx) + fx
assert collect(e, [fx, f(x)]) == fx*(1 + f(x)) + f(x)
def test_issue_6097():
assert collect(a*y**(2.0*x) + b*y**(2.0*x), y**x) == (a + b)*(y**x)**2.0
assert collect(a*2**(2.0*x) + b*2**(2.0*x), 2**x) == (a + b)*(2**x)**2.0
def test_fraction_expand():
eq = (x + y)*y/x
assert eq.expand(frac=True) == fraction_expand(eq) == (x*y + y**2)/x
assert eq.expand() == y + y**2/x
def test_fraction():
x, y, z = map(Symbol, 'xyz')
A = Symbol('A', commutative=False)
assert fraction(S.Half) == (1, 2)
assert fraction(x) == (x, 1)
assert fraction(1/x) == (1, x)
assert fraction(x/y) == (x, y)
assert fraction(x/2) == (x, 2)
assert fraction(x*y/z) == (x*y, z)
assert fraction(x/(y*z)) == (x, y*z)
assert fraction(1/y**2) == (1, y**2)
assert fraction(x/y**2) == (x, y**2)
assert fraction((x**2 + 1)/y) == (x**2 + 1, y)
assert fraction(x*(y + 1)/y**7) == (x*(y + 1), y**7)
assert fraction(exp(-x), exact=True) == (exp(-x), 1)
assert fraction((1/(x + y))/2, exact=True) == (1, Mul(2,(x + y), evaluate=False))
assert fraction(x*A/y) == (x*A, y)
assert fraction(x*A**-1/y) == (x*A**-1, y)
n = symbols('n', negative=True)
assert fraction(exp(n)) == (1, exp(-n))
assert fraction(exp(-n)) == (exp(-n), 1)
p = symbols('p', positive=True)
assert fraction(exp(-p)*log(p), exact=True) == (exp(-p)*log(p), 1)
m = Mul(1, 1, S.Half, evaluate=False)
assert fraction(m) == (1, 2)
assert fraction(m, exact=True) == (Mul(1, 1, evaluate=False), 2)
m = Mul(1, 1, S.Half, S.Half, Pow(1, -1, evaluate=False), evaluate=False)
assert fraction(m) == (1, 4)
assert fraction(m, exact=True) == \
(Mul(1, 1, evaluate=False), Mul(2, 2, 1, evaluate=False))
def test_issue_5615():
aA, Re, a, b, D = symbols('aA Re a b D')
e = ((D**3*a + b*aA**3)/Re).expand()
assert collect(e, [aA**3/Re, a]) == e
def test_issue_5933():
from sympy.geometry.polygon import (Polygon, RegularPolygon)
from sympy.simplify.radsimp import denom
x = Polygon(*RegularPolygon((0, 0), 1, 5).vertices).centroid.x
assert abs(denom(x).n()) > 1e-12
assert abs(denom(radsimp(x))) > 1e-12 # in case simplify didn't handle it
def test_issue_14608():
a, b = symbols('a b', commutative=False)
x, y = symbols('x y')
raises(AttributeError, lambda: collect(a*b + b*a, a))
assert collect(x*y + y*(x+1), a) == x*y + y*(x+1)
assert collect(x*y + y*(x+1) + a*b + b*a, y) == y*(2*x + 1) + a*b + b*a
def test_collect_abs():
s = abs(x) + abs(y)
assert collect_abs(s) == s
assert unchanged(Mul, abs(x), abs(y))
ans = Abs(x*y)
assert isinstance(ans, Abs)
assert collect_abs(abs(x)*abs(y)) == ans
assert collect_abs(1 + exp(abs(x)*abs(y))) == 1 + exp(ans)
# See https://github.com/sympy/sympy/issues/12910
p = Symbol('p', positive=True)
assert collect_abs(p/abs(1-p)).is_commutative is True
def test_issue_19149():
eq = exp(3*x/4)
assert collect(eq, exp(x)) == eq
def test_issue_19719():
a, b = symbols('a, b')
expr = a**2 * (b + 1) + (7 + 1/b)/a
collected = collect(expr, (a**2, 1/a), evaluate=False)
# Would return {_Dummy_20**(-2): b + 1, 1/a: 7 + 1/b} without xreplace
assert collected == {a**2: b + 1, 1/a: 7 + 1/b}
def test_issue_21355():
assert radsimp(1/(x + sqrt(x**2))) == 1/(x + sqrt(x**2))
assert radsimp(1/(x - sqrt(x**2))) == 1/(x - sqrt(x**2))

View File

@ -0,0 +1,78 @@
from sympy.core.numbers import (Rational, pi)
from sympy.functions.elementary.exponential import log
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.special.error_functions import erf
from sympy.polys.domains import GF
from sympy.simplify.ratsimp import (ratsimp, ratsimpmodprime)
from sympy.abc import x, y, z, t, a, b, c, d, e
def test_ratsimp():
f, g = 1/x + 1/y, (x + y)/(x*y)
assert f != g and ratsimp(f) == g
f, g = 1/(1 + 1/x), 1 - 1/(x + 1)
assert f != g and ratsimp(f) == g
f, g = x/(x + y) + y/(x + y), 1
assert f != g and ratsimp(f) == g
f, g = -x - y - y**2/(x + y) + x**2/(x + y), -2*y
assert f != g and ratsimp(f) == g
f = (a*c*x*y + a*c*z - b*d*x*y - b*d*z - b*t*x*y - b*t*x - b*t*z +
e*x)/(x*y + z)
G = [a*c - b*d - b*t + (-b*t*x + e*x)/(x*y + z),
a*c - b*d - b*t - ( b*t*x - e*x)/(x*y + z)]
assert f != g and ratsimp(f) in G
A = sqrt(pi)
B = log(erf(x) - 1)
C = log(erf(x) + 1)
D = 8 - 8*erf(x)
f = A*B/D - A*C/D + A*C*erf(x)/D - A*B*erf(x)/D + 2*A/D
assert ratsimp(f) == A*B/8 - A*C/8 - A/(4*erf(x) - 4)
def test_ratsimpmodprime():
a = y**5 + x + y
b = x - y
F = [x*y**5 - x - y]
assert ratsimpmodprime(a/b, F, x, y, order='lex') == \
(-x**2 - x*y - x - y) / (-x**2 + x*y)
a = x + y**2 - 2
b = x + y**2 - y - 1
F = [x*y - 1]
assert ratsimpmodprime(a/b, F, x, y, order='lex') == \
(1 + y - x)/(y - x)
a = 5*x**3 + 21*x**2 + 4*x*y + 23*x + 12*y + 15
b = 7*x**3 - y*x**2 + 31*x**2 + 2*x*y + 15*y + 37*x + 21
F = [x**2 + y**2 - 1]
assert ratsimpmodprime(a/b, F, x, y, order='lex') == \
(1 + 5*y - 5*x)/(8*y - 6*x)
a = x*y - x - 2*y + 4
b = x + y**2 - 2*y
F = [x - 2, y - 3]
assert ratsimpmodprime(a/b, F, x, y, order='lex') == \
Rational(2, 5)
# Test a bug where denominators would be dropped
assert ratsimpmodprime(x, [y - 2*x], order='lex') == \
y/2
a = (x**5 + 2*x**4 + 2*x**3 + 2*x**2 + x + 2/x + x**(-2))
assert ratsimpmodprime(a, [x + 1], domain=GF(2)) == 1
assert ratsimpmodprime(a, [x + 1], domain=GF(3)) == -1

View File

@ -0,0 +1,31 @@
from sympy.core.numbers import I
from sympy.core.symbol import symbols
from sympy.functions.elementary.exponential import exp
from sympy.functions.elementary.trigonometric import (cos, cot, sin)
from sympy.testing.pytest import _both_exp_pow
x, y, z, n = symbols('x,y,z,n')
@_both_exp_pow
def test_has():
assert cot(x).has(x)
assert cot(x).has(cot)
assert not cot(x).has(sin)
assert sin(x).has(x)
assert sin(x).has(sin)
assert not sin(x).has(cot)
assert exp(x).has(exp)
@_both_exp_pow
def test_sin_exp_rewrite():
assert sin(x).rewrite(sin, exp) == -I/2*(exp(I*x) - exp(-I*x))
assert sin(x).rewrite(sin, exp).rewrite(exp, sin) == sin(x)
assert cos(x).rewrite(cos, exp).rewrite(exp, cos) == cos(x)
assert (sin(5*y) - sin(
2*x)).rewrite(sin, exp).rewrite(exp, sin) == sin(5*y) - sin(2*x)
assert sin(x + y).rewrite(sin, exp).rewrite(exp, sin) == sin(x + y)
assert cos(x + y).rewrite(cos, exp).rewrite(exp, cos) == cos(x + y)
# This next test currently passes... not clear whether it should or not?
assert cos(x).rewrite(cos, exp).rewrite(exp, sin) == cos(x)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,204 @@
from sympy.core.mul import Mul
from sympy.core.numbers import (I, Integer, Rational)
from sympy.core.symbol import Symbol
from sympy.functions.elementary.miscellaneous import (root, sqrt)
from sympy.functions.elementary.trigonometric import cos
from sympy.integrals.integrals import Integral
from sympy.simplify.sqrtdenest import sqrtdenest
from sympy.simplify.sqrtdenest import (
_subsets as subsets, _sqrt_numeric_denest)
r2, r3, r5, r6, r7, r10, r15, r29 = [sqrt(x) for x in (2, 3, 5, 6, 7, 10,
15, 29)]
def test_sqrtdenest():
d = {sqrt(5 + 2 * r6): r2 + r3,
sqrt(5. + 2 * r6): sqrt(5. + 2 * r6),
sqrt(5. + 4*sqrt(5 + 2 * r6)): sqrt(5.0 + 4*r2 + 4*r3),
sqrt(r2): sqrt(r2),
sqrt(5 + r7): sqrt(5 + r7),
sqrt(3 + sqrt(5 + 2*r7)):
3*r2*(5 + 2*r7)**Rational(1, 4)/(2*sqrt(6 + 3*r7)) +
r2*sqrt(6 + 3*r7)/(2*(5 + 2*r7)**Rational(1, 4)),
sqrt(3 + 2*r3): 3**Rational(3, 4)*(r6/2 + 3*r2/2)/3}
for i in d:
assert sqrtdenest(i) == d[i], i
def test_sqrtdenest2():
assert sqrtdenest(sqrt(16 - 2*r29 + 2*sqrt(55 - 10*r29))) == \
r5 + sqrt(11 - 2*r29)
e = sqrt(-r5 + sqrt(-2*r29 + 2*sqrt(-10*r29 + 55) + 16))
assert sqrtdenest(e) == root(-2*r29 + 11, 4)
r = sqrt(1 + r7)
assert sqrtdenest(sqrt(1 + r)) == sqrt(1 + r)
e = sqrt(((1 + sqrt(1 + 2*sqrt(3 + r2 + r5)))**2).expand())
assert sqrtdenest(e) == 1 + sqrt(1 + 2*sqrt(r2 + r5 + 3))
assert sqrtdenest(sqrt(5*r3 + 6*r2)) == \
sqrt(2)*root(3, 4) + root(3, 4)**3
assert sqrtdenest(sqrt(((1 + r5 + sqrt(1 + r3))**2).expand())) == \
1 + r5 + sqrt(1 + r3)
assert sqrtdenest(sqrt(((1 + r5 + r7 + sqrt(1 + r3))**2).expand())) == \
1 + sqrt(1 + r3) + r5 + r7
e = sqrt(((1 + cos(2) + cos(3) + sqrt(1 + r3))**2).expand())
assert sqrtdenest(e) == cos(3) + cos(2) + 1 + sqrt(1 + r3)
e = sqrt(-2*r10 + 2*r2*sqrt(-2*r10 + 11) + 14)
assert sqrtdenest(e) == sqrt(-2*r10 - 2*r2 + 4*r5 + 14)
# check that the result is not more complicated than the input
z = sqrt(-2*r29 + cos(2) + 2*sqrt(-10*r29 + 55) + 16)
assert sqrtdenest(z) == z
assert sqrtdenest(sqrt(r6 + sqrt(15))) == sqrt(r6 + sqrt(15))
z = sqrt(15 - 2*sqrt(31) + 2*sqrt(55 - 10*r29))
assert sqrtdenest(z) == z
def test_sqrtdenest_rec():
assert sqrtdenest(sqrt(-4*sqrt(14) - 2*r6 + 4*sqrt(21) + 33)) == \
-r2 + r3 + 2*r7
assert sqrtdenest(sqrt(-28*r7 - 14*r5 + 4*sqrt(35) + 82)) == \
-7 + r5 + 2*r7
assert sqrtdenest(sqrt(6*r2/11 + 2*sqrt(22)/11 + 6*sqrt(11)/11 + 2)) == \
sqrt(11)*(r2 + 3 + sqrt(11))/11
assert sqrtdenest(sqrt(468*r3 + 3024*r2 + 2912*r6 + 19735)) == \
9*r3 + 26 + 56*r6
z = sqrt(-490*r3 - 98*sqrt(115) - 98*sqrt(345) - 2107)
assert sqrtdenest(z) == sqrt(-1)*(7*r5 + 7*r15 + 7*sqrt(23))
z = sqrt(-4*sqrt(14) - 2*r6 + 4*sqrt(21) + 34)
assert sqrtdenest(z) == z
assert sqrtdenest(sqrt(-8*r2 - 2*r5 + 18)) == -r10 + 1 + r2 + r5
assert sqrtdenest(sqrt(8*r2 + 2*r5 - 18)) == \
sqrt(-1)*(-r10 + 1 + r2 + r5)
assert sqrtdenest(sqrt(8*r2/3 + 14*r5/3 + Rational(154, 9))) == \
-r10/3 + r2 + r5 + 3
assert sqrtdenest(sqrt(sqrt(2*r6 + 5) + sqrt(2*r7 + 8))) == \
sqrt(1 + r2 + r3 + r7)
assert sqrtdenest(sqrt(4*r15 + 8*r5 + 12*r3 + 24)) == 1 + r3 + r5 + r15
w = 1 + r2 + r3 + r5 + r7
assert sqrtdenest(sqrt((w**2).expand())) == w
z = sqrt((w**2).expand() + 1)
assert sqrtdenest(z) == z
z = sqrt(2*r10 + 6*r2 + 4*r5 + 12 + 10*r15 + 30*r3)
assert sqrtdenest(z) == z
def test_issue_6241():
z = sqrt( -320 + 32*sqrt(5) + 64*r15)
assert sqrtdenest(z) == z
def test_sqrtdenest3():
z = sqrt(13 - 2*r10 + 2*r2*sqrt(-2*r10 + 11))
assert sqrtdenest(z) == -1 + r2 + r10
assert sqrtdenest(z, max_iter=1) == -1 + sqrt(2) + sqrt(10)
z = sqrt(sqrt(r2 + 2) + 2)
assert sqrtdenest(z) == z
assert sqrtdenest(sqrt(-2*r10 + 4*r2*sqrt(-2*r10 + 11) + 20)) == \
sqrt(-2*r10 - 4*r2 + 8*r5 + 20)
assert sqrtdenest(sqrt((112 + 70*r2) + (46 + 34*r2)*r5)) == \
r10 + 5 + 4*r2 + 3*r5
z = sqrt(5 + sqrt(2*r6 + 5)*sqrt(-2*r29 + 2*sqrt(-10*r29 + 55) + 16))
r = sqrt(-2*r29 + 11)
assert sqrtdenest(z) == sqrt(r2*r + r3*r + r10 + r15 + 5)
n = sqrt(2*r6/7 + 2*r7/7 + 2*sqrt(42)/7 + 2)
d = sqrt(16 - 2*r29 + 2*sqrt(55 - 10*r29))
assert sqrtdenest(n/d) == r7*(1 + r6 + r7)/(Mul(7, (sqrt(-2*r29 + 11) + r5),
evaluate=False))
def test_sqrtdenest4():
# see Denest_en.pdf in https://github.com/sympy/sympy/issues/3192
z = sqrt(8 - r2*sqrt(5 - r5) - sqrt(3)*(1 + r5))
z1 = sqrtdenest(z)
c = sqrt(-r5 + 5)
z1 = ((-r15*c - r3*c + c + r5*c - r6 - r2 + r10 + sqrt(30))/4).expand()
assert sqrtdenest(z) == z1
z = sqrt(2*r2*sqrt(r2 + 2) + 5*r2 + 4*sqrt(r2 + 2) + 8)
assert sqrtdenest(z) == r2 + sqrt(r2 + 2) + 2
w = 2 + r2 + r3 + (1 + r3)*sqrt(2 + r2 + 5*r3)
z = sqrt((w**2).expand())
assert sqrtdenest(z) == w.expand()
def test_sqrt_symbolic_denest():
x = Symbol('x')
z = sqrt(((1 + sqrt(sqrt(2 + x) + 3))**2).expand())
assert sqrtdenest(z) == sqrt((1 + sqrt(sqrt(2 + x) + 3))**2)
z = sqrt(((1 + sqrt(sqrt(2 + cos(1)) + 3))**2).expand())
assert sqrtdenest(z) == 1 + sqrt(sqrt(2 + cos(1)) + 3)
z = ((1 + cos(2))**4 + 1).expand()
assert sqrtdenest(z) == z
z = sqrt(((1 + sqrt(sqrt(2 + cos(3*x)) + 3))**2 + 1).expand())
assert sqrtdenest(z) == z
c = cos(3)
c2 = c**2
assert sqrtdenest(sqrt(2*sqrt(1 + r3)*c + c2 + 1 + r3*c2)) == \
-1 - sqrt(1 + r3)*c
ra = sqrt(1 + r3)
z = sqrt(20*ra*sqrt(3 + 3*r3) + 12*r3*ra*sqrt(3 + 3*r3) + 64*r3 + 112)
assert sqrtdenest(z) == z
def test_issue_5857():
from sympy.abc import x, y
z = sqrt(1/(4*r3 + 7) + 1)
ans = (r2 + r6)/(r3 + 2)
assert sqrtdenest(z) == ans
assert sqrtdenest(1 + z) == 1 + ans
assert sqrtdenest(Integral(z + 1, (x, 1, 2))) == \
Integral(1 + ans, (x, 1, 2))
assert sqrtdenest(x + sqrt(y)) == x + sqrt(y)
ans = (r2 + r6)/(r3 + 2)
assert sqrtdenest(z) == ans
assert sqrtdenest(1 + z) == 1 + ans
assert sqrtdenest(Integral(z + 1, (x, 1, 2))) == \
Integral(1 + ans, (x, 1, 2))
assert sqrtdenest(x + sqrt(y)) == x + sqrt(y)
def test_subsets():
assert subsets(1) == [[1]]
assert subsets(4) == [
[1, 0, 0, 0], [0, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 0], [1, 0, 1, 0],
[0, 1, 1, 0], [1, 1, 1, 0], [0, 0, 0, 1], [1, 0, 0, 1], [0, 1, 0, 1],
[1, 1, 0, 1], [0, 0, 1, 1], [1, 0, 1, 1], [0, 1, 1, 1], [1, 1, 1, 1]]
def test_issue_5653():
assert sqrtdenest(
sqrt(2 + sqrt(2 + sqrt(2)))) == sqrt(2 + sqrt(2 + sqrt(2)))
def test_issue_12420():
assert sqrtdenest((3 - sqrt(2)*sqrt(4 + 3*I) + 3*I)/2) == I
e = 3 - sqrt(2)*sqrt(4 + I) + 3*I
assert sqrtdenest(e) == e
def test_sqrt_ratcomb():
assert sqrtdenest(sqrt(1 + r3) + sqrt(3 + 3*r3) - sqrt(10 + 6*r3)) == 0
def test_issue_18041():
e = -sqrt(-2 + 2*sqrt(3)*I)
assert sqrtdenest(e) == -1 - sqrt(3)*I
def test_issue_19914():
a = Integer(-8)
b = Integer(-1)
r = Integer(63)
d2 = a*a - b*b*r
assert _sqrt_numeric_denest(a, b, r, d2) == \
sqrt(14)*I/2 + 3*sqrt(2)*I/2
assert sqrtdenest(sqrt(-8-sqrt(63))) == sqrt(14)*I/2 + 3*sqrt(2)*I/2

View File

@ -0,0 +1,520 @@
from itertools import product
from sympy.core.function import (Subs, count_ops, diff, expand)
from sympy.core.numbers import (E, I, Rational, pi)
from sympy.core.singleton import S
from sympy.core.symbol import (Symbol, symbols)
from sympy.functions.elementary.exponential import (exp, log)
from sympy.functions.elementary.hyperbolic import (cosh, coth, sinh, tanh)
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.piecewise import Piecewise
from sympy.functions.elementary.trigonometric import (cos, cot, sin, tan)
from sympy.functions.elementary.trigonometric import (acos, asin, atan2)
from sympy.functions.elementary.trigonometric import (asec, acsc)
from sympy.functions.elementary.trigonometric import (acot, atan)
from sympy.integrals.integrals import integrate
from sympy.matrices.dense import Matrix
from sympy.simplify.simplify import simplify
from sympy.simplify.trigsimp import (exptrigsimp, trigsimp)
from sympy.testing.pytest import XFAIL
from sympy.abc import x, y
def test_trigsimp1():
x, y = symbols('x,y')
assert trigsimp(1 - sin(x)**2) == cos(x)**2
assert trigsimp(1 - cos(x)**2) == sin(x)**2
assert trigsimp(sin(x)**2 + cos(x)**2) == 1
assert trigsimp(1 + tan(x)**2) == 1/cos(x)**2
assert trigsimp(1/cos(x)**2 - 1) == tan(x)**2
assert trigsimp(1/cos(x)**2 - tan(x)**2) == 1
assert trigsimp(1 + cot(x)**2) == 1/sin(x)**2
assert trigsimp(1/sin(x)**2 - 1) == 1/tan(x)**2
assert trigsimp(1/sin(x)**2 - cot(x)**2) == 1
assert trigsimp(5*cos(x)**2 + 5*sin(x)**2) == 5
assert trigsimp(5*cos(x/2)**2 + 2*sin(x/2)**2) == 3*cos(x)/2 + Rational(7, 2)
assert trigsimp(sin(x)/cos(x)) == tan(x)
assert trigsimp(2*tan(x)*cos(x)) == 2*sin(x)
assert trigsimp(cot(x)**3*sin(x)**3) == cos(x)**3
assert trigsimp(y*tan(x)**2/sin(x)**2) == y/cos(x)**2
assert trigsimp(cot(x)/cos(x)) == 1/sin(x)
assert trigsimp(sin(x + y) + sin(x - y)) == 2*sin(x)*cos(y)
assert trigsimp(sin(x + y) - sin(x - y)) == 2*sin(y)*cos(x)
assert trigsimp(cos(x + y) + cos(x - y)) == 2*cos(x)*cos(y)
assert trigsimp(cos(x + y) - cos(x - y)) == -2*sin(x)*sin(y)
assert trigsimp(tan(x + y) - tan(x)/(1 - tan(x)*tan(y))) == \
sin(y)/(-sin(y)*tan(x) + cos(y)) # -tan(y)/(tan(x)*tan(y) - 1)
assert trigsimp(sinh(x + y) + sinh(x - y)) == 2*sinh(x)*cosh(y)
assert trigsimp(sinh(x + y) - sinh(x - y)) == 2*sinh(y)*cosh(x)
assert trigsimp(cosh(x + y) + cosh(x - y)) == 2*cosh(x)*cosh(y)
assert trigsimp(cosh(x + y) - cosh(x - y)) == 2*sinh(x)*sinh(y)
assert trigsimp(tanh(x + y) - tanh(x)/(1 + tanh(x)*tanh(y))) == \
sinh(y)/(sinh(y)*tanh(x) + cosh(y))
assert trigsimp(cos(0.12345)**2 + sin(0.12345)**2) == 1.0
e = 2*sin(x)**2 + 2*cos(x)**2
assert trigsimp(log(e)) == log(2)
def test_trigsimp1a():
assert trigsimp(sin(2)**2*cos(3)*exp(2)/cos(2)**2) == tan(2)**2*cos(3)*exp(2)
assert trigsimp(tan(2)**2*cos(3)*exp(2)*cos(2)**2) == sin(2)**2*cos(3)*exp(2)
assert trigsimp(cot(2)*cos(3)*exp(2)*sin(2)) == cos(3)*exp(2)*cos(2)
assert trigsimp(tan(2)*cos(3)*exp(2)/sin(2)) == cos(3)*exp(2)/cos(2)
assert trigsimp(cot(2)*cos(3)*exp(2)/cos(2)) == cos(3)*exp(2)/sin(2)
assert trigsimp(cot(2)*cos(3)*exp(2)*tan(2)) == cos(3)*exp(2)
assert trigsimp(sinh(2)*cos(3)*exp(2)/cosh(2)) == tanh(2)*cos(3)*exp(2)
assert trigsimp(tanh(2)*cos(3)*exp(2)*cosh(2)) == sinh(2)*cos(3)*exp(2)
assert trigsimp(coth(2)*cos(3)*exp(2)*sinh(2)) == cosh(2)*cos(3)*exp(2)
assert trigsimp(tanh(2)*cos(3)*exp(2)/sinh(2)) == cos(3)*exp(2)/cosh(2)
assert trigsimp(coth(2)*cos(3)*exp(2)/cosh(2)) == cos(3)*exp(2)/sinh(2)
assert trigsimp(coth(2)*cos(3)*exp(2)*tanh(2)) == cos(3)*exp(2)
def test_trigsimp2():
x, y = symbols('x,y')
assert trigsimp(cos(x)**2*sin(y)**2 + cos(x)**2*cos(y)**2 + sin(x)**2,
recursive=True) == 1
assert trigsimp(sin(x)**2*sin(y)**2 + sin(x)**2*cos(y)**2 + cos(x)**2,
recursive=True) == 1
assert trigsimp(
Subs(x, x, sin(y)**2 + cos(y)**2)) == Subs(x, x, 1)
def test_issue_4373():
x = Symbol("x")
assert abs(trigsimp(2.0*sin(x)**2 + 2.0*cos(x)**2) - 2.0) < 1e-10
def test_trigsimp3():
x, y = symbols('x,y')
assert trigsimp(sin(x)/cos(x)) == tan(x)
assert trigsimp(sin(x)**2/cos(x)**2) == tan(x)**2
assert trigsimp(sin(x)**3/cos(x)**3) == tan(x)**3
assert trigsimp(sin(x)**10/cos(x)**10) == tan(x)**10
assert trigsimp(cos(x)/sin(x)) == 1/tan(x)
assert trigsimp(cos(x)**2/sin(x)**2) == 1/tan(x)**2
assert trigsimp(cos(x)**10/sin(x)**10) == 1/tan(x)**10
assert trigsimp(tan(x)) == trigsimp(sin(x)/cos(x))
def test_issue_4661():
a, x, y = symbols('a x y')
eq = -4*sin(x)**4 + 4*cos(x)**4 - 8*cos(x)**2
assert trigsimp(eq) == -4
n = sin(x)**6 + 4*sin(x)**4*cos(x)**2 + 5*sin(x)**2*cos(x)**4 + 2*cos(x)**6
d = -sin(x)**2 - 2*cos(x)**2
assert simplify(n/d) == -1
assert trigsimp(-2*cos(x)**2 + cos(x)**4 - sin(x)**4) == -1
eq = (- sin(x)**3/4)*cos(x) + (cos(x)**3/4)*sin(x) - sin(2*x)*cos(2*x)/8
assert trigsimp(eq) == 0
def test_issue_4494():
a, b = symbols('a b')
eq = sin(a)**2*sin(b)**2 + cos(a)**2*cos(b)**2*tan(a)**2 + cos(a)**2
assert trigsimp(eq) == 1
def test_issue_5948():
a, x, y = symbols('a x y')
assert trigsimp(diff(integrate(cos(x)/sin(x)**7, x), x)) == \
cos(x)/sin(x)**7
def test_issue_4775():
a, x, y = symbols('a x y')
assert trigsimp(sin(x)*cos(y)+cos(x)*sin(y)) == sin(x + y)
assert trigsimp(sin(x)*cos(y)+cos(x)*sin(y)+3) == sin(x + y) + 3
def test_issue_4280():
a, x, y = symbols('a x y')
assert trigsimp(cos(x)**2 + cos(y)**2*sin(x)**2 + sin(y)**2*sin(x)**2) == 1
assert trigsimp(a**2*sin(x)**2 + a**2*cos(y)**2*cos(x)**2 + a**2*cos(x)**2*sin(y)**2) == a**2
assert trigsimp(a**2*cos(y)**2*sin(x)**2 + a**2*sin(y)**2*sin(x)**2) == a**2*sin(x)**2
def test_issue_3210():
eqs = (sin(2)*cos(3) + sin(3)*cos(2),
-sin(2)*sin(3) + cos(2)*cos(3),
sin(2)*cos(3) - sin(3)*cos(2),
sin(2)*sin(3) + cos(2)*cos(3),
sin(2)*sin(3) + cos(2)*cos(3) + cos(2),
sinh(2)*cosh(3) + sinh(3)*cosh(2),
sinh(2)*sinh(3) + cosh(2)*cosh(3),
)
assert [trigsimp(e) for e in eqs] == [
sin(5),
cos(5),
-sin(1),
cos(1),
cos(1) + cos(2),
sinh(5),
cosh(5),
]
def test_trigsimp_issues():
a, x, y = symbols('a x y')
# issue 4625 - factor_terms works, too
assert trigsimp(sin(x)**3 + cos(x)**2*sin(x)) == sin(x)
# issue 5948
assert trigsimp(diff(integrate(cos(x)/sin(x)**3, x), x)) == \
cos(x)/sin(x)**3
assert trigsimp(diff(integrate(sin(x)/cos(x)**3, x), x)) == \
sin(x)/cos(x)**3
# check integer exponents
e = sin(x)**y/cos(x)**y
assert trigsimp(e) == e
assert trigsimp(e.subs(y, 2)) == tan(x)**2
assert trigsimp(e.subs(x, 1)) == tan(1)**y
# check for multiple patterns
assert (cos(x)**2/sin(x)**2*cos(y)**2/sin(y)**2).trigsimp() == \
1/tan(x)**2/tan(y)**2
assert trigsimp(cos(x)/sin(x)*cos(x+y)/sin(x+y)) == \
1/(tan(x)*tan(x + y))
eq = cos(2)*(cos(3) + 1)**2/(cos(3) - 1)**2
assert trigsimp(eq) == eq.factor() # factor makes denom (-1 + cos(3))**2
assert trigsimp(cos(2)*(cos(3) + 1)**2*(cos(3) - 1)**2) == \
cos(2)*sin(3)**4
# issue 6789; this generates an expression that formerly caused
# trigsimp to hang
assert cot(x).equals(tan(x)) is False
# nan or the unchanged expression is ok, but not sin(1)
z = cos(x)**2 + sin(x)**2 - 1
z1 = tan(x)**2 - 1/cot(x)**2
n = (1 + z1/z)
assert trigsimp(sin(n)) != sin(1)
eq = x*(n - 1) - x*n
assert trigsimp(eq) is S.NaN
assert trigsimp(eq, recursive=True) is S.NaN
assert trigsimp(1).is_Integer
assert trigsimp(-sin(x)**4 - 2*sin(x)**2*cos(x)**2 - cos(x)**4) == -1
def test_trigsimp_issue_2515():
x = Symbol('x')
assert trigsimp(x*cos(x)*tan(x)) == x*sin(x)
assert trigsimp(-sin(x) + cos(x)*tan(x)) == 0
def test_trigsimp_issue_3826():
assert trigsimp(tan(2*x).expand(trig=True)) == tan(2*x)
def test_trigsimp_issue_4032():
n = Symbol('n', integer=True, positive=True)
assert trigsimp(2**(n/2)*cos(pi*n/4)/2 + 2**(n - 1)/2) == \
2**(n/2)*cos(pi*n/4)/2 + 2**n/4
def test_trigsimp_issue_7761():
assert trigsimp(cosh(pi/4)) == cosh(pi/4)
def test_trigsimp_noncommutative():
x, y = symbols('x,y')
A, B = symbols('A,B', commutative=False)
assert trigsimp(A - A*sin(x)**2) == A*cos(x)**2
assert trigsimp(A - A*cos(x)**2) == A*sin(x)**2
assert trigsimp(A*sin(x)**2 + A*cos(x)**2) == A
assert trigsimp(A + A*tan(x)**2) == A/cos(x)**2
assert trigsimp(A/cos(x)**2 - A) == A*tan(x)**2
assert trigsimp(A/cos(x)**2 - A*tan(x)**2) == A
assert trigsimp(A + A*cot(x)**2) == A/sin(x)**2
assert trigsimp(A/sin(x)**2 - A) == A/tan(x)**2
assert trigsimp(A/sin(x)**2 - A*cot(x)**2) == A
assert trigsimp(y*A*cos(x)**2 + y*A*sin(x)**2) == y*A
assert trigsimp(A*sin(x)/cos(x)) == A*tan(x)
assert trigsimp(A*tan(x)*cos(x)) == A*sin(x)
assert trigsimp(A*cot(x)**3*sin(x)**3) == A*cos(x)**3
assert trigsimp(y*A*tan(x)**2/sin(x)**2) == y*A/cos(x)**2
assert trigsimp(A*cot(x)/cos(x)) == A/sin(x)
assert trigsimp(A*sin(x + y) + A*sin(x - y)) == 2*A*sin(x)*cos(y)
assert trigsimp(A*sin(x + y) - A*sin(x - y)) == 2*A*sin(y)*cos(x)
assert trigsimp(A*cos(x + y) + A*cos(x - y)) == 2*A*cos(x)*cos(y)
assert trigsimp(A*cos(x + y) - A*cos(x - y)) == -2*A*sin(x)*sin(y)
assert trigsimp(A*sinh(x + y) + A*sinh(x - y)) == 2*A*sinh(x)*cosh(y)
assert trigsimp(A*sinh(x + y) - A*sinh(x - y)) == 2*A*sinh(y)*cosh(x)
assert trigsimp(A*cosh(x + y) + A*cosh(x - y)) == 2*A*cosh(x)*cosh(y)
assert trigsimp(A*cosh(x + y) - A*cosh(x - y)) == 2*A*sinh(x)*sinh(y)
assert trigsimp(A*cos(0.12345)**2 + A*sin(0.12345)**2) == 1.0*A
def test_hyperbolic_simp():
x, y = symbols('x,y')
assert trigsimp(sinh(x)**2 + 1) == cosh(x)**2
assert trigsimp(cosh(x)**2 - 1) == sinh(x)**2
assert trigsimp(cosh(x)**2 - sinh(x)**2) == 1
assert trigsimp(1 - tanh(x)**2) == 1/cosh(x)**2
assert trigsimp(1 - 1/cosh(x)**2) == tanh(x)**2
assert trigsimp(tanh(x)**2 + 1/cosh(x)**2) == 1
assert trigsimp(coth(x)**2 - 1) == 1/sinh(x)**2
assert trigsimp(1/sinh(x)**2 + 1) == 1/tanh(x)**2
assert trigsimp(coth(x)**2 - 1/sinh(x)**2) == 1
assert trigsimp(5*cosh(x)**2 - 5*sinh(x)**2) == 5
assert trigsimp(5*cosh(x/2)**2 - 2*sinh(x/2)**2) == 3*cosh(x)/2 + Rational(7, 2)
assert trigsimp(sinh(x)/cosh(x)) == tanh(x)
assert trigsimp(tanh(x)) == trigsimp(sinh(x)/cosh(x))
assert trigsimp(cosh(x)/sinh(x)) == 1/tanh(x)
assert trigsimp(2*tanh(x)*cosh(x)) == 2*sinh(x)
assert trigsimp(coth(x)**3*sinh(x)**3) == cosh(x)**3
assert trigsimp(y*tanh(x)**2/sinh(x)**2) == y/cosh(x)**2
assert trigsimp(coth(x)/cosh(x)) == 1/sinh(x)
for a in (pi/6*I, pi/4*I, pi/3*I):
assert trigsimp(sinh(a)*cosh(x) + cosh(a)*sinh(x)) == sinh(x + a)
assert trigsimp(-sinh(a)*cosh(x) + cosh(a)*sinh(x)) == sinh(x - a)
e = 2*cosh(x)**2 - 2*sinh(x)**2
assert trigsimp(log(e)) == log(2)
# issue 19535:
assert trigsimp(sqrt(cosh(x)**2 - 1)) == sqrt(sinh(x)**2)
assert trigsimp(cosh(x)**2*cosh(y)**2 - cosh(x)**2*sinh(y)**2 - sinh(x)**2,
recursive=True) == 1
assert trigsimp(sinh(x)**2*sinh(y)**2 - sinh(x)**2*cosh(y)**2 + cosh(x)**2,
recursive=True) == 1
assert abs(trigsimp(2.0*cosh(x)**2 - 2.0*sinh(x)**2) - 2.0) < 1e-10
assert trigsimp(sinh(x)**2/cosh(x)**2) == tanh(x)**2
assert trigsimp(sinh(x)**3/cosh(x)**3) == tanh(x)**3
assert trigsimp(sinh(x)**10/cosh(x)**10) == tanh(x)**10
assert trigsimp(cosh(x)**3/sinh(x)**3) == 1/tanh(x)**3
assert trigsimp(cosh(x)/sinh(x)) == 1/tanh(x)
assert trigsimp(cosh(x)**2/sinh(x)**2) == 1/tanh(x)**2
assert trigsimp(cosh(x)**10/sinh(x)**10) == 1/tanh(x)**10
assert trigsimp(x*cosh(x)*tanh(x)) == x*sinh(x)
assert trigsimp(-sinh(x) + cosh(x)*tanh(x)) == 0
assert tan(x) != 1/cot(x) # cot doesn't auto-simplify
assert trigsimp(tan(x) - 1/cot(x)) == 0
assert trigsimp(3*tanh(x)**7 - 2/coth(x)**7) == tanh(x)**7
def test_trigsimp_groebner():
from sympy.simplify.trigsimp import trigsimp_groebner
c = cos(x)
s = sin(x)
ex = (4*s*c + 12*s + 5*c**3 + 21*c**2 + 23*c + 15)/(
-s*c**2 + 2*s*c + 15*s + 7*c**3 + 31*c**2 + 37*c + 21)
resnum = (5*s - 5*c + 1)
resdenom = (8*s - 6*c)
results = [resnum/resdenom, (-resnum)/(-resdenom)]
assert trigsimp_groebner(ex) in results
assert trigsimp_groebner(s/c, hints=[tan]) == tan(x)
assert trigsimp_groebner(c*s) == c*s
assert trigsimp((-s + 1)/c + c/(-s + 1),
method='groebner') == 2/c
assert trigsimp((-s + 1)/c + c/(-s + 1),
method='groebner', polynomial=True) == 2/c
# Test quick=False works
assert trigsimp_groebner(ex, hints=[2]) in results
assert trigsimp_groebner(ex, hints=[int(2)]) in results
# test "I"
assert trigsimp_groebner(sin(I*x)/cos(I*x), hints=[tanh]) == I*tanh(x)
# test hyperbolic / sums
assert trigsimp_groebner((tanh(x)+tanh(y))/(1+tanh(x)*tanh(y)),
hints=[(tanh, x, y)]) == tanh(x + y)
def test_issue_2827_trigsimp_methods():
measure1 = lambda expr: len(str(expr))
measure2 = lambda expr: -count_ops(expr)
# Return the most complicated result
expr = (x + 1)/(x + sin(x)**2 + cos(x)**2)
ans = Matrix([1])
M = Matrix([expr])
assert trigsimp(M, method='fu', measure=measure1) == ans
assert trigsimp(M, method='fu', measure=measure2) != ans
# all methods should work with Basic expressions even if they
# aren't Expr
M = Matrix.eye(1)
assert all(trigsimp(M, method=m) == M for m in
'fu matching groebner old'.split())
# watch for E in exptrigsimp, not only exp()
eq = 1/sqrt(E) + E
assert exptrigsimp(eq) == eq
def test_issue_15129_trigsimp_methods():
t1 = Matrix([sin(Rational(1, 50)), cos(Rational(1, 50)), 0])
t2 = Matrix([sin(Rational(1, 25)), cos(Rational(1, 25)), 0])
t3 = Matrix([cos(Rational(1, 25)), sin(Rational(1, 25)), 0])
r1 = t1.dot(t2)
r2 = t1.dot(t3)
assert trigsimp(r1) == cos(Rational(1, 50))
assert trigsimp(r2) == sin(Rational(3, 50))
def test_exptrigsimp():
def valid(a, b):
from sympy.core.random import verify_numerically as tn
if not (tn(a, b) and a == b):
return False
return True
assert exptrigsimp(exp(x) + exp(-x)) == 2*cosh(x)
assert exptrigsimp(exp(x) - exp(-x)) == 2*sinh(x)
assert exptrigsimp((2*exp(x)-2*exp(-x))/(exp(x)+exp(-x))) == 2*tanh(x)
assert exptrigsimp((2*exp(2*x)-2)/(exp(2*x)+1)) == 2*tanh(x)
e = [cos(x) + I*sin(x), cos(x) - I*sin(x),
cosh(x) - sinh(x), cosh(x) + sinh(x)]
ok = [exp(I*x), exp(-I*x), exp(-x), exp(x)]
assert all(valid(i, j) for i, j in zip(
[exptrigsimp(ei) for ei in e], ok))
ue = [cos(x) + sin(x), cos(x) - sin(x),
cosh(x) + I*sinh(x), cosh(x) - I*sinh(x)]
assert [exptrigsimp(ei) == ei for ei in ue]
res = []
ok = [y*tanh(1), 1/(y*tanh(1)), I*y*tan(1), -I/(y*tan(1)),
y*tanh(x), 1/(y*tanh(x)), I*y*tan(x), -I/(y*tan(x)),
y*tanh(1 + I), 1/(y*tanh(1 + I))]
for a in (1, I, x, I*x, 1 + I):
w = exp(a)
eq = y*(w - 1/w)/(w + 1/w)
res.append(simplify(eq))
res.append(simplify(1/eq))
assert all(valid(i, j) for i, j in zip(res, ok))
for a in range(1, 3):
w = exp(a)
e = w + 1/w
s = simplify(e)
assert s == exptrigsimp(e)
assert valid(s, 2*cosh(a))
e = w - 1/w
s = simplify(e)
assert s == exptrigsimp(e)
assert valid(s, 2*sinh(a))
def test_exptrigsimp_noncommutative():
a,b = symbols('a b', commutative=False)
x = Symbol('x', commutative=True)
assert exp(a + x) == exptrigsimp(exp(a)*exp(x))
p = exp(a)*exp(b) - exp(b)*exp(a)
assert p == exptrigsimp(p) != 0
def test_powsimp_on_numbers():
assert 2**(Rational(1, 3) - 2) == 2**Rational(1, 3)/4
@XFAIL
def test_issue_6811_fail():
# from doc/src/modules/physics/mechanics/examples.rst, the current `eq`
# at Line 576 (in different variables) was formerly the equivalent and
# shorter expression given below...it would be nice to get the short one
# back again
xp, y, x, z = symbols('xp, y, x, z')
eq = 4*(-19*sin(x)*y + 5*sin(3*x)*y + 15*cos(2*x)*z - 21*z)*xp/(9*cos(x) - 5*cos(3*x))
assert trigsimp(eq) == -2*(2*cos(x)*tan(x)*y + 3*z)*xp/cos(x)
def test_Piecewise():
e1 = x*(x + y) - y*(x + y)
e2 = sin(x)**2 + cos(x)**2
e3 = expand((x + y)*y/x)
# s1 = simplify(e1)
s2 = simplify(e2)
# s3 = simplify(e3)
# trigsimp tries not to touch non-trig containing args
assert trigsimp(Piecewise((e1, e3 < e2), (e3, True))) == \
Piecewise((e1, e3 < s2), (e3, True))
def test_issue_21594():
assert simplify(exp(Rational(1,2)) + exp(Rational(-1,2))) == cosh(S.Half)*2
def test_trigsimp_old():
x, y = symbols('x,y')
assert trigsimp(1 - sin(x)**2, old=True) == cos(x)**2
assert trigsimp(1 - cos(x)**2, old=True) == sin(x)**2
assert trigsimp(sin(x)**2 + cos(x)**2, old=True) == 1
assert trigsimp(1 + tan(x)**2, old=True) == 1/cos(x)**2
assert trigsimp(1/cos(x)**2 - 1, old=True) == tan(x)**2
assert trigsimp(1/cos(x)**2 - tan(x)**2, old=True) == 1
assert trigsimp(1 + cot(x)**2, old=True) == 1/sin(x)**2
assert trigsimp(1/sin(x)**2 - cot(x)**2, old=True) == 1
assert trigsimp(5*cos(x)**2 + 5*sin(x)**2, old=True) == 5
assert trigsimp(sin(x)/cos(x), old=True) == tan(x)
assert trigsimp(2*tan(x)*cos(x), old=True) == 2*sin(x)
assert trigsimp(cot(x)**3*sin(x)**3, old=True) == cos(x)**3
assert trigsimp(y*tan(x)**2/sin(x)**2, old=True) == y/cos(x)**2
assert trigsimp(cot(x)/cos(x), old=True) == 1/sin(x)
assert trigsimp(sin(x + y) + sin(x - y), old=True) == 2*sin(x)*cos(y)
assert trigsimp(sin(x + y) - sin(x - y), old=True) == 2*sin(y)*cos(x)
assert trigsimp(cos(x + y) + cos(x - y), old=True) == 2*cos(x)*cos(y)
assert trigsimp(cos(x + y) - cos(x - y), old=True) == -2*sin(x)*sin(y)
assert trigsimp(sinh(x + y) + sinh(x - y), old=True) == 2*sinh(x)*cosh(y)
assert trigsimp(sinh(x + y) - sinh(x - y), old=True) == 2*sinh(y)*cosh(x)
assert trigsimp(cosh(x + y) + cosh(x - y), old=True) == 2*cosh(x)*cosh(y)
assert trigsimp(cosh(x + y) - cosh(x - y), old=True) == 2*sinh(x)*sinh(y)
assert trigsimp(cos(0.12345)**2 + sin(0.12345)**2, old=True) == 1.0
assert trigsimp(sin(x)/cos(x), old=True, method='combined') == tan(x)
assert trigsimp(sin(x)/cos(x), old=True, method='groebner') == sin(x)/cos(x)
assert trigsimp(sin(x)/cos(x), old=True, method='groebner', hints=[tan]) == tan(x)
assert trigsimp(1-sin(sin(x)**2+cos(x)**2)**2, old=True, deep=True) == cos(1)**2
def test_trigsimp_inverse():
alpha = symbols('alpha')
s, c = sin(alpha), cos(alpha)
for finv in [asin, acos, asec, acsc, atan, acot]:
f = finv.inverse(None)
assert alpha == trigsimp(finv(f(alpha)), inverse=True)
# test atan2(cos, sin), atan2(sin, cos), etc...
for a, b in [[c, s], [s, c]]:
for i, j in product([-1, 1], repeat=2):
angle = atan2(i*b, j*a)
angle_inverted = trigsimp(angle, inverse=True)
assert angle_inverted != angle # assures simplification happened
assert sin(angle_inverted) == trigsimp(sin(angle))
assert cos(angle_inverted) == trigsimp(cos(angle))