I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,60 @@
"""The module helps converting SymPy expressions into shorter forms of them.
for example:
the expression E**(pi*I) will be converted into -1
the expression (x+x)**2 will be converted into 4*x**2
"""
from .simplify import (simplify, hypersimp, hypersimilar,
logcombine, separatevars, posify, besselsimp, kroneckersimp,
signsimp, nsimplify)
from .fu import FU, fu
from .sqrtdenest import sqrtdenest
from .cse_main import cse
from .epathtools import epath, EPath
from .hyperexpand import hyperexpand
from .radsimp import collect, rcollect, radsimp, collect_const, fraction, numer, denom
from .trigsimp import trigsimp, exptrigsimp
from .powsimp import powsimp, powdenest
from .combsimp import combsimp
from .gammasimp import gammasimp
from .ratsimp import ratsimp, ratsimpmodprime
__all__ = [
'simplify', 'hypersimp', 'hypersimilar', 'logcombine', 'separatevars',
'posify', 'besselsimp', 'kroneckersimp', 'signsimp',
'nsimplify',
'FU', 'fu',
'sqrtdenest',
'cse',
'epath', 'EPath',
'hyperexpand',
'collect', 'rcollect', 'radsimp', 'collect_const', 'fraction', 'numer',
'denom',
'trigsimp', 'exptrigsimp',
'powsimp', 'powdenest',
'combsimp',
'gammasimp',
'ratsimp', 'ratsimpmodprime',
]

View File

@ -0,0 +1,114 @@
from sympy.core import Mul
from sympy.core.function import count_ops
from sympy.core.traversal import preorder_traversal, bottom_up
from sympy.functions.combinatorial.factorials import binomial, factorial
from sympy.functions import gamma
from sympy.simplify.gammasimp import gammasimp, _gammasimp
from sympy.utilities.timeutils import timethis
@timethis('combsimp')
def combsimp(expr):
r"""
Simplify combinatorial expressions.
Explanation
===========
This function takes as input an expression containing factorials,
binomials, Pochhammer symbol and other "combinatorial" functions,
and tries to minimize the number of those functions and reduce
the size of their arguments.
The algorithm works by rewriting all combinatorial functions as
gamma functions and applying gammasimp() except simplification
steps that may make an integer argument non-integer. See docstring
of gammasimp for more information.
Then it rewrites expression in terms of factorials and binomials by
rewriting gammas as factorials and converting (a+b)!/a!b! into
binomials.
If expression has gamma functions or combinatorial functions
with non-integer argument, it is automatically passed to gammasimp.
Examples
========
>>> from sympy.simplify import combsimp
>>> from sympy import factorial, binomial, symbols
>>> n, k = symbols('n k', integer = True)
>>> combsimp(factorial(n)/factorial(n - 3))
n*(n - 2)*(n - 1)
>>> combsimp(binomial(n+1, k+1)/binomial(n, k))
(n + 1)/(k + 1)
"""
expr = expr.rewrite(gamma, piecewise=False)
if any(isinstance(node, gamma) and not node.args[0].is_integer
for node in preorder_traversal(expr)):
return gammasimp(expr);
expr = _gammasimp(expr, as_comb = True)
expr = _gamma_as_comb(expr)
return expr
def _gamma_as_comb(expr):
"""
Helper function for combsimp.
Rewrites expression in terms of factorials and binomials
"""
expr = expr.rewrite(factorial)
def f(rv):
if not rv.is_Mul:
return rv
rvd = rv.as_powers_dict()
nd_fact_args = [[], []] # numerator, denominator
for k in rvd:
if isinstance(k, factorial) and rvd[k].is_Integer:
if rvd[k].is_positive:
nd_fact_args[0].extend([k.args[0]]*rvd[k])
else:
nd_fact_args[1].extend([k.args[0]]*-rvd[k])
rvd[k] = 0
if not nd_fact_args[0] or not nd_fact_args[1]:
return rv
hit = False
for m in range(2):
i = 0
while i < len(nd_fact_args[m]):
ai = nd_fact_args[m][i]
for j in range(i + 1, len(nd_fact_args[m])):
aj = nd_fact_args[m][j]
sum = ai + aj
if sum in nd_fact_args[1 - m]:
hit = True
nd_fact_args[1 - m].remove(sum)
del nd_fact_args[m][j]
del nd_fact_args[m][i]
rvd[binomial(sum, ai if count_ops(ai) <
count_ops(aj) else aj)] += (
-1 if m == 0 else 1)
break
else:
i += 1
if hit:
return Mul(*([k**rvd[k] for k in rvd] + [factorial(k)
for k in nd_fact_args[0]]))/Mul(*[factorial(k)
for k in nd_fact_args[1]])
return rv
return bottom_up(expr, f)

View File

@ -0,0 +1,946 @@
""" Tools for doing common subexpression elimination.
"""
from collections import defaultdict
from sympy.core import Basic, Mul, Add, Pow, sympify
from sympy.core.containers import Tuple, OrderedSet
from sympy.core.exprtools import factor_terms
from sympy.core.singleton import S
from sympy.core.sorting import ordered
from sympy.core.symbol import symbols, Symbol
from sympy.matrices import (MatrixBase, Matrix, ImmutableMatrix,
SparseMatrix, ImmutableSparseMatrix)
from sympy.matrices.expressions import (MatrixExpr, MatrixSymbol, MatMul,
MatAdd, MatPow, Inverse)
from sympy.matrices.expressions.matexpr import MatrixElement
from sympy.polys.rootoftools import RootOf
from sympy.utilities.iterables import numbered_symbols, sift, \
topological_sort, iterable
from . import cse_opts
# (preprocessor, postprocessor) pairs which are commonly useful. They should
# each take a SymPy expression and return a possibly transformed expression.
# When used in the function ``cse()``, the target expressions will be transformed
# by each of the preprocessor functions in order. After the common
# subexpressions are eliminated, each resulting expression will have the
# postprocessor functions transform them in *reverse* order in order to undo the
# transformation if necessary. This allows the algorithm to operate on
# a representation of the expressions that allows for more optimization
# opportunities.
# ``None`` can be used to specify no transformation for either the preprocessor or
# postprocessor.
basic_optimizations = [(cse_opts.sub_pre, cse_opts.sub_post),
(factor_terms, None)]
# sometimes we want the output in a different format; non-trivial
# transformations can be put here for users
# ===============================================================
def reps_toposort(r):
"""Sort replacements ``r`` so (k1, v1) appears before (k2, v2)
if k2 is in v1's free symbols. This orders items in the
way that cse returns its results (hence, in order to use the
replacements in a substitution option it would make sense
to reverse the order).
Examples
========
>>> from sympy.simplify.cse_main import reps_toposort
>>> from sympy.abc import x, y
>>> from sympy import Eq
>>> for l, r in reps_toposort([(x, y + 1), (y, 2)]):
... print(Eq(l, r))
...
Eq(y, 2)
Eq(x, y + 1)
"""
r = sympify(r)
E = []
for c1, (k1, v1) in enumerate(r):
for c2, (k2, v2) in enumerate(r):
if k1 in v2.free_symbols:
E.append((c1, c2))
return [r[i] for i in topological_sort((range(len(r)), E))]
def cse_separate(r, e):
"""Move expressions that are in the form (symbol, expr) out of the
expressions and sort them into the replacements using the reps_toposort.
Examples
========
>>> from sympy.simplify.cse_main import cse_separate
>>> from sympy.abc import x, y, z
>>> from sympy import cos, exp, cse, Eq, symbols
>>> x0, x1 = symbols('x:2')
>>> eq = (x + 1 + exp((x + 1)/(y + 1)) + cos(y + 1))
>>> cse([eq, Eq(x, z + 1), z - 2], postprocess=cse_separate) in [
... [[(x0, y + 1), (x, z + 1), (x1, x + 1)],
... [x1 + exp(x1/x0) + cos(x0), z - 2]],
... [[(x1, y + 1), (x, z + 1), (x0, x + 1)],
... [x0 + exp(x0/x1) + cos(x1), z - 2]]]
...
True
"""
d = sift(e, lambda w: w.is_Equality and w.lhs.is_Symbol)
r = r + [w.args for w in d[True]]
e = d[False]
return [reps_toposort(r), e]
def cse_release_variables(r, e):
"""
Return tuples giving ``(a, b)`` where ``a`` is a symbol and ``b`` is
either an expression or None. The value of None is used when a
symbol is no longer needed for subsequent expressions.
Use of such output can reduce the memory footprint of lambdified
expressions that contain large, repeated subexpressions.
Examples
========
>>> from sympy import cse
>>> from sympy.simplify.cse_main import cse_release_variables
>>> from sympy.abc import x, y
>>> eqs = [(x + y - 1)**2, x, x + y, (x + y)/(2*x + 1) + (x + y - 1)**2, (2*x + 1)**(x + y)]
>>> defs, rvs = cse_release_variables(*cse(eqs))
>>> for i in defs:
... print(i)
...
(x0, x + y)
(x1, (x0 - 1)**2)
(x2, 2*x + 1)
(_3, x0/x2 + x1)
(_4, x2**x0)
(x2, None)
(_0, x1)
(x1, None)
(_2, x0)
(x0, None)
(_1, x)
>>> print(rvs)
(_0, _1, _2, _3, _4)
"""
if not r:
return r, e
s, p = zip(*r)
esyms = symbols('_:%d' % len(e))
syms = list(esyms)
s = list(s)
in_use = set(s)
p = list(p)
# sort e so those with most sub-expressions appear first
e = [(e[i], syms[i]) for i in range(len(e))]
e, syms = zip(*sorted(e,
key=lambda x: -sum(p[s.index(i)].count_ops()
for i in x[0].free_symbols & in_use)))
syms = list(syms)
p += e
rv = []
i = len(p) - 1
while i >= 0:
_p = p.pop()
c = in_use & _p.free_symbols
if c: # sorting for canonical results
rv.extend([(s, None) for s in sorted(c, key=str)])
if i >= len(r):
rv.append((syms.pop(), _p))
else:
rv.append((s[i], _p))
in_use -= c
i -= 1
rv.reverse()
return rv, esyms
# ====end of cse postprocess idioms===========================
def preprocess_for_cse(expr, optimizations):
""" Preprocess an expression to optimize for common subexpression
elimination.
Parameters
==========
expr : SymPy expression
The target expression to optimize.
optimizations : list of (callable, callable) pairs
The (preprocessor, postprocessor) pairs.
Returns
=======
expr : SymPy expression
The transformed expression.
"""
for pre, post in optimizations:
if pre is not None:
expr = pre(expr)
return expr
def postprocess_for_cse(expr, optimizations):
"""Postprocess an expression after common subexpression elimination to
return the expression to canonical SymPy form.
Parameters
==========
expr : SymPy expression
The target expression to transform.
optimizations : list of (callable, callable) pairs, optional
The (preprocessor, postprocessor) pairs. The postprocessors will be
applied in reversed order to undo the effects of the preprocessors
correctly.
Returns
=======
expr : SymPy expression
The transformed expression.
"""
for pre, post in reversed(optimizations):
if post is not None:
expr = post(expr)
return expr
class FuncArgTracker:
"""
A class which manages a mapping from functions to arguments and an inverse
mapping from arguments to functions.
"""
def __init__(self, funcs):
# To minimize the number of symbolic comparisons, all function arguments
# get assigned a value number.
self.value_numbers = {}
self.value_number_to_value = []
# Both of these maps use integer indices for arguments / functions.
self.arg_to_funcset = []
self.func_to_argset = []
for func_i, func in enumerate(funcs):
func_argset = OrderedSet()
for func_arg in func.args:
arg_number = self.get_or_add_value_number(func_arg)
func_argset.add(arg_number)
self.arg_to_funcset[arg_number].add(func_i)
self.func_to_argset.append(func_argset)
def get_args_in_value_order(self, argset):
"""
Return the list of arguments in sorted order according to their value
numbers.
"""
return [self.value_number_to_value[argn] for argn in sorted(argset)]
def get_or_add_value_number(self, value):
"""
Return the value number for the given argument.
"""
nvalues = len(self.value_numbers)
value_number = self.value_numbers.setdefault(value, nvalues)
if value_number == nvalues:
self.value_number_to_value.append(value)
self.arg_to_funcset.append(OrderedSet())
return value_number
def stop_arg_tracking(self, func_i):
"""
Remove the function func_i from the argument to function mapping.
"""
for arg in self.func_to_argset[func_i]:
self.arg_to_funcset[arg].remove(func_i)
def get_common_arg_candidates(self, argset, min_func_i=0):
"""Return a dict whose keys are function numbers. The entries of the dict are
the number of arguments said function has in common with
``argset``. Entries have at least 2 items in common. All keys have
value at least ``min_func_i``.
"""
count_map = defaultdict(lambda: 0)
if not argset:
return count_map
funcsets = [self.arg_to_funcset[arg] for arg in argset]
# As an optimization below, we handle the largest funcset separately from
# the others.
largest_funcset = max(funcsets, key=len)
for funcset in funcsets:
if largest_funcset is funcset:
continue
for func_i in funcset:
if func_i >= min_func_i:
count_map[func_i] += 1
# We pick the smaller of the two containers (count_map, largest_funcset)
# to iterate over to reduce the number of iterations needed.
(smaller_funcs_container,
larger_funcs_container) = sorted(
[largest_funcset, count_map],
key=len)
for func_i in smaller_funcs_container:
# Not already in count_map? It can't possibly be in the output, so
# skip it.
if count_map[func_i] < 1:
continue
if func_i in larger_funcs_container:
count_map[func_i] += 1
return {k: v for k, v in count_map.items() if v >= 2}
def get_subset_candidates(self, argset, restrict_to_funcset=None):
"""
Return a set of functions each of which whose argument list contains
``argset``, optionally filtered only to contain functions in
``restrict_to_funcset``.
"""
iarg = iter(argset)
indices = OrderedSet(
fi for fi in self.arg_to_funcset[next(iarg)])
if restrict_to_funcset is not None:
indices &= restrict_to_funcset
for arg in iarg:
indices &= self.arg_to_funcset[arg]
return indices
def update_func_argset(self, func_i, new_argset):
"""
Update a function with a new set of arguments.
"""
new_args = OrderedSet(new_argset)
old_args = self.func_to_argset[func_i]
for deleted_arg in old_args - new_args:
self.arg_to_funcset[deleted_arg].remove(func_i)
for added_arg in new_args - old_args:
self.arg_to_funcset[added_arg].add(func_i)
self.func_to_argset[func_i].clear()
self.func_to_argset[func_i].update(new_args)
class Unevaluated:
def __init__(self, func, args):
self.func = func
self.args = args
def __str__(self):
return "Uneval<{}>({})".format(
self.func, ", ".join(str(a) for a in self.args))
def as_unevaluated_basic(self):
return self.func(*self.args, evaluate=False)
@property
def free_symbols(self):
return set().union(*[a.free_symbols for a in self.args])
__repr__ = __str__
def match_common_args(func_class, funcs, opt_subs):
"""
Recognize and extract common subexpressions of function arguments within a
set of function calls. For instance, for the following function calls::
x + z + y
sin(x + y)
this will extract a common subexpression of `x + y`::
w = x + y
w + z
sin(w)
The function we work with is assumed to be associative and commutative.
Parameters
==========
func_class: class
The function class (e.g. Add, Mul)
funcs: list of functions
A list of function calls.
opt_subs: dict
A dictionary of substitutions which this function may update.
"""
# Sort to ensure that whole-function subexpressions come before the items
# that use them.
funcs = sorted(funcs, key=lambda f: len(f.args))
arg_tracker = FuncArgTracker(funcs)
changed = OrderedSet()
for i in range(len(funcs)):
common_arg_candidates_counts = arg_tracker.get_common_arg_candidates(
arg_tracker.func_to_argset[i], min_func_i=i + 1)
# Sort the candidates in order of match size.
# This makes us try combining smaller matches first.
common_arg_candidates = OrderedSet(sorted(
common_arg_candidates_counts.keys(),
key=lambda k: (common_arg_candidates_counts[k], k)))
while common_arg_candidates:
j = common_arg_candidates.pop(last=False)
com_args = arg_tracker.func_to_argset[i].intersection(
arg_tracker.func_to_argset[j])
if len(com_args) <= 1:
# This may happen if a set of common arguments was already
# combined in a previous iteration.
continue
# For all sets, replace the common symbols by the function
# over them, to allow recursive matches.
diff_i = arg_tracker.func_to_argset[i].difference(com_args)
if diff_i:
# com_func needs to be unevaluated to allow for recursive matches.
com_func = Unevaluated(
func_class, arg_tracker.get_args_in_value_order(com_args))
com_func_number = arg_tracker.get_or_add_value_number(com_func)
arg_tracker.update_func_argset(i, diff_i | OrderedSet([com_func_number]))
changed.add(i)
else:
# Treat the whole expression as a CSE.
#
# The reason this needs to be done is somewhat subtle. Within
# tree_cse(), to_eliminate only contains expressions that are
# seen more than once. The problem is unevaluated expressions
# do not compare equal to the evaluated equivalent. So
# tree_cse() won't mark funcs[i] as a CSE if we use an
# unevaluated version.
com_func_number = arg_tracker.get_or_add_value_number(funcs[i])
diff_j = arg_tracker.func_to_argset[j].difference(com_args)
arg_tracker.update_func_argset(j, diff_j | OrderedSet([com_func_number]))
changed.add(j)
for k in arg_tracker.get_subset_candidates(
com_args, common_arg_candidates):
diff_k = arg_tracker.func_to_argset[k].difference(com_args)
arg_tracker.update_func_argset(k, diff_k | OrderedSet([com_func_number]))
changed.add(k)
if i in changed:
opt_subs[funcs[i]] = Unevaluated(func_class,
arg_tracker.get_args_in_value_order(arg_tracker.func_to_argset[i]))
arg_tracker.stop_arg_tracking(i)
def opt_cse(exprs, order='canonical'):
"""Find optimization opportunities in Adds, Muls, Pows and negative
coefficient Muls.
Parameters
==========
exprs : list of SymPy expressions
The expressions to optimize.
order : string, 'none' or 'canonical'
The order by which Mul and Add arguments are processed. For large
expressions where speed is a concern, use the setting order='none'.
Returns
=======
opt_subs : dictionary of expression substitutions
The expression substitutions which can be useful to optimize CSE.
Examples
========
>>> from sympy.simplify.cse_main import opt_cse
>>> from sympy.abc import x
>>> opt_subs = opt_cse([x**-2])
>>> k, v = list(opt_subs.keys())[0], list(opt_subs.values())[0]
>>> print((k, v.as_unevaluated_basic()))
(x**(-2), 1/(x**2))
"""
opt_subs = {}
adds = OrderedSet()
muls = OrderedSet()
seen_subexp = set()
collapsible_subexp = set()
def _find_opts(expr):
if not isinstance(expr, (Basic, Unevaluated)):
return
if expr.is_Atom or expr.is_Order:
return
if iterable(expr):
list(map(_find_opts, expr))
return
if expr in seen_subexp:
return expr
seen_subexp.add(expr)
list(map(_find_opts, expr.args))
if not isinstance(expr, MatrixExpr) and expr.could_extract_minus_sign():
# XXX -expr does not always work rigorously for some expressions
# containing UnevaluatedExpr.
# https://github.com/sympy/sympy/issues/24818
if isinstance(expr, Add):
neg_expr = Add(*(-i for i in expr.args))
else:
neg_expr = -expr
if not neg_expr.is_Atom:
opt_subs[expr] = Unevaluated(Mul, (S.NegativeOne, neg_expr))
seen_subexp.add(neg_expr)
expr = neg_expr
if isinstance(expr, (Mul, MatMul)):
if len(expr.args) == 1:
collapsible_subexp.add(expr)
else:
muls.add(expr)
elif isinstance(expr, (Add, MatAdd)):
if len(expr.args) == 1:
collapsible_subexp.add(expr)
else:
adds.add(expr)
elif isinstance(expr, Inverse):
# Do not want to treat `Inverse` as a `MatPow`
pass
elif isinstance(expr, (Pow, MatPow)):
base, exp = expr.base, expr.exp
if exp.could_extract_minus_sign():
opt_subs[expr] = Unevaluated(Pow, (Pow(base, -exp), -1))
for e in exprs:
if isinstance(e, (Basic, Unevaluated)):
_find_opts(e)
# Handle collapsing of multinary operations with single arguments
edges = [(s, s.args[0]) for s in collapsible_subexp
if s.args[0] in collapsible_subexp]
for e in reversed(topological_sort((collapsible_subexp, edges))):
opt_subs[e] = opt_subs.get(e.args[0], e.args[0])
# split muls into commutative
commutative_muls = OrderedSet()
for m in muls:
c, nc = m.args_cnc(cset=False)
if c:
c_mul = m.func(*c)
if nc:
if c_mul == 1:
new_obj = m.func(*nc)
else:
if isinstance(m, MatMul):
new_obj = m.func(c_mul, *nc, evaluate=False)
else:
new_obj = m.func(c_mul, m.func(*nc), evaluate=False)
opt_subs[m] = new_obj
if len(c) > 1:
commutative_muls.add(c_mul)
match_common_args(Add, adds, opt_subs)
match_common_args(Mul, commutative_muls, opt_subs)
return opt_subs
def tree_cse(exprs, symbols, opt_subs=None, order='canonical', ignore=()):
"""Perform raw CSE on expression tree, taking opt_subs into account.
Parameters
==========
exprs : list of SymPy expressions
The expressions to reduce.
symbols : infinite iterator yielding unique Symbols
The symbols used to label the common subexpressions which are pulled
out.
opt_subs : dictionary of expression substitutions
The expressions to be substituted before any CSE action is performed.
order : string, 'none' or 'canonical'
The order by which Mul and Add arguments are processed. For large
expressions where speed is a concern, use the setting order='none'.
ignore : iterable of Symbols
Substitutions containing any Symbol from ``ignore`` will be ignored.
"""
if opt_subs is None:
opt_subs = {}
## Find repeated sub-expressions
to_eliminate = set()
seen_subexp = set()
excluded_symbols = set()
def _find_repeated(expr):
if not isinstance(expr, (Basic, Unevaluated)):
return
if isinstance(expr, RootOf):
return
if isinstance(expr, Basic) and (
expr.is_Atom or
expr.is_Order or
isinstance(expr, (MatrixSymbol, MatrixElement))):
if expr.is_Symbol:
excluded_symbols.add(expr.name)
return
if iterable(expr):
args = expr
else:
if expr in seen_subexp:
for ign in ignore:
if ign in expr.free_symbols:
break
else:
to_eliminate.add(expr)
return
seen_subexp.add(expr)
if expr in opt_subs:
expr = opt_subs[expr]
args = expr.args
list(map(_find_repeated, args))
for e in exprs:
if isinstance(e, Basic):
_find_repeated(e)
## Rebuild tree
# Remove symbols from the generator that conflict with names in the expressions.
symbols = (_ for _ in symbols if _.name not in excluded_symbols)
replacements = []
subs = {}
def _rebuild(expr):
if not isinstance(expr, (Basic, Unevaluated)):
return expr
if not expr.args:
return expr
if iterable(expr):
new_args = [_rebuild(arg) for arg in expr.args]
return expr.func(*new_args)
if expr in subs:
return subs[expr]
orig_expr = expr
if expr in opt_subs:
expr = opt_subs[expr]
# If enabled, parse Muls and Adds arguments by order to ensure
# replacement order independent from hashes
if order != 'none':
if isinstance(expr, (Mul, MatMul)):
c, nc = expr.args_cnc()
if c == [1]:
args = nc
else:
args = list(ordered(c)) + nc
elif isinstance(expr, (Add, MatAdd)):
args = list(ordered(expr.args))
else:
args = expr.args
else:
args = expr.args
new_args = list(map(_rebuild, args))
if isinstance(expr, Unevaluated) or new_args != args:
new_expr = expr.func(*new_args)
else:
new_expr = expr
if orig_expr in to_eliminate:
try:
sym = next(symbols)
except StopIteration:
raise ValueError("Symbols iterator ran out of symbols.")
if isinstance(orig_expr, MatrixExpr):
sym = MatrixSymbol(sym.name, orig_expr.rows,
orig_expr.cols)
subs[orig_expr] = sym
replacements.append((sym, new_expr))
return sym
else:
return new_expr
reduced_exprs = []
for e in exprs:
if isinstance(e, Basic):
reduced_e = _rebuild(e)
else:
reduced_e = e
reduced_exprs.append(reduced_e)
return replacements, reduced_exprs
def cse(exprs, symbols=None, optimizations=None, postprocess=None,
order='canonical', ignore=(), list=True):
""" Perform common subexpression elimination on an expression.
Parameters
==========
exprs : list of SymPy expressions, or a single SymPy expression
The expressions to reduce.
symbols : infinite iterator yielding unique Symbols
The symbols used to label the common subexpressions which are pulled
out. The ``numbered_symbols`` generator is useful. The default is a
stream of symbols of the form "x0", "x1", etc. This must be an
infinite iterator.
optimizations : list of (callable, callable) pairs
The (preprocessor, postprocessor) pairs of external optimization
functions. Optionally 'basic' can be passed for a set of predefined
basic optimizations. Such 'basic' optimizations were used by default
in old implementation, however they can be really slow on larger
expressions. Now, no pre or post optimizations are made by default.
postprocess : a function which accepts the two return values of cse and
returns the desired form of output from cse, e.g. if you want the
replacements reversed the function might be the following lambda:
lambda r, e: return reversed(r), e
order : string, 'none' or 'canonical'
The order by which Mul and Add arguments are processed. If set to
'canonical', arguments will be canonically ordered. If set to 'none',
ordering will be faster but dependent on expressions hashes, thus
machine dependent and variable. For large expressions where speed is a
concern, use the setting order='none'.
ignore : iterable of Symbols
Substitutions containing any Symbol from ``ignore`` will be ignored.
list : bool, (default True)
Returns expression in list or else with same type as input (when False).
Returns
=======
replacements : list of (Symbol, expression) pairs
All of the common subexpressions that were replaced. Subexpressions
earlier in this list might show up in subexpressions later in this
list.
reduced_exprs : list of SymPy expressions
The reduced expressions with all of the replacements above.
Examples
========
>>> from sympy import cse, SparseMatrix
>>> from sympy.abc import x, y, z, w
>>> cse(((w + x + y + z)*(w + y + z))/(w + x)**3)
([(x0, y + z), (x1, w + x)], [(w + x0)*(x0 + x1)/x1**3])
List of expressions with recursive substitutions:
>>> m = SparseMatrix([x + y, x + y + z])
>>> cse([(x+y)**2, x + y + z, y + z, x + z + y, m])
([(x0, x + y), (x1, x0 + z)], [x0**2, x1, y + z, x1, Matrix([
[x0],
[x1]])])
Note: the type and mutability of input matrices is retained.
>>> isinstance(_[1][-1], SparseMatrix)
True
The user may disallow substitutions containing certain symbols:
>>> cse([y**2*(x + 1), 3*y**2*(x + 1)], ignore=(y,))
([(x0, x + 1)], [x0*y**2, 3*x0*y**2])
The default return value for the reduced expression(s) is a list, even if there is only
one expression. The `list` flag preserves the type of the input in the output:
>>> cse(x)
([], [x])
>>> cse(x, list=False)
([], x)
"""
if not list:
return _cse_homogeneous(exprs,
symbols=symbols, optimizations=optimizations,
postprocess=postprocess, order=order, ignore=ignore)
if isinstance(exprs, (int, float)):
exprs = sympify(exprs)
# Handle the case if just one expression was passed.
if isinstance(exprs, (Basic, MatrixBase)):
exprs = [exprs]
copy = exprs
temp = []
for e in exprs:
if isinstance(e, (Matrix, ImmutableMatrix)):
temp.append(Tuple(*e.flat()))
elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)):
temp.append(Tuple(*e.todok().items()))
else:
temp.append(e)
exprs = temp
del temp
if optimizations is None:
optimizations = []
elif optimizations == 'basic':
optimizations = basic_optimizations
# Preprocess the expressions to give us better optimization opportunities.
reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs]
if symbols is None:
symbols = numbered_symbols(cls=Symbol)
else:
# In case we get passed an iterable with an __iter__ method instead of
# an actual iterator.
symbols = iter(symbols)
# Find other optimization opportunities.
opt_subs = opt_cse(reduced_exprs, order)
# Main CSE algorithm.
replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs,
order, ignore)
# Postprocess the expressions to return the expressions to canonical form.
exprs = copy
for i, (sym, subtree) in enumerate(replacements):
subtree = postprocess_for_cse(subtree, optimizations)
replacements[i] = (sym, subtree)
reduced_exprs = [postprocess_for_cse(e, optimizations)
for e in reduced_exprs]
# Get the matrices back
for i, e in enumerate(exprs):
if isinstance(e, (Matrix, ImmutableMatrix)):
reduced_exprs[i] = Matrix(e.rows, e.cols, reduced_exprs[i])
if isinstance(e, ImmutableMatrix):
reduced_exprs[i] = reduced_exprs[i].as_immutable()
elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)):
m = SparseMatrix(e.rows, e.cols, {})
for k, v in reduced_exprs[i]:
m[k] = v
if isinstance(e, ImmutableSparseMatrix):
m = m.as_immutable()
reduced_exprs[i] = m
if postprocess is None:
return replacements, reduced_exprs
return postprocess(replacements, reduced_exprs)
def _cse_homogeneous(exprs, **kwargs):
"""
Same as ``cse`` but the ``reduced_exprs`` are returned
with the same type as ``exprs`` or a sympified version of the same.
Parameters
==========
exprs : an Expr, iterable of Expr or dictionary with Expr values
the expressions in which repeated subexpressions will be identified
kwargs : additional arguments for the ``cse`` function
Returns
=======
replacements : list of (Symbol, expression) pairs
All of the common subexpressions that were replaced. Subexpressions
earlier in this list might show up in subexpressions later in this
list.
reduced_exprs : list of SymPy expressions
The reduced expressions with all of the replacements above.
Examples
========
>>> from sympy.simplify.cse_main import cse
>>> from sympy import cos, Tuple, Matrix
>>> from sympy.abc import x
>>> output = lambda x: type(cse(x, list=False)[1])
>>> output(1)
<class 'sympy.core.numbers.One'>
>>> output('cos(x)')
<class 'str'>
>>> output(cos(x))
cos
>>> output(Tuple(1, x))
<class 'sympy.core.containers.Tuple'>
>>> output(Matrix([[1,0], [0,1]]))
<class 'sympy.matrices.dense.MutableDenseMatrix'>
>>> output([1, x])
<class 'list'>
>>> output((1, x))
<class 'tuple'>
>>> output({1, x})
<class 'set'>
"""
if isinstance(exprs, str):
replacements, reduced_exprs = _cse_homogeneous(
sympify(exprs), **kwargs)
return replacements, repr(reduced_exprs)
if isinstance(exprs, (list, tuple, set)):
replacements, reduced_exprs = cse(exprs, **kwargs)
return replacements, type(exprs)(reduced_exprs)
if isinstance(exprs, dict):
keys = list(exprs.keys()) # In order to guarantee the order of the elements.
replacements, values = cse([exprs[k] for k in keys], **kwargs)
reduced_exprs = dict(zip(keys, values))
return replacements, reduced_exprs
try:
replacements, (reduced_exprs,) = cse(exprs, **kwargs)
except TypeError: # For example 'mpf' objects
return [], exprs
else:
return replacements, reduced_exprs

View File

@ -0,0 +1,52 @@
""" Optimizations of the expression tree representation for better CSE
opportunities.
"""
from sympy.core import Add, Basic, Mul
from sympy.core.singleton import S
from sympy.core.sorting import default_sort_key
from sympy.core.traversal import preorder_traversal
def sub_pre(e):
""" Replace y - x with -(x - y) if -1 can be extracted from y - x.
"""
# replacing Add, A, from which -1 can be extracted with -1*-A
adds = [a for a in e.atoms(Add) if a.could_extract_minus_sign()]
reps = {}
ignore = set()
for a in adds:
na = -a
if na.is_Mul: # e.g. MatExpr
ignore.add(a)
continue
reps[a] = Mul._from_args([S.NegativeOne, na])
e = e.xreplace(reps)
# repeat again for persisting Adds but mark these with a leading 1, -1
# e.g. y - x -> 1*-1*(x - y)
if isinstance(e, Basic):
negs = {}
for a in sorted(e.atoms(Add), key=default_sort_key):
if a in ignore:
continue
if a in reps:
negs[a] = reps[a]
elif a.could_extract_minus_sign():
negs[a] = Mul._from_args([S.One, S.NegativeOne, -a])
e = e.xreplace(negs)
return e
def sub_post(e):
""" Replace 1*-1*x with -x.
"""
replacements = []
for node in preorder_traversal(e):
if isinstance(node, Mul) and \
node.args[0] is S.One and node.args[1] is S.NegativeOne:
replacements.append((node, -Mul._from_args(node.args[2:])))
for node, replacement in replacements:
e = e.xreplace({node: replacement})
return e

View File

@ -0,0 +1,356 @@
"""Tools for manipulation of expressions using paths. """
from sympy.core import Basic
class EPath:
r"""
Manipulate expressions using paths.
EPath grammar in EBNF notation::
literal ::= /[A-Za-z_][A-Za-z_0-9]*/
number ::= /-?\d+/
type ::= literal
attribute ::= literal "?"
all ::= "*"
slice ::= "[" number? (":" number? (":" number?)?)? "]"
range ::= all | slice
query ::= (type | attribute) ("|" (type | attribute))*
selector ::= range | query range?
path ::= "/" selector ("/" selector)*
See the docstring of the epath() function.
"""
__slots__ = ("_path", "_epath")
def __new__(cls, path):
"""Construct new EPath. """
if isinstance(path, EPath):
return path
if not path:
raise ValueError("empty EPath")
_path = path
if path[0] == '/':
path = path[1:]
else:
raise NotImplementedError("non-root EPath")
epath = []
for selector in path.split('/'):
selector = selector.strip()
if not selector:
raise ValueError("empty selector")
index = 0
for c in selector:
if c.isalnum() or c in ('_', '|', '?'):
index += 1
else:
break
attrs = []
types = []
if index:
elements = selector[:index]
selector = selector[index:]
for element in elements.split('|'):
element = element.strip()
if not element:
raise ValueError("empty element")
if element.endswith('?'):
attrs.append(element[:-1])
else:
types.append(element)
span = None
if selector == '*':
pass
else:
if selector.startswith('['):
try:
i = selector.index(']')
except ValueError:
raise ValueError("expected ']', got EOL")
_span, span = selector[1:i], []
if ':' not in _span:
span = int(_span)
else:
for elt in _span.split(':', 3):
if not elt:
span.append(None)
else:
span.append(int(elt))
span = slice(*span)
selector = selector[i + 1:]
if selector:
raise ValueError("trailing characters in selector")
epath.append((attrs, types, span))
obj = object.__new__(cls)
obj._path = _path
obj._epath = epath
return obj
def __repr__(self):
return "%s(%r)" % (self.__class__.__name__, self._path)
def _get_ordered_args(self, expr):
"""Sort ``expr.args`` using printing order. """
if expr.is_Add:
return expr.as_ordered_terms()
elif expr.is_Mul:
return expr.as_ordered_factors()
else:
return expr.args
def _hasattrs(self, expr, attrs):
"""Check if ``expr`` has any of ``attrs``. """
for attr in attrs:
if not hasattr(expr, attr):
return False
return True
def _hastypes(self, expr, types):
"""Check if ``expr`` is any of ``types``. """
_types = [ cls.__name__ for cls in expr.__class__.mro() ]
return bool(set(_types).intersection(types))
def _has(self, expr, attrs, types):
"""Apply ``_hasattrs`` and ``_hastypes`` to ``expr``. """
if not (attrs or types):
return True
if attrs and self._hasattrs(expr, attrs):
return True
if types and self._hastypes(expr, types):
return True
return False
def apply(self, expr, func, args=None, kwargs=None):
"""
Modify parts of an expression selected by a path.
Examples
========
>>> from sympy.simplify.epathtools import EPath
>>> from sympy import sin, cos, E
>>> from sympy.abc import x, y, z, t
>>> path = EPath("/*/[0]/Symbol")
>>> expr = [((x, 1), 2), ((3, y), z)]
>>> path.apply(expr, lambda expr: expr**2)
[((x**2, 1), 2), ((3, y**2), z)]
>>> path = EPath("/*/*/Symbol")
>>> expr = t + sin(x + 1) + cos(x + y + E)
>>> path.apply(expr, lambda expr: 2*expr)
t + sin(2*x + 1) + cos(2*x + 2*y + E)
"""
def _apply(path, expr, func):
if not path:
return func(expr)
else:
selector, path = path[0], path[1:]
attrs, types, span = selector
if isinstance(expr, Basic):
if not expr.is_Atom:
args, basic = self._get_ordered_args(expr), True
else:
return expr
elif hasattr(expr, '__iter__'):
args, basic = expr, False
else:
return expr
args = list(args)
if span is not None:
if isinstance(span, slice):
indices = range(*span.indices(len(args)))
else:
indices = [span]
else:
indices = range(len(args))
for i in indices:
try:
arg = args[i]
except IndexError:
continue
if self._has(arg, attrs, types):
args[i] = _apply(path, arg, func)
if basic:
return expr.func(*args)
else:
return expr.__class__(args)
_args, _kwargs = args or (), kwargs or {}
_func = lambda expr: func(expr, *_args, **_kwargs)
return _apply(self._epath, expr, _func)
def select(self, expr):
"""
Retrieve parts of an expression selected by a path.
Examples
========
>>> from sympy.simplify.epathtools import EPath
>>> from sympy import sin, cos, E
>>> from sympy.abc import x, y, z, t
>>> path = EPath("/*/[0]/Symbol")
>>> expr = [((x, 1), 2), ((3, y), z)]
>>> path.select(expr)
[x, y]
>>> path = EPath("/*/*/Symbol")
>>> expr = t + sin(x + 1) + cos(x + y + E)
>>> path.select(expr)
[x, x, y]
"""
result = []
def _select(path, expr):
if not path:
result.append(expr)
else:
selector, path = path[0], path[1:]
attrs, types, span = selector
if isinstance(expr, Basic):
args = self._get_ordered_args(expr)
elif hasattr(expr, '__iter__'):
args = expr
else:
return
if span is not None:
if isinstance(span, slice):
args = args[span]
else:
try:
args = [args[span]]
except IndexError:
return
for arg in args:
if self._has(arg, attrs, types):
_select(path, arg)
_select(self._epath, expr)
return result
def epath(path, expr=None, func=None, args=None, kwargs=None):
r"""
Manipulate parts of an expression selected by a path.
Explanation
===========
This function allows to manipulate large nested expressions in single
line of code, utilizing techniques to those applied in XML processing
standards (e.g. XPath).
If ``func`` is ``None``, :func:`epath` retrieves elements selected by
the ``path``. Otherwise it applies ``func`` to each matching element.
Note that it is more efficient to create an EPath object and use the select
and apply methods of that object, since this will compile the path string
only once. This function should only be used as a convenient shortcut for
interactive use.
This is the supported syntax:
* select all: ``/*``
Equivalent of ``for arg in args:``.
* select slice: ``/[0]`` or ``/[1:5]`` or ``/[1:5:2]``
Supports standard Python's slice syntax.
* select by type: ``/list`` or ``/list|tuple``
Emulates ``isinstance()``.
* select by attribute: ``/__iter__?``
Emulates ``hasattr()``.
Parameters
==========
path : str | EPath
A path as a string or a compiled EPath.
expr : Basic | iterable
An expression or a container of expressions.
func : callable (optional)
A callable that will be applied to matching parts.
args : tuple (optional)
Additional positional arguments to ``func``.
kwargs : dict (optional)
Additional keyword arguments to ``func``.
Examples
========
>>> from sympy.simplify.epathtools import epath
>>> from sympy import sin, cos, E
>>> from sympy.abc import x, y, z, t
>>> path = "/*/[0]/Symbol"
>>> expr = [((x, 1), 2), ((3, y), z)]
>>> epath(path, expr)
[x, y]
>>> epath(path, expr, lambda expr: expr**2)
[((x**2, 1), 2), ((3, y**2), z)]
>>> path = "/*/*/Symbol"
>>> expr = t + sin(x + 1) + cos(x + y + E)
>>> epath(path, expr)
[x, x, y]
>>> epath(path, expr, lambda expr: 2*expr)
t + sin(2*x + 1) + cos(2*x + 2*y + E)
"""
_epath = EPath(path)
if expr is None:
return _epath
if func is None:
return _epath.select(expr)
else:
return _epath.apply(expr, func, args, kwargs)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,497 @@
from sympy.core import Function, S, Mul, Pow, Add
from sympy.core.sorting import ordered, default_sort_key
from sympy.core.function import expand_func
from sympy.core.symbol import Dummy
from sympy.functions import gamma, sqrt, sin
from sympy.polys import factor, cancel
from sympy.utilities.iterables import sift, uniq
def gammasimp(expr):
r"""
Simplify expressions with gamma functions.
Explanation
===========
This function takes as input an expression containing gamma
functions or functions that can be rewritten in terms of gamma
functions and tries to minimize the number of those functions and
reduce the size of their arguments.
The algorithm works by rewriting all gamma functions as expressions
involving rising factorials (Pochhammer symbols) and applies
recurrence relations and other transformations applicable to rising
factorials, to reduce their arguments, possibly letting the resulting
rising factorial to cancel. Rising factorials with the second argument
being an integer are expanded into polynomial forms and finally all
other rising factorial are rewritten in terms of gamma functions.
Then the following two steps are performed.
1. Reduce the number of gammas by applying the reflection theorem
gamma(x)*gamma(1-x) == pi/sin(pi*x).
2. Reduce the number of gammas by applying the multiplication theorem
gamma(x)*gamma(x+1/n)*...*gamma(x+(n-1)/n) == C*gamma(n*x).
It then reduces the number of prefactors by absorbing them into gammas
where possible and expands gammas with rational argument.
All transformation rules can be found (or were derived from) here:
.. [1] https://functions.wolfram.com/GammaBetaErf/Pochhammer/17/01/02/
.. [2] https://functions.wolfram.com/GammaBetaErf/Pochhammer/27/01/0005/
Examples
========
>>> from sympy.simplify import gammasimp
>>> from sympy import gamma, Symbol
>>> from sympy.abc import x
>>> n = Symbol('n', integer = True)
>>> gammasimp(gamma(x)/gamma(x - 3))
(x - 3)*(x - 2)*(x - 1)
>>> gammasimp(gamma(n + 3))
gamma(n + 3)
"""
expr = expr.rewrite(gamma)
# compute_ST will be looking for Functions and we don't want
# it looking for non-gamma functions: issue 22606
# so we mask free, non-gamma functions
f = expr.atoms(Function)
# take out gammas
gammas = {i for i in f if isinstance(i, gamma)}
if not gammas:
return expr # avoid side effects like factoring
f -= gammas
# keep only those without bound symbols
f = f & expr.as_dummy().atoms(Function)
if f:
dum, fun, simp = zip(*[
(Dummy(), fi, fi.func(*[
_gammasimp(a, as_comb=False) for a in fi.args]))
for fi in ordered(f)])
d = expr.xreplace(dict(zip(fun, dum)))
return _gammasimp(d, as_comb=False).xreplace(dict(zip(dum, simp)))
return _gammasimp(expr, as_comb=False)
def _gammasimp(expr, as_comb):
"""
Helper function for gammasimp and combsimp.
Explanation
===========
Simplifies expressions written in terms of gamma function. If
as_comb is True, it tries to preserve integer arguments. See
docstring of gammasimp for more information. This was part of
combsimp() in combsimp.py.
"""
expr = expr.replace(gamma,
lambda n: _rf(1, (n - 1).expand()))
if as_comb:
expr = expr.replace(_rf,
lambda a, b: gamma(b + 1))
else:
expr = expr.replace(_rf,
lambda a, b: gamma(a + b)/gamma(a))
def rule_gamma(expr, level=0):
""" Simplify products of gamma functions further. """
if expr.is_Atom:
return expr
def gamma_rat(x):
# helper to simplify ratios of gammas
was = x.count(gamma)
xx = x.replace(gamma, lambda n: _rf(1, (n - 1).expand()
).replace(_rf, lambda a, b: gamma(a + b)/gamma(a)))
if xx.count(gamma) < was:
x = xx
return x
def gamma_factor(x):
# return True if there is a gamma factor in shallow args
if isinstance(x, gamma):
return True
if x.is_Add or x.is_Mul:
return any(gamma_factor(xi) for xi in x.args)
if x.is_Pow and (x.exp.is_integer or x.base.is_positive):
return gamma_factor(x.base)
return False
# recursion step
if level == 0:
expr = expr.func(*[rule_gamma(x, level + 1) for x in expr.args])
level += 1
if not expr.is_Mul:
return expr
# non-commutative step
if level == 1:
args, nc = expr.args_cnc()
if not args:
return expr
if nc:
return rule_gamma(Mul._from_args(args), level + 1)*Mul._from_args(nc)
level += 1
# pure gamma handling, not factor absorption
if level == 2:
T, F = sift(expr.args, gamma_factor, binary=True)
gamma_ind = Mul(*F)
d = Mul(*T)
nd, dd = d.as_numer_denom()
for ipass in range(2):
args = list(ordered(Mul.make_args(nd)))
for i, ni in enumerate(args):
if ni.is_Add:
ni, dd = Add(*[
rule_gamma(gamma_rat(a/dd), level + 1) for a in ni.args]
).as_numer_denom()
args[i] = ni
if not dd.has(gamma):
break
nd = Mul(*args)
if ipass == 0 and not gamma_factor(nd):
break
nd, dd = dd, nd # now process in reversed order
expr = gamma_ind*nd/dd
if not (expr.is_Mul and (gamma_factor(dd) or gamma_factor(nd))):
return expr
level += 1
# iteration until constant
if level == 3:
while True:
was = expr
expr = rule_gamma(expr, 4)
if expr == was:
return expr
numer_gammas = []
denom_gammas = []
numer_others = []
denom_others = []
def explicate(p):
if p is S.One:
return None, []
b, e = p.as_base_exp()
if e.is_Integer:
if isinstance(b, gamma):
return True, [b.args[0]]*e
else:
return False, [b]*e
else:
return False, [p]
newargs = list(ordered(expr.args))
while newargs:
n, d = newargs.pop().as_numer_denom()
isg, l = explicate(n)
if isg:
numer_gammas.extend(l)
elif isg is False:
numer_others.extend(l)
isg, l = explicate(d)
if isg:
denom_gammas.extend(l)
elif isg is False:
denom_others.extend(l)
# =========== level 2 work: pure gamma manipulation =========
if not as_comb:
# Try to reduce the number of gamma factors by applying the
# reflection formula gamma(x)*gamma(1-x) = pi/sin(pi*x)
for gammas, numer, denom in [(
numer_gammas, numer_others, denom_others),
(denom_gammas, denom_others, numer_others)]:
new = []
while gammas:
g1 = gammas.pop()
if g1.is_integer:
new.append(g1)
continue
for i, g2 in enumerate(gammas):
n = g1 + g2 - 1
if not n.is_Integer:
continue
numer.append(S.Pi)
denom.append(sin(S.Pi*g1))
gammas.pop(i)
if n > 0:
for k in range(n):
numer.append(1 - g1 + k)
elif n < 0:
for k in range(-n):
denom.append(-g1 - k)
break
else:
new.append(g1)
# /!\ updating IN PLACE
gammas[:] = new
# Try to reduce the number of gammas by using the duplication
# theorem to cancel an upper and lower: gamma(2*s)/gamma(s) =
# 2**(2*s + 1)/(4*sqrt(pi))*gamma(s + 1/2). Although this could
# be done with higher argument ratios like gamma(3*x)/gamma(x),
# this would not reduce the number of gammas as in this case.
for ng, dg, no, do in [(numer_gammas, denom_gammas, numer_others,
denom_others),
(denom_gammas, numer_gammas, denom_others,
numer_others)]:
while True:
for x in ng:
for y in dg:
n = x - 2*y
if n.is_Integer:
break
else:
continue
break
else:
break
ng.remove(x)
dg.remove(y)
if n > 0:
for k in range(n):
no.append(2*y + k)
elif n < 0:
for k in range(-n):
do.append(2*y - 1 - k)
ng.append(y + S.Half)
no.append(2**(2*y - 1))
do.append(sqrt(S.Pi))
# Try to reduce the number of gamma factors by applying the
# multiplication theorem (used when n gammas with args differing
# by 1/n mod 1 are encountered).
#
# run of 2 with args differing by 1/2
#
# >>> gammasimp(gamma(x)*gamma(x+S.Half))
# 2*sqrt(2)*2**(-2*x - 1/2)*sqrt(pi)*gamma(2*x)
#
# run of 3 args differing by 1/3 (mod 1)
#
# >>> gammasimp(gamma(x)*gamma(x+S(1)/3)*gamma(x+S(2)/3))
# 6*3**(-3*x - 1/2)*pi*gamma(3*x)
# >>> gammasimp(gamma(x)*gamma(x+S(1)/3)*gamma(x+S(5)/3))
# 2*3**(-3*x - 1/2)*pi*(3*x + 2)*gamma(3*x)
#
def _run(coeffs):
# find runs in coeffs such that the difference in terms (mod 1)
# of t1, t2, ..., tn is 1/n
u = list(uniq(coeffs))
for i in range(len(u)):
dj = ([((u[j] - u[i]) % 1, j) for j in range(i + 1, len(u))])
for one, j in dj:
if one.p == 1 and one.q != 1:
n = one.q
got = [i]
get = list(range(1, n))
for d, j in dj:
m = n*d
if m.is_Integer and m in get:
get.remove(m)
got.append(j)
if not get:
break
else:
continue
for i, j in enumerate(got):
c = u[j]
coeffs.remove(c)
got[i] = c
return one.q, got[0], got[1:]
def _mult_thm(gammas, numer, denom):
# pull off and analyze the leading coefficient from each gamma arg
# looking for runs in those Rationals
# expr -> coeff + resid -> rats[resid] = coeff
rats = {}
for g in gammas:
c, resid = g.as_coeff_Add()
rats.setdefault(resid, []).append(c)
# look for runs in Rationals for each resid
keys = sorted(rats, key=default_sort_key)
for resid in keys:
coeffs = sorted(rats[resid])
new = []
while True:
run = _run(coeffs)
if run is None:
break
# process the sequence that was found:
# 1) convert all the gamma functions to have the right
# argument (could be off by an integer)
# 2) append the factors corresponding to the theorem
# 3) append the new gamma function
n, ui, other = run
# (1)
for u in other:
con = resid + u - 1
for k in range(int(u - ui)):
numer.append(con - k)
con = n*(resid + ui) # for (2) and (3)
# (2)
numer.append((2*S.Pi)**(S(n - 1)/2)*
n**(S.Half - con))
# (3)
new.append(con)
# restore resid to coeffs
rats[resid] = [resid + c for c in coeffs] + new
# rebuild the gamma arguments
g = []
for resid in keys:
g += rats[resid]
# /!\ updating IN PLACE
gammas[:] = g
for l, numer, denom in [(numer_gammas, numer_others, denom_others),
(denom_gammas, denom_others, numer_others)]:
_mult_thm(l, numer, denom)
# =========== level >= 2 work: factor absorption =========
if level >= 2:
# Try to absorb factors into the gammas: x*gamma(x) -> gamma(x + 1)
# and gamma(x)/(x - 1) -> gamma(x - 1)
# This code (in particular repeated calls to find_fuzzy) can be very
# slow.
def find_fuzzy(l, x):
if not l:
return
S1, T1 = compute_ST(x)
for y in l:
S2, T2 = inv[y]
if T1 != T2 or (not S1.intersection(S2) and
(S1 != set() or S2 != set())):
continue
# XXX we want some simplification (e.g. cancel or
# simplify) but no matter what it's slow.
a = len(cancel(x/y).free_symbols)
b = len(x.free_symbols)
c = len(y.free_symbols)
# TODO is there a better heuristic?
if a == 0 and (b > 0 or c > 0):
return y
# We thus try to avoid expensive calls by building the following
# "invariants": For every factor or gamma function argument
# - the set of free symbols S
# - the set of functional components T
# We will only try to absorb if T1==T2 and (S1 intersect S2 != emptyset
# or S1 == S2 == emptyset)
inv = {}
def compute_ST(expr):
if expr in inv:
return inv[expr]
return (expr.free_symbols, expr.atoms(Function).union(
{e.exp for e in expr.atoms(Pow)}))
def update_ST(expr):
inv[expr] = compute_ST(expr)
for expr in numer_gammas + denom_gammas + numer_others + denom_others:
update_ST(expr)
for gammas, numer, denom in [(
numer_gammas, numer_others, denom_others),
(denom_gammas, denom_others, numer_others)]:
new = []
while gammas:
g = gammas.pop()
cont = True
while cont:
cont = False
y = find_fuzzy(numer, g)
if y is not None:
numer.remove(y)
if y != g:
numer.append(y/g)
update_ST(y/g)
g += 1
cont = True
y = find_fuzzy(denom, g - 1)
if y is not None:
denom.remove(y)
if y != g - 1:
numer.append((g - 1)/y)
update_ST((g - 1)/y)
g -= 1
cont = True
new.append(g)
# /!\ updating IN PLACE
gammas[:] = new
# =========== rebuild expr ==================================
return Mul(*[gamma(g) for g in numer_gammas]) \
/ Mul(*[gamma(g) for g in denom_gammas]) \
* Mul(*numer_others) / Mul(*denom_others)
was = factor(expr)
# (for some reason we cannot use Basic.replace in this case)
expr = rule_gamma(was)
if expr != was:
expr = factor(expr)
expr = expr.replace(gamma,
lambda n: expand_func(gamma(n)) if n.is_Rational else gamma(n))
return expr
class _rf(Function):
@classmethod
def eval(cls, a, b):
if b.is_Integer:
if not b:
return S.One
n = int(b)
if n > 0:
return Mul(*[a + i for i in range(n)])
elif n < 0:
return 1/Mul(*[a - i for i in range(1, -n + 1)])
else:
if b.is_Add:
c, _b = b.as_coeff_Add()
if c.is_Integer:
if c > 0:
return _rf(a, _b)*_rf(a + _b, c)
elif c < 0:
return _rf(a, _b)/_rf(a + _b + c, -c)
if a.is_Add:
c, _a = a.as_coeff_Add()
if c.is_Integer:
if c > 0:
return _rf(_a, b)*_rf(_a + b, c)/_rf(_a, c)
elif c < 0:
return _rf(_a, b)*_rf(_a + c, -c)/_rf(_a + b + c, -c)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,18 @@
""" This module cooks up a docstring when imported. Its only purpose is to
be displayed in the sphinx documentation. """
from sympy.core.relational import Eq
from sympy.functions.special.hyper import hyper
from sympy.printing.latex import latex
from sympy.simplify.hyperexpand import FormulaCollection
c = FormulaCollection()
doc = ""
for f in c.formulae:
obj = Eq(hyper(f.func.ap, f.func.bq, f.z),
f.closed_form.rewrite('nonrepsmall'))
doc += ".. math::\n %s\n" % latex(obj)
__doc__ = doc

View File

@ -0,0 +1,714 @@
from collections import defaultdict
from functools import reduce
from math import prod
from sympy.core.function import expand_log, count_ops, _coeff_isneg
from sympy.core import sympify, Basic, Dummy, S, Add, Mul, Pow, expand_mul, factor_terms
from sympy.core.sorting import ordered, default_sort_key
from sympy.core.numbers import Integer, Rational
from sympy.core.mul import _keep_coeff
from sympy.core.rules import Transform
from sympy.functions import exp_polar, exp, log, root, polarify, unpolarify
from sympy.matrices.expressions.matexpr import MatrixSymbol
from sympy.polys import lcm, gcd
from sympy.ntheory.factor_ import multiplicity
def powsimp(expr, deep=False, combine='all', force=False, measure=count_ops):
"""
Reduce expression by combining powers with similar bases and exponents.
Explanation
===========
If ``deep`` is ``True`` then powsimp() will also simplify arguments of
functions. By default ``deep`` is set to ``False``.
If ``force`` is ``True`` then bases will be combined without checking for
assumptions, e.g. sqrt(x)*sqrt(y) -> sqrt(x*y) which is not true
if x and y are both negative.
You can make powsimp() only combine bases or only combine exponents by
changing combine='base' or combine='exp'. By default, combine='all',
which does both. combine='base' will only combine::
a a a 2x x
x * y => (x*y) as well as things like 2 => 4
and combine='exp' will only combine
::
a b (a + b)
x * x => x
combine='exp' will strictly only combine exponents in the way that used
to be automatic. Also use deep=True if you need the old behavior.
When combine='all', 'exp' is evaluated first. Consider the first
example below for when there could be an ambiguity relating to this.
This is done so things like the second example can be completely
combined. If you want 'base' combined first, do something like
powsimp(powsimp(expr, combine='base'), combine='exp').
Examples
========
>>> from sympy import powsimp, exp, log, symbols
>>> from sympy.abc import x, y, z, n
>>> powsimp(x**y*x**z*y**z, combine='all')
x**(y + z)*y**z
>>> powsimp(x**y*x**z*y**z, combine='exp')
x**(y + z)*y**z
>>> powsimp(x**y*x**z*y**z, combine='base', force=True)
x**y*(x*y)**z
>>> powsimp(x**z*x**y*n**z*n**y, combine='all', force=True)
(n*x)**(y + z)
>>> powsimp(x**z*x**y*n**z*n**y, combine='exp')
n**(y + z)*x**(y + z)
>>> powsimp(x**z*x**y*n**z*n**y, combine='base', force=True)
(n*x)**y*(n*x)**z
>>> x, y = symbols('x y', positive=True)
>>> powsimp(log(exp(x)*exp(y)))
log(exp(x)*exp(y))
>>> powsimp(log(exp(x)*exp(y)), deep=True)
x + y
Radicals with Mul bases will be combined if combine='exp'
>>> from sympy import sqrt
>>> x, y = symbols('x y')
Two radicals are automatically joined through Mul:
>>> a=sqrt(x*sqrt(y))
>>> a*a**3 == a**4
True
But if an integer power of that radical has been
autoexpanded then Mul does not join the resulting factors:
>>> a**4 # auto expands to a Mul, no longer a Pow
x**2*y
>>> _*a # so Mul doesn't combine them
x**2*y*sqrt(x*sqrt(y))
>>> powsimp(_) # but powsimp will
(x*sqrt(y))**(5/2)
>>> powsimp(x*y*a) # but won't when doing so would violate assumptions
x*y*sqrt(x*sqrt(y))
"""
def recurse(arg, **kwargs):
_deep = kwargs.get('deep', deep)
_combine = kwargs.get('combine', combine)
_force = kwargs.get('force', force)
_measure = kwargs.get('measure', measure)
return powsimp(arg, _deep, _combine, _force, _measure)
expr = sympify(expr)
if (not isinstance(expr, Basic) or isinstance(expr, MatrixSymbol) or (
expr.is_Atom or expr in (exp_polar(0), exp_polar(1)))):
return expr
if deep or expr.is_Add or expr.is_Mul and _y not in expr.args:
expr = expr.func(*[recurse(w) for w in expr.args])
if expr.is_Pow:
return recurse(expr*_y, deep=False)/_y
if not expr.is_Mul:
return expr
# handle the Mul
if combine in ('exp', 'all'):
# Collect base/exp data, while maintaining order in the
# non-commutative parts of the product
c_powers = defaultdict(list)
nc_part = []
newexpr = []
coeff = S.One
for term in expr.args:
if term.is_Rational:
coeff *= term
continue
if term.is_Pow:
term = _denest_pow(term)
if term.is_commutative:
b, e = term.as_base_exp()
if deep:
b, e = [recurse(i) for i in [b, e]]
if b.is_Pow or isinstance(b, exp):
# don't let smthg like sqrt(x**a) split into x**a, 1/2
# or else it will be joined as x**(a/2) later
b, e = b**e, S.One
c_powers[b].append(e)
else:
# This is the logic that combines exponents for equal,
# but non-commutative bases: A**x*A**y == A**(x+y).
if nc_part:
b1, e1 = nc_part[-1].as_base_exp()
b2, e2 = term.as_base_exp()
if (b1 == b2 and
e1.is_commutative and e2.is_commutative):
nc_part[-1] = Pow(b1, Add(e1, e2))
continue
nc_part.append(term)
# add up exponents of common bases
for b, e in ordered(iter(c_powers.items())):
# allow 2**x/4 -> 2**(x - 2); don't do this when b and e are
# Numbers since autoevaluation will undo it, e.g.
# 2**(1/3)/4 -> 2**(1/3 - 2) -> 2**(1/3)/4
if (b and b.is_Rational and not all(ei.is_Number for ei in e) and \
coeff is not S.One and
b not in (S.One, S.NegativeOne)):
m = multiplicity(abs(b), abs(coeff))
if m:
e.append(m)
coeff /= b**m
c_powers[b] = Add(*e)
if coeff is not S.One:
if coeff in c_powers:
c_powers[coeff] += S.One
else:
c_powers[coeff] = S.One
# convert to plain dictionary
c_powers = dict(c_powers)
# check for base and inverted base pairs
be = list(c_powers.items())
skip = set() # skip if we already saw them
for b, e in be:
if b in skip:
continue
bpos = b.is_positive or b.is_polar
if bpos:
binv = 1/b
if b != binv and binv in c_powers:
if b.as_numer_denom()[0] is S.One:
c_powers.pop(b)
c_powers[binv] -= e
else:
skip.add(binv)
e = c_powers.pop(binv)
c_powers[b] -= e
# check for base and negated base pairs
be = list(c_powers.items())
_n = S.NegativeOne
for b, e in be:
if (b.is_Symbol or b.is_Add) and -b in c_powers and b in c_powers:
if (b.is_positive is not None or e.is_integer):
if e.is_integer or b.is_negative:
c_powers[-b] += c_powers.pop(b)
else: # (-b).is_positive so use its e
e = c_powers.pop(-b)
c_powers[b] += e
if _n in c_powers:
c_powers[_n] += e
else:
c_powers[_n] = e
# filter c_powers and convert to a list
c_powers = [(b, e) for b, e in c_powers.items() if e]
# ==============================================================
# check for Mul bases of Rational powers that can be combined with
# separated bases, e.g. x*sqrt(x*y)*sqrt(x*sqrt(x*y)) ->
# (x*sqrt(x*y))**(3/2)
# ---------------- helper functions
def ratq(x):
'''Return Rational part of x's exponent as it appears in the bkey.
'''
return bkey(x)[0][1]
def bkey(b, e=None):
'''Return (b**s, c.q), c.p where e -> c*s. If e is not given then
it will be taken by using as_base_exp() on the input b.
e.g.
x**3/2 -> (x, 2), 3
x**y -> (x**y, 1), 1
x**(2*y/3) -> (x**y, 3), 2
exp(x/2) -> (exp(a), 2), 1
'''
if e is not None: # coming from c_powers or from below
if e.is_Integer:
return (b, S.One), e
elif e.is_Rational:
return (b, Integer(e.q)), Integer(e.p)
else:
c, m = e.as_coeff_Mul(rational=True)
if c is not S.One:
if m.is_integer:
return (b, Integer(c.q)), m*Integer(c.p)
return (b**m, Integer(c.q)), Integer(c.p)
else:
return (b**e, S.One), S.One
else:
return bkey(*b.as_base_exp())
def update(b):
'''Decide what to do with base, b. If its exponent is now an
integer multiple of the Rational denominator, then remove it
and put the factors of its base in the common_b dictionary or
update the existing bases if necessary. If it has been zeroed
out, simply remove the base.
'''
newe, r = divmod(common_b[b], b[1])
if not r:
common_b.pop(b)
if newe:
for m in Mul.make_args(b[0]**newe):
b, e = bkey(m)
if b not in common_b:
common_b[b] = 0
common_b[b] += e
if b[1] != 1:
bases.append(b)
# ---------------- end of helper functions
# assemble a dictionary of the factors having a Rational power
common_b = {}
done = []
bases = []
for b, e in c_powers:
b, e = bkey(b, e)
if b in common_b:
common_b[b] = common_b[b] + e
else:
common_b[b] = e
if b[1] != 1 and b[0].is_Mul:
bases.append(b)
bases.sort(key=default_sort_key) # this makes tie-breaking canonical
bases.sort(key=measure, reverse=True) # handle longest first
for base in bases:
if base not in common_b: # it may have been removed already
continue
b, exponent = base
last = False # True when no factor of base is a radical
qlcm = 1 # the lcm of the radical denominators
while True:
bstart = b
qstart = qlcm
bb = [] # list of factors
ee = [] # (factor's expo. and it's current value in common_b)
for bi in Mul.make_args(b):
bib, bie = bkey(bi)
if bib not in common_b or common_b[bib] < bie:
ee = bb = [] # failed
break
ee.append([bie, common_b[bib]])
bb.append(bib)
if ee:
# find the number of integral extractions possible
# e.g. [(1, 2), (2, 2)] -> min(2/1, 2/2) -> 1
min1 = ee[0][1]//ee[0][0]
for i in range(1, len(ee)):
rat = ee[i][1]//ee[i][0]
if rat < 1:
break
min1 = min(min1, rat)
else:
# update base factor counts
# e.g. if ee = [(2, 5), (3, 6)] then min1 = 2
# and the new base counts will be 5-2*2 and 6-2*3
for i in range(len(bb)):
common_b[bb[i]] -= min1*ee[i][0]
update(bb[i])
# update the count of the base
# e.g. x**2*y*sqrt(x*sqrt(y)) the count of x*sqrt(y)
# will increase by 4 to give bkey (x*sqrt(y), 2, 5)
common_b[base] += min1*qstart*exponent
if (last # no more radicals in base
or len(common_b) == 1 # nothing left to join with
or all(k[1] == 1 for k in common_b) # no rad's in common_b
):
break
# see what we can exponentiate base by to remove any radicals
# so we know what to search for
# e.g. if base were x**(1/2)*y**(1/3) then we should
# exponentiate by 6 and look for powers of x and y in the ratio
# of 2 to 3
qlcm = lcm([ratq(bi) for bi in Mul.make_args(bstart)])
if qlcm == 1:
break # we are done
b = bstart**qlcm
qlcm *= qstart
if all(ratq(bi) == 1 for bi in Mul.make_args(b)):
last = True # we are going to be done after this next pass
# this base no longer can find anything to join with and
# since it was longer than any other we are done with it
b, q = base
done.append((b, common_b.pop(base)*Rational(1, q)))
# update c_powers and get ready to continue with powsimp
c_powers = done
# there may be terms still in common_b that were bases that were
# identified as needing processing, so remove those, too
for (b, q), e in common_b.items():
if (b.is_Pow or isinstance(b, exp)) and \
q is not S.One and not b.exp.is_Rational:
b, be = b.as_base_exp()
b = b**(be/q)
else:
b = root(b, q)
c_powers.append((b, e))
check = len(c_powers)
c_powers = dict(c_powers)
assert len(c_powers) == check # there should have been no duplicates
# ==============================================================
# rebuild the expression
newexpr = expr.func(*(newexpr + [Pow(b, e) for b, e in c_powers.items()]))
if combine == 'exp':
return expr.func(newexpr, expr.func(*nc_part))
else:
return recurse(expr.func(*nc_part), combine='base') * \
recurse(newexpr, combine='base')
elif combine == 'base':
# Build c_powers and nc_part. These must both be lists not
# dicts because exp's are not combined.
c_powers = []
nc_part = []
for term in expr.args:
if term.is_commutative:
c_powers.append(list(term.as_base_exp()))
else:
nc_part.append(term)
# Pull out numerical coefficients from exponent if assumptions allow
# e.g., 2**(2*x) => 4**x
for i in range(len(c_powers)):
b, e = c_powers[i]
if not (all(x.is_nonnegative for x in b.as_numer_denom()) or e.is_integer or force or b.is_polar):
continue
exp_c, exp_t = e.as_coeff_Mul(rational=True)
if exp_c is not S.One and exp_t is not S.One:
c_powers[i] = [Pow(b, exp_c), exp_t]
# Combine bases whenever they have the same exponent and
# assumptions allow
# first gather the potential bases under the common exponent
c_exp = defaultdict(list)
for b, e in c_powers:
if deep:
e = recurse(e)
if e.is_Add and (b.is_positive or e.is_integer):
e = factor_terms(e)
if _coeff_isneg(e):
e = -e
b = 1/b
c_exp[e].append(b)
del c_powers
# Merge back in the results of the above to form a new product
c_powers = defaultdict(list)
for e in c_exp:
bases = c_exp[e]
# calculate the new base for e
if len(bases) == 1:
new_base = bases[0]
elif e.is_integer or force:
new_base = expr.func(*bases)
else:
# see which ones can be joined
unk = []
nonneg = []
neg = []
for bi in bases:
if bi.is_negative:
neg.append(bi)
elif bi.is_nonnegative:
nonneg.append(bi)
elif bi.is_polar:
nonneg.append(
bi) # polar can be treated like non-negative
else:
unk.append(bi)
if len(unk) == 1 and not neg or len(neg) == 1 and not unk:
# a single neg or a single unk can join the rest
nonneg.extend(unk + neg)
unk = neg = []
elif neg:
# their negative signs cancel in groups of 2*q if we know
# that e = p/q else we have to treat them as unknown
israt = False
if e.is_Rational:
israt = True
else:
p, d = e.as_numer_denom()
if p.is_integer and d.is_integer:
israt = True
if israt:
neg = [-w for w in neg]
unk.extend([S.NegativeOne]*len(neg))
else:
unk.extend(neg)
neg = []
del israt
# these shouldn't be joined
for b in unk:
c_powers[b].append(e)
# here is a new joined base
new_base = expr.func(*(nonneg + neg))
# if there are positive parts they will just get separated
# again unless some change is made
def _terms(e):
# return the number of terms of this expression
# when multiplied out -- assuming no joining of terms
if e.is_Add:
return sum(_terms(ai) for ai in e.args)
if e.is_Mul:
return prod([_terms(mi) for mi in e.args])
return 1
xnew_base = expand_mul(new_base, deep=False)
if len(Add.make_args(xnew_base)) < _terms(new_base):
new_base = factor_terms(xnew_base)
c_powers[new_base].append(e)
# break out the powers from c_powers now
c_part = [Pow(b, ei) for b, e in c_powers.items() for ei in e]
# we're done
return expr.func(*(c_part + nc_part))
else:
raise ValueError("combine must be one of ('all', 'exp', 'base').")
def powdenest(eq, force=False, polar=False):
r"""
Collect exponents on powers as assumptions allow.
Explanation
===========
Given ``(bb**be)**e``, this can be simplified as follows:
* if ``bb`` is positive, or
* ``e`` is an integer, or
* ``|be| < 1`` then this simplifies to ``bb**(be*e)``
Given a product of powers raised to a power, ``(bb1**be1 *
bb2**be2...)**e``, simplification can be done as follows:
- if e is positive, the gcd of all bei can be joined with e;
- all non-negative bb can be separated from those that are negative
and their gcd can be joined with e; autosimplification already
handles this separation.
- integer factors from powers that have integers in the denominator
of the exponent can be removed from any term and the gcd of such
integers can be joined with e
Setting ``force`` to ``True`` will make symbols that are not explicitly
negative behave as though they are positive, resulting in more
denesting.
Setting ``polar`` to ``True`` will do simplifications on the Riemann surface of
the logarithm, also resulting in more denestings.
When there are sums of logs in exp() then a product of powers may be
obtained e.g. ``exp(3*(log(a) + 2*log(b)))`` - > ``a**3*b**6``.
Examples
========
>>> from sympy.abc import a, b, x, y, z
>>> from sympy import Symbol, exp, log, sqrt, symbols, powdenest
>>> powdenest((x**(2*a/3))**(3*x))
(x**(2*a/3))**(3*x)
>>> powdenest(exp(3*x*log(2)))
2**(3*x)
Assumptions may prevent expansion:
>>> powdenest(sqrt(x**2))
sqrt(x**2)
>>> p = symbols('p', positive=True)
>>> powdenest(sqrt(p**2))
p
No other expansion is done.
>>> i, j = symbols('i,j', integer=True)
>>> powdenest((x**x)**(i + j)) # -X-> (x**x)**i*(x**x)**j
x**(x*(i + j))
But exp() will be denested by moving all non-log terms outside of
the function; this may result in the collapsing of the exp to a power
with a different base:
>>> powdenest(exp(3*y*log(x)))
x**(3*y)
>>> powdenest(exp(y*(log(a) + log(b))))
(a*b)**y
>>> powdenest(exp(3*(log(a) + log(b))))
a**3*b**3
If assumptions allow, symbols can also be moved to the outermost exponent:
>>> i = Symbol('i', integer=True)
>>> powdenest(((x**(2*i))**(3*y))**x)
((x**(2*i))**(3*y))**x
>>> powdenest(((x**(2*i))**(3*y))**x, force=True)
x**(6*i*x*y)
>>> powdenest(((x**(2*a/3))**(3*y/i))**x)
((x**(2*a/3))**(3*y/i))**x
>>> powdenest((x**(2*i)*y**(4*i))**z, force=True)
(x*y**2)**(2*i*z)
>>> n = Symbol('n', negative=True)
>>> powdenest((x**i)**y, force=True)
x**(i*y)
>>> powdenest((n**i)**x, force=True)
(n**i)**x
"""
from sympy.simplify.simplify import posify
if force:
def _denest(b, e):
if not isinstance(b, (Pow, exp)):
return b.is_positive, Pow(b, e, evaluate=False)
return _denest(b.base, b.exp*e)
reps = []
for p in eq.atoms(Pow, exp):
if isinstance(p.base, (Pow, exp)):
ok, dp = _denest(*p.args)
if ok is not False:
reps.append((p, dp))
if reps:
eq = eq.subs(reps)
eq, reps = posify(eq)
return powdenest(eq, force=False, polar=polar).xreplace(reps)
if polar:
eq, rep = polarify(eq)
return unpolarify(powdenest(unpolarify(eq, exponents_only=True)), rep)
new = powsimp(eq)
return new.xreplace(Transform(
_denest_pow, filter=lambda m: m.is_Pow or isinstance(m, exp)))
_y = Dummy('y')
def _denest_pow(eq):
"""
Denest powers.
This is a helper function for powdenest that performs the actual
transformation.
"""
from sympy.simplify.simplify import logcombine
b, e = eq.as_base_exp()
if b.is_Pow or isinstance(b, exp) and e != 1:
new = b._eval_power(e)
if new is not None:
eq = new
b, e = new.as_base_exp()
# denest exp with log terms in exponent
if b is S.Exp1 and e.is_Mul:
logs = []
other = []
for ei in e.args:
if any(isinstance(ai, log) for ai in Add.make_args(ei)):
logs.append(ei)
else:
other.append(ei)
logs = logcombine(Mul(*logs))
return Pow(exp(logs), Mul(*other))
_, be = b.as_base_exp()
if be is S.One and not (b.is_Mul or
b.is_Rational and b.q != 1 or
b.is_positive):
return eq
# denest eq which is either pos**e or Pow**e or Mul**e or
# Mul(b1**e1, b2**e2)
# handle polar numbers specially
polars, nonpolars = [], []
for bb in Mul.make_args(b):
if bb.is_polar:
polars.append(bb.as_base_exp())
else:
nonpolars.append(bb)
if len(polars) == 1 and not polars[0][0].is_Mul:
return Pow(polars[0][0], polars[0][1]*e)*powdenest(Mul(*nonpolars)**e)
elif polars:
return Mul(*[powdenest(bb**(ee*e)) for (bb, ee) in polars]) \
*powdenest(Mul(*nonpolars)**e)
if b.is_Integer:
# use log to see if there is a power here
logb = expand_log(log(b))
if logb.is_Mul:
c, logb = logb.args
e *= c
base = logb.args[0]
return Pow(base, e)
# if b is not a Mul or any factor is an atom then there is nothing to do
if not b.is_Mul or any(s.is_Atom for s in Mul.make_args(b)):
return eq
# let log handle the case of the base of the argument being a Mul, e.g.
# sqrt(x**(2*i)*y**(6*i)) -> x**i*y**(3**i) if x and y are positive; we
# will take the log, expand it, and then factor out the common powers that
# now appear as coefficient. We do this manually since terms_gcd pulls out
# fractions, terms_gcd(x+x*y/2) -> x*(y + 2)/2 and we don't want the 1/2;
# gcd won't pull out numerators from a fraction: gcd(3*x, 9*x/2) -> x but
# we want 3*x. Neither work with noncommutatives.
def nc_gcd(aa, bb):
a, b = [i.as_coeff_Mul() for i in [aa, bb]]
c = gcd(a[0], b[0]).as_numer_denom()[0]
g = Mul(*(a[1].args_cnc(cset=True)[0] & b[1].args_cnc(cset=True)[0]))
return _keep_coeff(c, g)
glogb = expand_log(log(b))
if glogb.is_Add:
args = glogb.args
g = reduce(nc_gcd, args)
if g != 1:
cg, rg = g.as_coeff_Mul()
glogb = _keep_coeff(cg, rg*Add(*[a/g for a in args]))
# now put the log back together again
if isinstance(glogb, log) or not glogb.is_Mul:
if glogb.args[0].is_Pow or isinstance(glogb.args[0], exp):
glogb = _denest_pow(glogb.args[0])
if (abs(glogb.exp) < 1) == True:
return Pow(glogb.base, glogb.exp*e)
return eq
# the log(b) was a Mul so join any adds with logcombine
add = []
other = []
for a in glogb.args:
if a.is_Add:
add.append(a)
else:
other.append(a)
return Pow(exp(logcombine(Mul(*add))), e*Mul(*other))

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,222 @@
from itertools import combinations_with_replacement
from sympy.core import symbols, Add, Dummy
from sympy.core.numbers import Rational
from sympy.polys import cancel, ComputationFailed, parallel_poly_from_expr, reduced, Poly
from sympy.polys.monomials import Monomial, monomial_div
from sympy.polys.polyerrors import DomainError, PolificationFailed
from sympy.utilities.misc import debug, debugf
def ratsimp(expr):
"""
Put an expression over a common denominator, cancel and reduce.
Examples
========
>>> from sympy import ratsimp
>>> from sympy.abc import x, y
>>> ratsimp(1/x + 1/y)
(x + y)/(x*y)
"""
f, g = cancel(expr).as_numer_denom()
try:
Q, r = reduced(f, [g], field=True, expand=False)
except ComputationFailed:
return f/g
return Add(*Q) + cancel(r/g)
def ratsimpmodprime(expr, G, *gens, quick=True, polynomial=False, **args):
"""
Simplifies a rational expression ``expr`` modulo the prime ideal
generated by ``G``. ``G`` should be a Groebner basis of the
ideal.
Examples
========
>>> from sympy.simplify.ratsimp import ratsimpmodprime
>>> from sympy.abc import x, y
>>> eq = (x + y**5 + y)/(x - y)
>>> ratsimpmodprime(eq, [x*y**5 - x - y], x, y, order='lex')
(-x**2 - x*y - x - y)/(-x**2 + x*y)
If ``polynomial`` is ``False``, the algorithm computes a rational
simplification which minimizes the sum of the total degrees of
the numerator and the denominator.
If ``polynomial`` is ``True``, this function just brings numerator and
denominator into a canonical form. This is much faster, but has
potentially worse results.
References
==========
.. [1] M. Monagan, R. Pearce, Rational Simplification Modulo a Polynomial
Ideal, https://dl.acm.org/doi/pdf/10.1145/1145768.1145809
(specifically, the second algorithm)
"""
from sympy.solvers.solvers import solve
debug('ratsimpmodprime', expr)
# usual preparation of polynomials:
num, denom = cancel(expr).as_numer_denom()
try:
polys, opt = parallel_poly_from_expr([num, denom] + G, *gens, **args)
except PolificationFailed:
return expr
domain = opt.domain
if domain.has_assoc_Field:
opt.domain = domain.get_field()
else:
raise DomainError(
"Cannot compute rational simplification over %s" % domain)
# compute only once
leading_monomials = [g.LM(opt.order) for g in polys[2:]]
tested = set()
def staircase(n):
"""
Compute all monomials with degree less than ``n`` that are
not divisible by any element of ``leading_monomials``.
"""
if n == 0:
return [1]
S = []
for mi in combinations_with_replacement(range(len(opt.gens)), n):
m = [0]*len(opt.gens)
for i in mi:
m[i] += 1
if all(monomial_div(m, lmg) is None for lmg in
leading_monomials):
S.append(m)
return [Monomial(s).as_expr(*opt.gens) for s in S] + staircase(n - 1)
def _ratsimpmodprime(a, b, allsol, N=0, D=0):
r"""
Computes a rational simplification of ``a/b`` which minimizes
the sum of the total degrees of the numerator and the denominator.
Explanation
===========
The algorithm proceeds by looking at ``a * d - b * c`` modulo
the ideal generated by ``G`` for some ``c`` and ``d`` with degree
less than ``a`` and ``b`` respectively.
The coefficients of ``c`` and ``d`` are indeterminates and thus
the coefficients of the normalform of ``a * d - b * c`` are
linear polynomials in these indeterminates.
If these linear polynomials, considered as system of
equations, have a nontrivial solution, then `\frac{a}{b}
\equiv \frac{c}{d}` modulo the ideal generated by ``G``. So,
by construction, the degree of ``c`` and ``d`` is less than
the degree of ``a`` and ``b``, so a simpler representation
has been found.
After a simpler representation has been found, the algorithm
tries to reduce the degree of the numerator and denominator
and returns the result afterwards.
As an extension, if quick=False, we look at all possible degrees such
that the total degree is less than *or equal to* the best current
solution. We retain a list of all solutions of minimal degree, and try
to find the best one at the end.
"""
c, d = a, b
steps = 0
maxdeg = a.total_degree() + b.total_degree()
if quick:
bound = maxdeg - 1
else:
bound = maxdeg
while N + D <= bound:
if (N, D) in tested:
break
tested.add((N, D))
M1 = staircase(N)
M2 = staircase(D)
debugf('%s / %s: %s, %s', (N, D, M1, M2))
Cs = symbols("c:%d" % len(M1), cls=Dummy)
Ds = symbols("d:%d" % len(M2), cls=Dummy)
ng = Cs + Ds
c_hat = Poly(
sum(Cs[i] * M1[i] for i in range(len(M1))), opt.gens + ng)
d_hat = Poly(
sum(Ds[i] * M2[i] for i in range(len(M2))), opt.gens + ng)
r = reduced(a * d_hat - b * c_hat, G, opt.gens + ng,
order=opt.order, polys=True)[1]
S = Poly(r, gens=opt.gens).coeffs()
sol = solve(S, Cs + Ds, particular=True, quick=True)
if sol and not all(s == 0 for s in sol.values()):
c = c_hat.subs(sol)
d = d_hat.subs(sol)
# The "free" variables occurring before as parameters
# might still be in the substituted c, d, so set them
# to the value chosen before:
c = c.subs(dict(list(zip(Cs + Ds, [1] * (len(Cs) + len(Ds))))))
d = d.subs(dict(list(zip(Cs + Ds, [1] * (len(Cs) + len(Ds))))))
c = Poly(c, opt.gens)
d = Poly(d, opt.gens)
if d == 0:
raise ValueError('Ideal not prime?')
allsol.append((c_hat, d_hat, S, Cs + Ds))
if N + D != maxdeg:
allsol = [allsol[-1]]
break
steps += 1
N += 1
D += 1
if steps > 0:
c, d, allsol = _ratsimpmodprime(c, d, allsol, N, D - steps)
c, d, allsol = _ratsimpmodprime(c, d, allsol, N - steps, D)
return c, d, allsol
# preprocessing. this improves performance a bit when deg(num)
# and deg(denom) are large:
num = reduced(num, G, opt.gens, order=opt.order)[1]
denom = reduced(denom, G, opt.gens, order=opt.order)[1]
if polynomial:
return (num/denom).cancel()
c, d, allsol = _ratsimpmodprime(
Poly(num, opt.gens, domain=opt.domain), Poly(denom, opt.gens, domain=opt.domain), [])
if not quick and allsol:
debugf('Looking for best minimal solution. Got: %s', len(allsol))
newsol = []
for c_hat, d_hat, S, ng in allsol:
sol = solve(S, ng, particular=True, quick=False)
# all values of sol should be numbers; if not, solve is broken
newsol.append((c_hat.subs(sol), d_hat.subs(sol)))
c, d = min(newsol, key=lambda x: len(x[0].terms()) + len(x[1].terms()))
if not domain.is_Field:
cn, c = c.clear_denoms(convert=True)
dn, d = d.clear_denoms(convert=True)
r = Rational(cn, dn)
else:
r = Rational(1)
return (c*r.q)/(d*r.p)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,678 @@
from sympy.core import Add, Expr, Mul, S, sympify
from sympy.core.function import _mexpand, count_ops, expand_mul
from sympy.core.sorting import default_sort_key
from sympy.core.symbol import Dummy
from sympy.functions import root, sign, sqrt
from sympy.polys import Poly, PolynomialError
def is_sqrt(expr):
"""Return True if expr is a sqrt, otherwise False."""
return expr.is_Pow and expr.exp.is_Rational and abs(expr.exp) is S.Half
def sqrt_depth(p) -> int:
"""Return the maximum depth of any square root argument of p.
>>> from sympy.functions.elementary.miscellaneous import sqrt
>>> from sympy.simplify.sqrtdenest import sqrt_depth
Neither of these square roots contains any other square roots
so the depth is 1:
>>> sqrt_depth(1 + sqrt(2)*(1 + sqrt(3)))
1
The sqrt(3) is contained within a square root so the depth is
2:
>>> sqrt_depth(1 + sqrt(2)*sqrt(1 + sqrt(3)))
2
"""
if p is S.ImaginaryUnit:
return 1
if p.is_Atom:
return 0
if p.is_Add or p.is_Mul:
return max(sqrt_depth(x) for x in p.args)
if is_sqrt(p):
return sqrt_depth(p.base) + 1
return 0
def is_algebraic(p):
"""Return True if p is comprised of only Rationals or square roots
of Rationals and algebraic operations.
Examples
========
>>> from sympy.functions.elementary.miscellaneous import sqrt
>>> from sympy.simplify.sqrtdenest import is_algebraic
>>> from sympy import cos
>>> is_algebraic(sqrt(2)*(3/(sqrt(7) + sqrt(5)*sqrt(2))))
True
>>> is_algebraic(sqrt(2)*(3/(sqrt(7) + sqrt(5)*cos(2))))
False
"""
if p.is_Rational:
return True
elif p.is_Atom:
return False
elif is_sqrt(p) or p.is_Pow and p.exp.is_Integer:
return is_algebraic(p.base)
elif p.is_Add or p.is_Mul:
return all(is_algebraic(x) for x in p.args)
else:
return False
def _subsets(n):
"""
Returns all possible subsets of the set (0, 1, ..., n-1) except the
empty set, listed in reversed lexicographical order according to binary
representation, so that the case of the fourth root is treated last.
Examples
========
>>> from sympy.simplify.sqrtdenest import _subsets
>>> _subsets(2)
[[1, 0], [0, 1], [1, 1]]
"""
if n == 1:
a = [[1]]
elif n == 2:
a = [[1, 0], [0, 1], [1, 1]]
elif n == 3:
a = [[1, 0, 0], [0, 1, 0], [1, 1, 0],
[0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]]
else:
b = _subsets(n - 1)
a0 = [x + [0] for x in b]
a1 = [x + [1] for x in b]
a = a0 + [[0]*(n - 1) + [1]] + a1
return a
def sqrtdenest(expr, max_iter=3):
"""Denests sqrts in an expression that contain other square roots
if possible, otherwise returns the expr unchanged. This is based on the
algorithms of [1].
Examples
========
>>> from sympy.simplify.sqrtdenest import sqrtdenest
>>> from sympy import sqrt
>>> sqrtdenest(sqrt(5 + 2 * sqrt(6)))
sqrt(2) + sqrt(3)
See Also
========
sympy.solvers.solvers.unrad
References
==========
.. [1] https://web.archive.org/web/20210806201615/https://researcher.watson.ibm.com/researcher/files/us-fagin/symb85.pdf
.. [2] D. J. Jeffrey and A. D. Rich, 'Symplifying Square Roots of Square Roots
by Denesting' (available at https://www.cybertester.com/data/denest.pdf)
"""
expr = expand_mul(expr)
for i in range(max_iter):
z = _sqrtdenest0(expr)
if expr == z:
return expr
expr = z
return expr
def _sqrt_match(p):
"""Return [a, b, r] for p.match(a + b*sqrt(r)) where, in addition to
matching, sqrt(r) also has then maximal sqrt_depth among addends of p.
Examples
========
>>> from sympy.functions.elementary.miscellaneous import sqrt
>>> from sympy.simplify.sqrtdenest import _sqrt_match
>>> _sqrt_match(1 + sqrt(2) + sqrt(2)*sqrt(3) + 2*sqrt(1+sqrt(5)))
[1 + sqrt(2) + sqrt(6), 2, 1 + sqrt(5)]
"""
from sympy.simplify.radsimp import split_surds
p = _mexpand(p)
if p.is_Number:
res = (p, S.Zero, S.Zero)
elif p.is_Add:
pargs = sorted(p.args, key=default_sort_key)
sqargs = [x**2 for x in pargs]
if all(sq.is_Rational and sq.is_positive for sq in sqargs):
r, b, a = split_surds(p)
res = a, b, r
return list(res)
# to make the process canonical, the argument is included in the tuple
# so when the max is selected, it will be the largest arg having a
# given depth
v = [(sqrt_depth(x), x, i) for i, x in enumerate(pargs)]
nmax = max(v, key=default_sort_key)
if nmax[0] == 0:
res = []
else:
# select r
depth, _, i = nmax
r = pargs.pop(i)
v.pop(i)
b = S.One
if r.is_Mul:
bv = []
rv = []
for x in r.args:
if sqrt_depth(x) < depth:
bv.append(x)
else:
rv.append(x)
b = Mul._from_args(bv)
r = Mul._from_args(rv)
# collect terms comtaining r
a1 = []
b1 = [b]
for x in v:
if x[0] < depth:
a1.append(x[1])
else:
x1 = x[1]
if x1 == r:
b1.append(1)
else:
if x1.is_Mul:
x1args = list(x1.args)
if r in x1args:
x1args.remove(r)
b1.append(Mul(*x1args))
else:
a1.append(x[1])
else:
a1.append(x[1])
a = Add(*a1)
b = Add(*b1)
res = (a, b, r**2)
else:
b, r = p.as_coeff_Mul()
if is_sqrt(r):
res = (S.Zero, b, r**2)
else:
res = []
return list(res)
class SqrtdenestStopIteration(StopIteration):
pass
def _sqrtdenest0(expr):
"""Returns expr after denesting its arguments."""
if is_sqrt(expr):
n, d = expr.as_numer_denom()
if d is S.One: # n is a square root
if n.base.is_Add:
args = sorted(n.base.args, key=default_sort_key)
if len(args) > 2 and all((x**2).is_Integer for x in args):
try:
return _sqrtdenest_rec(n)
except SqrtdenestStopIteration:
pass
expr = sqrt(_mexpand(Add(*[_sqrtdenest0(x) for x in args])))
return _sqrtdenest1(expr)
else:
n, d = [_sqrtdenest0(i) for i in (n, d)]
return n/d
if isinstance(expr, Add):
cs = []
args = []
for arg in expr.args:
c, a = arg.as_coeff_Mul()
cs.append(c)
args.append(a)
if all(c.is_Rational for c in cs) and all(is_sqrt(arg) for arg in args):
return _sqrt_ratcomb(cs, args)
if isinstance(expr, Expr):
args = expr.args
if args:
return expr.func(*[_sqrtdenest0(a) for a in args])
return expr
def _sqrtdenest_rec(expr):
"""Helper that denests the square root of three or more surds.
Explanation
===========
It returns the denested expression; if it cannot be denested it
throws SqrtdenestStopIteration
Algorithm: expr.base is in the extension Q_m = Q(sqrt(r_1),..,sqrt(r_k));
split expr.base = a + b*sqrt(r_k), where `a` and `b` are on
Q_(m-1) = Q(sqrt(r_1),..,sqrt(r_(k-1))); then a**2 - b**2*r_k is
on Q_(m-1); denest sqrt(a**2 - b**2*r_k) and so on.
See [1], section 6.
Examples
========
>>> from sympy import sqrt
>>> from sympy.simplify.sqrtdenest import _sqrtdenest_rec
>>> _sqrtdenest_rec(sqrt(-72*sqrt(2) + 158*sqrt(5) + 498))
-sqrt(10) + sqrt(2) + 9 + 9*sqrt(5)
>>> w=-6*sqrt(55)-6*sqrt(35)-2*sqrt(22)-2*sqrt(14)+2*sqrt(77)+6*sqrt(10)+65
>>> _sqrtdenest_rec(sqrt(w))
-sqrt(11) - sqrt(7) + sqrt(2) + 3*sqrt(5)
"""
from sympy.simplify.radsimp import radsimp, rad_rationalize, split_surds
if not expr.is_Pow:
return sqrtdenest(expr)
if expr.base < 0:
return sqrt(-1)*_sqrtdenest_rec(sqrt(-expr.base))
g, a, b = split_surds(expr.base)
a = a*sqrt(g)
if a < b:
a, b = b, a
c2 = _mexpand(a**2 - b**2)
if len(c2.args) > 2:
g, a1, b1 = split_surds(c2)
a1 = a1*sqrt(g)
if a1 < b1:
a1, b1 = b1, a1
c2_1 = _mexpand(a1**2 - b1**2)
c_1 = _sqrtdenest_rec(sqrt(c2_1))
d_1 = _sqrtdenest_rec(sqrt(a1 + c_1))
num, den = rad_rationalize(b1, d_1)
c = _mexpand(d_1/sqrt(2) + num/(den*sqrt(2)))
else:
c = _sqrtdenest1(sqrt(c2))
if sqrt_depth(c) > 1:
raise SqrtdenestStopIteration
ac = a + c
if len(ac.args) >= len(expr.args):
if count_ops(ac) >= count_ops(expr.base):
raise SqrtdenestStopIteration
d = sqrtdenest(sqrt(ac))
if sqrt_depth(d) > 1:
raise SqrtdenestStopIteration
num, den = rad_rationalize(b, d)
r = d/sqrt(2) + num/(den*sqrt(2))
r = radsimp(r)
return _mexpand(r)
def _sqrtdenest1(expr, denester=True):
"""Return denested expr after denesting with simpler methods or, that
failing, using the denester."""
from sympy.simplify.simplify import radsimp
if not is_sqrt(expr):
return expr
a = expr.base
if a.is_Atom:
return expr
val = _sqrt_match(a)
if not val:
return expr
a, b, r = val
# try a quick numeric denesting
d2 = _mexpand(a**2 - b**2*r)
if d2.is_Rational:
if d2.is_positive:
z = _sqrt_numeric_denest(a, b, r, d2)
if z is not None:
return z
else:
# fourth root case
# sqrtdenest(sqrt(3 + 2*sqrt(3))) =
# sqrt(2)*3**(1/4)/2 + sqrt(2)*3**(3/4)/2
dr2 = _mexpand(-d2*r)
dr = sqrt(dr2)
if dr.is_Rational:
z = _sqrt_numeric_denest(_mexpand(b*r), a, r, dr2)
if z is not None:
return z/root(r, 4)
else:
z = _sqrt_symbolic_denest(a, b, r)
if z is not None:
return z
if not denester or not is_algebraic(expr):
return expr
res = sqrt_biquadratic_denest(expr, a, b, r, d2)
if res:
return res
# now call to the denester
av0 = [a, b, r, d2]
z = _denester([radsimp(expr**2)], av0, 0, sqrt_depth(expr))[0]
if av0[1] is None:
return expr
if z is not None:
if sqrt_depth(z) == sqrt_depth(expr) and count_ops(z) > count_ops(expr):
return expr
return z
return expr
def _sqrt_symbolic_denest(a, b, r):
"""Given an expression, sqrt(a + b*sqrt(b)), return the denested
expression or None.
Explanation
===========
If r = ra + rb*sqrt(rr), try replacing sqrt(rr) in ``a`` with
(y**2 - ra)/rb, and if the result is a quadratic, ca*y**2 + cb*y + cc, and
(cb + b)**2 - 4*ca*cc is 0, then sqrt(a + b*sqrt(r)) can be rewritten as
sqrt(ca*(sqrt(r) + (cb + b)/(2*ca))**2).
Examples
========
>>> from sympy.simplify.sqrtdenest import _sqrt_symbolic_denest, sqrtdenest
>>> from sympy import sqrt, Symbol
>>> from sympy.abc import x
>>> a, b, r = 16 - 2*sqrt(29), 2, -10*sqrt(29) + 55
>>> _sqrt_symbolic_denest(a, b, r)
sqrt(11 - 2*sqrt(29)) + sqrt(5)
If the expression is numeric, it will be simplified:
>>> w = sqrt(sqrt(sqrt(3) + 1) + 1) + 1 + sqrt(2)
>>> sqrtdenest(sqrt((w**2).expand()))
1 + sqrt(2) + sqrt(1 + sqrt(1 + sqrt(3)))
Otherwise, it will only be simplified if assumptions allow:
>>> w = w.subs(sqrt(3), sqrt(x + 3))
>>> sqrtdenest(sqrt((w**2).expand()))
sqrt((sqrt(sqrt(sqrt(x + 3) + 1) + 1) + 1 + sqrt(2))**2)
Notice that the argument of the sqrt is a square. If x is made positive
then the sqrt of the square is resolved:
>>> _.subs(x, Symbol('x', positive=True))
sqrt(sqrt(sqrt(x + 3) + 1) + 1) + 1 + sqrt(2)
"""
a, b, r = map(sympify, (a, b, r))
rval = _sqrt_match(r)
if not rval:
return None
ra, rb, rr = rval
if rb:
y = Dummy('y', positive=True)
try:
newa = Poly(a.subs(sqrt(rr), (y**2 - ra)/rb), y)
except PolynomialError:
return None
if newa.degree() == 2:
ca, cb, cc = newa.all_coeffs()
cb += b
if _mexpand(cb**2 - 4*ca*cc).equals(0):
z = sqrt(ca*(sqrt(r) + cb/(2*ca))**2)
if z.is_number:
z = _mexpand(Mul._from_args(z.as_content_primitive()))
return z
def _sqrt_numeric_denest(a, b, r, d2):
r"""Helper that denest
$\sqrt{a + b \sqrt{r}}, d^2 = a^2 - b^2 r > 0$
If it cannot be denested, it returns ``None``.
"""
d = sqrt(d2)
s = a + d
# sqrt_depth(res) <= sqrt_depth(s) + 1
# sqrt_depth(expr) = sqrt_depth(r) + 2
# there is denesting if sqrt_depth(s) + 1 < sqrt_depth(r) + 2
# if s**2 is Number there is a fourth root
if sqrt_depth(s) < sqrt_depth(r) + 1 or (s**2).is_Rational:
s1, s2 = sign(s), sign(b)
if s1 == s2 == -1:
s1 = s2 = 1
res = (s1 * sqrt(a + d) + s2 * sqrt(a - d)) * sqrt(2) / 2
return res.expand()
def sqrt_biquadratic_denest(expr, a, b, r, d2):
"""denest expr = sqrt(a + b*sqrt(r))
where a, b, r are linear combinations of square roots of
positive rationals on the rationals (SQRR) and r > 0, b != 0,
d2 = a**2 - b**2*r > 0
If it cannot denest it returns None.
Explanation
===========
Search for a solution A of type SQRR of the biquadratic equation
4*A**4 - 4*a*A**2 + b**2*r = 0 (1)
sqd = sqrt(a**2 - b**2*r)
Choosing the sqrt to be positive, the possible solutions are
A = sqrt(a/2 +/- sqd/2)
Since a, b, r are SQRR, then a**2 - b**2*r is a SQRR,
so if sqd can be denested, it is done by
_sqrtdenest_rec, and the result is a SQRR.
Similarly for A.
Examples of solutions (in both cases a and sqd are positive):
Example of expr with solution sqrt(a/2 + sqd/2) but not
solution sqrt(a/2 - sqd/2):
expr = sqrt(-sqrt(15) - sqrt(2)*sqrt(-sqrt(5) + 5) - sqrt(3) + 8)
a = -sqrt(15) - sqrt(3) + 8; sqd = -2*sqrt(5) - 2 + 4*sqrt(3)
Example of expr with solution sqrt(a/2 - sqd/2) but not
solution sqrt(a/2 + sqd/2):
w = 2 + r2 + r3 + (1 + r3)*sqrt(2 + r2 + 5*r3)
expr = sqrt((w**2).expand())
a = 4*sqrt(6) + 8*sqrt(2) + 47 + 28*sqrt(3)
sqd = 29 + 20*sqrt(3)
Define B = b/2*A; eq.(1) implies a = A**2 + B**2*r; then
expr**2 = a + b*sqrt(r) = (A + B*sqrt(r))**2
Examples
========
>>> from sympy import sqrt
>>> from sympy.simplify.sqrtdenest import _sqrt_match, sqrt_biquadratic_denest
>>> z = sqrt((2*sqrt(2) + 4)*sqrt(2 + sqrt(2)) + 5*sqrt(2) + 8)
>>> a, b, r = _sqrt_match(z**2)
>>> d2 = a**2 - b**2*r
>>> sqrt_biquadratic_denest(z, a, b, r, d2)
sqrt(2) + sqrt(sqrt(2) + 2) + 2
"""
from sympy.simplify.radsimp import radsimp, rad_rationalize
if r <= 0 or d2 < 0 or not b or sqrt_depth(expr.base) < 2:
return None
for x in (a, b, r):
for y in x.args:
y2 = y**2
if not y2.is_Integer or not y2.is_positive:
return None
sqd = _mexpand(sqrtdenest(sqrt(radsimp(d2))))
if sqrt_depth(sqd) > 1:
return None
x1, x2 = [a/2 + sqd/2, a/2 - sqd/2]
# look for a solution A with depth 1
for x in (x1, x2):
A = sqrtdenest(sqrt(x))
if sqrt_depth(A) > 1:
continue
Bn, Bd = rad_rationalize(b, _mexpand(2*A))
B = Bn/Bd
z = A + B*sqrt(r)
if z < 0:
z = -z
return _mexpand(z)
return None
def _denester(nested, av0, h, max_depth_level):
"""Denests a list of expressions that contain nested square roots.
Explanation
===========
Algorithm based on <http://www.almaden.ibm.com/cs/people/fagin/symb85.pdf>.
It is assumed that all of the elements of 'nested' share the same
bottom-level radicand. (This is stated in the paper, on page 177, in
the paragraph immediately preceding the algorithm.)
When evaluating all of the arguments in parallel, the bottom-level
radicand only needs to be denested once. This means that calling
_denester with x arguments results in a recursive invocation with x+1
arguments; hence _denester has polynomial complexity.
However, if the arguments were evaluated separately, each call would
result in two recursive invocations, and the algorithm would have
exponential complexity.
This is discussed in the paper in the middle paragraph of page 179.
"""
from sympy.simplify.simplify import radsimp
if h > max_depth_level:
return None, None
if av0[1] is None:
return None, None
if (av0[0] is None and
all(n.is_Number for n in nested)): # no arguments are nested
for f in _subsets(len(nested)): # test subset 'f' of nested
p = _mexpand(Mul(*[nested[i] for i in range(len(f)) if f[i]]))
if f.count(1) > 1 and f[-1]:
p = -p
sqp = sqrt(p)
if sqp.is_Rational:
return sqp, f # got a perfect square so return its square root.
# Otherwise, return the radicand from the previous invocation.
return sqrt(nested[-1]), [0]*len(nested)
else:
R = None
if av0[0] is not None:
values = [av0[:2]]
R = av0[2]
nested2 = [av0[3], R]
av0[0] = None
else:
values = list(filter(None, [_sqrt_match(expr) for expr in nested]))
for v in values:
if v[2]: # Since if b=0, r is not defined
if R is not None:
if R != v[2]:
av0[1] = None
return None, None
else:
R = v[2]
if R is None:
# return the radicand from the previous invocation
return sqrt(nested[-1]), [0]*len(nested)
nested2 = [_mexpand(v[0]**2) -
_mexpand(R*v[1]**2) for v in values] + [R]
d, f = _denester(nested2, av0, h + 1, max_depth_level)
if not f:
return None, None
if not any(f[i] for i in range(len(nested))):
v = values[-1]
return sqrt(v[0] + _mexpand(v[1]*d)), f
else:
p = Mul(*[nested[i] for i in range(len(nested)) if f[i]])
v = _sqrt_match(p)
if 1 in f and f.index(1) < len(nested) - 1 and f[len(nested) - 1]:
v[0] = -v[0]
v[1] = -v[1]
if not f[len(nested)]: # Solution denests with square roots
vad = _mexpand(v[0] + d)
if vad <= 0:
# return the radicand from the previous invocation.
return sqrt(nested[-1]), [0]*len(nested)
if not(sqrt_depth(vad) <= sqrt_depth(R) + 1 or
(vad**2).is_Number):
av0[1] = None
return None, None
sqvad = _sqrtdenest1(sqrt(vad), denester=False)
if not (sqrt_depth(sqvad) <= sqrt_depth(R) + 1):
av0[1] = None
return None, None
sqvad1 = radsimp(1/sqvad)
res = _mexpand(sqvad/sqrt(2) + (v[1]*sqrt(R)*sqvad1/sqrt(2)))
return res, f
# sign(v[1])*sqrt(_mexpand(v[1]**2*R*vad1/2))), f
else: # Solution requires a fourth root
s2 = _mexpand(v[1]*R) + d
if s2 <= 0:
return sqrt(nested[-1]), [0]*len(nested)
FR, s = root(_mexpand(R), 4), sqrt(s2)
return _mexpand(s/(sqrt(2)*FR) + v[0]*FR/(sqrt(2)*s)), f
def _sqrt_ratcomb(cs, args):
"""Denest rational combinations of radicals.
Based on section 5 of [1].
Examples
========
>>> from sympy import sqrt
>>> from sympy.simplify.sqrtdenest import sqrtdenest
>>> z = sqrt(1+sqrt(3)) + sqrt(3+3*sqrt(3)) - sqrt(10+6*sqrt(3))
>>> sqrtdenest(z)
0
"""
from sympy.simplify.radsimp import radsimp
# check if there exists a pair of sqrt that can be denested
def find(a):
n = len(a)
for i in range(n - 1):
for j in range(i + 1, n):
s1 = a[i].base
s2 = a[j].base
p = _mexpand(s1 * s2)
s = sqrtdenest(sqrt(p))
if s != sqrt(p):
return s, i, j
indices = find(args)
if indices is None:
return Add(*[c * arg for c, arg in zip(cs, args)])
s, i1, i2 = indices
c2 = cs.pop(i2)
args.pop(i2)
a1 = args[i1]
# replace a2 by s/a1
cs[i1] += radsimp(c2 * s / a1.base)
return _sqrt_ratcomb(cs, args)

View File

@ -0,0 +1,75 @@
from sympy.core.numbers import Rational
from sympy.core.symbol import symbols
from sympy.functions.combinatorial.factorials import (FallingFactorial, RisingFactorial, binomial, factorial)
from sympy.functions.special.gamma_functions import gamma
from sympy.simplify.combsimp import combsimp
from sympy.abc import x
def test_combsimp():
k, m, n = symbols('k m n', integer = True)
assert combsimp(factorial(n)) == factorial(n)
assert combsimp(binomial(n, k)) == binomial(n, k)
assert combsimp(factorial(n)/factorial(n - 3)) == n*(-1 + n)*(-2 + n)
assert combsimp(binomial(n + 1, k + 1)/binomial(n, k)) == (1 + n)/(1 + k)
assert combsimp(binomial(3*n + 4, n + 1)/binomial(3*n + 1, n)) == \
Rational(3, 2)*((3*n + 2)*(3*n + 4)/((n + 1)*(2*n + 3)))
assert combsimp(factorial(n)**2/factorial(n - 3)) == \
factorial(n)*n*(-1 + n)*(-2 + n)
assert combsimp(factorial(n)*binomial(n + 1, k + 1)/binomial(n, k)) == \
factorial(n + 1)/(1 + k)
assert combsimp(gamma(n + 3)) == factorial(n + 2)
assert combsimp(factorial(x)) == gamma(x + 1)
# issue 9699
assert combsimp((n + 1)*factorial(n)) == factorial(n + 1)
assert combsimp(factorial(n)/n) == factorial(n-1)
# issue 6658
assert combsimp(binomial(n, n - k)) == binomial(n, k)
# issue 6341, 7135
assert combsimp(factorial(n)/(factorial(k)*factorial(n - k))) == \
binomial(n, k)
assert combsimp(factorial(k)*factorial(n - k)/factorial(n)) == \
1/binomial(n, k)
assert combsimp(factorial(2*n)/factorial(n)**2) == binomial(2*n, n)
assert combsimp(factorial(2*n)*factorial(k)*factorial(n - k)/
factorial(n)**3) == binomial(2*n, n)/binomial(n, k)
assert combsimp(factorial(n*(1 + n) - n**2 - n)) == 1
assert combsimp(6*FallingFactorial(-4, n)/factorial(n)) == \
(-1)**n*(n + 1)*(n + 2)*(n + 3)
assert combsimp(6*FallingFactorial(-4, n - 1)/factorial(n - 1)) == \
(-1)**(n - 1)*n*(n + 1)*(n + 2)
assert combsimp(6*FallingFactorial(-4, n - 3)/factorial(n - 3)) == \
(-1)**(n - 3)*n*(n - 1)*(n - 2)
assert combsimp(6*FallingFactorial(-4, -n - 1)/factorial(-n - 1)) == \
-(-1)**(-n - 1)*n*(n - 1)*(n - 2)
assert combsimp(6*RisingFactorial(4, n)/factorial(n)) == \
(n + 1)*(n + 2)*(n + 3)
assert combsimp(6*RisingFactorial(4, n - 1)/factorial(n - 1)) == \
n*(n + 1)*(n + 2)
assert combsimp(6*RisingFactorial(4, n - 3)/factorial(n - 3)) == \
n*(n - 1)*(n - 2)
assert combsimp(6*RisingFactorial(4, -n - 1)/factorial(-n - 1)) == \
-n*(n - 1)*(n - 2)
def test_issue_6878():
n = symbols('n', integer=True)
assert combsimp(RisingFactorial(-10, n)) == 3628800*(-1)**n/factorial(10 - n)
def test_issue_14528():
p = symbols("p", integer=True, positive=True)
assert combsimp(binomial(1,p)) == 1/(factorial(p)*factorial(1-p))
assert combsimp(factorial(2-p)) == factorial(2-p)

View File

@ -0,0 +1,758 @@
from functools import reduce
import itertools
from operator import add
from sympy.codegen.matrix_nodes import MatrixSolve
from sympy.core.add import Add
from sympy.core.containers import Tuple
from sympy.core.expr import UnevaluatedExpr
from sympy.core.function import Function
from sympy.core.mul import Mul
from sympy.core.power import Pow
from sympy.core.relational import Eq
from sympy.core.singleton import S
from sympy.core.symbol import (Symbol, symbols)
from sympy.core.sympify import sympify
from sympy.functions.elementary.exponential import exp
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.piecewise import Piecewise
from sympy.functions.elementary.trigonometric import (cos, sin)
from sympy.matrices.dense import Matrix
from sympy.matrices.expressions import Inverse, MatAdd, MatMul, Transpose
from sympy.polys.rootoftools import CRootOf
from sympy.series.order import O
from sympy.simplify.cse_main import cse
from sympy.simplify.simplify import signsimp
from sympy.tensor.indexed import (Idx, IndexedBase)
from sympy.core.function import count_ops
from sympy.simplify.cse_opts import sub_pre, sub_post
from sympy.functions.special.hyper import meijerg
from sympy.simplify import cse_main, cse_opts
from sympy.utilities.iterables import subsets
from sympy.testing.pytest import XFAIL, raises
from sympy.matrices import (MutableDenseMatrix, MutableSparseMatrix,
ImmutableDenseMatrix, ImmutableSparseMatrix)
from sympy.matrices.expressions import MatrixSymbol
w, x, y, z = symbols('w,x,y,z')
x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12 = symbols('x:13')
def test_numbered_symbols():
ns = cse_main.numbered_symbols(prefix='y')
assert list(itertools.islice(
ns, 0, 10)) == [Symbol('y%s' % i) for i in range(0, 10)]
ns = cse_main.numbered_symbols(prefix='y')
assert list(itertools.islice(
ns, 10, 20)) == [Symbol('y%s' % i) for i in range(10, 20)]
ns = cse_main.numbered_symbols()
assert list(itertools.islice(
ns, 0, 10)) == [Symbol('x%s' % i) for i in range(0, 10)]
# Dummy "optimization" functions for testing.
def opt1(expr):
return expr + y
def opt2(expr):
return expr*z
def test_preprocess_for_cse():
assert cse_main.preprocess_for_cse(x, [(opt1, None)]) == x + y
assert cse_main.preprocess_for_cse(x, [(None, opt1)]) == x
assert cse_main.preprocess_for_cse(x, [(None, None)]) == x
assert cse_main.preprocess_for_cse(x, [(opt1, opt2)]) == x + y
assert cse_main.preprocess_for_cse(
x, [(opt1, None), (opt2, None)]) == (x + y)*z
def test_postprocess_for_cse():
assert cse_main.postprocess_for_cse(x, [(opt1, None)]) == x
assert cse_main.postprocess_for_cse(x, [(None, opt1)]) == x + y
assert cse_main.postprocess_for_cse(x, [(None, None)]) == x
assert cse_main.postprocess_for_cse(x, [(opt1, opt2)]) == x*z
# Note the reverse order of application.
assert cse_main.postprocess_for_cse(
x, [(None, opt1), (None, opt2)]) == x*z + y
def test_cse_single():
# Simple substitution.
e = Add(Pow(x + y, 2), sqrt(x + y))
substs, reduced = cse([e])
assert substs == [(x0, x + y)]
assert reduced == [sqrt(x0) + x0**2]
subst42, (red42,) = cse([42]) # issue_15082
assert len(subst42) == 0 and red42 == 42
subst_half, (red_half,) = cse([0.5])
assert len(subst_half) == 0 and red_half == 0.5
def test_cse_single2():
# Simple substitution, test for being able to pass the expression directly
e = Add(Pow(x + y, 2), sqrt(x + y))
substs, reduced = cse(e)
assert substs == [(x0, x + y)]
assert reduced == [sqrt(x0) + x0**2]
substs, reduced = cse(Matrix([[1]]))
assert isinstance(reduced[0], Matrix)
subst42, (red42,) = cse(42) # issue 15082
assert len(subst42) == 0 and red42 == 42
subst_half, (red_half,) = cse(0.5) # issue 15082
assert len(subst_half) == 0 and red_half == 0.5
def test_cse_not_possible():
# No substitution possible.
e = Add(x, y)
substs, reduced = cse([e])
assert substs == []
assert reduced == [x + y]
# issue 6329
eq = (meijerg((1, 2), (y, 4), (5,), [], x) +
meijerg((1, 3), (y, 4), (5,), [], x))
assert cse(eq) == ([], [eq])
def test_nested_substitution():
# Substitution within a substitution.
e = Add(Pow(w*x + y, 2), sqrt(w*x + y))
substs, reduced = cse([e])
assert substs == [(x0, w*x + y)]
assert reduced == [sqrt(x0) + x0**2]
def test_subtraction_opt():
# Make sure subtraction is optimized.
e = (x - y)*(z - y) + exp((x - y)*(z - y))
substs, reduced = cse(
[e], optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)])
assert substs == [(x0, (x - y)*(y - z))]
assert reduced == [-x0 + exp(-x0)]
e = -(x - y)*(z - y) + exp(-(x - y)*(z - y))
substs, reduced = cse(
[e], optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)])
assert substs == [(x0, (x - y)*(y - z))]
assert reduced == [x0 + exp(x0)]
# issue 4077
n = -1 + 1/x
e = n/x/(-n)**2 - 1/n/x
assert cse(e, optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)]) == \
([], [0])
assert cse(((w + x + y + z)*(w - y - z))/(w + x)**3) == \
([(x0, w + x), (x1, y + z)], [(w - x1)*(x0 + x1)/x0**3])
def test_multiple_expressions():
e1 = (x + y)*z
e2 = (x + y)*w
substs, reduced = cse([e1, e2])
assert substs == [(x0, x + y)]
assert reduced == [x0*z, x0*w]
l = [w*x*y + z, w*y]
substs, reduced = cse(l)
rsubsts, _ = cse(reversed(l))
assert substs == rsubsts
assert reduced == [z + x*x0, x0]
l = [w*x*y, w*x*y + z, w*y]
substs, reduced = cse(l)
rsubsts, _ = cse(reversed(l))
assert substs == rsubsts
assert reduced == [x1, x1 + z, x0]
l = [(x - z)*(y - z), x - z, y - z]
substs, reduced = cse(l)
rsubsts, _ = cse(reversed(l))
assert substs == [(x0, -z), (x1, x + x0), (x2, x0 + y)]
assert rsubsts == [(x0, -z), (x1, x0 + y), (x2, x + x0)]
assert reduced == [x1*x2, x1, x2]
l = [w*y + w + x + y + z, w*x*y]
assert cse(l) == ([(x0, w*y)], [w + x + x0 + y + z, x*x0])
assert cse([x + y, x + y + z]) == ([(x0, x + y)], [x0, z + x0])
assert cse([x + y, x + z]) == ([], [x + y, x + z])
assert cse([x*y, z + x*y, x*y*z + 3]) == \
([(x0, x*y)], [x0, z + x0, 3 + x0*z])
@XFAIL # CSE of non-commutative Mul terms is disabled
def test_non_commutative_cse():
A, B, C = symbols('A B C', commutative=False)
l = [A*B*C, A*C]
assert cse(l) == ([], l)
l = [A*B*C, A*B]
assert cse(l) == ([(x0, A*B)], [x0*C, x0])
# Test if CSE of non-commutative Mul terms is disabled
def test_bypass_non_commutatives():
A, B, C = symbols('A B C', commutative=False)
l = [A*B*C, A*C]
assert cse(l) == ([], l)
l = [A*B*C, A*B]
assert cse(l) == ([], l)
l = [B*C, A*B*C]
assert cse(l) == ([], l)
@XFAIL # CSE fails when replacing non-commutative sub-expressions
def test_non_commutative_order():
A, B, C = symbols('A B C', commutative=False)
x0 = symbols('x0', commutative=False)
l = [B+C, A*(B+C)]
assert cse(l) == ([(x0, B+C)], [x0, A*x0])
@XFAIL # Worked in gh-11232, but was reverted due to performance considerations
def test_issue_10228():
assert cse([x*y**2 + x*y]) == ([(x0, x*y)], [x0*y + x0])
assert cse([x + y, 2*x + y]) == ([(x0, x + y)], [x0, x + x0])
assert cse((w + 2*x + y + z, w + x + 1)) == (
[(x0, w + x)], [x0 + x + y + z, x0 + 1])
assert cse(((w + x + y + z)*(w - x))/(w + x)) == (
[(x0, w + x)], [(x0 + y + z)*(w - x)/x0])
a, b, c, d, f, g, j, m = symbols('a, b, c, d, f, g, j, m')
exprs = (d*g**2*j*m, 4*a*f*g*m, a*b*c*f**2)
assert cse(exprs) == (
[(x0, g*m), (x1, a*f)], [d*g*j*x0, 4*x0*x1, b*c*f*x1]
)
@XFAIL
def test_powers():
assert cse(x*y**2 + x*y) == ([(x0, x*y)], [x0*y + x0])
def test_issue_4498():
assert cse(w/(x - y) + z/(y - x), optimizations='basic') == \
([], [(w - z)/(x - y)])
def test_issue_4020():
assert cse(x**5 + x**4 + x**3 + x**2, optimizations='basic') \
== ([(x0, x**2)], [x0*(x**3 + x + x0 + 1)])
def test_issue_4203():
assert cse(sin(x**x)/x**x) == ([(x0, x**x)], [sin(x0)/x0])
def test_issue_6263():
e = Eq(x*(-x + 1) + x*(x - 1), 0)
assert cse(e, optimizations='basic') == ([], [True])
def test_issue_25043():
c = symbols("c")
x = symbols("x0", real=True)
cse_expr = cse(c*x**2 + c*(x**4 - x**2))[-1][-1]
free = cse_expr.free_symbols
assert len(free) == len({i.name for i in free})
def test_dont_cse_tuples():
from sympy.core.function import Subs
f = Function("f")
g = Function("g")
name_val, (expr,) = cse(
Subs(f(x, y), (x, y), (0, 1))
+ Subs(g(x, y), (x, y), (0, 1)))
assert name_val == []
assert expr == (Subs(f(x, y), (x, y), (0, 1))
+ Subs(g(x, y), (x, y), (0, 1)))
name_val, (expr,) = cse(
Subs(f(x, y), (x, y), (0, x + y))
+ Subs(g(x, y), (x, y), (0, x + y)))
assert name_val == [(x0, x + y)]
assert expr == Subs(f(x, y), (x, y), (0, x0)) + \
Subs(g(x, y), (x, y), (0, x0))
def test_pow_invpow():
assert cse(1/x**2 + x**2) == \
([(x0, x**2)], [x0 + 1/x0])
assert cse(x**2 + (1 + 1/x**2)/x**2) == \
([(x0, x**2), (x1, 1/x0)], [x0 + x1*(x1 + 1)])
assert cse(1/x**2 + (1 + 1/x**2)*x**2) == \
([(x0, x**2), (x1, 1/x0)], [x0*(x1 + 1) + x1])
assert cse(cos(1/x**2) + sin(1/x**2)) == \
([(x0, x**(-2))], [sin(x0) + cos(x0)])
assert cse(cos(x**2) + sin(x**2)) == \
([(x0, x**2)], [sin(x0) + cos(x0)])
assert cse(y/(2 + x**2) + z/x**2/y) == \
([(x0, x**2)], [y/(x0 + 2) + z/(x0*y)])
assert cse(exp(x**2) + x**2*cos(1/x**2)) == \
([(x0, x**2)], [x0*cos(1/x0) + exp(x0)])
assert cse((1 + 1/x**2)/x**2) == \
([(x0, x**(-2))], [x0*(x0 + 1)])
assert cse(x**(2*y) + x**(-2*y)) == \
([(x0, x**(2*y))], [x0 + 1/x0])
def test_postprocess():
eq = (x + 1 + exp((x + 1)/(y + 1)) + cos(y + 1))
assert cse([eq, Eq(x, z + 1), z - 2, (z + 1)*(x + 1)],
postprocess=cse_main.cse_separate) == \
[[(x0, y + 1), (x2, z + 1), (x, x2), (x1, x + 1)],
[x1 + exp(x1/x0) + cos(x0), z - 2, x1*x2]]
def test_issue_4499():
# previously, this gave 16 constants
from sympy.abc import a, b
B = Function('B')
G = Function('G')
t = Tuple(*
(a, a + S.Half, 2*a, b, 2*a - b + 1, (sqrt(z)/2)**(-2*a + 1)*B(2*a -
b, sqrt(z))*B(b - 1, sqrt(z))*G(b)*G(2*a - b + 1),
sqrt(z)*(sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b,
sqrt(z))*G(b)*G(2*a - b + 1), sqrt(z)*(sqrt(z)/2)**(-2*a + 1)*B(b - 1,
sqrt(z))*B(2*a - b + 1, sqrt(z))*G(b)*G(2*a - b + 1),
(sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b + 1,
sqrt(z))*G(b)*G(2*a - b + 1), 1, 0, S.Half, z/2, -b + 1, -2*a + b,
-2*a))
c = cse(t)
ans = (
[(x0, 2*a), (x1, -b + x0), (x2, x1 + 1), (x3, b - 1), (x4, sqrt(z)),
(x5, B(x3, x4)), (x6, (x4/2)**(1 - x0)*G(b)*G(x2)), (x7, x6*B(x1, x4)),
(x8, B(b, x4)), (x9, x6*B(x2, x4))],
[(a, a + S.Half, x0, b, x2, x5*x7, x4*x7*x8, x4*x5*x9, x8*x9,
1, 0, S.Half, z/2, -x3, -x1, -x0)])
assert ans == c
def test_issue_6169():
r = CRootOf(x**6 - 4*x**5 - 2, 1)
assert cse(r) == ([], [r])
# and a check that the right thing is done with the new
# mechanism
assert sub_post(sub_pre((-x - y)*z - x - y)) == -z*(x + y) - x - y
def test_cse_Indexed():
len_y = 5
y = IndexedBase('y', shape=(len_y,))
x = IndexedBase('x', shape=(len_y,))
i = Idx('i', len_y-1)
expr1 = (y[i+1]-y[i])/(x[i+1]-x[i])
expr2 = 1/(x[i+1]-x[i])
replacements, reduced_exprs = cse([expr1, expr2])
assert len(replacements) > 0
def test_cse_MatrixSymbol():
# MatrixSymbols have non-Basic args, so make sure that works
A = MatrixSymbol("A", 3, 3)
assert cse(A) == ([], [A])
n = symbols('n', integer=True)
B = MatrixSymbol("B", n, n)
assert cse(B) == ([], [B])
assert cse(A[0] * A[0]) == ([], [A[0]*A[0]])
assert cse(A[0,0]*A[0,1] + A[0,0]*A[0,1]*A[0,2]) == ([(x0, A[0, 0]*A[0, 1])], [x0*A[0, 2] + x0])
def test_cse_MatrixExpr():
A = MatrixSymbol('A', 3, 3)
y = MatrixSymbol('y', 3, 1)
expr1 = (A.T*A).I * A * y
expr2 = (A.T*A) * A * y
replacements, reduced_exprs = cse([expr1, expr2])
assert len(replacements) > 0
replacements, reduced_exprs = cse([expr1 + expr2, expr1])
assert replacements
replacements, reduced_exprs = cse([A**2, A + A**2])
assert replacements
def test_Piecewise():
f = Piecewise((-z + x*y, Eq(y, 0)), (-z - x*y, True))
ans = cse(f)
actual_ans = ([(x0, x*y)],
[Piecewise((x0 - z, Eq(y, 0)), (-z - x0, True))])
assert ans == actual_ans
def test_ignore_order_terms():
eq = exp(x).series(x,0,3) + sin(y+x**3) - 1
assert cse(eq) == ([], [sin(x**3 + y) + x + x**2/2 + O(x**3)])
def test_name_conflict():
z1 = x0 + y
z2 = x2 + x3
l = [cos(z1) + z1, cos(z2) + z2, x0 + x2]
substs, reduced = cse(l)
assert [e.subs(reversed(substs)) for e in reduced] == l
def test_name_conflict_cust_symbols():
z1 = x0 + y
z2 = x2 + x3
l = [cos(z1) + z1, cos(z2) + z2, x0 + x2]
substs, reduced = cse(l, symbols("x:10"))
assert [e.subs(reversed(substs)) for e in reduced] == l
def test_symbols_exhausted_error():
l = cos(x+y)+x+y+cos(w+y)+sin(w+y)
sym = [x, y, z]
with raises(ValueError):
cse(l, symbols=sym)
def test_issue_7840():
# daveknippers' example
C393 = sympify( \
'Piecewise((C391 - 1.65, C390 < 0.5), (Piecewise((C391 - 1.65, \
C391 > 2.35), (C392, True)), True))'
)
C391 = sympify( \
'Piecewise((2.05*C390**(-1.03), C390 < 0.5), (2.5*C390**(-0.625), True))'
)
C393 = C393.subs('C391',C391)
# simple substitution
sub = {}
sub['C390'] = 0.703451854
sub['C392'] = 1.01417794
ss_answer = C393.subs(sub)
# cse
substitutions,new_eqn = cse(C393)
for pair in substitutions:
sub[pair[0].name] = pair[1].subs(sub)
cse_answer = new_eqn[0].subs(sub)
# both methods should be the same
assert ss_answer == cse_answer
# GitRay's example
expr = sympify(
"Piecewise((Symbol('ON'), Equality(Symbol('mode'), Symbol('ON'))), \
(Piecewise((Piecewise((Symbol('OFF'), StrictLessThan(Symbol('x'), \
Symbol('threshold'))), (Symbol('ON'), true)), Equality(Symbol('mode'), \
Symbol('AUTO'))), (Symbol('OFF'), true)), true))"
)
substitutions, new_eqn = cse(expr)
# this Piecewise should be exactly the same
assert new_eqn[0] == expr
# there should not be any replacements
assert len(substitutions) < 1
def test_issue_8891():
for cls in (MutableDenseMatrix, MutableSparseMatrix,
ImmutableDenseMatrix, ImmutableSparseMatrix):
m = cls(2, 2, [x + y, 0, 0, 0])
res = cse([x + y, m])
ans = ([(x0, x + y)], [x0, cls([[x0, 0], [0, 0]])])
assert res == ans
assert isinstance(res[1][-1], cls)
def test_issue_11230():
# a specific test that always failed
a, b, f, k, l, i = symbols('a b f k l i')
p = [a*b*f*k*l, a*i*k**2*l, f*i*k**2*l]
R, C = cse(p)
assert not any(i.is_Mul for a in C for i in a.args)
# random tests for the issue
from sympy.core.random import choice
from sympy.core.function import expand_mul
s = symbols('a:m')
# 35 Mul tests, none of which should ever fail
ex = [Mul(*[choice(s) for i in range(5)]) for i in range(7)]
for p in subsets(ex, 3):
p = list(p)
R, C = cse(p)
assert not any(i.is_Mul for a in C for i in a.args)
for ri in reversed(R):
for i in range(len(C)):
C[i] = C[i].subs(*ri)
assert p == C
# 35 Add tests, none of which should ever fail
ex = [Add(*[choice(s[:7]) for i in range(5)]) for i in range(7)]
for p in subsets(ex, 3):
p = list(p)
R, C = cse(p)
assert not any(i.is_Add for a in C for i in a.args)
for ri in reversed(R):
for i in range(len(C)):
C[i] = C[i].subs(*ri)
# use expand_mul to handle cases like this:
# p = [a + 2*b + 2*e, 2*b + c + 2*e, b + 2*c + 2*g]
# x0 = 2*(b + e) is identified giving a rebuilt p that
# is now `[a + 2*(b + e), c + 2*(b + e), b + 2*c + 2*g]`
assert p == [expand_mul(i) for i in C]
@XFAIL
def test_issue_11577():
def check(eq):
r, c = cse(eq)
assert eq.count_ops() >= \
len(r) + sum(i[1].count_ops() for i in r) + \
count_ops(c)
eq = x**5*y**2 + x**5*y + x**5
assert cse(eq) == (
[(x0, x**4), (x1, x*y)], [x**5 + x0*x1*y + x0*x1])
# ([(x0, x**5*y)], [x0*y + x0 + x**5]) or
# ([(x0, x**5)], [x0*y**2 + x0*y + x0])
check(eq)
eq = x**2/(y + 1)**2 + x/(y + 1)
assert cse(eq) == (
[(x0, y + 1)], [x**2/x0**2 + x/x0])
# ([(x0, x/(y + 1))], [x0**2 + x0])
check(eq)
def test_hollow_rejection():
eq = [x + 3, x + 4]
assert cse(eq) == ([], eq)
def test_cse_ignore():
exprs = [exp(y)*(3*y + 3*sqrt(x+1)), exp(y)*(5*y + 5*sqrt(x+1))]
subst1, red1 = cse(exprs)
assert any(y in sub.free_symbols for _, sub in subst1), "cse failed to identify any term with y"
subst2, red2 = cse(exprs, ignore=(y,)) # y is not allowed in substitutions
assert not any(y in sub.free_symbols for _, sub in subst2), "Sub-expressions containing y must be ignored"
assert any(sub - sqrt(x + 1) == 0 for _, sub in subst2), "cse failed to identify sqrt(x + 1) as sub-expression"
def test_cse_ignore_issue_15002():
l = [
w*exp(x)*exp(-z),
exp(y)*exp(x)*exp(-z)
]
substs, reduced = cse(l, ignore=(x,))
rl = [e.subs(reversed(substs)) for e in reduced]
assert rl == l
def test_cse_unevaluated():
xp1 = UnevaluatedExpr(x + 1)
# This used to cause RecursionError
[(x0, ue)], [red] = cse([(-1 - xp1) / (1 - xp1)])
if ue == xp1:
assert red == (-1 - x0) / (1 - x0)
elif ue == -xp1:
assert red == (-1 + x0) / (1 + x0)
else:
msg = f'Expected common subexpression {xp1} or {-xp1}, instead got {ue}'
assert False, msg
def test_cse__performance():
nexprs, nterms = 3, 20
x = symbols('x:%d' % nterms)
exprs = [
reduce(add, [x[j]*(-1)**(i+j) for j in range(nterms)])
for i in range(nexprs)
]
assert (exprs[0] + exprs[1]).simplify() == 0
subst, red = cse(exprs)
assert len(subst) > 0, "exprs[0] == -exprs[2], i.e. a CSE"
for i, e in enumerate(red):
assert (e.subs(reversed(subst)) - exprs[i]).simplify() == 0
def test_issue_12070():
exprs = [x + y, 2 + x + y, x + y + z, 3 + x + y + z]
subst, red = cse(exprs)
assert 6 >= (len(subst) + sum(v.count_ops() for k, v in subst) +
count_ops(red))
def test_issue_13000():
eq = x/(-4*x**2 + y**2)
cse_eq = cse(eq)[1][0]
assert cse_eq == eq
def test_issue_18203():
eq = CRootOf(x**5 + 11*x - 2, 0) + CRootOf(x**5 + 11*x - 2, 1)
assert cse(eq) == ([], [eq])
def test_unevaluated_mul():
eq = Mul(x + y, x + y, evaluate=False)
assert cse(eq) == ([(x0, x + y)], [x0**2])
def test_cse_release_variables():
from sympy.simplify.cse_main import cse_release_variables
_0, _1, _2, _3, _4 = symbols('_:5')
eqs = [(x + y - 1)**2, x,
x + y, (x + y)/(2*x + 1) + (x + y - 1)**2,
(2*x + 1)**(x + y)]
r, e = cse(eqs, postprocess=cse_release_variables)
# this can change in keeping with the intention of the function
assert r, e == ([
(x0, x + y), (x1, (x0 - 1)**2), (x2, 2*x + 1),
(_3, x0/x2 + x1), (_4, x2**x0), (x2, None), (_0, x1),
(x1, None), (_2, x0), (x0, None), (_1, x)], (_0, _1, _2, _3, _4))
r.reverse()
r = [(s, v) for s, v in r if v is not None]
assert eqs == [i.subs(r) for i in e]
def test_cse_list():
_cse = lambda x: cse(x, list=False)
assert _cse(x) == ([], x)
assert _cse('x') == ([], 'x')
it = [x]
for c in (list, tuple, set):
assert _cse(c(it)) == ([], c(it))
#Tuple works different from tuple:
assert _cse(Tuple(*it)) == ([], Tuple(*it))
d = {x: 1}
assert _cse(d) == ([], d)
def test_issue_18991():
A = MatrixSymbol('A', 2, 2)
assert signsimp(-A * A - A) == -A * A - A
def test_unevaluated_Mul():
m = [Mul(1, 2, evaluate=False)]
assert cse(m) == ([], m)
def test_cse_matrix_expression_inverse():
A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
x = Inverse(A)
cse_expr = cse(x)
assert cse_expr == ([], [Inverse(A)])
def test_cse_matrix_expression_matmul_inverse():
A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
b = ImmutableDenseMatrix(symbols('b:2'))
x = MatMul(Inverse(A), b)
cse_expr = cse(x)
assert cse_expr == ([], [x])
def test_cse_matrix_negate_matrix():
A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
x = MatMul(S.NegativeOne, A)
cse_expr = cse(x)
assert cse_expr == ([], [x])
def test_cse_matrix_negate_matmul_not_extracted():
A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
B = ImmutableDenseMatrix(symbols('B:4')).reshape(2, 2)
x = MatMul(S.NegativeOne, A, B)
cse_expr = cse(x)
assert cse_expr == ([], [x])
@XFAIL # No simplification rule for nested associative operations
def test_cse_matrix_nested_matmul_collapsed():
A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
B = ImmutableDenseMatrix(symbols('B:4')).reshape(2, 2)
x = MatMul(S.NegativeOne, MatMul(A, B))
cse_expr = cse(x)
assert cse_expr == ([], [MatMul(S.NegativeOne, A, B)])
def test_cse_matrix_optimize_out_single_argument_mul():
A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
x = MatMul(MatMul(MatMul(A)))
cse_expr = cse(x)
assert cse_expr == ([], [A])
@XFAIL # Multiple simplification passed not supported in CSE
def test_cse_matrix_optimize_out_single_argument_mul_combined():
A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
x = MatAdd(MatMul(MatMul(MatMul(A))), MatMul(MatMul(A)), MatMul(A), A)
cse_expr = cse(x)
assert cse_expr == ([], [MatMul(4, A)])
def test_cse_matrix_optimize_out_single_argument_add():
A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
x = MatAdd(MatAdd(MatAdd(MatAdd(A))))
cse_expr = cse(x)
assert cse_expr == ([], [A])
@XFAIL # Multiple simplification passed not supported in CSE
def test_cse_matrix_optimize_out_single_argument_add_combined():
A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
x = MatMul(MatAdd(MatAdd(MatAdd(A))), MatAdd(MatAdd(A)), MatAdd(A), A)
cse_expr = cse(x)
assert cse_expr == ([], [MatMul(4, A)])
def test_cse_matrix_expression_matrix_solve():
A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
b = ImmutableDenseMatrix(symbols('b:2'))
x = MatrixSolve(A, b)
cse_expr = cse(x)
assert cse_expr == ([], [x])
def test_cse_matrix_matrix_expression():
X = ImmutableDenseMatrix(symbols('X:4')).reshape(2, 2)
y = ImmutableDenseMatrix(symbols('y:2'))
b = MatMul(Inverse(MatMul(Transpose(X), X)), Transpose(X), y)
cse_expr = cse(b)
x0 = MatrixSymbol('x0', 2, 2)
reduced_expr_expected = MatMul(Inverse(MatMul(x0, X)), x0, y)
assert cse_expr == ([(x0, Transpose(X))], [reduced_expr_expected])
def test_cse_matrix_kalman_filter():
"""Kalman Filter example from Matthew Rocklin's SciPy 2013 talk.
Talk titled: "Matrix Expressions and BLAS/LAPACK; SciPy 2013 Presentation"
Video: https://pyvideo.org/scipy-2013/matrix-expressions-and-blaslapack-scipy-2013-pr.html
Notes
=====
Equations are:
new_mu = mu + Sigma*H.T * (R + H*Sigma*H.T).I * (H*mu - data)
= MatAdd(mu, MatMul(Sigma, Transpose(H), Inverse(MatAdd(R, MatMul(H, Sigma, Transpose(H)))), MatAdd(MatMul(H, mu), MatMul(S.NegativeOne, data))))
new_Sigma = Sigma - Sigma*H.T * (R + H*Sigma*H.T).I * H * Sigma
= MatAdd(Sigma, MatMul(S.NegativeOne, Sigma, Transpose(H)), Inverse(MatAdd(R, MatMul(H*Sigma*Transpose(H)))), H, Sigma))
"""
N = 2
mu = ImmutableDenseMatrix(symbols(f'mu:{N}'))
Sigma = ImmutableDenseMatrix(symbols(f'Sigma:{N * N}')).reshape(N, N)
H = ImmutableDenseMatrix(symbols(f'H:{N * N}')).reshape(N, N)
R = ImmutableDenseMatrix(symbols(f'R:{N * N}')).reshape(N, N)
data = ImmutableDenseMatrix(symbols(f'data:{N}'))
new_mu = MatAdd(mu, MatMul(Sigma, Transpose(H), Inverse(MatAdd(R, MatMul(H, Sigma, Transpose(H)))), MatAdd(MatMul(H, mu), MatMul(S.NegativeOne, data))))
new_Sigma = MatAdd(Sigma, MatMul(S.NegativeOne, Sigma, Transpose(H), Inverse(MatAdd(R, MatMul(H, Sigma, Transpose(H)))), H, Sigma))
cse_expr = cse([new_mu, new_Sigma])
x0 = MatrixSymbol('x0', N, N)
x1 = MatrixSymbol('x1', N, N)
replacements_expected = [
(x0, Transpose(H)),
(x1, Inverse(MatAdd(R, MatMul(H, Sigma, x0)))),
]
reduced_exprs_expected = [
MatAdd(mu, MatMul(Sigma, x0, x1, MatAdd(MatMul(H, mu), MatMul(S.NegativeOne, data)))),
MatAdd(Sigma, MatMul(S.NegativeOne, Sigma, x0, x1, H, Sigma)),
]
assert cse_expr == (replacements_expected, reduced_exprs_expected)

View File

@ -0,0 +1,90 @@
"""Tests for tools for manipulation of expressions using paths. """
from sympy.simplify.epathtools import epath, EPath
from sympy.testing.pytest import raises
from sympy.core.numbers import E
from sympy.functions.elementary.trigonometric import (cos, sin)
from sympy.abc import x, y, z, t
def test_epath_select():
expr = [((x, 1, t), 2), ((3, y, 4), z)]
assert epath("/*", expr) == [((x, 1, t), 2), ((3, y, 4), z)]
assert epath("/*/*", expr) == [(x, 1, t), 2, (3, y, 4), z]
assert epath("/*/*/*", expr) == [x, 1, t, 3, y, 4]
assert epath("/*/*/*/*", expr) == []
assert epath("/[:]", expr) == [((x, 1, t), 2), ((3, y, 4), z)]
assert epath("/[:]/[:]", expr) == [(x, 1, t), 2, (3, y, 4), z]
assert epath("/[:]/[:]/[:]", expr) == [x, 1, t, 3, y, 4]
assert epath("/[:]/[:]/[:]/[:]", expr) == []
assert epath("/*/[:]", expr) == [(x, 1, t), 2, (3, y, 4), z]
assert epath("/*/[0]", expr) == [(x, 1, t), (3, y, 4)]
assert epath("/*/[1]", expr) == [2, z]
assert epath("/*/[2]", expr) == []
assert epath("/*/int", expr) == [2]
assert epath("/*/Symbol", expr) == [z]
assert epath("/*/tuple", expr) == [(x, 1, t), (3, y, 4)]
assert epath("/*/__iter__?", expr) == [(x, 1, t), (3, y, 4)]
assert epath("/*/int|tuple", expr) == [(x, 1, t), 2, (3, y, 4)]
assert epath("/*/Symbol|tuple", expr) == [(x, 1, t), (3, y, 4), z]
assert epath("/*/int|Symbol|tuple", expr) == [(x, 1, t), 2, (3, y, 4), z]
assert epath("/*/int|__iter__?", expr) == [(x, 1, t), 2, (3, y, 4)]
assert epath("/*/Symbol|__iter__?", expr) == [(x, 1, t), (3, y, 4), z]
assert epath(
"/*/int|Symbol|__iter__?", expr) == [(x, 1, t), 2, (3, y, 4), z]
assert epath("/*/[0]/int", expr) == [1, 3, 4]
assert epath("/*/[0]/Symbol", expr) == [x, t, y]
assert epath("/*/[0]/int[1:]", expr) == [1, 4]
assert epath("/*/[0]/Symbol[1:]", expr) == [t, y]
assert epath("/Symbol", x + y + z + 1) == [x, y, z]
assert epath("/*/*/Symbol", t + sin(x + 1) + cos(x + y + E)) == [x, x, y]
def test_epath_apply():
expr = [((x, 1, t), 2), ((3, y, 4), z)]
func = lambda expr: expr**2
assert epath("/*", expr, list) == [[(x, 1, t), 2], [(3, y, 4), z]]
assert epath("/*/[0]", expr, list) == [([x, 1, t], 2), ([3, y, 4], z)]
assert epath("/*/[1]", expr, func) == [((x, 1, t), 4), ((3, y, 4), z**2)]
assert epath("/*/[2]", expr, list) == expr
assert epath("/*/[0]/int", expr, func) == [((x, 1, t), 2), ((9, y, 16), z)]
assert epath("/*/[0]/Symbol", expr, func) == [((x**2, 1, t**2), 2),
((3, y**2, 4), z)]
assert epath(
"/*/[0]/int[1:]", expr, func) == [((x, 1, t), 2), ((3, y, 16), z)]
assert epath("/*/[0]/Symbol[1:]", expr, func) == [((x, 1, t**2),
2), ((3, y**2, 4), z)]
assert epath("/Symbol", x + y + z + 1, func) == x**2 + y**2 + z**2 + 1
assert epath("/*/*/Symbol", t + sin(x + 1) + cos(x + y + E), func) == \
t + sin(x**2 + 1) + cos(x**2 + y**2 + E)
def test_EPath():
assert EPath("/*/[0]")._path == "/*/[0]"
assert EPath(EPath("/*/[0]"))._path == "/*/[0]"
assert isinstance(epath("/*/[0]"), EPath) is True
assert repr(EPath("/*/[0]")) == "EPath('/*/[0]')"
raises(ValueError, lambda: EPath(""))
raises(ValueError, lambda: EPath("/"))
raises(ValueError, lambda: EPath("/|x"))
raises(ValueError, lambda: EPath("/["))
raises(ValueError, lambda: EPath("/[0]%"))
raises(NotImplementedError, lambda: EPath("Symbol"))

View File

@ -0,0 +1,492 @@
from sympy.core.add import Add
from sympy.core.mul import Mul
from sympy.core.numbers import (I, Rational, pi)
from sympy.core.parameters import evaluate
from sympy.core.singleton import S
from sympy.core.symbol import (Dummy, Symbol, symbols)
from sympy.functions.elementary.hyperbolic import (cosh, coth, csch, sech, sinh, tanh)
from sympy.functions.elementary.miscellaneous import (root, sqrt)
from sympy.functions.elementary.trigonometric import (cos, cot, csc, sec, sin, tan)
from sympy.simplify.powsimp import powsimp
from sympy.simplify.fu import (
L, TR1, TR10, TR10i, TR11, _TR11, TR12, TR12i, TR13, TR14, TR15, TR16,
TR111, TR2, TR2i, TR3, TR4, TR5, TR6, TR7, TR8, TR9, TRmorrie, _TR56 as T,
TRpower, hyper_as_trig, fu, process_common_addends, trig_split,
as_f_sign_1)
from sympy.core.random import verify_numerically
from sympy.abc import a, b, c, x, y, z
def test_TR1():
assert TR1(2*csc(x) + sec(x)) == 1/cos(x) + 2/sin(x)
def test_TR2():
assert TR2(tan(x)) == sin(x)/cos(x)
assert TR2(cot(x)) == cos(x)/sin(x)
assert TR2(tan(tan(x) - sin(x)/cos(x))) == 0
def test_TR2i():
# just a reminder that ratios of powers only simplify if both
# numerator and denominator satisfy the condition that each
# has a positive base or an integer exponent; e.g. the following,
# at y=-1, x=1/2 gives sqrt(2)*I != -sqrt(2)*I
assert powsimp(2**x/y**x) != (2/y)**x
assert TR2i(sin(x)/cos(x)) == tan(x)
assert TR2i(sin(x)*sin(y)/cos(x)) == tan(x)*sin(y)
assert TR2i(1/(sin(x)/cos(x))) == 1/tan(x)
assert TR2i(1/(sin(x)*sin(y)/cos(x))) == 1/tan(x)/sin(y)
assert TR2i(sin(x)/2/(cos(x) + 1)) == sin(x)/(cos(x) + 1)/2
assert TR2i(sin(x)/2/(cos(x) + 1), half=True) == tan(x/2)/2
assert TR2i(sin(1)/(cos(1) + 1), half=True) == tan(S.Half)
assert TR2i(sin(2)/(cos(2) + 1), half=True) == tan(1)
assert TR2i(sin(4)/(cos(4) + 1), half=True) == tan(2)
assert TR2i(sin(5)/(cos(5) + 1), half=True) == tan(5*S.Half)
assert TR2i((cos(1) + 1)/sin(1), half=True) == 1/tan(S.Half)
assert TR2i((cos(2) + 1)/sin(2), half=True) == 1/tan(1)
assert TR2i((cos(4) + 1)/sin(4), half=True) == 1/tan(2)
assert TR2i((cos(5) + 1)/sin(5), half=True) == 1/tan(5*S.Half)
assert TR2i((cos(1) + 1)**(-a)*sin(1)**a, half=True) == tan(S.Half)**a
assert TR2i((cos(2) + 1)**(-a)*sin(2)**a, half=True) == tan(1)**a
assert TR2i((cos(4) + 1)**(-a)*sin(4)**a, half=True) == (cos(4) + 1)**(-a)*sin(4)**a
assert TR2i((cos(5) + 1)**(-a)*sin(5)**a, half=True) == (cos(5) + 1)**(-a)*sin(5)**a
assert TR2i((cos(1) + 1)**a*sin(1)**(-a), half=True) == tan(S.Half)**(-a)
assert TR2i((cos(2) + 1)**a*sin(2)**(-a), half=True) == tan(1)**(-a)
assert TR2i((cos(4) + 1)**a*sin(4)**(-a), half=True) == (cos(4) + 1)**a*sin(4)**(-a)
assert TR2i((cos(5) + 1)**a*sin(5)**(-a), half=True) == (cos(5) + 1)**a*sin(5)**(-a)
i = symbols('i', integer=True)
assert TR2i(((cos(5) + 1)**i*sin(5)**(-i)), half=True) == tan(5*S.Half)**(-i)
assert TR2i(1/((cos(5) + 1)**i*sin(5)**(-i)), half=True) == tan(5*S.Half)**i
def test_TR3():
assert TR3(cos(y - x*(y - x))) == cos(x*(x - y) + y)
assert cos(pi/2 + x) == -sin(x)
assert cos(30*pi/2 + x) == -cos(x)
for f in (cos, sin, tan, cot, csc, sec):
i = f(pi*Rational(3, 7))
j = TR3(i)
assert verify_numerically(i, j) and i.func != j.func
with evaluate(False):
eq = cos(9*pi/22)
assert eq.has(9*pi) and TR3(eq) == sin(pi/11)
def test_TR4():
for i in [0, pi/6, pi/4, pi/3, pi/2]:
with evaluate(False):
eq = cos(i)
assert isinstance(eq, cos) and TR4(eq) == cos(i)
def test__TR56():
h = lambda x: 1 - x
assert T(sin(x)**3, sin, cos, h, 4, False) == sin(x)*(-cos(x)**2 + 1)
assert T(sin(x)**10, sin, cos, h, 4, False) == sin(x)**10
assert T(sin(x)**6, sin, cos, h, 6, False) == (-cos(x)**2 + 1)**3
assert T(sin(x)**6, sin, cos, h, 6, True) == sin(x)**6
assert T(sin(x)**8, sin, cos, h, 10, True) == (-cos(x)**2 + 1)**4
# issue 17137
assert T(sin(x)**I, sin, cos, h, 4, True) == sin(x)**I
assert T(sin(x)**(2*I + 1), sin, cos, h, 4, True) == sin(x)**(2*I + 1)
def test_TR5():
assert TR5(sin(x)**2) == -cos(x)**2 + 1
assert TR5(sin(x)**-2) == sin(x)**(-2)
assert TR5(sin(x)**4) == (-cos(x)**2 + 1)**2
def test_TR6():
assert TR6(cos(x)**2) == -sin(x)**2 + 1
assert TR6(cos(x)**-2) == cos(x)**(-2)
assert TR6(cos(x)**4) == (-sin(x)**2 + 1)**2
def test_TR7():
assert TR7(cos(x)**2) == cos(2*x)/2 + S.Half
assert TR7(cos(x)**2 + 1) == cos(2*x)/2 + Rational(3, 2)
def test_TR8():
assert TR8(cos(2)*cos(3)) == cos(5)/2 + cos(1)/2
assert TR8(cos(2)*sin(3)) == sin(5)/2 + sin(1)/2
assert TR8(sin(2)*sin(3)) == -cos(5)/2 + cos(1)/2
assert TR8(sin(1)*sin(2)*sin(3)) == sin(4)/4 - sin(6)/4 + sin(2)/4
assert TR8(cos(2)*cos(3)*cos(4)*cos(5)) == \
cos(4)/4 + cos(10)/8 + cos(2)/8 + cos(8)/8 + cos(14)/8 + \
cos(6)/8 + Rational(1, 8)
assert TR8(cos(2)*cos(3)*cos(4)*cos(5)*cos(6)) == \
cos(10)/8 + cos(4)/8 + 3*cos(2)/16 + cos(16)/16 + cos(8)/8 + \
cos(14)/16 + cos(20)/16 + cos(12)/16 + Rational(1, 16) + cos(6)/8
assert TR8(sin(pi*Rational(3, 7))**2*cos(pi*Rational(3, 7))**2/(16*sin(pi/7)**2)) == Rational(1, 64)
def test_TR9():
a = S.Half
b = 3*a
assert TR9(a) == a
assert TR9(cos(1) + cos(2)) == 2*cos(a)*cos(b)
assert TR9(cos(1) - cos(2)) == 2*sin(a)*sin(b)
assert TR9(sin(1) - sin(2)) == -2*sin(a)*cos(b)
assert TR9(sin(1) + sin(2)) == 2*sin(b)*cos(a)
assert TR9(cos(1) + 2*sin(1) + 2*sin(2)) == cos(1) + 4*sin(b)*cos(a)
assert TR9(cos(4) + cos(2) + 2*cos(1)*cos(3)) == 4*cos(1)*cos(3)
assert TR9((cos(4) + cos(2))/cos(3)/2 + cos(3)) == 2*cos(1)*cos(2)
assert TR9(cos(3) + cos(4) + cos(5) + cos(6)) == \
4*cos(S.Half)*cos(1)*cos(Rational(9, 2))
assert TR9(cos(3) + cos(3)*cos(2)) == cos(3) + cos(2)*cos(3)
assert TR9(-cos(y) + cos(x*y)) == -2*sin(x*y/2 - y/2)*sin(x*y/2 + y/2)
assert TR9(-sin(y) + sin(x*y)) == 2*sin(x*y/2 - y/2)*cos(x*y/2 + y/2)
c = cos(x)
s = sin(x)
for si in ((1, 1), (1, -1), (-1, 1), (-1, -1)):
for a in ((c, s), (s, c), (cos(x), cos(x*y)), (sin(x), sin(x*y))):
args = zip(si, a)
ex = Add(*[Mul(*ai) for ai in args])
t = TR9(ex)
assert not (a[0].func == a[1].func and (
not verify_numerically(ex, t.expand(trig=True)) or t.is_Add)
or a[1].func != a[0].func and ex != t)
def test_TR10():
assert TR10(cos(a + b)) == -sin(a)*sin(b) + cos(a)*cos(b)
assert TR10(sin(a + b)) == sin(a)*cos(b) + sin(b)*cos(a)
assert TR10(sin(a + b + c)) == \
(-sin(a)*sin(b) + cos(a)*cos(b))*sin(c) + \
(sin(a)*cos(b) + sin(b)*cos(a))*cos(c)
assert TR10(cos(a + b + c)) == \
(-sin(a)*sin(b) + cos(a)*cos(b))*cos(c) - \
(sin(a)*cos(b) + sin(b)*cos(a))*sin(c)
def test_TR10i():
assert TR10i(cos(1)*cos(3) + sin(1)*sin(3)) == cos(2)
assert TR10i(cos(1)*cos(3) - sin(1)*sin(3)) == cos(4)
assert TR10i(cos(1)*sin(3) - sin(1)*cos(3)) == sin(2)
assert TR10i(cos(1)*sin(3) + sin(1)*cos(3)) == sin(4)
assert TR10i(cos(1)*sin(3) + sin(1)*cos(3) + 7) == sin(4) + 7
assert TR10i(cos(1)*sin(3) + sin(1)*cos(3) + cos(3)) == cos(3) + sin(4)
assert TR10i(2*cos(1)*sin(3) + 2*sin(1)*cos(3) + cos(3)) == \
2*sin(4) + cos(3)
assert TR10i(cos(2)*cos(3) + sin(2)*(cos(1)*sin(2) + cos(2)*sin(1))) == \
cos(1)
eq = (cos(2)*cos(3) + sin(2)*(
cos(1)*sin(2) + cos(2)*sin(1)))*cos(5) + sin(1)*sin(5)
assert TR10i(eq) == TR10i(eq.expand()) == cos(4)
assert TR10i(sqrt(2)*cos(x)*x + sqrt(6)*sin(x)*x) == \
2*sqrt(2)*x*sin(x + pi/6)
assert TR10i(cos(x)/sqrt(6) + sin(x)/sqrt(2) +
cos(x)/sqrt(6)/3 + sin(x)/sqrt(2)/3) == 4*sqrt(6)*sin(x + pi/6)/9
assert TR10i(cos(x)/sqrt(6) + sin(x)/sqrt(2) +
cos(y)/sqrt(6)/3 + sin(y)/sqrt(2)/3) == \
sqrt(6)*sin(x + pi/6)/3 + sqrt(6)*sin(y + pi/6)/9
assert TR10i(cos(x) + sqrt(3)*sin(x) + 2*sqrt(3)*cos(x + pi/6)) == 4*cos(x)
assert TR10i(cos(x) + sqrt(3)*sin(x) +
2*sqrt(3)*cos(x + pi/6) + 4*sin(x)) == 4*sqrt(2)*sin(x + pi/4)
assert TR10i(cos(2)*sin(3) + sin(2)*cos(4)) == \
sin(2)*cos(4) + sin(3)*cos(2)
A = Symbol('A', commutative=False)
assert TR10i(sqrt(2)*cos(x)*A + sqrt(6)*sin(x)*A) == \
2*sqrt(2)*sin(x + pi/6)*A
c = cos(x)
s = sin(x)
h = sin(y)
r = cos(y)
for si in ((1, 1), (1, -1), (-1, 1), (-1, -1)):
for argsi in ((c*r, s*h), (c*h, s*r)): # explicit 2-args
args = zip(si, argsi)
ex = Add(*[Mul(*ai) for ai in args])
t = TR10i(ex)
assert not (ex - t.expand(trig=True) or t.is_Add)
c = cos(x)
s = sin(x)
h = sin(pi/6)
r = cos(pi/6)
for si in ((1, 1), (1, -1), (-1, 1), (-1, -1)):
for argsi in ((c*r, s*h), (c*h, s*r)): # induced
args = zip(si, argsi)
ex = Add(*[Mul(*ai) for ai in args])
t = TR10i(ex)
assert not (ex - t.expand(trig=True) or t.is_Add)
def test_TR11():
assert TR11(sin(2*x)) == 2*sin(x)*cos(x)
assert TR11(sin(4*x)) == 4*((-sin(x)**2 + cos(x)**2)*sin(x)*cos(x))
assert TR11(sin(x*Rational(4, 3))) == \
4*((-sin(x/3)**2 + cos(x/3)**2)*sin(x/3)*cos(x/3))
assert TR11(cos(2*x)) == -sin(x)**2 + cos(x)**2
assert TR11(cos(4*x)) == \
(-sin(x)**2 + cos(x)**2)**2 - 4*sin(x)**2*cos(x)**2
assert TR11(cos(2)) == cos(2)
assert TR11(cos(pi*Rational(3, 7)), pi*Rational(2, 7)) == -cos(pi*Rational(2, 7))**2 + sin(pi*Rational(2, 7))**2
assert TR11(cos(4), 2) == -sin(2)**2 + cos(2)**2
assert TR11(cos(6), 2) == cos(6)
assert TR11(sin(x)/cos(x/2), x/2) == 2*sin(x/2)
def test__TR11():
assert _TR11(sin(x/3)*sin(2*x)*sin(x/4)/(cos(x/6)*cos(x/8))) == \
4*sin(x/8)*sin(x/6)*sin(2*x),_TR11(sin(x/3)*sin(2*x)*sin(x/4)/(cos(x/6)*cos(x/8)))
assert _TR11(sin(x/3)/cos(x/6)) == 2*sin(x/6)
assert _TR11(cos(x/6)/sin(x/3)) == 1/(2*sin(x/6))
assert _TR11(sin(2*x)*cos(x/8)/sin(x/4)) == sin(2*x)/(2*sin(x/8)), _TR11(sin(2*x)*cos(x/8)/sin(x/4))
assert _TR11(sin(x)/sin(x/2)) == 2*cos(x/2)
def test_TR12():
assert TR12(tan(x + y)) == (tan(x) + tan(y))/(-tan(x)*tan(y) + 1)
assert TR12(tan(x + y + z)) ==\
(tan(z) + (tan(x) + tan(y))/(-tan(x)*tan(y) + 1))/(
1 - (tan(x) + tan(y))*tan(z)/(-tan(x)*tan(y) + 1))
assert TR12(tan(x*y)) == tan(x*y)
def test_TR13():
assert TR13(tan(3)*tan(2)) == -tan(2)/tan(5) - tan(3)/tan(5) + 1
assert TR13(cot(3)*cot(2)) == 1 + cot(3)*cot(5) + cot(2)*cot(5)
assert TR13(tan(1)*tan(2)*tan(3)) == \
(-tan(2)/tan(5) - tan(3)/tan(5) + 1)*tan(1)
assert TR13(tan(1)*tan(2)*cot(3)) == \
(-tan(2)/tan(3) + 1 - tan(1)/tan(3))*cot(3)
def test_L():
assert L(cos(x) + sin(x)) == 2
def test_fu():
assert fu(sin(50)**2 + cos(50)**2 + sin(pi/6)) == Rational(3, 2)
assert fu(sqrt(6)*cos(x) + sqrt(2)*sin(x)) == 2*sqrt(2)*sin(x + pi/3)
eq = sin(x)**4 - cos(y)**2 + sin(y)**2 + 2*cos(x)**2
assert fu(eq) == cos(x)**4 - 2*cos(y)**2 + 2
assert fu(S.Half - cos(2*x)/2) == sin(x)**2
assert fu(sin(a)*(cos(b) - sin(b)) + cos(a)*(sin(b) + cos(b))) == \
sqrt(2)*sin(a + b + pi/4)
assert fu(sqrt(3)*cos(x)/2 + sin(x)/2) == sin(x + pi/3)
assert fu(1 - sin(2*x)**2/4 - sin(y)**2 - cos(x)**4) == \
-cos(x)**2 + cos(y)**2
assert fu(cos(pi*Rational(4, 9))) == sin(pi/18)
assert fu(cos(pi/9)*cos(pi*Rational(2, 9))*cos(pi*Rational(3, 9))*cos(pi*Rational(4, 9))) == Rational(1, 16)
assert fu(
tan(pi*Rational(7, 18)) + tan(pi*Rational(5, 18)) - sqrt(3)*tan(pi*Rational(5, 18))*tan(pi*Rational(7, 18))) == \
-sqrt(3)
assert fu(tan(1)*tan(2)) == tan(1)*tan(2)
expr = Mul(*[cos(2**i) for i in range(10)])
assert fu(expr) == sin(1024)/(1024*sin(1))
# issue #18059:
assert fu(cos(x) + sqrt(sin(x)**2)) == cos(x) + sqrt(sin(x)**2)
assert fu((-14*sin(x)**3 + 35*sin(x) + 6*sqrt(3)*cos(x)**3 + 9*sqrt(3)*cos(x))/((cos(2*x) + 4))) == \
7*sin(x) + 3*sqrt(3)*cos(x)
def test_objective():
assert fu(sin(x)/cos(x), measure=lambda x: x.count_ops()) == \
tan(x)
assert fu(sin(x)/cos(x), measure=lambda x: -x.count_ops()) == \
sin(x)/cos(x)
def test_process_common_addends():
# this tests that the args are not evaluated as they are given to do
# and that key2 works when key1 is False
do = lambda x: Add(*[i**(i%2) for i in x.args])
assert process_common_addends(Add(*[1, 2, 3, 4], evaluate=False), do,
key2=lambda x: x%2, key1=False) == 1**1 + 3**1 + 2**0 + 4**0
def test_trig_split():
assert trig_split(cos(x), cos(y)) == (1, 1, 1, x, y, True)
assert trig_split(2*cos(x), -2*cos(y)) == (2, 1, -1, x, y, True)
assert trig_split(cos(x)*sin(y), cos(y)*sin(y)) == \
(sin(y), 1, 1, x, y, True)
assert trig_split(cos(x), -sqrt(3)*sin(x), two=True) == \
(2, 1, -1, x, pi/6, False)
assert trig_split(cos(x), sin(x), two=True) == \
(sqrt(2), 1, 1, x, pi/4, False)
assert trig_split(cos(x), -sin(x), two=True) == \
(sqrt(2), 1, -1, x, pi/4, False)
assert trig_split(sqrt(2)*cos(x), -sqrt(6)*sin(x), two=True) == \
(2*sqrt(2), 1, -1, x, pi/6, False)
assert trig_split(-sqrt(6)*cos(x), -sqrt(2)*sin(x), two=True) == \
(-2*sqrt(2), 1, 1, x, pi/3, False)
assert trig_split(cos(x)/sqrt(6), sin(x)/sqrt(2), two=True) == \
(sqrt(6)/3, 1, 1, x, pi/6, False)
assert trig_split(-sqrt(6)*cos(x)*sin(y),
-sqrt(2)*sin(x)*sin(y), two=True) == \
(-2*sqrt(2)*sin(y), 1, 1, x, pi/3, False)
assert trig_split(cos(x), sin(x)) is None
assert trig_split(cos(x), sin(z)) is None
assert trig_split(2*cos(x), -sin(x)) is None
assert trig_split(cos(x), -sqrt(3)*sin(x)) is None
assert trig_split(cos(x)*cos(y), sin(x)*sin(z)) is None
assert trig_split(cos(x)*cos(y), sin(x)*sin(y)) is None
assert trig_split(-sqrt(6)*cos(x), sqrt(2)*sin(x)*sin(y), two=True) is \
None
assert trig_split(sqrt(3)*sqrt(x), cos(3), two=True) is None
assert trig_split(sqrt(3)*root(x, 3), sin(3)*cos(2), two=True) is None
assert trig_split(cos(5)*cos(6), cos(7)*sin(5), two=True) is None
def test_TRmorrie():
assert TRmorrie(7*Mul(*[cos(i) for i in range(10)])) == \
7*sin(12)*sin(16)*cos(5)*cos(7)*cos(9)/(64*sin(1)*sin(3))
assert TRmorrie(x) == x
assert TRmorrie(2*x) == 2*x
e = cos(pi/7)*cos(pi*Rational(2, 7))*cos(pi*Rational(4, 7))
assert TR8(TRmorrie(e)) == Rational(-1, 8)
e = Mul(*[cos(2**i*pi/17) for i in range(1, 17)])
assert TR8(TR3(TRmorrie(e))) == Rational(1, 65536)
# issue 17063
eq = cos(x)/cos(x/2)
assert TRmorrie(eq) == eq
# issue #20430
eq = cos(x/2)*sin(x/2)*cos(x)**3
assert TRmorrie(eq) == sin(2*x)*cos(x)**2/4
def test_TRpower():
assert TRpower(1/sin(x)**2) == 1/sin(x)**2
assert TRpower(cos(x)**3*sin(x/2)**4) == \
(3*cos(x)/4 + cos(3*x)/4)*(-cos(x)/2 + cos(2*x)/8 + Rational(3, 8))
for k in range(2, 8):
assert verify_numerically(sin(x)**k, TRpower(sin(x)**k))
assert verify_numerically(cos(x)**k, TRpower(cos(x)**k))
def test_hyper_as_trig():
from sympy.simplify.fu import _osborne, _osbornei
eq = sinh(x)**2 + cosh(x)**2
t, f = hyper_as_trig(eq)
assert f(fu(t)) == cosh(2*x)
e, f = hyper_as_trig(tanh(x + y))
assert f(TR12(e)) == (tanh(x) + tanh(y))/(tanh(x)*tanh(y) + 1)
d = Dummy()
assert _osborne(sinh(x), d) == I*sin(x*d)
assert _osborne(tanh(x), d) == I*tan(x*d)
assert _osborne(coth(x), d) == cot(x*d)/I
assert _osborne(cosh(x), d) == cos(x*d)
assert _osborne(sech(x), d) == sec(x*d)
assert _osborne(csch(x), d) == csc(x*d)/I
for func in (sinh, cosh, tanh, coth, sech, csch):
h = func(pi)
assert _osbornei(_osborne(h, d), d) == h
# /!\ the _osborne functions are not meant to work
# in the o(i(trig, d), d) direction so we just check
# that they work as they are supposed to work
assert _osbornei(cos(x*y + z), y) == cosh(x + z*I)
assert _osbornei(sin(x*y + z), y) == sinh(x + z*I)/I
assert _osbornei(tan(x*y + z), y) == tanh(x + z*I)/I
assert _osbornei(cot(x*y + z), y) == coth(x + z*I)*I
assert _osbornei(sec(x*y + z), y) == sech(x + z*I)
assert _osbornei(csc(x*y + z), y) == csch(x + z*I)*I
def test_TR12i():
ta, tb, tc = [tan(i) for i in (a, b, c)]
assert TR12i((ta + tb)/(-ta*tb + 1)) == tan(a + b)
assert TR12i((ta + tb)/(ta*tb - 1)) == -tan(a + b)
assert TR12i((-ta - tb)/(ta*tb - 1)) == tan(a + b)
eq = (ta + tb)/(-ta*tb + 1)**2*(-3*ta - 3*tc)/(2*(ta*tc - 1))
assert TR12i(eq.expand()) == \
-3*tan(a + b)*tan(a + c)/(tan(a) + tan(b) - 1)/2
assert TR12i(tan(x)/sin(x)) == tan(x)/sin(x)
eq = (ta + cos(2))/(-ta*tb + 1)
assert TR12i(eq) == eq
eq = (ta + tb + 2)**2/(-ta*tb + 1)
assert TR12i(eq) == eq
eq = ta/(-ta*tb + 1)
assert TR12i(eq) == eq
eq = (((ta + tb)*(a + 1)).expand())**2/(ta*tb - 1)
assert TR12i(eq) == -(a + 1)**2*tan(a + b)
def test_TR14():
eq = (cos(x) - 1)*(cos(x) + 1)
ans = -sin(x)**2
assert TR14(eq) == ans
assert TR14(1/eq) == 1/ans
assert TR14((cos(x) - 1)**2*(cos(x) + 1)**2) == ans**2
assert TR14((cos(x) - 1)**2*(cos(x) + 1)**3) == ans**2*(cos(x) + 1)
assert TR14((cos(x) - 1)**3*(cos(x) + 1)**2) == ans**2*(cos(x) - 1)
eq = (cos(x) - 1)**y*(cos(x) + 1)**y
assert TR14(eq) == eq
eq = (cos(x) - 2)**y*(cos(x) + 1)
assert TR14(eq) == eq
eq = (tan(x) - 2)**2*(cos(x) + 1)
assert TR14(eq) == eq
i = symbols('i', integer=True)
assert TR14((cos(x) - 1)**i*(cos(x) + 1)**i) == ans**i
assert TR14((sin(x) - 1)**i*(sin(x) + 1)**i) == (-cos(x)**2)**i
# could use extraction in this case
eq = (cos(x) - 1)**(i + 1)*(cos(x) + 1)**i
assert TR14(eq) in [(cos(x) - 1)*ans**i, eq]
assert TR14((sin(x) - 1)*(sin(x) + 1)) == -cos(x)**2
p1 = (cos(x) + 1)*(cos(x) - 1)
p2 = (cos(y) - 1)*2*(cos(y) + 1)
p3 = (3*(cos(y) - 1))*(3*(cos(y) + 1))
assert TR14(p1*p2*p3*(x - 1)) == -18*((x - 1)*sin(x)**2*sin(y)**4)
def test_TR15_16_17():
assert TR15(1 - 1/sin(x)**2) == -cot(x)**2
assert TR16(1 - 1/cos(x)**2) == -tan(x)**2
assert TR111(1 - 1/tan(x)**2) == 1 - cot(x)**2
def test_as_f_sign_1():
assert as_f_sign_1(x + 1) == (1, x, 1)
assert as_f_sign_1(x - 1) == (1, x, -1)
assert as_f_sign_1(-x + 1) == (-1, x, -1)
assert as_f_sign_1(-x - 1) == (-1, x, 1)
assert as_f_sign_1(2*x + 2) == (2, x, 1)
assert as_f_sign_1(x*y - y) == (y, x, -1)
assert as_f_sign_1(-x*y + y) == (-y, x, -1)
def test_issue_25590():
A = Symbol('A', commutative=False)
B = Symbol('B', commutative=False)
assert TR8(2*cos(x)*sin(x)*B*A) == sin(2*x)*B*A
assert TR13(tan(2)*tan(3)*B*A) == (-tan(2)/tan(5) - tan(3)/tan(5) + 1)*B*A
# XXX The result may not be optimal than
# sin(2*x)*B*A + cos(x)**2 and may change in the future
assert (2*cos(x)*sin(x)*B*A + cos(x)**2).simplify() == sin(2*x)*B*A + cos(2*x)/2 + S.One/2

View File

@ -0,0 +1,54 @@
""" Unit tests for Hyper_Function"""
from sympy.core import symbols, Dummy, Tuple, S, Rational
from sympy.functions import hyper
from sympy.simplify.hyperexpand import Hyper_Function
def test_attrs():
a, b = symbols('a, b', cls=Dummy)
f = Hyper_Function([2, a], [b])
assert f.ap == Tuple(2, a)
assert f.bq == Tuple(b)
assert f.args == (Tuple(2, a), Tuple(b))
assert f.sizes == (2, 1)
def test_call():
a, b, x = symbols('a, b, x', cls=Dummy)
f = Hyper_Function([2, a], [b])
assert f(x) == hyper([2, a], [b], x)
def test_has():
a, b, c = symbols('a, b, c', cls=Dummy)
f = Hyper_Function([2, -a], [b])
assert f.has(a)
assert f.has(Tuple(b))
assert not f.has(c)
def test_eq():
assert Hyper_Function([1], []) == Hyper_Function([1], [])
assert (Hyper_Function([1], []) != Hyper_Function([1], [])) is False
assert Hyper_Function([1], []) != Hyper_Function([2], [])
assert Hyper_Function([1], []) != Hyper_Function([1, 2], [])
assert Hyper_Function([1], []) != Hyper_Function([1], [2])
def test_gamma():
assert Hyper_Function([2, 3], [-1]).gamma == 0
assert Hyper_Function([-2, -3], [-1]).gamma == 2
n = Dummy(integer=True)
assert Hyper_Function([-1, n, 1], []).gamma == 1
assert Hyper_Function([-1, -n, 1], []).gamma == 1
p = Dummy(integer=True, positive=True)
assert Hyper_Function([-1, p, 1], []).gamma == 1
assert Hyper_Function([-1, -p, 1], []).gamma == 2
def test_suitable_origin():
assert Hyper_Function((S.Half,), (Rational(3, 2),))._is_suitable_origin() is True
assert Hyper_Function((S.Half,), (S.Half,))._is_suitable_origin() is False
assert Hyper_Function((S.Half,), (Rational(-1, 2),))._is_suitable_origin() is False
assert Hyper_Function((S.Half,), (0,))._is_suitable_origin() is False
assert Hyper_Function((S.Half,), (-1, 1,))._is_suitable_origin() is False
assert Hyper_Function((S.Half, 0), (1,))._is_suitable_origin() is False
assert Hyper_Function((S.Half, 1),
(2, Rational(-2, 3)))._is_suitable_origin() is True
assert Hyper_Function((S.Half, 1),
(2, Rational(-2, 3), Rational(3, 2)))._is_suitable_origin() is True

View File

@ -0,0 +1,127 @@
from sympy.core.function import Function
from sympy.core.numbers import (Rational, pi)
from sympy.core.singleton import S
from sympy.core.symbol import symbols
from sympy.functions.combinatorial.factorials import (rf, binomial, factorial)
from sympy.functions.elementary.exponential import exp
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.piecewise import Piecewise
from sympy.functions.elementary.trigonometric import (cos, sin)
from sympy.functions.special.gamma_functions import gamma
from sympy.simplify.gammasimp import gammasimp
from sympy.simplify.powsimp import powsimp
from sympy.simplify.simplify import simplify
from sympy.abc import x, y, n, k
def test_gammasimp():
R = Rational
# was part of test_combsimp_gamma() in test_combsimp.py
assert gammasimp(gamma(x)) == gamma(x)
assert gammasimp(gamma(x + 1)/x) == gamma(x)
assert gammasimp(gamma(x)/(x - 1)) == gamma(x - 1)
assert gammasimp(x*gamma(x)) == gamma(x + 1)
assert gammasimp((x + 1)*gamma(x + 1)) == gamma(x + 2)
assert gammasimp(gamma(x + y)*(x + y)) == gamma(x + y + 1)
assert gammasimp(x/gamma(x + 1)) == 1/gamma(x)
assert gammasimp((x + 1)**2/gamma(x + 2)) == (x + 1)/gamma(x + 1)
assert gammasimp(x*gamma(x) + gamma(x + 3)/(x + 2)) == \
(x + 2)*gamma(x + 1)
assert gammasimp(gamma(2*x)*x) == gamma(2*x + 1)/2
assert gammasimp(gamma(2*x)/(x - S.Half)) == 2*gamma(2*x - 1)
assert gammasimp(gamma(x)*gamma(1 - x)) == pi/sin(pi*x)
assert gammasimp(gamma(x)*gamma(-x)) == -pi/(x*sin(pi*x))
assert gammasimp(1/gamma(x + 3)/gamma(1 - x)) == \
sin(pi*x)/(pi*x*(x + 1)*(x + 2))
assert gammasimp(factorial(n + 2)) == gamma(n + 3)
assert gammasimp(binomial(n, k)) == \
gamma(n + 1)/(gamma(k + 1)*gamma(-k + n + 1))
assert powsimp(gammasimp(
gamma(x)*gamma(x + S.Half)*gamma(y)/gamma(x + y))) == \
2**(-2*x + 1)*sqrt(pi)*gamma(2*x)*gamma(y)/gamma(x + y)
assert gammasimp(1/gamma(x)/gamma(x - Rational(1, 3))/gamma(x + Rational(1, 3))) == \
3**(3*x - Rational(3, 2))/(2*pi*gamma(3*x - 1))
assert simplify(
gamma(S.Half + x/2)*gamma(1 + x/2)/gamma(1 + x)/sqrt(pi)*2**x) == 1
assert gammasimp(gamma(Rational(-1, 4))*gamma(Rational(-3, 4))) == 16*sqrt(2)*pi/3
assert powsimp(gammasimp(gamma(2*x)/gamma(x))) == \
2**(2*x - 1)*gamma(x + S.Half)/sqrt(pi)
# issue 6792
e = (-gamma(k)*gamma(k + 2) + gamma(k + 1)**2)/gamma(k)**2
assert gammasimp(e) == -k
assert gammasimp(1/e) == -1/k
e = (gamma(x) + gamma(x + 1))/gamma(x)
assert gammasimp(e) == x + 1
assert gammasimp(1/e) == 1/(x + 1)
e = (gamma(x) + gamma(x + 2))*(gamma(x - 1) + gamma(x))/gamma(x)
assert gammasimp(e) == (x**2 + x + 1)*gamma(x + 1)/(x - 1)
e = (-gamma(k)*gamma(k + 2) + gamma(k + 1)**2)/gamma(k)**2
assert gammasimp(e**2) == k**2
assert gammasimp(e**2/gamma(k + 1)) == k/gamma(k)
a = R(1, 2) + R(1, 3)
b = a + R(1, 3)
assert gammasimp(gamma(2*k)/gamma(k)*gamma(k + a)*gamma(k + b)
) == 3*2**(2*k + 1)*3**(-3*k - 2)*sqrt(pi)*gamma(3*k + R(3, 2))/2
# issue 9699
assert gammasimp((x + 1)*factorial(x)/gamma(y)) == gamma(x + 2)/gamma(y)
assert gammasimp(rf(x + n, k)*binomial(n, k)).simplify() == Piecewise(
(gamma(n + 1)*gamma(k + n + x)/(gamma(k + 1)*gamma(n + x)*gamma(-k + n + 1)), n > -x),
((-1)**k*gamma(n + 1)*gamma(-n - x + 1)/(gamma(k + 1)*gamma(-k + n + 1)*gamma(-k - n - x + 1)), True))
A, B = symbols('A B', commutative=False)
assert gammasimp(e*B*A) == gammasimp(e)*B*A
# check iteration
assert gammasimp(gamma(2*k)/gamma(k)*gamma(-k - R(1, 2))) == (
-2**(2*k + 1)*sqrt(pi)/(2*((2*k + 1)*cos(pi*k))))
assert gammasimp(
gamma(k)*gamma(k + R(1, 3))*gamma(k + R(2, 3))/gamma(k*R(3, 2))) == (
3*2**(3*k + 1)*3**(-3*k - S.Half)*sqrt(pi)*gamma(k*R(3, 2) + S.Half)/2)
# issue 6153
assert gammasimp(gamma(Rational(1, 4))/gamma(Rational(5, 4))) == 4
# was part of test_combsimp() in test_combsimp.py
assert gammasimp(binomial(n + 2, k + S.Half)) == gamma(n + 3)/ \
(gamma(k + R(3, 2))*gamma(-k + n + R(5, 2)))
assert gammasimp(binomial(n + 2, k + 2.0)) == \
gamma(n + 3)/(gamma(k + 3.0)*gamma(-k + n + 1))
# issue 11548
assert gammasimp(binomial(0, x)) == sin(pi*x)/(pi*x)
e = gamma(n + Rational(1, 3))*gamma(n + R(2, 3))
assert gammasimp(e) == e
assert gammasimp(gamma(4*n + S.Half)/gamma(2*n - R(3, 4))) == \
2**(4*n - R(5, 2))*(8*n - 3)*gamma(2*n + R(3, 4))/sqrt(pi)
i, m = symbols('i m', integer = True)
e = gamma(exp(i))
assert gammasimp(e) == e
e = gamma(m + 3)
assert gammasimp(e) == e
e = gamma(m + 1)/(gamma(i + 1)*gamma(-i + m + 1))
assert gammasimp(e) == e
p = symbols("p", integer=True, positive=True)
assert gammasimp(gamma(-p + 4)) == gamma(-p + 4)
def test_issue_22606():
fx = Function('f')(x)
eq = x + gamma(y)
# seems like ans should be `eq`, not `(x*y + gamma(y + 1))/y`
ans = gammasimp(eq)
assert gammasimp(eq.subs(x, fx)).subs(fx, x) == ans
assert gammasimp(eq.subs(x, cos(x))).subs(cos(x), x) == ans
assert 1/gammasimp(1/eq) == ans
assert gammasimp(fx.subs(x, eq)).args[0] == ans

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,366 @@
from sympy.core.function import Function
from sympy.core.mul import Mul
from sympy.core.numbers import (E, I, Rational, oo, pi)
from sympy.core.singleton import S
from sympy.core.symbol import (Dummy, Symbol, symbols)
from sympy.functions.elementary.exponential import (exp, log)
from sympy.functions.elementary.miscellaneous import (root, sqrt)
from sympy.functions.elementary.trigonometric import sin
from sympy.functions.special.gamma_functions import gamma
from sympy.functions.special.hyper import hyper
from sympy.matrices.expressions.matexpr import MatrixSymbol
from sympy.simplify.powsimp import (powdenest, powsimp)
from sympy.simplify.simplify import (signsimp, simplify)
from sympy.core.symbol import Str
from sympy.abc import x, y, z, a, b
def test_powsimp():
x, y, z, n = symbols('x,y,z,n')
f = Function('f')
assert powsimp( 4**x * 2**(-x) * 2**(-x) ) == 1
assert powsimp( (-4)**x * (-2)**(-x) * 2**(-x) ) == 1
assert powsimp(
f(4**x * 2**(-x) * 2**(-x)) ) == f(4**x * 2**(-x) * 2**(-x))
assert powsimp( f(4**x * 2**(-x) * 2**(-x)), deep=True ) == f(1)
assert exp(x)*exp(y) == exp(x)*exp(y)
assert powsimp(exp(x)*exp(y)) == exp(x + y)
assert powsimp(exp(x)*exp(y)*2**x*2**y) == (2*E)**(x + y)
assert powsimp(exp(x)*exp(y)*2**x*2**y, combine='exp') == \
exp(x + y)*2**(x + y)
assert powsimp(exp(x)*exp(y)*exp(2)*sin(x) + sin(y) + 2**x*2**y) == \
exp(2 + x + y)*sin(x) + sin(y) + 2**(x + y)
assert powsimp(sin(exp(x)*exp(y))) == sin(exp(x)*exp(y))
assert powsimp(sin(exp(x)*exp(y)), deep=True) == sin(exp(x + y))
assert powsimp(x**2*x**y) == x**(2 + y)
# This should remain factored, because 'exp' with deep=True is supposed
# to act like old automatic exponent combining.
assert powsimp((1 + E*exp(E))*exp(-E), combine='exp', deep=True) == \
(1 + exp(1 + E))*exp(-E)
assert powsimp((1 + E*exp(E))*exp(-E), deep=True) == \
(1 + exp(1 + E))*exp(-E)
assert powsimp((1 + E*exp(E))*exp(-E)) == (1 + exp(1 + E))*exp(-E)
assert powsimp((1 + E*exp(E))*exp(-E), combine='exp') == \
(1 + exp(1 + E))*exp(-E)
assert powsimp((1 + E*exp(E))*exp(-E), combine='base') == \
(1 + E*exp(E))*exp(-E)
x, y = symbols('x,y', nonnegative=True)
n = Symbol('n', real=True)
assert powsimp(y**n * (y/x)**(-n)) == x**n
assert powsimp(x**(x**(x*y)*y**(x*y))*y**(x**(x*y)*y**(x*y)), deep=True) \
== (x*y)**(x*y)**(x*y)
assert powsimp(2**(2**(2*x)*x), deep=False) == 2**(2**(2*x)*x)
assert powsimp(2**(2**(2*x)*x), deep=True) == 2**(x*4**x)
assert powsimp(
exp(-x + exp(-x)*exp(-x*log(x))), deep=False, combine='exp') == \
exp(-x + exp(-x)*exp(-x*log(x)))
assert powsimp(
exp(-x + exp(-x)*exp(-x*log(x))), deep=False, combine='exp') == \
exp(-x + exp(-x)*exp(-x*log(x)))
assert powsimp((x + y)/(3*z), deep=False, combine='exp') == (x + y)/(3*z)
assert powsimp((x/3 + y/3)/z, deep=True, combine='exp') == (x/3 + y/3)/z
assert powsimp(exp(x)/(1 + exp(x)*exp(y)), deep=True) == \
exp(x)/(1 + exp(x + y))
assert powsimp(x*y**(z**x*z**y), deep=True) == x*y**(z**(x + y))
assert powsimp((z**x*z**y)**x, deep=True) == (z**(x + y))**x
assert powsimp(x*(z**x*z**y)**x, deep=True) == x*(z**(x + y))**x
p = symbols('p', positive=True)
assert powsimp((1/x)**log(2)/x) == (1/x)**(1 + log(2))
assert powsimp((1/p)**log(2)/p) == p**(-1 - log(2))
# coefficient of exponent can only be simplified for positive bases
assert powsimp(2**(2*x)) == 4**x
assert powsimp((-1)**(2*x)) == (-1)**(2*x)
i = symbols('i', integer=True)
assert powsimp((-1)**(2*i)) == 1
assert powsimp((-1)**(-x)) != (-1)**x # could be 1/((-1)**x), but is not
# force=True overrides assumptions
assert powsimp((-1)**(2*x), force=True) == 1
# rational exponents allow combining of negative terms
w, n, m = symbols('w n m', negative=True)
e = i/a # not a rational exponent if `a` is unknown
ex = w**e*n**e*m**e
assert powsimp(ex) == m**(i/a)*n**(i/a)*w**(i/a)
e = i/3
ex = w**e*n**e*m**e
assert powsimp(ex) == (-1)**i*(-m*n*w)**(i/3)
e = (3 + i)/i
ex = w**e*n**e*m**e
assert powsimp(ex) == (-1)**(3*e)*(-m*n*w)**e
eq = x**(a*Rational(2, 3))
# eq != (x**a)**(2/3) (try x = -1 and a = 3 to see)
assert powsimp(eq).exp == eq.exp == a*Rational(2, 3)
# powdenest goes the other direction
assert powsimp(2**(2*x)) == 4**x
assert powsimp(exp(p/2)) == exp(p/2)
# issue 6368
eq = Mul(*[sqrt(Dummy(imaginary=True)) for i in range(3)])
assert powsimp(eq) == eq and eq.is_Mul
assert all(powsimp(e) == e for e in (sqrt(x**a), sqrt(x**2)))
# issue 8836
assert str( powsimp(exp(I*pi/3)*root(-1,3)) ) == '(-1)**(2/3)'
# issue 9183
assert powsimp(-0.1**x) == -0.1**x
# issue 10095
assert powsimp((1/(2*E))**oo) == (exp(-1)/2)**oo
# PR 13131
eq = sin(2*x)**2*sin(2.0*x)**2
assert powsimp(eq) == eq
# issue 14615
assert powsimp(x**2*y**3*(x*y**2)**Rational(3, 2)
) == x*y*(x*y**2)**Rational(5, 2)
def test_powsimp_negated_base():
assert powsimp((-x + y)/sqrt(x - y)) == -sqrt(x - y)
assert powsimp((-x + y)*(-z + y)/sqrt(x - y)/sqrt(z - y)) == sqrt(x - y)*sqrt(z - y)
p = symbols('p', positive=True)
reps = {p: 2, a: S.Half}
assert powsimp((-p)**a/p**a).subs(reps) == ((-1)**a).subs(reps)
assert powsimp((-p)**a*p**a).subs(reps) == ((-p**2)**a).subs(reps)
n = symbols('n', negative=True)
reps = {p: -2, a: S.Half}
assert powsimp((-n)**a/n**a).subs(reps) == (-1)**(-a).subs(a, S.Half)
assert powsimp((-n)**a*n**a).subs(reps) == ((-n**2)**a).subs(reps)
# if x is 0 then the lhs is 0**a*oo**a which is not (-1)**a
eq = (-x)**a/x**a
assert powsimp(eq) == eq
def test_powsimp_nc():
x, y, z = symbols('x,y,z')
A, B, C = symbols('A B C', commutative=False)
assert powsimp(A**x*A**y, combine='all') == A**(x + y)
assert powsimp(A**x*A**y, combine='base') == A**x*A**y
assert powsimp(A**x*A**y, combine='exp') == A**(x + y)
assert powsimp(A**x*B**x, combine='all') == A**x*B**x
assert powsimp(A**x*B**x, combine='base') == A**x*B**x
assert powsimp(A**x*B**x, combine='exp') == A**x*B**x
assert powsimp(B**x*A**x, combine='all') == B**x*A**x
assert powsimp(B**x*A**x, combine='base') == B**x*A**x
assert powsimp(B**x*A**x, combine='exp') == B**x*A**x
assert powsimp(A**x*A**y*A**z, combine='all') == A**(x + y + z)
assert powsimp(A**x*A**y*A**z, combine='base') == A**x*A**y*A**z
assert powsimp(A**x*A**y*A**z, combine='exp') == A**(x + y + z)
assert powsimp(A**x*B**x*C**x, combine='all') == A**x*B**x*C**x
assert powsimp(A**x*B**x*C**x, combine='base') == A**x*B**x*C**x
assert powsimp(A**x*B**x*C**x, combine='exp') == A**x*B**x*C**x
assert powsimp(B**x*A**x*C**x, combine='all') == B**x*A**x*C**x
assert powsimp(B**x*A**x*C**x, combine='base') == B**x*A**x*C**x
assert powsimp(B**x*A**x*C**x, combine='exp') == B**x*A**x*C**x
def test_issue_6440():
assert powsimp(16*2**a*8**b) == 2**(a + 3*b + 4)
def test_powdenest():
x, y = symbols('x,y')
p, q = symbols('p q', positive=True)
i, j = symbols('i,j', integer=True)
assert powdenest(x) == x
assert powdenest(x + 2*(x**(a*Rational(2, 3)))**(3*x)) == (x + 2*(x**(a*Rational(2, 3)))**(3*x))
assert powdenest((exp(a*Rational(2, 3)))**(3*x)) # -X-> (exp(a/3))**(6*x)
assert powdenest((x**(a*Rational(2, 3)))**(3*x)) == ((x**(a*Rational(2, 3)))**(3*x))
assert powdenest(exp(3*x*log(2))) == 2**(3*x)
assert powdenest(sqrt(p**2)) == p
eq = p**(2*i)*q**(4*i)
assert powdenest(eq) == (p*q**2)**(2*i)
# -X-> (x**x)**i*(x**x)**j == x**(x*(i + j))
assert powdenest((x**x)**(i + j))
assert powdenest(exp(3*y*log(x))) == x**(3*y)
assert powdenest(exp(y*(log(a) + log(b)))) == (a*b)**y
assert powdenest(exp(3*(log(a) + log(b)))) == a**3*b**3
assert powdenest(((x**(2*i))**(3*y))**x) == ((x**(2*i))**(3*y))**x
assert powdenest(((x**(2*i))**(3*y))**x, force=True) == x**(6*i*x*y)
assert powdenest(((x**(a*Rational(2, 3)))**(3*y/i))**x) == \
(((x**(a*Rational(2, 3)))**(3*y/i))**x)
assert powdenest((x**(2*i)*y**(4*i))**z, force=True) == (x*y**2)**(2*i*z)
assert powdenest((p**(2*i)*q**(4*i))**j) == (p*q**2)**(2*i*j)
e = ((p**(2*a))**(3*y))**x
assert powdenest(e) == e
e = ((x**2*y**4)**a)**(x*y)
assert powdenest(e) == e
e = (((x**2*y**4)**a)**(x*y))**3
assert powdenest(e) == ((x**2*y**4)**a)**(3*x*y)
assert powdenest((((x**2*y**4)**a)**(x*y)), force=True) == \
(x*y**2)**(2*a*x*y)
assert powdenest((((x**2*y**4)**a)**(x*y))**3, force=True) == \
(x*y**2)**(6*a*x*y)
assert powdenest((x**2*y**6)**i) != (x*y**3)**(2*i)
x, y = symbols('x,y', positive=True)
assert powdenest((x**2*y**6)**i) == (x*y**3)**(2*i)
assert powdenest((x**(i*Rational(2, 3))*y**(i/2))**(2*i)) == (x**Rational(4, 3)*y)**(i**2)
assert powdenest(sqrt(x**(2*i)*y**(6*i))) == (x*y**3)**i
assert powdenest(4**x) == 2**(2*x)
assert powdenest((4**x)**y) == 2**(2*x*y)
assert powdenest(4**x*y) == 2**(2*x)*y
def test_powdenest_polar():
x, y, z = symbols('x y z', polar=True)
a, b, c = symbols('a b c')
assert powdenest((x*y*z)**a) == x**a*y**a*z**a
assert powdenest((x**a*y**b)**c) == x**(a*c)*y**(b*c)
assert powdenest(((x**a)**b*y**c)**c) == x**(a*b*c)*y**(c**2)
def test_issue_5805():
arg = ((gamma(x)*hyper((), (), x))*pi)**2
assert powdenest(arg) == (pi*gamma(x)*hyper((), (), x))**2
assert arg.is_positive is None
def test_issue_9324_powsimp_on_matrix_symbol():
M = MatrixSymbol('M', 10, 10)
expr = powsimp(M, deep=True)
assert expr == M
assert expr.args[0] == Str('M')
def test_issue_6367():
z = -5*sqrt(2)/(2*sqrt(2*sqrt(29) + 29)) + sqrt(-sqrt(29)/29 + S.Half)
assert Mul(*[powsimp(a) for a in Mul.make_args(z.normal())]) == 0
assert powsimp(z.normal()) == 0
assert simplify(z) == 0
assert powsimp(sqrt(2 + sqrt(3))*sqrt(2 - sqrt(3)) + 1) == 2
assert powsimp(z) != 0
def test_powsimp_polar():
from sympy.functions.elementary.complexes import polar_lift
from sympy.functions.elementary.exponential import exp_polar
x, y, z = symbols('x y z')
p, q, r = symbols('p q r', polar=True)
assert (polar_lift(-1))**(2*x) == exp_polar(2*pi*I*x)
assert powsimp(p**x * q**x) == (p*q)**x
assert p**x * (1/p)**x == 1
assert (1/p)**x == p**(-x)
assert exp_polar(x)*exp_polar(y) == exp_polar(x)*exp_polar(y)
assert powsimp(exp_polar(x)*exp_polar(y)) == exp_polar(x + y)
assert powsimp(exp_polar(x)*exp_polar(y)*p**x*p**y) == \
(p*exp_polar(1))**(x + y)
assert powsimp(exp_polar(x)*exp_polar(y)*p**x*p**y, combine='exp') == \
exp_polar(x + y)*p**(x + y)
assert powsimp(
exp_polar(x)*exp_polar(y)*exp_polar(2)*sin(x) + sin(y) + p**x*p**y) \
== p**(x + y) + sin(x)*exp_polar(2 + x + y) + sin(y)
assert powsimp(sin(exp_polar(x)*exp_polar(y))) == \
sin(exp_polar(x)*exp_polar(y))
assert powsimp(sin(exp_polar(x)*exp_polar(y)), deep=True) == \
sin(exp_polar(x + y))
def test_issue_5728():
b = x*sqrt(y)
a = sqrt(b)
c = sqrt(sqrt(x)*y)
assert powsimp(a*b) == sqrt(b)**3
assert powsimp(a*b**2*sqrt(y)) == sqrt(y)*a**5
assert powsimp(a*x**2*c**3*y) == c**3*a**5
assert powsimp(a*x*c**3*y**2) == c**7*a
assert powsimp(x*c**3*y**2) == c**7
assert powsimp(x*c**3*y) == x*y*c**3
assert powsimp(sqrt(x)*c**3*y) == c**5
assert powsimp(sqrt(x)*a**3*sqrt(y)) == sqrt(x)*sqrt(y)*a**3
assert powsimp(Mul(sqrt(x)*c**3*sqrt(y), y, evaluate=False)) == \
sqrt(x)*sqrt(y)**3*c**3
assert powsimp(a**2*a*x**2*y) == a**7
# symbolic powers work, too
b = x**y*y
a = b*sqrt(b)
assert a.is_Mul is True
assert powsimp(a) == sqrt(b)**3
# as does exp
a = x*exp(y*Rational(2, 3))
assert powsimp(a*sqrt(a)) == sqrt(a)**3
assert powsimp(a**2*sqrt(a)) == sqrt(a)**5
assert powsimp(a**2*sqrt(sqrt(a))) == sqrt(sqrt(a))**9
def test_issue_from_PR1599():
n1, n2, n3, n4 = symbols('n1 n2 n3 n4', negative=True)
assert (powsimp(sqrt(n1)*sqrt(n2)*sqrt(n3)) ==
-I*sqrt(-n1)*sqrt(-n2)*sqrt(-n3))
assert (powsimp(root(n1, 3)*root(n2, 3)*root(n3, 3)*root(n4, 3)) ==
-(-1)**Rational(1, 3)*
(-n1)**Rational(1, 3)*(-n2)**Rational(1, 3)*(-n3)**Rational(1, 3)*(-n4)**Rational(1, 3))
def test_issue_10195():
a = Symbol('a', integer=True)
l = Symbol('l', even=True, nonzero=True)
n = Symbol('n', odd=True)
e_x = (-1)**(n/2 - S.Half) - (-1)**(n*Rational(3, 2) - S.Half)
assert powsimp((-1)**(l/2)) == I**l
assert powsimp((-1)**(n/2)) == I**n
assert powsimp((-1)**(n*Rational(3, 2))) == -I**n
assert powsimp(e_x) == (-1)**(n/2 - S.Half) + (-1)**(n*Rational(3, 2) +
S.Half)
assert powsimp((-1)**(a*Rational(3, 2))) == (-I)**a
def test_issue_15709():
assert powsimp(3**x*Rational(2, 3)) == 2*3**(x-1)
assert powsimp(2*3**x/3) == 2*3**(x-1)
def test_issue_11981():
x, y = symbols('x y', commutative=False)
assert powsimp((x*y)**2 * (y*x)**2) == (x*y)**2 * (y*x)**2
def test_issue_17524():
a = symbols("a", real=True)
e = (-1 - a**2)*sqrt(1 + a**2)
assert signsimp(powsimp(e)) == signsimp(e) == -(a**2 + 1)**(S(3)/2)
def test_issue_19627():
# if you use force the user must verify
assert powdenest(sqrt(sin(x)**2), force=True) == sin(x)
assert powdenest((x**(S.Half/y))**(2*y), force=True) == x
from sympy.core.function import expand_power_base
e = 1 - a
expr = (exp(z/e)*x**(b/e)*y**((1 - b)/e))**e
assert powdenest(expand_power_base(expr, force=True), force=True
) == x**b*y**(1 - b)*exp(z)
def test_issue_22546():
p1, p2 = symbols('p1, p2', positive=True)
ref = powsimp(p1**z/p2**z)
e = z + 1
ans = ref.subs(z, e)
assert ans.is_Pow
assert powsimp(p1**e/p2**e) == ans
i = symbols('i', integer=True)
ref = powsimp(x**i/y**i)
e = i + 1
ans = ref.subs(i, e)
assert ans.is_Pow
assert powsimp(x**e/y**e) == ans

View File

@ -0,0 +1,490 @@
from sympy.core.add import Add
from sympy.core.function import (Derivative, Function, diff)
from sympy.core.mul import Mul
from sympy.core.numbers import (I, Rational)
from sympy.core.power import Pow
from sympy.core.singleton import S
from sympy.core.symbol import (Symbol, Wild, symbols)
from sympy.functions.elementary.complexes import Abs
from sympy.functions.elementary.exponential import (exp, log)
from sympy.functions.elementary.miscellaneous import (root, sqrt)
from sympy.functions.elementary.trigonometric import (cos, sin)
from sympy.polys.polytools import factor
from sympy.series.order import O
from sympy.simplify.radsimp import (collect, collect_const, fraction, radsimp, rcollect)
from sympy.core.expr import unchanged
from sympy.core.mul import _unevaluated_Mul as umul
from sympy.simplify.radsimp import (_unevaluated_Add,
collect_sqrt, fraction_expand, collect_abs)
from sympy.testing.pytest import raises
from sympy.abc import x, y, z, a, b, c, d
def test_radsimp():
r2 = sqrt(2)
r3 = sqrt(3)
r5 = sqrt(5)
r7 = sqrt(7)
assert fraction(radsimp(1/r2)) == (sqrt(2), 2)
assert radsimp(1/(1 + r2)) == \
-1 + sqrt(2)
assert radsimp(1/(r2 + r3)) == \
-sqrt(2) + sqrt(3)
assert fraction(radsimp(1/(1 + r2 + r3))) == \
(-sqrt(6) + sqrt(2) + 2, 4)
assert fraction(radsimp(1/(r2 + r3 + r5))) == \
(-sqrt(30) + 2*sqrt(3) + 3*sqrt(2), 12)
assert fraction(radsimp(1/(1 + r2 + r3 + r5))) == (
(-34*sqrt(10) - 26*sqrt(15) - 55*sqrt(3) - 61*sqrt(2) + 14*sqrt(30) +
93 + 46*sqrt(6) + 53*sqrt(5), 71))
assert fraction(radsimp(1/(r2 + r3 + r5 + r7))) == (
(-50*sqrt(42) - 133*sqrt(5) - 34*sqrt(70) - 145*sqrt(3) + 22*sqrt(105)
+ 185*sqrt(2) + 62*sqrt(30) + 135*sqrt(7), 215))
z = radsimp(1/(1 + r2/3 + r3/5 + r5 + r7))
assert len((3616791619821680643598*z).args) == 16
assert radsimp(1/z) == 1/z
assert radsimp(1/z, max_terms=20).expand() == 1 + r2/3 + r3/5 + r5 + r7
assert radsimp(1/(r2*3)) == \
sqrt(2)/6
assert radsimp(1/(r2*a + r3 + r5 + r7)) == (
(8*sqrt(2)*a**7 - 8*sqrt(7)*a**6 - 8*sqrt(5)*a**6 - 8*sqrt(3)*a**6 -
180*sqrt(2)*a**5 + 8*sqrt(30)*a**5 + 8*sqrt(42)*a**5 + 8*sqrt(70)*a**5
- 24*sqrt(105)*a**4 + 84*sqrt(3)*a**4 + 100*sqrt(5)*a**4 +
116*sqrt(7)*a**4 - 72*sqrt(70)*a**3 - 40*sqrt(42)*a**3 -
8*sqrt(30)*a**3 + 782*sqrt(2)*a**3 - 462*sqrt(3)*a**2 -
302*sqrt(7)*a**2 - 254*sqrt(5)*a**2 + 120*sqrt(105)*a**2 -
795*sqrt(2)*a - 62*sqrt(30)*a + 82*sqrt(42)*a + 98*sqrt(70)*a -
118*sqrt(105) + 59*sqrt(7) + 295*sqrt(5) + 531*sqrt(3))/(16*a**8 -
480*a**6 + 3128*a**4 - 6360*a**2 + 3481))
assert radsimp(1/(r2*a + r2*b + r3 + r7)) == (
(sqrt(2)*a*(a + b)**2 - 5*sqrt(2)*a + sqrt(42)*a + sqrt(2)*b*(a +
b)**2 - 5*sqrt(2)*b + sqrt(42)*b - sqrt(7)*(a + b)**2 - sqrt(3)*(a +
b)**2 - 2*sqrt(3) + 2*sqrt(7))/(2*a**4 + 8*a**3*b + 12*a**2*b**2 -
20*a**2 + 8*a*b**3 - 40*a*b + 2*b**4 - 20*b**2 + 8))
assert radsimp(1/(r2*a + r2*b + r2*c + r2*d)) == \
sqrt(2)/(2*a + 2*b + 2*c + 2*d)
assert radsimp(1/(1 + r2*a + r2*b + r2*c + r2*d)) == (
(sqrt(2)*a + sqrt(2)*b + sqrt(2)*c + sqrt(2)*d - 1)/(2*a**2 + 4*a*b +
4*a*c + 4*a*d + 2*b**2 + 4*b*c + 4*b*d + 2*c**2 + 4*c*d + 2*d**2 - 1))
assert radsimp((y**2 - x)/(y - sqrt(x))) == \
sqrt(x) + y
assert radsimp(-(y**2 - x)/(y - sqrt(x))) == \
-(sqrt(x) + y)
assert radsimp(1/(1 - I + a*I)) == \
(-I*a + 1 + I)/(a**2 - 2*a + 2)
assert radsimp(1/((-x + y)*(x - sqrt(y)))) == \
(-x - sqrt(y))/((x - y)*(x**2 - y))
e = (3 + 3*sqrt(2))*x*(3*x - 3*sqrt(y))
assert radsimp(e) == x*(3 + 3*sqrt(2))*(3*x - 3*sqrt(y))
assert radsimp(1/e) == (
(-9*x + 9*sqrt(2)*x - 9*sqrt(y) + 9*sqrt(2)*sqrt(y))/(9*x*(9*x**2 -
9*y)))
assert radsimp(1 + 1/(1 + sqrt(3))) == \
Mul(S.Half, -1 + sqrt(3), evaluate=False) + 1
A = symbols("A", commutative=False)
assert radsimp(x**2 + sqrt(2)*x**2 - sqrt(2)*x*A) == \
x**2 + sqrt(2)*x**2 - sqrt(2)*x*A
assert radsimp(1/sqrt(5 + 2 * sqrt(6))) == -sqrt(2) + sqrt(3)
assert radsimp(1/sqrt(5 + 2 * sqrt(6))**3) == -(-sqrt(3) + sqrt(2))**3
# issue 6532
assert fraction(radsimp(1/sqrt(x))) == (sqrt(x), x)
assert fraction(radsimp(1/sqrt(2*x + 3))) == (sqrt(2*x + 3), 2*x + 3)
assert fraction(radsimp(1/sqrt(2*(x + 3)))) == (sqrt(2*x + 6), 2*x + 6)
# issue 5994
e = S('-(2 + 2*sqrt(2) + 4*2**(1/4))/'
'(1 + 2**(3/4) + 3*2**(1/4) + 3*sqrt(2))')
assert radsimp(e).expand() == -2*2**Rational(3, 4) - 2*2**Rational(1, 4) + 2 + 2*sqrt(2)
# issue 5986 (modifications to radimp didn't initially recognize this so
# the test is included here)
assert radsimp(1/(-sqrt(5)/2 - S.Half + (-sqrt(5)/2 - S.Half)**2)) == 1
# from issue 5934
eq = (
(-240*sqrt(2)*sqrt(sqrt(5) + 5)*sqrt(8*sqrt(5) + 40) -
360*sqrt(2)*sqrt(-8*sqrt(5) + 40)*sqrt(-sqrt(5) + 5) -
120*sqrt(10)*sqrt(-8*sqrt(5) + 40)*sqrt(-sqrt(5) + 5) +
120*sqrt(2)*sqrt(-sqrt(5) + 5)*sqrt(8*sqrt(5) + 40) +
120*sqrt(2)*sqrt(-8*sqrt(5) + 40)*sqrt(sqrt(5) + 5) +
120*sqrt(10)*sqrt(-sqrt(5) + 5)*sqrt(8*sqrt(5) + 40) +
120*sqrt(10)*sqrt(-8*sqrt(5) + 40)*sqrt(sqrt(5) + 5))/(-36000 -
7200*sqrt(5) + (12*sqrt(10)*sqrt(sqrt(5) + 5) +
24*sqrt(10)*sqrt(-sqrt(5) + 5))**2))
assert radsimp(eq) is S.NaN # it's 0/0
# work with normal form
e = 1/sqrt(sqrt(7)/7 + 2*sqrt(2) + 3*sqrt(3) + 5*sqrt(5)) + 3
assert radsimp(e) == (
-sqrt(sqrt(7) + 14*sqrt(2) + 21*sqrt(3) +
35*sqrt(5))*(-11654899*sqrt(35) - 1577436*sqrt(210) - 1278438*sqrt(15)
- 1346996*sqrt(10) + 1635060*sqrt(6) + 5709765 + 7539830*sqrt(14) +
8291415*sqrt(21))/1300423175 + 3)
# obey power rules
base = sqrt(3) - sqrt(2)
assert radsimp(1/base**3) == (sqrt(3) + sqrt(2))**3
assert radsimp(1/(-base)**3) == -(sqrt(2) + sqrt(3))**3
assert radsimp(1/(-base)**x) == (-base)**(-x)
assert radsimp(1/base**x) == (sqrt(2) + sqrt(3))**x
assert radsimp(root(1/(-1 - sqrt(2)), -x)) == (-1)**(-1/x)*(1 + sqrt(2))**(1/x)
# recurse
e = cos(1/(1 + sqrt(2)))
assert radsimp(e) == cos(-sqrt(2) + 1)
assert radsimp(e/2) == cos(-sqrt(2) + 1)/2
assert radsimp(1/e) == 1/cos(-sqrt(2) + 1)
assert radsimp(2/e) == 2/cos(-sqrt(2) + 1)
assert fraction(radsimp(e/sqrt(x))) == (sqrt(x)*cos(-sqrt(2)+1), x)
# test that symbolic denominators are not processed
r = 1 + sqrt(2)
assert radsimp(x/r, symbolic=False) == -x*(-sqrt(2) + 1)
assert radsimp(x/(y + r), symbolic=False) == x/(y + 1 + sqrt(2))
assert radsimp(x/(y + r)/r, symbolic=False) == \
-x*(-sqrt(2) + 1)/(y + 1 + sqrt(2))
# issue 7408
eq = sqrt(x)/sqrt(y)
assert radsimp(eq) == umul(sqrt(x), sqrt(y), 1/y)
assert radsimp(eq, symbolic=False) == eq
# issue 7498
assert radsimp(sqrt(x)/sqrt(y)**3) == umul(sqrt(x), sqrt(y**3), 1/y**3)
# for coverage
eq = sqrt(x)/y**2
assert radsimp(eq) == eq
def test_radsimp_issue_3214():
c, p = symbols('c p', positive=True)
s = sqrt(c**2 - p**2)
b = (c + I*p - s)/(c + I*p + s)
assert radsimp(b) == -I*(c + I*p - sqrt(c**2 - p**2))**2/(2*c*p)
def test_collect_1():
"""Collect with respect to Symbol"""
x, y, z, n = symbols('x,y,z,n')
assert collect(1, x) == 1
assert collect( x + y*x, x ) == x * (1 + y)
assert collect( x + x**2, x ) == x + x**2
assert collect( x**2 + y*x**2, x ) == (x**2)*(1 + y)
assert collect( x**2 + y*x, x ) == x*y + x**2
assert collect( 2*x**2 + y*x**2 + 3*x*y, [x] ) == x**2*(2 + y) + 3*x*y
assert collect( 2*x**2 + y*x**2 + 3*x*y, [y] ) == 2*x**2 + y*(x**2 + 3*x)
assert collect( ((1 + y + x)**4).expand(), x) == ((1 + y)**4).expand() + \
x*(4*(1 + y)**3).expand() + x**2*(6*(1 + y)**2).expand() + \
x**3*(4*(1 + y)).expand() + x**4
# symbols can be given as any iterable
expr = x + y
assert collect(expr, expr.free_symbols) == expr
assert collect(x*exp(x) + sin(x)*y + sin(x)*2 + 3*x, x, exact=None
) == x*exp(x) + 3*x + (y + 2)*sin(x)
assert collect(x*exp(x) + sin(x)*y + sin(x)*2 + 3*x + y*x +
y*x*exp(x), x, exact=None
) == x*exp(x)*(y + 1) + (3 + y)*x + (y + 2)*sin(x)
def test_collect_2():
"""Collect with respect to a sum"""
a, b, x = symbols('a,b,x')
assert collect(a*(cos(x) + sin(x)) + b*(cos(x) + sin(x)),
sin(x) + cos(x)) == (a + b)*(cos(x) + sin(x))
def test_collect_3():
"""Collect with respect to a product"""
a, b, c = symbols('a,b,c')
f = Function('f')
x, y, z, n = symbols('x,y,z,n')
assert collect(-x/8 + x*y, -x) == x*(y - Rational(1, 8))
assert collect( 1 + x*(y**2), x*y ) == 1 + x*(y**2)
assert collect( x*y + a*x*y, x*y) == x*y*(1 + a)
assert collect( 1 + x*y + a*x*y, x*y) == 1 + x*y*(1 + a)
assert collect(a*x*f(x) + b*(x*f(x)), x*f(x)) == x*(a + b)*f(x)
assert collect(a*x*log(x) + b*(x*log(x)), x*log(x)) == x*(a + b)*log(x)
assert collect(a*x**2*log(x)**2 + b*(x*log(x))**2, x*log(x)) == \
x**2*log(x)**2*(a + b)
# with respect to a product of three symbols
assert collect(y*x*z + a*x*y*z, x*y*z) == (1 + a)*x*y*z
def test_collect_4():
"""Collect with respect to a power"""
a, b, c, x = symbols('a,b,c,x')
assert collect(a*x**c + b*x**c, x**c) == x**c*(a + b)
# issue 6096: 2 stays with c (unless c is integer or x is positive0
assert collect(a*x**(2*c) + b*x**(2*c), x**c) == x**(2*c)*(a + b)
def test_collect_5():
"""Collect with respect to a tuple"""
a, x, y, z, n = symbols('a,x,y,z,n')
assert collect(x**2*y**4 + z*(x*y**2)**2 + z + a*z, [x*y**2, z]) in [
z*(1 + a + x**2*y**4) + x**2*y**4,
z*(1 + a) + x**2*y**4*(1 + z) ]
assert collect((1 + (x + y) + (x + y)**2).expand(),
[x, y]) == 1 + y + x*(1 + 2*y) + x**2 + y**2
def test_collect_pr19431():
"""Unevaluated collect with respect to a product"""
a = symbols('a')
assert collect(a**2*(a**2 + 1), a**2, evaluate=False)[a**2] == (a**2 + 1)
def test_collect_D():
D = Derivative
f = Function('f')
x, a, b = symbols('x,a,b')
fx = D(f(x), x)
fxx = D(f(x), x, x)
assert collect(a*fx + b*fx, fx) == (a + b)*fx
assert collect(a*D(fx, x) + b*D(fx, x), fx) == (a + b)*D(fx, x)
assert collect(a*fxx + b*fxx, fx) == (a + b)*D(fx, x)
# issue 4784
assert collect(5*f(x) + 3*fx, fx) == 5*f(x) + 3*fx
assert collect(f(x) + f(x)*diff(f(x), x) + x*diff(f(x), x)*f(x), f(x).diff(x)) == \
(x*f(x) + f(x))*D(f(x), x) + f(x)
assert collect(f(x) + f(x)*diff(f(x), x) + x*diff(f(x), x)*f(x), f(x).diff(x), exact=True) == \
(x*f(x) + f(x))*D(f(x), x) + f(x)
assert collect(1/f(x) + 1/f(x)*diff(f(x), x) + x*diff(f(x), x)/f(x), f(x).diff(x), exact=True) == \
(1/f(x) + x/f(x))*D(f(x), x) + 1/f(x)
e = (1 + x*fx + fx)/f(x)
assert collect(e.expand(), fx) == fx*(x/f(x) + 1/f(x)) + 1/f(x)
def test_collect_func():
f = ((x + a + 1)**3).expand()
assert collect(f, x) == a**3 + 3*a**2 + 3*a + x**3 + x**2*(3*a + 3) + \
x*(3*a**2 + 6*a + 3) + 1
assert collect(f, x, factor) == x**3 + 3*x**2*(a + 1) + 3*x*(a + 1)**2 + \
(a + 1)**3
assert collect(f, x, evaluate=False) == {
S.One: a**3 + 3*a**2 + 3*a + 1,
x: 3*a**2 + 6*a + 3, x**2: 3*a + 3,
x**3: 1
}
assert collect(f, x, factor, evaluate=False) == {
S.One: (a + 1)**3, x: 3*(a + 1)**2,
x**2: umul(S(3), a + 1), x**3: 1}
def test_collect_order():
a, b, x, t = symbols('a,b,x,t')
assert collect(t + t*x + t*x**2 + O(x**3), t) == t*(1 + x + x**2 + O(x**3))
assert collect(t + t*x + x**2 + O(x**3), t) == \
t*(1 + x + O(x**3)) + x**2 + O(x**3)
f = a*x + b*x + c*x**2 + d*x**2 + O(x**3)
g = x*(a + b) + x**2*(c + d) + O(x**3)
assert collect(f, x) == g
assert collect(f, x, distribute_order_term=False) == g
f = sin(a + b).series(b, 0, 10)
assert collect(f, [sin(a), cos(a)]) == \
sin(a)*cos(b).series(b, 0, 10) + cos(a)*sin(b).series(b, 0, 10)
assert collect(f, [sin(a), cos(a)], distribute_order_term=False) == \
sin(a)*cos(b).series(b, 0, 10).removeO() + \
cos(a)*sin(b).series(b, 0, 10).removeO() + O(b**10)
def test_rcollect():
assert rcollect((x**2*y + x*y + x + y)/(x + y), y) == \
(x + y*(1 + x + x**2))/(x + y)
assert rcollect(sqrt(-((x + 1)*(y + 1))), z) == sqrt(-((x + 1)*(y + 1)))
def test_collect_D_0():
D = Derivative
f = Function('f')
x, a, b = symbols('x,a,b')
fxx = D(f(x), x, x)
assert collect(a*fxx + b*fxx, fxx) == (a + b)*fxx
def test_collect_Wild():
"""Collect with respect to functions with Wild argument"""
a, b, x, y = symbols('a b x y')
f = Function('f')
w1 = Wild('.1')
w2 = Wild('.2')
assert collect(f(x) + a*f(x), f(w1)) == (1 + a)*f(x)
assert collect(f(x, y) + a*f(x, y), f(w1)) == f(x, y) + a*f(x, y)
assert collect(f(x, y) + a*f(x, y), f(w1, w2)) == (1 + a)*f(x, y)
assert collect(f(x, y) + a*f(x, y), f(w1, w1)) == f(x, y) + a*f(x, y)
assert collect(f(x, x) + a*f(x, x), f(w1, w1)) == (1 + a)*f(x, x)
assert collect(a*(x + 1)**y + (x + 1)**y, w1**y) == (1 + a)*(x + 1)**y
assert collect(a*(x + 1)**y + (x + 1)**y, w1**b) == \
a*(x + 1)**y + (x + 1)**y
assert collect(a*(x + 1)**y + (x + 1)**y, (x + 1)**w2) == \
(1 + a)*(x + 1)**y
assert collect(a*(x + 1)**y + (x + 1)**y, w1**w2) == (1 + a)*(x + 1)**y
def test_collect_const():
# coverage not provided by above tests
assert collect_const(2*sqrt(3) + 4*a*sqrt(5)) == \
2*(2*sqrt(5)*a + sqrt(3)) # let the primitive reabsorb
assert collect_const(2*sqrt(3) + 4*a*sqrt(5), sqrt(3)) == \
2*sqrt(3) + 4*a*sqrt(5)
assert collect_const(sqrt(2)*(1 + sqrt(2)) + sqrt(3) + x*sqrt(2)) == \
sqrt(2)*(x + 1 + sqrt(2)) + sqrt(3)
# issue 5290
assert collect_const(2*x + 2*y + 1, 2) == \
collect_const(2*x + 2*y + 1) == \
Add(S.One, Mul(2, x + y, evaluate=False), evaluate=False)
assert collect_const(-y - z) == Mul(-1, y + z, evaluate=False)
assert collect_const(2*x - 2*y - 2*z, 2) == \
Mul(2, x - y - z, evaluate=False)
assert collect_const(2*x - 2*y - 2*z, -2) == \
_unevaluated_Add(2*x, Mul(-2, y + z, evaluate=False))
# this is why the content_primitive is used
eq = (sqrt(15 + 5*sqrt(2))*x + sqrt(3 + sqrt(2))*y)*2
assert collect_sqrt(eq + 2) == \
2*sqrt(sqrt(2) + 3)*(sqrt(5)*x + y) + 2
# issue 16296
assert collect_const(a + b + x/2 + y/2) == a + b + Mul(S.Half, x + y, evaluate=False)
def test_issue_13143():
f = Function('f')
fx = f(x).diff(x)
e = f(x) + fx + f(x)*fx
# collect function before derivative
assert collect(e, Wild('w')) == f(x)*(fx + 1) + fx
e = f(x) + f(x)*fx + x*fx*f(x)
assert collect(e, fx) == (x*f(x) + f(x))*fx + f(x)
assert collect(e, f(x)) == (x*fx + fx + 1)*f(x)
e = f(x) + fx + f(x)*fx
assert collect(e, [f(x), fx]) == f(x)*(1 + fx) + fx
assert collect(e, [fx, f(x)]) == fx*(1 + f(x)) + f(x)
def test_issue_6097():
assert collect(a*y**(2.0*x) + b*y**(2.0*x), y**x) == (a + b)*(y**x)**2.0
assert collect(a*2**(2.0*x) + b*2**(2.0*x), 2**x) == (a + b)*(2**x)**2.0
def test_fraction_expand():
eq = (x + y)*y/x
assert eq.expand(frac=True) == fraction_expand(eq) == (x*y + y**2)/x
assert eq.expand() == y + y**2/x
def test_fraction():
x, y, z = map(Symbol, 'xyz')
A = Symbol('A', commutative=False)
assert fraction(S.Half) == (1, 2)
assert fraction(x) == (x, 1)
assert fraction(1/x) == (1, x)
assert fraction(x/y) == (x, y)
assert fraction(x/2) == (x, 2)
assert fraction(x*y/z) == (x*y, z)
assert fraction(x/(y*z)) == (x, y*z)
assert fraction(1/y**2) == (1, y**2)
assert fraction(x/y**2) == (x, y**2)
assert fraction((x**2 + 1)/y) == (x**2 + 1, y)
assert fraction(x*(y + 1)/y**7) == (x*(y + 1), y**7)
assert fraction(exp(-x), exact=True) == (exp(-x), 1)
assert fraction((1/(x + y))/2, exact=True) == (1, Mul(2,(x + y), evaluate=False))
assert fraction(x*A/y) == (x*A, y)
assert fraction(x*A**-1/y) == (x*A**-1, y)
n = symbols('n', negative=True)
assert fraction(exp(n)) == (1, exp(-n))
assert fraction(exp(-n)) == (exp(-n), 1)
p = symbols('p', positive=True)
assert fraction(exp(-p)*log(p), exact=True) == (exp(-p)*log(p), 1)
m = Mul(1, 1, S.Half, evaluate=False)
assert fraction(m) == (1, 2)
assert fraction(m, exact=True) == (Mul(1, 1, evaluate=False), 2)
m = Mul(1, 1, S.Half, S.Half, Pow(1, -1, evaluate=False), evaluate=False)
assert fraction(m) == (1, 4)
assert fraction(m, exact=True) == \
(Mul(1, 1, evaluate=False), Mul(2, 2, 1, evaluate=False))
def test_issue_5615():
aA, Re, a, b, D = symbols('aA Re a b D')
e = ((D**3*a + b*aA**3)/Re).expand()
assert collect(e, [aA**3/Re, a]) == e
def test_issue_5933():
from sympy.geometry.polygon import (Polygon, RegularPolygon)
from sympy.simplify.radsimp import denom
x = Polygon(*RegularPolygon((0, 0), 1, 5).vertices).centroid.x
assert abs(denom(x).n()) > 1e-12
assert abs(denom(radsimp(x))) > 1e-12 # in case simplify didn't handle it
def test_issue_14608():
a, b = symbols('a b', commutative=False)
x, y = symbols('x y')
raises(AttributeError, lambda: collect(a*b + b*a, a))
assert collect(x*y + y*(x+1), a) == x*y + y*(x+1)
assert collect(x*y + y*(x+1) + a*b + b*a, y) == y*(2*x + 1) + a*b + b*a
def test_collect_abs():
s = abs(x) + abs(y)
assert collect_abs(s) == s
assert unchanged(Mul, abs(x), abs(y))
ans = Abs(x*y)
assert isinstance(ans, Abs)
assert collect_abs(abs(x)*abs(y)) == ans
assert collect_abs(1 + exp(abs(x)*abs(y))) == 1 + exp(ans)
# See https://github.com/sympy/sympy/issues/12910
p = Symbol('p', positive=True)
assert collect_abs(p/abs(1-p)).is_commutative is True
def test_issue_19149():
eq = exp(3*x/4)
assert collect(eq, exp(x)) == eq
def test_issue_19719():
a, b = symbols('a, b')
expr = a**2 * (b + 1) + (7 + 1/b)/a
collected = collect(expr, (a**2, 1/a), evaluate=False)
# Would return {_Dummy_20**(-2): b + 1, 1/a: 7 + 1/b} without xreplace
assert collected == {a**2: b + 1, 1/a: 7 + 1/b}
def test_issue_21355():
assert radsimp(1/(x + sqrt(x**2))) == 1/(x + sqrt(x**2))
assert radsimp(1/(x - sqrt(x**2))) == 1/(x - sqrt(x**2))

View File

@ -0,0 +1,78 @@
from sympy.core.numbers import (Rational, pi)
from sympy.functions.elementary.exponential import log
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.special.error_functions import erf
from sympy.polys.domains import GF
from sympy.simplify.ratsimp import (ratsimp, ratsimpmodprime)
from sympy.abc import x, y, z, t, a, b, c, d, e
def test_ratsimp():
f, g = 1/x + 1/y, (x + y)/(x*y)
assert f != g and ratsimp(f) == g
f, g = 1/(1 + 1/x), 1 - 1/(x + 1)
assert f != g and ratsimp(f) == g
f, g = x/(x + y) + y/(x + y), 1
assert f != g and ratsimp(f) == g
f, g = -x - y - y**2/(x + y) + x**2/(x + y), -2*y
assert f != g and ratsimp(f) == g
f = (a*c*x*y + a*c*z - b*d*x*y - b*d*z - b*t*x*y - b*t*x - b*t*z +
e*x)/(x*y + z)
G = [a*c - b*d - b*t + (-b*t*x + e*x)/(x*y + z),
a*c - b*d - b*t - ( b*t*x - e*x)/(x*y + z)]
assert f != g and ratsimp(f) in G
A = sqrt(pi)
B = log(erf(x) - 1)
C = log(erf(x) + 1)
D = 8 - 8*erf(x)
f = A*B/D - A*C/D + A*C*erf(x)/D - A*B*erf(x)/D + 2*A/D
assert ratsimp(f) == A*B/8 - A*C/8 - A/(4*erf(x) - 4)
def test_ratsimpmodprime():
a = y**5 + x + y
b = x - y
F = [x*y**5 - x - y]
assert ratsimpmodprime(a/b, F, x, y, order='lex') == \
(-x**2 - x*y - x - y) / (-x**2 + x*y)
a = x + y**2 - 2
b = x + y**2 - y - 1
F = [x*y - 1]
assert ratsimpmodprime(a/b, F, x, y, order='lex') == \
(1 + y - x)/(y - x)
a = 5*x**3 + 21*x**2 + 4*x*y + 23*x + 12*y + 15
b = 7*x**3 - y*x**2 + 31*x**2 + 2*x*y + 15*y + 37*x + 21
F = [x**2 + y**2 - 1]
assert ratsimpmodprime(a/b, F, x, y, order='lex') == \
(1 + 5*y - 5*x)/(8*y - 6*x)
a = x*y - x - 2*y + 4
b = x + y**2 - 2*y
F = [x - 2, y - 3]
assert ratsimpmodprime(a/b, F, x, y, order='lex') == \
Rational(2, 5)
# Test a bug where denominators would be dropped
assert ratsimpmodprime(x, [y - 2*x], order='lex') == \
y/2
a = (x**5 + 2*x**4 + 2*x**3 + 2*x**2 + x + 2/x + x**(-2))
assert ratsimpmodprime(a, [x + 1], domain=GF(2)) == 1
assert ratsimpmodprime(a, [x + 1], domain=GF(3)) == -1

View File

@ -0,0 +1,31 @@
from sympy.core.numbers import I
from sympy.core.symbol import symbols
from sympy.functions.elementary.exponential import exp
from sympy.functions.elementary.trigonometric import (cos, cot, sin)
from sympy.testing.pytest import _both_exp_pow
x, y, z, n = symbols('x,y,z,n')
@_both_exp_pow
def test_has():
assert cot(x).has(x)
assert cot(x).has(cot)
assert not cot(x).has(sin)
assert sin(x).has(x)
assert sin(x).has(sin)
assert not sin(x).has(cot)
assert exp(x).has(exp)
@_both_exp_pow
def test_sin_exp_rewrite():
assert sin(x).rewrite(sin, exp) == -I/2*(exp(I*x) - exp(-I*x))
assert sin(x).rewrite(sin, exp).rewrite(exp, sin) == sin(x)
assert cos(x).rewrite(cos, exp).rewrite(exp, cos) == cos(x)
assert (sin(5*y) - sin(
2*x)).rewrite(sin, exp).rewrite(exp, sin) == sin(5*y) - sin(2*x)
assert sin(x + y).rewrite(sin, exp).rewrite(exp, sin) == sin(x + y)
assert cos(x + y).rewrite(cos, exp).rewrite(exp, cos) == cos(x + y)
# This next test currently passes... not clear whether it should or not?
assert cos(x).rewrite(cos, exp).rewrite(exp, sin) == cos(x)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,204 @@
from sympy.core.mul import Mul
from sympy.core.numbers import (I, Integer, Rational)
from sympy.core.symbol import Symbol
from sympy.functions.elementary.miscellaneous import (root, sqrt)
from sympy.functions.elementary.trigonometric import cos
from sympy.integrals.integrals import Integral
from sympy.simplify.sqrtdenest import sqrtdenest
from sympy.simplify.sqrtdenest import (
_subsets as subsets, _sqrt_numeric_denest)
r2, r3, r5, r6, r7, r10, r15, r29 = [sqrt(x) for x in (2, 3, 5, 6, 7, 10,
15, 29)]
def test_sqrtdenest():
d = {sqrt(5 + 2 * r6): r2 + r3,
sqrt(5. + 2 * r6): sqrt(5. + 2 * r6),
sqrt(5. + 4*sqrt(5 + 2 * r6)): sqrt(5.0 + 4*r2 + 4*r3),
sqrt(r2): sqrt(r2),
sqrt(5 + r7): sqrt(5 + r7),
sqrt(3 + sqrt(5 + 2*r7)):
3*r2*(5 + 2*r7)**Rational(1, 4)/(2*sqrt(6 + 3*r7)) +
r2*sqrt(6 + 3*r7)/(2*(5 + 2*r7)**Rational(1, 4)),
sqrt(3 + 2*r3): 3**Rational(3, 4)*(r6/2 + 3*r2/2)/3}
for i in d:
assert sqrtdenest(i) == d[i], i
def test_sqrtdenest2():
assert sqrtdenest(sqrt(16 - 2*r29 + 2*sqrt(55 - 10*r29))) == \
r5 + sqrt(11 - 2*r29)
e = sqrt(-r5 + sqrt(-2*r29 + 2*sqrt(-10*r29 + 55) + 16))
assert sqrtdenest(e) == root(-2*r29 + 11, 4)
r = sqrt(1 + r7)
assert sqrtdenest(sqrt(1 + r)) == sqrt(1 + r)
e = sqrt(((1 + sqrt(1 + 2*sqrt(3 + r2 + r5)))**2).expand())
assert sqrtdenest(e) == 1 + sqrt(1 + 2*sqrt(r2 + r5 + 3))
assert sqrtdenest(sqrt(5*r3 + 6*r2)) == \
sqrt(2)*root(3, 4) + root(3, 4)**3
assert sqrtdenest(sqrt(((1 + r5 + sqrt(1 + r3))**2).expand())) == \
1 + r5 + sqrt(1 + r3)
assert sqrtdenest(sqrt(((1 + r5 + r7 + sqrt(1 + r3))**2).expand())) == \
1 + sqrt(1 + r3) + r5 + r7
e = sqrt(((1 + cos(2) + cos(3) + sqrt(1 + r3))**2).expand())
assert sqrtdenest(e) == cos(3) + cos(2) + 1 + sqrt(1 + r3)
e = sqrt(-2*r10 + 2*r2*sqrt(-2*r10 + 11) + 14)
assert sqrtdenest(e) == sqrt(-2*r10 - 2*r2 + 4*r5 + 14)
# check that the result is not more complicated than the input
z = sqrt(-2*r29 + cos(2) + 2*sqrt(-10*r29 + 55) + 16)
assert sqrtdenest(z) == z
assert sqrtdenest(sqrt(r6 + sqrt(15))) == sqrt(r6 + sqrt(15))
z = sqrt(15 - 2*sqrt(31) + 2*sqrt(55 - 10*r29))
assert sqrtdenest(z) == z
def test_sqrtdenest_rec():
assert sqrtdenest(sqrt(-4*sqrt(14) - 2*r6 + 4*sqrt(21) + 33)) == \
-r2 + r3 + 2*r7
assert sqrtdenest(sqrt(-28*r7 - 14*r5 + 4*sqrt(35) + 82)) == \
-7 + r5 + 2*r7
assert sqrtdenest(sqrt(6*r2/11 + 2*sqrt(22)/11 + 6*sqrt(11)/11 + 2)) == \
sqrt(11)*(r2 + 3 + sqrt(11))/11
assert sqrtdenest(sqrt(468*r3 + 3024*r2 + 2912*r6 + 19735)) == \
9*r3 + 26 + 56*r6
z = sqrt(-490*r3 - 98*sqrt(115) - 98*sqrt(345) - 2107)
assert sqrtdenest(z) == sqrt(-1)*(7*r5 + 7*r15 + 7*sqrt(23))
z = sqrt(-4*sqrt(14) - 2*r6 + 4*sqrt(21) + 34)
assert sqrtdenest(z) == z
assert sqrtdenest(sqrt(-8*r2 - 2*r5 + 18)) == -r10 + 1 + r2 + r5
assert sqrtdenest(sqrt(8*r2 + 2*r5 - 18)) == \
sqrt(-1)*(-r10 + 1 + r2 + r5)
assert sqrtdenest(sqrt(8*r2/3 + 14*r5/3 + Rational(154, 9))) == \
-r10/3 + r2 + r5 + 3
assert sqrtdenest(sqrt(sqrt(2*r6 + 5) + sqrt(2*r7 + 8))) == \
sqrt(1 + r2 + r3 + r7)
assert sqrtdenest(sqrt(4*r15 + 8*r5 + 12*r3 + 24)) == 1 + r3 + r5 + r15
w = 1 + r2 + r3 + r5 + r7
assert sqrtdenest(sqrt((w**2).expand())) == w
z = sqrt((w**2).expand() + 1)
assert sqrtdenest(z) == z
z = sqrt(2*r10 + 6*r2 + 4*r5 + 12 + 10*r15 + 30*r3)
assert sqrtdenest(z) == z
def test_issue_6241():
z = sqrt( -320 + 32*sqrt(5) + 64*r15)
assert sqrtdenest(z) == z
def test_sqrtdenest3():
z = sqrt(13 - 2*r10 + 2*r2*sqrt(-2*r10 + 11))
assert sqrtdenest(z) == -1 + r2 + r10
assert sqrtdenest(z, max_iter=1) == -1 + sqrt(2) + sqrt(10)
z = sqrt(sqrt(r2 + 2) + 2)
assert sqrtdenest(z) == z
assert sqrtdenest(sqrt(-2*r10 + 4*r2*sqrt(-2*r10 + 11) + 20)) == \
sqrt(-2*r10 - 4*r2 + 8*r5 + 20)
assert sqrtdenest(sqrt((112 + 70*r2) + (46 + 34*r2)*r5)) == \
r10 + 5 + 4*r2 + 3*r5
z = sqrt(5 + sqrt(2*r6 + 5)*sqrt(-2*r29 + 2*sqrt(-10*r29 + 55) + 16))
r = sqrt(-2*r29 + 11)
assert sqrtdenest(z) == sqrt(r2*r + r3*r + r10 + r15 + 5)
n = sqrt(2*r6/7 + 2*r7/7 + 2*sqrt(42)/7 + 2)
d = sqrt(16 - 2*r29 + 2*sqrt(55 - 10*r29))
assert sqrtdenest(n/d) == r7*(1 + r6 + r7)/(Mul(7, (sqrt(-2*r29 + 11) + r5),
evaluate=False))
def test_sqrtdenest4():
# see Denest_en.pdf in https://github.com/sympy/sympy/issues/3192
z = sqrt(8 - r2*sqrt(5 - r5) - sqrt(3)*(1 + r5))
z1 = sqrtdenest(z)
c = sqrt(-r5 + 5)
z1 = ((-r15*c - r3*c + c + r5*c - r6 - r2 + r10 + sqrt(30))/4).expand()
assert sqrtdenest(z) == z1
z = sqrt(2*r2*sqrt(r2 + 2) + 5*r2 + 4*sqrt(r2 + 2) + 8)
assert sqrtdenest(z) == r2 + sqrt(r2 + 2) + 2
w = 2 + r2 + r3 + (1 + r3)*sqrt(2 + r2 + 5*r3)
z = sqrt((w**2).expand())
assert sqrtdenest(z) == w.expand()
def test_sqrt_symbolic_denest():
x = Symbol('x')
z = sqrt(((1 + sqrt(sqrt(2 + x) + 3))**2).expand())
assert sqrtdenest(z) == sqrt((1 + sqrt(sqrt(2 + x) + 3))**2)
z = sqrt(((1 + sqrt(sqrt(2 + cos(1)) + 3))**2).expand())
assert sqrtdenest(z) == 1 + sqrt(sqrt(2 + cos(1)) + 3)
z = ((1 + cos(2))**4 + 1).expand()
assert sqrtdenest(z) == z
z = sqrt(((1 + sqrt(sqrt(2 + cos(3*x)) + 3))**2 + 1).expand())
assert sqrtdenest(z) == z
c = cos(3)
c2 = c**2
assert sqrtdenest(sqrt(2*sqrt(1 + r3)*c + c2 + 1 + r3*c2)) == \
-1 - sqrt(1 + r3)*c
ra = sqrt(1 + r3)
z = sqrt(20*ra*sqrt(3 + 3*r3) + 12*r3*ra*sqrt(3 + 3*r3) + 64*r3 + 112)
assert sqrtdenest(z) == z
def test_issue_5857():
from sympy.abc import x, y
z = sqrt(1/(4*r3 + 7) + 1)
ans = (r2 + r6)/(r3 + 2)
assert sqrtdenest(z) == ans
assert sqrtdenest(1 + z) == 1 + ans
assert sqrtdenest(Integral(z + 1, (x, 1, 2))) == \
Integral(1 + ans, (x, 1, 2))
assert sqrtdenest(x + sqrt(y)) == x + sqrt(y)
ans = (r2 + r6)/(r3 + 2)
assert sqrtdenest(z) == ans
assert sqrtdenest(1 + z) == 1 + ans
assert sqrtdenest(Integral(z + 1, (x, 1, 2))) == \
Integral(1 + ans, (x, 1, 2))
assert sqrtdenest(x + sqrt(y)) == x + sqrt(y)
def test_subsets():
assert subsets(1) == [[1]]
assert subsets(4) == [
[1, 0, 0, 0], [0, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 0], [1, 0, 1, 0],
[0, 1, 1, 0], [1, 1, 1, 0], [0, 0, 0, 1], [1, 0, 0, 1], [0, 1, 0, 1],
[1, 1, 0, 1], [0, 0, 1, 1], [1, 0, 1, 1], [0, 1, 1, 1], [1, 1, 1, 1]]
def test_issue_5653():
assert sqrtdenest(
sqrt(2 + sqrt(2 + sqrt(2)))) == sqrt(2 + sqrt(2 + sqrt(2)))
def test_issue_12420():
assert sqrtdenest((3 - sqrt(2)*sqrt(4 + 3*I) + 3*I)/2) == I
e = 3 - sqrt(2)*sqrt(4 + I) + 3*I
assert sqrtdenest(e) == e
def test_sqrt_ratcomb():
assert sqrtdenest(sqrt(1 + r3) + sqrt(3 + 3*r3) - sqrt(10 + 6*r3)) == 0
def test_issue_18041():
e = -sqrt(-2 + 2*sqrt(3)*I)
assert sqrtdenest(e) == -1 - sqrt(3)*I
def test_issue_19914():
a = Integer(-8)
b = Integer(-1)
r = Integer(63)
d2 = a*a - b*b*r
assert _sqrt_numeric_denest(a, b, r, d2) == \
sqrt(14)*I/2 + 3*sqrt(2)*I/2
assert sqrtdenest(sqrt(-8-sqrt(63))) == sqrt(14)*I/2 + 3*sqrt(2)*I/2

View File

@ -0,0 +1,520 @@
from itertools import product
from sympy.core.function import (Subs, count_ops, diff, expand)
from sympy.core.numbers import (E, I, Rational, pi)
from sympy.core.singleton import S
from sympy.core.symbol import (Symbol, symbols)
from sympy.functions.elementary.exponential import (exp, log)
from sympy.functions.elementary.hyperbolic import (cosh, coth, sinh, tanh)
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.piecewise import Piecewise
from sympy.functions.elementary.trigonometric import (cos, cot, sin, tan)
from sympy.functions.elementary.trigonometric import (acos, asin, atan2)
from sympy.functions.elementary.trigonometric import (asec, acsc)
from sympy.functions.elementary.trigonometric import (acot, atan)
from sympy.integrals.integrals import integrate
from sympy.matrices.dense import Matrix
from sympy.simplify.simplify import simplify
from sympy.simplify.trigsimp import (exptrigsimp, trigsimp)
from sympy.testing.pytest import XFAIL
from sympy.abc import x, y
def test_trigsimp1():
x, y = symbols('x,y')
assert trigsimp(1 - sin(x)**2) == cos(x)**2
assert trigsimp(1 - cos(x)**2) == sin(x)**2
assert trigsimp(sin(x)**2 + cos(x)**2) == 1
assert trigsimp(1 + tan(x)**2) == 1/cos(x)**2
assert trigsimp(1/cos(x)**2 - 1) == tan(x)**2
assert trigsimp(1/cos(x)**2 - tan(x)**2) == 1
assert trigsimp(1 + cot(x)**2) == 1/sin(x)**2
assert trigsimp(1/sin(x)**2 - 1) == 1/tan(x)**2
assert trigsimp(1/sin(x)**2 - cot(x)**2) == 1
assert trigsimp(5*cos(x)**2 + 5*sin(x)**2) == 5
assert trigsimp(5*cos(x/2)**2 + 2*sin(x/2)**2) == 3*cos(x)/2 + Rational(7, 2)
assert trigsimp(sin(x)/cos(x)) == tan(x)
assert trigsimp(2*tan(x)*cos(x)) == 2*sin(x)
assert trigsimp(cot(x)**3*sin(x)**3) == cos(x)**3
assert trigsimp(y*tan(x)**2/sin(x)**2) == y/cos(x)**2
assert trigsimp(cot(x)/cos(x)) == 1/sin(x)
assert trigsimp(sin(x + y) + sin(x - y)) == 2*sin(x)*cos(y)
assert trigsimp(sin(x + y) - sin(x - y)) == 2*sin(y)*cos(x)
assert trigsimp(cos(x + y) + cos(x - y)) == 2*cos(x)*cos(y)
assert trigsimp(cos(x + y) - cos(x - y)) == -2*sin(x)*sin(y)
assert trigsimp(tan(x + y) - tan(x)/(1 - tan(x)*tan(y))) == \
sin(y)/(-sin(y)*tan(x) + cos(y)) # -tan(y)/(tan(x)*tan(y) - 1)
assert trigsimp(sinh(x + y) + sinh(x - y)) == 2*sinh(x)*cosh(y)
assert trigsimp(sinh(x + y) - sinh(x - y)) == 2*sinh(y)*cosh(x)
assert trigsimp(cosh(x + y) + cosh(x - y)) == 2*cosh(x)*cosh(y)
assert trigsimp(cosh(x + y) - cosh(x - y)) == 2*sinh(x)*sinh(y)
assert trigsimp(tanh(x + y) - tanh(x)/(1 + tanh(x)*tanh(y))) == \
sinh(y)/(sinh(y)*tanh(x) + cosh(y))
assert trigsimp(cos(0.12345)**2 + sin(0.12345)**2) == 1.0
e = 2*sin(x)**2 + 2*cos(x)**2
assert trigsimp(log(e)) == log(2)
def test_trigsimp1a():
assert trigsimp(sin(2)**2*cos(3)*exp(2)/cos(2)**2) == tan(2)**2*cos(3)*exp(2)
assert trigsimp(tan(2)**2*cos(3)*exp(2)*cos(2)**2) == sin(2)**2*cos(3)*exp(2)
assert trigsimp(cot(2)*cos(3)*exp(2)*sin(2)) == cos(3)*exp(2)*cos(2)
assert trigsimp(tan(2)*cos(3)*exp(2)/sin(2)) == cos(3)*exp(2)/cos(2)
assert trigsimp(cot(2)*cos(3)*exp(2)/cos(2)) == cos(3)*exp(2)/sin(2)
assert trigsimp(cot(2)*cos(3)*exp(2)*tan(2)) == cos(3)*exp(2)
assert trigsimp(sinh(2)*cos(3)*exp(2)/cosh(2)) == tanh(2)*cos(3)*exp(2)
assert trigsimp(tanh(2)*cos(3)*exp(2)*cosh(2)) == sinh(2)*cos(3)*exp(2)
assert trigsimp(coth(2)*cos(3)*exp(2)*sinh(2)) == cosh(2)*cos(3)*exp(2)
assert trigsimp(tanh(2)*cos(3)*exp(2)/sinh(2)) == cos(3)*exp(2)/cosh(2)
assert trigsimp(coth(2)*cos(3)*exp(2)/cosh(2)) == cos(3)*exp(2)/sinh(2)
assert trigsimp(coth(2)*cos(3)*exp(2)*tanh(2)) == cos(3)*exp(2)
def test_trigsimp2():
x, y = symbols('x,y')
assert trigsimp(cos(x)**2*sin(y)**2 + cos(x)**2*cos(y)**2 + sin(x)**2,
recursive=True) == 1
assert trigsimp(sin(x)**2*sin(y)**2 + sin(x)**2*cos(y)**2 + cos(x)**2,
recursive=True) == 1
assert trigsimp(
Subs(x, x, sin(y)**2 + cos(y)**2)) == Subs(x, x, 1)
def test_issue_4373():
x = Symbol("x")
assert abs(trigsimp(2.0*sin(x)**2 + 2.0*cos(x)**2) - 2.0) < 1e-10
def test_trigsimp3():
x, y = symbols('x,y')
assert trigsimp(sin(x)/cos(x)) == tan(x)
assert trigsimp(sin(x)**2/cos(x)**2) == tan(x)**2
assert trigsimp(sin(x)**3/cos(x)**3) == tan(x)**3
assert trigsimp(sin(x)**10/cos(x)**10) == tan(x)**10
assert trigsimp(cos(x)/sin(x)) == 1/tan(x)
assert trigsimp(cos(x)**2/sin(x)**2) == 1/tan(x)**2
assert trigsimp(cos(x)**10/sin(x)**10) == 1/tan(x)**10
assert trigsimp(tan(x)) == trigsimp(sin(x)/cos(x))
def test_issue_4661():
a, x, y = symbols('a x y')
eq = -4*sin(x)**4 + 4*cos(x)**4 - 8*cos(x)**2
assert trigsimp(eq) == -4
n = sin(x)**6 + 4*sin(x)**4*cos(x)**2 + 5*sin(x)**2*cos(x)**4 + 2*cos(x)**6
d = -sin(x)**2 - 2*cos(x)**2
assert simplify(n/d) == -1
assert trigsimp(-2*cos(x)**2 + cos(x)**4 - sin(x)**4) == -1
eq = (- sin(x)**3/4)*cos(x) + (cos(x)**3/4)*sin(x) - sin(2*x)*cos(2*x)/8
assert trigsimp(eq) == 0
def test_issue_4494():
a, b = symbols('a b')
eq = sin(a)**2*sin(b)**2 + cos(a)**2*cos(b)**2*tan(a)**2 + cos(a)**2
assert trigsimp(eq) == 1
def test_issue_5948():
a, x, y = symbols('a x y')
assert trigsimp(diff(integrate(cos(x)/sin(x)**7, x), x)) == \
cos(x)/sin(x)**7
def test_issue_4775():
a, x, y = symbols('a x y')
assert trigsimp(sin(x)*cos(y)+cos(x)*sin(y)) == sin(x + y)
assert trigsimp(sin(x)*cos(y)+cos(x)*sin(y)+3) == sin(x + y) + 3
def test_issue_4280():
a, x, y = symbols('a x y')
assert trigsimp(cos(x)**2 + cos(y)**2*sin(x)**2 + sin(y)**2*sin(x)**2) == 1
assert trigsimp(a**2*sin(x)**2 + a**2*cos(y)**2*cos(x)**2 + a**2*cos(x)**2*sin(y)**2) == a**2
assert trigsimp(a**2*cos(y)**2*sin(x)**2 + a**2*sin(y)**2*sin(x)**2) == a**2*sin(x)**2
def test_issue_3210():
eqs = (sin(2)*cos(3) + sin(3)*cos(2),
-sin(2)*sin(3) + cos(2)*cos(3),
sin(2)*cos(3) - sin(3)*cos(2),
sin(2)*sin(3) + cos(2)*cos(3),
sin(2)*sin(3) + cos(2)*cos(3) + cos(2),
sinh(2)*cosh(3) + sinh(3)*cosh(2),
sinh(2)*sinh(3) + cosh(2)*cosh(3),
)
assert [trigsimp(e) for e in eqs] == [
sin(5),
cos(5),
-sin(1),
cos(1),
cos(1) + cos(2),
sinh(5),
cosh(5),
]
def test_trigsimp_issues():
a, x, y = symbols('a x y')
# issue 4625 - factor_terms works, too
assert trigsimp(sin(x)**3 + cos(x)**2*sin(x)) == sin(x)
# issue 5948
assert trigsimp(diff(integrate(cos(x)/sin(x)**3, x), x)) == \
cos(x)/sin(x)**3
assert trigsimp(diff(integrate(sin(x)/cos(x)**3, x), x)) == \
sin(x)/cos(x)**3
# check integer exponents
e = sin(x)**y/cos(x)**y
assert trigsimp(e) == e
assert trigsimp(e.subs(y, 2)) == tan(x)**2
assert trigsimp(e.subs(x, 1)) == tan(1)**y
# check for multiple patterns
assert (cos(x)**2/sin(x)**2*cos(y)**2/sin(y)**2).trigsimp() == \
1/tan(x)**2/tan(y)**2
assert trigsimp(cos(x)/sin(x)*cos(x+y)/sin(x+y)) == \
1/(tan(x)*tan(x + y))
eq = cos(2)*(cos(3) + 1)**2/(cos(3) - 1)**2
assert trigsimp(eq) == eq.factor() # factor makes denom (-1 + cos(3))**2
assert trigsimp(cos(2)*(cos(3) + 1)**2*(cos(3) - 1)**2) == \
cos(2)*sin(3)**4
# issue 6789; this generates an expression that formerly caused
# trigsimp to hang
assert cot(x).equals(tan(x)) is False
# nan or the unchanged expression is ok, but not sin(1)
z = cos(x)**2 + sin(x)**2 - 1
z1 = tan(x)**2 - 1/cot(x)**2
n = (1 + z1/z)
assert trigsimp(sin(n)) != sin(1)
eq = x*(n - 1) - x*n
assert trigsimp(eq) is S.NaN
assert trigsimp(eq, recursive=True) is S.NaN
assert trigsimp(1).is_Integer
assert trigsimp(-sin(x)**4 - 2*sin(x)**2*cos(x)**2 - cos(x)**4) == -1
def test_trigsimp_issue_2515():
x = Symbol('x')
assert trigsimp(x*cos(x)*tan(x)) == x*sin(x)
assert trigsimp(-sin(x) + cos(x)*tan(x)) == 0
def test_trigsimp_issue_3826():
assert trigsimp(tan(2*x).expand(trig=True)) == tan(2*x)
def test_trigsimp_issue_4032():
n = Symbol('n', integer=True, positive=True)
assert trigsimp(2**(n/2)*cos(pi*n/4)/2 + 2**(n - 1)/2) == \
2**(n/2)*cos(pi*n/4)/2 + 2**n/4
def test_trigsimp_issue_7761():
assert trigsimp(cosh(pi/4)) == cosh(pi/4)
def test_trigsimp_noncommutative():
x, y = symbols('x,y')
A, B = symbols('A,B', commutative=False)
assert trigsimp(A - A*sin(x)**2) == A*cos(x)**2
assert trigsimp(A - A*cos(x)**2) == A*sin(x)**2
assert trigsimp(A*sin(x)**2 + A*cos(x)**2) == A
assert trigsimp(A + A*tan(x)**2) == A/cos(x)**2
assert trigsimp(A/cos(x)**2 - A) == A*tan(x)**2
assert trigsimp(A/cos(x)**2 - A*tan(x)**2) == A
assert trigsimp(A + A*cot(x)**2) == A/sin(x)**2
assert trigsimp(A/sin(x)**2 - A) == A/tan(x)**2
assert trigsimp(A/sin(x)**2 - A*cot(x)**2) == A
assert trigsimp(y*A*cos(x)**2 + y*A*sin(x)**2) == y*A
assert trigsimp(A*sin(x)/cos(x)) == A*tan(x)
assert trigsimp(A*tan(x)*cos(x)) == A*sin(x)
assert trigsimp(A*cot(x)**3*sin(x)**3) == A*cos(x)**3
assert trigsimp(y*A*tan(x)**2/sin(x)**2) == y*A/cos(x)**2
assert trigsimp(A*cot(x)/cos(x)) == A/sin(x)
assert trigsimp(A*sin(x + y) + A*sin(x - y)) == 2*A*sin(x)*cos(y)
assert trigsimp(A*sin(x + y) - A*sin(x - y)) == 2*A*sin(y)*cos(x)
assert trigsimp(A*cos(x + y) + A*cos(x - y)) == 2*A*cos(x)*cos(y)
assert trigsimp(A*cos(x + y) - A*cos(x - y)) == -2*A*sin(x)*sin(y)
assert trigsimp(A*sinh(x + y) + A*sinh(x - y)) == 2*A*sinh(x)*cosh(y)
assert trigsimp(A*sinh(x + y) - A*sinh(x - y)) == 2*A*sinh(y)*cosh(x)
assert trigsimp(A*cosh(x + y) + A*cosh(x - y)) == 2*A*cosh(x)*cosh(y)
assert trigsimp(A*cosh(x + y) - A*cosh(x - y)) == 2*A*sinh(x)*sinh(y)
assert trigsimp(A*cos(0.12345)**2 + A*sin(0.12345)**2) == 1.0*A
def test_hyperbolic_simp():
x, y = symbols('x,y')
assert trigsimp(sinh(x)**2 + 1) == cosh(x)**2
assert trigsimp(cosh(x)**2 - 1) == sinh(x)**2
assert trigsimp(cosh(x)**2 - sinh(x)**2) == 1
assert trigsimp(1 - tanh(x)**2) == 1/cosh(x)**2
assert trigsimp(1 - 1/cosh(x)**2) == tanh(x)**2
assert trigsimp(tanh(x)**2 + 1/cosh(x)**2) == 1
assert trigsimp(coth(x)**2 - 1) == 1/sinh(x)**2
assert trigsimp(1/sinh(x)**2 + 1) == 1/tanh(x)**2
assert trigsimp(coth(x)**2 - 1/sinh(x)**2) == 1
assert trigsimp(5*cosh(x)**2 - 5*sinh(x)**2) == 5
assert trigsimp(5*cosh(x/2)**2 - 2*sinh(x/2)**2) == 3*cosh(x)/2 + Rational(7, 2)
assert trigsimp(sinh(x)/cosh(x)) == tanh(x)
assert trigsimp(tanh(x)) == trigsimp(sinh(x)/cosh(x))
assert trigsimp(cosh(x)/sinh(x)) == 1/tanh(x)
assert trigsimp(2*tanh(x)*cosh(x)) == 2*sinh(x)
assert trigsimp(coth(x)**3*sinh(x)**3) == cosh(x)**3
assert trigsimp(y*tanh(x)**2/sinh(x)**2) == y/cosh(x)**2
assert trigsimp(coth(x)/cosh(x)) == 1/sinh(x)
for a in (pi/6*I, pi/4*I, pi/3*I):
assert trigsimp(sinh(a)*cosh(x) + cosh(a)*sinh(x)) == sinh(x + a)
assert trigsimp(-sinh(a)*cosh(x) + cosh(a)*sinh(x)) == sinh(x - a)
e = 2*cosh(x)**2 - 2*sinh(x)**2
assert trigsimp(log(e)) == log(2)
# issue 19535:
assert trigsimp(sqrt(cosh(x)**2 - 1)) == sqrt(sinh(x)**2)
assert trigsimp(cosh(x)**2*cosh(y)**2 - cosh(x)**2*sinh(y)**2 - sinh(x)**2,
recursive=True) == 1
assert trigsimp(sinh(x)**2*sinh(y)**2 - sinh(x)**2*cosh(y)**2 + cosh(x)**2,
recursive=True) == 1
assert abs(trigsimp(2.0*cosh(x)**2 - 2.0*sinh(x)**2) - 2.0) < 1e-10
assert trigsimp(sinh(x)**2/cosh(x)**2) == tanh(x)**2
assert trigsimp(sinh(x)**3/cosh(x)**3) == tanh(x)**3
assert trigsimp(sinh(x)**10/cosh(x)**10) == tanh(x)**10
assert trigsimp(cosh(x)**3/sinh(x)**3) == 1/tanh(x)**3
assert trigsimp(cosh(x)/sinh(x)) == 1/tanh(x)
assert trigsimp(cosh(x)**2/sinh(x)**2) == 1/tanh(x)**2
assert trigsimp(cosh(x)**10/sinh(x)**10) == 1/tanh(x)**10
assert trigsimp(x*cosh(x)*tanh(x)) == x*sinh(x)
assert trigsimp(-sinh(x) + cosh(x)*tanh(x)) == 0
assert tan(x) != 1/cot(x) # cot doesn't auto-simplify
assert trigsimp(tan(x) - 1/cot(x)) == 0
assert trigsimp(3*tanh(x)**7 - 2/coth(x)**7) == tanh(x)**7
def test_trigsimp_groebner():
from sympy.simplify.trigsimp import trigsimp_groebner
c = cos(x)
s = sin(x)
ex = (4*s*c + 12*s + 5*c**3 + 21*c**2 + 23*c + 15)/(
-s*c**2 + 2*s*c + 15*s + 7*c**3 + 31*c**2 + 37*c + 21)
resnum = (5*s - 5*c + 1)
resdenom = (8*s - 6*c)
results = [resnum/resdenom, (-resnum)/(-resdenom)]
assert trigsimp_groebner(ex) in results
assert trigsimp_groebner(s/c, hints=[tan]) == tan(x)
assert trigsimp_groebner(c*s) == c*s
assert trigsimp((-s + 1)/c + c/(-s + 1),
method='groebner') == 2/c
assert trigsimp((-s + 1)/c + c/(-s + 1),
method='groebner', polynomial=True) == 2/c
# Test quick=False works
assert trigsimp_groebner(ex, hints=[2]) in results
assert trigsimp_groebner(ex, hints=[int(2)]) in results
# test "I"
assert trigsimp_groebner(sin(I*x)/cos(I*x), hints=[tanh]) == I*tanh(x)
# test hyperbolic / sums
assert trigsimp_groebner((tanh(x)+tanh(y))/(1+tanh(x)*tanh(y)),
hints=[(tanh, x, y)]) == tanh(x + y)
def test_issue_2827_trigsimp_methods():
measure1 = lambda expr: len(str(expr))
measure2 = lambda expr: -count_ops(expr)
# Return the most complicated result
expr = (x + 1)/(x + sin(x)**2 + cos(x)**2)
ans = Matrix([1])
M = Matrix([expr])
assert trigsimp(M, method='fu', measure=measure1) == ans
assert trigsimp(M, method='fu', measure=measure2) != ans
# all methods should work with Basic expressions even if they
# aren't Expr
M = Matrix.eye(1)
assert all(trigsimp(M, method=m) == M for m in
'fu matching groebner old'.split())
# watch for E in exptrigsimp, not only exp()
eq = 1/sqrt(E) + E
assert exptrigsimp(eq) == eq
def test_issue_15129_trigsimp_methods():
t1 = Matrix([sin(Rational(1, 50)), cos(Rational(1, 50)), 0])
t2 = Matrix([sin(Rational(1, 25)), cos(Rational(1, 25)), 0])
t3 = Matrix([cos(Rational(1, 25)), sin(Rational(1, 25)), 0])
r1 = t1.dot(t2)
r2 = t1.dot(t3)
assert trigsimp(r1) == cos(Rational(1, 50))
assert trigsimp(r2) == sin(Rational(3, 50))
def test_exptrigsimp():
def valid(a, b):
from sympy.core.random import verify_numerically as tn
if not (tn(a, b) and a == b):
return False
return True
assert exptrigsimp(exp(x) + exp(-x)) == 2*cosh(x)
assert exptrigsimp(exp(x) - exp(-x)) == 2*sinh(x)
assert exptrigsimp((2*exp(x)-2*exp(-x))/(exp(x)+exp(-x))) == 2*tanh(x)
assert exptrigsimp((2*exp(2*x)-2)/(exp(2*x)+1)) == 2*tanh(x)
e = [cos(x) + I*sin(x), cos(x) - I*sin(x),
cosh(x) - sinh(x), cosh(x) + sinh(x)]
ok = [exp(I*x), exp(-I*x), exp(-x), exp(x)]
assert all(valid(i, j) for i, j in zip(
[exptrigsimp(ei) for ei in e], ok))
ue = [cos(x) + sin(x), cos(x) - sin(x),
cosh(x) + I*sinh(x), cosh(x) - I*sinh(x)]
assert [exptrigsimp(ei) == ei for ei in ue]
res = []
ok = [y*tanh(1), 1/(y*tanh(1)), I*y*tan(1), -I/(y*tan(1)),
y*tanh(x), 1/(y*tanh(x)), I*y*tan(x), -I/(y*tan(x)),
y*tanh(1 + I), 1/(y*tanh(1 + I))]
for a in (1, I, x, I*x, 1 + I):
w = exp(a)
eq = y*(w - 1/w)/(w + 1/w)
res.append(simplify(eq))
res.append(simplify(1/eq))
assert all(valid(i, j) for i, j in zip(res, ok))
for a in range(1, 3):
w = exp(a)
e = w + 1/w
s = simplify(e)
assert s == exptrigsimp(e)
assert valid(s, 2*cosh(a))
e = w - 1/w
s = simplify(e)
assert s == exptrigsimp(e)
assert valid(s, 2*sinh(a))
def test_exptrigsimp_noncommutative():
a,b = symbols('a b', commutative=False)
x = Symbol('x', commutative=True)
assert exp(a + x) == exptrigsimp(exp(a)*exp(x))
p = exp(a)*exp(b) - exp(b)*exp(a)
assert p == exptrigsimp(p) != 0
def test_powsimp_on_numbers():
assert 2**(Rational(1, 3) - 2) == 2**Rational(1, 3)/4
@XFAIL
def test_issue_6811_fail():
# from doc/src/modules/physics/mechanics/examples.rst, the current `eq`
# at Line 576 (in different variables) was formerly the equivalent and
# shorter expression given below...it would be nice to get the short one
# back again
xp, y, x, z = symbols('xp, y, x, z')
eq = 4*(-19*sin(x)*y + 5*sin(3*x)*y + 15*cos(2*x)*z - 21*z)*xp/(9*cos(x) - 5*cos(3*x))
assert trigsimp(eq) == -2*(2*cos(x)*tan(x)*y + 3*z)*xp/cos(x)
def test_Piecewise():
e1 = x*(x + y) - y*(x + y)
e2 = sin(x)**2 + cos(x)**2
e3 = expand((x + y)*y/x)
# s1 = simplify(e1)
s2 = simplify(e2)
# s3 = simplify(e3)
# trigsimp tries not to touch non-trig containing args
assert trigsimp(Piecewise((e1, e3 < e2), (e3, True))) == \
Piecewise((e1, e3 < s2), (e3, True))
def test_issue_21594():
assert simplify(exp(Rational(1,2)) + exp(Rational(-1,2))) == cosh(S.Half)*2
def test_trigsimp_old():
x, y = symbols('x,y')
assert trigsimp(1 - sin(x)**2, old=True) == cos(x)**2
assert trigsimp(1 - cos(x)**2, old=True) == sin(x)**2
assert trigsimp(sin(x)**2 + cos(x)**2, old=True) == 1
assert trigsimp(1 + tan(x)**2, old=True) == 1/cos(x)**2
assert trigsimp(1/cos(x)**2 - 1, old=True) == tan(x)**2
assert trigsimp(1/cos(x)**2 - tan(x)**2, old=True) == 1
assert trigsimp(1 + cot(x)**2, old=True) == 1/sin(x)**2
assert trigsimp(1/sin(x)**2 - cot(x)**2, old=True) == 1
assert trigsimp(5*cos(x)**2 + 5*sin(x)**2, old=True) == 5
assert trigsimp(sin(x)/cos(x), old=True) == tan(x)
assert trigsimp(2*tan(x)*cos(x), old=True) == 2*sin(x)
assert trigsimp(cot(x)**3*sin(x)**3, old=True) == cos(x)**3
assert trigsimp(y*tan(x)**2/sin(x)**2, old=True) == y/cos(x)**2
assert trigsimp(cot(x)/cos(x), old=True) == 1/sin(x)
assert trigsimp(sin(x + y) + sin(x - y), old=True) == 2*sin(x)*cos(y)
assert trigsimp(sin(x + y) - sin(x - y), old=True) == 2*sin(y)*cos(x)
assert trigsimp(cos(x + y) + cos(x - y), old=True) == 2*cos(x)*cos(y)
assert trigsimp(cos(x + y) - cos(x - y), old=True) == -2*sin(x)*sin(y)
assert trigsimp(sinh(x + y) + sinh(x - y), old=True) == 2*sinh(x)*cosh(y)
assert trigsimp(sinh(x + y) - sinh(x - y), old=True) == 2*sinh(y)*cosh(x)
assert trigsimp(cosh(x + y) + cosh(x - y), old=True) == 2*cosh(x)*cosh(y)
assert trigsimp(cosh(x + y) - cosh(x - y), old=True) == 2*sinh(x)*sinh(y)
assert trigsimp(cos(0.12345)**2 + sin(0.12345)**2, old=True) == 1.0
assert trigsimp(sin(x)/cos(x), old=True, method='combined') == tan(x)
assert trigsimp(sin(x)/cos(x), old=True, method='groebner') == sin(x)/cos(x)
assert trigsimp(sin(x)/cos(x), old=True, method='groebner', hints=[tan]) == tan(x)
assert trigsimp(1-sin(sin(x)**2+cos(x)**2)**2, old=True, deep=True) == cos(1)**2
def test_trigsimp_inverse():
alpha = symbols('alpha')
s, c = sin(alpha), cos(alpha)
for finv in [asin, acos, asec, acsc, atan, acot]:
f = finv.inverse(None)
assert alpha == trigsimp(finv(f(alpha)), inverse=True)
# test atan2(cos, sin), atan2(sin, cos), etc...
for a, b in [[c, s], [s, c]]:
for i, j in product([-1, 1], repeat=2):
angle = atan2(i*b, j*a)
angle_inverted = trigsimp(angle, inverse=True)
assert angle_inverted != angle # assures simplification happened
assert sin(angle_inverted) == trigsimp(sin(angle))
assert cos(angle_inverted) == trigsimp(cos(angle))

View File

@ -0,0 +1,15 @@
from sympy.core.traversal import use as _use
from sympy.utilities.decorator import deprecated
use = deprecated(
"""
Using use from the sympy.simplify.traversaltools submodule is
deprecated.
Instead, use use from the top-level sympy namespace, like
sympy.use
""",
deprecated_since_version="1.10",
active_deprecations_target="deprecated-traversal-functions-moved"
)(_use)

File diff suppressed because it is too large Load Diff