324 lines
12 KiB
Python
324 lines
12 KiB
Python
from sympy.core.containers import Tuple
|
|
from sympy.core.basic import Basic
|
|
from sympy.core.expr import Expr
|
|
from sympy.core.function import AppliedUndef
|
|
from sympy.core.relational import Relational
|
|
from sympy.core.symbol import Dummy
|
|
from sympy.core.sympify import sympify
|
|
from sympy.logic.boolalg import BooleanFunction
|
|
from sympy.sets.fancysets import ImageSet
|
|
from sympy.sets.sets import FiniteSet
|
|
from sympy.tensor.indexed import Indexed
|
|
|
|
|
|
def _get_free_symbols(exprs):
|
|
"""Returns the free symbols of a symbolic expression.
|
|
|
|
If the expression contains any of these elements, assume that they are
|
|
the "free symbols" of the expression:
|
|
|
|
* indexed objects
|
|
* applied undefined function (useful for sympy.physics.mechanics module)
|
|
"""
|
|
if not isinstance(exprs, (list, tuple, set)):
|
|
exprs = [exprs]
|
|
if all(callable(e) for e in exprs):
|
|
return set()
|
|
|
|
free = set().union(*[e.atoms(Indexed) for e in exprs])
|
|
free = free.union(*[e.atoms(AppliedUndef) for e in exprs])
|
|
return free or set().union(*[e.free_symbols for e in exprs])
|
|
|
|
|
|
def extract_solution(set_sol, n=10):
|
|
"""Extract numerical solutions from a set solution (computed by solveset,
|
|
linsolve, nonlinsolve). Often, it is not trivial do get something useful
|
|
out of them.
|
|
|
|
Parameters
|
|
==========
|
|
|
|
n : int, optional
|
|
In order to replace ImageSet with FiniteSet, an iterator is created
|
|
for each ImageSet contained in `set_sol`, starting from 0 up to `n`.
|
|
Default value: 10.
|
|
"""
|
|
images = set_sol.find(ImageSet)
|
|
for im in images:
|
|
it = iter(im)
|
|
s = FiniteSet(*[next(it) for n in range(0, n)])
|
|
set_sol = set_sol.subs(im, s)
|
|
return set_sol
|
|
|
|
|
|
def _plot_sympify(args):
|
|
"""This function recursively loop over the arguments passed to the plot
|
|
functions: the sympify function will be applied to all arguments except
|
|
those of type string/dict.
|
|
|
|
Generally, users can provide the following arguments to a plot function:
|
|
|
|
expr, range1 [tuple, opt], ..., label [str, opt], rendering_kw [dict, opt]
|
|
|
|
`expr, range1, ...` can be sympified, whereas `label, rendering_kw` can't.
|
|
In particular, whenever a special character like $, {, }, ... is used in
|
|
the `label`, sympify will raise an error.
|
|
"""
|
|
if isinstance(args, Expr):
|
|
return args
|
|
|
|
args = list(args)
|
|
for i, a in enumerate(args):
|
|
if isinstance(a, (list, tuple)):
|
|
args[i] = Tuple(*_plot_sympify(a), sympify=False)
|
|
elif not (isinstance(a, (str, dict)) or callable(a)
|
|
# NOTE: check if it is a vector from sympy.physics.vector module
|
|
# without importing the module (because it slows down SymPy's
|
|
# import process and triggers SymPy's optional-dependencies
|
|
# tests to fail).
|
|
or ((a.__class__.__name__ == "Vector") and not isinstance(a, Basic))
|
|
):
|
|
args[i] = sympify(a)
|
|
return args
|
|
|
|
|
|
def _create_ranges(exprs, ranges, npar, label="", params=None):
|
|
"""This function does two things:
|
|
|
|
1. Check if the number of free symbols is in agreement with the type of
|
|
plot chosen. For example, plot() requires 1 free symbol;
|
|
plot3d() requires 2 free symbols.
|
|
2. Sometime users create plots without providing ranges for the variables.
|
|
Here we create the necessary ranges.
|
|
|
|
Parameters
|
|
==========
|
|
|
|
exprs : iterable
|
|
The expressions from which to extract the free symbols
|
|
ranges : iterable
|
|
The limiting ranges provided by the user
|
|
npar : int
|
|
The number of free symbols required by the plot functions.
|
|
For example,
|
|
npar=1 for plot, npar=2 for plot3d, ...
|
|
params : dict
|
|
A dictionary mapping symbols to parameters for interactive plot.
|
|
"""
|
|
get_default_range = lambda symbol: Tuple(symbol, -10, 10)
|
|
|
|
free_symbols = _get_free_symbols(exprs)
|
|
if params is not None:
|
|
free_symbols = free_symbols.difference(params.keys())
|
|
|
|
if len(free_symbols) > npar:
|
|
raise ValueError(
|
|
"Too many free symbols.\n"
|
|
+ "Expected {} free symbols.\n".format(npar)
|
|
+ "Received {}: {}".format(len(free_symbols), free_symbols)
|
|
)
|
|
|
|
if len(ranges) > npar:
|
|
raise ValueError(
|
|
"Too many ranges. Received %s, expected %s" % (len(ranges), npar))
|
|
|
|
# free symbols in the ranges provided by the user
|
|
rfs = set().union([r[0] for r in ranges])
|
|
if len(rfs) != len(ranges):
|
|
raise ValueError("Multiple ranges with the same symbol")
|
|
|
|
if len(ranges) < npar:
|
|
symbols = free_symbols.difference(rfs)
|
|
if symbols != set():
|
|
# add a range for each missing free symbols
|
|
for s in symbols:
|
|
ranges.append(get_default_range(s))
|
|
# if there is still room, fill them with dummys
|
|
for i in range(npar - len(ranges)):
|
|
ranges.append(get_default_range(Dummy()))
|
|
|
|
if len(free_symbols) == npar:
|
|
# there could be times when this condition is not met, for example
|
|
# plotting the function f(x, y) = x (which is a plane); in this case,
|
|
# free_symbols = {x} whereas rfs = {x, y} (or x and Dummy)
|
|
rfs = set().union([r[0] for r in ranges])
|
|
if len(free_symbols.difference(rfs)) > 0:
|
|
raise ValueError(
|
|
"Incompatible free symbols of the expressions with "
|
|
"the ranges.\n"
|
|
+ "Free symbols in the expressions: {}\n".format(free_symbols)
|
|
+ "Free symbols in the ranges: {}".format(rfs)
|
|
)
|
|
return ranges
|
|
|
|
|
|
def _is_range(r):
|
|
"""A range is defined as (symbol, start, end). start and end should
|
|
be numbers.
|
|
"""
|
|
# TODO: prange check goes here
|
|
return (
|
|
isinstance(r, Tuple)
|
|
and (len(r) == 3)
|
|
and (not isinstance(r.args[1], str)) and r.args[1].is_number
|
|
and (not isinstance(r.args[2], str)) and r.args[2].is_number
|
|
)
|
|
|
|
|
|
def _unpack_args(*args):
|
|
"""Given a list/tuple of arguments previously processed by _plot_sympify()
|
|
and/or _check_arguments(), separates and returns its components:
|
|
expressions, ranges, label and rendering keywords.
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy import cos, sin, symbols
|
|
>>> from sympy.plotting.utils import _plot_sympify, _unpack_args
|
|
>>> x, y = symbols('x, y')
|
|
>>> args = (sin(x), (x, -10, 10), "f1")
|
|
>>> args = _plot_sympify(args)
|
|
>>> _unpack_args(*args)
|
|
([sin(x)], [(x, -10, 10)], 'f1', None)
|
|
|
|
>>> args = (sin(x**2 + y**2), (x, -2, 2), (y, -3, 3), "f2")
|
|
>>> args = _plot_sympify(args)
|
|
>>> _unpack_args(*args)
|
|
([sin(x**2 + y**2)], [(x, -2, 2), (y, -3, 3)], 'f2', None)
|
|
|
|
>>> args = (sin(x + y), cos(x - y), x + y, (x, -2, 2), (y, -3, 3), "f3")
|
|
>>> args = _plot_sympify(args)
|
|
>>> _unpack_args(*args)
|
|
([sin(x + y), cos(x - y), x + y], [(x, -2, 2), (y, -3, 3)], 'f3', None)
|
|
"""
|
|
ranges = [t for t in args if _is_range(t)]
|
|
labels = [t for t in args if isinstance(t, str)]
|
|
label = None if not labels else labels[0]
|
|
rendering_kw = [t for t in args if isinstance(t, dict)]
|
|
rendering_kw = None if not rendering_kw else rendering_kw[0]
|
|
# NOTE: why None? because args might have been preprocessed by
|
|
# _check_arguments, so None might represent the rendering_kw
|
|
results = [not (_is_range(a) or isinstance(a, (str, dict)) or (a is None)) for a in args]
|
|
exprs = [a for a, b in zip(args, results) if b]
|
|
return exprs, ranges, label, rendering_kw
|
|
|
|
|
|
def _check_arguments(args, nexpr, npar, **kwargs):
|
|
"""Checks the arguments and converts into tuples of the
|
|
form (exprs, ranges, label, rendering_kw).
|
|
|
|
Parameters
|
|
==========
|
|
|
|
args
|
|
The arguments provided to the plot functions
|
|
nexpr
|
|
The number of sub-expression forming an expression to be plotted.
|
|
For example:
|
|
nexpr=1 for plot.
|
|
nexpr=2 for plot_parametric: a curve is represented by a tuple of two
|
|
elements.
|
|
nexpr=1 for plot3d.
|
|
nexpr=3 for plot3d_parametric_line: a curve is represented by a tuple
|
|
of three elements.
|
|
npar
|
|
The number of free symbols required by the plot functions. For example,
|
|
npar=1 for plot, npar=2 for plot3d, ...
|
|
**kwargs :
|
|
keyword arguments passed to the plotting function. It will be used to
|
|
verify if ``params`` has ben provided.
|
|
|
|
Examples
|
|
========
|
|
|
|
.. plot::
|
|
:context: reset
|
|
:format: doctest
|
|
:include-source: True
|
|
|
|
>>> from sympy import cos, sin, symbols
|
|
>>> from sympy.plotting.plot import _check_arguments
|
|
>>> x = symbols('x')
|
|
>>> _check_arguments([cos(x), sin(x)], 2, 1)
|
|
[(cos(x), sin(x), (x, -10, 10), None, None)]
|
|
|
|
>>> _check_arguments([cos(x), sin(x), "test"], 2, 1)
|
|
[(cos(x), sin(x), (x, -10, 10), 'test', None)]
|
|
|
|
>>> _check_arguments([cos(x), sin(x), "test", {"a": 0, "b": 1}], 2, 1)
|
|
[(cos(x), sin(x), (x, -10, 10), 'test', {'a': 0, 'b': 1})]
|
|
|
|
>>> _check_arguments([x, x**2], 1, 1)
|
|
[(x, (x, -10, 10), None, None), (x**2, (x, -10, 10), None, None)]
|
|
"""
|
|
if not args:
|
|
return []
|
|
output = []
|
|
params = kwargs.get("params", None)
|
|
|
|
if all(isinstance(a, (Expr, Relational, BooleanFunction)) for a in args[:nexpr]):
|
|
# In this case, with a single plot command, we are plotting either:
|
|
# 1. one expression
|
|
# 2. multiple expressions over the same range
|
|
|
|
exprs, ranges, label, rendering_kw = _unpack_args(*args)
|
|
free_symbols = set().union(*[e.free_symbols for e in exprs])
|
|
ranges = _create_ranges(exprs, ranges, npar, label, params)
|
|
|
|
if nexpr > 1:
|
|
# in case of plot_parametric or plot3d_parametric_line, there will
|
|
# be 2 or 3 expressions defining a curve. Group them together.
|
|
if len(exprs) == nexpr:
|
|
exprs = (tuple(exprs),)
|
|
for expr in exprs:
|
|
# need this if-else to deal with both plot/plot3d and
|
|
# plot_parametric/plot3d_parametric_line
|
|
is_expr = isinstance(expr, (Expr, Relational, BooleanFunction))
|
|
e = (expr,) if is_expr else expr
|
|
output.append((*e, *ranges, label, rendering_kw))
|
|
|
|
else:
|
|
# In this case, we are plotting multiple expressions, each one with its
|
|
# range. Each "expression" to be plotted has the following form:
|
|
# (expr, range, label) where label is optional
|
|
|
|
_, ranges, labels, rendering_kw = _unpack_args(*args)
|
|
labels = [labels] if labels else []
|
|
|
|
# number of expressions
|
|
n = (len(ranges) + len(labels) +
|
|
(len(rendering_kw) if rendering_kw is not None else 0))
|
|
new_args = args[:-n] if n > 0 else args
|
|
|
|
# at this point, new_args might just be [expr]. But I need it to be
|
|
# [[expr]] in order to be able to loop over
|
|
# [expr, range [opt], label [opt]]
|
|
if not isinstance(new_args[0], (list, tuple, Tuple)):
|
|
new_args = [new_args]
|
|
|
|
# Each arg has the form (expr1, expr2, ..., range1 [optional], ...,
|
|
# label [optional], rendering_kw [optional])
|
|
for arg in new_args:
|
|
# look for "local" range and label. If there is not, use "global".
|
|
l = [a for a in arg if isinstance(a, str)]
|
|
if not l:
|
|
l = labels
|
|
r = [a for a in arg if _is_range(a)]
|
|
if not r:
|
|
r = ranges.copy()
|
|
rend_kw = [a for a in arg if isinstance(a, dict)]
|
|
rend_kw = rendering_kw if len(rend_kw) == 0 else rend_kw[0]
|
|
|
|
# NOTE: arg = arg[:nexpr] may raise an exception if lambda
|
|
# functions are used. Execute the following instead:
|
|
arg = [arg[i] for i in range(nexpr)]
|
|
free_symbols = set()
|
|
if all(not callable(a) for a in arg):
|
|
free_symbols = free_symbols.union(*[a.free_symbols for a in arg])
|
|
if len(r) != npar:
|
|
r = _create_ranges(arg, r, npar, "", params)
|
|
|
|
label = None if not l else l[0]
|
|
output.append((*arg, *r, label, rend_kw))
|
|
return output
|